Files
WeChatHookBot/plugins/AIChat/main.py

6069 lines
272 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI 聊天插件
支持自定义模型、API 和人设
支持 Redis 存储对话历史和限流
"""
import asyncio
import tomllib
import aiohttp
import json
import re
import time
import copy
from contextlib import asynccontextmanager
from pathlib import Path
from datetime import datetime
from loguru import logger
from utils.plugin_base import PluginBase
from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message
from utils.redis_cache import get_cache
from utils.image_processor import ImageProcessor, MediaConfig
from utils.tool_executor import ToolExecutor
from utils.tool_registry import get_tool_registry
from utils.member_info_service import get_member_service
import xml.etree.ElementTree as ET
import base64
import uuid
# 可选导入代理支持
try:
from aiohttp_socks import ProxyConnector
PROXY_SUPPORT = True
except ImportError:
PROXY_SUPPORT = False
logger.warning("aiohttp_socks 未安装,代理功能将不可用")
# 可选导入 Chroma 向量数据库
try:
import chromadb
from chromadb.api.types import EmbeddingFunction, Documents, Embeddings
CHROMA_SUPPORT = True
except ImportError:
CHROMA_SUPPORT = False
if CHROMA_SUPPORT:
class SiliconFlowEmbedding(EmbeddingFunction):
"""调用硅基流动 API 的自定义 Embedding 函数"""
def __init__(self, api_url: str, api_key: str, model: str):
self._api_url = api_url
self._api_key = api_key
self._model = model
def __call__(self, input: Documents) -> Embeddings:
import httpx as _httpx
resp = _httpx.post(
self._api_url,
headers={
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
},
json={"model": self._model, "input": input},
timeout=30,
)
resp.raise_for_status()
data = resp.json()
return [item["embedding"] for item in data["data"]]
class AIChat(PluginBase):
"""AI 聊天插件"""
# 插件元数据
description = "AI 聊天插件,支持自定义模型和人设"
author = "ShiHao"
version = "1.0.0"
def __init__(self):
super().__init__()
self.config = None
self.system_prompt = ""
self.memory = {} # 存储每个会话的记忆 {chat_id: [messages]}
self.history_dir = None # 历史记录目录
self.history_locks = {} # 每个会话一把锁
self._reply_locks = {} # 每个会话一把回复锁(串行回复)
self._serial_reply = False
self._tool_async = True
self._tool_followup_ai_reply = True
self._tool_rule_prompt_enabled = True
self.image_desc_queue = asyncio.Queue() # 图片描述任务队列
self.image_desc_workers = [] # 工作协程列表
self.persistent_memory_db = None # 持久记忆数据库路径
self.store = None # ContextStore 实例(统一存储)
self._chatroom_member_cache = {} # {chatroom_id: (ts, {wxid: display_name})}
self._chatroom_member_cache_locks = {} # {chatroom_id: asyncio.Lock}
self._chatroom_member_cache_ttl_seconds = 3600 # 群名片缓存1小时减少协议 API 调用
self._image_processor = None # ImageProcessor 实例
# 向量长期记忆Chroma
self._vector_memory_enabled = False
self._chroma_collection = None
self._vector_watermarks = {} # {chatroom_id: str} 最后已摘要消息的时间戳
self._vector_tasks = {} # {chatroom_id: asyncio.Task} 后台摘要任务
self._watermark_file = None
async def async_init(self):
"""插件异步初始化"""
# 读取配置
config_path = Path(__file__).parent / "config.toml"
with open(config_path, "rb") as f:
self.config = tomllib.load(f)
behavior_config = self.config.get("behavior", {})
self._serial_reply = bool(behavior_config.get("serial_reply", False))
tools_config = self.config.get("tools", {})
self._tool_async = bool(tools_config.get("async_execute", True))
self._tool_followup_ai_reply = bool(tools_config.get("followup_ai_reply", True))
self._tool_rule_prompt_enabled = bool(tools_config.get("rule_prompt_enabled", True))
if self._serial_reply:
self._tool_async = False
logger.info(
f"AIChat 串行回复: {self._serial_reply}, 工具异步执行: {self._tool_async}, "
f"工具后AI总结: {self._tool_followup_ai_reply}, 工具规则注入: {self._tool_rule_prompt_enabled}"
)
# 读取人设
prompt_file = self.config["prompt"]["system_prompt_file"]
prompt_path = Path(__file__).parent / "prompts" / prompt_file
if prompt_path.exists():
with open(prompt_path, "r", encoding="utf-8") as f:
self.system_prompt = f.read().strip()
logger.success(f"已加载人设: {prompt_file}")
else:
logger.warning(f"人设文件不存在: {prompt_file},使用默认人设")
self.system_prompt = "你是一个友好的 AI 助手。"
# 检查代理配置
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5")
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
logger.info(f"AI 聊天插件已启用代理: {proxy_type}://{proxy_host}:{proxy_port}")
# 初始化历史记录目录
history_config = self.config.get("history", {})
if history_config.get("enabled", True):
history_dir_name = history_config.get("history_dir", "history")
self.history_dir = Path(__file__).parent / history_dir_name
self.history_dir.mkdir(exist_ok=True)
logger.info(f"历史记录目录: {self.history_dir}")
# 启动图片描述工作协程并发数为2
for i in range(2):
worker = asyncio.create_task(self._image_desc_worker())
self.image_desc_workers.append(worker)
logger.info("已启动 2 个图片描述工作协程")
# 初始化持久记忆数据库与统一存储
from utils.context_store import ContextStore
db_dir = Path(__file__).parent / "data"
db_dir.mkdir(exist_ok=True)
self.persistent_memory_db = db_dir / "persistent_memory.db"
self.store = ContextStore(
self.config,
self.history_dir,
self.memory,
self.history_locks,
self.persistent_memory_db,
)
self.store.init_persistent_memory_db()
# 初始化向量长期记忆Chroma
vm_config = self.config.get("vector_memory", {})
if vm_config.get("enabled", False) and CHROMA_SUPPORT:
try:
chroma_path = Path(__file__).parent / vm_config.get("chroma_db_path", "data/chroma_db")
chroma_path.mkdir(parents=True, exist_ok=True)
embedding_fn = SiliconFlowEmbedding(
api_url=vm_config.get("embedding_url", ""),
api_key=vm_config.get("embedding_api_key", ""),
model=vm_config.get("embedding_model", "BAAI/bge-m3"),
)
chroma_client = chromadb.PersistentClient(path=str(chroma_path))
self._chroma_collection = chroma_client.get_or_create_collection(
name="group_chat_summaries",
embedding_function=embedding_fn,
metadata={"hnsw:space": "cosine"},
)
self._watermark_file = Path(__file__).parent / "data" / "vector_watermarks.json"
self._load_watermarks()
self._vector_memory_enabled = True
logger.success(f"向量记忆已启用Chroma 路径: {chroma_path}")
except Exception as e:
logger.error(f"向量记忆初始化失败: {e}")
self._vector_memory_enabled = False
elif vm_config.get("enabled", False) and not CHROMA_SUPPORT:
logger.warning("向量记忆已启用但 chromadb 未安装,请 pip install chromadb")
# 初始化 ImageProcessor图片/表情/视频处理器)
temp_dir = Path(__file__).parent / "temp"
temp_dir.mkdir(exist_ok=True)
media_config = MediaConfig.from_dict(self.config)
self._image_processor = ImageProcessor(media_config, temp_dir)
logger.debug("ImageProcessor 已初始化")
logger.info(f"AI 聊天插件已加载,模型: {self.config['api']['model']}")
async def on_disable(self):
"""插件禁用时调用,清理后台任务和队列"""
await super().on_disable()
# 取消图片描述工作协程,避免重载后叠加
if self.image_desc_workers:
for worker in self.image_desc_workers:
worker.cancel()
await asyncio.gather(*self.image_desc_workers, return_exceptions=True)
self.image_desc_workers.clear()
# 清空图片描述队列
try:
while self.image_desc_queue and not self.image_desc_queue.empty():
self.image_desc_queue.get_nowait()
self.image_desc_queue.task_done()
except Exception:
pass
self.image_desc_queue = asyncio.Queue()
logger.info("AIChat 已清理后台图片描述任务")
# 取消向量摘要后台任务并保存水位线
if self._vector_tasks:
for task in self._vector_tasks.values():
if not task.done():
task.cancel()
await asyncio.gather(*self._vector_tasks.values(), return_exceptions=True)
self._vector_tasks.clear()
if self._vector_memory_enabled:
self._save_watermarks()
logger.info("AIChat 已清理向量摘要后台任务")
def _get_reply_lock(self, chat_id: str) -> asyncio.Lock:
lock = self._reply_locks.get(chat_id)
if lock is None:
lock = asyncio.Lock()
self._reply_locks[chat_id] = lock
return lock
@asynccontextmanager
async def _reply_lock_context(self, chat_id: str):
if not self._serial_reply or not chat_id:
yield
return
lock = self._get_reply_lock(chat_id)
if lock.locked():
logger.debug(f"AI 回复排队中: chat_id={chat_id}")
async with lock:
yield
def _add_persistent_memory(self, chat_id: str, chat_type: str, user_wxid: str,
user_nickname: str, content: str) -> int:
"""添加持久记忆返回记忆ID委托 ContextStore"""
if not self.store:
return -1
return self.store.add_persistent_memory(chat_id, chat_type, user_wxid, user_nickname, content)
def _get_persistent_memories(self, chat_id: str) -> list:
"""获取指定会话的所有持久记忆(委托 ContextStore"""
if not self.store:
return []
return self.store.get_persistent_memories(chat_id)
def _delete_persistent_memory(self, chat_id: str, memory_id: int) -> bool:
"""删除指定的持久记忆(委托 ContextStore"""
if not self.store:
return False
return self.store.delete_persistent_memory(chat_id, memory_id)
def _clear_persistent_memories(self, chat_id: str) -> int:
"""清空指定会话的所有持久记忆(委托 ContextStore"""
if not self.store:
return 0
return self.store.clear_persistent_memories(chat_id)
# ==================== 向量长期记忆Chroma====================
def _load_watermarks(self):
"""从文件加载向量摘要水位线"""
if self._watermark_file and self._watermark_file.exists():
try:
with open(self._watermark_file, "r", encoding="utf-8") as f:
self._vector_watermarks = json.load(f)
except Exception as e:
logger.warning(f"加载向量水位线失败: {e}")
self._vector_watermarks = {}
def _save_watermarks(self):
"""保存向量摘要水位线到文件"""
if not self._watermark_file:
return
try:
self._watermark_file.parent.mkdir(parents=True, exist_ok=True)
temp = Path(str(self._watermark_file) + ".tmp")
with open(temp, "w", encoding="utf-8") as f:
json.dump(self._vector_watermarks, f, ensure_ascii=False, indent=2)
temp.replace(self._watermark_file)
except Exception as e:
logger.warning(f"保存向量水位线失败: {e}")
async def _maybe_trigger_summarize(self, chatroom_id: str):
"""检查是否需要触发向量摘要(每 N 条新消息触发一次)
水位线使用最后一条已摘要消息的时间戳,不受历史裁剪影响。
"""
if not self._vector_memory_enabled:
return
existing = self._vector_tasks.get(chatroom_id)
if existing and not existing.done():
return
try:
history_chat_id = self._get_group_history_chat_id(chatroom_id)
history = await self._load_history(history_chat_id)
every = self.config.get("vector_memory", {}).get("summarize_every", 80)
watermark_ts = self._vector_watermarks.get(chatroom_id, "")
# 筛选水位线之后的新消息
if watermark_ts:
new_msgs = [m for m in history if str(m.get("timestamp", "")) > str(watermark_ts)]
else:
new_msgs = list(history)
logger.info(f"[VectorMemory] 检查: chatroom={chatroom_id}, history={len(history)}, new={len(new_msgs)}, watermark_ts={watermark_ts}, every={every}")
if len(new_msgs) >= every:
# 取最早的 every 条新消息做摘要
batch = new_msgs[:every]
last_ts = str(batch[-1].get("timestamp", ""))
# 乐观更新水位线
self._vector_watermarks[chatroom_id] = last_ts
self._save_watermarks()
task = asyncio.create_task(
self._do_summarize_and_store(chatroom_id, batch, last_ts)
)
self._vector_tasks[chatroom_id] = task
except Exception as e:
logger.warning(f"[VectorMemory] 触发检查失败: {e}")
async def _do_summarize_and_store(self, chatroom_id: str, messages: list, watermark_ts: str):
"""后台任务LLM 摘要 + 存入 Chroma"""
try:
logger.info(f"[VectorMemory] 触发后台摘要: {chatroom_id}, 消息数={len(messages)}")
text_block = self._format_messages_for_summary(messages)
summary = await self._call_summary_llm(text_block)
if not summary or len(summary.strip()) < 10:
logger.warning(f"[VectorMemory] 摘要结果过短,跳过本批")
return
ts_start = str(messages[0].get("timestamp", "")) if messages else ""
ts_end = str(messages[-1].get("timestamp", "")) if messages else ""
safe_ts = watermark_ts.replace(":", "-").replace(".", "-")
doc_id = f"{chatroom_id}_{safe_ts}"
self._chroma_collection.add(
ids=[doc_id],
documents=[summary],
metadatas=[{
"chatroom_id": chatroom_id,
"ts_start": ts_start,
"ts_end": ts_end,
"watermark_ts": watermark_ts,
}],
)
logger.success(f"[VectorMemory] 摘要已存储: {doc_id}, 长度={len(summary)}")
except Exception as e:
logger.error(f"[VectorMemory] 摘要存储失败: {e}")
def _format_messages_for_summary(self, messages: list) -> str:
"""将历史消息格式化为文本块供 LLM 摘要"""
lines = []
for m in messages:
nick = m.get("nickname", "未知")
content = m.get("content", "")
ts = m.get("timestamp", "")
if ts:
try:
from datetime import datetime
dt = datetime.fromtimestamp(float(ts))
ts_str = dt.strftime("%m-%d %H:%M")
except Exception:
ts_str = str(ts)[:16]
else:
ts_str = ""
prefix = f"[{ts_str}] " if ts_str else ""
lines.append(f"{prefix}{nick}: {content}")
return "\n".join(lines)
async def _call_summary_llm(self, text_block: str) -> str:
"""调用 LLM 生成群聊摘要"""
vm_config = self.config.get("vector_memory", {})
api_config = self.config.get("api", {})
model = vm_config.get("summary_model", "") or api_config.get("model", "")
max_tokens = vm_config.get("summary_max_tokens", 2048)
url = api_config.get("url", "")
api_key = api_config.get("api_key", "")
prompt = (
"请将以下群聊记录总结为一段简洁的摘要,保留关键话题、重要观点、"
"参与者的核心发言和结论。摘要应便于日后检索,不要遗漏重要信息。\n\n"
f"--- 群聊记录 ---\n{text_block}\n--- 结束 ---\n\n请输出摘要:"
)
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": 0.3,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
timeout_sec = api_config.get("timeout", 120)
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload, headers=headers, timeout=aiohttp.ClientTimeout(total=timeout_sec)) as resp:
resp.raise_for_status()
data = await resp.json()
return data["choices"][0]["message"]["content"].strip()
except Exception as e:
logger.error(f"[VectorMemory] LLM 摘要调用失败: {e}")
return ""
async def _retrieve_vector_memories(self, chatroom_id: str, query_text: str) -> str:
"""从 Chroma 检索与当前消息相关的历史摘要"""
if not self._vector_memory_enabled or not self._chroma_collection:
return ""
vm_config = self.config.get("vector_memory", {})
top_k = vm_config.get("retrieval_top_k", 3)
min_score = vm_config.get("retrieval_min_score", 0.35)
max_chars = vm_config.get("max_inject_chars", 2000)
log_candidates = bool(vm_config.get("retrieval_log_candidates", True))
log_injected = bool(vm_config.get("retrieval_log_injected", True))
try:
log_max_chars = int(vm_config.get("retrieval_log_max_chars", 180))
except (TypeError, ValueError):
log_max_chars = 180
log_max_chars = max(60, log_max_chars)
try:
logger.info(f"[VectorMemory] 检索: chatroom={chatroom_id}, query={query_text[:50]}")
results = self._chroma_collection.query(
query_texts=[query_text],
where={"chatroom_id": chatroom_id},
n_results=top_k,
)
if not results or not results.get("documents") or not results["documents"][0]:
logger.info(f"[VectorMemory] 检索无结果: chatroom={chatroom_id}")
return ""
docs = results["documents"][0]
distances = results["distances"][0] if results.get("distances") else [0] * len(docs)
logger.info(f"[VectorMemory] 检索到 {len(docs)} 条候选, distances={[round(d, 3) for d in distances]}")
ids = results.get("ids", [])
ids = ids[0] if ids and isinstance(ids[0], list) else []
pieces = []
total_len = 0
for idx, (doc, dist) in enumerate(zip(docs, distances), start=1):
doc_id = ids[idx - 1] if idx - 1 < len(ids) else "-"
raw_doc = doc or ""
keep = dist <= min_score
if log_candidates:
snippet = re.sub(r"\s+", " ", raw_doc).strip()
if len(snippet) > log_max_chars:
snippet = snippet[:log_max_chars] + "..."
status = "命中" if keep else "过滤"
logger.info(
f"[VectorMemory] 候选#{idx} {status} "
f"(dist={dist:.4f}, threshold<={min_score:.4f}, id={doc_id}) "
f"内容: {snippet}"
)
if dist > min_score:
continue
if total_len + len(raw_doc) > max_chars:
remaining = max_chars - total_len
if remaining > 50:
pieces.append(raw_doc[:remaining] + "...")
if log_candidates:
logger.info(
f"[VectorMemory] 候选#{idx} 达到注入上限,截断后加入 "
f"{remaining} 字符 (max_inject_chars={max_chars})"
)
elif log_candidates:
logger.info(
f"[VectorMemory] 候选#{idx} 达到注入上限,剩余空间 {remaining} 字符,已跳过"
)
break
pieces.append(raw_doc)
total_len += len(raw_doc)
if not pieces:
logger.info(
f"[VectorMemory] 候选均未命中阈值或被注入长度限制拦截 "
f"(threshold<={min_score:.4f}, max_inject_chars={max_chars})"
)
return ""
logger.info(f"[VectorMemory] 检索到 {len(pieces)} 条相关记忆 (chatroom={chatroom_id})")
if log_injected:
preview = re.sub(r"\s+", " ", "\n---\n".join(pieces)).strip()
if len(preview) > log_max_chars:
preview = preview[:log_max_chars] + "..."
logger.info(
f"[VectorMemory] 最终注入预览 ({len(pieces)}条, {total_len}字): {preview}"
)
return "\n\n【历史记忆】以下是与当前话题相关的历史摘要:\n" + "\n---\n".join(pieces)
except Exception as e:
logger.warning(f"[VectorMemory] 检索失败: {e}")
return ""
def _get_vector_memories_for_display(self, chatroom_id: str) -> list:
"""获取指定群的所有向量记忆摘要(用于展示)"""
if not self._vector_memory_enabled or not self._chroma_collection:
return []
try:
results = self._chroma_collection.get(
where={"chatroom_id": chatroom_id},
include=["documents", "metadatas"],
)
if not results or not results.get("ids"):
return []
items = []
for i, doc_id in enumerate(results["ids"]):
meta = results["metadatas"][i] if results.get("metadatas") else {}
doc = results["documents"][i] if results.get("documents") else ""
items.append({
"id": doc_id,
"summary": doc,
"ts_start": meta.get("ts_start", ""),
"ts_end": meta.get("ts_end", ""),
"watermark": meta.get("watermark", 0),
})
items.sort(key=lambda x: x.get("watermark", 0))
return items
except Exception as e:
logger.warning(f"[VectorMemory] 获取展示数据失败: {e}")
return []
def _build_vector_memory_html(self, items: list, chatroom_id: str) -> str:
"""构建向量记忆展示的 HTML"""
from datetime import datetime as _dt
# 构建摘要卡片 HTML
cards_html = ""
for idx, item in enumerate(items, 1):
summary = item["summary"].replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;").replace("\n", "<br>")
# 时间范围
time_range = ""
try:
if item["ts_start"] and item["ts_end"]:
t1 = _dt.fromtimestamp(float(item["ts_start"]))
t2 = _dt.fromtimestamp(float(item["ts_end"]))
time_range = f'{t1.strftime("%m/%d %H:%M")}{t2.strftime("%m/%d %H:%M")}'
except Exception:
pass
if not time_range:
time_range = f'片段 #{idx}'
cards_html += f'''
<div class="memory-card">
<div class="card-header">
<span class="card-index">#{idx}</span>
<span class="card-time">{time_range}</span>
</div>
<div class="card-body">{summary}</div>
</div>'''
now_str = _dt.now().strftime("%Y-%m-%d %H:%M")
watermark_ts = self._vector_watermarks.get(chatroom_id, "")
if watermark_ts:
try:
wt = _dt.fromisoformat(str(watermark_ts))
watermark_display = wt.strftime("%m/%d %H:%M")
except Exception:
watermark_display = str(watermark_ts)[:16]
else:
watermark_display = ""
html = f'''<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<style>
* {{ margin: 0; padding: 0; box-sizing: border-box; }}
body {{ font-family: "Microsoft YaHei", "PingFang SC", sans-serif; }}
#card {{
width: 600px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 24px;
min-height: 200px;
}}
.title-bar {{
display: flex; justify-content: space-between; align-items: center;
margin-bottom: 18px;
}}
.title {{
font-size: 22px; font-weight: bold; color: #fff;
text-shadow: 0 1px 3px rgba(0,0,0,0.3);
}}
.subtitle {{
font-size: 13px; color: rgba(255,255,255,0.75);
}}
.stats {{
display: flex; gap: 16px; margin-bottom: 18px;
}}
.stat-box {{
flex: 1; background: rgba(255,255,255,0.15);
backdrop-filter: blur(8px);
border-radius: 10px; padding: 12px 14px;
text-align: center; color: #fff;
}}
.stat-num {{
font-size: 26px; font-weight: bold;
}}
.stat-label {{
font-size: 12px; color: rgba(255,255,255,0.8); margin-top: 2px;
}}
.memory-card {{
background: rgba(255,255,255,0.92);
border-radius: 10px;
margin-bottom: 12px;
overflow: hidden;
box-shadow: 0 2px 8px rgba(0,0,0,0.08);
}}
.card-header {{
background: rgba(102,126,234,0.12);
padding: 8px 14px;
display: flex; justify-content: space-between; align-items: center;
}}
.card-index {{
font-size: 13px; font-weight: bold; color: #667eea;
}}
.card-time {{
font-size: 12px; color: #999;
}}
.card-body {{
padding: 12px 14px;
font-size: 13px; color: #333;
line-height: 1.7;
word-break: break-all;
}}
.footer {{
text-align: center; margin-top: 14px;
font-size: 11px; color: rgba(255,255,255,0.6);
}}
</style>
</head>
<body>
<div id="card">
<div class="title-bar">
<div>
<div class="title">🧠 向量记忆</div>
<div class="subtitle">{chatroom_id}</div>
</div>
<div class="subtitle">{now_str}</div>
</div>
<div class="stats">
<div class="stat-box">
<div class="stat-num">{len(items)}</div>
<div class="stat-label">摘要总数</div>
</div>
<div class="stat-box">
<div class="stat-num">{watermark_display}</div>
<div class="stat-label">最后摘要时间</div>
</div>
</div>
{cards_html}
<div class="footer">AIChat Vector Memory · Powered by Chroma</div>
</div>
</body>
</html>'''
return html
async def _render_vector_memory_image(self, html: str) -> str | None:
"""用 Playwright 将 HTML 渲染为截图,返回图片路径"""
try:
from plugins.SignInPlugin.html_renderer import get_browser
except ImportError:
try:
from playwright.async_api import async_playwright
pw = await async_playwright().start()
browser = await pw.chromium.launch(headless=True, args=['--no-sandbox'])
except Exception as e:
logger.error(f"[VectorMemory] Playwright 不可用: {e}")
return None
else:
browser = await get_browser()
try:
page = await browser.new_page()
await page.set_viewport_size({"width": 600, "height": 800})
await page.set_content(html)
await page.wait_for_selector("#card", timeout=5000)
element = await page.query_selector("#card")
if not element:
await page.close()
return None
output_dir = Path(__file__).parent / "data" / "temp"
output_dir.mkdir(parents=True, exist_ok=True)
ts = int(datetime.now().timestamp())
output_path = output_dir / f"vector_memory_{ts}.png"
await element.screenshot(path=str(output_path))
await page.close()
logger.success(f"[VectorMemory] 渲染成功: {output_path}")
return str(output_path)
except Exception as e:
logger.error(f"[VectorMemory] 渲染失败: {e}")
return None
def _get_chat_id(self, from_wxid: str, sender_wxid: str = None, is_group: bool = False) -> str:
"""获取会话ID"""
if is_group:
# 群聊使用 "群ID:用户ID" 组合,确保每个用户有独立的对话记忆
user_wxid = sender_wxid or from_wxid
return f"{from_wxid}:{user_wxid}"
else:
return sender_wxid or from_wxid # 私聊使用用户ID
def _get_group_history_chat_id(self, from_wxid: str, user_wxid: str = None) -> str:
"""获取群聊 history 的会话ID可配置为全群共享或按用户隔离"""
if not from_wxid:
return ""
history_config = (self.config or {}).get("history", {})
scope = str(history_config.get("scope", "chatroom") or "chatroom").strip().lower()
if scope in ("per_user", "user", "peruser"):
if not user_wxid:
return from_wxid
return self._get_chat_id(from_wxid, user_wxid, is_group=True)
return from_wxid
def _should_capture_group_history(self, *, is_triggered: bool) -> bool:
"""判断群聊消息是否需要写入 history减少无关上下文污染"""
history_config = (self.config or {}).get("history", {})
capture = str(history_config.get("capture", "all") or "all").strip().lower()
if capture in ("none", "off", "disable", "disabled"):
return False
if capture in ("reply", "ai_only", "triggered"):
return bool(is_triggered)
return True
def _parse_history_timestamp(self, ts) -> float | None:
if ts is None:
return None
if isinstance(ts, (int, float)):
return float(ts)
if isinstance(ts, str):
s = ts.strip()
if not s:
return None
try:
return float(s)
except Exception:
pass
try:
return datetime.fromisoformat(s).timestamp()
except Exception:
return None
return None
def _filter_history_by_window(self, history: list) -> list:
history_config = (self.config or {}).get("history", {})
window_seconds = history_config.get("context_window_seconds", None)
if window_seconds is None:
window_seconds = history_config.get("window_seconds", 0)
try:
window_seconds = float(window_seconds or 0)
except Exception:
window_seconds = 0
if window_seconds <= 0:
return history
cutoff = time.time() - window_seconds
filtered = []
for msg in history or []:
ts = self._parse_history_timestamp((msg or {}).get("timestamp"))
if ts is None or ts >= cutoff:
filtered.append(msg)
return filtered
def _sanitize_speaker_name(self, name: str) -> str:
"""清洗昵称,避免破坏历史格式(如 [name] 前缀)。"""
if name is None:
return ""
s = str(name).strip()
if not s:
return ""
s = s.replace("\r", " ").replace("\n", " ")
s = re.sub(r"\s{2,}", " ", s)
# 避免与历史前缀 [xxx] 冲突
s = s.replace("[", "").replace("]", "")
return s.strip()
def _combine_display_and_nickname(self, display_name: str, wechat_nickname: str) -> str:
display_name = self._sanitize_speaker_name(display_name)
wechat_nickname = self._sanitize_speaker_name(wechat_nickname)
# 重要:群昵称(群名片) 与 微信昵称(全局) 是两个不同概念,尽量同时给 AI。
if display_name and wechat_nickname:
return f"群昵称={display_name} | 微信昵称={wechat_nickname}"
if display_name:
return f"群昵称={display_name}"
if wechat_nickname:
return f"微信昵称={wechat_nickname}"
return ""
def _get_chatroom_member_lock(self, chatroom_id: str) -> asyncio.Lock:
lock = self._chatroom_member_cache_locks.get(chatroom_id)
if lock is None:
lock = asyncio.Lock()
self._chatroom_member_cache_locks[chatroom_id] = lock
return lock
async def _get_group_display_name(self, bot, chatroom_id: str, user_wxid: str, *, force_refresh: bool = False) -> str:
"""获取群名片(群内昵称)。失败时返回空串。"""
if not chatroom_id or not user_wxid:
return ""
if not hasattr(bot, "get_chatroom_members"):
return ""
now = time.time()
if not force_refresh:
cached = self._chatroom_member_cache.get(chatroom_id)
if cached:
ts, member_map = cached
if now - float(ts or 0) < float(self._chatroom_member_cache_ttl_seconds or 0):
return self._sanitize_speaker_name(member_map.get(user_wxid, ""))
lock = self._get_chatroom_member_lock(chatroom_id)
async with lock:
now = time.time()
if not force_refresh:
cached = self._chatroom_member_cache.get(chatroom_id)
if cached:
ts, member_map = cached
if now - float(ts or 0) < float(self._chatroom_member_cache_ttl_seconds or 0):
return self._sanitize_speaker_name(member_map.get(user_wxid, ""))
try:
# 群成员列表可能较大,避免长期阻塞消息处理
members = await asyncio.wait_for(bot.get_chatroom_members(chatroom_id), timeout=8)
except Exception as e:
logger.debug(f"获取群成员列表失败: {chatroom_id}, {e}")
return ""
member_map = {}
try:
for m in members or []:
wxid = (m.get("wxid") or "").strip()
if not wxid:
continue
display_name = m.get("display_name") or m.get("displayName") or ""
member_map[wxid] = str(display_name or "").strip()
except Exception as e:
logger.debug(f"解析群成员列表失败: {chatroom_id}, {e}")
self._chatroom_member_cache[chatroom_id] = (time.time(), member_map)
return self._sanitize_speaker_name(member_map.get(user_wxid, ""))
async def _get_user_display_label(self, bot, from_wxid: str, user_wxid: str, is_group: bool) -> str:
"""用于历史记录:群聊优先使用群名片,其次微信昵称。"""
if not is_group:
return ""
wechat_nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
group_display = await self._get_group_display_name(bot, from_wxid, user_wxid)
return self._combine_display_and_nickname(group_display, wechat_nickname) or wechat_nickname or user_wxid
async def _get_user_nickname(self, bot, from_wxid: str, user_wxid: str, is_group: bool) -> str:
"""
获取用户昵称,优先使用 Redis 缓存
Args:
bot: WechatHookClient 实例
from_wxid: 消息来源群聊ID或私聊用户ID
user_wxid: 用户wxid
is_group: 是否群聊
Returns:
用户昵称
"""
if not is_group:
return ""
nickname = ""
# 0. 优先从 MemberSync 数据库读取
try:
member_service = get_member_service()
if from_wxid:
member_info = await member_service.get_chatroom_member_info(from_wxid, user_wxid)
else:
member_info = await member_service.get_member_info(user_wxid)
if member_info and member_info.get("nickname"):
return member_info["nickname"]
except Exception as e:
logger.debug(f"[MemberSync数据库读取失败] {user_wxid}: {e}")
# 1. 优先从 Redis 缓存获取
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid)
if cached_info and cached_info.get("nickname"):
logger.debug(f"[缓存命中] 用户昵称: {user_wxid} -> {cached_info['nickname']}")
return cached_info["nickname"]
# 2. 从 MessageLogger 数据库查询
if not nickname:
try:
from plugins.MessageLogger.main import MessageLogger
msg_logger = MessageLogger.get_instance()
if msg_logger:
with msg_logger.get_db_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
(user_wxid,)
)
result = cursor.fetchone()
if result:
nickname = result[0]
except Exception as e:
logger.debug(f"从数据库获取昵称失败: {e}")
# 3. 最后降级使用 wxid
if not nickname:
nickname = user_wxid or "未知用户"
return nickname
def _check_rate_limit(self, user_wxid: str) -> tuple:
"""
检查用户是否超过限流
Args:
user_wxid: 用户wxid
Returns:
(是否允许, 剩余次数, 重置时间秒数)
"""
rate_limit_config = self.config.get("rate_limit", {})
if not rate_limit_config.get("enabled", True):
return (True, 999, 0)
redis_cache = get_cache()
if not redis_cache or not redis_cache.enabled:
return (True, 999, 0) # Redis 不可用时不限流
limit = rate_limit_config.get("ai_chat_limit", 20)
window = rate_limit_config.get("ai_chat_window", 60)
return redis_cache.check_rate_limit(user_wxid, limit, window, "ai_chat")
def _add_to_memory(self, chat_id: str, role: str, content, image_base64: str = None):
"""
添加消息到记忆
Args:
chat_id: 会话ID
role: 角色 (user/assistant)
content: 消息内容(可以是字符串或列表)
image_base64: 可选的图片base64数据
"""
if not self.store:
return
self.store.add_private_message(chat_id, role, content, image_base64=image_base64)
def _get_memory_messages(self, chat_id: str) -> list:
"""获取记忆中的消息"""
if not self.store:
return []
return self.store.get_private_messages(chat_id)
def _clear_memory(self, chat_id: str):
"""清空指定会话的记忆"""
if not self.store:
return
self.store.clear_private_messages(chat_id)
async def _download_and_encode_image(self, bot, message: dict) -> str:
"""下载图片并转换为base64委托给 ImageProcessor使用新接口"""
if self._image_processor:
return await self._image_processor.download_image(bot, message)
logger.warning("ImageProcessor 未初始化,无法下载图片")
return ""
async def _download_emoji_and_encode(self, cdn_url: str, max_retries: int = 3) -> str:
"""下载表情包并转换为base64委托给 ImageProcessor"""
if self._image_processor:
return await self._image_processor.download_emoji(cdn_url, max_retries)
logger.warning("ImageProcessor 未初始化,无法下载表情包")
return ""
async def _generate_image_description(self, image_base64: str, prompt: str, config: dict) -> str:
"""
使用 AI 生成图片描述,委托给 ImageProcessor
Args:
image_base64: 图片的 base64 数据
prompt: 描述提示词
config: 图片描述配置
Returns:
图片描述文本,失败返回空字符串
"""
if self._image_processor:
model = config.get("model")
return await self._image_processor.generate_description(image_base64, prompt, model)
logger.warning("ImageProcessor 未初始化,无法生成图片描述")
return ""
def _collect_tools_with_plugins(self) -> dict:
"""收集工具定义(来自 ToolRegistry并保留来源插件名"""
registry = get_tool_registry()
tools_config = (self.config or {}).get("tools", {})
mode = tools_config.get("mode", "all")
whitelist = set(tools_config.get("whitelist", []))
blacklist = set(tools_config.get("blacklist", []))
tools_map = {}
for name in registry.list_tools():
tool_def = registry.get(name)
if not tool_def:
continue
if mode == "whitelist" and name not in whitelist:
continue
if mode == "blacklist" and name in blacklist:
continue
tools_map[name] = (tool_def.plugin_name, tool_def.schema)
return tools_map
def _collect_tools(self):
"""收集所有插件的LLM工具支持白名单/黑名单过滤)"""
tools_map = self._collect_tools_with_plugins()
return [item[1] for item in tools_map.values()]
def _get_tool_schema_map(self, tools_map: dict | None = None) -> dict:
"""构建工具名到参数 schema 的映射"""
tools_map = tools_map or self._collect_tools_with_plugins()
schema_map = {}
for name, (_plugin_name, tool) in tools_map.items():
fn = tool.get("function", {})
schema_map[name] = fn.get("parameters", {}) or {}
return schema_map
def _guess_tool_param_description(self, param_name: str) -> str:
"""为缺失描述的参数补充通用说明"""
key = str(param_name or "").strip().lower()
hints = {
"query": "检索关键词或要查询的问题",
"keyword": "关键词",
"prompt": "用于生成内容的提示词",
"text": "要处理的文本内容",
"content": "主要内容",
"url": "目标链接地址",
"urls": "目标链接地址列表",
"image_url": "图片链接地址",
"image_base64": "图片的 Base64 编码数据",
"location": "地点名称(如城市/地区)",
"city": "城市名称",
"name": "名称",
"id": "目标对象 ID",
"user_id": "用户 ID",
"group_id": "群 ID",
"count": "数量",
"num": "数量",
"mode": "执行模式",
"type": "类型",
}
return hints.get(key, f"参数 {param_name} 的取值")
def _normalize_tool_schema_for_llm(self, tool: dict) -> dict | None:
"""标准化工具 schema提升模型函数调用稳定性。"""
if not isinstance(tool, dict):
return None
normalized = copy.deepcopy(tool)
normalized["type"] = "function"
function_def = normalized.get("function")
if not isinstance(function_def, dict):
return None
function_name = str(function_def.get("name", "")).strip()
if not function_name:
return None
description = str(function_def.get("description", "")).strip()
if not description:
function_def["description"] = f"调用 {function_name} 工具完成任务,仅在用户明确需要时使用。"
parameters = function_def.get("parameters")
if not isinstance(parameters, dict):
parameters = {"type": "object", "properties": {}, "required": []}
properties = parameters.get("properties")
if not isinstance(properties, dict):
properties = {}
for param_name, param_schema in list(properties.items()):
if not isinstance(param_schema, dict):
properties[param_name] = {
"type": "string",
"description": self._guess_tool_param_description(param_name),
}
continue
if not str(param_schema.get("description", "")).strip():
param_schema["description"] = self._guess_tool_param_description(param_name)
if not str(param_schema.get("type", "")).strip() and "enum" not in param_schema:
param_schema["type"] = "string"
required = parameters.get("required", [])
if not isinstance(required, list):
required = []
required = [item for item in required if isinstance(item, str) and item in properties]
parameters["type"] = "object"
parameters["properties"] = properties
parameters["required"] = required
parameters["additionalProperties"] = False
function_def["parameters"] = parameters
normalized["function"] = function_def
return normalized
def _prepare_tools_for_llm(self, tools: list) -> list:
"""预处理工具声明(补描述、补参数 schema、收敛格式"""
prepared = []
for tool in tools or []:
normalized = self._normalize_tool_schema_for_llm(tool)
if normalized:
prepared.append(normalized)
return prepared
def _build_tool_rules_prompt(self, tools: list) -> str:
"""构建函数调用规则提示词(参考 Eridanus 风格)。"""
lines = [
"【函数调用规则】",
"1) 仅可基于【当前消息】决定是否调用工具;历史内容只用于语境,不可据此触发工具。",
"2) 能直接回答就直接回答;只有在需要外部能力时才调用工具。",
"3) 关键参数不完整时先追问澄清,不要臆测。",
"4) 禁止向用户输出 function_call/tool_calls 或 JSON 调用片段。",
"5) 工具执行后请输出自然语言总结:先结论,再补充细节。",
]
if not tools:
lines.append("本轮未提供可调用工具,禁止伪造函数调用。")
lines.append("严禁在回复中输出任何 JSON、function_call、tool_calls、action/actioninput 格式的内容,即使历史上下文中出现过工具调用记录也不得模仿。只用自然语言回复。")
return "\n\n" + "\n".join(lines)
lines.append("本轮可用工具:")
for tool in (tools or [])[:12]:
function_def = tool.get("function", {}) if isinstance(tool, dict) else {}
name = str(function_def.get("name", "")).strip()
description = str(function_def.get("description", "")).strip()
if not name:
continue
if len(description) > 70:
description = description[:67] + "..."
lines.append(f"- {name}: {description or '按工具定义执行'}")
return "\n\n" + "\n".join(lines)
async def _handle_list_prompts(self, bot, from_wxid: str):
"""处理人设列表指令"""
try:
prompts_dir = Path(__file__).parent / "prompts"
# 获取所有 .txt 文件
if not prompts_dir.exists():
await bot.send_text(from_wxid, "❌ prompts 目录不存在")
return
txt_files = sorted(prompts_dir.glob("*.txt"))
if not txt_files:
await bot.send_text(from_wxid, "❌ 没有找到任何人设文件")
return
# 构建列表消息
current_file = self.config["prompt"]["system_prompt_file"]
msg = "📋 可用人设列表:\n\n"
for i, file_path in enumerate(txt_files, 1):
filename = file_path.name
# 标记当前使用的人设
if filename == current_file:
msg += f"{i}. {filename}\n"
else:
msg += f"{i}. {filename}\n"
msg += f"\n💡 使用方法:/切人设 文件名.txt"
await bot.send_text(from_wxid, msg)
logger.info("已发送人设列表")
except Exception as e:
logger.error(f"获取人设列表失败: {e}")
await bot.send_text(from_wxid, f"❌ 获取人设列表失败: {str(e)}")
def _estimate_tokens(self, text: str) -> int:
"""
估算文本的 token 数量
简单估算规则:
- 中文:约 1.5 字符 = 1 token
- 英文:约 4 字符 = 1 token
- 混合文本取平均
"""
if not text:
return 0
# 统计中文字符数
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
# 其他字符数
other_chars = len(text) - chinese_chars
# 估算 token 数
chinese_tokens = chinese_chars / 1.5
other_tokens = other_chars / 4
return int(chinese_tokens + other_tokens)
def _estimate_message_tokens(self, message: dict) -> int:
"""估算单条消息的 token 数"""
content = message.get("content", "")
if isinstance(content, str):
return self._estimate_tokens(content)
elif isinstance(content, list):
# 多模态消息
total = 0
for item in content:
if item.get("type") == "text":
total += self._estimate_tokens(item.get("text", ""))
elif item.get("type") == "image_url":
# 图片按 85 token 估算OpenAI 低分辨率图片)
total += 85
return total
return 0
def _extract_text_from_multimodal(self, content) -> str:
"""从多模态 content 中提取文本,模型不支持时用于降级"""
if isinstance(content, list):
texts = [item.get("text", "") for item in content if item.get("type") == "text"]
text = "".join(texts).strip()
return text or "[图片]"
if content is None:
return ""
return str(content)
def _extract_last_user_text(self, messages: list) -> str:
"""从 messages 中提取最近一条用户文本,用于工具参数兜底。"""
for msg in reversed(messages or []):
if msg.get("role") == "user":
return self._extract_text_from_multimodal(msg.get("content"))
return ""
def _sanitize_llm_output(self, text) -> str:
"""
清洗 LLM 输出,尽量满足:不输出思维链、不使用 Markdown。
说明:提示词并非强约束,因此在所有“发给用户/写入上下文”的出口统一做后处理。
"""
if text is None:
return ""
raw = str(text)
cleaned = raw
# 清理 xAI/Grok 风格的渲染卡片标签,避免被当作正文/上下文继续传播
cleaned = re.sub(
r"(?is)<grok:render\b[\s\S]*?</grok:render>",
"",
cleaned,
)
output_cfg = (self.config or {}).get("output", {})
strip_thinking = output_cfg.get("strip_thinking", True)
strip_markdown = output_cfg.get("strip_markdown", True)
# 先做一次 Markdown 清理,避免 “**思考过程:**/### 思考” 这类包裹导致无法识别
if strip_markdown:
cleaned = self._strip_markdown_syntax(cleaned)
if strip_thinking:
cleaned = self._strip_thinking_content(cleaned)
# 清理模型偶发输出的“文本工具调用”痕迹(如 tavilywebsearch{query:...} / <ctrl46>
# 这些内容既不是正常回复,也会破坏“工具只能用 Function Calling”的约束
try:
cleaned = re.sub(r"<ctrl\\d+>", "", cleaned, flags=re.IGNORECASE)
cleaned = re.sub(
r"(?:展开阅读下文\\s*)?(?:tavilywebsearch|tavily_web_search|web_search)\\s*\\{[^{}]{0,1500}\\}",
"",
cleaned,
flags=re.IGNORECASE,
)
cleaned = re.sub(
r"(?:tavilywebsearch|tavily_web_search|web_search)\\s*\\([^\\)]{0,1500}\\)",
"",
cleaned,
flags=re.IGNORECASE,
)
cleaned = re.sub(
r"\{[^\{\}]{0,2000}[\"']name[\"']\s*:\s*[\"'](?:draw_image|generate_image|nano_ai_image_generation|flow2_ai_image_generation|jimeng_ai_image_generation|kiira2_ai_image_generation)[\"'][\s\S]{0,2000}\}",
"",
cleaned,
flags=re.IGNORECASE,
)
cleaned = cleaned.replace("展开阅读下文", "")
cleaned = re.sub(
r"[(]已触发工具处理[^)\r\n]{0,500}[)]?",
"",
cleaned,
)
cleaned = re.sub(r"(?m)^.*已触发工具处理.*$", "", cleaned)
cleaned = re.sub(r"(?m)^.*结果将发送到聊天中.*$", "", cleaned)
# 清理历史记录格式标签 [时间:...][类型:...]
cleaned = re.sub(r"\[时间:[^\]]*\]", "", cleaned)
cleaned = re.sub(r"\[类型:[^\]]*\]", "", cleaned)
# 过滤图片占位符/文件名,避免把日志占位符当成正文发出去
cleaned = re.sub(
r"\\[图片[^\\]]*\\]\\s*\\S+\\.(?:png|jpe?g|gif|webp)",
"",
cleaned,
flags=re.IGNORECASE,
)
cleaned = re.sub(r"\\[图片[^\\]]*\\]", "", cleaned)
except Exception:
pass
# 再跑一轮:部分模型会把“思考/最终”标记写成 Markdown或在剥离标签后才露出标记
if strip_markdown:
cleaned = self._strip_markdown_syntax(cleaned)
if strip_thinking:
cleaned = self._strip_thinking_content(cleaned)
cleaned = cleaned.strip()
# 兜底:清洗后仍残留明显“思维链/大纲”标记时,再尝试一次“抽取最终段”
if strip_thinking and cleaned and self._contains_thinking_markers(cleaned):
extracted = self._extract_after_last_answer_marker(cleaned)
if not extracted:
extracted = self._extract_final_answer_from_outline(cleaned)
if extracted:
cleaned = extracted.strip()
# 仍残留标记:尽量选取最后一个“不含标记”的段落作为最终回复
if cleaned and self._contains_thinking_markers(cleaned):
parts = [p.strip() for p in re.split(r"\n{2,}", cleaned) if p.strip()]
for p in reversed(parts):
if not self._contains_thinking_markers(p):
cleaned = p
break
cleaned = cleaned.strip()
# 最终兜底:仍然像思维链就直接丢弃(宁可不发也不要把思维链发出去)
if strip_thinking and cleaned and self._contains_thinking_markers(cleaned):
return ""
if cleaned:
return cleaned
raw_stripped = raw.strip()
# 清洗后为空时,不要回退到包含思维链标记的原文(避免把 <think>... 直接发出去)
if strip_thinking and self._contains_thinking_markers(raw_stripped):
return ""
if self._contains_tool_call_markers(raw_stripped):
return ""
return raw_stripped
async def _maybe_send_voice_reply(self, bot, to_wxid: str, text: str, message: dict | None = None):
"""AI 回复后,按概率触发语音回复"""
if not text:
return
try:
voice_plugin = self.get_plugin("VoiceSynth")
if not voice_plugin or not getattr(voice_plugin, "enabled", True):
return
if not getattr(voice_plugin, "master_enabled", True):
return
handler = getattr(voice_plugin, "maybe_send_voice_reply", None)
if not handler:
return
asyncio.create_task(handler(bot, to_wxid, text, message=message))
except Exception as e:
logger.debug(f"触发语音回复失败: {e}")
def _contains_thinking_markers(self, text: str) -> bool:
"""粗略判断文本是否包含明显的“思考/推理”外显标记,用于决定是否允许回退原文。"""
if not text:
return False
lowered = text.lower()
tag_tokens = (
"<think", "</think",
"<analysis", "</analysis",
"<reasoning", "</reasoning",
"<thought", "</thought",
"<thinking", "</thinking",
"<thoughts", "</thoughts",
"<scratchpad", "</scratchpad",
"&lt;think", "&lt;/think",
"&lt;analysis", "&lt;/analysis",
"&lt;reasoning", "&lt;/reasoning",
"&lt;thought", "&lt;/thought",
"&lt;thinking", "&lt;/thinking",
"&lt;thoughts", "&lt;/thoughts",
"&lt;scratchpad", "&lt;/scratchpad",
)
if any(tok in lowered for tok in tag_tokens):
return True
stripped = text.strip()
if stripped.startswith("{") and stripped.endswith("}"):
# JSON 结构化输出常见于“analysis/final”
json_keys = (
"\"analysis\"",
"\"reasoning\"",
"\"thought\"",
"\"thoughts\"",
"\"scratchpad\"",
"\"final\"",
"\"answer\"",
"\"response\"",
"\"output\"",
"\"text\"",
)
if any(k in lowered for k in json_keys):
return True
# YAML/KV 风格
if re.search(r"(?im)^\s*(analysis|reasoning|thoughts?|scratchpad|final|answer|response|output|text|思考|分析|推理|最终|输出)\s*[:]", text):
return True
marker_re = re.compile(
r"(?mi)^\s*(?:\d+\s*[\.\、:)\-–—]\s*)?(?:[-*•]+\s*)?"
r"(?:【\s*(?:思考过程|推理过程|分析过程|思考|分析|推理|内心独白|内心os|思维链|思路|"
r"chain\s*of\s*thought|reasoning|analysis|thinking|thoughts|thought\s*process|scratchpad)\s*】"
r"|(?:思考过程|推理过程|分析过程|思考|分析|推理|内心独白|内心os|思维链|思路|"
r"chain\s*of\s*thought|reasoning|analysis|analyze|thinking|thoughts|thought\s*process|scratchpad|internal\s*monologue|mind\s*space|final\s*polish|output\s*generation)"
r"(?:\s*】)?\s*(?:[:]|$|\s+))"
)
return marker_re.search(text) is not None
def _contains_tool_call_markers(self, text: str) -> bool:
if not text:
return False
lowered = text.lower()
if "<ctrl" in lowered or "展开阅读下文" in text:
return True
if "<grok:render" in lowered:
return True
if re.search(r"(?i)\"?function_call\"?\s*[:=]\s*\{", text):
return True
if re.search(r"(?i)\bfunction_call\s*\(", text):
return True
if re.search(r"(?i)(tavilywebsearch|tavily_web_search|web_search)\s*[\{\(]", text):
return True
if re.search(
r"(?i)[\"']name[\"']\s*:\s*[\"'](draw_image|generate_image|nano_ai_image_generation|flow2_ai_image_generation|jimeng_ai_image_generation|kiira2_ai_image_generation)[\"']",
text,
):
return True
if re.search(
r"(?i)\b(?:print\s*\(\s*)?(draw_image|generate_image|nano_ai_image_generation|flow2_ai_image_generation|jimeng_ai_image_generation|kiira2_ai_image_generation)\s*\(",
text,
):
return True
if re.search(
r"(?i)[\"']action[\"']\s*:\s*[\"'](draw_image|generate_image|nano_ai_image_generation|flow2_ai_image_generation|jimeng_ai_image_generation|kiira2_ai_image_generation|nanoaiimage_generation|flow2aiimage_generation|jimengaiimage_generation|kiira2aiimage_generation)[\"']",
text,
):
return True
return False
def _extract_tool_calls_data(self, message: dict, choice: dict | None = None) -> list:
"""统一提取工具调用,兼容 tool_calls 与旧版 function_call。"""
choice = choice if isinstance(choice, dict) else {}
message = message if isinstance(message, dict) else {}
tool_calls = message.get("tool_calls") or choice.get("tool_calls") or []
if isinstance(tool_calls, dict):
tool_calls = [tool_calls]
if isinstance(tool_calls, list) and tool_calls:
return tool_calls
function_call = message.get("function_call") or choice.get("function_call")
if not isinstance(function_call, dict):
return []
function_name = (function_call.get("name") or "").strip()
if not function_name:
return []
arguments = function_call.get("arguments", "{}")
if not isinstance(arguments, str):
try:
arguments = json.dumps(arguments, ensure_ascii=False)
except Exception:
arguments = "{}"
return [{
"id": f"legacy_fc_{uuid.uuid4().hex[:8]}",
"type": "function",
"function": {
"name": function_name,
"arguments": arguments,
},
}]
def _normalize_dialog_api_mode(self, mode: str) -> str:
value = str(mode or "").strip().lower()
aliases = {
"auto": "auto",
"openai_chat_completions": "openai_chat_completions",
"chat_completions": "openai_chat_completions",
"chat": "openai_chat_completions",
"openai_chat": "openai_chat_completions",
"openai_responses": "openai_responses",
"responses": "openai_responses",
"openai_completions": "openai_completions",
"completions": "openai_completions",
"claude_messages": "claude_messages",
"claude": "claude_messages",
"anthropic": "claude_messages",
"anthropic_messages": "claude_messages",
"gemini_generate_content": "gemini_generate_content",
"gemini": "gemini_generate_content",
"gemini_models": "gemini_generate_content",
}
return aliases.get(value, "")
def _resolve_dialog_api_mode(self, api_config: dict) -> str:
configured = self._normalize_dialog_api_mode((api_config or {}).get("mode", "auto"))
if configured and configured != "auto":
return configured
api_url = str((api_config or {}).get("url", "")).lower()
if "/v1/responses" in api_url:
return "openai_responses"
if re.search(r"/v1/messages(?:[/?]|$)", api_url):
return "claude_messages"
if "v1beta/models" in api_url or ":generatecontent" in api_url:
return "gemini_generate_content"
if re.search(r"/v1/completions(?:[/?]|$)", api_url) and "/chat/completions" not in api_url:
return "openai_completions"
return "openai_chat_completions"
def _build_gemini_generate_content_url(self, api_url: str, model: str) -> str:
url = str(api_url or "").strip().rstrip("/")
if not url:
return ""
if ":generatecontent" in url.lower():
return url
if "/models/" in url:
return f"{url}:generateContent"
return f"{url}/{model}:generateContent"
def _parse_data_url(self, data_url: str) -> tuple[str | None, str | None]:
s = str(data_url or "").strip()
if not s.startswith("data:") or "," not in s:
return None, None
header, b64_data = s.split(",", 1)
mime = "application/octet-stream"
if ";" in header:
mime = header[5:].split(";", 1)[0].strip() or mime
elif ":" in header:
mime = header.split(":", 1)[1].strip() or mime
return mime, b64_data
def _convert_openai_content_to_claude_blocks(self, content) -> list:
if isinstance(content, str):
text = content.strip()
return [{"type": "text", "text": text or " "}]
blocks = []
if isinstance(content, list):
for item in content:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type == "text":
text = str(item.get("text", "")).strip()
if text:
blocks.append({"type": "text", "text": text})
elif item_type == "image_url":
image_url = ((item.get("image_url") or {}).get("url") or "").strip()
mime, b64_data = self._parse_data_url(image_url)
if mime and b64_data:
blocks.append({
"type": "image",
"source": {
"type": "base64",
"media_type": mime,
"data": b64_data,
},
})
elif image_url:
blocks.append({"type": "text", "text": f"[图片链接] {image_url}"})
if not blocks:
text = self._extract_text_from_multimodal(content)
blocks.append({"type": "text", "text": text or " "})
return blocks
def _convert_openai_content_to_gemini_parts(self, content) -> list:
if isinstance(content, str):
text = content.strip()
return [{"text": text or " "}]
parts = []
if isinstance(content, list):
for item in content:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type == "text":
text = str(item.get("text", "")).strip()
if text:
parts.append({"text": text})
elif item_type == "image_url":
image_url = ((item.get("image_url") or {}).get("url") or "").strip()
mime, b64_data = self._parse_data_url(image_url)
if mime and b64_data:
parts.append({
"inline_data": {
"mime_type": mime,
"data": b64_data,
}
})
elif image_url:
parts.append({"text": f"[图片链接] {image_url}"})
if not parts:
text = self._extract_text_from_multimodal(content)
parts.append({"text": text or " "})
return parts
def _convert_openai_messages_to_claude(self, messages: list) -> tuple[str, list]:
system_parts = []
claude_messages = []
tool_name_by_id = {}
for msg in messages or []:
role = str(msg.get("role", "")).strip().lower()
if role == "system":
text = self._extract_text_from_multimodal(msg.get("content"))
if text:
system_parts.append(text)
continue
if role == "assistant":
content_blocks = self._convert_openai_content_to_claude_blocks(msg.get("content"))
tool_calls = msg.get("tool_calls") or []
if isinstance(tool_calls, dict):
tool_calls = [tool_calls]
for tc in tool_calls:
function = (tc or {}).get("function") or {}
fn_name = function.get("name", "")
if not fn_name:
continue
tool_id = (tc or {}).get("id") or f"claude_tool_{uuid.uuid4().hex[:8]}"
raw_args = function.get("arguments", "{}")
try:
args = json.loads(raw_args) if isinstance(raw_args, str) else (raw_args or {})
if not isinstance(args, dict):
args = {}
except Exception:
args = {}
content_blocks.append({
"type": "tool_use",
"id": tool_id,
"name": fn_name,
"input": args,
})
tool_name_by_id[tool_id] = fn_name
if content_blocks:
claude_messages.append({"role": "assistant", "content": content_blocks})
continue
if role == "tool":
tool_id = str(msg.get("tool_call_id", "")).strip()
result_text = self._extract_text_from_multimodal(msg.get("content"))
if tool_id:
block = {"type": "tool_result", "tool_use_id": tool_id, "content": result_text}
claude_messages.append({"role": "user", "content": [block]})
else:
claude_messages.append({"role": "user", "content": [{"type": "text", "text": f"[工具结果]\n{result_text}"}]})
continue
user_blocks = self._convert_openai_content_to_claude_blocks(msg.get("content"))
claude_messages.append({"role": "user", "content": user_blocks})
if not claude_messages:
claude_messages = [{"role": "user", "content": [{"type": "text", "text": "你好"}]}]
return "\n\n".join(system_parts).strip(), claude_messages
def _convert_openai_messages_to_gemini(self, messages: list) -> tuple[str, list]:
system_parts = []
contents = []
tool_name_by_id = {}
for msg in messages or []:
role = str(msg.get("role", "")).strip().lower()
if role == "system":
text = self._extract_text_from_multimodal(msg.get("content"))
if text:
system_parts.append(text)
continue
if role == "assistant":
parts = self._convert_openai_content_to_gemini_parts(msg.get("content"))
tool_calls = msg.get("tool_calls") or []
if isinstance(tool_calls, dict):
tool_calls = [tool_calls]
for tc in tool_calls:
function = (tc or {}).get("function") or {}
fn_name = function.get("name", "")
if not fn_name:
continue
raw_args = function.get("arguments", "{}")
try:
args = json.loads(raw_args) if isinstance(raw_args, str) else (raw_args or {})
if not isinstance(args, dict):
args = {}
except Exception:
args = {}
tool_id = (tc or {}).get("id") or f"gemini_tool_{uuid.uuid4().hex[:8]}"
tool_name_by_id[tool_id] = fn_name
parts.append({"functionCall": {"name": fn_name, "args": args}})
contents.append({"role": "model", "parts": parts or [{"text": " "}]})
continue
if role == "tool":
tool_id = str(msg.get("tool_call_id", "")).strip()
fn_name = tool_name_by_id.get(tool_id, "tool_result")
tool_text = self._extract_text_from_multimodal(msg.get("content"))
parts = [{
"functionResponse": {
"name": fn_name,
"response": {
"content": tool_text
},
}
}]
contents.append({"role": "user", "parts": parts})
continue
parts = self._convert_openai_content_to_gemini_parts(msg.get("content"))
contents.append({"role": "user", "parts": parts})
if not contents:
contents = [{"role": "user", "parts": [{"text": "你好"}]}]
return "\n\n".join(system_parts).strip(), contents
def _convert_openai_messages_to_plain_prompt(self, messages: list) -> str:
lines = []
for msg in messages or []:
role = str(msg.get("role", "")).strip().lower() or "user"
content = self._extract_text_from_multimodal(msg.get("content"))
if role == "tool":
tool_id = str(msg.get("tool_call_id", "")).strip()
role_label = f"tool:{tool_id}" if tool_id else "tool"
else:
role_label = role
lines.append(f"[{role_label}] {content}")
return "\n".join(lines).strip() or "你好"
def _convert_openai_messages_to_responses_input(self, messages: list) -> list:
input_messages = []
for msg in messages or []:
role = str(msg.get("role", "")).strip().lower()
if role not in ("system", "user", "assistant"):
role = "user"
content = msg.get("content")
blocks = []
if isinstance(content, list):
for item in content:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type == "text":
text = str(item.get("text", "")).strip()
if text:
blocks.append({"type": "input_text", "text": text})
elif item_type == "image_url":
image_url = ((item.get("image_url") or {}).get("url") or "").strip()
if image_url:
blocks.append({"type": "input_image", "image_url": image_url})
else:
text = self._extract_text_from_multimodal(content)
if text:
blocks.append({"type": "input_text", "text": text})
if not blocks:
blocks.append({"type": "input_text", "text": " "})
input_messages.append({"role": role, "content": blocks})
if not input_messages:
input_messages = [{"role": "user", "content": [{"type": "input_text", "text": "你好"}]}]
return input_messages
def _convert_tools_for_dialog_api(self, tools: list, api_mode: str):
if not tools:
return []
if api_mode == "openai_chat_completions":
return tools
if api_mode == "openai_completions":
return []
converted = []
for tool in tools:
if not isinstance(tool, dict):
continue
function = tool.get("function") or {}
name = str(function.get("name", "")).strip()
if not name:
continue
desc = str(function.get("description", "") or "").strip()
params = function.get("parameters")
if not isinstance(params, dict):
params = {"type": "object", "properties": {}}
if api_mode == "openai_responses":
converted.append({
"type": "function",
"name": name,
"description": desc,
"parameters": params,
})
elif api_mode == "claude_messages":
converted.append({
"name": name,
"description": desc,
"input_schema": params,
})
elif api_mode == "gemini_generate_content":
converted.append({
"name": name,
"description": desc,
"parameters": params,
})
if api_mode == "gemini_generate_content":
return [{"functionDeclarations": converted}] if converted else []
return converted
def _parse_dialog_api_response(self, api_mode: str, data: dict) -> tuple[str, list]:
data = data if isinstance(data, dict) else {}
def _fallback_openai():
choices = data.get("choices", [])
choice0 = choices[0] if choices else {}
message = choice0.get("message", {}) if isinstance(choice0, dict) else {}
full = message.get("content", "") or choice0.get("text", "") or ""
calls = self._extract_tool_calls_data(message, choice0)
if not isinstance(calls, list):
calls = []
return full, calls
if api_mode == "openai_chat_completions":
return _fallback_openai()
if api_mode == "openai_completions":
choices = data.get("choices", [])
if choices:
c0 = choices[0] if isinstance(choices[0], dict) else {}
text = c0.get("text", "")
if not text and isinstance(c0.get("message"), dict):
text = c0["message"].get("content", "")
return text or "", []
return "", []
if api_mode == "openai_responses":
if "choices" in data:
return _fallback_openai()
text_parts = []
tool_calls = []
if isinstance(data.get("output_text"), str) and data.get("output_text"):
text_parts.append(data.get("output_text", ""))
output = data.get("output") or []
for item in output if isinstance(output, list) else []:
if not isinstance(item, dict):
continue
item_type = item.get("type", "")
if item_type in ("function_call", "tool_call"):
name = item.get("name", "")
raw_args = item.get("arguments", "{}")
if isinstance(raw_args, dict):
raw_args = json.dumps(raw_args, ensure_ascii=False)
tool_calls.append({
"id": item.get("id") or f"resp_fc_{uuid.uuid4().hex[:8]}",
"type": "function",
"function": {"name": name, "arguments": raw_args or "{}"},
})
if item_type == "message":
for c in item.get("content") or []:
if not isinstance(c, dict):
continue
c_type = c.get("type")
if c_type in ("output_text", "text"):
txt = c.get("text", "")
if txt:
text_parts.append(txt)
elif c_type in ("function_call", "tool_call"):
name = c.get("name", "")
raw_args = c.get("arguments", "{}")
if isinstance(raw_args, dict):
raw_args = json.dumps(raw_args, ensure_ascii=False)
tool_calls.append({
"id": c.get("id") or f"resp_fc_{uuid.uuid4().hex[:8]}",
"type": "function",
"function": {"name": name, "arguments": raw_args or "{}"},
})
return "".join(text_parts), tool_calls
if api_mode == "claude_messages":
if "choices" in data:
return _fallback_openai()
text_parts = []
tool_calls = []
for block in data.get("content") or []:
if not isinstance(block, dict):
continue
block_type = block.get("type")
if block_type == "text":
txt = block.get("text", "")
if txt:
text_parts.append(txt)
elif block_type == "tool_use":
name = block.get("name", "")
args = block.get("input", {})
if not isinstance(args, dict):
args = {}
tool_calls.append({
"id": block.get("id") or f"claude_fc_{uuid.uuid4().hex[:8]}",
"type": "function",
"function": {
"name": name,
"arguments": json.dumps(args, ensure_ascii=False),
},
})
return "".join(text_parts), tool_calls
if api_mode == "gemini_generate_content":
if "choices" in data:
return _fallback_openai()
text_parts = []
tool_calls = []
candidates = data.get("candidates") or []
for candidate in candidates if isinstance(candidates, list) else []:
content = (candidate or {}).get("content") or {}
parts = content.get("parts") or []
for part in parts if isinstance(parts, list) else []:
if not isinstance(part, dict):
continue
if "text" in part and part.get("text"):
text_parts.append(str(part.get("text")))
function_call = part.get("functionCall") or part.get("function_call")
if isinstance(function_call, dict):
name = function_call.get("name", "")
args = function_call.get("args", {})
if isinstance(args, str):
try:
args = json.loads(args)
except Exception:
args = {"raw": args}
if not isinstance(args, dict):
args = {}
tool_calls.append({
"id": f"gemini_fc_{uuid.uuid4().hex[:8]}",
"type": "function",
"function": {
"name": name,
"arguments": json.dumps(args, ensure_ascii=False),
},
})
return "".join(text_parts), tool_calls
return "", []
def _create_proxy_connector(self):
connector = None
proxy_config = self.config.get("proxy", {})
if not proxy_config.get("enabled", False):
return None
proxy_type = str(proxy_config.get("type", "socks5")).upper()
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_username = proxy_config.get("username")
proxy_password = proxy_config.get("password")
if proxy_username and proxy_password:
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
else:
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
if PROXY_SUPPORT:
try:
connector = ProxyConnector.from_url(proxy_url)
logger.debug(f"使用代理: {proxy_type}://{proxy_host}:{proxy_port}")
except Exception as e:
logger.warning(f"代理配置失败,将直连: {e}")
connector = None
else:
logger.warning("代理功能不可用aiohttp_socks 未安装),将直连")
connector = None
return connector
async def _send_dialog_api_request(
self,
api_config: dict,
messages: list,
tools: list | None = None,
*,
request_tag: str = "",
prefer_stream: bool = True,
max_tokens: int | None = None,
) -> tuple[str, list]:
tag = str(request_tag or "").strip()
mode = self._resolve_dialog_api_mode(api_config)
model_name = str(api_config.get("model", "")).lower()
allow_stream = bool(prefer_stream and mode == "openai_chat_completions" and "gemini-3" not in model_name)
api_url = str(api_config.get("url", "")).strip()
api_key = str(api_config.get("api_key", "")).strip()
request_url = api_url
limit = int(max_tokens if max_tokens is not None else api_config.get("max_tokens", 4096))
mode_tools = self._convert_tools_for_dialog_api(tools or [], mode)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
payload = {}
if mode == "openai_chat_completions":
payload = {
"model": api_config.get("model", ""),
"messages": messages,
"max_tokens": limit,
"stream": allow_stream,
}
if mode_tools:
payload["tools"] = mode_tools
elif mode == "openai_responses":
payload = {
"model": api_config.get("model", ""),
"input": self._convert_openai_messages_to_responses_input(messages),
"max_output_tokens": limit,
"stream": False,
}
if mode_tools:
payload["tools"] = mode_tools
allow_stream = False
elif mode == "openai_completions":
payload = {
"model": api_config.get("model", ""),
"prompt": self._convert_openai_messages_to_plain_prompt(messages),
"max_tokens": limit,
"stream": False,
}
allow_stream = False
elif mode == "claude_messages":
system_text, claude_messages = self._convert_openai_messages_to_claude(messages)
payload = {
"model": api_config.get("model", ""),
"messages": claude_messages,
"max_tokens": limit,
"stream": False,
}
if system_text:
payload["system"] = system_text
if mode_tools:
payload["tools"] = mode_tools
headers["x-api-key"] = api_key
headers["anthropic-version"] = str(api_config.get("anthropic_version", "2023-06-01"))
allow_stream = False
elif mode == "gemini_generate_content":
request_url = self._build_gemini_generate_content_url(api_url, api_config.get("model", ""))
system_text, gemini_contents = self._convert_openai_messages_to_gemini(messages)
payload = {
"contents": gemini_contents,
"generationConfig": {
"maxOutputTokens": limit,
},
}
if system_text:
payload["systemInstruction"] = {
"parts": [{"text": system_text}]
}
if mode_tools:
payload["tools"] = mode_tools
allow_stream = False
else:
raise Exception(f"不支持的 API 模式: {mode}")
timeout = aiohttp.ClientTimeout(total=int(api_config.get("timeout", 120)))
connector = self._create_proxy_connector()
logger.debug(f"{tag} 对话API模式: {mode}, stream={allow_stream}, url={request_url}")
try:
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.post(request_url, json=payload, headers=headers) as resp:
if resp.status != 200:
error_text = await resp.text()
raise Exception(f"API 错误 {resp.status}: {error_text}")
if allow_stream:
full_content = ""
tool_calls_dict = {}
async for raw_line in resp.content:
line = raw_line.decode("utf-8").strip()
if not line or line == "data: [DONE]":
continue
if not line.startswith("data: "):
continue
try:
chunk = json.loads(line[6:])
except Exception:
continue
choices = chunk.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
content_piece = delta.get("content", "")
if content_piece:
full_content += content_piece
if delta.get("tool_calls"):
for tool_call_delta in delta["tool_calls"]:
index = tool_call_delta.get("index", 0)
if index not in tool_calls_dict:
tool_calls_dict[index] = {
"id": "",
"type": "function",
"function": {
"name": "",
"arguments": "",
},
}
if "id" in tool_call_delta:
tool_calls_dict[index]["id"] = tool_call_delta["id"]
if "type" in tool_call_delta:
tool_calls_dict[index]["type"] = tool_call_delta["type"]
if "function" in tool_call_delta:
fn_delta = tool_call_delta["function"]
if "name" in fn_delta:
tool_calls_dict[index]["function"]["name"] += fn_delta["name"]
if "arguments" in fn_delta:
tool_calls_dict[index]["function"]["arguments"] += fn_delta["arguments"]
tool_calls_data = [tool_calls_dict[i] for i in sorted(tool_calls_dict.keys())] if tool_calls_dict else []
return full_content, tool_calls_data
data = await resp.json(content_type=None)
full_content, tool_calls_data = self._parse_dialog_api_response(mode, data)
if not isinstance(tool_calls_data, list):
tool_calls_data = []
return full_content or "", tool_calls_data
except aiohttp.ClientError as e:
raise Exception(f"网络请求失败: {type(e).__name__}: {str(e)}")
except asyncio.TimeoutError:
raise Exception(f"API 请求超时 (timeout={int(api_config.get('timeout', 120))}s)")
def _extract_after_last_answer_marker(self, text: str) -> str | None:
"""从文本中抽取最后一个“最终/输出/答案”标记后的内容(不要求必须是编号大纲)。"""
if not text:
return None
# 1) 明确的行首标记Text:/Final Answer:/输出: ...
marker_re = re.compile(
r"(?im)^\s*(?:\d+\s*[\.\、:\)、)\-–—]\s*)?"
r"(?:text|final\s*answer|final\s*response|final\s*output|final|output|answer|response|输出|最终回复|最终答案|最终)\s*[:]\s*"
)
matches = list(marker_re.finditer(text))
if matches:
candidate = text[matches[-1].end():].strip()
if candidate:
return candidate
# 2) JSON/YAML 风格final: ... / \"final\": \"...\"
kv_re = re.compile(
r"(?im)^\s*\"?(?:final|answer|response|output|text|最终|最终回复|最终答案|输出)\"?\s*[:]\s*"
)
kv_matches = list(kv_re.finditer(text))
if kv_matches:
candidate = text[kv_matches[-1].end():].strip()
if candidate:
return candidate
# 3) 纯 JSON 对象(尝试解析)
stripped = text.strip()
if stripped.startswith("{") and stripped.endswith("}"):
try:
obj = json.loads(stripped)
if isinstance(obj, dict):
for key in ("final", "answer", "response", "output", "text"):
v = obj.get(key)
if isinstance(v, str) and v.strip():
return v.strip()
except Exception:
pass
return None
def _extract_final_answer_from_outline(self, text: str) -> str | None:
"""从“分析/草稿/输出”这类结构化大纲中提取最终回复正文(用于拦截思维链)。"""
if not text:
return None
# 至少包含多个“1./2./3.”段落,才认为可能是大纲/思维链输出
heading_re = re.compile(r"(?m)^\s*\d+\s*[\.\、:\)、)\-–—]\s*\S+")
if len(heading_re.findall(text)) < 2:
return None
# 优先:提取最后一个 “Text:/Final Answer:/Output:” 之后的内容
marker_re = re.compile(
r"(?im)^\s*(?:\d+\s*[\.\、:\)、)\-–—]\s*)?"
r"(?:text|final\s*answer|final\s*response|final\s*output|output|answer|response|输出|最终回复|最终答案)\s*[:]\s*"
)
matches = list(marker_re.finditer(text))
if matches:
candidate = text[matches[-1].end():].strip()
if candidate:
return candidate
# 没有明确的最终标记时,仅在包含“分析/思考/草稿/输出”等元信息关键词的情况下兜底抽取
lowered = text.lower()
outline_keywords = (
"analyze",
"analysis",
"reasoning",
"internal monologue",
"mind space",
"draft",
"drafting",
"outline",
"plan",
"steps",
"formulating response",
"final polish",
"final answer",
"output generation",
"system prompt",
"chat log",
"previous turn",
"current situation",
)
cn_keywords = ("思考", "分析", "推理", "思维链", "草稿", "计划", "步骤", "输出", "最终")
if not any(k in lowered for k in outline_keywords) and not any(k in text for k in cn_keywords):
return None
# 次选:取最后一个非空段落(避免返回整段大纲)
parts = [p.strip() for p in re.split(r"\n{2,}", text) if p.strip()]
if not parts:
return None
last = parts[-1]
if len(heading_re.findall(last)) == 0:
return last
return None
def _strip_thinking_content(self, text: str) -> str:
"""移除常见的“思考/推理”外显内容(如 <think>...</think>、思考:...)。"""
if not text:
return ""
t = text.replace("\r\n", "\n").replace("\r", "\n")
# 1) 先移除显式标签块(常见于某些推理模型)
thinking_tags = ("think", "analysis", "reasoning", "thought", "thinking", "thoughts", "scratchpad", "reflection")
for tag in thinking_tags:
t = re.sub(rf"<{tag}\b[^>]*>.*?</{tag}>", "", t, flags=re.IGNORECASE | re.DOTALL)
# 兼容被转义的标签(&lt;think&gt;...&lt;/think&gt;
t = re.sub(rf"&lt;{tag}\b[^&]*&gt;.*?&lt;/{tag}&gt;", "", t, flags=re.IGNORECASE | re.DOTALL)
# 1.1) 兜底:流式/截断导致标签未闭合时,若开头出现思考标签,直接截断后续内容
m = re.search(r"<(think|analysis|reasoning|thought|thinking|thoughts|scratchpad|reflection)\b[^>]*>", t, flags=re.IGNORECASE)
if m and m.start() < 200:
t = t[: m.start()].rstrip()
m2 = re.search(r"&lt;(think|analysis|reasoning|thought|thinking|thoughts|scratchpad|reflection)\b[^&]*&gt;", t, flags=re.IGNORECASE)
if m2 and m2.start() < 200:
t = t[: m2.start()].rstrip()
# 2) 再处理“思考:.../最终:...”这种分段格式(尽量只剥离前置思考)
lines = t.split("\n")
if not lines:
return t
# 若文本中包含明显的“最终/输出/答案”标记(不限是否编号),直接抽取最后一段,避免把大纲整体发出去
if self._contains_thinking_markers(t):
extracted_anywhere = self._extract_after_last_answer_marker(t)
if extracted_anywhere:
return extracted_anywhere
reasoning_kw = (
r"思考过程|推理过程|分析过程|思考|分析|推理|思路|内心独白|内心os|思维链|"
r"chain\s*of\s*thought|reasoning|analysis|analyze|thinking|thoughts|thought\s*process|scratchpad|plan|steps|draft|outline"
)
answer_kw = r"最终答案|最终回复|最终|回答|回复|答复|结论|输出|final(?:\s*answer)?|final\s*response|final\s*output|answer|response|output|text"
# 兼容:
# - 思考:... / 最终回复:...
# - 【思考】... / 【最终】...
# - **思考过程:**Markdown 会在外层先被剥离)
reasoning_start = re.compile(
rf"^\s*(?:\d+\s*[\.\、:\)、)\-–—]\s*)?(?:[-*•]+\s*)?"
rf"(?:【\s*(?:{reasoning_kw})\s*】\s*[:]?\s*|(?:{reasoning_kw})(?:\s*】)?\s*(?:[:]|$|\s+))",
re.IGNORECASE,
)
answer_start = re.compile(
rf"^\s*(?:\d+\s*[\.\、:\)、)\-–—]\s*)?(?:[-*•]+\s*)?"
rf"(?:【\s*(?:{answer_kw})\s*】\s*[:]?\s*|(?:{answer_kw})(?:\s*】)?\s*(?:[:]|$)\s*)",
re.IGNORECASE,
)
# 2.0) 若文本开头就是“最终回复:/Final answer:”之类,直接去掉标记(不强依赖出现“思考块”)
for idx, line in enumerate(lines):
if line.strip() == "":
continue
m0 = answer_start.match(line)
if m0:
lines[idx] = line[m0.end():].lstrip()
break
has_reasoning = any(reasoning_start.match(line) for line in lines[:10])
has_answer_marker = any(answer_start.match(line) for line in lines)
# 2.1) 若同时存在“思考块 + 答案标记”,跳过思考块直到答案标记
if has_reasoning and has_answer_marker:
out_lines: list[str] = []
skipping = False
answer_started = False
for line in lines:
if answer_started:
out_lines.append(line)
continue
if not skipping and reasoning_start.match(line):
skipping = True
continue
if skipping:
m = answer_start.match(line)
if m:
answer_started = True
skipping = False
out_lines.append(line[m.end():].lstrip())
continue
m = answer_start.match(line)
if m:
answer_started = True
out_lines.append(line[m.end():].lstrip())
else:
out_lines.append(line)
t2 = "\n".join(out_lines).strip()
return t2 if t2 else t
# 2.2) 兜底:若开头就是“思考:”,尝试去掉第一段(到第一个空行)
if has_reasoning:
first_blank_idx = None
for idx, line in enumerate(lines):
if line.strip() == "":
first_blank_idx = idx
break
if first_blank_idx is not None and first_blank_idx + 1 < len(lines):
candidate = "\n".join(lines[first_blank_idx + 1 :]).strip()
if candidate:
return candidate
# 2.3) 兜底识别“1. Analyze... 2. ... 6. Output ... Text: ...”这类思维链大纲并抽取最终正文
outline_extracted = self._extract_final_answer_from_outline("\n".join(lines).strip())
if outline_extracted:
return outline_extracted
# 将行级处理结果合回文本(例如去掉开头的“最终回复:”标记)
t = "\n".join(lines).strip()
# 3) 兼容 <final>...</final> 这类包裹(保留正文,去掉标签)
t = re.sub(r"</?\s*(final|answer)\s*>", "", t, flags=re.IGNORECASE).strip()
return t
def _strip_markdown_syntax(self, text: str) -> str:
"""将常见 Markdown 标记转换为更像纯文本的形式(保留内容,移除格式符)。"""
if not text:
return ""
t = text.replace("\r\n", "\n").replace("\r", "\n")
# 去掉代码块围栏(保留内容)
t = re.sub(r"```[^\n]*\n", "", t)
t = t.replace("```", "")
# 图片/链接:![alt](url) / [text](url)
def _md_image_repl(m: re.Match) -> str:
alt = (m.group(1) or "").strip()
url = (m.group(2) or "").strip()
if alt and url:
return f"{alt}{url}"
return url or alt or ""
def _md_link_repl(m: re.Match) -> str:
label = (m.group(1) or "").strip()
url = (m.group(2) or "").strip()
if label and url:
return f"{label}{url}"
return url or label or ""
t = re.sub(r"!\[([^\]]*)\]\(([^)]+)\)", _md_image_repl, t)
t = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", _md_link_repl, t)
# 行级标记:标题、引用、分割线
cleaned_lines: list[str] = []
for line in t.split("\n"):
line = re.sub(r"^\s{0,3}#{1,6}\s+", "", line) # 标题
line = re.sub(r"^\s{0,3}>\s?", "", line) # 引用
if re.match(r"^\s*(?:-{3,}|\*{3,}|_{3,})\s*$", line):
continue # 分割线整行移除
cleaned_lines.append(line)
t = "\n".join(cleaned_lines)
# 行内代码:`code`
t = re.sub(r"`([^`]+)`", r"\1", t)
# 粗体/删除线(保留文本)
t = t.replace("**", "")
t = t.replace("__", "")
t = t.replace("~~", "")
# 斜体(保留文本,避免误伤乘法/通配符,仅处理成对包裹)
t = re.sub(r"(?<!\*)\*([^*\n]+)\*(?!\*)", r"\1", t)
t = re.sub(r"(?<!_)_([^_\n]+)_(?!_)", r"\1", t)
# 压缩过多空行
t = re.sub(r"\n{3,}", "\n\n", t)
return t.strip()
def _parse_nickname_parts(self, nickname: str) -> tuple:
"""
解析昵称字符串,提取群昵称和微信昵称
输入格式可能是:
- "群昵称=xxx | 微信昵称=yyy"
- "群昵称=xxx"
- "微信昵称=yyy"
- "普通昵称"
Returns:
(group_nickname, wechat_name)
"""
if not nickname:
return ("", "")
group_nickname = ""
wechat_name = ""
if "群昵称=" in nickname or "微信昵称=" in nickname:
parts = nickname.split("|")
for part in parts:
part = part.strip()
if part.startswith("群昵称="):
group_nickname = part[4:].strip()
elif part.startswith("微信昵称="):
wechat_name = part[5:].strip()
else:
# 普通昵称,当作微信昵称
wechat_name = nickname.strip()
return (group_nickname, wechat_name)
def _format_timestamp(self, timestamp) -> str:
"""格式化时间戳为可读字符串"""
if not timestamp:
return ""
try:
if isinstance(timestamp, (int, float)):
from datetime import datetime
dt = datetime.fromtimestamp(timestamp)
return dt.strftime("%Y-%m-%d %H:%M")
elif isinstance(timestamp, str):
# 尝试解析 ISO 格式
if "T" in timestamp:
dt_str = timestamp.split(".")[0] # 去掉毫秒
from datetime import datetime
dt = datetime.fromisoformat(dt_str)
return dt.strftime("%Y-%m-%d %H:%M")
return timestamp[:16] if len(timestamp) > 16 else timestamp
except Exception:
pass
return ""
def _format_user_message_content(self, nickname: str, content: str, timestamp=None, msg_type: str = "text", user_id: str = None) -> str:
"""
格式化用户消息内容,包含结构化的用户信息
格式: [时间][用户ID:xxx][群昵称:xxx][微信昵称:yyy][类型:text]
消息内容
"""
group_nickname, wechat_name = self._parse_nickname_parts(nickname)
time_str = self._format_timestamp(timestamp)
# 构建结构化前缀
parts = []
if time_str:
parts.append(f"时间:{time_str}")
# 添加用户唯一标识取wxid后6位作为短ID便于AI区分不同用户
if user_id:
short_id = user_id[-6:] if len(user_id) > 6 else user_id
parts.append(f"用户ID:{short_id}")
if group_nickname:
parts.append(f"群昵称:{group_nickname}")
if wechat_name:
parts.append(f"微信昵称:{wechat_name}")
if msg_type:
parts.append(f"类型:{msg_type}")
prefix = "[" + "][".join(parts) + "]" if parts else ""
return f"{prefix}\n{content}" if prefix else content
def _append_group_history_messages(self, messages: list, recent_history: list):
"""将群聊历史按 role 追加到 LLM messages"""
for msg in recent_history:
role = msg.get("role") or "user"
msg_nickname = msg.get("nickname", "")
msg_content = msg.get("content", "")
msg_timestamp = msg.get("timestamp")
msg_wxid = msg.get("wxid", "") # 获取用户唯一标识
# 机器人历史回复
if role == "assistant":
if isinstance(msg_content, list):
msg_content = self._extract_text_from_multimodal(msg_content)
# 避免旧历史中的 Markdown/思维链污染上下文
msg_content = self._sanitize_llm_output(msg_content)
# 机器人回复也加上时间标记
time_str = self._format_timestamp(msg_timestamp)
if time_str:
msg_content = f"[时间:{time_str}][类型:assistant]\n{msg_content}"
messages.append({
"role": "assistant",
"content": msg_content
})
continue
# 用户历史消息
if isinstance(msg_content, list):
# 多模态消息(含图片)
content_with_info = []
text_content = ""
has_image = False
for item in msg_content:
if item.get("type") == "text":
text_content = item.get("text", "")
elif item.get("type") == "image_url":
has_image = True
content_with_info.append(item)
msg_type = "image" if has_image else "text"
formatted_text = self._format_user_message_content(
msg_nickname, text_content, msg_timestamp, msg_type, msg_wxid
)
content_with_info.insert(0, {"type": "text", "text": formatted_text})
messages.append({
"role": "user",
"content": content_with_info
})
else:
# 纯文本消息
formatted_content = self._format_user_message_content(
msg_nickname, msg_content, msg_timestamp, "text", msg_wxid
)
messages.append({
"role": "user",
"content": formatted_content
})
def _get_bot_nickname(self) -> str:
try:
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
nickname = main_config.get("Bot", {}).get("nickname", "")
return nickname or "机器人"
except Exception:
return "机器人"
def _tool_call_to_action_text(self, function_name: str, arguments: dict) -> str:
args = arguments if isinstance(arguments, dict) else {}
if function_name == "user_signin":
return "签到"
if function_name == "check_profile":
return "查询个人信息"
return f"执行{function_name}"
def _build_tool_calls_context_note(self, tool_calls_data: list) -> str:
actions: list[str] = []
for tool_call in tool_calls_data or []:
function_name = tool_call.get("function", {}).get("name", "")
if not function_name:
continue
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
try:
arguments = json.loads(arguments_str) if arguments_str else {}
except Exception:
arguments = {}
actions.append(self._tool_call_to_action_text(function_name, arguments))
if not actions:
return "(已触发工具处理:上一条请求。结果将发送到聊天中。)"
return f"(已触发工具处理:{''.join(actions)}。结果将发送到聊天中。)"
async def _record_tool_calls_to_context(
self,
tool_calls_data: list,
*,
from_wxid: str,
chat_id: str,
is_group: bool,
user_wxid: str | None = None,
):
note = self._build_tool_calls_context_note(tool_calls_data)
if chat_id:
self._add_to_memory(chat_id, "assistant", note)
if is_group and from_wxid:
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid or "")
await self._add_to_history(history_chat_id, self._get_bot_nickname(), note, role="assistant", sender_wxid=user_wxid or None)
def _extract_tool_intent_text(self, user_message: str, tool_query: str | None = None) -> str:
text = tool_query if tool_query is not None else user_message
text = str(text or "").strip()
if not text:
return ""
# 对“聊天记录/视频”等组合消息,尽量只取用户真实提问部分,避免历史文本触发工具误判
markers = (
"[用户的问题]:",
"[用户的问题]",
"[用户的问题]\n",
"[用户的问题]",
)
for marker in markers:
if marker in text:
text = text.rsplit(marker, 1)[-1].strip()
return self._normalize_search_query(text) or text
def _normalize_search_query(self, text: str) -> str:
"""清洗搜索类工具的查询参数,去掉元信息/触发词"""
cleaned = str(text or "").strip()
if not cleaned:
return ""
cleaned = cleaned.replace("【当前消息】", "").strip()
if cleaned.startswith("[") and "\n" in cleaned:
first_line, rest = cleaned.split("\n", 1)
if any(token in first_line for token in ("时间:", "用户ID:", "群昵称:", "微信昵称:", "类型:")):
cleaned = rest.strip()
if cleaned.startswith("@"):
parts = cleaned.split(maxsplit=1)
if len(parts) > 1:
cleaned = parts[1].strip()
cleaned = re.sub(r"^(搜索|搜|查|查询|帮我搜|帮我查|帮我搜索|请搜索|请查)\s*", "", cleaned)
return cleaned.strip()
def _looks_like_info_query(self, text: str) -> bool:
t = str(text or "").strip().lower()
if not t:
return False
# 太短的消息不值得额外走一轮分类
if len(t) < 6:
return False
# 疑问/求评价/求推荐类
if any(x in t for x in ("?", "")):
return True
if re.search(r"(什么|咋|怎么|如何|为啥|为什么|哪|哪里|哪个|多少|推荐|值不值得|值得吗|好不好|靠谱吗|评价|口碑|怎么样|如何评价|近况|最新|最近)", t):
return True
if re.search(r"\\b(what|who|when|where|why|how|details?|impact|latest|news|review|rating|price|info|information)\\b", t):
return True
# 明确的实体/对象询问(公会/游戏/公司/项目等)
if re.search(r"(公会|战队|服务器|区服|游戏|公司|品牌|产品|软件|插件|项目|平台|up主|主播|作者|电影|电视剧|小说)", t) and len(t) >= 8:
return True
return False
def _looks_like_lyrics_query(self, text: str) -> bool:
t = str(text or "").strip().lower()
if not t:
return False
return bool(re.search(
r"(歌词|歌名|哪首歌|哪一首歌|哪首歌曲|哪一首歌曲|谁的歌|谁唱|谁唱的|这句.*(歌|歌词)|这段.*(歌|歌词)|是什么歌|什么歌|是哪首歌|出自.*(歌|歌曲)|台词.*歌|lyric|lyrics)",
t,
))
def _looks_like_image_generation_request(self, text: str) -> bool:
"""判断是否是明确的生图/自拍请求。"""
t = str(text or "").strip().lower()
if not t:
return False
if re.search(r"(画一张|画张|画一幅|画幅|画一个|画个|画一下|画图|绘图|绘制|作画|出图|生成图片|生成照片|生成相片|文生图|图生图|以图生图)", t):
return True
if re.search(r"(生成|做|给我|帮我).{0,4}(一张|一幅|一个|张|个).{0,8}(图|图片|照片|自拍|自拍照|自画像)", t):
return True
if re.search(r"(来|发).{0,2}(一张|一幅|一个|张|个).{0,10}(图|图片|照片|自拍|自拍照|自画像)", t):
return True
if re.search(r"(来|发|给我|给|看看|看下|看一看).{0,4}(自拍|自拍照|自画像)", t):
return True
if re.search(r"(看看|看下|看一看|来点|来张|发|给我).{0,4}(腿|白丝|黑丝|丝袜|福利|福利图|色图|涩图|写真)", t):
return True
if re.search(r"(白丝|黑丝|丝袜|福利|福利图|色图|涩图|写真).{0,6}(图|图片|照片|自拍|来一张|来点|发一张)", t):
return True
if re.search(r"(看看腿|看腿|来点福利|来张福利|发点福利|来张白丝|来张黑丝)", t):
return True
return False
def _extract_legacy_text_search_tool_call(self, text: str) -> tuple[str, dict] | None:
"""
解析模型偶发输出的“文本工具调用”写法(例如 tavilywebsearch{query:...}),并转换为真实工具调用参数。
"""
raw = str(text or "")
if not raw:
return None
# 去掉 <ctrl46> 之类的控制标记
cleaned = re.sub(r"<ctrl\d+>", "", raw, flags=re.IGNORECASE)
m = re.search(
r"(?i)\b(?P<tool>tavilywebsearch|tavily_web_search|web_search)\s*\{\s*query\s*[:=]\s*(?P<q>[^{}]{1,800})\}",
cleaned,
)
if not m:
m = re.search(
r"(?i)\b(?P<tool>tavilywebsearch|tavily_web_search|web_search)\s*\(\s*query\s*[:=]\s*(?P<q>[^\)]{1,800})\)",
cleaned,
)
if not m:
return None
tool = str(m.group("tool") or "").strip().lower()
query = str(m.group("q") or "").strip().strip("\"'`")
if not query:
return None
# 统一映射到项目实际存在的工具名
if tool in ("tavilywebsearch", "tavily_web_search"):
tool_name = "tavily_web_search"
else:
tool_name = "web_search"
return tool_name, {"query": query[:400]}
def _extract_legacy_text_image_tool_call(self, text: str) -> tuple[str, dict] | None:
"""解析模型文本输出的绘图工具调用 JSON并转换为真实工具调用参数。"""
raw = str(text or "")
if not raw:
return None
# 兼容 python 代码风格print(draw_image("...")) / draw_image("...")
py_call = re.search(
r"(?is)(?:print\s*\(\s*)?"
r"(draw_image|generate_image|nano_ai_image_generation|flow2_ai_image_generation|"
r"jimeng_ai_image_generation|kiira2_ai_image_generation)\s*"
r"\(\s*([\"'])([\s\S]{1,2000}?)\2\s*\)\s*\)?",
raw,
)
if py_call:
name = py_call.group(1).strip()
prompt = py_call.group(3).strip()
if prompt:
return name, {"prompt": prompt}
candidates = []
for m in re.finditer(r"```(?:json)?\s*({[\s\S]{20,2000}})\s*```", raw, flags=re.IGNORECASE):
candidates.append(m.group(1))
m = re.search(r"(\{\s*\"(?:name|tool|action)\"\s*:\s*\"[^\"]+\"[\s\S]{0,2000}\})", raw)
if m:
candidates.append(m.group(1))
for blob in candidates:
try:
data = json.loads(blob)
except Exception:
continue
if not isinstance(data, dict):
continue
name = str(
data.get("name")
or data.get("tool")
or data.get("action")
or data.get("Action")
or ""
).strip()
if not name:
continue
args = data.get("arguments", None)
if args in (None, "", {}):
args = (
data.get("actioninput")
or data.get("action_input")
or data.get("actionInput")
or data.get("input")
or {}
)
if isinstance(args, str):
try:
args = json.loads(args)
except Exception:
raw_args = str(args).strip()
parsed_args = None
# 某些模型会把 actioninput 生成为“类 JSON 字符串”(转义不完整),尝试兜底修复
try:
parsed_args = json.loads(raw_args.replace('\\"', '"'))
except Exception:
pass
if not isinstance(parsed_args, dict):
prompt_match = re.search(
r"(?i)[\"']?prompt[\"']?\s*[:=]\s*[\"']([\s\S]{1,2000}?)[\"']",
raw_args,
)
if prompt_match:
parsed_args = {"prompt": prompt_match.group(1).strip()}
ratio_match = re.search(
r"(?i)[\"']?(?:aspectratio|aspect_ratio)[\"']?\s*[:=]\s*[\"']([^\"']{1,30})[\"']",
raw_args,
)
if ratio_match:
parsed_args["aspectratio"] = ratio_match.group(1).strip()
args = parsed_args if isinstance(parsed_args, dict) else {"prompt": raw_args}
if not isinstance(args, dict):
continue
prompt = args.get("prompt") or args.get("text") or args.get("query") or args.get("description")
if not prompt or not isinstance(prompt, str):
continue
normalized_args = dict(args)
normalized_args["prompt"] = prompt.strip()
return name, normalized_args
# 兜底:用正则尽量提取 name/prompt允许单引号/非严格 JSON
name_match = re.search(
r"(?i)[\"'](?:name|tool|action)[\"']\s*:\s*[\"']([^\"']+)[\"']",
raw,
)
prompt_match = re.search(
r"(?i)[\"']prompt[\"']\s*:\s*[\"']([\s\S]{1,2000}?)[\"']",
raw,
)
if prompt_match:
name = name_match.group(1) if name_match else "draw_image"
prompt = prompt_match.group(1).strip()
if prompt:
return name, {"prompt": prompt}
return None
def _resolve_image_tool_alias(
self,
requested_name: str,
allowed_tool_names: set[str],
available_tool_names: set[str],
loose_image_tool: bool,
) -> str | None:
"""将模型输出的绘图工具别名映射为实际工具名。"""
name = (requested_name or "").strip().lower()
if not name:
return None
# 严格遵守本轮工具选择结果:本轮未开放绘图工具时,不允许任何文本兜底触发
if not allowed_tool_names:
return None
if name in available_tool_names:
if name in allowed_tool_names or loose_image_tool:
return name
return None
alias_map = {
"draw_image": "nano_ai_image_generation",
"image_generation": "nano_ai_image_generation",
"image_generate": "nano_ai_image_generation",
"make_image": "nano_ai_image_generation",
"create_image": "nano_ai_image_generation",
"generate_image": "generate_image",
"nanoaiimage_generation": "nano_ai_image_generation",
"flow2aiimage_generation": "flow2_ai_image_generation",
"jimengaiimage_generation": "jimeng_ai_image_generation",
"kiira2aiimage_generation": "kiira2_ai_image_generation",
}
mapped = alias_map.get(name)
if mapped and mapped in available_tool_names:
if mapped in allowed_tool_names or loose_image_tool:
return mapped
for fallback in ("nano_ai_image_generation", "generate_image", "flow2_ai_image_generation"):
if fallback in available_tool_names and (fallback in allowed_tool_names or loose_image_tool):
return fallback
return None
def _should_allow_music_followup(self, messages: list, tool_calls_data: list) -> bool:
if not tool_calls_data:
return False
has_search_tool = any(
(tc or {}).get("function", {}).get("name", "") in ("tavily_web_search", "web_search")
for tc in (tool_calls_data or [])
)
if not has_search_tool:
return False
user_text = ""
for msg in reversed(messages or []):
if msg.get("role") == "user":
user_text = self._extract_text_from_multimodal(msg.get("content"))
break
if not user_text:
return False
return self._looks_like_lyrics_query(user_text)
async def _select_tools_for_message_async(self, tools: list, *, user_message: str, tool_query: str | None = None) -> list:
"""工具选择(与旧版一致,仅使用规则筛选)"""
return self._select_tools_for_message(tools, user_message=user_message, tool_query=tool_query)
def _select_tools_for_message(self, tools: list, *, user_message: str, tool_query: str | None = None) -> list:
tools_config = (self.config or {}).get("tools", {})
if not tools_config.get("smart_select", False):
return tools
raw_intent_text = str(tool_query if tool_query is not None else user_message).strip()
raw_t = raw_intent_text.lower()
intent_text = self._extract_tool_intent_text(user_message, tool_query=tool_query)
if not intent_text:
return tools
t = intent_text.lower()
allow: set[str] = set()
available_tool_names = {
(tool or {}).get("function", {}).get("name", "")
for tool in (tools or [])
if (tool or {}).get("function", {}).get("name")
}
# 显式搜索意图硬兜底:只要本轮可用工具里有搜索工具,就强制放行
# 注意:显式搜索意图必须基于“原始文本”判断,不能只用清洗后的 intent_text
# 否则“搜索下 xxx”会被清洗成“xxx”导致误判为无搜索意图
raw_has_url = bool(re.search(r"(https?://|www\.)", raw_intent_text, flags=re.IGNORECASE))
explicit_read_web_intent = bool(re.search(
r"((阅读|读一下|读下|看下|看看|解析|总结|介绍).{0,8}(网页|网站|网址|链接))"
r"|((网页|网站|网址|链接).{0,8}(内容|正文|页面|信息|原文))",
raw_t,
))
if raw_has_url and re.search(r"(阅读|读一下|读下|看下|看看|解析|总结|介绍|提取)", raw_t):
explicit_read_web_intent = True
explicit_search_intent = bool(re.search(
r"(联网|搜索|搜一下|搜一搜|搜搜|搜索下|搜下|查一下|查资料|查新闻|查价格|帮我搜|帮我查)",
raw_t,
)) or explicit_read_web_intent
if explicit_search_intent:
for candidate in ("tavily_web_search", "web_search"):
if candidate in available_tool_names:
allow.add(candidate)
# 签到/个人信息
if re.search(r"(用户签到|签到|签个到)", t):
allow.add("user_signin")
if re.search(r"(个人信息|我的信息|我的积分|查积分|积分多少|连续签到|连签|我的资料)", t):
allow.add("check_profile")
# 鹿打卡
if re.search(r"(鹿打卡|鹿签到)", t):
allow.add("deer_checkin")
if re.search(r"(补签|补打卡)", t):
allow.add("makeup_checkin")
if re.search(r"(鹿.*(日历|月历|打卡日历))|((日历|月历|打卡日历).*鹿)", t):
allow.add("view_calendar")
# 搜索/资讯
if re.search(r"(联网|搜索|搜一下|搜一搜|搜搜|帮我搜|搜新闻|搜资料|查资料|查新闻|查价格|\bsearch\b|\bgoogle\b|\blookup\b|\bfind\b|\bnews\b|\blatest\b|\bdetails?\b|\bimpact\b)", t):
# 兼容旧工具名与当前插件实现
allow.add("tavily_web_search")
allow.add("web_search")
# 隐式信息检索:用户询问具体实体/口碑/评价但未明确说“搜索/联网”
if re.search(r"(怎么样|如何|评价|口碑|靠谱吗|值不值得|值得吗|好不好|推荐|牛不牛|强不强|厉不厉害|有名吗|什么来头|背景|近况|最新|最近)", t) and re.search(
r"(公会|战队|服务器|区服|游戏|公司|品牌|店|商家|产品|软件|插件|项目|平台|up主|主播|作者|电影|电视剧|小说|手游|网游)",
t,
):
allow.add("tavily_web_search")
allow.add("web_search")
if self._looks_like_lyrics_query(intent_text):
allow.add("tavily_web_search")
allow.add("web_search")
if re.search(r"(60秒|每日新闻|早报|新闻图片|读懂世界)", t):
allow.add("get_daily_news")
if re.search(r"(epic|喜加一|免费游戏)", t):
allow.add("get_epic_free_games")
# 音乐/短剧
# 仅在明确“点歌/播放/听一首/搜歌”等命令时开放,避免普通聊天误触
if re.search(r"(点歌|来(?:一首|首)|播放(?:一首|首)?|放歌|听(?:一首|首)|搜歌|找歌)", t):
allow.add("search_music")
if re.search(r"(短剧|搜短剧|找短剧)", t):
allow.add("search_playlet")
# 群聊总结
if re.search(r"(群聊总结|生成总结|总结一下|今日总结|昨天总结|群总结)", t):
allow.add("generate_summary")
# 娱乐
if re.search(r"(疯狂星期四|v我50|kfc)", t):
allow.add("get_kfc")
if re.search(r"(随机图片|来张图|来个图|随机图)", t):
allow.add("get_random_image")
if re.search(r"(随机视频|来个视频|随机短视频)", t):
allow.add("get_random_video")
# 绘图/视频生成(只在用户明确要求时开放)
if self._looks_like_image_generation_request(intent_text) or (
# 明确绘图动词/模式
re.search(r"(画一张|画张|画一幅|画幅|画一个|画个|画一下|画图|绘图|绘制|作画|出图|生成图片|生成照片|生成相片|文生图|图生图|以图生图)", t)
# “生成/做/给我”+“一张/一个/张/个”+“图/图片”类表达(例如:生成一张瑞依/做一张图)
or re.search(r"(生成|做|给我|帮我).{0,4}(一张|一幅|一个|张|个).{0,8}(图|图片|照片|自拍|自拍照|自画像)", t)
# “来/发”+“一张/张”+“图/图片”(例如:来张瑞依的图)
or re.search(r"(来|发).{0,2}(一张|一幅|一个|张|个).{0,10}(图|图片|照片|自拍|自拍照|自画像)", t)
# “发/来/给我”+“自拍/自画像”(例如:发张自拍/来个自画像)
or re.search(r"(来|发|给我|给).{0,3}(自拍|自拍照|自画像)", t)
# 口语化“看看腿/白丝/福利”等请求
or re.search(r"(看看|看下|看一看|来点|来张|发|给我).{0,4}(腿|白丝|黑丝|丝袜|福利|福利图|色图|涩图|写真)", t)
or re.search(r"(白丝|黑丝|丝袜|福利|福利图|色图|涩图|写真).{0,6}(图|图片|照片|自拍|来一张|来点|发一张)", t)
or re.search(r"(看看腿|看腿|来点福利|来张福利|发点福利|来张白丝|来张黑丝)", t)
# 二次重绘/返工(上下文里常省略“图/图片”)
or re.search(r"(重画|重新画|再画|重来一张|再来一张|重做一张)", t)
or re.fullmatch(r"(重来|再来|重来一次|再来一次|重新来)", t)
):
allow.update({
"nano_ai_image_generation",
"flow2_ai_image_generation",
"jimeng_ai_image_generation",
"kiira2_ai_image_generation",
"generate_image",
})
if re.search(
r"(生成视频|做个视频|视频生成|sora|grok|/视频)"
r"|((生成|制作|做|来|发|拍|整).{0,10}(视频|短视频|短片|片子|mv|vlog))"
r"|((视频|短视频|短片|片子|mv|vlog).{0,8}(生成|制作|做|来|发|整|安排))"
r"|(来一段.{0,8}(视频|短视频|短片))",
t,
):
allow.add("sora_video_generation")
allow.add("grok_video_generation")
# 如果已经命中特定领域工具(音乐/短剧等),且用户未明确表示“联网/网页/链接/来源”等需求,避免把联网搜索也暴露出去造成误触
explicit_web = bool(re.search(r"(联网|网页|网站|网址|链接|来源)", t))
if not explicit_web and {"search_music", "search_playlet"} & allow:
allow.discard("tavily_web_search")
allow.discard("web_search")
# 严格模式:没有明显工具意图时,不向模型暴露任何 tools避免误触
if not allow:
return []
selected = []
for tool in tools or []:
name = tool.get("function", {}).get("name", "")
if name and name in allow:
selected.append(tool)
if explicit_search_intent:
selected_names = [tool.get("function", {}).get("name", "") for tool in selected]
logger.info(
f"[工具选择-搜索兜底] raw={raw_intent_text[:80]} | cleaned={intent_text[:80]} "
f"| allow={sorted(list(allow))} | selected={selected_names}"
)
return selected
async def _handle_context_stats(self, bot, from_wxid: str, user_wxid: str, is_group: bool):
"""处理上下文统计指令"""
try:
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
# 计算持久记忆 token
memory_chat_id = from_wxid if is_group else user_wxid
persistent_memories = self._get_persistent_memories(memory_chat_id) if memory_chat_id else []
persistent_tokens = 0
if persistent_memories:
persistent_tokens += self._estimate_tokens("【持久记忆】以下是用户要求你记住的重要信息:\n")
for m in persistent_memories:
mem_time = m['time'][:10] if m['time'] else ""
persistent_tokens += self._estimate_tokens(f"- [{mem_time}] {m['nickname']}: {m['content']}\n")
if is_group:
# 群聊:使用 history 机制
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
history = await self._load_history(history_chat_id)
history = self._filter_history_by_window(history)
max_context = self.config.get("history", {}).get("max_context", 50)
# 实际会发送给 AI 的上下文
context_messages = history[-max_context:] if len(history) > max_context else history
# 计算 token
context_tokens = 0
for msg in context_messages:
msg_content = msg.get("content", "")
nickname = msg.get("nickname", "")
if isinstance(msg_content, list):
# 多模态消息
for item in msg_content:
if item.get("type") == "text":
context_tokens += self._estimate_tokens(f"[{nickname}] {item.get('text', '')}")
elif item.get("type") == "image_url":
context_tokens += 85
else:
context_tokens += self._estimate_tokens(f"[{nickname}] {msg_content}")
# 加上 system prompt 的 token
system_tokens = self._estimate_tokens(self.system_prompt)
total_tokens = system_tokens + persistent_tokens + context_tokens
# 计算百分比
context_limit = self.config.get("api", {}).get("context_limit", 200000)
usage_percent = (total_tokens / context_limit) * 100
remaining_tokens = context_limit - total_tokens
msg = f"📊 群聊上下文统计\n\n"
msg += f"💬 历史总条数: {len(history)}\n"
msg += f"📤 AI可见条数: {len(context_messages)}/{max_context}\n"
msg += f"🤖 人设 Token: ~{system_tokens}\n"
msg += f"📌 持久记忆: {len(persistent_memories)} 条 (~{persistent_tokens} token)\n"
msg += f"📝 上下文 Token: ~{context_tokens}\n"
msg += f"📦 总计 Token: ~{total_tokens}\n"
msg += f"📈 使用率: {usage_percent:.1f}% (剩余 ~{remaining_tokens:,})\n"
# 向量长期记忆统计
if self._vector_memory_enabled and self._chroma_collection:
try:
vm_count = self._chroma_collection.count()
vm_watermark_ts = self._vector_watermarks.get(from_wxid, "")
ts_display = vm_watermark_ts[:16] if vm_watermark_ts else ""
msg += f"🧠 向量记忆: {vm_count} 条摘要 (水位线: {ts_display})\n"
except Exception:
pass
msg += f"\n💡 /清空记忆 清空上下文 | /记忆列表 查看持久记忆"
else:
# 私聊:使用 memory 机制
memory_messages = self._get_memory_messages(chat_id)
max_messages = self.config.get("memory", {}).get("max_messages", 20)
# 计算 token
context_tokens = 0
for msg in memory_messages:
context_tokens += self._estimate_message_tokens(msg)
# 加上 system prompt 的 token
system_tokens = self._estimate_tokens(self.system_prompt)
total_tokens = system_tokens + persistent_tokens + context_tokens
# 计算百分比
context_limit = self.config.get("api", {}).get("context_limit", 200000)
usage_percent = (total_tokens / context_limit) * 100
remaining_tokens = context_limit - total_tokens
msg = f"📊 私聊上下文统计\n\n"
msg += f"💬 记忆条数: {len(memory_messages)}/{max_messages}\n"
msg += f"🤖 人设 Token: ~{system_tokens}\n"
msg += f"📌 持久记忆: {len(persistent_memories)} 条 (~{persistent_tokens} token)\n"
msg += f"📝 上下文 Token: ~{context_tokens}\n"
msg += f"📦 总计 Token: ~{total_tokens}\n"
msg += f"📈 使用率: {usage_percent:.1f}% (剩余 ~{remaining_tokens:,})\n"
msg += f"\n💡 /清空记忆 清空上下文 | /记忆列表 查看持久记忆"
await bot.send_text(from_wxid, msg)
logger.info(f"已发送上下文统计: {chat_id}")
except Exception as e:
logger.error(f"获取上下文统计失败: {e}")
await bot.send_text(from_wxid, f"❌ 获取上下文统计失败: {str(e)}")
async def _handle_switch_prompt(self, bot, from_wxid: str, content: str):
"""处理切换人设指令"""
try:
# 提取文件名
parts = content.split(maxsplit=1)
if len(parts) < 2:
await bot.send_text(from_wxid, "❌ 请指定人设文件名\n格式:/切人设 文件名.txt")
return
filename = parts[1].strip()
# 检查文件是否存在
prompt_path = Path(__file__).parent / "prompts" / filename
if not prompt_path.exists():
await bot.send_text(from_wxid, f"❌ 人设文件不存在: {filename}")
return
# 读取新人设
with open(prompt_path, "r", encoding="utf-8") as f:
new_prompt = f.read().strip()
# 更新人设
self.system_prompt = new_prompt
self.config["prompt"]["system_prompt_file"] = filename
await bot.send_text(from_wxid, f"✅ 已切换人设: {filename}")
logger.success(f"管理员切换人设: {filename}")
except Exception as e:
logger.error(f"切换人设失败: {e}")
await bot.send_text(from_wxid, f"❌ 切换人设失败: {str(e)}")
@on_text_message(priority=80)
async def handle_message(self, bot, message: dict):
"""处理文本消息"""
content = message.get("Content", "").strip()
from_wxid = message.get("FromWxid", "")
sender_wxid = message.get("SenderWxid", "")
is_group = message.get("IsGroup", False)
# 获取实际发送者
user_wxid = sender_wxid if is_group else from_wxid
# 获取机器人 wxid 和管理员列表
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
admins = main_config.get("Bot", {}).get("admins", [])
command_content = content
if is_group and bot_nickname:
command_content = self._strip_leading_bot_mention(content, bot_nickname)
# 检查是否是人设列表指令(精确匹配)
if command_content == "/人设列表":
await self._handle_list_prompts(bot, from_wxid)
return False
# 昵称测试:返回“微信昵称(全局)”和“群昵称/群名片(群内)”
if command_content == "/昵称测试":
if not is_group:
await bot.send_text(from_wxid, "该指令仅支持群聊:/昵称测试")
return False
wechat_nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
group_nickname = await self._get_group_display_name(bot, from_wxid, user_wxid, force_refresh=True)
wechat_nickname = self._sanitize_speaker_name(wechat_nickname) or "(未获取到)"
group_nickname = self._sanitize_speaker_name(group_nickname) or "(未设置/未获取到)"
await bot.send_text(
from_wxid,
f"微信昵称: {wechat_nickname}\n"
f"群昵称: {group_nickname}",
)
return False
# 检查是否是切换人设指令(精确匹配前缀)
if command_content.startswith("/切人设 ") or command_content.startswith("/切换人设 "):
if user_wxid in admins:
await self._handle_switch_prompt(bot, from_wxid, command_content)
else:
await bot.send_text(from_wxid, "❌ 仅管理员可以切换人设")
return False
# 检查是否是清空记忆指令
clear_command = self.config.get("memory", {}).get("clear_command", "/清空记忆")
if command_content == clear_command:
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
self._clear_memory(chat_id)
# 如果是群聊,还需要清空群聊历史
if is_group and self.store:
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
await self.store.clear_group_history(history_chat_id)
# 重置向量摘要水位线
if self._vector_memory_enabled and from_wxid in self._vector_watermarks:
self._vector_watermarks.pop(from_wxid, None)
self._save_watermarks()
await bot.send_text(from_wxid, "✅ 已清空当前群聊的记忆和历史记录")
else:
await bot.send_text(from_wxid, "✅ 已清空当前会话的记忆")
return False
# 检查是否是上下文统计指令
if command_content == "/context" or command_content == "/上下文":
await self._handle_context_stats(bot, from_wxid, user_wxid, is_group)
return False
# 旧群历史 key 扫描/清理(仅管理员)
if command_content in ("/旧群历史", "/legacy_history"):
if user_wxid in admins and self.store:
legacy_keys = self.store.find_legacy_group_history_keys()
if legacy_keys:
await bot.send_text(
from_wxid,
f"⚠️ 检测到 {len(legacy_keys)} 个旧版群历史 keysafe_id 写入)。\n"
f"如需清理请发送 /清理旧群历史",
)
else:
await bot.send_text(from_wxid, "✅ 未发现旧版群历史 key")
else:
await bot.send_text(from_wxid, "❌ 仅管理员可执行该指令")
return False
if command_content in ("/清理旧群历史", "/clean_legacy_history"):
if user_wxid in admins and self.store:
legacy_keys = self.store.find_legacy_group_history_keys()
deleted = self.store.delete_legacy_group_history_keys(legacy_keys)
await bot.send_text(
from_wxid,
f"✅ 已清理旧版群历史 key: {deleted}",
)
else:
await bot.send_text(from_wxid, "❌ 仅管理员可执行该指令")
return False
# 检查是否是记忆状态指令(仅管理员)
if command_content == "/记忆状态":
if user_wxid in admins:
if is_group:
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
history = await self._load_history(history_chat_id)
history = self._filter_history_by_window(history)
max_context = self.config.get("history", {}).get("max_context", 50)
context_count = min(len(history), max_context)
msg = f"📊 群聊记忆: {len(history)}\n"
msg += f"💬 AI可见: 最近 {context_count}"
await bot.send_text(from_wxid, msg)
else:
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
memory = self._get_memory_messages(chat_id)
msg = f"📊 私聊记忆: {len(memory)}"
await bot.send_text(from_wxid, msg)
else:
await bot.send_text(from_wxid, "❌ 仅管理员可以查看记忆状态")
return False
# 持久记忆相关指令
# 记录持久记忆:/记录 xxx
if command_content.startswith("/记录 "):
memory_content = command_content[4:].strip()
if memory_content:
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
# 群聊用群ID私聊用用户ID
memory_chat_id = from_wxid if is_group else user_wxid
chat_type = "group" if is_group else "private"
memory_id = self._add_persistent_memory(
memory_chat_id, chat_type, user_wxid, nickname, memory_content
)
await bot.send_text(from_wxid, f"✅ 已记录到持久记忆 (ID: {memory_id})")
logger.info(f"添加持久记忆: {memory_chat_id} - {memory_content[:30]}...")
else:
await bot.send_text(from_wxid, "❌ 请输入要记录的内容\n格式:/记录 要记住的内容")
return False
# 查看持久记忆列表(所有人可用)
if command_content == "/记忆列表" or command_content == "/持久记忆":
memory_chat_id = from_wxid if is_group else user_wxid
memories = self._get_persistent_memories(memory_chat_id)
if memories:
msg = f"📋 持久记忆列表 (共 {len(memories)} 条)\n\n"
for m in memories:
time_str = m['time'][:16] if m['time'] else "未知"
content_preview = m['content'][:30] + "..." if len(m['content']) > 30 else m['content']
msg += f"[{m['id']}] {m['nickname']}: {content_preview}\n 📅 {time_str}\n"
msg += f"\n💡 删除记忆:/删除记忆 ID (管理员)"
else:
msg = "📋 暂无持久记忆"
await bot.send_text(from_wxid, msg)
return False
# 删除持久记忆(管理员)
if command_content.startswith("/删除记忆 "):
if user_wxid in admins:
try:
memory_id = int(command_content[6:].strip())
memory_chat_id = from_wxid if is_group else user_wxid
if self._delete_persistent_memory(memory_chat_id, memory_id):
await bot.send_text(from_wxid, f"✅ 已删除记忆 ID: {memory_id}")
else:
await bot.send_text(from_wxid, f"❌ 未找到记忆 ID: {memory_id}")
except ValueError:
await bot.send_text(from_wxid, "❌ 请输入有效的记忆ID\n格式:/删除记忆 ID")
else:
await bot.send_text(from_wxid, "❌ 仅管理员可以删除持久记忆")
return False
# 清空所有持久记忆(管理员)
if command_content == "/清空持久记忆":
if user_wxid in admins:
memory_chat_id = from_wxid if is_group else user_wxid
deleted_count = self._clear_persistent_memories(memory_chat_id)
await bot.send_text(from_wxid, f"✅ 已清空 {deleted_count} 条持久记忆")
else:
await bot.send_text(from_wxid, "❌ 仅管理员可以清空持久记忆")
return False
# 查看向量记忆(群聊可用)
if command_content == "/向量记忆" or command_content == "/vector_memory":
if not is_group:
await bot.send_text(from_wxid, "❌ 向量记忆仅在群聊中可用")
return False
if not self._vector_memory_enabled:
await bot.send_text(from_wxid, "❌ 向量记忆功能未启用")
return False
items = self._get_vector_memories_for_display(from_wxid)
if not items:
await bot.send_text(from_wxid, "📭 当前群聊暂无向量记忆")
return False
try:
html = self._build_vector_memory_html(items, from_wxid)
img_path = await self._render_vector_memory_image(html)
if img_path:
await bot.send_image(from_wxid, img_path)
# 清理临时文件
try:
Path(img_path).unlink(missing_ok=True)
except Exception:
pass
else:
# 渲染失败,降级为文本
msg = f"🧠 向量记忆 (共 {len(items)} 条摘要)\n\n"
for i, item in enumerate(items, 1):
preview = item['summary'][:80] + "..." if len(item['summary']) > 80 else item['summary']
msg += f"#{i} {preview}\n\n"
await bot.send_text(from_wxid, msg.strip())
except Exception as e:
logger.error(f"[VectorMemory] 展示失败: {e}")
await bot.send_text(from_wxid, f"❌ 向量记忆展示失败: {e}")
return False
# 检查是否应该回复
should_reply = self._should_reply(message, content, bot_wxid)
# 获取用户昵称(用于历史记录)- 使用缓存优化
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
# 提取实际消息内容(去除@),仅在需要回复时使用
actual_content = ""
if should_reply:
actual_content = self._extract_content(message, content)
# 保存到群组历史记录(默认全量保存;可配置为仅保存触发 AI 的消息,减少上下文污染/串线)
# 但如果是 AutoReply 触发的,跳过保存(消息已经在正常流程中保存过了)
if is_group and not message.get('_auto_reply_triggered'):
if self._should_capture_group_history(is_triggered=bool(should_reply)):
# mention 模式下,群聊里@机器人仅作为触发条件,不进入上下文,避免同一句话在上下文中出现两种形式(含@/不含@
trigger_mode = self.config.get("behavior", {}).get("trigger_mode", "mention")
history_content = content
if trigger_mode == "mention" and should_reply and actual_content:
history_content = actual_content
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
await self._add_to_history(history_chat_id, nickname, history_content, sender_wxid=user_wxid)
# 向量长期记忆:检查是否需要触发摘要
if self._vector_memory_enabled:
await self._maybe_trigger_summarize(from_wxid)
# 如果不需要回复,直接返回
if not should_reply:
return
# 限流检查(仅在需要回复时检查)
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
if not allowed:
rate_limit_config = self.config.get("rate_limit", {})
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
msg = msg.format(seconds=reset_time)
await bot.send_text(from_wxid, msg)
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
return False
if not actual_content:
return
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
async with self._reply_lock_context(chat_id):
logger.info(f"AI 处理消息: {actual_content[:50]}...")
try:
# 如果是 AutoReply 触发的,不重复添加用户消息(已在正常流程中添加)
if not message.get('_auto_reply_triggered'):
self._add_to_memory(chat_id, "user", actual_content)
# 群聊:消息已写入 history则不再重复附加到 LLM messages避免“同一句话发给AI两次”
history_enabled = bool(self.store) and self.config.get("history", {}).get("enabled", True)
captured_to_history = bool(
is_group
and history_enabled
and not message.get('_auto_reply_triggered')
and self._should_capture_group_history(is_triggered=True)
)
append_user_message = not captured_to_history
disable_tools = bool(
message.get("_auto_reply_triggered")
or message.get("_auto_reply_context")
or message.get("_disable_tools")
)
# 调用 AI API带重试机制
max_retries = self.config.get("api", {}).get("max_retries", 2)
response = None
last_error = None
for attempt in range(max_retries + 1):
try:
response = await self._call_ai_api(
actual_content,
bot,
from_wxid,
chat_id,
nickname,
user_wxid,
is_group,
append_user_message=append_user_message,
disable_tools=disable_tools,
)
# 检查返回值:
# - None: 工具调用已异步处理,不需要重试
# - "": 真正的空响应,需要重试
# - 有内容: 正常响应
if response is None:
# 工具调用,不重试
logger.info("AI 触发工具调用,已异步处理")
break
if response == "" and attempt < max_retries:
logger.warning(f"AI 返回空内容,重试 {attempt + 1}/{max_retries}")
await asyncio.sleep(1) # 等待1秒后重试
continue
break # 成功或已达到最大重试次数
except Exception as e:
last_error = e
if attempt < max_retries:
logger.warning(f"AI API 调用失败,重试 {attempt + 1}/{max_retries}: {e}")
await asyncio.sleep(1)
else:
raise
# 发送回复并添加到记忆
# 注意:如果返回 None 或空字符串,说明已经以其他形式处理了,不需要再发送文本
if response:
cleaned_response = self._sanitize_llm_output(response)
if cleaned_response:
await bot.send_text(from_wxid, cleaned_response)
await self._maybe_send_voice_reply(bot, from_wxid, cleaned_response, message=message)
self._add_to_memory(chat_id, "assistant", cleaned_response)
# 保存机器人回复到历史记录
history_config = self.config.get("history", {})
sync_bot_messages = history_config.get("sync_bot_messages", False)
history_scope = str(history_config.get("scope", "chatroom") or "chatroom").strip().lower()
can_rely_on_hook = bool(sync_bot_messages and history_scope not in ("per_user", "user", "peruser"))
if is_group and not can_rely_on_hook:
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
await self._add_to_history(
history_chat_id,
bot_nickname,
cleaned_response,
role="assistant",
sender_wxid=user_wxid,
)
logger.success(f"AI 回复成功: {cleaned_response[:50]}...")
else:
logger.warning("AI 回复清洗后为空(可能只包含思维链/格式标记),已跳过发送")
else:
logger.info("AI 回复为空或已通过其他方式发送(如聊天记录)")
except Exception as e:
import traceback
error_detail = traceback.format_exc()
logger.error(f"AI 处理失败: {type(e).__name__}: {str(e)}")
logger.error(f"详细错误:\n{error_detail}")
await bot.send_text(from_wxid, "抱歉,我遇到了一些问题,请稍后再试。")
def _should_reply(self, message: dict, content: str, bot_wxid: str = None) -> bool:
"""判断是否应该回复"""
from_wxid = message.get("FromWxid", "")
logger.debug(f"[AIChat] _should_reply 检查: from={from_wxid}, content={content[:30]}")
# 检查是否由AutoReply插件触发
if message.get('_auto_reply_triggered'):
logger.debug(f"[AIChat] AutoReply 触发,返回 True")
return True
is_group = message.get("IsGroup", False)
# 检查群聊/私聊开关
if is_group and not self.config["behavior"].get("reply_group", True):
logger.debug(f"[AIChat] 群聊回复未启用,返回 False")
return False
if not is_group and not self.config["behavior"].get("reply_private", True):
return False
trigger_mode = self.config["behavior"].get("trigger_mode", "mention")
# all 模式:回复所有消息
if trigger_mode == "all":
return True
# mention 模式:检查是否@了机器人
if trigger_mode == "mention":
if is_group:
ats = message.get("Ats", [])
# 如果没有 bot_wxid从配置文件读取
if not bot_wxid:
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
else:
# 也需要读取昵称用于备用检测
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
# 方式1检查 @ 列表中是否包含机器人的 wxid
if ats and bot_wxid and bot_wxid in ats:
return True
# 方式2备用检测 - 从消息内容中检查是否包含 @机器人昵称
# (当 API 没有返回 at_user_list 时使用)
if bot_nickname and f"@{bot_nickname}" in content:
logger.debug(f"通过内容检测到 @{bot_nickname},触发回复")
return True
return False
else:
# 私聊直接回复
return True
# keyword 模式:检查关键词
if trigger_mode == "keyword":
keywords = self.config["behavior"]["keywords"]
return any(kw in content for kw in keywords)
return False
def _extract_content(self, message: dict, content: str) -> str:
"""提取实际消息内容(去除@等)"""
is_group = message.get("IsGroup", False)
if is_group:
# 群聊消息,去除@部分
# 格式通常是 "@昵称 消息内容"
parts = content.split(maxsplit=1)
if len(parts) > 1 and parts[0].startswith("@"):
return parts[1].strip()
return content.strip()
return content.strip()
def _strip_leading_bot_mention(self, content: str, bot_nickname: str) -> str:
"""去除开头的 @机器人昵称,便于识别命令"""
if not bot_nickname:
return content
prefix = f"@{bot_nickname}"
if not content.startswith(prefix):
return content
parts = content.split(maxsplit=1)
if len(parts) < 2:
return ""
return parts[1].strip()
async def _call_ai_api(
self,
user_message: str,
bot=None,
from_wxid: str = None,
chat_id: str = None,
nickname: str = "",
user_wxid: str = None,
is_group: bool = False,
*,
append_user_message: bool = True,
tool_query: str | None = None,
disable_tools: bool = False,
) -> str:
"""调用 AI API"""
api_config = self.config["api"]
# 收集工具
if disable_tools:
all_tools = []
available_tool_names = set()
tools = []
logger.info("AutoReply 模式:已禁用工具调用")
else:
all_tools = self._collect_tools()
available_tool_names = {
t.get("function", {}).get("name", "")
for t in (all_tools or [])
if isinstance(t, dict) and t.get("function", {}).get("name")
}
selected_tools = await self._select_tools_for_message_async(all_tools, user_message=user_message, tool_query=tool_query)
tools = self._prepare_tools_for_llm(selected_tools)
logger.info(f"收集到 {len(all_tools)} 个工具函数,本次启用 {len(tools)}")
if tools:
tool_names = [t["function"]["name"] for t in tools]
logger.info(f"本次启用工具: {tool_names}")
# 构建消息列表
system_content = self.system_prompt
# 添加当前时间信息
current_time = datetime.now()
weekday_map = {
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
4: "星期五", 5: "星期六", 6: "星期日"
}
weekday = weekday_map[current_time.weekday()]
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
system_content += f"\n\n当前时间:{time_str}"
if nickname:
system_content += f"\n当前对话用户的昵称是:{nickname}"
if self._tool_rule_prompt_enabled:
system_content += self._build_tool_rules_prompt(tools)
# 加载持久记忆
memory_chat_id = from_wxid if is_group else user_wxid
if memory_chat_id:
persistent_memories = self._get_persistent_memories(memory_chat_id)
if persistent_memories:
system_content += "\n\n【持久记忆】以下是用户要求你记住的重要信息:\n"
for m in persistent_memories:
mem_time = m['time'][:10] if m['time'] else ""
system_content += f"- [{mem_time}] {m['nickname']}: {m['content']}\n"
# 向量长期记忆检索
if is_group and from_wxid and self._vector_memory_enabled:
vector_mem = await self._retrieve_vector_memories(from_wxid, user_message)
if vector_mem:
system_content += vector_mem
messages = [{"role": "system", "content": system_content}]
# 从 JSON 历史记录加载上下文(仅群聊)
if is_group and from_wxid:
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid or "")
history = await self._load_history(history_chat_id)
history = self._filter_history_by_window(history)
max_context = self.config.get("history", {}).get("max_context", 50)
# 取最近的 N 条消息作为上下文
recent_history = history[-max_context:] if len(history) > max_context else history
# 转换为 AI 消息格式(按 role
self._append_group_history_messages(messages, recent_history)
else:
# 私聊使用原有的 memory 机制
if chat_id:
memory_messages = self._get_memory_messages(chat_id)
if memory_messages and len(memory_messages) > 1:
messages.extend(memory_messages[:-1])
# 添加当前用户消息
if append_user_message:
current_marker = "【当前消息】"
if is_group and nickname:
# 群聊使用结构化格式,当前消息使用当前时间
current_time = datetime.now().strftime("%Y-%m-%d %H:%M")
formatted_content = self._format_user_message_content(nickname, user_message, current_time, "text")
formatted_content = f"{current_marker}\n{formatted_content}"
messages.append({"role": "user", "content": formatted_content})
else:
messages.append({"role": "user", "content": f"{current_marker}\n{user_message}"})
async def _finalize_response(full_content: str, tool_calls_data: list):
# 过滤掉模型“幻觉出来”的工具调用(未在本次请求提供 tools 的情况下不应执行)
allowed_tool_names = {
t.get("function", {}).get("name", "")
for t in (tools or [])
if isinstance(t, dict) and t.get("function", {}).get("name")
}
if tool_calls_data:
unsupported = []
filtered = []
for tc in tool_calls_data:
fn = (tc or {}).get("function", {}).get("name", "")
if not fn:
continue
if not allowed_tool_names or fn not in allowed_tool_names:
unsupported.append(fn)
continue
filtered.append(tc)
if unsupported:
logger.warning(f"检测到未提供/未知的工具调用,已忽略: {unsupported}")
tool_calls_data = filtered
# 兼容:模型偶发输出“文本工具调用”写法(不走 tool_calls尝试转成真实工具调用
if not tool_calls_data and full_content:
legacy = self._extract_legacy_text_search_tool_call(full_content)
if legacy:
legacy_tool, legacy_args = legacy
# 兼容:有的模型会用旧名字/文本格式输出搜索工具调用
# 1) 优先映射到“本次提供给模型的工具”(尊重 smart_select
# 2) 若本次未提供搜索工具但用户确实在问信息类问题,可降级启用全局可用的搜索工具(仅限搜索)
preferred = None
if legacy_tool in allowed_tool_names:
preferred = legacy_tool
elif "tavily_web_search" in allowed_tool_names:
preferred = "tavily_web_search"
elif "web_search" in allowed_tool_names:
preferred = "web_search"
elif self._looks_like_info_query(user_message):
if "tavily_web_search" in available_tool_names:
preferred = "tavily_web_search"
elif "web_search" in available_tool_names:
preferred = "web_search"
if preferred:
logger.warning(f"检测到文本形式工具调用,已转换为 Function Calling: {preferred}")
try:
if bot and from_wxid:
await bot.send_text(from_wxid, "我帮你查一下,稍等。")
except Exception:
pass
tool_calls_data = [
{
"id": f"legacy_{uuid.uuid4().hex[:8]}",
"type": "function",
"function": {
"name": preferred,
"arguments": json.dumps(legacy_args, ensure_ascii=False),
},
}
]
# 兼容:文本输出的绘图工具调用 JSON / python 调用
if not tool_calls_data and full_content:
legacy_img = self._extract_legacy_text_image_tool_call(full_content)
if legacy_img:
legacy_tool, legacy_args = legacy_img
tools_cfg = (self.config or {}).get("tools", {})
loose_image_tool = tools_cfg.get("loose_image_tool", True)
preferred = self._resolve_image_tool_alias(
legacy_tool,
allowed_tool_names,
available_tool_names,
loose_image_tool,
)
if preferred:
logger.warning(f"检测到文本绘图工具调用,已转换为 Function Calling: {preferred}")
tool_calls_data = [
{
"id": f"legacy_img_{uuid.uuid4().hex[:8]}",
"type": "function",
"function": {
"name": preferred,
"arguments": json.dumps(legacy_args, ensure_ascii=False),
},
}
]
if not tool_calls_data and allowed_tool_names and full_content:
if self._contains_tool_call_markers(full_content):
fallback_tool = None
if "tavily_web_search" in allowed_tool_names:
fallback_tool = "tavily_web_search"
elif "web_search" in allowed_tool_names:
fallback_tool = "web_search"
if fallback_tool:
fallback_query = self._extract_tool_intent_text(user_message, tool_query=tool_query) or user_message
fallback_query = str(fallback_query or "").strip()
if fallback_query:
logger.warning(f"检测到文本工具调用但未解析成功,已兜底调用: {fallback_tool}")
try:
if bot and from_wxid:
await bot.send_text(from_wxid, "我帮你查一下,稍等。")
except Exception:
pass
tool_calls_data = [
{
"id": f"fallback_{uuid.uuid4().hex[:8]}",
"type": "function",
"function": {
"name": fallback_tool,
"arguments": json.dumps({"query": fallback_query[:400]}, ensure_ascii=False),
},
}
]
if not tool_calls_data and allowed_tool_names and self._looks_like_lyrics_query(user_message):
fallback_tool = None
if "tavily_web_search" in allowed_tool_names:
fallback_tool = "tavily_web_search"
elif "web_search" in allowed_tool_names:
fallback_tool = "web_search"
if fallback_tool:
fallback_query = self._extract_tool_intent_text(user_message, tool_query=tool_query) or user_message
fallback_query = str(fallback_query or "").strip()
if fallback_query:
logger.warning(f"歌词检索未触发工具,已兜底调用: {fallback_tool}")
try:
if bot and from_wxid:
await bot.send_text(from_wxid, "我帮你查一下这句歌词,稍等。")
except Exception:
pass
tool_calls_data = [
{
"id": f"lyrics_{uuid.uuid4().hex[:8]}",
"type": "function",
"function": {
"name": fallback_tool,
"arguments": json.dumps({"query": fallback_query[:400]}, ensure_ascii=False),
},
}
]
logger.info(f"流式/非流式 API 响应完成, 内容长度: {len(full_content)}, 工具调用数: {len(tool_calls_data)}")
# 检查是否有函数调用
if tool_calls_data:
# 提示已在流式处理中发送,直接启动工具执行
logger.info(f"启动工具执行,共 {len(tool_calls_data)} 个工具")
try:
await self._record_tool_calls_to_context(
tool_calls_data,
from_wxid=from_wxid,
chat_id=chat_id,
is_group=is_group,
user_wxid=user_wxid,
)
except Exception as e:
logger.debug(f"记录工具调用到上下文失败: {e}")
if self._tool_async:
asyncio.create_task(
self._execute_tools_async(
tool_calls_data, bot, from_wxid, chat_id,
user_wxid, nickname, is_group, messages
)
)
else:
await self._execute_tools_async(
tool_calls_data, bot, from_wxid, chat_id,
user_wxid, nickname, is_group, messages
)
# 返回 None 表示工具调用已异步处理,不需要重试
return None
# 检查是否包含错误的工具调用格式
if "<tool_code>" in full_content or re.search(
r"(?i)\bprint\s*\(\s*(draw_image|generate_image|nano_ai_image_generation|flow2_ai_image_generation|jimeng_ai_image_generation|kiira2_ai_image_generation)\s*\(",
full_content,
):
logger.warning("检测到模型输出了错误的工具调用格式,拦截并返回提示")
return "抱歉,我遇到了一些技术问题,请重新描述一下你的需求~"
return self._sanitize_llm_output(full_content)
try:
if tools:
logger.debug(f"已将 {len(tools)} 个工具添加到请求中")
full_content, tool_calls_data = await self._send_dialog_api_request(
api_config,
messages,
tools,
request_tag="[对话]",
prefer_stream=True,
max_tokens=api_config.get("max_tokens", 4096),
)
return await _finalize_response(full_content, tool_calls_data)
except Exception as e:
logger.error(f"调用对话 API 失败: {e}")
raise
async def _load_history(self, chat_id: str) -> list:
"""异步读取群聊历史(委托 ContextStore"""
if not self.store:
return []
return await self.store.load_group_history(chat_id)
async def _add_to_history(
self,
chat_id: str,
nickname: str,
content: str,
image_base64: str = None,
*,
role: str = "user",
sender_wxid: str = None,
):
"""将消息存入群聊历史(委托 ContextStore"""
if not self.store:
return
await self.store.add_group_message(
chat_id,
nickname,
content,
image_base64=image_base64,
role=role,
sender_wxid=sender_wxid,
)
async def _add_to_history_with_id(
self,
chat_id: str,
nickname: str,
content: str,
record_id: str,
*,
role: str = "user",
sender_wxid: str = None,
):
"""带ID的历史追加, 便于后续更新(委托 ContextStore"""
if not self.store:
return
await self.store.add_group_message(
chat_id,
nickname,
content,
record_id=record_id,
role=role,
sender_wxid=sender_wxid,
)
async def _update_history_by_id(self, chat_id: str, record_id: str, new_content: str):
"""根据ID更新历史记录委托 ContextStore"""
if not self.store:
return
await self.store.update_group_message_by_id(chat_id, record_id, new_content)
def _prepare_tool_calls_for_executor(
self,
tool_calls_data: list,
messages: list,
*,
user_wxid: str,
from_wxid: str,
is_group: bool,
image_base64: str | None = None,
) -> list:
prepared = []
if not tool_calls_data:
return prepared
for tool_call in tool_calls_data:
function = (tool_call or {}).get("function") or {}
function_name = function.get("name", "")
if not function_name:
continue
tool_call_id = (tool_call or {}).get("id", "")
if not tool_call_id:
tool_call_id = f"call_{uuid.uuid4().hex[:8]}"
tool_call["id"] = tool_call_id
raw_arguments = function.get("arguments", "{}")
try:
arguments = json.loads(raw_arguments) if raw_arguments else {}
if not isinstance(arguments, dict):
arguments = {}
except Exception:
arguments = {}
if "function" not in tool_call:
tool_call["function"] = {}
tool_call["function"]["arguments"] = "{}"
if function_name in ("tavily_web_search", "web_search"):
raw_query = arguments.get("query", "")
cleaned_query = self._normalize_search_query(raw_query)
if cleaned_query:
arguments["query"] = cleaned_query[:400]
if "function" not in tool_call:
tool_call["function"] = {}
tool_call["function"]["arguments"] = json.dumps(arguments, ensure_ascii=False)
elif not arguments.get("query"):
fallback_query = self._extract_tool_intent_text(self._extract_last_user_text(messages))
fallback_query = str(fallback_query or "").strip()
if fallback_query:
arguments["query"] = fallback_query[:400]
if "function" not in tool_call:
tool_call["function"] = {}
tool_call["function"]["arguments"] = json.dumps(arguments, ensure_ascii=False)
exec_args = dict(arguments)
exec_args["user_wxid"] = user_wxid or from_wxid
exec_args["is_group"] = bool(is_group)
if image_base64 and function_name in ("flow2_ai_image_generation", "nano_ai_image_generation", "grok_video_generation"):
exec_args["image_base64"] = image_base64
logger.info("[异步-图片] 图生图工具,已添加图片数据")
prepared.append({
"id": tool_call_id,
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(exec_args, ensure_ascii=False),
},
})
return prepared
async def _execute_tools_async(self, tool_calls_data: list, bot, from_wxid: str,
chat_id: str, user_wxid: str, nickname: str, is_group: bool,
messages: list):
"""
异步执行工具调用(不阻塞主流程)
AI 已经先回复用户,这里异步执行工具,完成后发送结果
支持 need_ai_reply 标记:工具结果回传给 AI 继续对话(保留上下文和人设)
"""
try:
logger.info(f"开始异步执行 {len(tool_calls_data)} 个工具调用")
concurrency_config = (self.config or {}).get("tools", {}).get("concurrency", {})
max_concurrent = concurrency_config.get("max_concurrent", 5)
parallel_tools = True
if self._serial_reply:
max_concurrent = 1
parallel_tools = False
timeout_config = (self.config or {}).get("tools", {}).get("timeout", {})
default_timeout = timeout_config.get("default", 60)
executor = ToolExecutor(default_timeout=default_timeout, max_parallel=max_concurrent)
prepared_tool_calls = self._prepare_tool_calls_for_executor(
tool_calls_data,
messages,
user_wxid=user_wxid,
from_wxid=from_wxid,
is_group=is_group,
)
if not prepared_tool_calls:
logger.info("[异步] 没有可执行的工具调用")
return
logger.info(f"[异步] 开始执行 {len(prepared_tool_calls)} 个工具 (最大并发: {max_concurrent})")
results = await executor.execute_batch(prepared_tool_calls, bot, from_wxid, parallel=parallel_tools)
followup_results = []
for result in results:
function_name = result.name
tool_call_id = result.id
tool_message = self._sanitize_llm_output(result.message or "")
if result.success:
logger.success(f"[异步] 工具 {function_name} 执行成功")
else:
logger.warning(f"[异步] 工具 {function_name} 执行失败: {result.error or result.message}")
if self._tool_followup_ai_reply:
should_followup = result.need_ai_reply or ((not result.no_reply) and (not result.already_sent))
logger.info(f"[异步] 工具 {function_name}: need_ai_reply={result.need_ai_reply}, already_sent={result.already_sent}, no_reply={result.no_reply}, should_followup={should_followup}")
if should_followup:
followup_results.append({
"tool_call_id": tool_call_id,
"function_name": function_name,
"result": tool_message,
"success": result.success,
})
continue
logger.info(f"[异步] 工具 {function_name} 结果: need_ai_reply={result.need_ai_reply}, success={result.success}")
if result.need_ai_reply:
logger.info(f"[异步] 工具 {function_name} 需要 AI 回复,加入 followup_results")
followup_results.append({
"tool_call_id": tool_call_id,
"function_name": function_name,
"result": tool_message,
"success": result.success,
})
continue
if result.success and not result.already_sent and tool_message and not result.no_reply:
if result.send_result_text:
if tool_message:
await bot.send_text(from_wxid, tool_message)
else:
logger.warning(f"[异步] 工具 {function_name} 输出清洗后为空,已跳过发送")
if not result.success and not result.no_reply:
try:
if tool_message:
await bot.send_text(from_wxid, f"? {tool_message}")
else:
await bot.send_text(from_wxid, f"? {function_name} 执行失败")
except Exception:
pass
if result.save_to_memory and chat_id and tool_message:
self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {tool_message}")
if followup_results:
await self._continue_with_tool_results(
followup_results, bot, from_wxid, user_wxid, chat_id,
nickname, is_group, messages, tool_calls_data
)
logger.info(f"[异步] 所有工具执行完成")
except Exception as e:
logger.error(f"[异步] 工具执行总体异常: {e}")
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
try:
await bot.send_text(from_wxid, "? 工具执行过程中出现错误")
except:
pass
async def _continue_with_tool_results(self, tool_results: list, bot, from_wxid: str,
user_wxid: str, chat_id: str, nickname: str, is_group: bool,
messages: list, tool_calls_data: list):
"""
基于工具结果继续调用 AI 对话(保留上下文和人设)
用于 need_ai_reply=True 的工具,如联网搜索等
"""
import json
try:
logger.info(f"[工具回传] 开始基于 {len(tool_results)} 个工具结果继续对话")
# 构建包含工具调用和结果的消息
# 1. 添加 assistant 的工具调用消息
tool_calls_msg = []
for tool_call in tool_calls_data:
tool_call_id = tool_call.get("id", "")
function_name = tool_call.get("function", {}).get("name", "")
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
# 只添加需要 AI 回复的工具
for tr in tool_results:
if tr["tool_call_id"] == tool_call_id:
tool_calls_msg.append({
"id": tool_call_id,
"type": "function",
"function": {
"name": function_name,
"arguments": arguments_str
}
})
break
if tool_calls_msg:
messages.append({
"role": "assistant",
"content": None,
"tool_calls": tool_calls_msg
})
# 2. 添加工具结果消息
failed_items = []
for tr in tool_results:
if not bool(tr.get("success", True)):
failed_items.append(tr.get("function_name", "工具"))
messages.append({
"role": "tool",
"tool_call_id": tr["tool_call_id"],
"content": tr["result"]
})
# 搜索类工具回传强约束:先完整回答用户问题,再可选简短互动
search_tool_names = {"tavily_web_search", "web_search"}
has_search_tool = any(str(tr.get("function_name", "")) in search_tool_names for tr in tool_results)
if has_search_tool:
latest_user_text = self._extract_last_user_text(messages)
messages.append({
"role": "system",
"content": (
"你将基于联网搜索工具结果回答用户。"
"必须先完整回答用户原问题,覆盖所有子问题与关键细节,"
"并给出清晰要点与必要来源依据;"
"禁止只给寒暄/反问/引导句,禁止把问题再抛回用户。"
"若原问题包含多个子问题例如A和B必须逐项作答不得漏项。"
"**严禁输出任何 JSON 格式、函数调用格式或工具调用格式的内容。**"
"**只输出自然语言文本回复。**"
"用户原问题如下:" + str(latest_user_text or "")
)
})
if failed_items:
failed_list = "".join([str(x) for x in failed_items if x])
messages.append({
"role": "system",
"content": (
"你将基于工具返回结果向用户回复。"
"本轮部分工具执行失败(" + failed_list + ")。"
"请直接给出简洁、自然、可执行的中文总结:"
"先说明已获取到的有效结果,再明确失败项与可能原因,"
"最后给出下一步建议(如更换关键词/稍后重试/补充信息)。"
"不要输出 JSON、代码块或函数调用片段。"
)
})
# 3. 调用 AI 继续对话(默认不带 tools 参数,歌词搜歌场景允许放开 search_music
api_config = self.config["api"]
user_wxid = user_wxid or from_wxid
followup_tools = None # 默认不传工具
if self._should_allow_music_followup(messages, tool_calls_data):
followup_tools = [
t for t in (self._collect_tools() or [])
if (t.get("function", {}).get("name") == "search_music")
]
if not followup_tools:
followup_tools = None # 如果没找到音乐工具,设为 None
try:
full_content, tool_calls_data = await self._send_dialog_api_request(
api_config,
messages,
followup_tools,
request_tag="[工具回传]",
prefer_stream=True,
max_tokens=api_config.get("max_tokens", 4096),
)
except Exception as req_err:
logger.error(f"[工具回传] AI API 调用失败: {req_err}")
await bot.send_text(from_wxid, "❌ AI 处理工具结果失败")
return
if tool_calls_data and followup_tools:
allowed_tool_names = {
t.get("function", {}).get("name", "")
for t in followup_tools
if isinstance(t, dict) and t.get("function", {}).get("name")
}
filtered = []
for tc in tool_calls_data:
fn = (tc or {}).get("function", {}).get("name", "")
if fn and fn in allowed_tool_names:
filtered.append(tc)
tool_calls_data = filtered
if tool_calls_data:
await self._execute_tools_async(
tool_calls_data, bot, from_wxid, chat_id,
user_wxid, nickname, is_group, messages
)
return
# 发送 AI 的回复
if full_content.strip():
cleaned_content = self._sanitize_llm_output(full_content)
if cleaned_content:
await bot.send_text(from_wxid, cleaned_content)
await self._maybe_send_voice_reply(bot, from_wxid, cleaned_content)
logger.success(f"[工具回传] AI 回复完成,长度: {len(cleaned_content)}")
else:
logger.warning("[工具回传] AI 回复清洗后为空,已跳过发送")
# 保存到历史记录
if chat_id and cleaned_content:
self._add_to_memory(chat_id, "assistant", cleaned_content)
else:
logger.warning("[工具回传] AI 返回空内容")
if failed_items:
failed_list = "".join([str(x) for x in failed_items if x])
fallback_text = f"工具执行已完成,但部分步骤失败({failed_list})。请稍后重试,或换个更具体的问题我再帮你处理。"
else:
fallback_text = "工具执行已完成,但这次没生成可读回复。你可以让我基于结果再总结一次。"
await bot.send_text(from_wxid, fallback_text)
except Exception as e:
logger.error(f"[工具回传] 继续对话失败: {e}")
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
try:
await bot.send_text(from_wxid, "❌ 处理工具结果时出错")
except:
pass
async def _execute_tools_async_with_image(self, tool_calls_data: list, bot, from_wxid: str,
chat_id: str, user_wxid: str, nickname: str, is_group: bool,
messages: list, image_base64: str):
"""
异步执行工具调用(带图片参数,用于图生图等场景)
AI 已经先回复用户,这里异步执行工具,完成后发送结果
"""
try:
logger.info(f"[异步-图片] 开始执行 {len(tool_calls_data)} 个工具调用")
concurrency_config = (self.config or {}).get("tools", {}).get("concurrency", {})
max_concurrent = concurrency_config.get("max_concurrent", 5)
parallel_tools = True
if self._serial_reply:
max_concurrent = 1
parallel_tools = False
timeout_config = (self.config or {}).get("tools", {}).get("timeout", {})
default_timeout = timeout_config.get("default", 60)
executor = ToolExecutor(default_timeout=default_timeout, max_parallel=max_concurrent)
prepared_tool_calls = self._prepare_tool_calls_for_executor(
tool_calls_data,
messages,
user_wxid=user_wxid,
from_wxid=from_wxid,
is_group=is_group,
image_base64=image_base64,
)
if not prepared_tool_calls:
logger.info("[异步-图片] 没有可执行的工具调用")
return
logger.info(f"[异步-图片] 开始执行 {len(prepared_tool_calls)} 个工具 (最大并发: {max_concurrent})")
results = await executor.execute_batch(prepared_tool_calls, bot, from_wxid, parallel=parallel_tools)
followup_results = []
for result in results:
function_name = result.name
tool_call_id = result.id
tool_message = self._sanitize_llm_output(result.message or "")
if result.success:
logger.success(f"[异步-图片] 工具 {function_name} 执行成功")
else:
logger.warning(f"[异步-图片] 工具 {function_name} 执行失败: {result.error or result.message}")
if self._tool_followup_ai_reply:
should_followup = result.need_ai_reply or ((not result.no_reply) and (not result.already_sent))
logger.info(f"[异步] 工具 {function_name}: need_ai_reply={result.need_ai_reply}, already_sent={result.already_sent}, no_reply={result.no_reply}, should_followup={should_followup}")
if should_followup:
followup_results.append({
"tool_call_id": tool_call_id,
"function_name": function_name,
"result": tool_message,
"success": result.success,
})
continue
logger.info(f"[异步] 工具 {function_name} 结果: need_ai_reply={result.need_ai_reply}, success={result.success}")
if result.need_ai_reply:
logger.info(f"[异步] 工具 {function_name} 需要 AI 回复,加入 followup_results")
followup_results.append({
"tool_call_id": tool_call_id,
"function_name": function_name,
"result": tool_message,
"success": result.success,
})
continue
if result.success and not result.already_sent and tool_message and not result.no_reply:
if result.send_result_text:
if tool_message:
await bot.send_text(from_wxid, tool_message)
else:
logger.warning(f"[异步-图片] 工具 {function_name} 输出清洗后为空,已跳过发送")
if not result.success and not result.no_reply:
try:
if tool_message:
await bot.send_text(from_wxid, f"? {tool_message}")
else:
await bot.send_text(from_wxid, f"? {function_name} 执行失败")
except Exception:
pass
if result.save_to_memory and chat_id and tool_message:
self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {tool_message}")
if followup_results:
await self._continue_with_tool_results(
followup_results, bot, from_wxid, user_wxid, chat_id,
nickname, is_group, messages, tool_calls_data
)
logger.info(f"[异步-图片] 所有工具执行完成")
except Exception as e:
logger.error(f"[异步-图片] 工具执行总体异常: {e}")
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
try:
await bot.send_text(from_wxid, "? 工具执行过程中出现错误")
except:
pass
@on_quote_message(priority=79)
async def handle_quote_message(self, bot, message: dict):
"""处理引用消息(包含图片或记录指令)"""
content = message.get("Content", "").strip()
from_wxid = message.get("FromWxid", "")
sender_wxid = message.get("SenderWxid", "")
is_group = message.get("IsGroup", False)
user_wxid = sender_wxid if is_group else from_wxid
try:
# 群聊引用消息可能带有 "wxid:\n" 前缀,需要去除
xml_content = content
if is_group and ":\n" in content:
# 查找 XML 声明或 <msg> 标签的位置
xml_start = content.find("<?xml")
if xml_start == -1:
xml_start = content.find("<msg")
if xml_start > 0:
xml_content = content[xml_start:]
logger.debug(f"去除引用消息前缀,原长度: {len(content)}, 新长度: {len(xml_content)}")
# 解析XML获取标题和引用消息
root = ET.fromstring(xml_content)
title = root.find(".//title")
if title is None or not title.text:
logger.debug("引用消息没有标题,跳过")
return True
title_text = title.text.strip()
logger.info(f"收到引用消息,标题: {title_text[:50]}...")
# 检查是否是 /记录 指令(引用消息记录)
if title_text == "/记录" or title_text.startswith("/记录 "):
# 获取被引用的消息内容
refermsg = root.find(".//refermsg")
if refermsg is not None:
# 获取被引用消息的发送者昵称
refer_displayname = refermsg.find("displayname")
refer_nickname = refer_displayname.text if refer_displayname is not None and refer_displayname.text else "未知"
# 获取被引用消息的内容
refer_content_elem = refermsg.find("content")
if refer_content_elem is not None and refer_content_elem.text:
refer_text = refer_content_elem.text.strip()
# 如果是XML格式如图片尝试提取文本描述
if refer_text.startswith("<?xml") or refer_text.startswith("<"):
refer_text = f"[多媒体消息]"
else:
refer_text = "[空消息]"
# 组合记忆内容:被引用者说的话
memory_content = f"{refer_nickname}: {refer_text}"
# 如果 /记录 后面有额外备注,添加到记忆中
if title_text.startswith("/记录 "):
extra_note = title_text[4:].strip()
if extra_note:
memory_content += f" (备注: {extra_note})"
# 保存到持久记忆
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
memory_chat_id = from_wxid if is_group else user_wxid
chat_type = "group" if is_group else "private"
memory_id = self._add_persistent_memory(
memory_chat_id, chat_type, user_wxid, nickname, memory_content
)
await bot.send_text(from_wxid, f"✅ 已记录到持久记忆 (ID: {memory_id})\n📝 {memory_content[:50]}...")
logger.info(f"通过引用添加持久记忆: {memory_chat_id} - {memory_content[:30]}...")
else:
await bot.send_text(from_wxid, "❌ 无法获取被引用的消息")
return False
# 检查是否应该回复
if not self._should_reply_quote(message, title_text):
logger.debug("引用消息不满足回复条件")
return True
# 获取引用消息中的图片信息
refermsg = root.find(".//refermsg")
if refermsg is None:
logger.debug("引用消息中没有 refermsg 节点")
return True
refer_content = refermsg.find("content")
if refer_content is None or not refer_content.text:
logger.debug("引用消息中没有 content")
return True
# 检查被引用消息的类型
# type=1: 纯文本type=3: 图片type=43: 视频type=49: 应用消息(含聊天记录)
refer_type_elem = refermsg.find("type")
refer_type = int(refer_type_elem.text) if refer_type_elem is not None and refer_type_elem.text else 0
logger.debug(f"被引用消息类型: {refer_type}")
# 纯文本消息type=1如果@了机器人,转发给 AI 处理
if refer_type == 1:
if self._should_reply_quote(message, title_text):
# 获取被引用的文本内容
refer_content_elem = refermsg.find("content")
refer_text = refer_content_elem.text.strip() if refer_content_elem is not None and refer_content_elem.text else ""
# 获取被引用者昵称
refer_displayname = refermsg.find("displayname")
refer_nickname = refer_displayname.text if refer_displayname is not None and refer_displayname.text else "某人"
# 组合消息:引用内容 + 用户评论
# title_text 格式如 "@瑞依 评价下",需要去掉 @昵称 部分
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
user_comment = title_text
if bot_nickname:
# 移除 @机器人昵称(可能有空格分隔)
user_comment = user_comment.replace(f"@{bot_nickname}", "").strip()
# 构造给 AI 的消息
combined_message = f"[引用 {refer_nickname} 的消息:{refer_text}]\n{user_comment}"
logger.info(f"引用纯文本消息,转发给 AI: {combined_message[:80]}...")
# 调用 AI 处理
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
chat_id = from_wxid if is_group else user_wxid
# 保存用户消息到群组历史记录
history_enabled = bool(self.store) and self.config.get("history", {}).get("enabled", True)
sync_bot_messages = self.config.get("history", {}).get("sync_bot_messages", True)
if is_group and history_enabled:
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
await self._add_to_history(history_chat_id, nickname, combined_message, sender_wxid=user_wxid)
async with self._reply_lock_context(chat_id):
ai_response = await self._call_ai_api(
combined_message,
bot=bot,
from_wxid=from_wxid,
chat_id=chat_id,
nickname=nickname
)
if ai_response:
final_response = self._sanitize_llm_output(ai_response)
await bot.send_text(from_wxid, final_response)
# 保存 AI 回复到群组历史记录
if is_group and history_enabled and sync_bot_messages:
bot_nickname_display = main_config.get("Bot", {}).get("nickname", "AI")
await self._add_to_history(history_chat_id, bot_nickname_display, final_response, role="assistant")
return False
else:
logger.debug("引用的是纯文本消息且未@机器人,跳过")
return True
# 只处理图片(3)、视频(43)、应用消息(49含聊天记录)
if refer_type not in [3, 43, 49]:
logger.debug(f"引用的消息类型 {refer_type} 不支持处理")
return True
# 解码HTML实体
import html
refer_xml = html.unescape(refer_content.text)
# 被引用消息的内容也可能带有 "wxid:\n" 前缀,需要去除
if ":\n" in refer_xml:
xml_start = refer_xml.find("<?xml")
if xml_start == -1:
xml_start = refer_xml.find("<msg")
if xml_start > 0:
refer_xml = refer_xml[xml_start:]
logger.debug(f"去除被引用消息前缀")
# 尝试解析 XML
try:
refer_root = ET.fromstring(refer_xml)
except ET.ParseError as e:
logger.debug(f"被引用消息内容不是有效的 XML: {e}")
return True
# 尝试提取聊天记录信息type=19
recorditem = refer_root.find(".//recorditem")
# 尝试提取图片信息
img = refer_root.find(".//img")
# 尝试提取视频信息
video = refer_root.find(".//videomsg")
if img is None and video is None and recorditem is None:
logger.debug("引用的消息不是图片、视频或聊天记录")
return True
# 检查是否应该回复(提前检查,避免下载后才发现不需要回复)
if not self._should_reply_quote(message, title_text):
logger.debug("引用消息不满足回复条件")
return True
# 限流检查
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
if not allowed:
rate_limit_config = self.config.get("rate_limit", {})
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
msg = msg.format(seconds=reset_time)
await bot.send_text(from_wxid, msg)
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
return False
# 获取用户昵称 - 使用缓存优化
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
# 处理聊天记录消息type=19
if recorditem is not None:
return await self._handle_quote_chat_record(
bot, recorditem, title_text, from_wxid, user_wxid,
is_group, nickname, chat_id
)
# 处理视频消息
if video is not None:
# 提取 svrid消息ID用于新协议下载
svrid_elem = refermsg.find("svrid")
svrid = int(svrid_elem.text) if svrid_elem is not None and svrid_elem.text else 0
return await self._handle_quote_video(
bot, video, title_text, from_wxid, user_wxid,
is_group, nickname, chat_id, svrid
)
# 处理图片消息
# 提取 svrid 用于从缓存获取
svrid_elem = refermsg.find("svrid")
svrid = svrid_elem.text if svrid_elem is not None and svrid_elem.text else ""
logger.info(f"AI处理引用图片消息: {title_text[:50]}...")
# 1. 优先从 Redis 缓存获取(使用 svrid
image_base64 = ""
if svrid:
try:
from utils.redis_cache import get_cache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
media_key = f"image:svrid:{svrid}"
cached_data = redis_cache.get_cached_media(media_key, "image")
if cached_data:
logger.info(f"从缓存获取引用图片成功: {media_key}")
image_base64 = cached_data
except Exception as e:
logger.debug(f"从缓存获取图片失败: {e}")
# 2. 缓存未命中,提示用户
if not image_base64:
logger.warning(f"引用图片缓存未命中: svrid={svrid}")
await bot.send_text(from_wxid, "❌ 图片缓存已过期,请重新发送图片后再引用")
return False
logger.info("图片获取成功")
# 添加消息到记忆包含图片base64
self._add_to_memory(chat_id, "user", title_text, image_base64=image_base64)
# 保存用户引用图片消息到群组历史记录
if is_group and self._should_capture_group_history(is_triggered=True):
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
await self._add_to_history(
history_chat_id,
nickname,
title_text,
image_base64=image_base64,
sender_wxid=user_wxid,
)
# 调用AI API带图片
history_enabled = bool(self.store) and self.config.get("history", {}).get("enabled", True)
captured_to_history = bool(is_group and history_enabled and self._should_capture_group_history(is_triggered=True))
append_user_message = not captured_to_history
async with self._reply_lock_context(chat_id):
response = await self._call_ai_api_with_image(
title_text,
image_base64,
bot,
from_wxid,
chat_id,
nickname,
user_wxid,
is_group,
append_user_message=append_user_message,
tool_query=title_text,
)
if response:
cleaned_response = self._sanitize_llm_output(response)
if cleaned_response:
await bot.send_text(from_wxid, cleaned_response)
await self._maybe_send_voice_reply(bot, from_wxid, cleaned_response)
self._add_to_memory(chat_id, "assistant", cleaned_response)
# 保存机器人回复到历史记录
history_config = self.config.get("history", {})
sync_bot_messages = history_config.get("sync_bot_messages", False)
history_scope = str(history_config.get("scope", "chatroom") or "chatroom").strip().lower()
can_rely_on_hook = bool(sync_bot_messages and history_scope not in ("per_user", "user", "peruser"))
if is_group and not can_rely_on_hook:
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
await self._add_to_history(
history_chat_id,
bot_nickname,
cleaned_response,
role="assistant",
sender_wxid=user_wxid,
)
logger.success(f"AI回复成功: {cleaned_response[:50]}...")
else:
logger.warning("AI 回复清洗后为空,已跳过发送")
return False
except Exception as e:
logger.error(f"处理引用消息失败: {e}")
return True
async def _handle_quote_chat_record(self, bot, recorditem_elem, title_text: str, from_wxid: str,
user_wxid: str, is_group: bool, nickname: str, chat_id: str):
"""处理引用的聊天记录消息type=19"""
try:
logger.info(f"[聊天记录] 处理引用的聊天记录: {title_text[:50]}...")
# recorditem 的内容是 CDATA需要提取并解析
record_text = recorditem_elem.text
if not record_text:
logger.warning("[聊天记录] recorditem 内容为空")
await bot.send_text(from_wxid, "❌ 无法读取聊天记录内容")
return False
# 解析 recordinfo XML
try:
record_root = ET.fromstring(record_text)
except ET.ParseError as e:
logger.error(f"[聊天记录] 解析 recordinfo 失败: {e}")
await bot.send_text(from_wxid, "❌ 聊天记录格式解析失败")
return False
# 提取聊天记录内容
datalist = record_root.find(".//datalist")
chat_records = []
# 尝试从 datalist 解析完整消息
if datalist is not None:
for dataitem in datalist.findall("dataitem"):
source_name = dataitem.find("sourcename")
source_time = dataitem.find("sourcetime")
data_desc = dataitem.find("datadesc")
sender = source_name.text if source_name is not None and source_name.text else "未知"
time_str = source_time.text if source_time is not None and source_time.text else ""
content = data_desc.text if data_desc is not None and data_desc.text else ""
if content:
chat_records.append({
"sender": sender,
"time": time_str,
"content": content
})
# 如果 datalist 为空(引用消息的简化版本),尝试从 desc 获取摘要
if not chat_records:
desc_elem = record_root.find(".//desc")
if desc_elem is not None and desc_elem.text:
# desc 格式通常是 "发送者: 内容\n发送者: 内容"
desc_text = desc_elem.text.strip()
logger.info(f"[聊天记录] 从 desc 获取摘要内容: {desc_text[:100]}...")
chat_records.append({
"sender": "聊天记录摘要",
"time": "",
"content": desc_text
})
if not chat_records:
logger.warning("[聊天记录] 没有解析到任何消息")
await bot.send_text(from_wxid, "❌ 聊天记录中没有消息内容")
return False
logger.info(f"[聊天记录] 解析到 {len(chat_records)} 条消息")
# 构建聊天记录文本
record_title = record_root.find(".//title")
title = record_title.text if record_title is not None and record_title.text else "聊天记录"
chat_text = f"{title}\n\n"
for i, record in enumerate(chat_records, 1):
time_part = f" ({record['time']})" if record['time'] else ""
if record['sender'] == "聊天记录摘要":
# 摘要模式,直接显示内容
chat_text += f"{record['content']}\n\n"
else:
chat_text += f"[{record['sender']}{time_part}]:\n{record['content']}\n\n"
# 构造发送给 AI 的消息
user_question = title_text.strip() if title_text.strip() else "请分析这段聊天记录"
# 去除 @ 部分
if user_question.startswith("@"):
parts = user_question.split(maxsplit=1)
if len(parts) > 1:
user_question = parts[1].strip()
else:
user_question = "请分析这段聊天记录"
combined_message = f"[用户发送了一段聊天记录,请阅读并回答问题]\n\n{chat_text}\n[用户的问题]: {user_question}"
logger.info(f"[聊天记录] 发送给 AI消息长度: {len(combined_message)}")
# 添加到记忆
self._add_to_memory(chat_id, "user", combined_message)
# 如果是群聊,添加到历史记录
if is_group and self._should_capture_group_history(is_triggered=True):
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
await self._add_to_history(
history_chat_id,
nickname,
f"[发送了聊天记录] {user_question}",
sender_wxid=user_wxid,
)
async with self._reply_lock_context(chat_id):
# 调用 AI API
response = await self._call_ai_api(
combined_message,
bot,
from_wxid,
chat_id,
nickname,
user_wxid,
is_group,
tool_query=user_question,
)
if response:
cleaned_response = self._sanitize_llm_output(response)
if cleaned_response:
await bot.send_text(from_wxid, cleaned_response)
await self._maybe_send_voice_reply(bot, from_wxid, cleaned_response)
self._add_to_memory(chat_id, "assistant", cleaned_response)
# 保存机器人回复到历史记录
history_config = self.config.get("history", {})
sync_bot_messages = history_config.get("sync_bot_messages", False)
history_scope = str(history_config.get("scope", "chatroom") or "chatroom").strip().lower()
can_rely_on_hook = bool(sync_bot_messages and history_scope not in ("per_user", "user", "peruser"))
if is_group and not can_rely_on_hook:
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
await self._add_to_history(
history_chat_id,
bot_nickname,
cleaned_response,
role="assistant",
sender_wxid=user_wxid,
)
logger.success(f"[聊天记录] AI 回复成功: {cleaned_response[:50]}...")
else:
logger.warning("[聊天记录] AI 回复清洗后为空,已跳过发送")
else:
await bot.send_text(from_wxid, "❌ AI 回复生成失败")
return False
except Exception as e:
logger.error(f"[聊天记录] 处理失败: {e}")
import traceback
logger.error(traceback.format_exc())
await bot.send_text(from_wxid, "❌ 聊天记录处理出错")
return False
async def _handle_quote_video(self, bot, video_elem, title_text: str, from_wxid: str,
user_wxid: str, is_group: bool, nickname: str, chat_id: str, svrid: int = 0):
"""处理引用的视频消息 - 双AI架构"""
try:
# 检查视频识别功能是否启用
video_config = self.config.get("video_recognition", {})
if not video_config.get("enabled", True):
logger.info("[视频识别] 功能未启用")
await bot.send_text(from_wxid, "❌ 视频识别功能未启用")
return False
# 提取视频长度
total_len = int(video_elem.get("length", 0))
if not svrid or not total_len:
logger.warning(f"[视频识别] 视频信息不完整: svrid={svrid}, total_len={total_len}")
await bot.send_text(from_wxid, "❌ 无法获取视频信息")
return False
logger.info(f"[视频识别] 使用新协议下载引用视频: svrid={svrid}, len={total_len}")
await bot.send_text(from_wxid, "🎬 正在分析视频,请稍候...")
video_base64 = await self._download_video_by_id(bot, svrid, total_len)
if not video_base64:
logger.error("[视频识别] 视频下载失败")
await bot.send_text(from_wxid, "❌ 视频下载失败")
return False
logger.info("[视频识别] 视频下载和编码成功")
# ========== 第一步视频AI 分析视频内容 ==========
video_description = await self._analyze_video_content(video_base64, video_config)
if not video_description:
logger.error("[视频识别] 视频AI分析失败")
await bot.send_text(from_wxid, "❌ 视频分析失败")
return False
logger.info(f"[视频识别] 视频AI分析完成: {video_description[:100]}...")
# ========== 第二步主AI 基于视频描述生成回复 ==========
# 构造包含视频描述的用户消息
user_question = title_text.strip() if title_text.strip() else "这个视频讲了什么?"
combined_message = f"[用户发送了一个视频,以下是视频内容描述]\n{video_description}\n\n[用户的问题]\n{user_question}"
# 添加到记忆让主AI知道用户发了视频
self._add_to_memory(chat_id, "user", combined_message)
# 如果是群聊,添加到历史记录
if is_group and self._should_capture_group_history(is_triggered=True):
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
await self._add_to_history(
history_chat_id,
nickname,
f"[发送了一个视频] {user_question}",
sender_wxid=user_wxid,
)
async with self._reply_lock_context(chat_id):
# 调用主AI生成回复使用现有的 _call_ai_api 方法,继承完整上下文)
response = await self._call_ai_api(
combined_message,
bot,
from_wxid,
chat_id,
nickname,
user_wxid,
is_group,
tool_query=user_question,
)
if response:
cleaned_response = self._sanitize_llm_output(response)
if cleaned_response:
await bot.send_text(from_wxid, cleaned_response)
await self._maybe_send_voice_reply(bot, from_wxid, cleaned_response)
self._add_to_memory(chat_id, "assistant", cleaned_response)
# 保存机器人回复到历史记录
history_config = self.config.get("history", {})
sync_bot_messages = history_config.get("sync_bot_messages", False)
history_scope = str(history_config.get("scope", "chatroom") or "chatroom").strip().lower()
can_rely_on_hook = bool(sync_bot_messages and history_scope not in ("per_user", "user", "peruser"))
if is_group and not can_rely_on_hook:
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
await self._add_to_history(
history_chat_id,
bot_nickname,
cleaned_response,
role="assistant",
sender_wxid=user_wxid,
)
logger.success(f"[视频识别] 主AI回复成功: {cleaned_response[:50]}...")
else:
logger.warning("[视频识别] 主AI回复清洗后为空已跳过发送")
else:
await bot.send_text(from_wxid, "❌ AI 回复生成失败")
return False
except Exception as e:
logger.error(f"[视频识别] 处理视频失败: {e}")
import traceback
logger.error(traceback.format_exc())
await bot.send_text(from_wxid, "❌ 视频处理出错")
return False
async def _analyze_video_content(self, video_base64: str, video_config: dict) -> str:
"""视频AI专门分析视频内容委托给 ImageProcessor"""
if self._image_processor:
result = await self._image_processor.analyze_video(video_base64)
# 对结果做输出清洗
return self._sanitize_llm_output(result) if result else ""
logger.warning("ImageProcessor 未初始化,无法分析视频")
return ""
async def _download_and_encode_video(self, bot, cdnurl: str, aeskey: str) -> str:
"""下载视频并转换为 base64委托给 ImageProcessor"""
if self._image_processor:
return await self._image_processor.download_video(bot, cdnurl, aeskey)
logger.warning("ImageProcessor 未初始化,无法下载视频")
return ""
async def _download_video_by_id(self, bot, msg_id: int, total_len: int) -> str:
"""通过消息ID下载视频并转换为 base64用于引用消息委托给 ImageProcessor"""
if self._image_processor:
return await self._image_processor.download_video_by_id(bot, msg_id, total_len)
logger.warning("ImageProcessor 未初始化,无法下载视频")
return ""
async def _download_image_by_id(self, bot, msg_id: int, total_len: int, to_user: str = "", from_user: str = "") -> str:
"""通过消息ID下载图片并转换为 base64用于引用消息委托给 ImageProcessor"""
if self._image_processor:
return await self._image_processor.download_image_by_id(bot, msg_id, total_len, to_user, from_user)
logger.warning("ImageProcessor 未初始化,无法下载图片")
return ""
async def _download_image_by_cdn(self, bot, cdnurl: str, aeskey: str) -> str:
"""通过 CDN 信息下载图片并转换为 base64用于引用消息"""
if not cdnurl or not aeskey:
logger.warning("CDN 参数不完整,无法下载图片")
return ""
if self._image_processor:
return await self._image_processor.download_image_by_cdn(bot, cdnurl, aeskey)
logger.warning("ImageProcessor 未初始化,无法下载图片")
return ""
async def _call_ai_api_with_video(self, user_message: str, video_base64: str, bot=None,
from_wxid: str = None, chat_id: str = None,
nickname: str = "", user_wxid: str = None,
is_group: bool = False) -> str:
"""调用 Gemini 原生 API带视频- 继承完整上下文"""
try:
video_config = self.config.get("video_recognition", {})
# 使用视频识别专用配置
video_model = video_config.get("model", "gemini-3-pro-preview")
api_url = video_config.get("api_url", "https://api.functen.cn/v1beta/models")
api_key = video_config.get("api_key", self.config["api"]["api_key"])
# 构建完整的 API URL
full_url = f"{api_url}/{video_model}:generateContent"
# 构建系统提示(与 _call_ai_api 保持一致)
system_content = self.system_prompt
current_time = datetime.now()
weekday_map = {
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
4: "星期五", 5: "星期六", 6: "星期日"
}
weekday = weekday_map[current_time.weekday()]
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
system_content += f"\n\n当前时间:{time_str}"
if nickname:
system_content += f"\n当前对话用户的昵称是:{nickname}"
# 加载持久记忆
memory_chat_id = from_wxid if is_group else user_wxid
if memory_chat_id:
persistent_memories = self._get_persistent_memories(memory_chat_id)
if persistent_memories:
system_content += "\n\n【持久记忆】以下是用户要求你记住的重要信息:\n"
for m in persistent_memories:
mem_time = m['time'][:10] if m['time'] else ""
system_content += f"- [{mem_time}] {m['nickname']}: {m['content']}\n"
# 向量长期记忆检索
if is_group and from_wxid and self._vector_memory_enabled:
vector_mem = await self._retrieve_vector_memories(from_wxid, user_message)
if vector_mem:
system_content += vector_mem
# 构建历史上下文
history_context = ""
if is_group and from_wxid:
# 群聊:从 Redis/文件加载历史
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid or "")
history = await self._load_history(history_chat_id)
history = self._filter_history_by_window(history)
max_context = self.config.get("history", {}).get("max_context", 50)
recent_history = history[-max_context:] if len(history) > max_context else history
if recent_history:
history_context = "\n\n【最近的群聊记录】\n"
for msg in recent_history:
msg_nickname = msg.get("nickname", "")
msg_content = msg.get("content", "")
if isinstance(msg_content, list):
# 多模态内容,提取文本
for item in msg_content:
if item.get("type") == "text":
msg_content = item.get("text", "")
break
else:
msg_content = "[图片]"
# 限制单条消息长度
if len(str(msg_content)) > 200:
msg_content = str(msg_content)[:200] + "..."
history_context += f"[{msg_nickname}] {msg_content}\n"
else:
# 私聊:从 memory 加载
if chat_id:
memory_messages = self._get_memory_messages(chat_id)
if memory_messages:
history_context = "\n\n【最近的对话记录】\n"
for msg in memory_messages[-20:]: # 最近20条
role = msg.get("role", "")
content = msg.get("content", "")
if isinstance(content, list):
for item in content:
if item.get("type") == "text":
content = item.get("text", "")
break
else:
content = "[图片]"
role_name = "用户" if role == "user" else ""
if len(str(content)) > 200:
content = str(content)[:200] + "..."
history_context += f"[{role_name}] {content}\n"
# 从 data:video/mp4;base64,xxx 中提取纯 base64 数据
if video_base64.startswith("data:"):
video_base64 = video_base64.split(",", 1)[1]
# 构建完整提示(人设 + 历史 + 当前问题)
full_prompt = system_content + history_context + f"\n\n【当前】用户发送了一个视频并问:{user_message or '请描述这个视频的内容'}"
# 构建 Gemini 原生格式请求
payload = {
"contents": [
{
"parts": [
{"text": full_prompt},
{
"inline_data": {
"mime_type": "video/mp4",
"data": video_base64
}
}
]
}
],
"generationConfig": {
"maxOutputTokens": video_config.get("max_tokens", 8192)
}
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
timeout = aiohttp.ClientTimeout(total=video_config.get("timeout", 360))
# 配置代理
connector = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False) and PROXY_SUPPORT:
proxy_type = proxy_config.get("type", "socks5").upper()
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
try:
connector = ProxyConnector.from_url(proxy_url)
except Exception as e:
logger.warning(f"[视频识别] 代理配置失败: {e}")
logger.info(f"[视频识别] 调用 Gemini API: {full_url}")
logger.debug(f"[视频识别] 提示词长度: {len(full_prompt)} 字符")
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.post(full_url, json=payload, headers=headers) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"[视频识别] API 错误: {resp.status}, {error_text[:500]}")
return ""
# 解析 Gemini 响应格式
result = await resp.json()
# 详细记录响应(用于调试)
logger.info(f"[视频识别] API 响应 keys: {list(result.keys()) if isinstance(result, dict) else type(result)}")
# 检查是否有错误
if "error" in result:
logger.error(f"[视频识别] API 返回错误: {result['error']}")
return ""
# 检查 promptFeedback安全过滤信息
if "promptFeedback" in result:
feedback = result["promptFeedback"]
block_reason = feedback.get("blockReason", "")
if block_reason:
logger.warning(f"[视频识别] 请求被阻止,原因: {block_reason}")
logger.warning(f"[视频识别] 安全评级: {feedback.get('safetyRatings', [])}")
return "抱歉,视频内容无法分析(内容策略限制)。"
# 提取文本内容
full_content = ""
if "candidates" in result and result["candidates"]:
logger.info(f"[视频识别] candidates 数量: {len(result['candidates'])}")
for i, candidate in enumerate(result["candidates"]):
# 检查 finishReason
finish_reason = candidate.get("finishReason", "")
if finish_reason:
logger.info(f"[视频识别] candidate[{i}] finishReason: {finish_reason}")
if finish_reason == "SAFETY":
logger.warning(f"[视频识别] 内容被安全过滤: {candidate.get('safetyRatings', [])}")
return "抱歉,视频内容无法分析。"
content = candidate.get("content", {})
parts = content.get("parts", [])
logger.info(f"[视频识别] candidate[{i}] parts 数量: {len(parts)}")
for part in parts:
if "text" in part:
full_content += part["text"]
else:
# 没有 candidates记录完整响应
logger.error(f"[视频识别] 响应中没有 candidates: {str(result)[:500]}")
# 可能是上下文太长导致,记录 token 使用情况
if "usageMetadata" in result:
usage = result["usageMetadata"]
logger.warning(f"[视频识别] Token 使用: prompt={usage.get('promptTokenCount', 0)}, total={usage.get('totalTokenCount', 0)}")
logger.info(f"[视频识别] AI 响应完成,长度: {len(full_content)}")
# 如果没有内容,尝试简化重试
if not full_content:
logger.info("[视频识别] 尝试简化请求重试...")
return await self._call_ai_api_with_video_simple(
user_message or "请描述这个视频的内容",
video_base64,
video_config
)
return self._sanitize_llm_output(full_content)
except Exception as e:
logger.error(f"[视频识别] API 调用失败: {e}")
import traceback
logger.error(traceback.format_exc())
return ""
async def _call_ai_api_with_video_simple(self, user_message: str, video_base64: str, video_config: dict) -> str:
"""简化版视频识别 API 调用(不带上下文,用于降级重试)"""
try:
api_url = video_config.get("api_url", "https://api.functen.cn/v1beta/models")
api_key = video_config.get("api_key", self.config["api"]["api_key"])
model = video_config.get("model", "gemini-3-pro-preview")
full_url = f"{api_url}/{model}:generateContent"
# 简化请求:只发送用户问题和视频
payload = {
"contents": [
{
"parts": [
{"text": user_message},
{
"inline_data": {
"mime_type": "video/mp4",
"data": video_base64
}
}
]
}
],
"generationConfig": {
"maxOutputTokens": video_config.get("max_tokens", 8192)
}
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
timeout = aiohttp.ClientTimeout(total=video_config.get("timeout", 360))
logger.info(f"[视频识别-简化] 调用 API: {full_url}")
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(full_url, json=payload, headers=headers) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"[视频识别-简化] API 错误: {resp.status}, {error_text[:300]}")
return ""
result = await resp.json()
logger.info(f"[视频识别-简化] API 响应 keys: {list(result.keys())}")
# 提取文本
if "candidates" in result and result["candidates"]:
for candidate in result["candidates"]:
content = candidate.get("content", {})
for part in content.get("parts", []):
if "text" in part:
text = part["text"]
logger.info(f"[视频识别-简化] 成功,长度: {len(text)}")
return self._sanitize_llm_output(text)
logger.error(f"[视频识别-简化] 仍然没有 candidates: {str(result)[:300]}")
return ""
except Exception as e:
logger.error(f"[视频识别-简化] 失败: {e}")
return ""
def _should_reply_quote(self, message: dict, title_text: str) -> bool:
"""判断是否应该回复引用消息"""
is_group = message.get("IsGroup", False)
# 检查群聊/私聊开关
if is_group and not self.config["behavior"]["reply_group"]:
return False
if not is_group and not self.config["behavior"]["reply_private"]:
return False
trigger_mode = self.config["behavior"]["trigger_mode"]
# all模式回复所有消息
if trigger_mode == "all":
return True
# mention模式检查是否@了机器人
if trigger_mode == "mention":
if is_group:
# 方式1检查 Ats 字段(普通消息格式)
ats = message.get("Ats", [])
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
# 检查 Ats 列表
if bot_wxid and bot_wxid in ats:
return True
# 方式2检查标题中是否包含 @机器人昵称(引用消息格式)
# 引用消息的 @ 信息在 title 中,如 "@瑞依 评价下"
if bot_nickname and f"@{bot_nickname}" in title_text:
logger.debug(f"引用消息标题中检测到 @{bot_nickname}")
return True
return False
else:
return True
# keyword模式检查关键词
if trigger_mode == "keyword":
keywords = self.config["behavior"]["keywords"]
return any(kw in title_text for kw in keywords)
return False
async def _call_ai_api_with_image(
self,
user_message: str,
image_base64: str,
bot=None,
from_wxid: str = None,
chat_id: str = None,
nickname: str = "",
user_wxid: str = None,
is_group: bool = False,
*,
append_user_message: bool = True,
tool_query: str | None = None,
disable_tools: bool = False,
) -> str:
"""调用AI API带图片"""
api_config = self.config["api"]
if disable_tools:
all_tools = []
available_tool_names = set()
tools = []
logger.info("[图片] AutoReply 模式:已禁用工具调用")
else:
all_tools = self._collect_tools()
available_tool_names = {
t.get("function", {}).get("name", "")
for t in (all_tools or [])
if isinstance(t, dict) and t.get("function", {}).get("name")
}
selected_tools = await self._select_tools_for_message_async(all_tools, user_message=user_message, tool_query=tool_query)
tools = self._prepare_tools_for_llm(selected_tools)
logger.info(f"[图片] 收集到 {len(all_tools)} 个工具函数,本次启用 {len(tools)}")
if tools:
tool_names = [t["function"]["name"] for t in tools]
logger.info(f"[图片] 本次启用工具: {tool_names}")
# 构建消息列表
system_content = self.system_prompt
# 添加当前时间信息
current_time = datetime.now()
weekday_map = {
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
4: "星期五", 5: "星期六", 6: "星期日"
}
weekday = weekday_map[current_time.weekday()]
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
system_content += f"\n\n当前时间:{time_str}"
if nickname:
system_content += f"\n当前对话用户的昵称是:{nickname}"
if self._tool_rule_prompt_enabled:
system_content += self._build_tool_rules_prompt(tools)
# 加载持久记忆(与文本模式一致)
memory_chat_id = from_wxid if is_group else user_wxid
if memory_chat_id:
persistent_memories = self._get_persistent_memories(memory_chat_id)
if persistent_memories:
system_content += "\n\n【持久记忆】以下是用户要求你记住的重要信息:\n"
for m in persistent_memories:
mem_time = m['time'][:10] if m['time'] else ""
system_content += f"- [{mem_time}] {m['nickname']}: {m['content']}\n"
# 向量长期记忆检索
if is_group and from_wxid and self._vector_memory_enabled:
vector_mem = await self._retrieve_vector_memories(from_wxid, user_message)
if vector_mem:
system_content += vector_mem
messages = [{"role": "system", "content": system_content}]
# 添加历史上下文
if is_group and from_wxid:
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid or "")
history = await self._load_history(history_chat_id)
history = self._filter_history_by_window(history)
max_context = self.config.get("history", {}).get("max_context", 50)
recent_history = history[-max_context:] if len(history) > max_context else history
self._append_group_history_messages(messages, recent_history)
else:
if chat_id:
memory_messages = self._get_memory_messages(chat_id)
if memory_messages and len(memory_messages) > 1:
messages.extend(memory_messages[:-1])
# 添加当前用户消息(带图片)
if append_user_message:
current_marker = "【当前消息】"
if is_group and nickname:
# 群聊使用结构化格式
current_time = datetime.now().strftime("%Y-%m-%d %H:%M")
text_value = self._format_user_message_content(nickname, user_message, current_time, "image")
else:
text_value = user_message
text_value = f"{current_marker}\n{text_value}"
messages.append({
"role": "user",
"content": [
{"type": "text", "text": text_value},
{"type": "image_url", "image_url": {"url": image_base64}}
]
})
try:
if tools:
logger.debug(f"[图片] 已将 {len(tools)} 个工具添加到请求中")
full_content, tool_calls_data = await self._send_dialog_api_request(
api_config,
messages,
tools,
request_tag="[图片]",
prefer_stream=True,
max_tokens=api_config.get("max_tokens", 4096),
)
# 检查是否有函数调用
if tool_calls_data:
# 过滤掉模型“幻觉出来”的工具调用(未在本次请求提供 tools 的情况下不应执行)
allowed_tool_names = {
t.get("function", {}).get("name", "")
for t in (tools or [])
if isinstance(t, dict) and t.get("function", {}).get("name")
}
unsupported = []
filtered = []
for tc in tool_calls_data:
fn = (tc or {}).get("function", {}).get("name", "")
if not fn:
continue
if not allowed_tool_names or fn not in allowed_tool_names:
unsupported.append(fn)
continue
filtered.append(tc)
if unsupported:
logger.warning(f"[图片] 检测到未提供/未知的工具调用,已忽略: {unsupported}")
tool_calls_data = filtered
if tool_calls_data:
# 提示已在流式处理中发送,直接启动工具执行
logger.info(f"[图片] 启动工具执行,共 {len(tool_calls_data)} 个工具")
try:
await self._record_tool_calls_to_context(
tool_calls_data,
from_wxid=from_wxid,
chat_id=chat_id,
is_group=is_group,
user_wxid=user_wxid,
)
except Exception as e:
logger.debug(f"[图片] 记录工具调用到上下文失败: {e}")
if self._tool_async:
asyncio.create_task(
self._execute_tools_async_with_image(
tool_calls_data, bot, from_wxid, chat_id,
user_wxid, nickname, is_group, messages, image_base64
)
)
else:
await self._execute_tools_async_with_image(
tool_calls_data, bot, from_wxid, chat_id,
user_wxid, nickname, is_group, messages, image_base64
)
return None
# 兼容:文本形式工具调用
if full_content:
legacy = self._extract_legacy_text_search_tool_call(full_content)
if legacy:
legacy_tool, legacy_args = legacy
# 仅允许转成“本次实际提供给模型的工具”,避免绕过 smart_select
allowed_tool_names = {
t.get("function", {}).get("name", "")
for t in (tools or [])
if isinstance(t, dict) and t.get("function", {}).get("name")
}
preferred = None
if legacy_tool in allowed_tool_names:
preferred = legacy_tool
elif "tavily_web_search" in allowed_tool_names:
preferred = "tavily_web_search"
elif "web_search" in allowed_tool_names:
preferred = "web_search"
elif self._looks_like_info_query(user_message):
if "tavily_web_search" in available_tool_names:
preferred = "tavily_web_search"
elif "web_search" in available_tool_names:
preferred = "web_search"
if preferred:
logger.warning(f"[图片] 检测到文本形式工具调用,已转换为 Function Calling: {preferred}")
try:
if bot and from_wxid:
await bot.send_text(from_wxid, "我帮你查一下,稍等。")
except Exception:
pass
tool_calls_data = [
{
"id": f"legacy_{uuid.uuid4().hex[:8]}",
"type": "function",
"function": {
"name": preferred,
"arguments": json.dumps(legacy_args, ensure_ascii=False),
},
}
]
try:
await self._record_tool_calls_to_context(
tool_calls_data,
from_wxid=from_wxid,
chat_id=chat_id,
is_group=is_group,
user_wxid=user_wxid,
)
except Exception:
pass
if self._tool_async:
asyncio.create_task(
self._execute_tools_async_with_image(
tool_calls_data,
bot,
from_wxid,
chat_id,
user_wxid,
nickname,
is_group,
messages,
image_base64,
)
)
else:
await self._execute_tools_async_with_image(
tool_calls_data,
bot,
from_wxid,
chat_id,
user_wxid,
nickname,
is_group,
messages,
image_base64,
)
return None
# 兼容:文本形式绘图工具调用 JSON
if full_content:
legacy_img = self._extract_legacy_text_image_tool_call(full_content)
if legacy_img:
legacy_tool, legacy_args = legacy_img
tools_cfg = (self.config or {}).get("tools", {})
loose_image_tool = tools_cfg.get("loose_image_tool", True)
allowed_tool_names = {
t.get("function", {}).get("name", "")
for t in (tools or [])
if isinstance(t, dict) and t.get("function", {}).get("name")
}
preferred = self._resolve_image_tool_alias(
legacy_tool,
allowed_tool_names,
available_tool_names,
loose_image_tool,
)
if preferred:
logger.warning(f"[图片] 检测到文本绘图工具调用,已转换为 Function Calling: {preferred}")
tool_calls_data = [
{
"id": f"legacy_img_{uuid.uuid4().hex[:8]}",
"type": "function",
"function": {
"name": preferred,
"arguments": json.dumps(legacy_args, ensure_ascii=False),
},
}
]
try:
await self._record_tool_calls_to_context(
tool_calls_data,
from_wxid=from_wxid,
chat_id=chat_id,
is_group=is_group,
user_wxid=user_wxid,
)
except Exception:
pass
if self._tool_async:
asyncio.create_task(
self._execute_tools_async_with_image(
tool_calls_data,
bot,
from_wxid,
chat_id,
user_wxid,
nickname,
is_group,
messages,
image_base64,
)
)
else:
await self._execute_tools_async_with_image(
tool_calls_data,
bot,
from_wxid,
chat_id,
user_wxid,
nickname,
is_group,
messages,
image_base64,
)
return None
# 检查是否包含错误的工具调用格式
if "<tool_code>" in full_content or re.search(
r"(?i)\bprint\s*\(\s*(draw_image|generate_image|nano_ai_image_generation|flow2_ai_image_generation|jimeng_ai_image_generation|kiira2_ai_image_generation)\s*\(",
full_content,
):
logger.warning("检测到模型输出了错误的工具调用格式,拦截并返回提示")
return "抱歉,我遇到了一些技术问题,请重新描述一下你的需求~"
return self._sanitize_llm_output(full_content)
except Exception as e:
logger.error(f"调用AI API失败: {e}")
raise
async def _send_chat_records(self, bot, from_wxid: str, title: str, content: str):
"""发送聊天记录格式消息"""
try:
import uuid
import time
import hashlib
import xml.etree.ElementTree as ET
is_group = from_wxid.endswith("@chatroom")
# 自动分割内容
max_length = 800
content_parts = []
if len(content) <= max_length:
content_parts = [content]
else:
lines = content.split('\n')
current_part = ""
for line in lines:
if len(current_part + line + '\n') > max_length:
if current_part:
content_parts.append(current_part.strip())
current_part = line + '\n'
else:
content_parts.append(line[:max_length])
current_part = line[max_length:] + '\n'
else:
current_part += line + '\n'
if current_part.strip():
content_parts.append(current_part.strip())
recordinfo = ET.Element("recordinfo")
info_el = ET.SubElement(recordinfo, "info")
info_el.text = title
is_group_el = ET.SubElement(recordinfo, "isChatRoom")
is_group_el.text = "1" if is_group else "0"
datalist = ET.SubElement(recordinfo, "datalist")
datalist.set("count", str(len(content_parts)))
desc_el = ET.SubElement(recordinfo, "desc")
desc_el.text = title
fromscene_el = ET.SubElement(recordinfo, "fromscene")
fromscene_el.text = "3"
for i, part in enumerate(content_parts):
di = ET.SubElement(datalist, "dataitem")
di.set("datatype", "1")
di.set("dataid", uuid.uuid4().hex)
src_local_id = str((int(time.time() * 1000) % 90000) + 10000)
new_msg_id = str(int(time.time() * 1000) + i)
create_time = str(int(time.time()) - len(content_parts) + i)
ET.SubElement(di, "srcMsgLocalid").text = src_local_id
ET.SubElement(di, "sourcetime").text = time.strftime("%Y-%m-%d %H:%M", time.localtime(int(create_time)))
ET.SubElement(di, "fromnewmsgid").text = new_msg_id
ET.SubElement(di, "srcMsgCreateTime").text = create_time
ET.SubElement(di, "sourcename").text = "AI助手"
ET.SubElement(di, "sourceheadurl").text = ""
ET.SubElement(di, "datatitle").text = part
ET.SubElement(di, "datadesc").text = part
ET.SubElement(di, "datafmt").text = "text"
ET.SubElement(di, "ischatroom").text = "1" if is_group else "0"
dataitemsource = ET.SubElement(di, "dataitemsource")
ET.SubElement(dataitemsource, "hashusername").text = hashlib.sha256(from_wxid.encode("utf-8")).hexdigest()
record_xml = ET.tostring(recordinfo, encoding="unicode")
appmsg_parts = [
"<appmsg appid=\"\" sdkver=\"0\">",
f"<title>{title}</title>",
f"<des>{title}</des>",
"<type>19</type>",
"<url>https://support.weixin.qq.com/cgi-bin/mmsupport-bin/readtemplate?t=page/favorite_record__w_unsupport</url>",
"<appattach><cdnthumbaeskey></cdnthumbaeskey><aeskey></aeskey></appattach>",
f"<recorditem><![CDATA[{record_xml}]]></recorditem>",
"<percent>0</percent>",
"</appmsg>"
]
appmsg_xml = "".join(appmsg_parts)
# 使用新的 HTTP API 发送 XML 消息
await bot.send_xml(from_wxid, appmsg_xml)
logger.success(f"已发送聊天记录: {title}")
except Exception as e:
logger.error(f"发送聊天记录失败: {e}")
async def _process_image_to_history(self, bot, message: dict, content: str) -> bool:
"""处理图片/表情包并保存描述到 history通用方法"""
from_wxid = message.get("FromWxid", "")
sender_wxid = message.get("SenderWxid", "")
is_group = message.get("IsGroup", False)
user_wxid = sender_wxid if is_group else from_wxid
# 只处理群聊
if not is_group:
return True
# 检查是否启用图片描述功能
image_desc_config = self.config.get("image_description", {})
if not image_desc_config.get("enabled", True):
return True
try:
# 解析XML获取图片信息
root = ET.fromstring(content)
# 尝试查找 <img> 标签(图片消息)或 <emoji> 标签(表情包)
img = root.find(".//img")
if img is None:
img = root.find(".//emoji")
if img is None:
return True
cdnbigimgurl = img.get("cdnbigimgurl", "") or img.get("cdnurl", "")
aeskey = img.get("aeskey", "")
# 检查是否是表情包(有 cdnurl 但可能没有 aeskey
is_emoji = img.tag == "emoji"
if not cdnbigimgurl:
return True
# 图片消息需要 aeskey表情包不需要
if not is_emoji and not aeskey:
return True
# 获取用户昵称 - 使用缓存优化
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
# 立即插入占位符到 history
placeholder_id = str(uuid.uuid4())
await self._add_to_history_with_id(history_chat_id, nickname, "[图片: 处理中...]", placeholder_id)
logger.info(f"已插入图片占位符: {placeholder_id}")
# 将任务加入队列(不阻塞)
task = {
"bot": bot,
"history_chat_id": history_chat_id,
"nickname": nickname,
"cdnbigimgurl": cdnbigimgurl,
"aeskey": aeskey,
"is_emoji": is_emoji,
"placeholder_id": placeholder_id,
"config": image_desc_config,
"message": message # 添加完整的 message 对象供新接口使用
}
await self.image_desc_queue.put(task)
logger.info(f"图片描述任务已加入队列,当前队列长度: {self.image_desc_queue.qsize()}")
return True
except Exception as e:
logger.error(f"处理图片消息失败: {e}")
return True
async def _image_desc_worker(self):
"""图片描述工作协程,从队列中取任务并处理"""
while True:
try:
task = await self.image_desc_queue.get()
except asyncio.CancelledError:
logger.info("图片描述工作协程收到取消信号,退出")
break
try:
await self._generate_and_update_image_description(
task["bot"], task["history_chat_id"], task["nickname"],
task["cdnbigimgurl"], task["aeskey"], task["is_emoji"],
task["placeholder_id"], task["config"], task.get("message")
)
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"图片描述工作协程异常: {e}")
finally:
try:
self.image_desc_queue.task_done()
except ValueError:
pass
async def _generate_and_update_image_description(self, bot, history_chat_id: str, nickname: str,
cdnbigimgurl: str, aeskey: str, is_emoji: bool,
placeholder_id: str, image_desc_config: dict, message: dict = None):
"""异步生成图片描述并更新 history"""
try:
# 下载并编码图片/表情包
if is_emoji:
image_base64 = await self._download_emoji_and_encode(cdnbigimgurl)
else:
# 优先使用新接口(需要完整的 message 对象)
if message:
image_base64 = await self._download_and_encode_image(bot, message)
else:
# 降级:如果没有 message 对象,使用旧方法(但会失败)
logger.warning("缺少 message 对象,图片下载可能失败")
image_base64 = ""
if not image_base64:
logger.warning(f"{'表情包' if is_emoji else '图片'}下载失败")
await self._update_history_by_id(history_chat_id, placeholder_id, "[图片]")
return
# 调用 AI 生成图片描述
description_prompt = image_desc_config.get("prompt", "请用一句话简洁地描述这张图片的主要内容。")
description = await self._generate_image_description(image_base64, description_prompt, image_desc_config)
if description:
cleaned_description = self._sanitize_llm_output(description)
await self._update_history_by_id(history_chat_id, placeholder_id, f"[图片: {cleaned_description}]")
logger.success(f"已更新图片描述: {nickname} - {cleaned_description[:30]}...")
else:
await self._update_history_by_id(history_chat_id, placeholder_id, "[图片]")
logger.warning(f"图片描述生成失败")
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"异步生成图片描述失败: {e}")
await self._update_history_by_id(history_chat_id, placeholder_id, "[图片]")
@on_image_message(priority=15)
async def handle_image_message(self, bot, message: dict):
"""处理直接发送的图片消息(生成描述并保存到 history不触发 AI 回复)"""
logger.info("AIChat: handle_image_message 被调用")
content = message.get("Content", "")
return await self._process_image_to_history(bot, message, content)
@on_emoji_message(priority=15)
async def handle_emoji_message(self, bot, message: dict):
"""处理表情包消息(生成描述并保存到 history不触发 AI 回复)"""
logger.info("AIChat: handle_emoji_message 被调用")
content = message.get("Content", "")
return await self._process_image_to_history(bot, message, content)