6069 lines
272 KiB
Python
6069 lines
272 KiB
Python
"""
|
||
AI 聊天插件
|
||
|
||
支持自定义模型、API 和人设
|
||
支持 Redis 存储对话历史和限流
|
||
"""
|
||
|
||
import asyncio
|
||
import tomllib
|
||
import aiohttp
|
||
import json
|
||
import re
|
||
import time
|
||
import copy
|
||
from contextlib import asynccontextmanager
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
from loguru import logger
|
||
from utils.plugin_base import PluginBase
|
||
from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message
|
||
from utils.redis_cache import get_cache
|
||
from utils.image_processor import ImageProcessor, MediaConfig
|
||
from utils.tool_executor import ToolExecutor
|
||
from utils.tool_registry import get_tool_registry
|
||
from utils.member_info_service import get_member_service
|
||
import xml.etree.ElementTree as ET
|
||
import base64
|
||
import uuid
|
||
|
||
# 可选导入代理支持
|
||
try:
|
||
from aiohttp_socks import ProxyConnector
|
||
PROXY_SUPPORT = True
|
||
except ImportError:
|
||
PROXY_SUPPORT = False
|
||
logger.warning("aiohttp_socks 未安装,代理功能将不可用")
|
||
|
||
# 可选导入 Chroma 向量数据库
|
||
try:
|
||
import chromadb
|
||
from chromadb.api.types import EmbeddingFunction, Documents, Embeddings
|
||
CHROMA_SUPPORT = True
|
||
except ImportError:
|
||
CHROMA_SUPPORT = False
|
||
|
||
|
||
if CHROMA_SUPPORT:
|
||
class SiliconFlowEmbedding(EmbeddingFunction):
|
||
"""调用硅基流动 API 的自定义 Embedding 函数"""
|
||
|
||
def __init__(self, api_url: str, api_key: str, model: str):
|
||
self._api_url = api_url
|
||
self._api_key = api_key
|
||
self._model = model
|
||
|
||
def __call__(self, input: Documents) -> Embeddings:
|
||
import httpx as _httpx
|
||
resp = _httpx.post(
|
||
self._api_url,
|
||
headers={
|
||
"Authorization": f"Bearer {self._api_key}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
json={"model": self._model, "input": input},
|
||
timeout=30,
|
||
)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
return [item["embedding"] for item in data["data"]]
|
||
|
||
|
||
class AIChat(PluginBase):
|
||
"""AI 聊天插件"""
|
||
|
||
# 插件元数据
|
||
description = "AI 聊天插件,支持自定义模型和人设"
|
||
author = "ShiHao"
|
||
version = "1.0.0"
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.config = None
|
||
self.system_prompt = ""
|
||
self.memory = {} # 存储每个会话的记忆 {chat_id: [messages]}
|
||
self.history_dir = None # 历史记录目录
|
||
self.history_locks = {} # 每个会话一把锁
|
||
self._reply_locks = {} # 每个会话一把回复锁(串行回复)
|
||
self._serial_reply = False
|
||
self._tool_async = True
|
||
self._tool_followup_ai_reply = True
|
||
self._tool_rule_prompt_enabled = True
|
||
self.image_desc_queue = asyncio.Queue() # 图片描述任务队列
|
||
self.image_desc_workers = [] # 工作协程列表
|
||
self.persistent_memory_db = None # 持久记忆数据库路径
|
||
self.store = None # ContextStore 实例(统一存储)
|
||
self._chatroom_member_cache = {} # {chatroom_id: (ts, {wxid: display_name})}
|
||
self._chatroom_member_cache_locks = {} # {chatroom_id: asyncio.Lock}
|
||
self._chatroom_member_cache_ttl_seconds = 3600 # 群名片缓存1小时,减少协议 API 调用
|
||
self._image_processor = None # ImageProcessor 实例
|
||
# 向量长期记忆(Chroma)
|
||
self._vector_memory_enabled = False
|
||
self._chroma_collection = None
|
||
self._vector_watermarks = {} # {chatroom_id: str} 最后已摘要消息的时间戳
|
||
self._vector_tasks = {} # {chatroom_id: asyncio.Task} 后台摘要任务
|
||
self._watermark_file = None
|
||
|
||
async def async_init(self):
|
||
"""插件异步初始化"""
|
||
# 读取配置
|
||
config_path = Path(__file__).parent / "config.toml"
|
||
with open(config_path, "rb") as f:
|
||
self.config = tomllib.load(f)
|
||
|
||
behavior_config = self.config.get("behavior", {})
|
||
self._serial_reply = bool(behavior_config.get("serial_reply", False))
|
||
tools_config = self.config.get("tools", {})
|
||
self._tool_async = bool(tools_config.get("async_execute", True))
|
||
self._tool_followup_ai_reply = bool(tools_config.get("followup_ai_reply", True))
|
||
self._tool_rule_prompt_enabled = bool(tools_config.get("rule_prompt_enabled", True))
|
||
if self._serial_reply:
|
||
self._tool_async = False
|
||
logger.info(
|
||
f"AIChat 串行回复: {self._serial_reply}, 工具异步执行: {self._tool_async}, "
|
||
f"工具后AI总结: {self._tool_followup_ai_reply}, 工具规则注入: {self._tool_rule_prompt_enabled}"
|
||
)
|
||
|
||
# 读取人设
|
||
prompt_file = self.config["prompt"]["system_prompt_file"]
|
||
prompt_path = Path(__file__).parent / "prompts" / prompt_file
|
||
|
||
if prompt_path.exists():
|
||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||
self.system_prompt = f.read().strip()
|
||
logger.success(f"已加载人设: {prompt_file}")
|
||
else:
|
||
logger.warning(f"人设文件不存在: {prompt_file},使用默认人设")
|
||
self.system_prompt = "你是一个友好的 AI 助手。"
|
||
|
||
# 检查代理配置
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5")
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
logger.info(f"AI 聊天插件已启用代理: {proxy_type}://{proxy_host}:{proxy_port}")
|
||
|
||
# 初始化历史记录目录
|
||
history_config = self.config.get("history", {})
|
||
if history_config.get("enabled", True):
|
||
history_dir_name = history_config.get("history_dir", "history")
|
||
self.history_dir = Path(__file__).parent / history_dir_name
|
||
self.history_dir.mkdir(exist_ok=True)
|
||
logger.info(f"历史记录目录: {self.history_dir}")
|
||
|
||
# 启动图片描述工作协程(并发数为2)
|
||
for i in range(2):
|
||
worker = asyncio.create_task(self._image_desc_worker())
|
||
self.image_desc_workers.append(worker)
|
||
logger.info("已启动 2 个图片描述工作协程")
|
||
|
||
# 初始化持久记忆数据库与统一存储
|
||
from utils.context_store import ContextStore
|
||
db_dir = Path(__file__).parent / "data"
|
||
db_dir.mkdir(exist_ok=True)
|
||
self.persistent_memory_db = db_dir / "persistent_memory.db"
|
||
self.store = ContextStore(
|
||
self.config,
|
||
self.history_dir,
|
||
self.memory,
|
||
self.history_locks,
|
||
self.persistent_memory_db,
|
||
)
|
||
self.store.init_persistent_memory_db()
|
||
|
||
# 初始化向量长期记忆(Chroma)
|
||
vm_config = self.config.get("vector_memory", {})
|
||
if vm_config.get("enabled", False) and CHROMA_SUPPORT:
|
||
try:
|
||
chroma_path = Path(__file__).parent / vm_config.get("chroma_db_path", "data/chroma_db")
|
||
chroma_path.mkdir(parents=True, exist_ok=True)
|
||
embedding_fn = SiliconFlowEmbedding(
|
||
api_url=vm_config.get("embedding_url", ""),
|
||
api_key=vm_config.get("embedding_api_key", ""),
|
||
model=vm_config.get("embedding_model", "BAAI/bge-m3"),
|
||
)
|
||
chroma_client = chromadb.PersistentClient(path=str(chroma_path))
|
||
self._chroma_collection = chroma_client.get_or_create_collection(
|
||
name="group_chat_summaries",
|
||
embedding_function=embedding_fn,
|
||
metadata={"hnsw:space": "cosine"},
|
||
)
|
||
self._watermark_file = Path(__file__).parent / "data" / "vector_watermarks.json"
|
||
self._load_watermarks()
|
||
self._vector_memory_enabled = True
|
||
logger.success(f"向量记忆已启用,Chroma 路径: {chroma_path}")
|
||
except Exception as e:
|
||
logger.error(f"向量记忆初始化失败: {e}")
|
||
self._vector_memory_enabled = False
|
||
elif vm_config.get("enabled", False) and not CHROMA_SUPPORT:
|
||
logger.warning("向量记忆已启用但 chromadb 未安装,请 pip install chromadb")
|
||
|
||
# 初始化 ImageProcessor(图片/表情/视频处理器)
|
||
temp_dir = Path(__file__).parent / "temp"
|
||
temp_dir.mkdir(exist_ok=True)
|
||
media_config = MediaConfig.from_dict(self.config)
|
||
self._image_processor = ImageProcessor(media_config, temp_dir)
|
||
logger.debug("ImageProcessor 已初始化")
|
||
|
||
logger.info(f"AI 聊天插件已加载,模型: {self.config['api']['model']}")
|
||
|
||
async def on_disable(self):
|
||
"""插件禁用时调用,清理后台任务和队列"""
|
||
await super().on_disable()
|
||
|
||
# 取消图片描述工作协程,避免重载后叠加
|
||
if self.image_desc_workers:
|
||
for worker in self.image_desc_workers:
|
||
worker.cancel()
|
||
await asyncio.gather(*self.image_desc_workers, return_exceptions=True)
|
||
self.image_desc_workers.clear()
|
||
|
||
# 清空图片描述队列
|
||
try:
|
||
while self.image_desc_queue and not self.image_desc_queue.empty():
|
||
self.image_desc_queue.get_nowait()
|
||
self.image_desc_queue.task_done()
|
||
except Exception:
|
||
pass
|
||
self.image_desc_queue = asyncio.Queue()
|
||
|
||
logger.info("AIChat 已清理后台图片描述任务")
|
||
|
||
# 取消向量摘要后台任务并保存水位线
|
||
if self._vector_tasks:
|
||
for task in self._vector_tasks.values():
|
||
if not task.done():
|
||
task.cancel()
|
||
await asyncio.gather(*self._vector_tasks.values(), return_exceptions=True)
|
||
self._vector_tasks.clear()
|
||
if self._vector_memory_enabled:
|
||
self._save_watermarks()
|
||
logger.info("AIChat 已清理向量摘要后台任务")
|
||
|
||
def _get_reply_lock(self, chat_id: str) -> asyncio.Lock:
|
||
lock = self._reply_locks.get(chat_id)
|
||
if lock is None:
|
||
lock = asyncio.Lock()
|
||
self._reply_locks[chat_id] = lock
|
||
return lock
|
||
|
||
@asynccontextmanager
|
||
async def _reply_lock_context(self, chat_id: str):
|
||
if not self._serial_reply or not chat_id:
|
||
yield
|
||
return
|
||
lock = self._get_reply_lock(chat_id)
|
||
if lock.locked():
|
||
logger.debug(f"AI 回复排队中: chat_id={chat_id}")
|
||
async with lock:
|
||
yield
|
||
|
||
def _add_persistent_memory(self, chat_id: str, chat_type: str, user_wxid: str,
|
||
user_nickname: str, content: str) -> int:
|
||
"""添加持久记忆,返回记忆ID(委托 ContextStore)"""
|
||
if not self.store:
|
||
return -1
|
||
return self.store.add_persistent_memory(chat_id, chat_type, user_wxid, user_nickname, content)
|
||
|
||
def _get_persistent_memories(self, chat_id: str) -> list:
|
||
"""获取指定会话的所有持久记忆(委托 ContextStore)"""
|
||
if not self.store:
|
||
return []
|
||
return self.store.get_persistent_memories(chat_id)
|
||
|
||
def _delete_persistent_memory(self, chat_id: str, memory_id: int) -> bool:
|
||
"""删除指定的持久记忆(委托 ContextStore)"""
|
||
if not self.store:
|
||
return False
|
||
return self.store.delete_persistent_memory(chat_id, memory_id)
|
||
|
||
def _clear_persistent_memories(self, chat_id: str) -> int:
|
||
"""清空指定会话的所有持久记忆(委托 ContextStore)"""
|
||
if not self.store:
|
||
return 0
|
||
return self.store.clear_persistent_memories(chat_id)
|
||
|
||
# ==================== 向量长期记忆(Chroma)====================
|
||
|
||
def _load_watermarks(self):
|
||
"""从文件加载向量摘要水位线"""
|
||
if self._watermark_file and self._watermark_file.exists():
|
||
try:
|
||
with open(self._watermark_file, "r", encoding="utf-8") as f:
|
||
self._vector_watermarks = json.load(f)
|
||
except Exception as e:
|
||
logger.warning(f"加载向量水位线失败: {e}")
|
||
self._vector_watermarks = {}
|
||
|
||
def _save_watermarks(self):
|
||
"""保存向量摘要水位线到文件"""
|
||
if not self._watermark_file:
|
||
return
|
||
try:
|
||
self._watermark_file.parent.mkdir(parents=True, exist_ok=True)
|
||
temp = Path(str(self._watermark_file) + ".tmp")
|
||
with open(temp, "w", encoding="utf-8") as f:
|
||
json.dump(self._vector_watermarks, f, ensure_ascii=False, indent=2)
|
||
temp.replace(self._watermark_file)
|
||
except Exception as e:
|
||
logger.warning(f"保存向量水位线失败: {e}")
|
||
|
||
async def _maybe_trigger_summarize(self, chatroom_id: str):
|
||
"""检查是否需要触发向量摘要(每 N 条新消息触发一次)
|
||
水位线使用最后一条已摘要消息的时间戳,不受历史裁剪影响。
|
||
"""
|
||
if not self._vector_memory_enabled:
|
||
return
|
||
existing = self._vector_tasks.get(chatroom_id)
|
||
if existing and not existing.done():
|
||
return
|
||
try:
|
||
history_chat_id = self._get_group_history_chat_id(chatroom_id)
|
||
history = await self._load_history(history_chat_id)
|
||
every = self.config.get("vector_memory", {}).get("summarize_every", 80)
|
||
watermark_ts = self._vector_watermarks.get(chatroom_id, "")
|
||
|
||
# 筛选水位线之后的新消息
|
||
if watermark_ts:
|
||
new_msgs = [m for m in history if str(m.get("timestamp", "")) > str(watermark_ts)]
|
||
else:
|
||
new_msgs = list(history)
|
||
|
||
logger.info(f"[VectorMemory] 检查: chatroom={chatroom_id}, history={len(history)}, new={len(new_msgs)}, watermark_ts={watermark_ts}, every={every}")
|
||
|
||
if len(new_msgs) >= every:
|
||
# 取最早的 every 条新消息做摘要
|
||
batch = new_msgs[:every]
|
||
last_ts = str(batch[-1].get("timestamp", ""))
|
||
# 乐观更新水位线
|
||
self._vector_watermarks[chatroom_id] = last_ts
|
||
self._save_watermarks()
|
||
task = asyncio.create_task(
|
||
self._do_summarize_and_store(chatroom_id, batch, last_ts)
|
||
)
|
||
self._vector_tasks[chatroom_id] = task
|
||
except Exception as e:
|
||
logger.warning(f"[VectorMemory] 触发检查失败: {e}")
|
||
|
||
async def _do_summarize_and_store(self, chatroom_id: str, messages: list, watermark_ts: str):
|
||
"""后台任务:LLM 摘要 + 存入 Chroma"""
|
||
try:
|
||
logger.info(f"[VectorMemory] 触发后台摘要: {chatroom_id}, 消息数={len(messages)}")
|
||
text_block = self._format_messages_for_summary(messages)
|
||
summary = await self._call_summary_llm(text_block)
|
||
if not summary or len(summary.strip()) < 10:
|
||
logger.warning(f"[VectorMemory] 摘要结果过短,跳过本批")
|
||
return
|
||
|
||
ts_start = str(messages[0].get("timestamp", "")) if messages else ""
|
||
ts_end = str(messages[-1].get("timestamp", "")) if messages else ""
|
||
safe_ts = watermark_ts.replace(":", "-").replace(".", "-")
|
||
doc_id = f"{chatroom_id}_{safe_ts}"
|
||
|
||
self._chroma_collection.add(
|
||
ids=[doc_id],
|
||
documents=[summary],
|
||
metadatas=[{
|
||
"chatroom_id": chatroom_id,
|
||
"ts_start": ts_start,
|
||
"ts_end": ts_end,
|
||
"watermark_ts": watermark_ts,
|
||
}],
|
||
)
|
||
logger.success(f"[VectorMemory] 摘要已存储: {doc_id}, 长度={len(summary)}")
|
||
except Exception as e:
|
||
logger.error(f"[VectorMemory] 摘要存储失败: {e}")
|
||
|
||
def _format_messages_for_summary(self, messages: list) -> str:
|
||
"""将历史消息格式化为文本块供 LLM 摘要"""
|
||
lines = []
|
||
for m in messages:
|
||
nick = m.get("nickname", "未知")
|
||
content = m.get("content", "")
|
||
ts = m.get("timestamp", "")
|
||
if ts:
|
||
try:
|
||
from datetime import datetime
|
||
dt = datetime.fromtimestamp(float(ts))
|
||
ts_str = dt.strftime("%m-%d %H:%M")
|
||
except Exception:
|
||
ts_str = str(ts)[:16]
|
||
else:
|
||
ts_str = ""
|
||
prefix = f"[{ts_str}] " if ts_str else ""
|
||
lines.append(f"{prefix}{nick}: {content}")
|
||
return "\n".join(lines)
|
||
|
||
async def _call_summary_llm(self, text_block: str) -> str:
|
||
"""调用 LLM 生成群聊摘要"""
|
||
vm_config = self.config.get("vector_memory", {})
|
||
api_config = self.config.get("api", {})
|
||
model = vm_config.get("summary_model", "") or api_config.get("model", "")
|
||
max_tokens = vm_config.get("summary_max_tokens", 2048)
|
||
url = api_config.get("url", "")
|
||
api_key = api_config.get("api_key", "")
|
||
|
||
prompt = (
|
||
"请将以下群聊记录总结为一段简洁的摘要,保留关键话题、重要观点、"
|
||
"参与者的核心发言和结论。摘要应便于日后检索,不要遗漏重要信息。\n\n"
|
||
f"--- 群聊记录 ---\n{text_block}\n--- 结束 ---\n\n请输出摘要:"
|
||
)
|
||
payload = {
|
||
"model": model,
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"max_tokens": max_tokens,
|
||
"temperature": 0.3,
|
||
}
|
||
headers = {
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
timeout_sec = api_config.get("timeout", 120)
|
||
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(url, json=payload, headers=headers, timeout=aiohttp.ClientTimeout(total=timeout_sec)) as resp:
|
||
resp.raise_for_status()
|
||
data = await resp.json()
|
||
return data["choices"][0]["message"]["content"].strip()
|
||
except Exception as e:
|
||
logger.error(f"[VectorMemory] LLM 摘要调用失败: {e}")
|
||
return ""
|
||
|
||
async def _retrieve_vector_memories(self, chatroom_id: str, query_text: str) -> str:
|
||
"""从 Chroma 检索与当前消息相关的历史摘要"""
|
||
if not self._vector_memory_enabled or not self._chroma_collection:
|
||
return ""
|
||
vm_config = self.config.get("vector_memory", {})
|
||
top_k = vm_config.get("retrieval_top_k", 3)
|
||
min_score = vm_config.get("retrieval_min_score", 0.35)
|
||
max_chars = vm_config.get("max_inject_chars", 2000)
|
||
log_candidates = bool(vm_config.get("retrieval_log_candidates", True))
|
||
log_injected = bool(vm_config.get("retrieval_log_injected", True))
|
||
try:
|
||
log_max_chars = int(vm_config.get("retrieval_log_max_chars", 180))
|
||
except (TypeError, ValueError):
|
||
log_max_chars = 180
|
||
log_max_chars = max(60, log_max_chars)
|
||
|
||
try:
|
||
logger.info(f"[VectorMemory] 检索: chatroom={chatroom_id}, query={query_text[:50]}")
|
||
results = self._chroma_collection.query(
|
||
query_texts=[query_text],
|
||
where={"chatroom_id": chatroom_id},
|
||
n_results=top_k,
|
||
)
|
||
if not results or not results.get("documents") or not results["documents"][0]:
|
||
logger.info(f"[VectorMemory] 检索无结果: chatroom={chatroom_id}")
|
||
return ""
|
||
|
||
docs = results["documents"][0]
|
||
distances = results["distances"][0] if results.get("distances") else [0] * len(docs)
|
||
logger.info(f"[VectorMemory] 检索到 {len(docs)} 条候选, distances={[round(d, 3) for d in distances]}")
|
||
ids = results.get("ids", [])
|
||
ids = ids[0] if ids and isinstance(ids[0], list) else []
|
||
|
||
pieces = []
|
||
total_len = 0
|
||
for idx, (doc, dist) in enumerate(zip(docs, distances), start=1):
|
||
doc_id = ids[idx - 1] if idx - 1 < len(ids) else "-"
|
||
raw_doc = doc or ""
|
||
keep = dist <= min_score
|
||
if log_candidates:
|
||
snippet = re.sub(r"\s+", " ", raw_doc).strip()
|
||
if len(snippet) > log_max_chars:
|
||
snippet = snippet[:log_max_chars] + "..."
|
||
status = "命中" if keep else "过滤"
|
||
logger.info(
|
||
f"[VectorMemory] 候选#{idx} {status} "
|
||
f"(dist={dist:.4f}, threshold<={min_score:.4f}, id={doc_id}) "
|
||
f"内容: {snippet}"
|
||
)
|
||
|
||
if dist > min_score:
|
||
continue
|
||
if total_len + len(raw_doc) > max_chars:
|
||
remaining = max_chars - total_len
|
||
if remaining > 50:
|
||
pieces.append(raw_doc[:remaining] + "...")
|
||
if log_candidates:
|
||
logger.info(
|
||
f"[VectorMemory] 候选#{idx} 达到注入上限,截断后加入 "
|
||
f"{remaining} 字符 (max_inject_chars={max_chars})"
|
||
)
|
||
elif log_candidates:
|
||
logger.info(
|
||
f"[VectorMemory] 候选#{idx} 达到注入上限,剩余空间 {remaining} 字符,已跳过"
|
||
)
|
||
break
|
||
pieces.append(raw_doc)
|
||
total_len += len(raw_doc)
|
||
|
||
if not pieces:
|
||
logger.info(
|
||
f"[VectorMemory] 候选均未命中阈值或被注入长度限制拦截 "
|
||
f"(threshold<={min_score:.4f}, max_inject_chars={max_chars})"
|
||
)
|
||
return ""
|
||
logger.info(f"[VectorMemory] 检索到 {len(pieces)} 条相关记忆 (chatroom={chatroom_id})")
|
||
if log_injected:
|
||
preview = re.sub(r"\s+", " ", "\n---\n".join(pieces)).strip()
|
||
if len(preview) > log_max_chars:
|
||
preview = preview[:log_max_chars] + "..."
|
||
logger.info(
|
||
f"[VectorMemory] 最终注入预览 ({len(pieces)}条, {total_len}字): {preview}"
|
||
)
|
||
return "\n\n【历史记忆】以下是与当前话题相关的历史摘要:\n" + "\n---\n".join(pieces)
|
||
except Exception as e:
|
||
logger.warning(f"[VectorMemory] 检索失败: {e}")
|
||
return ""
|
||
|
||
def _get_vector_memories_for_display(self, chatroom_id: str) -> list:
|
||
"""获取指定群的所有向量记忆摘要(用于展示)"""
|
||
if not self._vector_memory_enabled or not self._chroma_collection:
|
||
return []
|
||
try:
|
||
results = self._chroma_collection.get(
|
||
where={"chatroom_id": chatroom_id},
|
||
include=["documents", "metadatas"],
|
||
)
|
||
if not results or not results.get("ids"):
|
||
return []
|
||
items = []
|
||
for i, doc_id in enumerate(results["ids"]):
|
||
meta = results["metadatas"][i] if results.get("metadatas") else {}
|
||
doc = results["documents"][i] if results.get("documents") else ""
|
||
items.append({
|
||
"id": doc_id,
|
||
"summary": doc,
|
||
"ts_start": meta.get("ts_start", ""),
|
||
"ts_end": meta.get("ts_end", ""),
|
||
"watermark": meta.get("watermark", 0),
|
||
})
|
||
items.sort(key=lambda x: x.get("watermark", 0))
|
||
return items
|
||
except Exception as e:
|
||
logger.warning(f"[VectorMemory] 获取展示数据失败: {e}")
|
||
return []
|
||
|
||
def _build_vector_memory_html(self, items: list, chatroom_id: str) -> str:
|
||
"""构建向量记忆展示的 HTML"""
|
||
from datetime import datetime as _dt
|
||
|
||
# 构建摘要卡片 HTML
|
||
cards_html = ""
|
||
for idx, item in enumerate(items, 1):
|
||
summary = item["summary"].replace("&", "&").replace("<", "<").replace(">", ">").replace("\n", "<br>")
|
||
# 时间范围
|
||
time_range = ""
|
||
try:
|
||
if item["ts_start"] and item["ts_end"]:
|
||
t1 = _dt.fromtimestamp(float(item["ts_start"]))
|
||
t2 = _dt.fromtimestamp(float(item["ts_end"]))
|
||
time_range = f'{t1.strftime("%m/%d %H:%M")} — {t2.strftime("%m/%d %H:%M")}'
|
||
except Exception:
|
||
pass
|
||
if not time_range:
|
||
time_range = f'片段 #{idx}'
|
||
|
||
cards_html += f'''
|
||
<div class="memory-card">
|
||
<div class="card-header">
|
||
<span class="card-index">#{idx}</span>
|
||
<span class="card-time">{time_range}</span>
|
||
</div>
|
||
<div class="card-body">{summary}</div>
|
||
</div>'''
|
||
|
||
now_str = _dt.now().strftime("%Y-%m-%d %H:%M")
|
||
watermark_ts = self._vector_watermarks.get(chatroom_id, "")
|
||
if watermark_ts:
|
||
try:
|
||
wt = _dt.fromisoformat(str(watermark_ts))
|
||
watermark_display = wt.strftime("%m/%d %H:%M")
|
||
except Exception:
|
||
watermark_display = str(watermark_ts)[:16]
|
||
else:
|
||
watermark_display = "无"
|
||
|
||
html = f'''<!DOCTYPE html>
|
||
<html lang="zh-CN">
|
||
<head>
|
||
<meta charset="UTF-8">
|
||
<style>
|
||
* {{ margin: 0; padding: 0; box-sizing: border-box; }}
|
||
body {{ font-family: "Microsoft YaHei", "PingFang SC", sans-serif; }}
|
||
#card {{
|
||
width: 600px;
|
||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||
padding: 24px;
|
||
min-height: 200px;
|
||
}}
|
||
.title-bar {{
|
||
display: flex; justify-content: space-between; align-items: center;
|
||
margin-bottom: 18px;
|
||
}}
|
||
.title {{
|
||
font-size: 22px; font-weight: bold; color: #fff;
|
||
text-shadow: 0 1px 3px rgba(0,0,0,0.3);
|
||
}}
|
||
.subtitle {{
|
||
font-size: 13px; color: rgba(255,255,255,0.75);
|
||
}}
|
||
.stats {{
|
||
display: flex; gap: 16px; margin-bottom: 18px;
|
||
}}
|
||
.stat-box {{
|
||
flex: 1; background: rgba(255,255,255,0.15);
|
||
backdrop-filter: blur(8px);
|
||
border-radius: 10px; padding: 12px 14px;
|
||
text-align: center; color: #fff;
|
||
}}
|
||
.stat-num {{
|
||
font-size: 26px; font-weight: bold;
|
||
}}
|
||
.stat-label {{
|
||
font-size: 12px; color: rgba(255,255,255,0.8); margin-top: 2px;
|
||
}}
|
||
.memory-card {{
|
||
background: rgba(255,255,255,0.92);
|
||
border-radius: 10px;
|
||
margin-bottom: 12px;
|
||
overflow: hidden;
|
||
box-shadow: 0 2px 8px rgba(0,0,0,0.08);
|
||
}}
|
||
.card-header {{
|
||
background: rgba(102,126,234,0.12);
|
||
padding: 8px 14px;
|
||
display: flex; justify-content: space-between; align-items: center;
|
||
}}
|
||
.card-index {{
|
||
font-size: 13px; font-weight: bold; color: #667eea;
|
||
}}
|
||
.card-time {{
|
||
font-size: 12px; color: #999;
|
||
}}
|
||
.card-body {{
|
||
padding: 12px 14px;
|
||
font-size: 13px; color: #333;
|
||
line-height: 1.7;
|
||
word-break: break-all;
|
||
}}
|
||
.footer {{
|
||
text-align: center; margin-top: 14px;
|
||
font-size: 11px; color: rgba(255,255,255,0.6);
|
||
}}
|
||
</style>
|
||
</head>
|
||
<body>
|
||
<div id="card">
|
||
<div class="title-bar">
|
||
<div>
|
||
<div class="title">🧠 向量记忆</div>
|
||
<div class="subtitle">{chatroom_id}</div>
|
||
</div>
|
||
<div class="subtitle">{now_str}</div>
|
||
</div>
|
||
<div class="stats">
|
||
<div class="stat-box">
|
||
<div class="stat-num">{len(items)}</div>
|
||
<div class="stat-label">摘要总数</div>
|
||
</div>
|
||
<div class="stat-box">
|
||
<div class="stat-num">{watermark_display}</div>
|
||
<div class="stat-label">最后摘要时间</div>
|
||
</div>
|
||
</div>
|
||
{cards_html}
|
||
<div class="footer">AIChat Vector Memory · Powered by Chroma</div>
|
||
</div>
|
||
</body>
|
||
</html>'''
|
||
return html
|
||
|
||
async def _render_vector_memory_image(self, html: str) -> str | None:
|
||
"""用 Playwright 将 HTML 渲染为截图,返回图片路径"""
|
||
try:
|
||
from plugins.SignInPlugin.html_renderer import get_browser
|
||
except ImportError:
|
||
try:
|
||
from playwright.async_api import async_playwright
|
||
pw = await async_playwright().start()
|
||
browser = await pw.chromium.launch(headless=True, args=['--no-sandbox'])
|
||
except Exception as e:
|
||
logger.error(f"[VectorMemory] Playwright 不可用: {e}")
|
||
return None
|
||
else:
|
||
browser = await get_browser()
|
||
|
||
try:
|
||
page = await browser.new_page()
|
||
await page.set_viewport_size({"width": 600, "height": 800})
|
||
await page.set_content(html)
|
||
await page.wait_for_selector("#card", timeout=5000)
|
||
element = await page.query_selector("#card")
|
||
if not element:
|
||
await page.close()
|
||
return None
|
||
output_dir = Path(__file__).parent / "data" / "temp"
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
ts = int(datetime.now().timestamp())
|
||
output_path = output_dir / f"vector_memory_{ts}.png"
|
||
await element.screenshot(path=str(output_path))
|
||
await page.close()
|
||
logger.success(f"[VectorMemory] 渲染成功: {output_path}")
|
||
return str(output_path)
|
||
except Exception as e:
|
||
logger.error(f"[VectorMemory] 渲染失败: {e}")
|
||
return None
|
||
|
||
def _get_chat_id(self, from_wxid: str, sender_wxid: str = None, is_group: bool = False) -> str:
|
||
"""获取会话ID"""
|
||
if is_group:
|
||
# 群聊使用 "群ID:用户ID" 组合,确保每个用户有独立的对话记忆
|
||
user_wxid = sender_wxid or from_wxid
|
||
return f"{from_wxid}:{user_wxid}"
|
||
else:
|
||
return sender_wxid or from_wxid # 私聊使用用户ID
|
||
|
||
def _get_group_history_chat_id(self, from_wxid: str, user_wxid: str = None) -> str:
|
||
"""获取群聊 history 的会话ID(可配置为全群共享或按用户隔离)"""
|
||
if not from_wxid:
|
||
return ""
|
||
|
||
history_config = (self.config or {}).get("history", {})
|
||
scope = str(history_config.get("scope", "chatroom") or "chatroom").strip().lower()
|
||
if scope in ("per_user", "user", "peruser"):
|
||
if not user_wxid:
|
||
return from_wxid
|
||
return self._get_chat_id(from_wxid, user_wxid, is_group=True)
|
||
|
||
return from_wxid
|
||
|
||
def _should_capture_group_history(self, *, is_triggered: bool) -> bool:
|
||
"""判断群聊消息是否需要写入 history(减少无关上下文污染)"""
|
||
history_config = (self.config or {}).get("history", {})
|
||
capture = str(history_config.get("capture", "all") or "all").strip().lower()
|
||
|
||
if capture in ("none", "off", "disable", "disabled"):
|
||
return False
|
||
if capture in ("reply", "ai_only", "triggered"):
|
||
return bool(is_triggered)
|
||
return True
|
||
|
||
def _parse_history_timestamp(self, ts) -> float | None:
|
||
if ts is None:
|
||
return None
|
||
if isinstance(ts, (int, float)):
|
||
return float(ts)
|
||
if isinstance(ts, str):
|
||
s = ts.strip()
|
||
if not s:
|
||
return None
|
||
try:
|
||
return float(s)
|
||
except Exception:
|
||
pass
|
||
try:
|
||
return datetime.fromisoformat(s).timestamp()
|
||
except Exception:
|
||
return None
|
||
return None
|
||
|
||
def _filter_history_by_window(self, history: list) -> list:
|
||
history_config = (self.config or {}).get("history", {})
|
||
window_seconds = history_config.get("context_window_seconds", None)
|
||
if window_seconds is None:
|
||
window_seconds = history_config.get("window_seconds", 0)
|
||
try:
|
||
window_seconds = float(window_seconds or 0)
|
||
except Exception:
|
||
window_seconds = 0
|
||
if window_seconds <= 0:
|
||
return history
|
||
|
||
cutoff = time.time() - window_seconds
|
||
filtered = []
|
||
for msg in history or []:
|
||
ts = self._parse_history_timestamp((msg or {}).get("timestamp"))
|
||
if ts is None or ts >= cutoff:
|
||
filtered.append(msg)
|
||
return filtered
|
||
|
||
def _sanitize_speaker_name(self, name: str) -> str:
|
||
"""清洗昵称,避免破坏历史格式(如 [name] 前缀)。"""
|
||
if name is None:
|
||
return ""
|
||
s = str(name).strip()
|
||
if not s:
|
||
return ""
|
||
s = s.replace("\r", " ").replace("\n", " ")
|
||
s = re.sub(r"\s{2,}", " ", s)
|
||
# 避免与历史前缀 [xxx] 冲突
|
||
s = s.replace("[", "(").replace("]", ")")
|
||
return s.strip()
|
||
|
||
def _combine_display_and_nickname(self, display_name: str, wechat_nickname: str) -> str:
|
||
display_name = self._sanitize_speaker_name(display_name)
|
||
wechat_nickname = self._sanitize_speaker_name(wechat_nickname)
|
||
# 重要:群昵称(群名片) 与 微信昵称(全局) 是两个不同概念,尽量同时给 AI。
|
||
if display_name and wechat_nickname:
|
||
return f"群昵称={display_name} | 微信昵称={wechat_nickname}"
|
||
if display_name:
|
||
return f"群昵称={display_name}"
|
||
if wechat_nickname:
|
||
return f"微信昵称={wechat_nickname}"
|
||
return ""
|
||
|
||
def _get_chatroom_member_lock(self, chatroom_id: str) -> asyncio.Lock:
|
||
lock = self._chatroom_member_cache_locks.get(chatroom_id)
|
||
if lock is None:
|
||
lock = asyncio.Lock()
|
||
self._chatroom_member_cache_locks[chatroom_id] = lock
|
||
return lock
|
||
|
||
async def _get_group_display_name(self, bot, chatroom_id: str, user_wxid: str, *, force_refresh: bool = False) -> str:
|
||
"""获取群名片(群内昵称)。失败时返回空串。"""
|
||
if not chatroom_id or not user_wxid:
|
||
return ""
|
||
if not hasattr(bot, "get_chatroom_members"):
|
||
return ""
|
||
|
||
now = time.time()
|
||
if not force_refresh:
|
||
cached = self._chatroom_member_cache.get(chatroom_id)
|
||
if cached:
|
||
ts, member_map = cached
|
||
if now - float(ts or 0) < float(self._chatroom_member_cache_ttl_seconds or 0):
|
||
return self._sanitize_speaker_name(member_map.get(user_wxid, ""))
|
||
|
||
lock = self._get_chatroom_member_lock(chatroom_id)
|
||
async with lock:
|
||
now = time.time()
|
||
if not force_refresh:
|
||
cached = self._chatroom_member_cache.get(chatroom_id)
|
||
if cached:
|
||
ts, member_map = cached
|
||
if now - float(ts or 0) < float(self._chatroom_member_cache_ttl_seconds or 0):
|
||
return self._sanitize_speaker_name(member_map.get(user_wxid, ""))
|
||
|
||
try:
|
||
# 群成员列表可能较大,避免长期阻塞消息处理
|
||
members = await asyncio.wait_for(bot.get_chatroom_members(chatroom_id), timeout=8)
|
||
except Exception as e:
|
||
logger.debug(f"获取群成员列表失败: {chatroom_id}, {e}")
|
||
return ""
|
||
|
||
member_map = {}
|
||
try:
|
||
for m in members or []:
|
||
wxid = (m.get("wxid") or "").strip()
|
||
if not wxid:
|
||
continue
|
||
display_name = m.get("display_name") or m.get("displayName") or ""
|
||
member_map[wxid] = str(display_name or "").strip()
|
||
except Exception as e:
|
||
logger.debug(f"解析群成员列表失败: {chatroom_id}, {e}")
|
||
|
||
self._chatroom_member_cache[chatroom_id] = (time.time(), member_map)
|
||
return self._sanitize_speaker_name(member_map.get(user_wxid, ""))
|
||
|
||
async def _get_user_display_label(self, bot, from_wxid: str, user_wxid: str, is_group: bool) -> str:
|
||
"""用于历史记录:群聊优先使用群名片,其次微信昵称。"""
|
||
if not is_group:
|
||
return ""
|
||
wechat_nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
|
||
group_display = await self._get_group_display_name(bot, from_wxid, user_wxid)
|
||
return self._combine_display_and_nickname(group_display, wechat_nickname) or wechat_nickname or user_wxid
|
||
|
||
async def _get_user_nickname(self, bot, from_wxid: str, user_wxid: str, is_group: bool) -> str:
|
||
"""
|
||
获取用户昵称,优先使用 Redis 缓存
|
||
|
||
Args:
|
||
bot: WechatHookClient 实例
|
||
from_wxid: 消息来源(群聊ID或私聊用户ID)
|
||
user_wxid: 用户wxid
|
||
is_group: 是否群聊
|
||
|
||
Returns:
|
||
用户昵称
|
||
"""
|
||
if not is_group:
|
||
return ""
|
||
|
||
nickname = ""
|
||
|
||
# 0. 优先从 MemberSync 数据库读取
|
||
try:
|
||
member_service = get_member_service()
|
||
if from_wxid:
|
||
member_info = await member_service.get_chatroom_member_info(from_wxid, user_wxid)
|
||
else:
|
||
member_info = await member_service.get_member_info(user_wxid)
|
||
if member_info and member_info.get("nickname"):
|
||
return member_info["nickname"]
|
||
except Exception as e:
|
||
logger.debug(f"[MemberSync数据库读取失败] {user_wxid}: {e}")
|
||
|
||
# 1. 优先从 Redis 缓存获取
|
||
redis_cache = get_cache()
|
||
if redis_cache and redis_cache.enabled:
|
||
cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid)
|
||
if cached_info and cached_info.get("nickname"):
|
||
logger.debug(f"[缓存命中] 用户昵称: {user_wxid} -> {cached_info['nickname']}")
|
||
return cached_info["nickname"]
|
||
|
||
# 2. 从 MessageLogger 数据库查询
|
||
if not nickname:
|
||
try:
|
||
from plugins.MessageLogger.main import MessageLogger
|
||
msg_logger = MessageLogger.get_instance()
|
||
if msg_logger:
|
||
with msg_logger.get_db_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute(
|
||
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
|
||
(user_wxid,)
|
||
)
|
||
result = cursor.fetchone()
|
||
if result:
|
||
nickname = result[0]
|
||
except Exception as e:
|
||
logger.debug(f"从数据库获取昵称失败: {e}")
|
||
|
||
# 3. 最后降级使用 wxid
|
||
if not nickname:
|
||
nickname = user_wxid or "未知用户"
|
||
|
||
return nickname
|
||
|
||
def _check_rate_limit(self, user_wxid: str) -> tuple:
|
||
"""
|
||
检查用户是否超过限流
|
||
|
||
Args:
|
||
user_wxid: 用户wxid
|
||
|
||
Returns:
|
||
(是否允许, 剩余次数, 重置时间秒数)
|
||
"""
|
||
rate_limit_config = self.config.get("rate_limit", {})
|
||
if not rate_limit_config.get("enabled", True):
|
||
return (True, 999, 0)
|
||
|
||
redis_cache = get_cache()
|
||
if not redis_cache or not redis_cache.enabled:
|
||
return (True, 999, 0) # Redis 不可用时不限流
|
||
|
||
limit = rate_limit_config.get("ai_chat_limit", 20)
|
||
window = rate_limit_config.get("ai_chat_window", 60)
|
||
|
||
return redis_cache.check_rate_limit(user_wxid, limit, window, "ai_chat")
|
||
|
||
def _add_to_memory(self, chat_id: str, role: str, content, image_base64: str = None):
|
||
"""
|
||
添加消息到记忆
|
||
|
||
Args:
|
||
chat_id: 会话ID
|
||
role: 角色 (user/assistant)
|
||
content: 消息内容(可以是字符串或列表)
|
||
image_base64: 可选的图片base64数据
|
||
"""
|
||
if not self.store:
|
||
return
|
||
self.store.add_private_message(chat_id, role, content, image_base64=image_base64)
|
||
|
||
def _get_memory_messages(self, chat_id: str) -> list:
|
||
"""获取记忆中的消息"""
|
||
if not self.store:
|
||
return []
|
||
return self.store.get_private_messages(chat_id)
|
||
|
||
def _clear_memory(self, chat_id: str):
|
||
"""清空指定会话的记忆"""
|
||
if not self.store:
|
||
return
|
||
self.store.clear_private_messages(chat_id)
|
||
|
||
async def _download_and_encode_image(self, bot, message: dict) -> str:
|
||
"""下载图片并转换为base64,委托给 ImageProcessor(使用新接口)"""
|
||
if self._image_processor:
|
||
return await self._image_processor.download_image(bot, message)
|
||
logger.warning("ImageProcessor 未初始化,无法下载图片")
|
||
return ""
|
||
|
||
async def _download_emoji_and_encode(self, cdn_url: str, max_retries: int = 3) -> str:
|
||
"""下载表情包并转换为base64,委托给 ImageProcessor"""
|
||
if self._image_processor:
|
||
return await self._image_processor.download_emoji(cdn_url, max_retries)
|
||
logger.warning("ImageProcessor 未初始化,无法下载表情包")
|
||
return ""
|
||
|
||
async def _generate_image_description(self, image_base64: str, prompt: str, config: dict) -> str:
|
||
"""
|
||
使用 AI 生成图片描述,委托给 ImageProcessor
|
||
|
||
Args:
|
||
image_base64: 图片的 base64 数据
|
||
prompt: 描述提示词
|
||
config: 图片描述配置
|
||
|
||
Returns:
|
||
图片描述文本,失败返回空字符串
|
||
"""
|
||
if self._image_processor:
|
||
model = config.get("model")
|
||
return await self._image_processor.generate_description(image_base64, prompt, model)
|
||
logger.warning("ImageProcessor 未初始化,无法生成图片描述")
|
||
return ""
|
||
|
||
def _collect_tools_with_plugins(self) -> dict:
|
||
"""收集工具定义(来自 ToolRegistry)并保留来源插件名"""
|
||
registry = get_tool_registry()
|
||
tools_config = (self.config or {}).get("tools", {})
|
||
mode = tools_config.get("mode", "all")
|
||
whitelist = set(tools_config.get("whitelist", []))
|
||
blacklist = set(tools_config.get("blacklist", []))
|
||
|
||
tools_map = {}
|
||
for name in registry.list_tools():
|
||
tool_def = registry.get(name)
|
||
if not tool_def:
|
||
continue
|
||
if mode == "whitelist" and name not in whitelist:
|
||
continue
|
||
if mode == "blacklist" and name in blacklist:
|
||
continue
|
||
tools_map[name] = (tool_def.plugin_name, tool_def.schema)
|
||
|
||
return tools_map
|
||
|
||
def _collect_tools(self):
|
||
"""收集所有插件的LLM工具(支持白名单/黑名单过滤)"""
|
||
tools_map = self._collect_tools_with_plugins()
|
||
return [item[1] for item in tools_map.values()]
|
||
|
||
def _get_tool_schema_map(self, tools_map: dict | None = None) -> dict:
|
||
"""构建工具名到参数 schema 的映射"""
|
||
tools_map = tools_map or self._collect_tools_with_plugins()
|
||
schema_map = {}
|
||
for name, (_plugin_name, tool) in tools_map.items():
|
||
fn = tool.get("function", {})
|
||
schema_map[name] = fn.get("parameters", {}) or {}
|
||
return schema_map
|
||
|
||
def _guess_tool_param_description(self, param_name: str) -> str:
|
||
"""为缺失描述的参数补充通用说明"""
|
||
key = str(param_name or "").strip().lower()
|
||
hints = {
|
||
"query": "检索关键词或要查询的问题",
|
||
"keyword": "关键词",
|
||
"prompt": "用于生成内容的提示词",
|
||
"text": "要处理的文本内容",
|
||
"content": "主要内容",
|
||
"url": "目标链接地址",
|
||
"urls": "目标链接地址列表",
|
||
"image_url": "图片链接地址",
|
||
"image_base64": "图片的 Base64 编码数据",
|
||
"location": "地点名称(如城市/地区)",
|
||
"city": "城市名称",
|
||
"name": "名称",
|
||
"id": "目标对象 ID",
|
||
"user_id": "用户 ID",
|
||
"group_id": "群 ID",
|
||
"count": "数量",
|
||
"num": "数量",
|
||
"mode": "执行模式",
|
||
"type": "类型",
|
||
}
|
||
return hints.get(key, f"参数 {param_name} 的取值")
|
||
|
||
def _normalize_tool_schema_for_llm(self, tool: dict) -> dict | None:
|
||
"""标准化工具 schema,提升模型函数调用稳定性。"""
|
||
if not isinstance(tool, dict):
|
||
return None
|
||
|
||
normalized = copy.deepcopy(tool)
|
||
normalized["type"] = "function"
|
||
|
||
function_def = normalized.get("function")
|
||
if not isinstance(function_def, dict):
|
||
return None
|
||
|
||
function_name = str(function_def.get("name", "")).strip()
|
||
if not function_name:
|
||
return None
|
||
|
||
description = str(function_def.get("description", "")).strip()
|
||
if not description:
|
||
function_def["description"] = f"调用 {function_name} 工具完成任务,仅在用户明确需要时使用。"
|
||
|
||
parameters = function_def.get("parameters")
|
||
if not isinstance(parameters, dict):
|
||
parameters = {"type": "object", "properties": {}, "required": []}
|
||
|
||
properties = parameters.get("properties")
|
||
if not isinstance(properties, dict):
|
||
properties = {}
|
||
|
||
for param_name, param_schema in list(properties.items()):
|
||
if not isinstance(param_schema, dict):
|
||
properties[param_name] = {
|
||
"type": "string",
|
||
"description": self._guess_tool_param_description(param_name),
|
||
}
|
||
continue
|
||
|
||
if not str(param_schema.get("description", "")).strip():
|
||
param_schema["description"] = self._guess_tool_param_description(param_name)
|
||
|
||
if not str(param_schema.get("type", "")).strip() and "enum" not in param_schema:
|
||
param_schema["type"] = "string"
|
||
|
||
required = parameters.get("required", [])
|
||
if not isinstance(required, list):
|
||
required = []
|
||
required = [item for item in required if isinstance(item, str) and item in properties]
|
||
|
||
parameters["type"] = "object"
|
||
parameters["properties"] = properties
|
||
parameters["required"] = required
|
||
parameters["additionalProperties"] = False
|
||
function_def["parameters"] = parameters
|
||
normalized["function"] = function_def
|
||
return normalized
|
||
|
||
def _prepare_tools_for_llm(self, tools: list) -> list:
|
||
"""预处理工具声明(补描述、补参数 schema、收敛格式)。"""
|
||
prepared = []
|
||
for tool in tools or []:
|
||
normalized = self._normalize_tool_schema_for_llm(tool)
|
||
if normalized:
|
||
prepared.append(normalized)
|
||
return prepared
|
||
|
||
def _build_tool_rules_prompt(self, tools: list) -> str:
|
||
"""构建函数调用规则提示词(参考 Eridanus 风格)。"""
|
||
lines = [
|
||
"【函数调用规则】",
|
||
"1) 仅可基于【当前消息】决定是否调用工具;历史内容只用于语境,不可据此触发工具。",
|
||
"2) 能直接回答就直接回答;只有在需要外部能力时才调用工具。",
|
||
"3) 关键参数不完整时先追问澄清,不要臆测。",
|
||
"4) 禁止向用户输出 function_call/tool_calls 或 JSON 调用片段。",
|
||
"5) 工具执行后请输出自然语言总结:先结论,再补充细节。",
|
||
]
|
||
|
||
if not tools:
|
||
lines.append("本轮未提供可调用工具,禁止伪造函数调用。")
|
||
lines.append("严禁在回复中输出任何 JSON、function_call、tool_calls、action/actioninput 格式的内容,即使历史上下文中出现过工具调用记录也不得模仿。只用自然语言回复。")
|
||
return "\n\n" + "\n".join(lines)
|
||
|
||
lines.append("本轮可用工具:")
|
||
for tool in (tools or [])[:12]:
|
||
function_def = tool.get("function", {}) if isinstance(tool, dict) else {}
|
||
name = str(function_def.get("name", "")).strip()
|
||
description = str(function_def.get("description", "")).strip()
|
||
if not name:
|
||
continue
|
||
if len(description) > 70:
|
||
description = description[:67] + "..."
|
||
lines.append(f"- {name}: {description or '按工具定义执行'}")
|
||
|
||
return "\n\n" + "\n".join(lines)
|
||
|
||
async def _handle_list_prompts(self, bot, from_wxid: str):
|
||
"""处理人设列表指令"""
|
||
try:
|
||
prompts_dir = Path(__file__).parent / "prompts"
|
||
|
||
# 获取所有 .txt 文件
|
||
if not prompts_dir.exists():
|
||
await bot.send_text(from_wxid, "❌ prompts 目录不存在")
|
||
return
|
||
|
||
txt_files = sorted(prompts_dir.glob("*.txt"))
|
||
|
||
if not txt_files:
|
||
await bot.send_text(from_wxid, "❌ 没有找到任何人设文件")
|
||
return
|
||
|
||
# 构建列表消息
|
||
current_file = self.config["prompt"]["system_prompt_file"]
|
||
msg = "📋 可用人设列表:\n\n"
|
||
|
||
for i, file_path in enumerate(txt_files, 1):
|
||
filename = file_path.name
|
||
# 标记当前使用的人设
|
||
if filename == current_file:
|
||
msg += f"{i}. {filename} ✅\n"
|
||
else:
|
||
msg += f"{i}. {filename}\n"
|
||
|
||
msg += f"\n💡 使用方法:/切人设 文件名.txt"
|
||
|
||
await bot.send_text(from_wxid, msg)
|
||
logger.info("已发送人设列表")
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取人设列表失败: {e}")
|
||
await bot.send_text(from_wxid, f"❌ 获取人设列表失败: {str(e)}")
|
||
|
||
def _estimate_tokens(self, text: str) -> int:
|
||
"""
|
||
估算文本的 token 数量
|
||
|
||
简单估算规则:
|
||
- 中文:约 1.5 字符 = 1 token
|
||
- 英文:约 4 字符 = 1 token
|
||
- 混合文本取平均
|
||
"""
|
||
if not text:
|
||
return 0
|
||
|
||
# 统计中文字符数
|
||
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||
# 其他字符数
|
||
other_chars = len(text) - chinese_chars
|
||
|
||
# 估算 token 数
|
||
chinese_tokens = chinese_chars / 1.5
|
||
other_tokens = other_chars / 4
|
||
|
||
return int(chinese_tokens + other_tokens)
|
||
|
||
def _estimate_message_tokens(self, message: dict) -> int:
|
||
"""估算单条消息的 token 数"""
|
||
content = message.get("content", "")
|
||
|
||
if isinstance(content, str):
|
||
return self._estimate_tokens(content)
|
||
elif isinstance(content, list):
|
||
# 多模态消息
|
||
total = 0
|
||
for item in content:
|
||
if item.get("type") == "text":
|
||
total += self._estimate_tokens(item.get("text", ""))
|
||
elif item.get("type") == "image_url":
|
||
# 图片按 85 token 估算(OpenAI 低分辨率图片)
|
||
total += 85
|
||
return total
|
||
return 0
|
||
|
||
def _extract_text_from_multimodal(self, content) -> str:
|
||
"""从多模态 content 中提取文本,模型不支持时用于降级"""
|
||
if isinstance(content, list):
|
||
texts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
||
text = "".join(texts).strip()
|
||
return text or "[图片]"
|
||
if content is None:
|
||
return ""
|
||
return str(content)
|
||
|
||
def _extract_last_user_text(self, messages: list) -> str:
|
||
"""从 messages 中提取最近一条用户文本,用于工具参数兜底。"""
|
||
for msg in reversed(messages or []):
|
||
if msg.get("role") == "user":
|
||
return self._extract_text_from_multimodal(msg.get("content"))
|
||
return ""
|
||
|
||
def _sanitize_llm_output(self, text) -> str:
|
||
"""
|
||
清洗 LLM 输出,尽量满足:不输出思维链、不使用 Markdown。
|
||
|
||
说明:提示词并非强约束,因此在所有“发给用户/写入上下文”的出口统一做后处理。
|
||
"""
|
||
if text is None:
|
||
return ""
|
||
raw = str(text)
|
||
cleaned = raw
|
||
|
||
# 清理 xAI/Grok 风格的渲染卡片标签,避免被当作正文/上下文继续传播
|
||
cleaned = re.sub(
|
||
r"(?is)<grok:render\b[\s\S]*?</grok:render>",
|
||
"",
|
||
cleaned,
|
||
)
|
||
|
||
output_cfg = (self.config or {}).get("output", {})
|
||
strip_thinking = output_cfg.get("strip_thinking", True)
|
||
strip_markdown = output_cfg.get("strip_markdown", True)
|
||
|
||
# 先做一次 Markdown 清理,避免 “**思考过程:**/### 思考” 这类包裹导致无法识别
|
||
if strip_markdown:
|
||
cleaned = self._strip_markdown_syntax(cleaned)
|
||
|
||
if strip_thinking:
|
||
cleaned = self._strip_thinking_content(cleaned)
|
||
|
||
# 清理模型偶发输出的“文本工具调用”痕迹(如 tavilywebsearch{query:...} / <ctrl46>)
|
||
# 这些内容既不是正常回复,也会破坏“工具只能用 Function Calling”的约束
|
||
try:
|
||
cleaned = re.sub(r"<ctrl\\d+>", "", cleaned, flags=re.IGNORECASE)
|
||
cleaned = re.sub(
|
||
r"(?:展开阅读下文\\s*)?(?:tavilywebsearch|tavily_web_search|web_search)\\s*\\{[^{}]{0,1500}\\}",
|
||
"",
|
||
cleaned,
|
||
flags=re.IGNORECASE,
|
||
)
|
||
cleaned = re.sub(
|
||
r"(?:tavilywebsearch|tavily_web_search|web_search)\\s*\\([^\\)]{0,1500}\\)",
|
||
"",
|
||
cleaned,
|
||
flags=re.IGNORECASE,
|
||
)
|
||
cleaned = re.sub(
|
||
r"\{[^\{\}]{0,2000}[\"']name[\"']\s*:\s*[\"'](?:draw_image|generate_image|nano_ai_image_generation|flow2_ai_image_generation|jimeng_ai_image_generation|kiira2_ai_image_generation)[\"'][\s\S]{0,2000}\}",
|
||
"",
|
||
cleaned,
|
||
flags=re.IGNORECASE,
|
||
)
|
||
cleaned = cleaned.replace("展开阅读下文", "")
|
||
cleaned = re.sub(
|
||
r"[((]已触发工具处理[^))\r\n]{0,500}[))]?",
|
||
"",
|
||
cleaned,
|
||
)
|
||
cleaned = re.sub(r"(?m)^.*已触发工具处理.*$", "", cleaned)
|
||
cleaned = re.sub(r"(?m)^.*结果将发送到聊天中.*$", "", cleaned)
|
||
# 清理历史记录格式标签 [时间:...][类型:...]
|
||
cleaned = re.sub(r"\[时间:[^\]]*\]", "", cleaned)
|
||
cleaned = re.sub(r"\[类型:[^\]]*\]", "", cleaned)
|
||
# 过滤图片占位符/文件名,避免把日志占位符当成正文发出去
|
||
cleaned = re.sub(
|
||
r"\\[图片[^\\]]*\\]\\s*\\S+\\.(?:png|jpe?g|gif|webp)",
|
||
"",
|
||
cleaned,
|
||
flags=re.IGNORECASE,
|
||
)
|
||
cleaned = re.sub(r"\\[图片[^\\]]*\\]", "", cleaned)
|
||
except Exception:
|
||
pass
|
||
|
||
# 再跑一轮:部分模型会把“思考/最终”标记写成 Markdown,或在剥离标签后才露出标记
|
||
if strip_markdown:
|
||
cleaned = self._strip_markdown_syntax(cleaned)
|
||
if strip_thinking:
|
||
cleaned = self._strip_thinking_content(cleaned)
|
||
|
||
cleaned = cleaned.strip()
|
||
# 兜底:清洗后仍残留明显“思维链/大纲”标记时,再尝试一次“抽取最终段”
|
||
if strip_thinking and cleaned and self._contains_thinking_markers(cleaned):
|
||
extracted = self._extract_after_last_answer_marker(cleaned)
|
||
if not extracted:
|
||
extracted = self._extract_final_answer_from_outline(cleaned)
|
||
if extracted:
|
||
cleaned = extracted.strip()
|
||
# 仍残留标记:尽量选取最后一个“不含标记”的段落作为最终回复
|
||
if cleaned and self._contains_thinking_markers(cleaned):
|
||
parts = [p.strip() for p in re.split(r"\n{2,}", cleaned) if p.strip()]
|
||
for p in reversed(parts):
|
||
if not self._contains_thinking_markers(p):
|
||
cleaned = p
|
||
break
|
||
cleaned = cleaned.strip()
|
||
|
||
# 最终兜底:仍然像思维链就直接丢弃(宁可不发也不要把思维链发出去)
|
||
if strip_thinking and cleaned and self._contains_thinking_markers(cleaned):
|
||
return ""
|
||
|
||
if cleaned:
|
||
return cleaned
|
||
|
||
raw_stripped = raw.strip()
|
||
# 清洗后为空时,不要回退到包含思维链标记的原文(避免把 <think>... 直接发出去)
|
||
if strip_thinking and self._contains_thinking_markers(raw_stripped):
|
||
return ""
|
||
if self._contains_tool_call_markers(raw_stripped):
|
||
return ""
|
||
return raw_stripped
|
||
|
||
async def _maybe_send_voice_reply(self, bot, to_wxid: str, text: str, message: dict | None = None):
|
||
"""AI 回复后,按概率触发语音回复"""
|
||
if not text:
|
||
return
|
||
try:
|
||
voice_plugin = self.get_plugin("VoiceSynth")
|
||
if not voice_plugin or not getattr(voice_plugin, "enabled", True):
|
||
return
|
||
if not getattr(voice_plugin, "master_enabled", True):
|
||
return
|
||
handler = getattr(voice_plugin, "maybe_send_voice_reply", None)
|
||
if not handler:
|
||
return
|
||
asyncio.create_task(handler(bot, to_wxid, text, message=message))
|
||
except Exception as e:
|
||
logger.debug(f"触发语音回复失败: {e}")
|
||
|
||
def _contains_thinking_markers(self, text: str) -> bool:
|
||
"""粗略判断文本是否包含明显的“思考/推理”外显标记,用于决定是否允许回退原文。"""
|
||
if not text:
|
||
return False
|
||
|
||
lowered = text.lower()
|
||
tag_tokens = (
|
||
"<think", "</think",
|
||
"<analysis", "</analysis",
|
||
"<reasoning", "</reasoning",
|
||
"<thought", "</thought",
|
||
"<thinking", "</thinking",
|
||
"<thoughts", "</thoughts",
|
||
"<scratchpad", "</scratchpad",
|
||
"<think", "</think",
|
||
"<analysis", "</analysis",
|
||
"<reasoning", "</reasoning",
|
||
"<thought", "</thought",
|
||
"<thinking", "</thinking",
|
||
"<thoughts", "</thoughts",
|
||
"<scratchpad", "</scratchpad",
|
||
)
|
||
if any(tok in lowered for tok in tag_tokens):
|
||
return True
|
||
|
||
stripped = text.strip()
|
||
if stripped.startswith("{") and stripped.endswith("}"):
|
||
# JSON 结构化输出(常见于“analysis/final”)
|
||
json_keys = (
|
||
"\"analysis\"",
|
||
"\"reasoning\"",
|
||
"\"thought\"",
|
||
"\"thoughts\"",
|
||
"\"scratchpad\"",
|
||
"\"final\"",
|
||
"\"answer\"",
|
||
"\"response\"",
|
||
"\"output\"",
|
||
"\"text\"",
|
||
)
|
||
if any(k in lowered for k in json_keys):
|
||
return True
|
||
|
||
# YAML/KV 风格
|
||
if re.search(r"(?im)^\s*(analysis|reasoning|thoughts?|scratchpad|final|answer|response|output|text|思考|分析|推理|最终|输出)\s*[::]", text):
|
||
return True
|
||
|
||
marker_re = re.compile(
|
||
r"(?mi)^\s*(?:\d+\s*[\.\、::))\-–—]\s*)?(?:[-*•]+\s*)?"
|
||
r"(?:【\s*(?:思考过程|推理过程|分析过程|思考|分析|推理|内心独白|内心os|思维链|思路|"
|
||
r"chain\s*of\s*thought|reasoning|analysis|thinking|thoughts|thought\s*process|scratchpad)\s*】"
|
||
r"|(?:思考过程|推理过程|分析过程|思考|分析|推理|内心独白|内心os|思维链|思路|"
|
||
r"chain\s*of\s*thought|reasoning|analysis|analyze|thinking|thoughts|thought\s*process|scratchpad|internal\s*monologue|mind\s*space|final\s*polish|output\s*generation)"
|
||
r"(?:\s*】)?\s*(?:[::]|$|\s+))"
|
||
)
|
||
return marker_re.search(text) is not None
|
||
|
||
def _contains_tool_call_markers(self, text: str) -> bool:
|
||
if not text:
|
||
return False
|
||
lowered = text.lower()
|
||
if "<ctrl" in lowered or "展开阅读下文" in text:
|
||
return True
|
||
if "<grok:render" in lowered:
|
||
return True
|
||
if re.search(r"(?i)\"?function_call\"?\s*[:=]\s*\{", text):
|
||
return True
|
||
if re.search(r"(?i)\bfunction_call\s*\(", text):
|
||
return True
|
||
if re.search(r"(?i)(tavilywebsearch|tavily_web_search|web_search)\s*[\{\(]", text):
|
||
return True
|
||
if re.search(
|
||
r"(?i)[\"']name[\"']\s*:\s*[\"'](draw_image|generate_image|nano_ai_image_generation|flow2_ai_image_generation|jimeng_ai_image_generation|kiira2_ai_image_generation)[\"']",
|
||
text,
|
||
):
|
||
return True
|
||
if re.search(
|
||
r"(?i)\b(?:print\s*\(\s*)?(draw_image|generate_image|nano_ai_image_generation|flow2_ai_image_generation|jimeng_ai_image_generation|kiira2_ai_image_generation)\s*\(",
|
||
text,
|
||
):
|
||
return True
|
||
if re.search(
|
||
r"(?i)[\"']action[\"']\s*:\s*[\"'](draw_image|generate_image|nano_ai_image_generation|flow2_ai_image_generation|jimeng_ai_image_generation|kiira2_ai_image_generation|nanoaiimage_generation|flow2aiimage_generation|jimengaiimage_generation|kiira2aiimage_generation)[\"']",
|
||
text,
|
||
):
|
||
return True
|
||
return False
|
||
|
||
def _extract_tool_calls_data(self, message: dict, choice: dict | None = None) -> list:
|
||
"""统一提取工具调用,兼容 tool_calls 与旧版 function_call。"""
|
||
choice = choice if isinstance(choice, dict) else {}
|
||
message = message if isinstance(message, dict) else {}
|
||
|
||
tool_calls = message.get("tool_calls") or choice.get("tool_calls") or []
|
||
if isinstance(tool_calls, dict):
|
||
tool_calls = [tool_calls]
|
||
if isinstance(tool_calls, list) and tool_calls:
|
||
return tool_calls
|
||
|
||
function_call = message.get("function_call") or choice.get("function_call")
|
||
if not isinstance(function_call, dict):
|
||
return []
|
||
|
||
function_name = (function_call.get("name") or "").strip()
|
||
if not function_name:
|
||
return []
|
||
|
||
arguments = function_call.get("arguments", "{}")
|
||
if not isinstance(arguments, str):
|
||
try:
|
||
arguments = json.dumps(arguments, ensure_ascii=False)
|
||
except Exception:
|
||
arguments = "{}"
|
||
|
||
return [{
|
||
"id": f"legacy_fc_{uuid.uuid4().hex[:8]}",
|
||
"type": "function",
|
||
"function": {
|
||
"name": function_name,
|
||
"arguments": arguments,
|
||
},
|
||
}]
|
||
|
||
def _normalize_dialog_api_mode(self, mode: str) -> str:
|
||
value = str(mode or "").strip().lower()
|
||
aliases = {
|
||
"auto": "auto",
|
||
"openai_chat_completions": "openai_chat_completions",
|
||
"chat_completions": "openai_chat_completions",
|
||
"chat": "openai_chat_completions",
|
||
"openai_chat": "openai_chat_completions",
|
||
"openai_responses": "openai_responses",
|
||
"responses": "openai_responses",
|
||
"openai_completions": "openai_completions",
|
||
"completions": "openai_completions",
|
||
"claude_messages": "claude_messages",
|
||
"claude": "claude_messages",
|
||
"anthropic": "claude_messages",
|
||
"anthropic_messages": "claude_messages",
|
||
"gemini_generate_content": "gemini_generate_content",
|
||
"gemini": "gemini_generate_content",
|
||
"gemini_models": "gemini_generate_content",
|
||
}
|
||
return aliases.get(value, "")
|
||
|
||
def _resolve_dialog_api_mode(self, api_config: dict) -> str:
|
||
configured = self._normalize_dialog_api_mode((api_config or {}).get("mode", "auto"))
|
||
if configured and configured != "auto":
|
||
return configured
|
||
|
||
api_url = str((api_config or {}).get("url", "")).lower()
|
||
if "/v1/responses" in api_url:
|
||
return "openai_responses"
|
||
if re.search(r"/v1/messages(?:[/?]|$)", api_url):
|
||
return "claude_messages"
|
||
if "v1beta/models" in api_url or ":generatecontent" in api_url:
|
||
return "gemini_generate_content"
|
||
if re.search(r"/v1/completions(?:[/?]|$)", api_url) and "/chat/completions" not in api_url:
|
||
return "openai_completions"
|
||
return "openai_chat_completions"
|
||
|
||
def _build_gemini_generate_content_url(self, api_url: str, model: str) -> str:
|
||
url = str(api_url or "").strip().rstrip("/")
|
||
if not url:
|
||
return ""
|
||
if ":generatecontent" in url.lower():
|
||
return url
|
||
if "/models/" in url:
|
||
return f"{url}:generateContent"
|
||
return f"{url}/{model}:generateContent"
|
||
|
||
def _parse_data_url(self, data_url: str) -> tuple[str | None, str | None]:
|
||
s = str(data_url or "").strip()
|
||
if not s.startswith("data:") or "," not in s:
|
||
return None, None
|
||
header, b64_data = s.split(",", 1)
|
||
mime = "application/octet-stream"
|
||
if ";" in header:
|
||
mime = header[5:].split(";", 1)[0].strip() or mime
|
||
elif ":" in header:
|
||
mime = header.split(":", 1)[1].strip() or mime
|
||
return mime, b64_data
|
||
|
||
def _convert_openai_content_to_claude_blocks(self, content) -> list:
|
||
if isinstance(content, str):
|
||
text = content.strip()
|
||
return [{"type": "text", "text": text or " "}]
|
||
|
||
blocks = []
|
||
if isinstance(content, list):
|
||
for item in content:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
item_type = item.get("type")
|
||
if item_type == "text":
|
||
text = str(item.get("text", "")).strip()
|
||
if text:
|
||
blocks.append({"type": "text", "text": text})
|
||
elif item_type == "image_url":
|
||
image_url = ((item.get("image_url") or {}).get("url") or "").strip()
|
||
mime, b64_data = self._parse_data_url(image_url)
|
||
if mime and b64_data:
|
||
blocks.append({
|
||
"type": "image",
|
||
"source": {
|
||
"type": "base64",
|
||
"media_type": mime,
|
||
"data": b64_data,
|
||
},
|
||
})
|
||
elif image_url:
|
||
blocks.append({"type": "text", "text": f"[图片链接] {image_url}"})
|
||
|
||
if not blocks:
|
||
text = self._extract_text_from_multimodal(content)
|
||
blocks.append({"type": "text", "text": text or " "})
|
||
return blocks
|
||
|
||
def _convert_openai_content_to_gemini_parts(self, content) -> list:
|
||
if isinstance(content, str):
|
||
text = content.strip()
|
||
return [{"text": text or " "}]
|
||
|
||
parts = []
|
||
if isinstance(content, list):
|
||
for item in content:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
item_type = item.get("type")
|
||
if item_type == "text":
|
||
text = str(item.get("text", "")).strip()
|
||
if text:
|
||
parts.append({"text": text})
|
||
elif item_type == "image_url":
|
||
image_url = ((item.get("image_url") or {}).get("url") or "").strip()
|
||
mime, b64_data = self._parse_data_url(image_url)
|
||
if mime and b64_data:
|
||
parts.append({
|
||
"inline_data": {
|
||
"mime_type": mime,
|
||
"data": b64_data,
|
||
}
|
||
})
|
||
elif image_url:
|
||
parts.append({"text": f"[图片链接] {image_url}"})
|
||
|
||
if not parts:
|
||
text = self._extract_text_from_multimodal(content)
|
||
parts.append({"text": text or " "})
|
||
return parts
|
||
|
||
def _convert_openai_messages_to_claude(self, messages: list) -> tuple[str, list]:
|
||
system_parts = []
|
||
claude_messages = []
|
||
tool_name_by_id = {}
|
||
|
||
for msg in messages or []:
|
||
role = str(msg.get("role", "")).strip().lower()
|
||
if role == "system":
|
||
text = self._extract_text_from_multimodal(msg.get("content"))
|
||
if text:
|
||
system_parts.append(text)
|
||
continue
|
||
|
||
if role == "assistant":
|
||
content_blocks = self._convert_openai_content_to_claude_blocks(msg.get("content"))
|
||
tool_calls = msg.get("tool_calls") or []
|
||
if isinstance(tool_calls, dict):
|
||
tool_calls = [tool_calls]
|
||
for tc in tool_calls:
|
||
function = (tc or {}).get("function") or {}
|
||
fn_name = function.get("name", "")
|
||
if not fn_name:
|
||
continue
|
||
tool_id = (tc or {}).get("id") or f"claude_tool_{uuid.uuid4().hex[:8]}"
|
||
raw_args = function.get("arguments", "{}")
|
||
try:
|
||
args = json.loads(raw_args) if isinstance(raw_args, str) else (raw_args or {})
|
||
if not isinstance(args, dict):
|
||
args = {}
|
||
except Exception:
|
||
args = {}
|
||
content_blocks.append({
|
||
"type": "tool_use",
|
||
"id": tool_id,
|
||
"name": fn_name,
|
||
"input": args,
|
||
})
|
||
tool_name_by_id[tool_id] = fn_name
|
||
if content_blocks:
|
||
claude_messages.append({"role": "assistant", "content": content_blocks})
|
||
continue
|
||
|
||
if role == "tool":
|
||
tool_id = str(msg.get("tool_call_id", "")).strip()
|
||
result_text = self._extract_text_from_multimodal(msg.get("content"))
|
||
if tool_id:
|
||
block = {"type": "tool_result", "tool_use_id": tool_id, "content": result_text}
|
||
claude_messages.append({"role": "user", "content": [block]})
|
||
else:
|
||
claude_messages.append({"role": "user", "content": [{"type": "text", "text": f"[工具结果]\n{result_text}"}]})
|
||
continue
|
||
|
||
user_blocks = self._convert_openai_content_to_claude_blocks(msg.get("content"))
|
||
claude_messages.append({"role": "user", "content": user_blocks})
|
||
|
||
if not claude_messages:
|
||
claude_messages = [{"role": "user", "content": [{"type": "text", "text": "你好"}]}]
|
||
|
||
return "\n\n".join(system_parts).strip(), claude_messages
|
||
|
||
def _convert_openai_messages_to_gemini(self, messages: list) -> tuple[str, list]:
|
||
system_parts = []
|
||
contents = []
|
||
tool_name_by_id = {}
|
||
|
||
for msg in messages or []:
|
||
role = str(msg.get("role", "")).strip().lower()
|
||
if role == "system":
|
||
text = self._extract_text_from_multimodal(msg.get("content"))
|
||
if text:
|
||
system_parts.append(text)
|
||
continue
|
||
|
||
if role == "assistant":
|
||
parts = self._convert_openai_content_to_gemini_parts(msg.get("content"))
|
||
tool_calls = msg.get("tool_calls") or []
|
||
if isinstance(tool_calls, dict):
|
||
tool_calls = [tool_calls]
|
||
for tc in tool_calls:
|
||
function = (tc or {}).get("function") or {}
|
||
fn_name = function.get("name", "")
|
||
if not fn_name:
|
||
continue
|
||
raw_args = function.get("arguments", "{}")
|
||
try:
|
||
args = json.loads(raw_args) if isinstance(raw_args, str) else (raw_args or {})
|
||
if not isinstance(args, dict):
|
||
args = {}
|
||
except Exception:
|
||
args = {}
|
||
tool_id = (tc or {}).get("id") or f"gemini_tool_{uuid.uuid4().hex[:8]}"
|
||
tool_name_by_id[tool_id] = fn_name
|
||
parts.append({"functionCall": {"name": fn_name, "args": args}})
|
||
contents.append({"role": "model", "parts": parts or [{"text": " "}]})
|
||
continue
|
||
|
||
if role == "tool":
|
||
tool_id = str(msg.get("tool_call_id", "")).strip()
|
||
fn_name = tool_name_by_id.get(tool_id, "tool_result")
|
||
tool_text = self._extract_text_from_multimodal(msg.get("content"))
|
||
parts = [{
|
||
"functionResponse": {
|
||
"name": fn_name,
|
||
"response": {
|
||
"content": tool_text
|
||
},
|
||
}
|
||
}]
|
||
contents.append({"role": "user", "parts": parts})
|
||
continue
|
||
|
||
parts = self._convert_openai_content_to_gemini_parts(msg.get("content"))
|
||
contents.append({"role": "user", "parts": parts})
|
||
|
||
if not contents:
|
||
contents = [{"role": "user", "parts": [{"text": "你好"}]}]
|
||
|
||
return "\n\n".join(system_parts).strip(), contents
|
||
|
||
def _convert_openai_messages_to_plain_prompt(self, messages: list) -> str:
|
||
lines = []
|
||
for msg in messages or []:
|
||
role = str(msg.get("role", "")).strip().lower() or "user"
|
||
content = self._extract_text_from_multimodal(msg.get("content"))
|
||
if role == "tool":
|
||
tool_id = str(msg.get("tool_call_id", "")).strip()
|
||
role_label = f"tool:{tool_id}" if tool_id else "tool"
|
||
else:
|
||
role_label = role
|
||
lines.append(f"[{role_label}] {content}")
|
||
return "\n".join(lines).strip() or "你好"
|
||
|
||
def _convert_openai_messages_to_responses_input(self, messages: list) -> list:
|
||
input_messages = []
|
||
for msg in messages or []:
|
||
role = str(msg.get("role", "")).strip().lower()
|
||
if role not in ("system", "user", "assistant"):
|
||
role = "user"
|
||
|
||
content = msg.get("content")
|
||
blocks = []
|
||
if isinstance(content, list):
|
||
for item in content:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
item_type = item.get("type")
|
||
if item_type == "text":
|
||
text = str(item.get("text", "")).strip()
|
||
if text:
|
||
blocks.append({"type": "input_text", "text": text})
|
||
elif item_type == "image_url":
|
||
image_url = ((item.get("image_url") or {}).get("url") or "").strip()
|
||
if image_url:
|
||
blocks.append({"type": "input_image", "image_url": image_url})
|
||
else:
|
||
text = self._extract_text_from_multimodal(content)
|
||
if text:
|
||
blocks.append({"type": "input_text", "text": text})
|
||
|
||
if not blocks:
|
||
blocks.append({"type": "input_text", "text": " "})
|
||
input_messages.append({"role": role, "content": blocks})
|
||
|
||
if not input_messages:
|
||
input_messages = [{"role": "user", "content": [{"type": "input_text", "text": "你好"}]}]
|
||
return input_messages
|
||
|
||
def _convert_tools_for_dialog_api(self, tools: list, api_mode: str):
|
||
if not tools:
|
||
return []
|
||
|
||
if api_mode == "openai_chat_completions":
|
||
return tools
|
||
if api_mode == "openai_completions":
|
||
return []
|
||
|
||
converted = []
|
||
for tool in tools:
|
||
if not isinstance(tool, dict):
|
||
continue
|
||
function = tool.get("function") or {}
|
||
name = str(function.get("name", "")).strip()
|
||
if not name:
|
||
continue
|
||
desc = str(function.get("description", "") or "").strip()
|
||
params = function.get("parameters")
|
||
if not isinstance(params, dict):
|
||
params = {"type": "object", "properties": {}}
|
||
|
||
if api_mode == "openai_responses":
|
||
converted.append({
|
||
"type": "function",
|
||
"name": name,
|
||
"description": desc,
|
||
"parameters": params,
|
||
})
|
||
elif api_mode == "claude_messages":
|
||
converted.append({
|
||
"name": name,
|
||
"description": desc,
|
||
"input_schema": params,
|
||
})
|
||
elif api_mode == "gemini_generate_content":
|
||
converted.append({
|
||
"name": name,
|
||
"description": desc,
|
||
"parameters": params,
|
||
})
|
||
|
||
if api_mode == "gemini_generate_content":
|
||
return [{"functionDeclarations": converted}] if converted else []
|
||
return converted
|
||
|
||
def _parse_dialog_api_response(self, api_mode: str, data: dict) -> tuple[str, list]:
|
||
data = data if isinstance(data, dict) else {}
|
||
|
||
def _fallback_openai():
|
||
choices = data.get("choices", [])
|
||
choice0 = choices[0] if choices else {}
|
||
message = choice0.get("message", {}) if isinstance(choice0, dict) else {}
|
||
full = message.get("content", "") or choice0.get("text", "") or ""
|
||
calls = self._extract_tool_calls_data(message, choice0)
|
||
if not isinstance(calls, list):
|
||
calls = []
|
||
return full, calls
|
||
|
||
if api_mode == "openai_chat_completions":
|
||
return _fallback_openai()
|
||
|
||
if api_mode == "openai_completions":
|
||
choices = data.get("choices", [])
|
||
if choices:
|
||
c0 = choices[0] if isinstance(choices[0], dict) else {}
|
||
text = c0.get("text", "")
|
||
if not text and isinstance(c0.get("message"), dict):
|
||
text = c0["message"].get("content", "")
|
||
return text or "", []
|
||
return "", []
|
||
|
||
if api_mode == "openai_responses":
|
||
if "choices" in data:
|
||
return _fallback_openai()
|
||
|
||
text_parts = []
|
||
tool_calls = []
|
||
if isinstance(data.get("output_text"), str) and data.get("output_text"):
|
||
text_parts.append(data.get("output_text", ""))
|
||
|
||
output = data.get("output") or []
|
||
for item in output if isinstance(output, list) else []:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
item_type = item.get("type", "")
|
||
if item_type in ("function_call", "tool_call"):
|
||
name = item.get("name", "")
|
||
raw_args = item.get("arguments", "{}")
|
||
if isinstance(raw_args, dict):
|
||
raw_args = json.dumps(raw_args, ensure_ascii=False)
|
||
tool_calls.append({
|
||
"id": item.get("id") or f"resp_fc_{uuid.uuid4().hex[:8]}",
|
||
"type": "function",
|
||
"function": {"name": name, "arguments": raw_args or "{}"},
|
||
})
|
||
if item_type == "message":
|
||
for c in item.get("content") or []:
|
||
if not isinstance(c, dict):
|
||
continue
|
||
c_type = c.get("type")
|
||
if c_type in ("output_text", "text"):
|
||
txt = c.get("text", "")
|
||
if txt:
|
||
text_parts.append(txt)
|
||
elif c_type in ("function_call", "tool_call"):
|
||
name = c.get("name", "")
|
||
raw_args = c.get("arguments", "{}")
|
||
if isinstance(raw_args, dict):
|
||
raw_args = json.dumps(raw_args, ensure_ascii=False)
|
||
tool_calls.append({
|
||
"id": c.get("id") or f"resp_fc_{uuid.uuid4().hex[:8]}",
|
||
"type": "function",
|
||
"function": {"name": name, "arguments": raw_args or "{}"},
|
||
})
|
||
return "".join(text_parts), tool_calls
|
||
|
||
if api_mode == "claude_messages":
|
||
if "choices" in data:
|
||
return _fallback_openai()
|
||
|
||
text_parts = []
|
||
tool_calls = []
|
||
for block in data.get("content") or []:
|
||
if not isinstance(block, dict):
|
||
continue
|
||
block_type = block.get("type")
|
||
if block_type == "text":
|
||
txt = block.get("text", "")
|
||
if txt:
|
||
text_parts.append(txt)
|
||
elif block_type == "tool_use":
|
||
name = block.get("name", "")
|
||
args = block.get("input", {})
|
||
if not isinstance(args, dict):
|
||
args = {}
|
||
tool_calls.append({
|
||
"id": block.get("id") or f"claude_fc_{uuid.uuid4().hex[:8]}",
|
||
"type": "function",
|
||
"function": {
|
||
"name": name,
|
||
"arguments": json.dumps(args, ensure_ascii=False),
|
||
},
|
||
})
|
||
return "".join(text_parts), tool_calls
|
||
|
||
if api_mode == "gemini_generate_content":
|
||
if "choices" in data:
|
||
return _fallback_openai()
|
||
|
||
text_parts = []
|
||
tool_calls = []
|
||
candidates = data.get("candidates") or []
|
||
for candidate in candidates if isinstance(candidates, list) else []:
|
||
content = (candidate or {}).get("content") or {}
|
||
parts = content.get("parts") or []
|
||
for part in parts if isinstance(parts, list) else []:
|
||
if not isinstance(part, dict):
|
||
continue
|
||
if "text" in part and part.get("text"):
|
||
text_parts.append(str(part.get("text")))
|
||
function_call = part.get("functionCall") or part.get("function_call")
|
||
if isinstance(function_call, dict):
|
||
name = function_call.get("name", "")
|
||
args = function_call.get("args", {})
|
||
if isinstance(args, str):
|
||
try:
|
||
args = json.loads(args)
|
||
except Exception:
|
||
args = {"raw": args}
|
||
if not isinstance(args, dict):
|
||
args = {}
|
||
tool_calls.append({
|
||
"id": f"gemini_fc_{uuid.uuid4().hex[:8]}",
|
||
"type": "function",
|
||
"function": {
|
||
"name": name,
|
||
"arguments": json.dumps(args, ensure_ascii=False),
|
||
},
|
||
})
|
||
return "".join(text_parts), tool_calls
|
||
|
||
return "", []
|
||
|
||
def _create_proxy_connector(self):
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if not proxy_config.get("enabled", False):
|
||
return None
|
||
|
||
proxy_type = str(proxy_config.get("type", "socks5")).upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_username = proxy_config.get("username")
|
||
proxy_password = proxy_config.get("password")
|
||
|
||
if proxy_username and proxy_password:
|
||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||
else:
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
|
||
if PROXY_SUPPORT:
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
logger.debug(f"使用代理: {proxy_type}://{proxy_host}:{proxy_port}")
|
||
except Exception as e:
|
||
logger.warning(f"代理配置失败,将直连: {e}")
|
||
connector = None
|
||
else:
|
||
logger.warning("代理功能不可用(aiohttp_socks 未安装),将直连")
|
||
connector = None
|
||
|
||
return connector
|
||
|
||
async def _send_dialog_api_request(
|
||
self,
|
||
api_config: dict,
|
||
messages: list,
|
||
tools: list | None = None,
|
||
*,
|
||
request_tag: str = "",
|
||
prefer_stream: bool = True,
|
||
max_tokens: int | None = None,
|
||
) -> tuple[str, list]:
|
||
tag = str(request_tag or "").strip()
|
||
mode = self._resolve_dialog_api_mode(api_config)
|
||
model_name = str(api_config.get("model", "")).lower()
|
||
allow_stream = bool(prefer_stream and mode == "openai_chat_completions" and "gemini-3" not in model_name)
|
||
|
||
api_url = str(api_config.get("url", "")).strip()
|
||
api_key = str(api_config.get("api_key", "")).strip()
|
||
request_url = api_url
|
||
limit = int(max_tokens if max_tokens is not None else api_config.get("max_tokens", 4096))
|
||
mode_tools = self._convert_tools_for_dialog_api(tools or [], mode)
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}",
|
||
}
|
||
|
||
payload = {}
|
||
if mode == "openai_chat_completions":
|
||
payload = {
|
||
"model": api_config.get("model", ""),
|
||
"messages": messages,
|
||
"max_tokens": limit,
|
||
"stream": allow_stream,
|
||
}
|
||
if mode_tools:
|
||
payload["tools"] = mode_tools
|
||
elif mode == "openai_responses":
|
||
payload = {
|
||
"model": api_config.get("model", ""),
|
||
"input": self._convert_openai_messages_to_responses_input(messages),
|
||
"max_output_tokens": limit,
|
||
"stream": False,
|
||
}
|
||
if mode_tools:
|
||
payload["tools"] = mode_tools
|
||
allow_stream = False
|
||
elif mode == "openai_completions":
|
||
payload = {
|
||
"model": api_config.get("model", ""),
|
||
"prompt": self._convert_openai_messages_to_plain_prompt(messages),
|
||
"max_tokens": limit,
|
||
"stream": False,
|
||
}
|
||
allow_stream = False
|
||
elif mode == "claude_messages":
|
||
system_text, claude_messages = self._convert_openai_messages_to_claude(messages)
|
||
payload = {
|
||
"model": api_config.get("model", ""),
|
||
"messages": claude_messages,
|
||
"max_tokens": limit,
|
||
"stream": False,
|
||
}
|
||
if system_text:
|
||
payload["system"] = system_text
|
||
if mode_tools:
|
||
payload["tools"] = mode_tools
|
||
headers["x-api-key"] = api_key
|
||
headers["anthropic-version"] = str(api_config.get("anthropic_version", "2023-06-01"))
|
||
allow_stream = False
|
||
elif mode == "gemini_generate_content":
|
||
request_url = self._build_gemini_generate_content_url(api_url, api_config.get("model", ""))
|
||
system_text, gemini_contents = self._convert_openai_messages_to_gemini(messages)
|
||
payload = {
|
||
"contents": gemini_contents,
|
||
"generationConfig": {
|
||
"maxOutputTokens": limit,
|
||
},
|
||
}
|
||
if system_text:
|
||
payload["systemInstruction"] = {
|
||
"parts": [{"text": system_text}]
|
||
}
|
||
if mode_tools:
|
||
payload["tools"] = mode_tools
|
||
allow_stream = False
|
||
else:
|
||
raise Exception(f"不支持的 API 模式: {mode}")
|
||
|
||
timeout = aiohttp.ClientTimeout(total=int(api_config.get("timeout", 120)))
|
||
connector = self._create_proxy_connector()
|
||
|
||
logger.debug(f"{tag} 对话API模式: {mode}, stream={allow_stream}, url={request_url}")
|
||
|
||
try:
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
async with session.post(request_url, json=payload, headers=headers) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
raise Exception(f"API 错误 {resp.status}: {error_text}")
|
||
|
||
if allow_stream:
|
||
full_content = ""
|
||
tool_calls_dict = {}
|
||
|
||
async for raw_line in resp.content:
|
||
line = raw_line.decode("utf-8").strip()
|
||
if not line or line == "data: [DONE]":
|
||
continue
|
||
if not line.startswith("data: "):
|
||
continue
|
||
|
||
try:
|
||
chunk = json.loads(line[6:])
|
||
except Exception:
|
||
continue
|
||
choices = chunk.get("choices", [])
|
||
if not choices:
|
||
continue
|
||
delta = choices[0].get("delta", {})
|
||
content_piece = delta.get("content", "")
|
||
if content_piece:
|
||
full_content += content_piece
|
||
|
||
if delta.get("tool_calls"):
|
||
for tool_call_delta in delta["tool_calls"]:
|
||
index = tool_call_delta.get("index", 0)
|
||
if index not in tool_calls_dict:
|
||
tool_calls_dict[index] = {
|
||
"id": "",
|
||
"type": "function",
|
||
"function": {
|
||
"name": "",
|
||
"arguments": "",
|
||
},
|
||
}
|
||
if "id" in tool_call_delta:
|
||
tool_calls_dict[index]["id"] = tool_call_delta["id"]
|
||
if "type" in tool_call_delta:
|
||
tool_calls_dict[index]["type"] = tool_call_delta["type"]
|
||
if "function" in tool_call_delta:
|
||
fn_delta = tool_call_delta["function"]
|
||
if "name" in fn_delta:
|
||
tool_calls_dict[index]["function"]["name"] += fn_delta["name"]
|
||
if "arguments" in fn_delta:
|
||
tool_calls_dict[index]["function"]["arguments"] += fn_delta["arguments"]
|
||
|
||
tool_calls_data = [tool_calls_dict[i] for i in sorted(tool_calls_dict.keys())] if tool_calls_dict else []
|
||
return full_content, tool_calls_data
|
||
|
||
data = await resp.json(content_type=None)
|
||
full_content, tool_calls_data = self._parse_dialog_api_response(mode, data)
|
||
if not isinstance(tool_calls_data, list):
|
||
tool_calls_data = []
|
||
return full_content or "", tool_calls_data
|
||
except aiohttp.ClientError as e:
|
||
raise Exception(f"网络请求失败: {type(e).__name__}: {str(e)}")
|
||
except asyncio.TimeoutError:
|
||
raise Exception(f"API 请求超时 (timeout={int(api_config.get('timeout', 120))}s)")
|
||
|
||
def _extract_after_last_answer_marker(self, text: str) -> str | None:
|
||
"""从文本中抽取最后一个“最终/输出/答案”标记后的内容(不要求必须是编号大纲)。"""
|
||
if not text:
|
||
return None
|
||
|
||
# 1) 明确的行首标记:Text:/Final Answer:/输出: ...
|
||
marker_re = re.compile(
|
||
r"(?im)^\s*(?:\d+\s*[\.\、::\)、))\-–—]\s*)?"
|
||
r"(?:text|final\s*answer|final\s*response|final\s*output|final|output|answer|response|输出|最终回复|最终答案|最终)\s*[::]\s*"
|
||
)
|
||
matches = list(marker_re.finditer(text))
|
||
if matches:
|
||
candidate = text[matches[-1].end():].strip()
|
||
if candidate:
|
||
return candidate
|
||
|
||
# 2) JSON/YAML 风格:final: ... / \"final\": \"...\"
|
||
kv_re = re.compile(
|
||
r"(?im)^\s*\"?(?:final|answer|response|output|text|最终|最终回复|最终答案|输出)\"?\s*[::]\s*"
|
||
)
|
||
kv_matches = list(kv_re.finditer(text))
|
||
if kv_matches:
|
||
candidate = text[kv_matches[-1].end():].strip()
|
||
if candidate:
|
||
return candidate
|
||
|
||
# 3) 纯 JSON 对象(尝试解析)
|
||
stripped = text.strip()
|
||
if stripped.startswith("{") and stripped.endswith("}"):
|
||
try:
|
||
obj = json.loads(stripped)
|
||
if isinstance(obj, dict):
|
||
for key in ("final", "answer", "response", "output", "text"):
|
||
v = obj.get(key)
|
||
if isinstance(v, str) and v.strip():
|
||
return v.strip()
|
||
except Exception:
|
||
pass
|
||
|
||
return None
|
||
|
||
def _extract_final_answer_from_outline(self, text: str) -> str | None:
|
||
"""从“分析/草稿/输出”这类结构化大纲中提取最终回复正文(用于拦截思维链)。"""
|
||
if not text:
|
||
return None
|
||
|
||
# 至少包含多个“1./2./3.”段落,才认为可能是大纲/思维链输出
|
||
heading_re = re.compile(r"(?m)^\s*\d+\s*[\.\、::\)、))\-–—]\s*\S+")
|
||
if len(heading_re.findall(text)) < 2:
|
||
return None
|
||
|
||
# 优先:提取最后一个 “Text:/Final Answer:/Output:” 之后的内容
|
||
marker_re = re.compile(
|
||
r"(?im)^\s*(?:\d+\s*[\.\、::\)、))\-–—]\s*)?"
|
||
r"(?:text|final\s*answer|final\s*response|final\s*output|output|answer|response|输出|最终回复|最终答案)\s*[::]\s*"
|
||
)
|
||
matches = list(marker_re.finditer(text))
|
||
if matches:
|
||
candidate = text[matches[-1].end():].strip()
|
||
if candidate:
|
||
return candidate
|
||
|
||
# 没有明确的最终标记时,仅在包含“分析/思考/草稿/输出”等元信息关键词的情况下兜底抽取
|
||
lowered = text.lower()
|
||
outline_keywords = (
|
||
"analyze",
|
||
"analysis",
|
||
"reasoning",
|
||
"internal monologue",
|
||
"mind space",
|
||
"draft",
|
||
"drafting",
|
||
"outline",
|
||
"plan",
|
||
"steps",
|
||
"formulating response",
|
||
"final polish",
|
||
"final answer",
|
||
"output generation",
|
||
"system prompt",
|
||
"chat log",
|
||
"previous turn",
|
||
"current situation",
|
||
)
|
||
cn_keywords = ("思考", "分析", "推理", "思维链", "草稿", "计划", "步骤", "输出", "最终")
|
||
if not any(k in lowered for k in outline_keywords) and not any(k in text for k in cn_keywords):
|
||
return None
|
||
|
||
# 次选:取最后一个非空段落(避免返回整段大纲)
|
||
parts = [p.strip() for p in re.split(r"\n{2,}", text) if p.strip()]
|
||
if not parts:
|
||
return None
|
||
|
||
last = parts[-1]
|
||
if len(heading_re.findall(last)) == 0:
|
||
return last
|
||
return None
|
||
|
||
def _strip_thinking_content(self, text: str) -> str:
|
||
"""移除常见的“思考/推理”外显内容(如 <think>...</think>、思考:...)。"""
|
||
if not text:
|
||
return ""
|
||
|
||
t = text.replace("\r\n", "\n").replace("\r", "\n")
|
||
|
||
# 1) 先移除显式标签块(常见于某些推理模型)
|
||
thinking_tags = ("think", "analysis", "reasoning", "thought", "thinking", "thoughts", "scratchpad", "reflection")
|
||
for tag in thinking_tags:
|
||
t = re.sub(rf"<{tag}\b[^>]*>.*?</{tag}>", "", t, flags=re.IGNORECASE | re.DOTALL)
|
||
# 兼容被转义的标签(<think>...</think>)
|
||
t = re.sub(rf"<{tag}\b[^&]*>.*?</{tag}>", "", t, flags=re.IGNORECASE | re.DOTALL)
|
||
|
||
# 1.1) 兜底:流式/截断导致标签未闭合时,若开头出现思考标签,直接截断后续内容
|
||
m = re.search(r"<(think|analysis|reasoning|thought|thinking|thoughts|scratchpad|reflection)\b[^>]*>", t, flags=re.IGNORECASE)
|
||
if m and m.start() < 200:
|
||
t = t[: m.start()].rstrip()
|
||
m2 = re.search(r"<(think|analysis|reasoning|thought|thinking|thoughts|scratchpad|reflection)\b[^&]*>", t, flags=re.IGNORECASE)
|
||
if m2 and m2.start() < 200:
|
||
t = t[: m2.start()].rstrip()
|
||
|
||
# 2) 再处理“思考:.../最终:...”这种分段格式(尽量只剥离前置思考)
|
||
lines = t.split("\n")
|
||
if not lines:
|
||
return t
|
||
|
||
# 若文本中包含明显的“最终/输出/答案”标记(不限是否编号),直接抽取最后一段,避免把大纲整体发出去
|
||
if self._contains_thinking_markers(t):
|
||
extracted_anywhere = self._extract_after_last_answer_marker(t)
|
||
if extracted_anywhere:
|
||
return extracted_anywhere
|
||
|
||
reasoning_kw = (
|
||
r"思考过程|推理过程|分析过程|思考|分析|推理|思路|内心独白|内心os|思维链|"
|
||
r"chain\s*of\s*thought|reasoning|analysis|analyze|thinking|thoughts|thought\s*process|scratchpad|plan|steps|draft|outline"
|
||
)
|
||
answer_kw = r"最终答案|最终回复|最终|回答|回复|答复|结论|输出|final(?:\s*answer)?|final\s*response|final\s*output|answer|response|output|text"
|
||
|
||
# 兼容:
|
||
# - 思考:... / 最终回复:...
|
||
# - 【思考】... / 【最终】...
|
||
# - **思考过程:**(Markdown 会在外层先被剥离)
|
||
reasoning_start = re.compile(
|
||
rf"^\s*(?:\d+\s*[\.\、::\)、))\-–—]\s*)?(?:[-*•]+\s*)?"
|
||
rf"(?:【\s*(?:{reasoning_kw})\s*】\s*[::]?\s*|(?:{reasoning_kw})(?:\s*】)?\s*(?:[::]|$|\s+))",
|
||
re.IGNORECASE,
|
||
)
|
||
answer_start = re.compile(
|
||
rf"^\s*(?:\d+\s*[\.\、::\)、))\-–—]\s*)?(?:[-*•]+\s*)?"
|
||
rf"(?:【\s*(?:{answer_kw})\s*】\s*[::]?\s*|(?:{answer_kw})(?:\s*】)?\s*(?:[::]|$)\s*)",
|
||
re.IGNORECASE,
|
||
)
|
||
|
||
# 2.0) 若文本开头就是“最终回复:/Final answer:”之类,直接去掉标记(不强依赖出现“思考块”)
|
||
for idx, line in enumerate(lines):
|
||
if line.strip() == "":
|
||
continue
|
||
m0 = answer_start.match(line)
|
||
if m0:
|
||
lines[idx] = line[m0.end():].lstrip()
|
||
break
|
||
|
||
has_reasoning = any(reasoning_start.match(line) for line in lines[:10])
|
||
has_answer_marker = any(answer_start.match(line) for line in lines)
|
||
|
||
# 2.1) 若同时存在“思考块 + 答案标记”,跳过思考块直到答案标记
|
||
if has_reasoning and has_answer_marker:
|
||
out_lines: list[str] = []
|
||
skipping = False
|
||
answer_started = False
|
||
for line in lines:
|
||
if answer_started:
|
||
out_lines.append(line)
|
||
continue
|
||
|
||
if not skipping and reasoning_start.match(line):
|
||
skipping = True
|
||
continue
|
||
|
||
if skipping:
|
||
m = answer_start.match(line)
|
||
if m:
|
||
answer_started = True
|
||
skipping = False
|
||
out_lines.append(line[m.end():].lstrip())
|
||
continue
|
||
|
||
m = answer_start.match(line)
|
||
if m:
|
||
answer_started = True
|
||
out_lines.append(line[m.end():].lstrip())
|
||
else:
|
||
out_lines.append(line)
|
||
|
||
t2 = "\n".join(out_lines).strip()
|
||
return t2 if t2 else t
|
||
|
||
# 2.2) 兜底:若开头就是“思考:”,尝试去掉第一段(到第一个空行)
|
||
if has_reasoning:
|
||
first_blank_idx = None
|
||
for idx, line in enumerate(lines):
|
||
if line.strip() == "":
|
||
first_blank_idx = idx
|
||
break
|
||
if first_blank_idx is not None and first_blank_idx + 1 < len(lines):
|
||
candidate = "\n".join(lines[first_blank_idx + 1 :]).strip()
|
||
if candidate:
|
||
return candidate
|
||
|
||
# 2.3) 兜底:识别“1. Analyze... 2. ... 6. Output ... Text: ...”这类思维链大纲并抽取最终正文
|
||
outline_extracted = self._extract_final_answer_from_outline("\n".join(lines).strip())
|
||
if outline_extracted:
|
||
return outline_extracted
|
||
|
||
# 将行级处理结果合回文本(例如去掉开头的“最终回复:”标记)
|
||
t = "\n".join(lines).strip()
|
||
|
||
# 3) 兼容 <final>...</final> 这类包裹(保留正文,去掉标签)
|
||
t = re.sub(r"</?\s*(final|answer)\s*>", "", t, flags=re.IGNORECASE).strip()
|
||
|
||
return t
|
||
|
||
def _strip_markdown_syntax(self, text: str) -> str:
|
||
"""将常见 Markdown 标记转换为更像纯文本的形式(保留内容,移除格式符)。"""
|
||
if not text:
|
||
return ""
|
||
|
||
t = text.replace("\r\n", "\n").replace("\r", "\n")
|
||
|
||
# 去掉代码块围栏(保留内容)
|
||
t = re.sub(r"```[^\n]*\n", "", t)
|
||
t = t.replace("```", "")
|
||
|
||
# 图片/链接: / [text](url)
|
||
def _md_image_repl(m: re.Match) -> str:
|
||
alt = (m.group(1) or "").strip()
|
||
url = (m.group(2) or "").strip()
|
||
if alt and url:
|
||
return f"{alt}({url})"
|
||
return url or alt or ""
|
||
|
||
def _md_link_repl(m: re.Match) -> str:
|
||
label = (m.group(1) or "").strip()
|
||
url = (m.group(2) or "").strip()
|
||
if label and url:
|
||
return f"{label}({url})"
|
||
return url or label or ""
|
||
|
||
t = re.sub(r"!\[([^\]]*)\]\(([^)]+)\)", _md_image_repl, t)
|
||
t = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", _md_link_repl, t)
|
||
|
||
# 行级标记:标题、引用、分割线
|
||
cleaned_lines: list[str] = []
|
||
for line in t.split("\n"):
|
||
line = re.sub(r"^\s{0,3}#{1,6}\s+", "", line) # 标题
|
||
line = re.sub(r"^\s{0,3}>\s?", "", line) # 引用
|
||
if re.match(r"^\s*(?:-{3,}|\*{3,}|_{3,})\s*$", line):
|
||
continue # 分割线整行移除
|
||
cleaned_lines.append(line)
|
||
t = "\n".join(cleaned_lines)
|
||
|
||
# 行内代码:`code`
|
||
t = re.sub(r"`([^`]+)`", r"\1", t)
|
||
|
||
# 粗体/删除线(保留文本)
|
||
t = t.replace("**", "")
|
||
t = t.replace("__", "")
|
||
t = t.replace("~~", "")
|
||
|
||
# 斜体(保留文本,避免误伤乘法/通配符,仅处理成对包裹)
|
||
t = re.sub(r"(?<!\*)\*([^*\n]+)\*(?!\*)", r"\1", t)
|
||
t = re.sub(r"(?<!_)_([^_\n]+)_(?!_)", r"\1", t)
|
||
|
||
# 压缩过多空行
|
||
t = re.sub(r"\n{3,}", "\n\n", t)
|
||
return t.strip()
|
||
|
||
def _parse_nickname_parts(self, nickname: str) -> tuple:
|
||
"""
|
||
解析昵称字符串,提取群昵称和微信昵称
|
||
|
||
输入格式可能是:
|
||
- "群昵称=xxx | 微信昵称=yyy"
|
||
- "群昵称=xxx"
|
||
- "微信昵称=yyy"
|
||
- "普通昵称"
|
||
|
||
Returns:
|
||
(group_nickname, wechat_name)
|
||
"""
|
||
if not nickname:
|
||
return ("", "")
|
||
|
||
group_nickname = ""
|
||
wechat_name = ""
|
||
|
||
if "群昵称=" in nickname or "微信昵称=" in nickname:
|
||
parts = nickname.split("|")
|
||
for part in parts:
|
||
part = part.strip()
|
||
if part.startswith("群昵称="):
|
||
group_nickname = part[4:].strip()
|
||
elif part.startswith("微信昵称="):
|
||
wechat_name = part[5:].strip()
|
||
else:
|
||
# 普通昵称,当作微信昵称
|
||
wechat_name = nickname.strip()
|
||
|
||
return (group_nickname, wechat_name)
|
||
|
||
def _format_timestamp(self, timestamp) -> str:
|
||
"""格式化时间戳为可读字符串"""
|
||
if not timestamp:
|
||
return ""
|
||
|
||
try:
|
||
if isinstance(timestamp, (int, float)):
|
||
from datetime import datetime
|
||
dt = datetime.fromtimestamp(timestamp)
|
||
return dt.strftime("%Y-%m-%d %H:%M")
|
||
elif isinstance(timestamp, str):
|
||
# 尝试解析 ISO 格式
|
||
if "T" in timestamp:
|
||
dt_str = timestamp.split(".")[0] # 去掉毫秒
|
||
from datetime import datetime
|
||
dt = datetime.fromisoformat(dt_str)
|
||
return dt.strftime("%Y-%m-%d %H:%M")
|
||
return timestamp[:16] if len(timestamp) > 16 else timestamp
|
||
except Exception:
|
||
pass
|
||
return ""
|
||
|
||
def _format_user_message_content(self, nickname: str, content: str, timestamp=None, msg_type: str = "text", user_id: str = None) -> str:
|
||
"""
|
||
格式化用户消息内容,包含结构化的用户信息
|
||
|
||
格式: [时间][用户ID:xxx][群昵称:xxx][微信昵称:yyy][类型:text]
|
||
消息内容
|
||
"""
|
||
group_nickname, wechat_name = self._parse_nickname_parts(nickname)
|
||
time_str = self._format_timestamp(timestamp)
|
||
|
||
# 构建结构化前缀
|
||
parts = []
|
||
if time_str:
|
||
parts.append(f"时间:{time_str}")
|
||
# 添加用户唯一标识(取wxid后6位作为短ID,便于AI区分不同用户)
|
||
if user_id:
|
||
short_id = user_id[-6:] if len(user_id) > 6 else user_id
|
||
parts.append(f"用户ID:{short_id}")
|
||
if group_nickname:
|
||
parts.append(f"群昵称:{group_nickname}")
|
||
if wechat_name:
|
||
parts.append(f"微信昵称:{wechat_name}")
|
||
if msg_type:
|
||
parts.append(f"类型:{msg_type}")
|
||
|
||
prefix = "[" + "][".join(parts) + "]" if parts else ""
|
||
|
||
return f"{prefix}\n{content}" if prefix else content
|
||
|
||
def _append_group_history_messages(self, messages: list, recent_history: list):
|
||
"""将群聊历史按 role 追加到 LLM messages"""
|
||
for msg in recent_history:
|
||
role = msg.get("role") or "user"
|
||
msg_nickname = msg.get("nickname", "")
|
||
msg_content = msg.get("content", "")
|
||
msg_timestamp = msg.get("timestamp")
|
||
msg_wxid = msg.get("wxid", "") # 获取用户唯一标识
|
||
|
||
# 机器人历史回复
|
||
if role == "assistant":
|
||
if isinstance(msg_content, list):
|
||
msg_content = self._extract_text_from_multimodal(msg_content)
|
||
# 避免旧历史中的 Markdown/思维链污染上下文
|
||
msg_content = self._sanitize_llm_output(msg_content)
|
||
|
||
# 机器人回复也加上时间标记
|
||
time_str = self._format_timestamp(msg_timestamp)
|
||
if time_str:
|
||
msg_content = f"[时间:{time_str}][类型:assistant]\n{msg_content}"
|
||
|
||
messages.append({
|
||
"role": "assistant",
|
||
"content": msg_content
|
||
})
|
||
continue
|
||
|
||
# 用户历史消息
|
||
if isinstance(msg_content, list):
|
||
# 多模态消息(含图片)
|
||
content_with_info = []
|
||
text_content = ""
|
||
has_image = False
|
||
|
||
for item in msg_content:
|
||
if item.get("type") == "text":
|
||
text_content = item.get("text", "")
|
||
elif item.get("type") == "image_url":
|
||
has_image = True
|
||
content_with_info.append(item)
|
||
|
||
msg_type = "image" if has_image else "text"
|
||
formatted_text = self._format_user_message_content(
|
||
msg_nickname, text_content, msg_timestamp, msg_type, msg_wxid
|
||
)
|
||
content_with_info.insert(0, {"type": "text", "text": formatted_text})
|
||
|
||
messages.append({
|
||
"role": "user",
|
||
"content": content_with_info
|
||
})
|
||
else:
|
||
# 纯文本消息
|
||
formatted_content = self._format_user_message_content(
|
||
msg_nickname, msg_content, msg_timestamp, "text", msg_wxid
|
||
)
|
||
messages.append({
|
||
"role": "user",
|
||
"content": formatted_content
|
||
})
|
||
|
||
def _get_bot_nickname(self) -> str:
|
||
try:
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
nickname = main_config.get("Bot", {}).get("nickname", "")
|
||
return nickname or "机器人"
|
||
except Exception:
|
||
return "机器人"
|
||
|
||
def _tool_call_to_action_text(self, function_name: str, arguments: dict) -> str:
|
||
args = arguments if isinstance(arguments, dict) else {}
|
||
|
||
if function_name == "user_signin":
|
||
return "签到"
|
||
|
||
if function_name == "check_profile":
|
||
return "查询个人信息"
|
||
|
||
return f"执行{function_name}"
|
||
|
||
def _build_tool_calls_context_note(self, tool_calls_data: list) -> str:
|
||
actions: list[str] = []
|
||
for tool_call in tool_calls_data or []:
|
||
function_name = tool_call.get("function", {}).get("name", "")
|
||
if not function_name:
|
||
continue
|
||
|
||
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
|
||
try:
|
||
arguments = json.loads(arguments_str) if arguments_str else {}
|
||
except Exception:
|
||
arguments = {}
|
||
|
||
actions.append(self._tool_call_to_action_text(function_name, arguments))
|
||
|
||
if not actions:
|
||
return "(已触发工具处理:上一条请求。结果将发送到聊天中。)"
|
||
|
||
return f"(已触发工具处理:{';'.join(actions)}。结果将发送到聊天中。)"
|
||
|
||
async def _record_tool_calls_to_context(
|
||
self,
|
||
tool_calls_data: list,
|
||
*,
|
||
from_wxid: str,
|
||
chat_id: str,
|
||
is_group: bool,
|
||
user_wxid: str | None = None,
|
||
):
|
||
note = self._build_tool_calls_context_note(tool_calls_data)
|
||
if chat_id:
|
||
self._add_to_memory(chat_id, "assistant", note)
|
||
|
||
if is_group and from_wxid:
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid or "")
|
||
await self._add_to_history(history_chat_id, self._get_bot_nickname(), note, role="assistant", sender_wxid=user_wxid or None)
|
||
|
||
def _extract_tool_intent_text(self, user_message: str, tool_query: str | None = None) -> str:
|
||
text = tool_query if tool_query is not None else user_message
|
||
text = str(text or "").strip()
|
||
if not text:
|
||
return ""
|
||
|
||
# 对“聊天记录/视频”等组合消息,尽量只取用户真实提问部分,避免历史文本触发工具误判
|
||
markers = (
|
||
"[用户的问题]:",
|
||
"[用户的问题]:",
|
||
"[用户的问题]\n",
|
||
"[用户的问题]",
|
||
)
|
||
for marker in markers:
|
||
if marker in text:
|
||
text = text.rsplit(marker, 1)[-1].strip()
|
||
return self._normalize_search_query(text) or text
|
||
|
||
def _normalize_search_query(self, text: str) -> str:
|
||
"""清洗搜索类工具的查询参数,去掉元信息/触发词"""
|
||
cleaned = str(text or "").strip()
|
||
if not cleaned:
|
||
return ""
|
||
|
||
cleaned = cleaned.replace("【当前消息】", "").strip()
|
||
|
||
if cleaned.startswith("[") and "\n" in cleaned:
|
||
first_line, rest = cleaned.split("\n", 1)
|
||
if any(token in first_line for token in ("时间:", "用户ID:", "群昵称:", "微信昵称:", "类型:")):
|
||
cleaned = rest.strip()
|
||
|
||
if cleaned.startswith("@"):
|
||
parts = cleaned.split(maxsplit=1)
|
||
if len(parts) > 1:
|
||
cleaned = parts[1].strip()
|
||
|
||
cleaned = re.sub(r"^(搜索|搜|查|查询|帮我搜|帮我查|帮我搜索|请搜索|请查)\s*", "", cleaned)
|
||
return cleaned.strip()
|
||
|
||
def _looks_like_info_query(self, text: str) -> bool:
|
||
t = str(text or "").strip().lower()
|
||
if not t:
|
||
return False
|
||
|
||
# 太短的消息不值得额外走一轮分类
|
||
if len(t) < 6:
|
||
return False
|
||
|
||
# 疑问/求评价/求推荐类
|
||
if any(x in t for x in ("?", "?")):
|
||
return True
|
||
if re.search(r"(什么|咋|怎么|如何|为啥|为什么|哪|哪里|哪个|多少|推荐|值不值得|值得吗|好不好|靠谱吗|评价|口碑|怎么样|如何评价|近况|最新|最近)", t):
|
||
return True
|
||
if re.search(r"\\b(what|who|when|where|why|how|details?|impact|latest|news|review|rating|price|info|information)\\b", t):
|
||
return True
|
||
|
||
# 明确的实体/对象询问(公会/游戏/公司/项目等)
|
||
if re.search(r"(公会|战队|服务器|区服|游戏|公司|品牌|产品|软件|插件|项目|平台|up主|主播|作者|电影|电视剧|小说)", t) and len(t) >= 8:
|
||
return True
|
||
|
||
return False
|
||
|
||
def _looks_like_lyrics_query(self, text: str) -> bool:
|
||
t = str(text or "").strip().lower()
|
||
if not t:
|
||
return False
|
||
return bool(re.search(
|
||
r"(歌词|歌名|哪首歌|哪一首歌|哪首歌曲|哪一首歌曲|谁的歌|谁唱|谁唱的|这句.*(歌|歌词)|这段.*(歌|歌词)|是什么歌|什么歌|是哪首歌|出自.*(歌|歌曲)|台词.*歌|lyric|lyrics)",
|
||
t,
|
||
))
|
||
|
||
def _looks_like_image_generation_request(self, text: str) -> bool:
|
||
"""判断是否是明确的生图/自拍请求。"""
|
||
t = str(text or "").strip().lower()
|
||
if not t:
|
||
return False
|
||
|
||
if re.search(r"(画一张|画张|画一幅|画幅|画一个|画个|画一下|画图|绘图|绘制|作画|出图|生成图片|生成照片|生成相片|文生图|图生图|以图生图)", t):
|
||
return True
|
||
if re.search(r"(生成|做|给我|帮我).{0,4}(一张|一幅|一个|张|个).{0,8}(图|图片|照片|自拍|自拍照|自画像)", t):
|
||
return True
|
||
if re.search(r"(来|发).{0,2}(一张|一幅|一个|张|个).{0,10}(图|图片|照片|自拍|自拍照|自画像)", t):
|
||
return True
|
||
if re.search(r"(来|发|给我|给|看看|看下|看一看).{0,4}(自拍|自拍照|自画像)", t):
|
||
return True
|
||
if re.search(r"(看看|看下|看一看|来点|来张|发|给我).{0,4}(腿|白丝|黑丝|丝袜|福利|福利图|色图|涩图|写真)", t):
|
||
return True
|
||
if re.search(r"(白丝|黑丝|丝袜|福利|福利图|色图|涩图|写真).{0,6}(图|图片|照片|自拍|来一张|来点|发一张)", t):
|
||
return True
|
||
if re.search(r"(看看腿|看腿|来点福利|来张福利|发点福利|来张白丝|来张黑丝)", t):
|
||
return True
|
||
|
||
return False
|
||
|
||
def _extract_legacy_text_search_tool_call(self, text: str) -> tuple[str, dict] | None:
|
||
"""
|
||
解析模型偶发输出的“文本工具调用”写法(例如 tavilywebsearch{query:...}),并转换为真实工具调用参数。
|
||
"""
|
||
raw = str(text or "")
|
||
if not raw:
|
||
return None
|
||
|
||
# 去掉 <ctrl46> 之类的控制标记
|
||
cleaned = re.sub(r"<ctrl\d+>", "", raw, flags=re.IGNORECASE)
|
||
|
||
m = re.search(
|
||
r"(?i)\b(?P<tool>tavilywebsearch|tavily_web_search|web_search)\s*\{\s*query\s*[:=]\s*(?P<q>[^{}]{1,800})\}",
|
||
cleaned,
|
||
)
|
||
if not m:
|
||
m = re.search(
|
||
r"(?i)\b(?P<tool>tavilywebsearch|tavily_web_search|web_search)\s*\(\s*query\s*[:=]\s*(?P<q>[^\)]{1,800})\)",
|
||
cleaned,
|
||
)
|
||
if not m:
|
||
return None
|
||
|
||
tool = str(m.group("tool") or "").strip().lower()
|
||
query = str(m.group("q") or "").strip().strip("\"'`")
|
||
if not query:
|
||
return None
|
||
|
||
# 统一映射到项目实际存在的工具名
|
||
if tool in ("tavilywebsearch", "tavily_web_search"):
|
||
tool_name = "tavily_web_search"
|
||
else:
|
||
tool_name = "web_search"
|
||
|
||
return tool_name, {"query": query[:400]}
|
||
|
||
def _extract_legacy_text_image_tool_call(self, text: str) -> tuple[str, dict] | None:
|
||
"""解析模型文本输出的绘图工具调用 JSON,并转换为真实工具调用参数。"""
|
||
raw = str(text or "")
|
||
if not raw:
|
||
return None
|
||
|
||
# 兼容 python 代码风格:print(draw_image("...")) / draw_image("...")
|
||
py_call = re.search(
|
||
r"(?is)(?:print\s*\(\s*)?"
|
||
r"(draw_image|generate_image|nano_ai_image_generation|flow2_ai_image_generation|"
|
||
r"jimeng_ai_image_generation|kiira2_ai_image_generation)\s*"
|
||
r"\(\s*([\"'])([\s\S]{1,2000}?)\2\s*\)\s*\)?",
|
||
raw,
|
||
)
|
||
if py_call:
|
||
name = py_call.group(1).strip()
|
||
prompt = py_call.group(3).strip()
|
||
if prompt:
|
||
return name, {"prompt": prompt}
|
||
|
||
candidates = []
|
||
for m in re.finditer(r"```(?:json)?\s*({[\s\S]{20,2000}})\s*```", raw, flags=re.IGNORECASE):
|
||
candidates.append(m.group(1))
|
||
|
||
m = re.search(r"(\{\s*\"(?:name|tool|action)\"\s*:\s*\"[^\"]+\"[\s\S]{0,2000}\})", raw)
|
||
if m:
|
||
candidates.append(m.group(1))
|
||
|
||
for blob in candidates:
|
||
try:
|
||
data = json.loads(blob)
|
||
except Exception:
|
||
continue
|
||
|
||
if not isinstance(data, dict):
|
||
continue
|
||
|
||
name = str(
|
||
data.get("name")
|
||
or data.get("tool")
|
||
or data.get("action")
|
||
or data.get("Action")
|
||
or ""
|
||
).strip()
|
||
if not name:
|
||
continue
|
||
|
||
args = data.get("arguments", None)
|
||
if args in (None, "", {}):
|
||
args = (
|
||
data.get("actioninput")
|
||
or data.get("action_input")
|
||
or data.get("actionInput")
|
||
or data.get("input")
|
||
or {}
|
||
)
|
||
if isinstance(args, str):
|
||
try:
|
||
args = json.loads(args)
|
||
except Exception:
|
||
raw_args = str(args).strip()
|
||
parsed_args = None
|
||
|
||
# 某些模型会把 actioninput 生成为“类 JSON 字符串”(转义不完整),尝试兜底修复
|
||
try:
|
||
parsed_args = json.loads(raw_args.replace('\\"', '"'))
|
||
except Exception:
|
||
pass
|
||
|
||
if not isinstance(parsed_args, dict):
|
||
prompt_match = re.search(
|
||
r"(?i)[\"']?prompt[\"']?\s*[:=]\s*[\"']([\s\S]{1,2000}?)[\"']",
|
||
raw_args,
|
||
)
|
||
if prompt_match:
|
||
parsed_args = {"prompt": prompt_match.group(1).strip()}
|
||
ratio_match = re.search(
|
||
r"(?i)[\"']?(?:aspectratio|aspect_ratio)[\"']?\s*[:=]\s*[\"']([^\"']{1,30})[\"']",
|
||
raw_args,
|
||
)
|
||
if ratio_match:
|
||
parsed_args["aspectratio"] = ratio_match.group(1).strip()
|
||
|
||
args = parsed_args if isinstance(parsed_args, dict) else {"prompt": raw_args}
|
||
|
||
if not isinstance(args, dict):
|
||
continue
|
||
|
||
prompt = args.get("prompt") or args.get("text") or args.get("query") or args.get("description")
|
||
if not prompt or not isinstance(prompt, str):
|
||
continue
|
||
|
||
normalized_args = dict(args)
|
||
normalized_args["prompt"] = prompt.strip()
|
||
return name, normalized_args
|
||
|
||
# 兜底:用正则尽量提取 name/prompt(允许单引号/非严格 JSON)
|
||
name_match = re.search(
|
||
r"(?i)[\"'](?:name|tool|action)[\"']\s*:\s*[\"']([^\"']+)[\"']",
|
||
raw,
|
||
)
|
||
prompt_match = re.search(
|
||
r"(?i)[\"']prompt[\"']\s*:\s*[\"']([\s\S]{1,2000}?)[\"']",
|
||
raw,
|
||
)
|
||
if prompt_match:
|
||
name = name_match.group(1) if name_match else "draw_image"
|
||
prompt = prompt_match.group(1).strip()
|
||
if prompt:
|
||
return name, {"prompt": prompt}
|
||
|
||
return None
|
||
|
||
def _resolve_image_tool_alias(
|
||
self,
|
||
requested_name: str,
|
||
allowed_tool_names: set[str],
|
||
available_tool_names: set[str],
|
||
loose_image_tool: bool,
|
||
) -> str | None:
|
||
"""将模型输出的绘图工具别名映射为实际工具名。"""
|
||
name = (requested_name or "").strip().lower()
|
||
if not name:
|
||
return None
|
||
|
||
# 严格遵守本轮工具选择结果:本轮未开放绘图工具时,不允许任何文本兜底触发
|
||
if not allowed_tool_names:
|
||
return None
|
||
|
||
if name in available_tool_names:
|
||
if name in allowed_tool_names or loose_image_tool:
|
||
return name
|
||
return None
|
||
|
||
alias_map = {
|
||
"draw_image": "nano_ai_image_generation",
|
||
"image_generation": "nano_ai_image_generation",
|
||
"image_generate": "nano_ai_image_generation",
|
||
"make_image": "nano_ai_image_generation",
|
||
"create_image": "nano_ai_image_generation",
|
||
"generate_image": "generate_image",
|
||
"nanoaiimage_generation": "nano_ai_image_generation",
|
||
"flow2aiimage_generation": "flow2_ai_image_generation",
|
||
"jimengaiimage_generation": "jimeng_ai_image_generation",
|
||
"kiira2aiimage_generation": "kiira2_ai_image_generation",
|
||
}
|
||
mapped = alias_map.get(name)
|
||
if mapped and mapped in available_tool_names:
|
||
if mapped in allowed_tool_names or loose_image_tool:
|
||
return mapped
|
||
|
||
for fallback in ("nano_ai_image_generation", "generate_image", "flow2_ai_image_generation"):
|
||
if fallback in available_tool_names and (fallback in allowed_tool_names or loose_image_tool):
|
||
return fallback
|
||
|
||
return None
|
||
|
||
def _should_allow_music_followup(self, messages: list, tool_calls_data: list) -> bool:
|
||
if not tool_calls_data:
|
||
return False
|
||
|
||
has_search_tool = any(
|
||
(tc or {}).get("function", {}).get("name", "") in ("tavily_web_search", "web_search")
|
||
for tc in (tool_calls_data or [])
|
||
)
|
||
if not has_search_tool:
|
||
return False
|
||
|
||
user_text = ""
|
||
for msg in reversed(messages or []):
|
||
if msg.get("role") == "user":
|
||
user_text = self._extract_text_from_multimodal(msg.get("content"))
|
||
break
|
||
if not user_text:
|
||
return False
|
||
|
||
return self._looks_like_lyrics_query(user_text)
|
||
|
||
async def _select_tools_for_message_async(self, tools: list, *, user_message: str, tool_query: str | None = None) -> list:
|
||
"""工具选择(与旧版一致,仅使用规则筛选)"""
|
||
return self._select_tools_for_message(tools, user_message=user_message, tool_query=tool_query)
|
||
|
||
def _select_tools_for_message(self, tools: list, *, user_message: str, tool_query: str | None = None) -> list:
|
||
tools_config = (self.config or {}).get("tools", {})
|
||
if not tools_config.get("smart_select", False):
|
||
return tools
|
||
|
||
raw_intent_text = str(tool_query if tool_query is not None else user_message).strip()
|
||
raw_t = raw_intent_text.lower()
|
||
intent_text = self._extract_tool_intent_text(user_message, tool_query=tool_query)
|
||
if not intent_text:
|
||
return tools
|
||
|
||
t = intent_text.lower()
|
||
allow: set[str] = set()
|
||
available_tool_names = {
|
||
(tool or {}).get("function", {}).get("name", "")
|
||
for tool in (tools or [])
|
||
if (tool or {}).get("function", {}).get("name")
|
||
}
|
||
|
||
# 显式搜索意图硬兜底:只要本轮可用工具里有搜索工具,就强制放行
|
||
# 注意:显式搜索意图必须基于“原始文本”判断,不能只用清洗后的 intent_text
|
||
# 否则“搜索下 xxx”会被清洗成“xxx”,导致误判为无搜索意图
|
||
raw_has_url = bool(re.search(r"(https?://|www\.)", raw_intent_text, flags=re.IGNORECASE))
|
||
explicit_read_web_intent = bool(re.search(
|
||
r"((阅读|读一下|读下|看下|看看|解析|总结|介绍).{0,8}(网页|网站|网址|链接))"
|
||
r"|((网页|网站|网址|链接).{0,8}(内容|正文|页面|信息|原文))",
|
||
raw_t,
|
||
))
|
||
if raw_has_url and re.search(r"(阅读|读一下|读下|看下|看看|解析|总结|介绍|提取)", raw_t):
|
||
explicit_read_web_intent = True
|
||
|
||
explicit_search_intent = bool(re.search(
|
||
r"(联网|搜索|搜一下|搜一搜|搜搜|搜索下|搜下|查一下|查资料|查新闻|查价格|帮我搜|帮我查)",
|
||
raw_t,
|
||
)) or explicit_read_web_intent
|
||
if explicit_search_intent:
|
||
for candidate in ("tavily_web_search", "web_search"):
|
||
if candidate in available_tool_names:
|
||
allow.add(candidate)
|
||
|
||
# 签到/个人信息
|
||
if re.search(r"(用户签到|签到|签个到)", t):
|
||
allow.add("user_signin")
|
||
if re.search(r"(个人信息|我的信息|我的积分|查积分|积分多少|连续签到|连签|我的资料)", t):
|
||
allow.add("check_profile")
|
||
|
||
# 鹿打卡
|
||
if re.search(r"(鹿打卡|鹿签到)", t):
|
||
allow.add("deer_checkin")
|
||
if re.search(r"(补签|补打卡)", t):
|
||
allow.add("makeup_checkin")
|
||
if re.search(r"(鹿.*(日历|月历|打卡日历))|((日历|月历|打卡日历).*鹿)", t):
|
||
allow.add("view_calendar")
|
||
|
||
# 搜索/资讯
|
||
if re.search(r"(联网|搜索|搜一下|搜一搜|搜搜|帮我搜|搜新闻|搜资料|查资料|查新闻|查价格|\bsearch\b|\bgoogle\b|\blookup\b|\bfind\b|\bnews\b|\blatest\b|\bdetails?\b|\bimpact\b)", t):
|
||
# 兼容旧工具名与当前插件实现
|
||
allow.add("tavily_web_search")
|
||
allow.add("web_search")
|
||
# 隐式信息检索:用户询问具体实体/口碑/评价但未明确说“搜索/联网”
|
||
if re.search(r"(怎么样|如何|评价|口碑|靠谱吗|值不值得|值得吗|好不好|推荐|牛不牛|强不强|厉不厉害|有名吗|什么来头|背景|近况|最新|最近)", t) and re.search(
|
||
r"(公会|战队|服务器|区服|游戏|公司|品牌|店|商家|产品|软件|插件|项目|平台|up主|主播|作者|电影|电视剧|小说|手游|网游)",
|
||
t,
|
||
):
|
||
allow.add("tavily_web_search")
|
||
allow.add("web_search")
|
||
if self._looks_like_lyrics_query(intent_text):
|
||
allow.add("tavily_web_search")
|
||
allow.add("web_search")
|
||
if re.search(r"(60秒|每日新闻|早报|新闻图片|读懂世界)", t):
|
||
allow.add("get_daily_news")
|
||
if re.search(r"(epic|喜加一|免费游戏)", t):
|
||
allow.add("get_epic_free_games")
|
||
|
||
# 音乐/短剧
|
||
# 仅在明确“点歌/播放/听一首/搜歌”等命令时开放,避免普通聊天误触
|
||
if re.search(r"(点歌|来(?:一首|首)|播放(?:一首|首)?|放歌|听(?:一首|首)|搜歌|找歌)", t):
|
||
allow.add("search_music")
|
||
if re.search(r"(短剧|搜短剧|找短剧)", t):
|
||
allow.add("search_playlet")
|
||
|
||
# 群聊总结
|
||
if re.search(r"(群聊总结|生成总结|总结一下|今日总结|昨天总结|群总结)", t):
|
||
allow.add("generate_summary")
|
||
|
||
# 娱乐
|
||
if re.search(r"(疯狂星期四|v我50|kfc)", t):
|
||
allow.add("get_kfc")
|
||
if re.search(r"(随机图片|来张图|来个图|随机图)", t):
|
||
allow.add("get_random_image")
|
||
if re.search(r"(随机视频|来个视频|随机短视频)", t):
|
||
allow.add("get_random_video")
|
||
|
||
# 绘图/视频生成(只在用户明确要求时开放)
|
||
if self._looks_like_image_generation_request(intent_text) or (
|
||
# 明确绘图动词/模式
|
||
re.search(r"(画一张|画张|画一幅|画幅|画一个|画个|画一下|画图|绘图|绘制|作画|出图|生成图片|生成照片|生成相片|文生图|图生图|以图生图)", t)
|
||
# “生成/做/给我”+“一张/一个/张/个”+“图/图片”类表达(例如:生成一张瑞依/做一张图)
|
||
or re.search(r"(生成|做|给我|帮我).{0,4}(一张|一幅|一个|张|个).{0,8}(图|图片|照片|自拍|自拍照|自画像)", t)
|
||
# “来/发”+“一张/张”+“图/图片”(例如:来张瑞依的图)
|
||
or re.search(r"(来|发).{0,2}(一张|一幅|一个|张|个).{0,10}(图|图片|照片|自拍|自拍照|自画像)", t)
|
||
# “发/来/给我”+“自拍/自画像”(例如:发张自拍/来个自画像)
|
||
or re.search(r"(来|发|给我|给).{0,3}(自拍|自拍照|自画像)", t)
|
||
# 口语化“看看腿/白丝/福利”等请求
|
||
or re.search(r"(看看|看下|看一看|来点|来张|发|给我).{0,4}(腿|白丝|黑丝|丝袜|福利|福利图|色图|涩图|写真)", t)
|
||
or re.search(r"(白丝|黑丝|丝袜|福利|福利图|色图|涩图|写真).{0,6}(图|图片|照片|自拍|来一张|来点|发一张)", t)
|
||
or re.search(r"(看看腿|看腿|来点福利|来张福利|发点福利|来张白丝|来张黑丝)", t)
|
||
# 二次重绘/返工(上下文里常省略“图/图片”)
|
||
or re.search(r"(重画|重新画|再画|重来一张|再来一张|重做一张)", t)
|
||
or re.fullmatch(r"(重来|再来|重来一次|再来一次|重新来)", t)
|
||
):
|
||
allow.update({
|
||
"nano_ai_image_generation",
|
||
"flow2_ai_image_generation",
|
||
"jimeng_ai_image_generation",
|
||
"kiira2_ai_image_generation",
|
||
"generate_image",
|
||
})
|
||
if re.search(
|
||
r"(生成视频|做个视频|视频生成|sora|grok|/视频)"
|
||
r"|((生成|制作|做|来|发|拍|整).{0,10}(视频|短视频|短片|片子|mv|vlog))"
|
||
r"|((视频|短视频|短片|片子|mv|vlog).{0,8}(生成|制作|做|来|发|整|安排))"
|
||
r"|(来一段.{0,8}(视频|短视频|短片))",
|
||
t,
|
||
):
|
||
allow.add("sora_video_generation")
|
||
allow.add("grok_video_generation")
|
||
|
||
# 如果已经命中特定领域工具(音乐/短剧等),且用户未明确表示“联网/网页/链接/来源”等需求,避免把联网搜索也暴露出去造成误触
|
||
explicit_web = bool(re.search(r"(联网|网页|网站|网址|链接|来源)", t))
|
||
if not explicit_web and {"search_music", "search_playlet"} & allow:
|
||
allow.discard("tavily_web_search")
|
||
allow.discard("web_search")
|
||
|
||
# 严格模式:没有明显工具意图时,不向模型暴露任何 tools,避免误触
|
||
if not allow:
|
||
return []
|
||
|
||
selected = []
|
||
for tool in tools or []:
|
||
name = tool.get("function", {}).get("name", "")
|
||
if name and name in allow:
|
||
selected.append(tool)
|
||
|
||
if explicit_search_intent:
|
||
selected_names = [tool.get("function", {}).get("name", "") for tool in selected]
|
||
logger.info(
|
||
f"[工具选择-搜索兜底] raw={raw_intent_text[:80]} | cleaned={intent_text[:80]} "
|
||
f"| allow={sorted(list(allow))} | selected={selected_names}"
|
||
)
|
||
|
||
return selected
|
||
|
||
async def _handle_context_stats(self, bot, from_wxid: str, user_wxid: str, is_group: bool):
|
||
"""处理上下文统计指令"""
|
||
try:
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
|
||
# 计算持久记忆 token
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
persistent_memories = self._get_persistent_memories(memory_chat_id) if memory_chat_id else []
|
||
persistent_tokens = 0
|
||
if persistent_memories:
|
||
persistent_tokens += self._estimate_tokens("【持久记忆】以下是用户要求你记住的重要信息:\n")
|
||
for m in persistent_memories:
|
||
mem_time = m['time'][:10] if m['time'] else ""
|
||
persistent_tokens += self._estimate_tokens(f"- [{mem_time}] {m['nickname']}: {m['content']}\n")
|
||
|
||
if is_group:
|
||
# 群聊:使用 history 机制
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
history = await self._load_history(history_chat_id)
|
||
history = self._filter_history_by_window(history)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
|
||
# 实际会发送给 AI 的上下文
|
||
context_messages = history[-max_context:] if len(history) > max_context else history
|
||
|
||
# 计算 token
|
||
context_tokens = 0
|
||
for msg in context_messages:
|
||
msg_content = msg.get("content", "")
|
||
nickname = msg.get("nickname", "")
|
||
|
||
if isinstance(msg_content, list):
|
||
# 多模态消息
|
||
for item in msg_content:
|
||
if item.get("type") == "text":
|
||
context_tokens += self._estimate_tokens(f"[{nickname}] {item.get('text', '')}")
|
||
elif item.get("type") == "image_url":
|
||
context_tokens += 85
|
||
else:
|
||
context_tokens += self._estimate_tokens(f"[{nickname}] {msg_content}")
|
||
|
||
# 加上 system prompt 的 token
|
||
system_tokens = self._estimate_tokens(self.system_prompt)
|
||
total_tokens = system_tokens + persistent_tokens + context_tokens
|
||
|
||
# 计算百分比
|
||
context_limit = self.config.get("api", {}).get("context_limit", 200000)
|
||
usage_percent = (total_tokens / context_limit) * 100
|
||
remaining_tokens = context_limit - total_tokens
|
||
|
||
msg = f"📊 群聊上下文统计\n\n"
|
||
msg += f"💬 历史总条数: {len(history)}\n"
|
||
msg += f"📤 AI可见条数: {len(context_messages)}/{max_context}\n"
|
||
msg += f"🤖 人设 Token: ~{system_tokens}\n"
|
||
msg += f"📌 持久记忆: {len(persistent_memories)} 条 (~{persistent_tokens} token)\n"
|
||
msg += f"📝 上下文 Token: ~{context_tokens}\n"
|
||
msg += f"📦 总计 Token: ~{total_tokens}\n"
|
||
msg += f"📈 使用率: {usage_percent:.1f}% (剩余 ~{remaining_tokens:,})\n"
|
||
# 向量长期记忆统计
|
||
if self._vector_memory_enabled and self._chroma_collection:
|
||
try:
|
||
vm_count = self._chroma_collection.count()
|
||
vm_watermark_ts = self._vector_watermarks.get(from_wxid, "")
|
||
ts_display = vm_watermark_ts[:16] if vm_watermark_ts else "无"
|
||
msg += f"🧠 向量记忆: {vm_count} 条摘要 (水位线: {ts_display})\n"
|
||
except Exception:
|
||
pass
|
||
msg += f"\n💡 /清空记忆 清空上下文 | /记忆列表 查看持久记忆"
|
||
|
||
else:
|
||
# 私聊:使用 memory 机制
|
||
memory_messages = self._get_memory_messages(chat_id)
|
||
max_messages = self.config.get("memory", {}).get("max_messages", 20)
|
||
|
||
# 计算 token
|
||
context_tokens = 0
|
||
for msg in memory_messages:
|
||
context_tokens += self._estimate_message_tokens(msg)
|
||
|
||
# 加上 system prompt 的 token
|
||
system_tokens = self._estimate_tokens(self.system_prompt)
|
||
total_tokens = system_tokens + persistent_tokens + context_tokens
|
||
|
||
# 计算百分比
|
||
context_limit = self.config.get("api", {}).get("context_limit", 200000)
|
||
usage_percent = (total_tokens / context_limit) * 100
|
||
remaining_tokens = context_limit - total_tokens
|
||
|
||
msg = f"📊 私聊上下文统计\n\n"
|
||
msg += f"💬 记忆条数: {len(memory_messages)}/{max_messages}\n"
|
||
msg += f"🤖 人设 Token: ~{system_tokens}\n"
|
||
msg += f"📌 持久记忆: {len(persistent_memories)} 条 (~{persistent_tokens} token)\n"
|
||
msg += f"📝 上下文 Token: ~{context_tokens}\n"
|
||
msg += f"📦 总计 Token: ~{total_tokens}\n"
|
||
msg += f"📈 使用率: {usage_percent:.1f}% (剩余 ~{remaining_tokens:,})\n"
|
||
msg += f"\n💡 /清空记忆 清空上下文 | /记忆列表 查看持久记忆"
|
||
|
||
await bot.send_text(from_wxid, msg)
|
||
logger.info(f"已发送上下文统计: {chat_id}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取上下文统计失败: {e}")
|
||
await bot.send_text(from_wxid, f"❌ 获取上下文统计失败: {str(e)}")
|
||
|
||
async def _handle_switch_prompt(self, bot, from_wxid: str, content: str):
|
||
"""处理切换人设指令"""
|
||
try:
|
||
# 提取文件名
|
||
parts = content.split(maxsplit=1)
|
||
if len(parts) < 2:
|
||
await bot.send_text(from_wxid, "❌ 请指定人设文件名\n格式:/切人设 文件名.txt")
|
||
return
|
||
|
||
filename = parts[1].strip()
|
||
|
||
# 检查文件是否存在
|
||
prompt_path = Path(__file__).parent / "prompts" / filename
|
||
if not prompt_path.exists():
|
||
await bot.send_text(from_wxid, f"❌ 人设文件不存在: {filename}")
|
||
return
|
||
|
||
# 读取新人设
|
||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||
new_prompt = f.read().strip()
|
||
|
||
# 更新人设
|
||
self.system_prompt = new_prompt
|
||
self.config["prompt"]["system_prompt_file"] = filename
|
||
|
||
await bot.send_text(from_wxid, f"✅ 已切换人设: {filename}")
|
||
logger.success(f"管理员切换人设: {filename}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"切换人设失败: {e}")
|
||
await bot.send_text(from_wxid, f"❌ 切换人设失败: {str(e)}")
|
||
|
||
@on_text_message(priority=80)
|
||
async def handle_message(self, bot, message: dict):
|
||
"""处理文本消息"""
|
||
content = message.get("Content", "").strip()
|
||
from_wxid = message.get("FromWxid", "")
|
||
sender_wxid = message.get("SenderWxid", "")
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
# 获取实际发送者
|
||
user_wxid = sender_wxid if is_group else from_wxid
|
||
|
||
# 获取机器人 wxid 和管理员列表
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
|
||
admins = main_config.get("Bot", {}).get("admins", [])
|
||
|
||
command_content = content
|
||
if is_group and bot_nickname:
|
||
command_content = self._strip_leading_bot_mention(content, bot_nickname)
|
||
|
||
# 检查是否是人设列表指令(精确匹配)
|
||
if command_content == "/人设列表":
|
||
await self._handle_list_prompts(bot, from_wxid)
|
||
return False
|
||
|
||
# 昵称测试:返回“微信昵称(全局)”和“群昵称/群名片(群内)”
|
||
if command_content == "/昵称测试":
|
||
if not is_group:
|
||
await bot.send_text(from_wxid, "该指令仅支持群聊:/昵称测试")
|
||
return False
|
||
|
||
wechat_nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
|
||
group_nickname = await self._get_group_display_name(bot, from_wxid, user_wxid, force_refresh=True)
|
||
|
||
wechat_nickname = self._sanitize_speaker_name(wechat_nickname) or "(未获取到)"
|
||
group_nickname = self._sanitize_speaker_name(group_nickname) or "(未设置/未获取到)"
|
||
|
||
await bot.send_text(
|
||
from_wxid,
|
||
f"微信昵称: {wechat_nickname}\n"
|
||
f"群昵称: {group_nickname}",
|
||
)
|
||
return False
|
||
|
||
# 检查是否是切换人设指令(精确匹配前缀)
|
||
if command_content.startswith("/切人设 ") or command_content.startswith("/切换人设 "):
|
||
if user_wxid in admins:
|
||
await self._handle_switch_prompt(bot, from_wxid, command_content)
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可以切换人设")
|
||
return False
|
||
|
||
# 检查是否是清空记忆指令
|
||
clear_command = self.config.get("memory", {}).get("clear_command", "/清空记忆")
|
||
if command_content == clear_command:
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
self._clear_memory(chat_id)
|
||
|
||
# 如果是群聊,还需要清空群聊历史
|
||
if is_group and self.store:
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self.store.clear_group_history(history_chat_id)
|
||
# 重置向量摘要水位线
|
||
if self._vector_memory_enabled and from_wxid in self._vector_watermarks:
|
||
self._vector_watermarks.pop(from_wxid, None)
|
||
self._save_watermarks()
|
||
await bot.send_text(from_wxid, "✅ 已清空当前群聊的记忆和历史记录")
|
||
else:
|
||
await bot.send_text(from_wxid, "✅ 已清空当前会话的记忆")
|
||
return False
|
||
|
||
# 检查是否是上下文统计指令
|
||
if command_content == "/context" or command_content == "/上下文":
|
||
await self._handle_context_stats(bot, from_wxid, user_wxid, is_group)
|
||
return False
|
||
|
||
# 旧群历史 key 扫描/清理(仅管理员)
|
||
if command_content in ("/旧群历史", "/legacy_history"):
|
||
if user_wxid in admins and self.store:
|
||
legacy_keys = self.store.find_legacy_group_history_keys()
|
||
if legacy_keys:
|
||
await bot.send_text(
|
||
from_wxid,
|
||
f"⚠️ 检测到 {len(legacy_keys)} 个旧版群历史 key(safe_id 写入)。\n"
|
||
f"如需清理请发送 /清理旧群历史",
|
||
)
|
||
else:
|
||
await bot.send_text(from_wxid, "✅ 未发现旧版群历史 key")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可执行该指令")
|
||
return False
|
||
|
||
if command_content in ("/清理旧群历史", "/clean_legacy_history"):
|
||
if user_wxid in admins and self.store:
|
||
legacy_keys = self.store.find_legacy_group_history_keys()
|
||
deleted = self.store.delete_legacy_group_history_keys(legacy_keys)
|
||
await bot.send_text(
|
||
from_wxid,
|
||
f"✅ 已清理旧版群历史 key: {deleted} 个",
|
||
)
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可执行该指令")
|
||
return False
|
||
|
||
# 检查是否是记忆状态指令(仅管理员)
|
||
if command_content == "/记忆状态":
|
||
if user_wxid in admins:
|
||
if is_group:
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
history = await self._load_history(history_chat_id)
|
||
history = self._filter_history_by_window(history)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
context_count = min(len(history), max_context)
|
||
msg = f"📊 群聊记忆: {len(history)} 条\n"
|
||
msg += f"💬 AI可见: 最近 {context_count} 条"
|
||
await bot.send_text(from_wxid, msg)
|
||
else:
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
memory = self._get_memory_messages(chat_id)
|
||
msg = f"📊 私聊记忆: {len(memory)} 条"
|
||
await bot.send_text(from_wxid, msg)
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可以查看记忆状态")
|
||
return False
|
||
|
||
# 持久记忆相关指令
|
||
# 记录持久记忆:/记录 xxx
|
||
if command_content.startswith("/记录 "):
|
||
memory_content = command_content[4:].strip()
|
||
if memory_content:
|
||
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
|
||
# 群聊用群ID,私聊用用户ID
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
chat_type = "group" if is_group else "private"
|
||
memory_id = self._add_persistent_memory(
|
||
memory_chat_id, chat_type, user_wxid, nickname, memory_content
|
||
)
|
||
await bot.send_text(from_wxid, f"✅ 已记录到持久记忆 (ID: {memory_id})")
|
||
logger.info(f"添加持久记忆: {memory_chat_id} - {memory_content[:30]}...")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 请输入要记录的内容\n格式:/记录 要记住的内容")
|
||
return False
|
||
|
||
# 查看持久记忆列表(所有人可用)
|
||
if command_content == "/记忆列表" or command_content == "/持久记忆":
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
memories = self._get_persistent_memories(memory_chat_id)
|
||
if memories:
|
||
msg = f"📋 持久记忆列表 (共 {len(memories)} 条)\n\n"
|
||
for m in memories:
|
||
time_str = m['time'][:16] if m['time'] else "未知"
|
||
content_preview = m['content'][:30] + "..." if len(m['content']) > 30 else m['content']
|
||
msg += f"[{m['id']}] {m['nickname']}: {content_preview}\n 📅 {time_str}\n"
|
||
msg += f"\n💡 删除记忆:/删除记忆 ID (管理员)"
|
||
else:
|
||
msg = "📋 暂无持久记忆"
|
||
await bot.send_text(from_wxid, msg)
|
||
return False
|
||
|
||
# 删除持久记忆(管理员)
|
||
if command_content.startswith("/删除记忆 "):
|
||
if user_wxid in admins:
|
||
try:
|
||
memory_id = int(command_content[6:].strip())
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
if self._delete_persistent_memory(memory_chat_id, memory_id):
|
||
await bot.send_text(from_wxid, f"✅ 已删除记忆 ID: {memory_id}")
|
||
else:
|
||
await bot.send_text(from_wxid, f"❌ 未找到记忆 ID: {memory_id}")
|
||
except ValueError:
|
||
await bot.send_text(from_wxid, "❌ 请输入有效的记忆ID\n格式:/删除记忆 ID")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可以删除持久记忆")
|
||
return False
|
||
|
||
# 清空所有持久记忆(管理员)
|
||
if command_content == "/清空持久记忆":
|
||
if user_wxid in admins:
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
deleted_count = self._clear_persistent_memories(memory_chat_id)
|
||
await bot.send_text(from_wxid, f"✅ 已清空 {deleted_count} 条持久记忆")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可以清空持久记忆")
|
||
return False
|
||
|
||
# 查看向量记忆(群聊可用)
|
||
if command_content == "/向量记忆" or command_content == "/vector_memory":
|
||
if not is_group:
|
||
await bot.send_text(from_wxid, "❌ 向量记忆仅在群聊中可用")
|
||
return False
|
||
if not self._vector_memory_enabled:
|
||
await bot.send_text(from_wxid, "❌ 向量记忆功能未启用")
|
||
return False
|
||
items = self._get_vector_memories_for_display(from_wxid)
|
||
if not items:
|
||
await bot.send_text(from_wxid, "📭 当前群聊暂无向量记忆")
|
||
return False
|
||
try:
|
||
html = self._build_vector_memory_html(items, from_wxid)
|
||
img_path = await self._render_vector_memory_image(html)
|
||
if img_path:
|
||
await bot.send_image(from_wxid, img_path)
|
||
# 清理临时文件
|
||
try:
|
||
Path(img_path).unlink(missing_ok=True)
|
||
except Exception:
|
||
pass
|
||
else:
|
||
# 渲染失败,降级为文本
|
||
msg = f"🧠 向量记忆 (共 {len(items)} 条摘要)\n\n"
|
||
for i, item in enumerate(items, 1):
|
||
preview = item['summary'][:80] + "..." if len(item['summary']) > 80 else item['summary']
|
||
msg += f"#{i} {preview}\n\n"
|
||
await bot.send_text(from_wxid, msg.strip())
|
||
except Exception as e:
|
||
logger.error(f"[VectorMemory] 展示失败: {e}")
|
||
await bot.send_text(from_wxid, f"❌ 向量记忆展示失败: {e}")
|
||
return False
|
||
|
||
# 检查是否应该回复
|
||
should_reply = self._should_reply(message, content, bot_wxid)
|
||
|
||
# 获取用户昵称(用于历史记录)- 使用缓存优化
|
||
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
|
||
|
||
# 提取实际消息内容(去除@),仅在需要回复时使用
|
||
actual_content = ""
|
||
if should_reply:
|
||
actual_content = self._extract_content(message, content)
|
||
|
||
# 保存到群组历史记录(默认全量保存;可配置为仅保存触发 AI 的消息,减少上下文污染/串线)
|
||
# 但如果是 AutoReply 触发的,跳过保存(消息已经在正常流程中保存过了)
|
||
if is_group and not message.get('_auto_reply_triggered'):
|
||
if self._should_capture_group_history(is_triggered=bool(should_reply)):
|
||
# mention 模式下,群聊里@机器人仅作为触发条件,不进入上下文,避免同一句话在上下文中出现两种形式(含@/不含@)
|
||
trigger_mode = self.config.get("behavior", {}).get("trigger_mode", "mention")
|
||
history_content = content
|
||
if trigger_mode == "mention" and should_reply and actual_content:
|
||
history_content = actual_content
|
||
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(history_chat_id, nickname, history_content, sender_wxid=user_wxid)
|
||
|
||
# 向量长期记忆:检查是否需要触发摘要
|
||
if self._vector_memory_enabled:
|
||
await self._maybe_trigger_summarize(from_wxid)
|
||
|
||
# 如果不需要回复,直接返回
|
||
if not should_reply:
|
||
return
|
||
|
||
# 限流检查(仅在需要回复时检查)
|
||
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
|
||
if not allowed:
|
||
rate_limit_config = self.config.get("rate_limit", {})
|
||
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
|
||
msg = msg.format(seconds=reset_time)
|
||
await bot.send_text(from_wxid, msg)
|
||
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
|
||
return False
|
||
|
||
if not actual_content:
|
||
return
|
||
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
|
||
async with self._reply_lock_context(chat_id):
|
||
logger.info(f"AI 处理消息: {actual_content[:50]}...")
|
||
|
||
try:
|
||
# 如果是 AutoReply 触发的,不重复添加用户消息(已在正常流程中添加)
|
||
if not message.get('_auto_reply_triggered'):
|
||
self._add_to_memory(chat_id, "user", actual_content)
|
||
|
||
# 群聊:消息已写入 history,则不再重复附加到 LLM messages,避免“同一句话发给AI两次”
|
||
history_enabled = bool(self.store) and self.config.get("history", {}).get("enabled", True)
|
||
captured_to_history = bool(
|
||
is_group
|
||
and history_enabled
|
||
and not message.get('_auto_reply_triggered')
|
||
and self._should_capture_group_history(is_triggered=True)
|
||
)
|
||
append_user_message = not captured_to_history
|
||
disable_tools = bool(
|
||
message.get("_auto_reply_triggered")
|
||
or message.get("_auto_reply_context")
|
||
or message.get("_disable_tools")
|
||
)
|
||
|
||
# 调用 AI API(带重试机制)
|
||
max_retries = self.config.get("api", {}).get("max_retries", 2)
|
||
response = None
|
||
last_error = None
|
||
|
||
for attempt in range(max_retries + 1):
|
||
try:
|
||
response = await self._call_ai_api(
|
||
actual_content,
|
||
bot,
|
||
from_wxid,
|
||
chat_id,
|
||
nickname,
|
||
user_wxid,
|
||
is_group,
|
||
append_user_message=append_user_message,
|
||
disable_tools=disable_tools,
|
||
)
|
||
|
||
# 检查返回值:
|
||
# - None: 工具调用已异步处理,不需要重试
|
||
# - "": 真正的空响应,需要重试
|
||
# - 有内容: 正常响应
|
||
if response is None:
|
||
# 工具调用,不重试
|
||
logger.info("AI 触发工具调用,已异步处理")
|
||
break
|
||
|
||
if response == "" and attempt < max_retries:
|
||
logger.warning(f"AI 返回空内容,重试 {attempt + 1}/{max_retries}")
|
||
await asyncio.sleep(1) # 等待1秒后重试
|
||
continue
|
||
|
||
break # 成功或已达到最大重试次数
|
||
|
||
except Exception as e:
|
||
last_error = e
|
||
if attempt < max_retries:
|
||
logger.warning(f"AI API 调用失败,重试 {attempt + 1}/{max_retries}: {e}")
|
||
await asyncio.sleep(1)
|
||
else:
|
||
raise
|
||
|
||
# 发送回复并添加到记忆
|
||
# 注意:如果返回 None 或空字符串,说明已经以其他形式处理了,不需要再发送文本
|
||
if response:
|
||
cleaned_response = self._sanitize_llm_output(response)
|
||
if cleaned_response:
|
||
await bot.send_text(from_wxid, cleaned_response)
|
||
await self._maybe_send_voice_reply(bot, from_wxid, cleaned_response, message=message)
|
||
self._add_to_memory(chat_id, "assistant", cleaned_response)
|
||
# 保存机器人回复到历史记录
|
||
history_config = self.config.get("history", {})
|
||
sync_bot_messages = history_config.get("sync_bot_messages", False)
|
||
history_scope = str(history_config.get("scope", "chatroom") or "chatroom").strip().lower()
|
||
can_rely_on_hook = bool(sync_bot_messages and history_scope not in ("per_user", "user", "peruser"))
|
||
if is_group and not can_rely_on_hook:
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(
|
||
history_chat_id,
|
||
bot_nickname,
|
||
cleaned_response,
|
||
role="assistant",
|
||
sender_wxid=user_wxid,
|
||
)
|
||
logger.success(f"AI 回复成功: {cleaned_response[:50]}...")
|
||
else:
|
||
logger.warning("AI 回复清洗后为空(可能只包含思维链/格式标记),已跳过发送")
|
||
else:
|
||
logger.info("AI 回复为空或已通过其他方式发送(如聊天记录)")
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_detail = traceback.format_exc()
|
||
logger.error(f"AI 处理失败: {type(e).__name__}: {str(e)}")
|
||
logger.error(f"详细错误:\n{error_detail}")
|
||
await bot.send_text(from_wxid, "抱歉,我遇到了一些问题,请稍后再试。")
|
||
|
||
def _should_reply(self, message: dict, content: str, bot_wxid: str = None) -> bool:
|
||
"""判断是否应该回复"""
|
||
from_wxid = message.get("FromWxid", "")
|
||
logger.debug(f"[AIChat] _should_reply 检查: from={from_wxid}, content={content[:30]}")
|
||
|
||
# 检查是否由AutoReply插件触发
|
||
if message.get('_auto_reply_triggered'):
|
||
logger.debug(f"[AIChat] AutoReply 触发,返回 True")
|
||
return True
|
||
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
# 检查群聊/私聊开关
|
||
if is_group and not self.config["behavior"].get("reply_group", True):
|
||
logger.debug(f"[AIChat] 群聊回复未启用,返回 False")
|
||
return False
|
||
if not is_group and not self.config["behavior"].get("reply_private", True):
|
||
return False
|
||
|
||
trigger_mode = self.config["behavior"].get("trigger_mode", "mention")
|
||
|
||
# all 模式:回复所有消息
|
||
if trigger_mode == "all":
|
||
return True
|
||
|
||
# mention 模式:检查是否@了机器人
|
||
if trigger_mode == "mention":
|
||
if is_group:
|
||
ats = message.get("Ats", [])
|
||
|
||
# 如果没有 bot_wxid,从配置文件读取
|
||
if not bot_wxid:
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
|
||
else:
|
||
# 也需要读取昵称用于备用检测
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
|
||
|
||
# 方式1:检查 @ 列表中是否包含机器人的 wxid
|
||
if ats and bot_wxid and bot_wxid in ats:
|
||
return True
|
||
|
||
# 方式2:备用检测 - 从消息内容中检查是否包含 @机器人昵称
|
||
# (当 API 没有返回 at_user_list 时使用)
|
||
if bot_nickname and f"@{bot_nickname}" in content:
|
||
logger.debug(f"通过内容检测到 @{bot_nickname},触发回复")
|
||
return True
|
||
|
||
return False
|
||
else:
|
||
# 私聊直接回复
|
||
return True
|
||
|
||
# keyword 模式:检查关键词
|
||
if trigger_mode == "keyword":
|
||
keywords = self.config["behavior"]["keywords"]
|
||
return any(kw in content for kw in keywords)
|
||
|
||
return False
|
||
|
||
def _extract_content(self, message: dict, content: str) -> str:
|
||
"""提取实际消息内容(去除@等)"""
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
if is_group:
|
||
# 群聊消息,去除@部分
|
||
# 格式通常是 "@昵称 消息内容"
|
||
parts = content.split(maxsplit=1)
|
||
if len(parts) > 1 and parts[0].startswith("@"):
|
||
return parts[1].strip()
|
||
return content.strip()
|
||
|
||
return content.strip()
|
||
|
||
def _strip_leading_bot_mention(self, content: str, bot_nickname: str) -> str:
|
||
"""去除开头的 @机器人昵称,便于识别命令"""
|
||
if not bot_nickname:
|
||
return content
|
||
prefix = f"@{bot_nickname}"
|
||
if not content.startswith(prefix):
|
||
return content
|
||
parts = content.split(maxsplit=1)
|
||
if len(parts) < 2:
|
||
return ""
|
||
return parts[1].strip()
|
||
|
||
async def _call_ai_api(
|
||
self,
|
||
user_message: str,
|
||
bot=None,
|
||
from_wxid: str = None,
|
||
chat_id: str = None,
|
||
nickname: str = "",
|
||
user_wxid: str = None,
|
||
is_group: bool = False,
|
||
*,
|
||
append_user_message: bool = True,
|
||
tool_query: str | None = None,
|
||
disable_tools: bool = False,
|
||
) -> str:
|
||
"""调用 AI API"""
|
||
api_config = self.config["api"]
|
||
|
||
# 收集工具
|
||
if disable_tools:
|
||
all_tools = []
|
||
available_tool_names = set()
|
||
tools = []
|
||
logger.info("AutoReply 模式:已禁用工具调用")
|
||
else:
|
||
all_tools = self._collect_tools()
|
||
available_tool_names = {
|
||
t.get("function", {}).get("name", "")
|
||
for t in (all_tools or [])
|
||
if isinstance(t, dict) and t.get("function", {}).get("name")
|
||
}
|
||
selected_tools = await self._select_tools_for_message_async(all_tools, user_message=user_message, tool_query=tool_query)
|
||
tools = self._prepare_tools_for_llm(selected_tools)
|
||
logger.info(f"收集到 {len(all_tools)} 个工具函数,本次启用 {len(tools)} 个")
|
||
if tools:
|
||
tool_names = [t["function"]["name"] for t in tools]
|
||
logger.info(f"本次启用工具: {tool_names}")
|
||
|
||
# 构建消息列表
|
||
system_content = self.system_prompt
|
||
|
||
# 添加当前时间信息
|
||
current_time = datetime.now()
|
||
weekday_map = {
|
||
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
|
||
4: "星期五", 5: "星期六", 6: "星期日"
|
||
}
|
||
weekday = weekday_map[current_time.weekday()]
|
||
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
|
||
system_content += f"\n\n当前时间:{time_str}"
|
||
|
||
if nickname:
|
||
system_content += f"\n当前对话用户的昵称是:{nickname}"
|
||
if self._tool_rule_prompt_enabled:
|
||
system_content += self._build_tool_rules_prompt(tools)
|
||
|
||
# 加载持久记忆
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
if memory_chat_id:
|
||
persistent_memories = self._get_persistent_memories(memory_chat_id)
|
||
if persistent_memories:
|
||
system_content += "\n\n【持久记忆】以下是用户要求你记住的重要信息:\n"
|
||
for m in persistent_memories:
|
||
mem_time = m['time'][:10] if m['time'] else ""
|
||
system_content += f"- [{mem_time}] {m['nickname']}: {m['content']}\n"
|
||
|
||
# 向量长期记忆检索
|
||
if is_group and from_wxid and self._vector_memory_enabled:
|
||
vector_mem = await self._retrieve_vector_memories(from_wxid, user_message)
|
||
if vector_mem:
|
||
system_content += vector_mem
|
||
|
||
messages = [{"role": "system", "content": system_content}]
|
||
|
||
# 从 JSON 历史记录加载上下文(仅群聊)
|
||
if is_group and from_wxid:
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid or "")
|
||
history = await self._load_history(history_chat_id)
|
||
history = self._filter_history_by_window(history)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
|
||
# 取最近的 N 条消息作为上下文
|
||
recent_history = history[-max_context:] if len(history) > max_context else history
|
||
|
||
# 转换为 AI 消息格式(按 role)
|
||
self._append_group_history_messages(messages, recent_history)
|
||
else:
|
||
# 私聊使用原有的 memory 机制
|
||
if chat_id:
|
||
memory_messages = self._get_memory_messages(chat_id)
|
||
if memory_messages and len(memory_messages) > 1:
|
||
messages.extend(memory_messages[:-1])
|
||
|
||
# 添加当前用户消息
|
||
if append_user_message:
|
||
current_marker = "【当前消息】"
|
||
if is_group and nickname:
|
||
# 群聊使用结构化格式,当前消息使用当前时间
|
||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M")
|
||
formatted_content = self._format_user_message_content(nickname, user_message, current_time, "text")
|
||
formatted_content = f"{current_marker}\n{formatted_content}"
|
||
messages.append({"role": "user", "content": formatted_content})
|
||
else:
|
||
messages.append({"role": "user", "content": f"{current_marker}\n{user_message}"})
|
||
|
||
async def _finalize_response(full_content: str, tool_calls_data: list):
|
||
# 过滤掉模型“幻觉出来”的工具调用(未在本次请求提供 tools 的情况下不应执行)
|
||
allowed_tool_names = {
|
||
t.get("function", {}).get("name", "")
|
||
for t in (tools or [])
|
||
if isinstance(t, dict) and t.get("function", {}).get("name")
|
||
}
|
||
if tool_calls_data:
|
||
unsupported = []
|
||
filtered = []
|
||
for tc in tool_calls_data:
|
||
fn = (tc or {}).get("function", {}).get("name", "")
|
||
if not fn:
|
||
continue
|
||
if not allowed_tool_names or fn not in allowed_tool_names:
|
||
unsupported.append(fn)
|
||
continue
|
||
filtered.append(tc)
|
||
if unsupported:
|
||
logger.warning(f"检测到未提供/未知的工具调用,已忽略: {unsupported}")
|
||
tool_calls_data = filtered
|
||
|
||
# 兼容:模型偶发输出“文本工具调用”写法(不走 tool_calls),尝试转成真实工具调用
|
||
if not tool_calls_data and full_content:
|
||
legacy = self._extract_legacy_text_search_tool_call(full_content)
|
||
if legacy:
|
||
legacy_tool, legacy_args = legacy
|
||
# 兼容:有的模型会用旧名字/文本格式输出搜索工具调用
|
||
# 1) 优先映射到“本次提供给模型的工具”(尊重 smart_select)
|
||
# 2) 若本次未提供搜索工具但用户确实在问信息类问题,可降级启用全局可用的搜索工具(仅限搜索)
|
||
preferred = None
|
||
if legacy_tool in allowed_tool_names:
|
||
preferred = legacy_tool
|
||
elif "tavily_web_search" in allowed_tool_names:
|
||
preferred = "tavily_web_search"
|
||
elif "web_search" in allowed_tool_names:
|
||
preferred = "web_search"
|
||
elif self._looks_like_info_query(user_message):
|
||
if "tavily_web_search" in available_tool_names:
|
||
preferred = "tavily_web_search"
|
||
elif "web_search" in available_tool_names:
|
||
preferred = "web_search"
|
||
|
||
if preferred:
|
||
logger.warning(f"检测到文本形式工具调用,已转换为 Function Calling: {preferred}")
|
||
try:
|
||
if bot and from_wxid:
|
||
await bot.send_text(from_wxid, "我帮你查一下,稍等。")
|
||
except Exception:
|
||
pass
|
||
tool_calls_data = [
|
||
{
|
||
"id": f"legacy_{uuid.uuid4().hex[:8]}",
|
||
"type": "function",
|
||
"function": {
|
||
"name": preferred,
|
||
"arguments": json.dumps(legacy_args, ensure_ascii=False),
|
||
},
|
||
}
|
||
]
|
||
|
||
# 兼容:文本输出的绘图工具调用 JSON / python 调用
|
||
if not tool_calls_data and full_content:
|
||
legacy_img = self._extract_legacy_text_image_tool_call(full_content)
|
||
if legacy_img:
|
||
legacy_tool, legacy_args = legacy_img
|
||
tools_cfg = (self.config or {}).get("tools", {})
|
||
loose_image_tool = tools_cfg.get("loose_image_tool", True)
|
||
preferred = self._resolve_image_tool_alias(
|
||
legacy_tool,
|
||
allowed_tool_names,
|
||
available_tool_names,
|
||
loose_image_tool,
|
||
)
|
||
if preferred:
|
||
logger.warning(f"检测到文本绘图工具调用,已转换为 Function Calling: {preferred}")
|
||
tool_calls_data = [
|
||
{
|
||
"id": f"legacy_img_{uuid.uuid4().hex[:8]}",
|
||
"type": "function",
|
||
"function": {
|
||
"name": preferred,
|
||
"arguments": json.dumps(legacy_args, ensure_ascii=False),
|
||
},
|
||
}
|
||
]
|
||
|
||
if not tool_calls_data and allowed_tool_names and full_content:
|
||
if self._contains_tool_call_markers(full_content):
|
||
fallback_tool = None
|
||
if "tavily_web_search" in allowed_tool_names:
|
||
fallback_tool = "tavily_web_search"
|
||
elif "web_search" in allowed_tool_names:
|
||
fallback_tool = "web_search"
|
||
|
||
if fallback_tool:
|
||
fallback_query = self._extract_tool_intent_text(user_message, tool_query=tool_query) or user_message
|
||
fallback_query = str(fallback_query or "").strip()
|
||
if fallback_query:
|
||
logger.warning(f"检测到文本工具调用但未解析成功,已兜底调用: {fallback_tool}")
|
||
try:
|
||
if bot and from_wxid:
|
||
await bot.send_text(from_wxid, "我帮你查一下,稍等。")
|
||
except Exception:
|
||
pass
|
||
tool_calls_data = [
|
||
{
|
||
"id": f"fallback_{uuid.uuid4().hex[:8]}",
|
||
"type": "function",
|
||
"function": {
|
||
"name": fallback_tool,
|
||
"arguments": json.dumps({"query": fallback_query[:400]}, ensure_ascii=False),
|
||
},
|
||
}
|
||
]
|
||
|
||
if not tool_calls_data and allowed_tool_names and self._looks_like_lyrics_query(user_message):
|
||
fallback_tool = None
|
||
if "tavily_web_search" in allowed_tool_names:
|
||
fallback_tool = "tavily_web_search"
|
||
elif "web_search" in allowed_tool_names:
|
||
fallback_tool = "web_search"
|
||
|
||
if fallback_tool:
|
||
fallback_query = self._extract_tool_intent_text(user_message, tool_query=tool_query) or user_message
|
||
fallback_query = str(fallback_query or "").strip()
|
||
if fallback_query:
|
||
logger.warning(f"歌词检索未触发工具,已兜底调用: {fallback_tool}")
|
||
try:
|
||
if bot and from_wxid:
|
||
await bot.send_text(from_wxid, "我帮你查一下这句歌词,稍等。")
|
||
except Exception:
|
||
pass
|
||
tool_calls_data = [
|
||
{
|
||
"id": f"lyrics_{uuid.uuid4().hex[:8]}",
|
||
"type": "function",
|
||
"function": {
|
||
"name": fallback_tool,
|
||
"arguments": json.dumps({"query": fallback_query[:400]}, ensure_ascii=False),
|
||
},
|
||
}
|
||
]
|
||
|
||
logger.info(f"流式/非流式 API 响应完成, 内容长度: {len(full_content)}, 工具调用数: {len(tool_calls_data)}")
|
||
|
||
# 检查是否有函数调用
|
||
if tool_calls_data:
|
||
# 提示已在流式处理中发送,直接启动工具执行
|
||
logger.info(f"启动工具执行,共 {len(tool_calls_data)} 个工具")
|
||
try:
|
||
await self._record_tool_calls_to_context(
|
||
tool_calls_data,
|
||
from_wxid=from_wxid,
|
||
chat_id=chat_id,
|
||
is_group=is_group,
|
||
user_wxid=user_wxid,
|
||
)
|
||
except Exception as e:
|
||
logger.debug(f"记录工具调用到上下文失败: {e}")
|
||
if self._tool_async:
|
||
asyncio.create_task(
|
||
self._execute_tools_async(
|
||
tool_calls_data, bot, from_wxid, chat_id,
|
||
user_wxid, nickname, is_group, messages
|
||
)
|
||
)
|
||
else:
|
||
await self._execute_tools_async(
|
||
tool_calls_data, bot, from_wxid, chat_id,
|
||
user_wxid, nickname, is_group, messages
|
||
)
|
||
# 返回 None 表示工具调用已异步处理,不需要重试
|
||
return None
|
||
|
||
# 检查是否包含错误的工具调用格式
|
||
if "<tool_code>" in full_content or re.search(
|
||
r"(?i)\bprint\s*\(\s*(draw_image|generate_image|nano_ai_image_generation|flow2_ai_image_generation|jimeng_ai_image_generation|kiira2_ai_image_generation)\s*\(",
|
||
full_content,
|
||
):
|
||
logger.warning("检测到模型输出了错误的工具调用格式,拦截并返回提示")
|
||
return "抱歉,我遇到了一些技术问题,请重新描述一下你的需求~"
|
||
|
||
return self._sanitize_llm_output(full_content)
|
||
|
||
try:
|
||
if tools:
|
||
logger.debug(f"已将 {len(tools)} 个工具添加到请求中")
|
||
full_content, tool_calls_data = await self._send_dialog_api_request(
|
||
api_config,
|
||
messages,
|
||
tools,
|
||
request_tag="[对话]",
|
||
prefer_stream=True,
|
||
max_tokens=api_config.get("max_tokens", 4096),
|
||
)
|
||
return await _finalize_response(full_content, tool_calls_data)
|
||
except Exception as e:
|
||
logger.error(f"调用对话 API 失败: {e}")
|
||
raise
|
||
|
||
|
||
async def _load_history(self, chat_id: str) -> list:
|
||
"""异步读取群聊历史(委托 ContextStore)"""
|
||
if not self.store:
|
||
return []
|
||
return await self.store.load_group_history(chat_id)
|
||
|
||
async def _add_to_history(
|
||
self,
|
||
chat_id: str,
|
||
nickname: str,
|
||
content: str,
|
||
image_base64: str = None,
|
||
*,
|
||
role: str = "user",
|
||
sender_wxid: str = None,
|
||
):
|
||
"""将消息存入群聊历史(委托 ContextStore)"""
|
||
if not self.store:
|
||
return
|
||
await self.store.add_group_message(
|
||
chat_id,
|
||
nickname,
|
||
content,
|
||
image_base64=image_base64,
|
||
role=role,
|
||
sender_wxid=sender_wxid,
|
||
)
|
||
|
||
async def _add_to_history_with_id(
|
||
self,
|
||
chat_id: str,
|
||
nickname: str,
|
||
content: str,
|
||
record_id: str,
|
||
*,
|
||
role: str = "user",
|
||
sender_wxid: str = None,
|
||
):
|
||
"""带ID的历史追加, 便于后续更新(委托 ContextStore)"""
|
||
if not self.store:
|
||
return
|
||
await self.store.add_group_message(
|
||
chat_id,
|
||
nickname,
|
||
content,
|
||
record_id=record_id,
|
||
role=role,
|
||
sender_wxid=sender_wxid,
|
||
)
|
||
|
||
async def _update_history_by_id(self, chat_id: str, record_id: str, new_content: str):
|
||
"""根据ID更新历史记录(委托 ContextStore)"""
|
||
if not self.store:
|
||
return
|
||
await self.store.update_group_message_by_id(chat_id, record_id, new_content)
|
||
|
||
|
||
def _prepare_tool_calls_for_executor(
|
||
self,
|
||
tool_calls_data: list,
|
||
messages: list,
|
||
*,
|
||
user_wxid: str,
|
||
from_wxid: str,
|
||
is_group: bool,
|
||
image_base64: str | None = None,
|
||
) -> list:
|
||
prepared = []
|
||
if not tool_calls_data:
|
||
return prepared
|
||
|
||
for tool_call in tool_calls_data:
|
||
function = (tool_call or {}).get("function") or {}
|
||
function_name = function.get("name", "")
|
||
if not function_name:
|
||
continue
|
||
|
||
tool_call_id = (tool_call or {}).get("id", "")
|
||
if not tool_call_id:
|
||
tool_call_id = f"call_{uuid.uuid4().hex[:8]}"
|
||
tool_call["id"] = tool_call_id
|
||
|
||
raw_arguments = function.get("arguments", "{}")
|
||
try:
|
||
arguments = json.loads(raw_arguments) if raw_arguments else {}
|
||
if not isinstance(arguments, dict):
|
||
arguments = {}
|
||
except Exception:
|
||
arguments = {}
|
||
if "function" not in tool_call:
|
||
tool_call["function"] = {}
|
||
tool_call["function"]["arguments"] = "{}"
|
||
|
||
if function_name in ("tavily_web_search", "web_search"):
|
||
raw_query = arguments.get("query", "")
|
||
cleaned_query = self._normalize_search_query(raw_query)
|
||
if cleaned_query:
|
||
arguments["query"] = cleaned_query[:400]
|
||
if "function" not in tool_call:
|
||
tool_call["function"] = {}
|
||
tool_call["function"]["arguments"] = json.dumps(arguments, ensure_ascii=False)
|
||
elif not arguments.get("query"):
|
||
fallback_query = self._extract_tool_intent_text(self._extract_last_user_text(messages))
|
||
fallback_query = str(fallback_query or "").strip()
|
||
if fallback_query:
|
||
arguments["query"] = fallback_query[:400]
|
||
if "function" not in tool_call:
|
||
tool_call["function"] = {}
|
||
tool_call["function"]["arguments"] = json.dumps(arguments, ensure_ascii=False)
|
||
|
||
exec_args = dict(arguments)
|
||
exec_args["user_wxid"] = user_wxid or from_wxid
|
||
exec_args["is_group"] = bool(is_group)
|
||
|
||
if image_base64 and function_name in ("flow2_ai_image_generation", "nano_ai_image_generation", "grok_video_generation"):
|
||
exec_args["image_base64"] = image_base64
|
||
logger.info("[异步-图片] 图生图工具,已添加图片数据")
|
||
|
||
prepared.append({
|
||
"id": tool_call_id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": function_name,
|
||
"arguments": json.dumps(exec_args, ensure_ascii=False),
|
||
},
|
||
})
|
||
|
||
return prepared
|
||
|
||
async def _execute_tools_async(self, tool_calls_data: list, bot, from_wxid: str,
|
||
chat_id: str, user_wxid: str, nickname: str, is_group: bool,
|
||
messages: list):
|
||
"""
|
||
异步执行工具调用(不阻塞主流程)
|
||
|
||
AI 已经先回复用户,这里异步执行工具,完成后发送结果
|
||
支持 need_ai_reply 标记:工具结果回传给 AI 继续对话(保留上下文和人设)
|
||
"""
|
||
try:
|
||
logger.info(f"开始异步执行 {len(tool_calls_data)} 个工具调用")
|
||
|
||
concurrency_config = (self.config or {}).get("tools", {}).get("concurrency", {})
|
||
max_concurrent = concurrency_config.get("max_concurrent", 5)
|
||
parallel_tools = True
|
||
if self._serial_reply:
|
||
max_concurrent = 1
|
||
parallel_tools = False
|
||
timeout_config = (self.config or {}).get("tools", {}).get("timeout", {})
|
||
default_timeout = timeout_config.get("default", 60)
|
||
|
||
executor = ToolExecutor(default_timeout=default_timeout, max_parallel=max_concurrent)
|
||
prepared_tool_calls = self._prepare_tool_calls_for_executor(
|
||
tool_calls_data,
|
||
messages,
|
||
user_wxid=user_wxid,
|
||
from_wxid=from_wxid,
|
||
is_group=is_group,
|
||
)
|
||
|
||
if not prepared_tool_calls:
|
||
logger.info("[异步] 没有可执行的工具调用")
|
||
return
|
||
|
||
logger.info(f"[异步] 开始执行 {len(prepared_tool_calls)} 个工具 (最大并发: {max_concurrent})")
|
||
results = await executor.execute_batch(prepared_tool_calls, bot, from_wxid, parallel=parallel_tools)
|
||
followup_results = []
|
||
|
||
for result in results:
|
||
function_name = result.name
|
||
tool_call_id = result.id
|
||
tool_message = self._sanitize_llm_output(result.message or "")
|
||
|
||
if result.success:
|
||
logger.success(f"[异步] 工具 {function_name} 执行成功")
|
||
else:
|
||
logger.warning(f"[异步] 工具 {function_name} 执行失败: {result.error or result.message}")
|
||
|
||
if self._tool_followup_ai_reply:
|
||
should_followup = result.need_ai_reply or ((not result.no_reply) and (not result.already_sent))
|
||
logger.info(f"[异步] 工具 {function_name}: need_ai_reply={result.need_ai_reply}, already_sent={result.already_sent}, no_reply={result.no_reply}, should_followup={should_followup}")
|
||
if should_followup:
|
||
followup_results.append({
|
||
"tool_call_id": tool_call_id,
|
||
"function_name": function_name,
|
||
"result": tool_message,
|
||
"success": result.success,
|
||
})
|
||
continue
|
||
|
||
logger.info(f"[异步] 工具 {function_name} 结果: need_ai_reply={result.need_ai_reply}, success={result.success}")
|
||
if result.need_ai_reply:
|
||
logger.info(f"[异步] 工具 {function_name} 需要 AI 回复,加入 followup_results")
|
||
followup_results.append({
|
||
"tool_call_id": tool_call_id,
|
||
"function_name": function_name,
|
||
"result": tool_message,
|
||
"success": result.success,
|
||
})
|
||
continue
|
||
|
||
if result.success and not result.already_sent and tool_message and not result.no_reply:
|
||
if result.send_result_text:
|
||
if tool_message:
|
||
await bot.send_text(from_wxid, tool_message)
|
||
else:
|
||
logger.warning(f"[异步] 工具 {function_name} 输出清洗后为空,已跳过发送")
|
||
|
||
if not result.success and not result.no_reply:
|
||
try:
|
||
if tool_message:
|
||
await bot.send_text(from_wxid, f"? {tool_message}")
|
||
else:
|
||
await bot.send_text(from_wxid, f"? {function_name} 执行失败")
|
||
except Exception:
|
||
pass
|
||
|
||
if result.save_to_memory and chat_id and tool_message:
|
||
self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {tool_message}")
|
||
|
||
if followup_results:
|
||
await self._continue_with_tool_results(
|
||
followup_results, bot, from_wxid, user_wxid, chat_id,
|
||
nickname, is_group, messages, tool_calls_data
|
||
)
|
||
|
||
logger.info(f"[异步] 所有工具执行完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[异步] 工具执行总体异常: {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
try:
|
||
await bot.send_text(from_wxid, "? 工具执行过程中出现错误")
|
||
except:
|
||
pass
|
||
|
||
async def _continue_with_tool_results(self, tool_results: list, bot, from_wxid: str,
|
||
user_wxid: str, chat_id: str, nickname: str, is_group: bool,
|
||
messages: list, tool_calls_data: list):
|
||
"""
|
||
基于工具结果继续调用 AI 对话(保留上下文和人设)
|
||
|
||
用于 need_ai_reply=True 的工具,如联网搜索等
|
||
"""
|
||
import json
|
||
|
||
try:
|
||
logger.info(f"[工具回传] 开始基于 {len(tool_results)} 个工具结果继续对话")
|
||
|
||
# 构建包含工具调用和结果的消息
|
||
# 1. 添加 assistant 的工具调用消息
|
||
tool_calls_msg = []
|
||
for tool_call in tool_calls_data:
|
||
tool_call_id = tool_call.get("id", "")
|
||
function_name = tool_call.get("function", {}).get("name", "")
|
||
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
|
||
|
||
# 只添加需要 AI 回复的工具
|
||
for tr in tool_results:
|
||
if tr["tool_call_id"] == tool_call_id:
|
||
tool_calls_msg.append({
|
||
"id": tool_call_id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": function_name,
|
||
"arguments": arguments_str
|
||
}
|
||
})
|
||
break
|
||
|
||
if tool_calls_msg:
|
||
messages.append({
|
||
"role": "assistant",
|
||
"content": None,
|
||
"tool_calls": tool_calls_msg
|
||
})
|
||
|
||
# 2. 添加工具结果消息
|
||
failed_items = []
|
||
for tr in tool_results:
|
||
if not bool(tr.get("success", True)):
|
||
failed_items.append(tr.get("function_name", "工具"))
|
||
messages.append({
|
||
"role": "tool",
|
||
"tool_call_id": tr["tool_call_id"],
|
||
"content": tr["result"]
|
||
})
|
||
|
||
# 搜索类工具回传强约束:先完整回答用户问题,再可选简短互动
|
||
search_tool_names = {"tavily_web_search", "web_search"}
|
||
has_search_tool = any(str(tr.get("function_name", "")) in search_tool_names for tr in tool_results)
|
||
if has_search_tool:
|
||
latest_user_text = self._extract_last_user_text(messages)
|
||
messages.append({
|
||
"role": "system",
|
||
"content": (
|
||
"你将基于联网搜索工具结果回答用户。"
|
||
"必须先完整回答用户原问题,覆盖所有子问题与关键细节,"
|
||
"并给出清晰要点与必要来源依据;"
|
||
"禁止只给寒暄/反问/引导句,禁止把问题再抛回用户。"
|
||
"若原问题包含多个子问题(例如A和B),必须逐项作答,不得漏项。"
|
||
"**严禁输出任何 JSON 格式、函数调用格式或工具调用格式的内容。**"
|
||
"**只输出自然语言文本回复。**"
|
||
"用户原问题如下:" + str(latest_user_text or "")
|
||
)
|
||
})
|
||
|
||
if failed_items:
|
||
failed_list = "、".join([str(x) for x in failed_items if x])
|
||
messages.append({
|
||
"role": "system",
|
||
"content": (
|
||
"你将基于工具返回结果向用户回复。"
|
||
"本轮部分工具执行失败(" + failed_list + ")。"
|
||
"请直接给出简洁、自然、可执行的中文总结:"
|
||
"先说明已获取到的有效结果,再明确失败项与可能原因,"
|
||
"最后给出下一步建议(如更换关键词/稍后重试/补充信息)。"
|
||
"不要输出 JSON、代码块或函数调用片段。"
|
||
)
|
||
})
|
||
|
||
# 3. 调用 AI 继续对话(默认不带 tools 参数,歌词搜歌场景允许放开 search_music)
|
||
api_config = self.config["api"]
|
||
user_wxid = user_wxid or from_wxid
|
||
|
||
followup_tools = None # 默认不传工具
|
||
if self._should_allow_music_followup(messages, tool_calls_data):
|
||
followup_tools = [
|
||
t for t in (self._collect_tools() or [])
|
||
if (t.get("function", {}).get("name") == "search_music")
|
||
]
|
||
if not followup_tools:
|
||
followup_tools = None # 如果没找到音乐工具,设为 None
|
||
|
||
try:
|
||
full_content, tool_calls_data = await self._send_dialog_api_request(
|
||
api_config,
|
||
messages,
|
||
followup_tools,
|
||
request_tag="[工具回传]",
|
||
prefer_stream=True,
|
||
max_tokens=api_config.get("max_tokens", 4096),
|
||
)
|
||
except Exception as req_err:
|
||
logger.error(f"[工具回传] AI API 调用失败: {req_err}")
|
||
await bot.send_text(from_wxid, "❌ AI 处理工具结果失败")
|
||
return
|
||
|
||
if tool_calls_data and followup_tools:
|
||
allowed_tool_names = {
|
||
t.get("function", {}).get("name", "")
|
||
for t in followup_tools
|
||
if isinstance(t, dict) and t.get("function", {}).get("name")
|
||
}
|
||
filtered = []
|
||
for tc in tool_calls_data:
|
||
fn = (tc or {}).get("function", {}).get("name", "")
|
||
if fn and fn in allowed_tool_names:
|
||
filtered.append(tc)
|
||
tool_calls_data = filtered
|
||
|
||
if tool_calls_data:
|
||
await self._execute_tools_async(
|
||
tool_calls_data, bot, from_wxid, chat_id,
|
||
user_wxid, nickname, is_group, messages
|
||
)
|
||
return
|
||
|
||
# 发送 AI 的回复
|
||
if full_content.strip():
|
||
cleaned_content = self._sanitize_llm_output(full_content)
|
||
if cleaned_content:
|
||
await bot.send_text(from_wxid, cleaned_content)
|
||
await self._maybe_send_voice_reply(bot, from_wxid, cleaned_content)
|
||
logger.success(f"[工具回传] AI 回复完成,长度: {len(cleaned_content)}")
|
||
else:
|
||
logger.warning("[工具回传] AI 回复清洗后为空,已跳过发送")
|
||
|
||
# 保存到历史记录
|
||
if chat_id and cleaned_content:
|
||
self._add_to_memory(chat_id, "assistant", cleaned_content)
|
||
else:
|
||
logger.warning("[工具回传] AI 返回空内容")
|
||
if failed_items:
|
||
failed_list = "、".join([str(x) for x in failed_items if x])
|
||
fallback_text = f"工具执行已完成,但部分步骤失败({failed_list})。请稍后重试,或换个更具体的问题我再帮你处理。"
|
||
else:
|
||
fallback_text = "工具执行已完成,但这次没生成可读回复。你可以让我基于结果再总结一次。"
|
||
await bot.send_text(from_wxid, fallback_text)
|
||
|
||
except Exception as e:
|
||
logger.error(f"[工具回传] 继续对话失败: {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
try:
|
||
await bot.send_text(from_wxid, "❌ 处理工具结果时出错")
|
||
except:
|
||
pass
|
||
|
||
async def _execute_tools_async_with_image(self, tool_calls_data: list, bot, from_wxid: str,
|
||
chat_id: str, user_wxid: str, nickname: str, is_group: bool,
|
||
messages: list, image_base64: str):
|
||
"""
|
||
异步执行工具调用(带图片参数,用于图生图等场景)
|
||
|
||
AI 已经先回复用户,这里异步执行工具,完成后发送结果
|
||
"""
|
||
try:
|
||
logger.info(f"[异步-图片] 开始执行 {len(tool_calls_data)} 个工具调用")
|
||
|
||
concurrency_config = (self.config or {}).get("tools", {}).get("concurrency", {})
|
||
max_concurrent = concurrency_config.get("max_concurrent", 5)
|
||
parallel_tools = True
|
||
if self._serial_reply:
|
||
max_concurrent = 1
|
||
parallel_tools = False
|
||
timeout_config = (self.config or {}).get("tools", {}).get("timeout", {})
|
||
default_timeout = timeout_config.get("default", 60)
|
||
|
||
executor = ToolExecutor(default_timeout=default_timeout, max_parallel=max_concurrent)
|
||
prepared_tool_calls = self._prepare_tool_calls_for_executor(
|
||
tool_calls_data,
|
||
messages,
|
||
user_wxid=user_wxid,
|
||
from_wxid=from_wxid,
|
||
is_group=is_group,
|
||
image_base64=image_base64,
|
||
)
|
||
|
||
if not prepared_tool_calls:
|
||
logger.info("[异步-图片] 没有可执行的工具调用")
|
||
return
|
||
|
||
logger.info(f"[异步-图片] 开始执行 {len(prepared_tool_calls)} 个工具 (最大并发: {max_concurrent})")
|
||
results = await executor.execute_batch(prepared_tool_calls, bot, from_wxid, parallel=parallel_tools)
|
||
followup_results = []
|
||
|
||
for result in results:
|
||
function_name = result.name
|
||
tool_call_id = result.id
|
||
tool_message = self._sanitize_llm_output(result.message or "")
|
||
|
||
if result.success:
|
||
logger.success(f"[异步-图片] 工具 {function_name} 执行成功")
|
||
else:
|
||
logger.warning(f"[异步-图片] 工具 {function_name} 执行失败: {result.error or result.message}")
|
||
|
||
if self._tool_followup_ai_reply:
|
||
should_followup = result.need_ai_reply or ((not result.no_reply) and (not result.already_sent))
|
||
logger.info(f"[异步] 工具 {function_name}: need_ai_reply={result.need_ai_reply}, already_sent={result.already_sent}, no_reply={result.no_reply}, should_followup={should_followup}")
|
||
if should_followup:
|
||
followup_results.append({
|
||
"tool_call_id": tool_call_id,
|
||
"function_name": function_name,
|
||
"result": tool_message,
|
||
"success": result.success,
|
||
})
|
||
continue
|
||
|
||
logger.info(f"[异步] 工具 {function_name} 结果: need_ai_reply={result.need_ai_reply}, success={result.success}")
|
||
if result.need_ai_reply:
|
||
logger.info(f"[异步] 工具 {function_name} 需要 AI 回复,加入 followup_results")
|
||
followup_results.append({
|
||
"tool_call_id": tool_call_id,
|
||
"function_name": function_name,
|
||
"result": tool_message,
|
||
"success": result.success,
|
||
})
|
||
continue
|
||
|
||
if result.success and not result.already_sent and tool_message and not result.no_reply:
|
||
if result.send_result_text:
|
||
if tool_message:
|
||
await bot.send_text(from_wxid, tool_message)
|
||
else:
|
||
logger.warning(f"[异步-图片] 工具 {function_name} 输出清洗后为空,已跳过发送")
|
||
|
||
if not result.success and not result.no_reply:
|
||
try:
|
||
if tool_message:
|
||
await bot.send_text(from_wxid, f"? {tool_message}")
|
||
else:
|
||
await bot.send_text(from_wxid, f"? {function_name} 执行失败")
|
||
except Exception:
|
||
pass
|
||
|
||
if result.save_to_memory and chat_id and tool_message:
|
||
self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {tool_message}")
|
||
|
||
if followup_results:
|
||
await self._continue_with_tool_results(
|
||
followup_results, bot, from_wxid, user_wxid, chat_id,
|
||
nickname, is_group, messages, tool_calls_data
|
||
)
|
||
|
||
logger.info(f"[异步-图片] 所有工具执行完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[异步-图片] 工具执行总体异常: {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
try:
|
||
await bot.send_text(from_wxid, "? 工具执行过程中出现错误")
|
||
except:
|
||
pass
|
||
|
||
@on_quote_message(priority=79)
|
||
async def handle_quote_message(self, bot, message: dict):
|
||
"""处理引用消息(包含图片或记录指令)"""
|
||
content = message.get("Content", "").strip()
|
||
from_wxid = message.get("FromWxid", "")
|
||
sender_wxid = message.get("SenderWxid", "")
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
user_wxid = sender_wxid if is_group else from_wxid
|
||
|
||
try:
|
||
# 群聊引用消息可能带有 "wxid:\n" 前缀,需要去除
|
||
xml_content = content
|
||
if is_group and ":\n" in content:
|
||
# 查找 XML 声明或 <msg> 标签的位置
|
||
xml_start = content.find("<?xml")
|
||
if xml_start == -1:
|
||
xml_start = content.find("<msg")
|
||
if xml_start > 0:
|
||
xml_content = content[xml_start:]
|
||
logger.debug(f"去除引用消息前缀,原长度: {len(content)}, 新长度: {len(xml_content)}")
|
||
|
||
# 解析XML获取标题和引用消息
|
||
root = ET.fromstring(xml_content)
|
||
title = root.find(".//title")
|
||
if title is None or not title.text:
|
||
logger.debug("引用消息没有标题,跳过")
|
||
return True
|
||
|
||
title_text = title.text.strip()
|
||
logger.info(f"收到引用消息,标题: {title_text[:50]}...")
|
||
|
||
# 检查是否是 /记录 指令(引用消息记录)
|
||
if title_text == "/记录" or title_text.startswith("/记录 "):
|
||
# 获取被引用的消息内容
|
||
refermsg = root.find(".//refermsg")
|
||
if refermsg is not None:
|
||
# 获取被引用消息的发送者昵称
|
||
refer_displayname = refermsg.find("displayname")
|
||
refer_nickname = refer_displayname.text if refer_displayname is not None and refer_displayname.text else "未知"
|
||
|
||
# 获取被引用消息的内容
|
||
refer_content_elem = refermsg.find("content")
|
||
if refer_content_elem is not None and refer_content_elem.text:
|
||
refer_text = refer_content_elem.text.strip()
|
||
# 如果是XML格式(如图片),尝试提取文本描述
|
||
if refer_text.startswith("<?xml") or refer_text.startswith("<"):
|
||
refer_text = f"[多媒体消息]"
|
||
else:
|
||
refer_text = "[空消息]"
|
||
|
||
# 组合记忆内容:被引用者说的话
|
||
memory_content = f"{refer_nickname}: {refer_text}"
|
||
|
||
# 如果 /记录 后面有额外备注,添加到记忆中
|
||
if title_text.startswith("/记录 "):
|
||
extra_note = title_text[4:].strip()
|
||
if extra_note:
|
||
memory_content += f" (备注: {extra_note})"
|
||
|
||
# 保存到持久记忆
|
||
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
chat_type = "group" if is_group else "private"
|
||
memory_id = self._add_persistent_memory(
|
||
memory_chat_id, chat_type, user_wxid, nickname, memory_content
|
||
)
|
||
await bot.send_text(from_wxid, f"✅ 已记录到持久记忆 (ID: {memory_id})\n📝 {memory_content[:50]}...")
|
||
logger.info(f"通过引用添加持久记忆: {memory_chat_id} - {memory_content[:30]}...")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 无法获取被引用的消息")
|
||
return False
|
||
|
||
# 检查是否应该回复
|
||
if not self._should_reply_quote(message, title_text):
|
||
logger.debug("引用消息不满足回复条件")
|
||
return True
|
||
|
||
# 获取引用消息中的图片信息
|
||
refermsg = root.find(".//refermsg")
|
||
if refermsg is None:
|
||
logger.debug("引用消息中没有 refermsg 节点")
|
||
return True
|
||
|
||
refer_content = refermsg.find("content")
|
||
if refer_content is None or not refer_content.text:
|
||
logger.debug("引用消息中没有 content")
|
||
return True
|
||
|
||
# 检查被引用消息的类型
|
||
# type=1: 纯文本,type=3: 图片,type=43: 视频,type=49: 应用消息(含聊天记录)
|
||
refer_type_elem = refermsg.find("type")
|
||
refer_type = int(refer_type_elem.text) if refer_type_elem is not None and refer_type_elem.text else 0
|
||
logger.debug(f"被引用消息类型: {refer_type}")
|
||
|
||
# 纯文本消息(type=1):如果@了机器人,转发给 AI 处理
|
||
if refer_type == 1:
|
||
if self._should_reply_quote(message, title_text):
|
||
# 获取被引用的文本内容
|
||
refer_content_elem = refermsg.find("content")
|
||
refer_text = refer_content_elem.text.strip() if refer_content_elem is not None and refer_content_elem.text else ""
|
||
|
||
# 获取被引用者昵称
|
||
refer_displayname = refermsg.find("displayname")
|
||
refer_nickname = refer_displayname.text if refer_displayname is not None and refer_displayname.text else "某人"
|
||
|
||
# 组合消息:引用内容 + 用户评论
|
||
# title_text 格式如 "@瑞依 评价下",需要去掉 @昵称 部分
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
|
||
|
||
user_comment = title_text
|
||
if bot_nickname:
|
||
# 移除 @机器人昵称(可能有空格分隔)
|
||
user_comment = user_comment.replace(f"@{bot_nickname}", "").strip()
|
||
|
||
# 构造给 AI 的消息
|
||
combined_message = f"[引用 {refer_nickname} 的消息:{refer_text}]\n{user_comment}"
|
||
logger.info(f"引用纯文本消息,转发给 AI: {combined_message[:80]}...")
|
||
|
||
# 调用 AI 处理
|
||
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
|
||
chat_id = from_wxid if is_group else user_wxid
|
||
|
||
# 保存用户消息到群组历史记录
|
||
history_enabled = bool(self.store) and self.config.get("history", {}).get("enabled", True)
|
||
sync_bot_messages = self.config.get("history", {}).get("sync_bot_messages", True)
|
||
if is_group and history_enabled:
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(history_chat_id, nickname, combined_message, sender_wxid=user_wxid)
|
||
|
||
async with self._reply_lock_context(chat_id):
|
||
ai_response = await self._call_ai_api(
|
||
combined_message,
|
||
bot=bot,
|
||
from_wxid=from_wxid,
|
||
chat_id=chat_id,
|
||
nickname=nickname
|
||
)
|
||
|
||
if ai_response:
|
||
final_response = self._sanitize_llm_output(ai_response)
|
||
await bot.send_text(from_wxid, final_response)
|
||
|
||
# 保存 AI 回复到群组历史记录
|
||
if is_group and history_enabled and sync_bot_messages:
|
||
bot_nickname_display = main_config.get("Bot", {}).get("nickname", "AI")
|
||
await self._add_to_history(history_chat_id, bot_nickname_display, final_response, role="assistant")
|
||
return False
|
||
else:
|
||
logger.debug("引用的是纯文本消息且未@机器人,跳过")
|
||
return True
|
||
|
||
# 只处理图片(3)、视频(43)、应用消息(49,含聊天记录)
|
||
if refer_type not in [3, 43, 49]:
|
||
logger.debug(f"引用的消息类型 {refer_type} 不支持处理")
|
||
return True
|
||
|
||
# 解码HTML实体
|
||
import html
|
||
refer_xml = html.unescape(refer_content.text)
|
||
|
||
# 被引用消息的内容也可能带有 "wxid:\n" 前缀,需要去除
|
||
if ":\n" in refer_xml:
|
||
xml_start = refer_xml.find("<?xml")
|
||
if xml_start == -1:
|
||
xml_start = refer_xml.find("<msg")
|
||
if xml_start > 0:
|
||
refer_xml = refer_xml[xml_start:]
|
||
logger.debug(f"去除被引用消息前缀")
|
||
|
||
# 尝试解析 XML
|
||
try:
|
||
refer_root = ET.fromstring(refer_xml)
|
||
except ET.ParseError as e:
|
||
logger.debug(f"被引用消息内容不是有效的 XML: {e}")
|
||
return True
|
||
|
||
# 尝试提取聊天记录信息(type=19)
|
||
recorditem = refer_root.find(".//recorditem")
|
||
# 尝试提取图片信息
|
||
img = refer_root.find(".//img")
|
||
# 尝试提取视频信息
|
||
video = refer_root.find(".//videomsg")
|
||
|
||
if img is None and video is None and recorditem is None:
|
||
logger.debug("引用的消息不是图片、视频或聊天记录")
|
||
return True
|
||
|
||
# 检查是否应该回复(提前检查,避免下载后才发现不需要回复)
|
||
if not self._should_reply_quote(message, title_text):
|
||
logger.debug("引用消息不满足回复条件")
|
||
return True
|
||
|
||
# 限流检查
|
||
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
|
||
if not allowed:
|
||
rate_limit_config = self.config.get("rate_limit", {})
|
||
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
|
||
msg = msg.format(seconds=reset_time)
|
||
await bot.send_text(from_wxid, msg)
|
||
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
|
||
return False
|
||
|
||
# 获取用户昵称 - 使用缓存优化
|
||
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
|
||
# 处理聊天记录消息(type=19)
|
||
if recorditem is not None:
|
||
return await self._handle_quote_chat_record(
|
||
bot, recorditem, title_text, from_wxid, user_wxid,
|
||
is_group, nickname, chat_id
|
||
)
|
||
|
||
# 处理视频消息
|
||
if video is not None:
|
||
# 提取 svrid(消息ID)用于新协议下载
|
||
svrid_elem = refermsg.find("svrid")
|
||
svrid = int(svrid_elem.text) if svrid_elem is not None and svrid_elem.text else 0
|
||
return await self._handle_quote_video(
|
||
bot, video, title_text, from_wxid, user_wxid,
|
||
is_group, nickname, chat_id, svrid
|
||
)
|
||
|
||
# 处理图片消息
|
||
# 提取 svrid 用于从缓存获取
|
||
svrid_elem = refermsg.find("svrid")
|
||
svrid = svrid_elem.text if svrid_elem is not None and svrid_elem.text else ""
|
||
|
||
logger.info(f"AI处理引用图片消息: {title_text[:50]}...")
|
||
|
||
# 1. 优先从 Redis 缓存获取(使用 svrid)
|
||
image_base64 = ""
|
||
if svrid:
|
||
try:
|
||
from utils.redis_cache import get_cache
|
||
redis_cache = get_cache()
|
||
if redis_cache and redis_cache.enabled:
|
||
media_key = f"image:svrid:{svrid}"
|
||
cached_data = redis_cache.get_cached_media(media_key, "image")
|
||
if cached_data:
|
||
logger.info(f"从缓存获取引用图片成功: {media_key}")
|
||
image_base64 = cached_data
|
||
except Exception as e:
|
||
logger.debug(f"从缓存获取图片失败: {e}")
|
||
|
||
# 2. 缓存未命中,提示用户
|
||
if not image_base64:
|
||
logger.warning(f"引用图片缓存未命中: svrid={svrid}")
|
||
await bot.send_text(from_wxid, "❌ 图片缓存已过期,请重新发送图片后再引用")
|
||
return False
|
||
|
||
logger.info("图片获取成功")
|
||
|
||
# 添加消息到记忆(包含图片base64)
|
||
self._add_to_memory(chat_id, "user", title_text, image_base64=image_base64)
|
||
|
||
# 保存用户引用图片消息到群组历史记录
|
||
if is_group and self._should_capture_group_history(is_triggered=True):
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(
|
||
history_chat_id,
|
||
nickname,
|
||
title_text,
|
||
image_base64=image_base64,
|
||
sender_wxid=user_wxid,
|
||
)
|
||
|
||
# 调用AI API(带图片)
|
||
history_enabled = bool(self.store) and self.config.get("history", {}).get("enabled", True)
|
||
captured_to_history = bool(is_group and history_enabled and self._should_capture_group_history(is_triggered=True))
|
||
append_user_message = not captured_to_history
|
||
async with self._reply_lock_context(chat_id):
|
||
response = await self._call_ai_api_with_image(
|
||
title_text,
|
||
image_base64,
|
||
bot,
|
||
from_wxid,
|
||
chat_id,
|
||
nickname,
|
||
user_wxid,
|
||
is_group,
|
||
append_user_message=append_user_message,
|
||
tool_query=title_text,
|
||
)
|
||
|
||
if response:
|
||
cleaned_response = self._sanitize_llm_output(response)
|
||
if cleaned_response:
|
||
await bot.send_text(from_wxid, cleaned_response)
|
||
await self._maybe_send_voice_reply(bot, from_wxid, cleaned_response)
|
||
self._add_to_memory(chat_id, "assistant", cleaned_response)
|
||
# 保存机器人回复到历史记录
|
||
history_config = self.config.get("history", {})
|
||
sync_bot_messages = history_config.get("sync_bot_messages", False)
|
||
history_scope = str(history_config.get("scope", "chatroom") or "chatroom").strip().lower()
|
||
can_rely_on_hook = bool(sync_bot_messages and history_scope not in ("per_user", "user", "peruser"))
|
||
if is_group and not can_rely_on_hook:
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(
|
||
history_chat_id,
|
||
bot_nickname,
|
||
cleaned_response,
|
||
role="assistant",
|
||
sender_wxid=user_wxid,
|
||
)
|
||
logger.success(f"AI回复成功: {cleaned_response[:50]}...")
|
||
else:
|
||
logger.warning("AI 回复清洗后为空,已跳过发送")
|
||
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理引用消息失败: {e}")
|
||
return True
|
||
|
||
async def _handle_quote_chat_record(self, bot, recorditem_elem, title_text: str, from_wxid: str,
|
||
user_wxid: str, is_group: bool, nickname: str, chat_id: str):
|
||
"""处理引用的聊天记录消息(type=19)"""
|
||
try:
|
||
logger.info(f"[聊天记录] 处理引用的聊天记录: {title_text[:50]}...")
|
||
|
||
# recorditem 的内容是 CDATA,需要提取并解析
|
||
record_text = recorditem_elem.text
|
||
if not record_text:
|
||
logger.warning("[聊天记录] recorditem 内容为空")
|
||
await bot.send_text(from_wxid, "❌ 无法读取聊天记录内容")
|
||
return False
|
||
|
||
# 解析 recordinfo XML
|
||
try:
|
||
record_root = ET.fromstring(record_text)
|
||
except ET.ParseError as e:
|
||
logger.error(f"[聊天记录] 解析 recordinfo 失败: {e}")
|
||
await bot.send_text(from_wxid, "❌ 聊天记录格式解析失败")
|
||
return False
|
||
|
||
# 提取聊天记录内容
|
||
datalist = record_root.find(".//datalist")
|
||
chat_records = []
|
||
|
||
# 尝试从 datalist 解析完整消息
|
||
if datalist is not None:
|
||
for dataitem in datalist.findall("dataitem"):
|
||
source_name = dataitem.find("sourcename")
|
||
source_time = dataitem.find("sourcetime")
|
||
data_desc = dataitem.find("datadesc")
|
||
|
||
sender = source_name.text if source_name is not None and source_name.text else "未知"
|
||
time_str = source_time.text if source_time is not None and source_time.text else ""
|
||
content = data_desc.text if data_desc is not None and data_desc.text else ""
|
||
|
||
if content:
|
||
chat_records.append({
|
||
"sender": sender,
|
||
"time": time_str,
|
||
"content": content
|
||
})
|
||
|
||
# 如果 datalist 为空(引用消息的简化版本),尝试从 desc 获取摘要
|
||
if not chat_records:
|
||
desc_elem = record_root.find(".//desc")
|
||
if desc_elem is not None and desc_elem.text:
|
||
# desc 格式通常是 "发送者: 内容\n发送者: 内容"
|
||
desc_text = desc_elem.text.strip()
|
||
logger.info(f"[聊天记录] 从 desc 获取摘要内容: {desc_text[:100]}...")
|
||
chat_records.append({
|
||
"sender": "聊天记录摘要",
|
||
"time": "",
|
||
"content": desc_text
|
||
})
|
||
|
||
if not chat_records:
|
||
logger.warning("[聊天记录] 没有解析到任何消息")
|
||
await bot.send_text(from_wxid, "❌ 聊天记录中没有消息内容")
|
||
return False
|
||
|
||
logger.info(f"[聊天记录] 解析到 {len(chat_records)} 条消息")
|
||
|
||
# 构建聊天记录文本
|
||
record_title = record_root.find(".//title")
|
||
title = record_title.text if record_title is not None and record_title.text else "聊天记录"
|
||
|
||
chat_text = f"【{title}】\n\n"
|
||
for i, record in enumerate(chat_records, 1):
|
||
time_part = f" ({record['time']})" if record['time'] else ""
|
||
if record['sender'] == "聊天记录摘要":
|
||
# 摘要模式,直接显示内容
|
||
chat_text += f"{record['content']}\n\n"
|
||
else:
|
||
chat_text += f"[{record['sender']}{time_part}]:\n{record['content']}\n\n"
|
||
|
||
# 构造发送给 AI 的消息
|
||
user_question = title_text.strip() if title_text.strip() else "请分析这段聊天记录"
|
||
# 去除 @ 部分
|
||
if user_question.startswith("@"):
|
||
parts = user_question.split(maxsplit=1)
|
||
if len(parts) > 1:
|
||
user_question = parts[1].strip()
|
||
else:
|
||
user_question = "请分析这段聊天记录"
|
||
|
||
combined_message = f"[用户发送了一段聊天记录,请阅读并回答问题]\n\n{chat_text}\n[用户的问题]: {user_question}"
|
||
|
||
logger.info(f"[聊天记录] 发送给 AI,消息长度: {len(combined_message)}")
|
||
|
||
# 添加到记忆
|
||
self._add_to_memory(chat_id, "user", combined_message)
|
||
|
||
# 如果是群聊,添加到历史记录
|
||
if is_group and self._should_capture_group_history(is_triggered=True):
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(
|
||
history_chat_id,
|
||
nickname,
|
||
f"[发送了聊天记录] {user_question}",
|
||
sender_wxid=user_wxid,
|
||
)
|
||
|
||
async with self._reply_lock_context(chat_id):
|
||
# 调用 AI API
|
||
response = await self._call_ai_api(
|
||
combined_message,
|
||
bot,
|
||
from_wxid,
|
||
chat_id,
|
||
nickname,
|
||
user_wxid,
|
||
is_group,
|
||
tool_query=user_question,
|
||
)
|
||
|
||
if response:
|
||
cleaned_response = self._sanitize_llm_output(response)
|
||
if cleaned_response:
|
||
await bot.send_text(from_wxid, cleaned_response)
|
||
await self._maybe_send_voice_reply(bot, from_wxid, cleaned_response)
|
||
self._add_to_memory(chat_id, "assistant", cleaned_response)
|
||
# 保存机器人回复到历史记录
|
||
history_config = self.config.get("history", {})
|
||
sync_bot_messages = history_config.get("sync_bot_messages", False)
|
||
history_scope = str(history_config.get("scope", "chatroom") or "chatroom").strip().lower()
|
||
can_rely_on_hook = bool(sync_bot_messages and history_scope not in ("per_user", "user", "peruser"))
|
||
if is_group and not can_rely_on_hook:
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(
|
||
history_chat_id,
|
||
bot_nickname,
|
||
cleaned_response,
|
||
role="assistant",
|
||
sender_wxid=user_wxid,
|
||
)
|
||
logger.success(f"[聊天记录] AI 回复成功: {cleaned_response[:50]}...")
|
||
else:
|
||
logger.warning("[聊天记录] AI 回复清洗后为空,已跳过发送")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ AI 回复生成失败")
|
||
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"[聊天记录] 处理失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
await bot.send_text(from_wxid, "❌ 聊天记录处理出错")
|
||
return False
|
||
|
||
async def _handle_quote_video(self, bot, video_elem, title_text: str, from_wxid: str,
|
||
user_wxid: str, is_group: bool, nickname: str, chat_id: str, svrid: int = 0):
|
||
"""处理引用的视频消息 - 双AI架构"""
|
||
try:
|
||
# 检查视频识别功能是否启用
|
||
video_config = self.config.get("video_recognition", {})
|
||
if not video_config.get("enabled", True):
|
||
logger.info("[视频识别] 功能未启用")
|
||
await bot.send_text(from_wxid, "❌ 视频识别功能未启用")
|
||
return False
|
||
|
||
# 提取视频长度
|
||
total_len = int(video_elem.get("length", 0))
|
||
|
||
if not svrid or not total_len:
|
||
logger.warning(f"[视频识别] 视频信息不完整: svrid={svrid}, total_len={total_len}")
|
||
await bot.send_text(from_wxid, "❌ 无法获取视频信息")
|
||
return False
|
||
|
||
logger.info(f"[视频识别] 使用新协议下载引用视频: svrid={svrid}, len={total_len}")
|
||
await bot.send_text(from_wxid, "🎬 正在分析视频,请稍候...")
|
||
|
||
video_base64 = await self._download_video_by_id(bot, svrid, total_len)
|
||
if not video_base64:
|
||
logger.error("[视频识别] 视频下载失败")
|
||
await bot.send_text(from_wxid, "❌ 视频下载失败")
|
||
return False
|
||
|
||
logger.info("[视频识别] 视频下载和编码成功")
|
||
|
||
# ========== 第一步:视频AI 分析视频内容 ==========
|
||
video_description = await self._analyze_video_content(video_base64, video_config)
|
||
if not video_description:
|
||
logger.error("[视频识别] 视频AI分析失败")
|
||
await bot.send_text(from_wxid, "❌ 视频分析失败")
|
||
return False
|
||
|
||
logger.info(f"[视频识别] 视频AI分析完成: {video_description[:100]}...")
|
||
|
||
# ========== 第二步:主AI 基于视频描述生成回复 ==========
|
||
# 构造包含视频描述的用户消息
|
||
user_question = title_text.strip() if title_text.strip() else "这个视频讲了什么?"
|
||
combined_message = f"[用户发送了一个视频,以下是视频内容描述]\n{video_description}\n\n[用户的问题]\n{user_question}"
|
||
|
||
# 添加到记忆(让主AI知道用户发了视频)
|
||
self._add_to_memory(chat_id, "user", combined_message)
|
||
|
||
# 如果是群聊,添加到历史记录
|
||
if is_group and self._should_capture_group_history(is_triggered=True):
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(
|
||
history_chat_id,
|
||
nickname,
|
||
f"[发送了一个视频] {user_question}",
|
||
sender_wxid=user_wxid,
|
||
)
|
||
|
||
async with self._reply_lock_context(chat_id):
|
||
# 调用主AI生成回复(使用现有的 _call_ai_api 方法,继承完整上下文)
|
||
response = await self._call_ai_api(
|
||
combined_message,
|
||
bot,
|
||
from_wxid,
|
||
chat_id,
|
||
nickname,
|
||
user_wxid,
|
||
is_group,
|
||
tool_query=user_question,
|
||
)
|
||
|
||
if response:
|
||
cleaned_response = self._sanitize_llm_output(response)
|
||
if cleaned_response:
|
||
await bot.send_text(from_wxid, cleaned_response)
|
||
await self._maybe_send_voice_reply(bot, from_wxid, cleaned_response)
|
||
self._add_to_memory(chat_id, "assistant", cleaned_response)
|
||
# 保存机器人回复到历史记录
|
||
history_config = self.config.get("history", {})
|
||
sync_bot_messages = history_config.get("sync_bot_messages", False)
|
||
history_scope = str(history_config.get("scope", "chatroom") or "chatroom").strip().lower()
|
||
can_rely_on_hook = bool(sync_bot_messages and history_scope not in ("per_user", "user", "peruser"))
|
||
if is_group and not can_rely_on_hook:
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(
|
||
history_chat_id,
|
||
bot_nickname,
|
||
cleaned_response,
|
||
role="assistant",
|
||
sender_wxid=user_wxid,
|
||
)
|
||
logger.success(f"[视频识别] 主AI回复成功: {cleaned_response[:50]}...")
|
||
else:
|
||
logger.warning("[视频识别] 主AI回复清洗后为空,已跳过发送")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ AI 回复生成失败")
|
||
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"[视频识别] 处理视频失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
await bot.send_text(from_wxid, "❌ 视频处理出错")
|
||
return False
|
||
|
||
async def _analyze_video_content(self, video_base64: str, video_config: dict) -> str:
|
||
"""视频AI:专门分析视频内容,委托给 ImageProcessor"""
|
||
if self._image_processor:
|
||
result = await self._image_processor.analyze_video(video_base64)
|
||
# 对结果做输出清洗
|
||
return self._sanitize_llm_output(result) if result else ""
|
||
logger.warning("ImageProcessor 未初始化,无法分析视频")
|
||
return ""
|
||
|
||
async def _download_and_encode_video(self, bot, cdnurl: str, aeskey: str) -> str:
|
||
"""下载视频并转换为 base64,委托给 ImageProcessor"""
|
||
if self._image_processor:
|
||
return await self._image_processor.download_video(bot, cdnurl, aeskey)
|
||
logger.warning("ImageProcessor 未初始化,无法下载视频")
|
||
return ""
|
||
|
||
async def _download_video_by_id(self, bot, msg_id: int, total_len: int) -> str:
|
||
"""通过消息ID下载视频并转换为 base64(用于引用消息),委托给 ImageProcessor"""
|
||
if self._image_processor:
|
||
return await self._image_processor.download_video_by_id(bot, msg_id, total_len)
|
||
logger.warning("ImageProcessor 未初始化,无法下载视频")
|
||
return ""
|
||
|
||
async def _download_image_by_id(self, bot, msg_id: int, total_len: int, to_user: str = "", from_user: str = "") -> str:
|
||
"""通过消息ID下载图片并转换为 base64(用于引用消息),委托给 ImageProcessor"""
|
||
if self._image_processor:
|
||
return await self._image_processor.download_image_by_id(bot, msg_id, total_len, to_user, from_user)
|
||
logger.warning("ImageProcessor 未初始化,无法下载图片")
|
||
return ""
|
||
|
||
async def _download_image_by_cdn(self, bot, cdnurl: str, aeskey: str) -> str:
|
||
"""通过 CDN 信息下载图片并转换为 base64(用于引用消息)"""
|
||
if not cdnurl or not aeskey:
|
||
logger.warning("CDN 参数不完整,无法下载图片")
|
||
return ""
|
||
if self._image_processor:
|
||
return await self._image_processor.download_image_by_cdn(bot, cdnurl, aeskey)
|
||
logger.warning("ImageProcessor 未初始化,无法下载图片")
|
||
return ""
|
||
|
||
async def _call_ai_api_with_video(self, user_message: str, video_base64: str, bot=None,
|
||
from_wxid: str = None, chat_id: str = None,
|
||
nickname: str = "", user_wxid: str = None,
|
||
is_group: bool = False) -> str:
|
||
"""调用 Gemini 原生 API(带视频)- 继承完整上下文"""
|
||
try:
|
||
video_config = self.config.get("video_recognition", {})
|
||
|
||
# 使用视频识别专用配置
|
||
video_model = video_config.get("model", "gemini-3-pro-preview")
|
||
api_url = video_config.get("api_url", "https://api.functen.cn/v1beta/models")
|
||
api_key = video_config.get("api_key", self.config["api"]["api_key"])
|
||
|
||
# 构建完整的 API URL
|
||
full_url = f"{api_url}/{video_model}:generateContent"
|
||
|
||
# 构建系统提示(与 _call_ai_api 保持一致)
|
||
system_content = self.system_prompt
|
||
current_time = datetime.now()
|
||
weekday_map = {
|
||
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
|
||
4: "星期五", 5: "星期六", 6: "星期日"
|
||
}
|
||
weekday = weekday_map[current_time.weekday()]
|
||
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
|
||
system_content += f"\n\n当前时间:{time_str}"
|
||
|
||
if nickname:
|
||
system_content += f"\n当前对话用户的昵称是:{nickname}"
|
||
|
||
# 加载持久记忆
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
if memory_chat_id:
|
||
persistent_memories = self._get_persistent_memories(memory_chat_id)
|
||
if persistent_memories:
|
||
system_content += "\n\n【持久记忆】以下是用户要求你记住的重要信息:\n"
|
||
for m in persistent_memories:
|
||
mem_time = m['time'][:10] if m['time'] else ""
|
||
system_content += f"- [{mem_time}] {m['nickname']}: {m['content']}\n"
|
||
|
||
# 向量长期记忆检索
|
||
if is_group and from_wxid and self._vector_memory_enabled:
|
||
vector_mem = await self._retrieve_vector_memories(from_wxid, user_message)
|
||
if vector_mem:
|
||
system_content += vector_mem
|
||
|
||
# 构建历史上下文
|
||
history_context = ""
|
||
if is_group and from_wxid:
|
||
# 群聊:从 Redis/文件加载历史
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid or "")
|
||
history = await self._load_history(history_chat_id)
|
||
history = self._filter_history_by_window(history)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
recent_history = history[-max_context:] if len(history) > max_context else history
|
||
|
||
if recent_history:
|
||
history_context = "\n\n【最近的群聊记录】\n"
|
||
for msg in recent_history:
|
||
msg_nickname = msg.get("nickname", "")
|
||
msg_content = msg.get("content", "")
|
||
if isinstance(msg_content, list):
|
||
# 多模态内容,提取文本
|
||
for item in msg_content:
|
||
if item.get("type") == "text":
|
||
msg_content = item.get("text", "")
|
||
break
|
||
else:
|
||
msg_content = "[图片]"
|
||
# 限制单条消息长度
|
||
if len(str(msg_content)) > 200:
|
||
msg_content = str(msg_content)[:200] + "..."
|
||
history_context += f"[{msg_nickname}] {msg_content}\n"
|
||
else:
|
||
# 私聊:从 memory 加载
|
||
if chat_id:
|
||
memory_messages = self._get_memory_messages(chat_id)
|
||
if memory_messages:
|
||
history_context = "\n\n【最近的对话记录】\n"
|
||
for msg in memory_messages[-20:]: # 最近20条
|
||
role = msg.get("role", "")
|
||
content = msg.get("content", "")
|
||
if isinstance(content, list):
|
||
for item in content:
|
||
if item.get("type") == "text":
|
||
content = item.get("text", "")
|
||
break
|
||
else:
|
||
content = "[图片]"
|
||
role_name = "用户" if role == "user" else "你"
|
||
if len(str(content)) > 200:
|
||
content = str(content)[:200] + "..."
|
||
history_context += f"[{role_name}] {content}\n"
|
||
|
||
# 从 data:video/mp4;base64,xxx 中提取纯 base64 数据
|
||
if video_base64.startswith("data:"):
|
||
video_base64 = video_base64.split(",", 1)[1]
|
||
|
||
# 构建完整提示(人设 + 历史 + 当前问题)
|
||
full_prompt = system_content + history_context + f"\n\n【当前】用户发送了一个视频并问:{user_message or '请描述这个视频的内容'}"
|
||
|
||
# 构建 Gemini 原生格式请求
|
||
payload = {
|
||
"contents": [
|
||
{
|
||
"parts": [
|
||
{"text": full_prompt},
|
||
{
|
||
"inline_data": {
|
||
"mime_type": "video/mp4",
|
||
"data": video_base64
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
"generationConfig": {
|
||
"maxOutputTokens": video_config.get("max_tokens", 8192)
|
||
}
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=video_config.get("timeout", 360))
|
||
|
||
# 配置代理
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False) and PROXY_SUPPORT:
|
||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
except Exception as e:
|
||
logger.warning(f"[视频识别] 代理配置失败: {e}")
|
||
|
||
logger.info(f"[视频识别] 调用 Gemini API: {full_url}")
|
||
logger.debug(f"[视频识别] 提示词长度: {len(full_prompt)} 字符")
|
||
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
async with session.post(full_url, json=payload, headers=headers) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"[视频识别] API 错误: {resp.status}, {error_text[:500]}")
|
||
return ""
|
||
|
||
# 解析 Gemini 响应格式
|
||
result = await resp.json()
|
||
# 详细记录响应(用于调试)
|
||
logger.info(f"[视频识别] API 响应 keys: {list(result.keys()) if isinstance(result, dict) else type(result)}")
|
||
|
||
# 检查是否有错误
|
||
if "error" in result:
|
||
logger.error(f"[视频识别] API 返回错误: {result['error']}")
|
||
return ""
|
||
|
||
# 检查 promptFeedback(安全过滤信息)
|
||
if "promptFeedback" in result:
|
||
feedback = result["promptFeedback"]
|
||
block_reason = feedback.get("blockReason", "")
|
||
if block_reason:
|
||
logger.warning(f"[视频识别] 请求被阻止,原因: {block_reason}")
|
||
logger.warning(f"[视频识别] 安全评级: {feedback.get('safetyRatings', [])}")
|
||
return "抱歉,视频内容无法分析(内容策略限制)。"
|
||
|
||
# 提取文本内容
|
||
full_content = ""
|
||
if "candidates" in result and result["candidates"]:
|
||
logger.info(f"[视频识别] candidates 数量: {len(result['candidates'])}")
|
||
for i, candidate in enumerate(result["candidates"]):
|
||
# 检查 finishReason
|
||
finish_reason = candidate.get("finishReason", "")
|
||
if finish_reason:
|
||
logger.info(f"[视频识别] candidate[{i}] finishReason: {finish_reason}")
|
||
if finish_reason == "SAFETY":
|
||
logger.warning(f"[视频识别] 内容被安全过滤: {candidate.get('safetyRatings', [])}")
|
||
return "抱歉,视频内容无法分析。"
|
||
|
||
content = candidate.get("content", {})
|
||
parts = content.get("parts", [])
|
||
logger.info(f"[视频识别] candidate[{i}] parts 数量: {len(parts)}")
|
||
for part in parts:
|
||
if "text" in part:
|
||
full_content += part["text"]
|
||
else:
|
||
# 没有 candidates,记录完整响应
|
||
logger.error(f"[视频识别] 响应中没有 candidates: {str(result)[:500]}")
|
||
# 可能是上下文太长导致,记录 token 使用情况
|
||
if "usageMetadata" in result:
|
||
usage = result["usageMetadata"]
|
||
logger.warning(f"[视频识别] Token 使用: prompt={usage.get('promptTokenCount', 0)}, total={usage.get('totalTokenCount', 0)}")
|
||
|
||
logger.info(f"[视频识别] AI 响应完成,长度: {len(full_content)}")
|
||
|
||
# 如果没有内容,尝试简化重试
|
||
if not full_content:
|
||
logger.info("[视频识别] 尝试简化请求重试...")
|
||
return await self._call_ai_api_with_video_simple(
|
||
user_message or "请描述这个视频的内容",
|
||
video_base64,
|
||
video_config
|
||
)
|
||
|
||
return self._sanitize_llm_output(full_content)
|
||
|
||
except Exception as e:
|
||
logger.error(f"[视频识别] API 调用失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return ""
|
||
|
||
async def _call_ai_api_with_video_simple(self, user_message: str, video_base64: str, video_config: dict) -> str:
|
||
"""简化版视频识别 API 调用(不带上下文,用于降级重试)"""
|
||
try:
|
||
api_url = video_config.get("api_url", "https://api.functen.cn/v1beta/models")
|
||
api_key = video_config.get("api_key", self.config["api"]["api_key"])
|
||
model = video_config.get("model", "gemini-3-pro-preview")
|
||
|
||
full_url = f"{api_url}/{model}:generateContent"
|
||
|
||
# 简化请求:只发送用户问题和视频
|
||
payload = {
|
||
"contents": [
|
||
{
|
||
"parts": [
|
||
{"text": user_message},
|
||
{
|
||
"inline_data": {
|
||
"mime_type": "video/mp4",
|
||
"data": video_base64
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
"generationConfig": {
|
||
"maxOutputTokens": video_config.get("max_tokens", 8192)
|
||
}
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=video_config.get("timeout", 360))
|
||
|
||
logger.info(f"[视频识别-简化] 调用 API: {full_url}")
|
||
|
||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||
async with session.post(full_url, json=payload, headers=headers) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"[视频识别-简化] API 错误: {resp.status}, {error_text[:300]}")
|
||
return ""
|
||
|
||
result = await resp.json()
|
||
logger.info(f"[视频识别-简化] API 响应 keys: {list(result.keys())}")
|
||
|
||
# 提取文本
|
||
if "candidates" in result and result["candidates"]:
|
||
for candidate in result["candidates"]:
|
||
content = candidate.get("content", {})
|
||
for part in content.get("parts", []):
|
||
if "text" in part:
|
||
text = part["text"]
|
||
logger.info(f"[视频识别-简化] 成功,长度: {len(text)}")
|
||
return self._sanitize_llm_output(text)
|
||
|
||
logger.error(f"[视频识别-简化] 仍然没有 candidates: {str(result)[:300]}")
|
||
return ""
|
||
|
||
except Exception as e:
|
||
logger.error(f"[视频识别-简化] 失败: {e}")
|
||
return ""
|
||
|
||
def _should_reply_quote(self, message: dict, title_text: str) -> bool:
|
||
"""判断是否应该回复引用消息"""
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
# 检查群聊/私聊开关
|
||
if is_group and not self.config["behavior"]["reply_group"]:
|
||
return False
|
||
if not is_group and not self.config["behavior"]["reply_private"]:
|
||
return False
|
||
|
||
trigger_mode = self.config["behavior"]["trigger_mode"]
|
||
|
||
# all模式:回复所有消息
|
||
if trigger_mode == "all":
|
||
return True
|
||
|
||
# mention模式:检查是否@了机器人
|
||
if trigger_mode == "mention":
|
||
if is_group:
|
||
# 方式1:检查 Ats 字段(普通消息格式)
|
||
ats = message.get("Ats", [])
|
||
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
|
||
|
||
# 检查 Ats 列表
|
||
if bot_wxid and bot_wxid in ats:
|
||
return True
|
||
|
||
# 方式2:检查标题中是否包含 @机器人昵称(引用消息格式)
|
||
# 引用消息的 @ 信息在 title 中,如 "@瑞依 评价下"
|
||
if bot_nickname and f"@{bot_nickname}" in title_text:
|
||
logger.debug(f"引用消息标题中检测到 @{bot_nickname}")
|
||
return True
|
||
|
||
return False
|
||
else:
|
||
return True
|
||
|
||
# keyword模式:检查关键词
|
||
if trigger_mode == "keyword":
|
||
keywords = self.config["behavior"]["keywords"]
|
||
return any(kw in title_text for kw in keywords)
|
||
|
||
return False
|
||
|
||
async def _call_ai_api_with_image(
|
||
self,
|
||
user_message: str,
|
||
image_base64: str,
|
||
bot=None,
|
||
from_wxid: str = None,
|
||
chat_id: str = None,
|
||
nickname: str = "",
|
||
user_wxid: str = None,
|
||
is_group: bool = False,
|
||
*,
|
||
append_user_message: bool = True,
|
||
tool_query: str | None = None,
|
||
disable_tools: bool = False,
|
||
) -> str:
|
||
"""调用AI API(带图片)"""
|
||
api_config = self.config["api"]
|
||
if disable_tools:
|
||
all_tools = []
|
||
available_tool_names = set()
|
||
tools = []
|
||
logger.info("[图片] AutoReply 模式:已禁用工具调用")
|
||
else:
|
||
all_tools = self._collect_tools()
|
||
available_tool_names = {
|
||
t.get("function", {}).get("name", "")
|
||
for t in (all_tools or [])
|
||
if isinstance(t, dict) and t.get("function", {}).get("name")
|
||
}
|
||
selected_tools = await self._select_tools_for_message_async(all_tools, user_message=user_message, tool_query=tool_query)
|
||
tools = self._prepare_tools_for_llm(selected_tools)
|
||
logger.info(f"[图片] 收集到 {len(all_tools)} 个工具函数,本次启用 {len(tools)} 个")
|
||
if tools:
|
||
tool_names = [t["function"]["name"] for t in tools]
|
||
logger.info(f"[图片] 本次启用工具: {tool_names}")
|
||
|
||
# 构建消息列表
|
||
system_content = self.system_prompt
|
||
|
||
# 添加当前时间信息
|
||
current_time = datetime.now()
|
||
weekday_map = {
|
||
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
|
||
4: "星期五", 5: "星期六", 6: "星期日"
|
||
}
|
||
weekday = weekday_map[current_time.weekday()]
|
||
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
|
||
system_content += f"\n\n当前时间:{time_str}"
|
||
|
||
if nickname:
|
||
system_content += f"\n当前对话用户的昵称是:{nickname}"
|
||
if self._tool_rule_prompt_enabled:
|
||
system_content += self._build_tool_rules_prompt(tools)
|
||
# 加载持久记忆(与文本模式一致)
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
if memory_chat_id:
|
||
persistent_memories = self._get_persistent_memories(memory_chat_id)
|
||
if persistent_memories:
|
||
system_content += "\n\n【持久记忆】以下是用户要求你记住的重要信息:\n"
|
||
for m in persistent_memories:
|
||
mem_time = m['time'][:10] if m['time'] else ""
|
||
system_content += f"- [{mem_time}] {m['nickname']}: {m['content']}\n"
|
||
|
||
# 向量长期记忆检索
|
||
if is_group and from_wxid and self._vector_memory_enabled:
|
||
vector_mem = await self._retrieve_vector_memories(from_wxid, user_message)
|
||
if vector_mem:
|
||
system_content += vector_mem
|
||
|
||
messages = [{"role": "system", "content": system_content}]
|
||
|
||
# 添加历史上下文
|
||
if is_group and from_wxid:
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid or "")
|
||
history = await self._load_history(history_chat_id)
|
||
history = self._filter_history_by_window(history)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
recent_history = history[-max_context:] if len(history) > max_context else history
|
||
self._append_group_history_messages(messages, recent_history)
|
||
else:
|
||
if chat_id:
|
||
memory_messages = self._get_memory_messages(chat_id)
|
||
if memory_messages and len(memory_messages) > 1:
|
||
messages.extend(memory_messages[:-1])
|
||
|
||
# 添加当前用户消息(带图片)
|
||
if append_user_message:
|
||
current_marker = "【当前消息】"
|
||
if is_group and nickname:
|
||
# 群聊使用结构化格式
|
||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M")
|
||
text_value = self._format_user_message_content(nickname, user_message, current_time, "image")
|
||
else:
|
||
text_value = user_message
|
||
text_value = f"{current_marker}\n{text_value}"
|
||
messages.append({
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": text_value},
|
||
{"type": "image_url", "image_url": {"url": image_base64}}
|
||
]
|
||
})
|
||
|
||
try:
|
||
if tools:
|
||
logger.debug(f"[图片] 已将 {len(tools)} 个工具添加到请求中")
|
||
full_content, tool_calls_data = await self._send_dialog_api_request(
|
||
api_config,
|
||
messages,
|
||
tools,
|
||
request_tag="[图片]",
|
||
prefer_stream=True,
|
||
max_tokens=api_config.get("max_tokens", 4096),
|
||
)
|
||
|
||
# 检查是否有函数调用
|
||
if tool_calls_data:
|
||
# 过滤掉模型“幻觉出来”的工具调用(未在本次请求提供 tools 的情况下不应执行)
|
||
allowed_tool_names = {
|
||
t.get("function", {}).get("name", "")
|
||
for t in (tools or [])
|
||
if isinstance(t, dict) and t.get("function", {}).get("name")
|
||
}
|
||
unsupported = []
|
||
filtered = []
|
||
for tc in tool_calls_data:
|
||
fn = (tc or {}).get("function", {}).get("name", "")
|
||
if not fn:
|
||
continue
|
||
if not allowed_tool_names or fn not in allowed_tool_names:
|
||
unsupported.append(fn)
|
||
continue
|
||
filtered.append(tc)
|
||
if unsupported:
|
||
logger.warning(f"[图片] 检测到未提供/未知的工具调用,已忽略: {unsupported}")
|
||
tool_calls_data = filtered
|
||
|
||
if tool_calls_data:
|
||
# 提示已在流式处理中发送,直接启动工具执行
|
||
logger.info(f"[图片] 启动工具执行,共 {len(tool_calls_data)} 个工具")
|
||
try:
|
||
await self._record_tool_calls_to_context(
|
||
tool_calls_data,
|
||
from_wxid=from_wxid,
|
||
chat_id=chat_id,
|
||
is_group=is_group,
|
||
user_wxid=user_wxid,
|
||
)
|
||
except Exception as e:
|
||
logger.debug(f"[图片] 记录工具调用到上下文失败: {e}")
|
||
if self._tool_async:
|
||
asyncio.create_task(
|
||
self._execute_tools_async_with_image(
|
||
tool_calls_data, bot, from_wxid, chat_id,
|
||
user_wxid, nickname, is_group, messages, image_base64
|
||
)
|
||
)
|
||
else:
|
||
await self._execute_tools_async_with_image(
|
||
tool_calls_data, bot, from_wxid, chat_id,
|
||
user_wxid, nickname, is_group, messages, image_base64
|
||
)
|
||
return None
|
||
|
||
# 兼容:文本形式工具调用
|
||
if full_content:
|
||
legacy = self._extract_legacy_text_search_tool_call(full_content)
|
||
if legacy:
|
||
legacy_tool, legacy_args = legacy
|
||
# 仅允许转成“本次实际提供给模型的工具”,避免绕过 smart_select
|
||
allowed_tool_names = {
|
||
t.get("function", {}).get("name", "")
|
||
for t in (tools or [])
|
||
if isinstance(t, dict) and t.get("function", {}).get("name")
|
||
}
|
||
preferred = None
|
||
if legacy_tool in allowed_tool_names:
|
||
preferred = legacy_tool
|
||
elif "tavily_web_search" in allowed_tool_names:
|
||
preferred = "tavily_web_search"
|
||
elif "web_search" in allowed_tool_names:
|
||
preferred = "web_search"
|
||
elif self._looks_like_info_query(user_message):
|
||
if "tavily_web_search" in available_tool_names:
|
||
preferred = "tavily_web_search"
|
||
elif "web_search" in available_tool_names:
|
||
preferred = "web_search"
|
||
|
||
if preferred:
|
||
logger.warning(f"[图片] 检测到文本形式工具调用,已转换为 Function Calling: {preferred}")
|
||
try:
|
||
if bot and from_wxid:
|
||
await bot.send_text(from_wxid, "我帮你查一下,稍等。")
|
||
except Exception:
|
||
pass
|
||
tool_calls_data = [
|
||
{
|
||
"id": f"legacy_{uuid.uuid4().hex[:8]}",
|
||
"type": "function",
|
||
"function": {
|
||
"name": preferred,
|
||
"arguments": json.dumps(legacy_args, ensure_ascii=False),
|
||
},
|
||
}
|
||
]
|
||
try:
|
||
await self._record_tool_calls_to_context(
|
||
tool_calls_data,
|
||
from_wxid=from_wxid,
|
||
chat_id=chat_id,
|
||
is_group=is_group,
|
||
user_wxid=user_wxid,
|
||
)
|
||
except Exception:
|
||
pass
|
||
if self._tool_async:
|
||
asyncio.create_task(
|
||
self._execute_tools_async_with_image(
|
||
tool_calls_data,
|
||
bot,
|
||
from_wxid,
|
||
chat_id,
|
||
user_wxid,
|
||
nickname,
|
||
is_group,
|
||
messages,
|
||
image_base64,
|
||
)
|
||
)
|
||
else:
|
||
await self._execute_tools_async_with_image(
|
||
tool_calls_data,
|
||
bot,
|
||
from_wxid,
|
||
chat_id,
|
||
user_wxid,
|
||
nickname,
|
||
is_group,
|
||
messages,
|
||
image_base64,
|
||
)
|
||
return None
|
||
|
||
# 兼容:文本形式绘图工具调用 JSON
|
||
if full_content:
|
||
legacy_img = self._extract_legacy_text_image_tool_call(full_content)
|
||
if legacy_img:
|
||
legacy_tool, legacy_args = legacy_img
|
||
tools_cfg = (self.config or {}).get("tools", {})
|
||
loose_image_tool = tools_cfg.get("loose_image_tool", True)
|
||
allowed_tool_names = {
|
||
t.get("function", {}).get("name", "")
|
||
for t in (tools or [])
|
||
if isinstance(t, dict) and t.get("function", {}).get("name")
|
||
}
|
||
preferred = self._resolve_image_tool_alias(
|
||
legacy_tool,
|
||
allowed_tool_names,
|
||
available_tool_names,
|
||
loose_image_tool,
|
||
)
|
||
if preferred:
|
||
logger.warning(f"[图片] 检测到文本绘图工具调用,已转换为 Function Calling: {preferred}")
|
||
tool_calls_data = [
|
||
{
|
||
"id": f"legacy_img_{uuid.uuid4().hex[:8]}",
|
||
"type": "function",
|
||
"function": {
|
||
"name": preferred,
|
||
"arguments": json.dumps(legacy_args, ensure_ascii=False),
|
||
},
|
||
}
|
||
]
|
||
try:
|
||
await self._record_tool_calls_to_context(
|
||
tool_calls_data,
|
||
from_wxid=from_wxid,
|
||
chat_id=chat_id,
|
||
is_group=is_group,
|
||
user_wxid=user_wxid,
|
||
)
|
||
except Exception:
|
||
pass
|
||
if self._tool_async:
|
||
asyncio.create_task(
|
||
self._execute_tools_async_with_image(
|
||
tool_calls_data,
|
||
bot,
|
||
from_wxid,
|
||
chat_id,
|
||
user_wxid,
|
||
nickname,
|
||
is_group,
|
||
messages,
|
||
image_base64,
|
||
)
|
||
)
|
||
else:
|
||
await self._execute_tools_async_with_image(
|
||
tool_calls_data,
|
||
bot,
|
||
from_wxid,
|
||
chat_id,
|
||
user_wxid,
|
||
nickname,
|
||
is_group,
|
||
messages,
|
||
image_base64,
|
||
)
|
||
return None
|
||
|
||
# 检查是否包含错误的工具调用格式
|
||
if "<tool_code>" in full_content or re.search(
|
||
r"(?i)\bprint\s*\(\s*(draw_image|generate_image|nano_ai_image_generation|flow2_ai_image_generation|jimeng_ai_image_generation|kiira2_ai_image_generation)\s*\(",
|
||
full_content,
|
||
):
|
||
logger.warning("检测到模型输出了错误的工具调用格式,拦截并返回提示")
|
||
return "抱歉,我遇到了一些技术问题,请重新描述一下你的需求~"
|
||
|
||
return self._sanitize_llm_output(full_content)
|
||
|
||
except Exception as e:
|
||
logger.error(f"调用AI API失败: {e}")
|
||
raise
|
||
|
||
async def _send_chat_records(self, bot, from_wxid: str, title: str, content: str):
|
||
"""发送聊天记录格式消息"""
|
||
try:
|
||
import uuid
|
||
import time
|
||
import hashlib
|
||
import xml.etree.ElementTree as ET
|
||
|
||
is_group = from_wxid.endswith("@chatroom")
|
||
|
||
# 自动分割内容
|
||
max_length = 800
|
||
content_parts = []
|
||
|
||
if len(content) <= max_length:
|
||
content_parts = [content]
|
||
else:
|
||
lines = content.split('\n')
|
||
current_part = ""
|
||
|
||
for line in lines:
|
||
if len(current_part + line + '\n') > max_length:
|
||
if current_part:
|
||
content_parts.append(current_part.strip())
|
||
current_part = line + '\n'
|
||
else:
|
||
content_parts.append(line[:max_length])
|
||
current_part = line[max_length:] + '\n'
|
||
else:
|
||
current_part += line + '\n'
|
||
|
||
if current_part.strip():
|
||
content_parts.append(current_part.strip())
|
||
|
||
recordinfo = ET.Element("recordinfo")
|
||
info_el = ET.SubElement(recordinfo, "info")
|
||
info_el.text = title
|
||
is_group_el = ET.SubElement(recordinfo, "isChatRoom")
|
||
is_group_el.text = "1" if is_group else "0"
|
||
datalist = ET.SubElement(recordinfo, "datalist")
|
||
datalist.set("count", str(len(content_parts)))
|
||
desc_el = ET.SubElement(recordinfo, "desc")
|
||
desc_el.text = title
|
||
fromscene_el = ET.SubElement(recordinfo, "fromscene")
|
||
fromscene_el.text = "3"
|
||
|
||
for i, part in enumerate(content_parts):
|
||
di = ET.SubElement(datalist, "dataitem")
|
||
di.set("datatype", "1")
|
||
di.set("dataid", uuid.uuid4().hex)
|
||
|
||
src_local_id = str((int(time.time() * 1000) % 90000) + 10000)
|
||
new_msg_id = str(int(time.time() * 1000) + i)
|
||
create_time = str(int(time.time()) - len(content_parts) + i)
|
||
|
||
ET.SubElement(di, "srcMsgLocalid").text = src_local_id
|
||
ET.SubElement(di, "sourcetime").text = time.strftime("%Y-%m-%d %H:%M", time.localtime(int(create_time)))
|
||
ET.SubElement(di, "fromnewmsgid").text = new_msg_id
|
||
ET.SubElement(di, "srcMsgCreateTime").text = create_time
|
||
ET.SubElement(di, "sourcename").text = "AI助手"
|
||
ET.SubElement(di, "sourceheadurl").text = ""
|
||
ET.SubElement(di, "datatitle").text = part
|
||
ET.SubElement(di, "datadesc").text = part
|
||
ET.SubElement(di, "datafmt").text = "text"
|
||
ET.SubElement(di, "ischatroom").text = "1" if is_group else "0"
|
||
|
||
dataitemsource = ET.SubElement(di, "dataitemsource")
|
||
ET.SubElement(dataitemsource, "hashusername").text = hashlib.sha256(from_wxid.encode("utf-8")).hexdigest()
|
||
|
||
record_xml = ET.tostring(recordinfo, encoding="unicode")
|
||
|
||
appmsg_parts = [
|
||
"<appmsg appid=\"\" sdkver=\"0\">",
|
||
f"<title>{title}</title>",
|
||
f"<des>{title}</des>",
|
||
"<type>19</type>",
|
||
"<url>https://support.weixin.qq.com/cgi-bin/mmsupport-bin/readtemplate?t=page/favorite_record__w_unsupport</url>",
|
||
"<appattach><cdnthumbaeskey></cdnthumbaeskey><aeskey></aeskey></appattach>",
|
||
f"<recorditem><![CDATA[{record_xml}]]></recorditem>",
|
||
"<percent>0</percent>",
|
||
"</appmsg>"
|
||
]
|
||
appmsg_xml = "".join(appmsg_parts)
|
||
|
||
# 使用新的 HTTP API 发送 XML 消息
|
||
await bot.send_xml(from_wxid, appmsg_xml)
|
||
logger.success(f"已发送聊天记录: {title}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"发送聊天记录失败: {e}")
|
||
|
||
async def _process_image_to_history(self, bot, message: dict, content: str) -> bool:
|
||
"""处理图片/表情包并保存描述到 history(通用方法)"""
|
||
from_wxid = message.get("FromWxid", "")
|
||
sender_wxid = message.get("SenderWxid", "")
|
||
is_group = message.get("IsGroup", False)
|
||
user_wxid = sender_wxid if is_group else from_wxid
|
||
|
||
# 只处理群聊
|
||
if not is_group:
|
||
return True
|
||
|
||
# 检查是否启用图片描述功能
|
||
image_desc_config = self.config.get("image_description", {})
|
||
if not image_desc_config.get("enabled", True):
|
||
return True
|
||
|
||
try:
|
||
# 解析XML获取图片信息
|
||
root = ET.fromstring(content)
|
||
|
||
# 尝试查找 <img> 标签(图片消息)或 <emoji> 标签(表情包)
|
||
img = root.find(".//img")
|
||
if img is None:
|
||
img = root.find(".//emoji")
|
||
|
||
if img is None:
|
||
return True
|
||
|
||
cdnbigimgurl = img.get("cdnbigimgurl", "") or img.get("cdnurl", "")
|
||
aeskey = img.get("aeskey", "")
|
||
|
||
# 检查是否是表情包(有 cdnurl 但可能没有 aeskey)
|
||
is_emoji = img.tag == "emoji"
|
||
|
||
if not cdnbigimgurl:
|
||
return True
|
||
|
||
# 图片消息需要 aeskey,表情包不需要
|
||
if not is_emoji and not aeskey:
|
||
return True
|
||
|
||
# 获取用户昵称 - 使用缓存优化
|
||
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
|
||
# 立即插入占位符到 history
|
||
placeholder_id = str(uuid.uuid4())
|
||
await self._add_to_history_with_id(history_chat_id, nickname, "[图片: 处理中...]", placeholder_id)
|
||
logger.info(f"已插入图片占位符: {placeholder_id}")
|
||
|
||
# 将任务加入队列(不阻塞)
|
||
task = {
|
||
"bot": bot,
|
||
"history_chat_id": history_chat_id,
|
||
"nickname": nickname,
|
||
"cdnbigimgurl": cdnbigimgurl,
|
||
"aeskey": aeskey,
|
||
"is_emoji": is_emoji,
|
||
"placeholder_id": placeholder_id,
|
||
"config": image_desc_config,
|
||
"message": message # 添加完整的 message 对象供新接口使用
|
||
}
|
||
await self.image_desc_queue.put(task)
|
||
logger.info(f"图片描述任务已加入队列,当前队列长度: {self.image_desc_queue.qsize()}")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理图片消息失败: {e}")
|
||
return True
|
||
|
||
async def _image_desc_worker(self):
|
||
"""图片描述工作协程,从队列中取任务并处理"""
|
||
while True:
|
||
try:
|
||
task = await self.image_desc_queue.get()
|
||
except asyncio.CancelledError:
|
||
logger.info("图片描述工作协程收到取消信号,退出")
|
||
break
|
||
|
||
try:
|
||
await self._generate_and_update_image_description(
|
||
task["bot"], task["history_chat_id"], task["nickname"],
|
||
task["cdnbigimgurl"], task["aeskey"], task["is_emoji"],
|
||
task["placeholder_id"], task["config"], task.get("message")
|
||
)
|
||
except asyncio.CancelledError:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"图片描述工作协程异常: {e}")
|
||
finally:
|
||
try:
|
||
self.image_desc_queue.task_done()
|
||
except ValueError:
|
||
pass
|
||
|
||
async def _generate_and_update_image_description(self, bot, history_chat_id: str, nickname: str,
|
||
cdnbigimgurl: str, aeskey: str, is_emoji: bool,
|
||
placeholder_id: str, image_desc_config: dict, message: dict = None):
|
||
"""异步生成图片描述并更新 history"""
|
||
try:
|
||
# 下载并编码图片/表情包
|
||
if is_emoji:
|
||
image_base64 = await self._download_emoji_and_encode(cdnbigimgurl)
|
||
else:
|
||
# 优先使用新接口(需要完整的 message 对象)
|
||
if message:
|
||
image_base64 = await self._download_and_encode_image(bot, message)
|
||
else:
|
||
# 降级:如果没有 message 对象,使用旧方法(但会失败)
|
||
logger.warning("缺少 message 对象,图片下载可能失败")
|
||
image_base64 = ""
|
||
|
||
if not image_base64:
|
||
logger.warning(f"{'表情包' if is_emoji else '图片'}下载失败")
|
||
await self._update_history_by_id(history_chat_id, placeholder_id, "[图片]")
|
||
return
|
||
|
||
# 调用 AI 生成图片描述
|
||
description_prompt = image_desc_config.get("prompt", "请用一句话简洁地描述这张图片的主要内容。")
|
||
description = await self._generate_image_description(image_base64, description_prompt, image_desc_config)
|
||
|
||
if description:
|
||
cleaned_description = self._sanitize_llm_output(description)
|
||
await self._update_history_by_id(history_chat_id, placeholder_id, f"[图片: {cleaned_description}]")
|
||
logger.success(f"已更新图片描述: {nickname} - {cleaned_description[:30]}...")
|
||
else:
|
||
await self._update_history_by_id(history_chat_id, placeholder_id, "[图片]")
|
||
logger.warning(f"图片描述生成失败")
|
||
|
||
except asyncio.CancelledError:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"异步生成图片描述失败: {e}")
|
||
await self._update_history_by_id(history_chat_id, placeholder_id, "[图片]")
|
||
|
||
@on_image_message(priority=15)
|
||
async def handle_image_message(self, bot, message: dict):
|
||
"""处理直接发送的图片消息(生成描述并保存到 history,不触发 AI 回复)"""
|
||
logger.info("AIChat: handle_image_message 被调用")
|
||
content = message.get("Content", "")
|
||
return await self._process_image_to_history(bot, message, content)
|
||
|
||
@on_emoji_message(priority=15)
|
||
async def handle_emoji_message(self, bot, message: dict):
|
||
"""处理表情包消息(生成描述并保存到 history,不触发 AI 回复)"""
|
||
logger.info("AIChat: handle_emoji_message 被调用")
|
||
content = message.get("Content", "")
|
||
return await self._process_image_to_history(bot, message, content)
|