4240 lines
190 KiB
Python
4240 lines
190 KiB
Python
"""
|
||
AI 聊天插件
|
||
|
||
支持自定义模型、API 和人设
|
||
支持 Redis 存储对话历史和限流
|
||
"""
|
||
|
||
import asyncio
|
||
import tomllib
|
||
import aiohttp
|
||
import json
|
||
import re
|
||
import time
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
from loguru import logger
|
||
from utils.plugin_base import PluginBase
|
||
from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message
|
||
from utils.redis_cache import get_cache
|
||
from utils.llm_tooling import ToolResult, collect_tools_with_plugins, collect_tools, get_tool_schema_map, validate_tool_arguments
|
||
import xml.etree.ElementTree as ET
|
||
import base64
|
||
import uuid
|
||
|
||
# 可选导入代理支持
|
||
try:
|
||
from aiohttp_socks import ProxyConnector
|
||
PROXY_SUPPORT = True
|
||
except ImportError:
|
||
PROXY_SUPPORT = False
|
||
logger.warning("aiohttp_socks 未安装,代理功能将不可用")
|
||
|
||
|
||
class AIChat(PluginBase):
|
||
"""AI 聊天插件"""
|
||
|
||
# 插件元数据
|
||
description = "AI 聊天插件,支持自定义模型和人设"
|
||
author = "ShiHao"
|
||
version = "1.0.0"
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.config = None
|
||
self.system_prompt = ""
|
||
self.memory = {} # 存储每个会话的记忆 {chat_id: [messages]}
|
||
self.history_dir = None # 历史记录目录
|
||
self.history_locks = {} # 每个会话一把锁
|
||
self.image_desc_queue = asyncio.Queue() # 图片描述任务队列
|
||
self.image_desc_workers = [] # 工作协程列表
|
||
self.persistent_memory_db = None # 持久记忆数据库路径
|
||
self.store = None # ContextStore 实例(统一存储)
|
||
self._chatroom_member_cache = {} # {chatroom_id: (ts, {wxid: display_name})}
|
||
self._chatroom_member_cache_locks = {} # {chatroom_id: asyncio.Lock}
|
||
self._chatroom_member_cache_ttl_seconds = 3600 # 群名片缓存1小时,减少协议 API 调用
|
||
|
||
async def async_init(self):
|
||
"""插件异步初始化"""
|
||
# 读取配置
|
||
config_path = Path(__file__).parent / "config.toml"
|
||
with open(config_path, "rb") as f:
|
||
self.config = tomllib.load(f)
|
||
|
||
# 读取人设
|
||
prompt_file = self.config["prompt"]["system_prompt_file"]
|
||
prompt_path = Path(__file__).parent / "prompts" / prompt_file
|
||
|
||
if prompt_path.exists():
|
||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||
self.system_prompt = f.read().strip()
|
||
logger.success(f"已加载人设: {prompt_file}")
|
||
else:
|
||
logger.warning(f"人设文件不存在: {prompt_file},使用默认人设")
|
||
self.system_prompt = "你是一个友好的 AI 助手。"
|
||
|
||
# 检查代理配置
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5")
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
logger.info(f"AI 聊天插件已启用代理: {proxy_type}://{proxy_host}:{proxy_port}")
|
||
|
||
# 初始化历史记录目录
|
||
history_config = self.config.get("history", {})
|
||
if history_config.get("enabled", True):
|
||
history_dir_name = history_config.get("history_dir", "history")
|
||
self.history_dir = Path(__file__).parent / history_dir_name
|
||
self.history_dir.mkdir(exist_ok=True)
|
||
logger.info(f"历史记录目录: {self.history_dir}")
|
||
|
||
# 启动图片描述工作协程(并发数为2)
|
||
for i in range(2):
|
||
worker = asyncio.create_task(self._image_desc_worker())
|
||
self.image_desc_workers.append(worker)
|
||
logger.info("已启动 2 个图片描述工作协程")
|
||
|
||
# 初始化持久记忆数据库与统一存储
|
||
from utils.context_store import ContextStore
|
||
db_dir = Path(__file__).parent / "data"
|
||
db_dir.mkdir(exist_ok=True)
|
||
self.persistent_memory_db = db_dir / "persistent_memory.db"
|
||
self.store = ContextStore(
|
||
self.config,
|
||
self.history_dir,
|
||
self.memory,
|
||
self.history_locks,
|
||
self.persistent_memory_db,
|
||
)
|
||
self.store.init_persistent_memory_db()
|
||
|
||
logger.info(f"AI 聊天插件已加载,模型: {self.config['api']['model']}")
|
||
|
||
async def on_disable(self):
|
||
"""插件禁用时调用,清理后台任务和队列"""
|
||
await super().on_disable()
|
||
|
||
# 取消图片描述工作协程,避免重载后叠加
|
||
if self.image_desc_workers:
|
||
for worker in self.image_desc_workers:
|
||
worker.cancel()
|
||
await asyncio.gather(*self.image_desc_workers, return_exceptions=True)
|
||
self.image_desc_workers.clear()
|
||
|
||
# 清空图片描述队列
|
||
try:
|
||
while self.image_desc_queue and not self.image_desc_queue.empty():
|
||
self.image_desc_queue.get_nowait()
|
||
self.image_desc_queue.task_done()
|
||
except Exception:
|
||
pass
|
||
self.image_desc_queue = asyncio.Queue()
|
||
|
||
logger.info("AIChat 已清理后台图片描述任务")
|
||
|
||
def _add_persistent_memory(self, chat_id: str, chat_type: str, user_wxid: str,
|
||
user_nickname: str, content: str) -> int:
|
||
"""添加持久记忆,返回记忆ID(委托 ContextStore)"""
|
||
if not self.store:
|
||
return -1
|
||
return self.store.add_persistent_memory(chat_id, chat_type, user_wxid, user_nickname, content)
|
||
|
||
def _get_persistent_memories(self, chat_id: str) -> list:
|
||
"""获取指定会话的所有持久记忆(委托 ContextStore)"""
|
||
if not self.store:
|
||
return []
|
||
return self.store.get_persistent_memories(chat_id)
|
||
|
||
def _delete_persistent_memory(self, chat_id: str, memory_id: int) -> bool:
|
||
"""删除指定的持久记忆(委托 ContextStore)"""
|
||
if not self.store:
|
||
return False
|
||
return self.store.delete_persistent_memory(chat_id, memory_id)
|
||
|
||
def _clear_persistent_memories(self, chat_id: str) -> int:
|
||
"""清空指定会话的所有持久记忆(委托 ContextStore)"""
|
||
if not self.store:
|
||
return 0
|
||
return self.store.clear_persistent_memories(chat_id)
|
||
|
||
def _get_chat_id(self, from_wxid: str, sender_wxid: str = None, is_group: bool = False) -> str:
|
||
"""获取会话ID"""
|
||
if is_group:
|
||
# 群聊使用 "群ID:用户ID" 组合,确保每个用户有独立的对话记忆
|
||
user_wxid = sender_wxid or from_wxid
|
||
return f"{from_wxid}:{user_wxid}"
|
||
else:
|
||
return sender_wxid or from_wxid # 私聊使用用户ID
|
||
|
||
def _get_group_history_chat_id(self, from_wxid: str, user_wxid: str = None) -> str:
|
||
"""获取群聊 history 的会话ID(可配置为全群共享或按用户隔离)"""
|
||
if not from_wxid:
|
||
return ""
|
||
|
||
history_config = (self.config or {}).get("history", {})
|
||
scope = str(history_config.get("scope", "chatroom") or "chatroom").strip().lower()
|
||
if scope in ("per_user", "user", "peruser"):
|
||
if not user_wxid:
|
||
return from_wxid
|
||
return self._get_chat_id(from_wxid, user_wxid, is_group=True)
|
||
|
||
return from_wxid
|
||
|
||
def _should_capture_group_history(self, *, is_triggered: bool) -> bool:
|
||
"""判断群聊消息是否需要写入 history(减少无关上下文污染)"""
|
||
history_config = (self.config or {}).get("history", {})
|
||
capture = str(history_config.get("capture", "all") or "all").strip().lower()
|
||
|
||
if capture in ("none", "off", "disable", "disabled"):
|
||
return False
|
||
if capture in ("reply", "ai_only", "triggered"):
|
||
return bool(is_triggered)
|
||
return True
|
||
|
||
def _parse_history_timestamp(self, ts) -> float | None:
|
||
if ts is None:
|
||
return None
|
||
if isinstance(ts, (int, float)):
|
||
return float(ts)
|
||
if isinstance(ts, str):
|
||
s = ts.strip()
|
||
if not s:
|
||
return None
|
||
try:
|
||
return float(s)
|
||
except Exception:
|
||
pass
|
||
try:
|
||
return datetime.fromisoformat(s).timestamp()
|
||
except Exception:
|
||
return None
|
||
return None
|
||
|
||
def _filter_history_by_window(self, history: list) -> list:
|
||
history_config = (self.config or {}).get("history", {})
|
||
window_seconds = history_config.get("context_window_seconds", None)
|
||
if window_seconds is None:
|
||
window_seconds = history_config.get("window_seconds", 0)
|
||
try:
|
||
window_seconds = float(window_seconds or 0)
|
||
except Exception:
|
||
window_seconds = 0
|
||
if window_seconds <= 0:
|
||
return history
|
||
|
||
cutoff = time.time() - window_seconds
|
||
filtered = []
|
||
for msg in history or []:
|
||
ts = self._parse_history_timestamp((msg or {}).get("timestamp"))
|
||
if ts is None or ts >= cutoff:
|
||
filtered.append(msg)
|
||
return filtered
|
||
|
||
def _sanitize_speaker_name(self, name: str) -> str:
|
||
"""清洗昵称,避免破坏历史格式(如 [name] 前缀)。"""
|
||
if name is None:
|
||
return ""
|
||
s = str(name).strip()
|
||
if not s:
|
||
return ""
|
||
s = s.replace("\r", " ").replace("\n", " ")
|
||
s = re.sub(r"\s{2,}", " ", s)
|
||
# 避免与历史前缀 [xxx] 冲突
|
||
s = s.replace("[", "(").replace("]", ")")
|
||
return s.strip()
|
||
|
||
def _combine_display_and_nickname(self, display_name: str, wechat_nickname: str) -> str:
|
||
display_name = self._sanitize_speaker_name(display_name)
|
||
wechat_nickname = self._sanitize_speaker_name(wechat_nickname)
|
||
# 重要:群昵称(群名片) 与 微信昵称(全局) 是两个不同概念,尽量同时给 AI。
|
||
if display_name and wechat_nickname:
|
||
return f"群昵称={display_name} | 微信昵称={wechat_nickname}"
|
||
if display_name:
|
||
return f"群昵称={display_name}"
|
||
if wechat_nickname:
|
||
return f"微信昵称={wechat_nickname}"
|
||
return ""
|
||
|
||
def _get_chatroom_member_lock(self, chatroom_id: str) -> asyncio.Lock:
|
||
lock = self._chatroom_member_cache_locks.get(chatroom_id)
|
||
if lock is None:
|
||
lock = asyncio.Lock()
|
||
self._chatroom_member_cache_locks[chatroom_id] = lock
|
||
return lock
|
||
|
||
async def _get_group_display_name(self, bot, chatroom_id: str, user_wxid: str, *, force_refresh: bool = False) -> str:
|
||
"""获取群名片(群内昵称)。失败时返回空串。"""
|
||
if not chatroom_id or not user_wxid:
|
||
return ""
|
||
if not hasattr(bot, "get_chatroom_members"):
|
||
return ""
|
||
|
||
now = time.time()
|
||
if not force_refresh:
|
||
cached = self._chatroom_member_cache.get(chatroom_id)
|
||
if cached:
|
||
ts, member_map = cached
|
||
if now - float(ts or 0) < float(self._chatroom_member_cache_ttl_seconds or 0):
|
||
return self._sanitize_speaker_name(member_map.get(user_wxid, ""))
|
||
|
||
lock = self._get_chatroom_member_lock(chatroom_id)
|
||
async with lock:
|
||
now = time.time()
|
||
if not force_refresh:
|
||
cached = self._chatroom_member_cache.get(chatroom_id)
|
||
if cached:
|
||
ts, member_map = cached
|
||
if now - float(ts or 0) < float(self._chatroom_member_cache_ttl_seconds or 0):
|
||
return self._sanitize_speaker_name(member_map.get(user_wxid, ""))
|
||
|
||
try:
|
||
# 群成员列表可能较大,避免长期阻塞消息处理
|
||
members = await asyncio.wait_for(bot.get_chatroom_members(chatroom_id), timeout=8)
|
||
except Exception as e:
|
||
logger.debug(f"获取群成员列表失败: {chatroom_id}, {e}")
|
||
return ""
|
||
|
||
member_map = {}
|
||
try:
|
||
for m in members or []:
|
||
wxid = (m.get("wxid") or "").strip()
|
||
if not wxid:
|
||
continue
|
||
display_name = m.get("display_name") or m.get("displayName") or ""
|
||
member_map[wxid] = str(display_name or "").strip()
|
||
except Exception as e:
|
||
logger.debug(f"解析群成员列表失败: {chatroom_id}, {e}")
|
||
|
||
self._chatroom_member_cache[chatroom_id] = (time.time(), member_map)
|
||
return self._sanitize_speaker_name(member_map.get(user_wxid, ""))
|
||
|
||
async def _get_user_display_label(self, bot, from_wxid: str, user_wxid: str, is_group: bool) -> str:
|
||
"""用于历史记录:群聊优先使用群名片,其次微信昵称。"""
|
||
if not is_group:
|
||
return ""
|
||
wechat_nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
|
||
group_display = await self._get_group_display_name(bot, from_wxid, user_wxid)
|
||
return self._combine_display_and_nickname(group_display, wechat_nickname) or wechat_nickname or user_wxid
|
||
|
||
async def _get_user_nickname(self, bot, from_wxid: str, user_wxid: str, is_group: bool) -> str:
|
||
"""
|
||
获取用户昵称,优先使用 Redis 缓存
|
||
|
||
Args:
|
||
bot: WechatHookClient 实例
|
||
from_wxid: 消息来源(群聊ID或私聊用户ID)
|
||
user_wxid: 用户wxid
|
||
is_group: 是否群聊
|
||
|
||
Returns:
|
||
用户昵称
|
||
"""
|
||
if not is_group:
|
||
return ""
|
||
|
||
nickname = ""
|
||
|
||
# 1. 优先从 Redis 缓存获取
|
||
redis_cache = get_cache()
|
||
if redis_cache and redis_cache.enabled:
|
||
cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid)
|
||
if cached_info and cached_info.get("nickname"):
|
||
logger.debug(f"[缓存命中] 用户昵称: {user_wxid} -> {cached_info['nickname']}")
|
||
return cached_info["nickname"]
|
||
|
||
# 2. 缓存未命中,调用 API 获取
|
||
try:
|
||
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
|
||
if user_info and user_info.get("nickName", {}).get("string"):
|
||
nickname = user_info["nickName"]["string"]
|
||
# 存入缓存
|
||
if redis_cache and redis_cache.enabled:
|
||
redis_cache.set_user_info(from_wxid, user_wxid, user_info)
|
||
logger.debug(f"[已缓存] 用户昵称: {user_wxid} -> {nickname}")
|
||
return nickname
|
||
except Exception as e:
|
||
logger.warning(f"API获取用户昵称失败: {e}")
|
||
|
||
# 3. 从 MessageLogger 数据库查询
|
||
if not nickname:
|
||
try:
|
||
from plugins.MessageLogger.main import MessageLogger
|
||
msg_logger = MessageLogger.get_instance()
|
||
if msg_logger:
|
||
with msg_logger.get_db_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute(
|
||
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
|
||
(user_wxid,)
|
||
)
|
||
result = cursor.fetchone()
|
||
if result:
|
||
nickname = result[0]
|
||
except Exception as e:
|
||
logger.debug(f"从数据库获取昵称失败: {e}")
|
||
|
||
# 4. 最后降级使用 wxid
|
||
if not nickname:
|
||
nickname = user_wxid or "未知用户"
|
||
|
||
return nickname
|
||
|
||
def _check_rate_limit(self, user_wxid: str) -> tuple:
|
||
"""
|
||
检查用户是否超过限流
|
||
|
||
Args:
|
||
user_wxid: 用户wxid
|
||
|
||
Returns:
|
||
(是否允许, 剩余次数, 重置时间秒数)
|
||
"""
|
||
rate_limit_config = self.config.get("rate_limit", {})
|
||
if not rate_limit_config.get("enabled", True):
|
||
return (True, 999, 0)
|
||
|
||
redis_cache = get_cache()
|
||
if not redis_cache or not redis_cache.enabled:
|
||
return (True, 999, 0) # Redis 不可用时不限流
|
||
|
||
limit = rate_limit_config.get("ai_chat_limit", 20)
|
||
window = rate_limit_config.get("ai_chat_window", 60)
|
||
|
||
return redis_cache.check_rate_limit(user_wxid, limit, window, "ai_chat")
|
||
|
||
def _add_to_memory(self, chat_id: str, role: str, content, image_base64: str = None):
|
||
"""
|
||
添加消息到记忆
|
||
|
||
Args:
|
||
chat_id: 会话ID
|
||
role: 角色 (user/assistant)
|
||
content: 消息内容(可以是字符串或列表)
|
||
image_base64: 可选的图片base64数据
|
||
"""
|
||
if not self.store:
|
||
return
|
||
self.store.add_private_message(chat_id, role, content, image_base64=image_base64)
|
||
|
||
def _get_memory_messages(self, chat_id: str) -> list:
|
||
"""获取记忆中的消息"""
|
||
if not self.store:
|
||
return []
|
||
return self.store.get_private_messages(chat_id)
|
||
|
||
def _clear_memory(self, chat_id: str):
|
||
"""清空指定会话的记忆"""
|
||
if not self.store:
|
||
return
|
||
self.store.clear_private_messages(chat_id)
|
||
|
||
async def _download_and_encode_image(self, bot, cdnurl: str, aeskey: str) -> str:
|
||
"""下载图片并转换为base64,优先从缓存获取"""
|
||
try:
|
||
# 1. 优先从 Redis 缓存获取
|
||
from utils.redis_cache import RedisCache
|
||
redis_cache = get_cache()
|
||
if redis_cache and redis_cache.enabled:
|
||
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
|
||
if media_key:
|
||
cached_data = redis_cache.get_cached_media(media_key, "image")
|
||
if cached_data:
|
||
logger.debug(f"[缓存命中] 图片从 Redis 获取: {media_key[:20]}...")
|
||
return cached_data
|
||
|
||
# 2. 缓存未命中,下载图片
|
||
logger.debug(f"[缓存未命中] 开始下载图片...")
|
||
temp_dir = Path(__file__).parent / "temp"
|
||
temp_dir.mkdir(exist_ok=True)
|
||
|
||
filename = f"temp_{uuid.uuid4().hex[:8]}.jpg"
|
||
save_path = str((temp_dir / filename).resolve())
|
||
|
||
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=2)
|
||
if not success:
|
||
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=1)
|
||
|
||
if not success:
|
||
return ""
|
||
|
||
# 等待文件写入完成
|
||
import os
|
||
import asyncio
|
||
for _ in range(20): # 最多等待10秒
|
||
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
|
||
break
|
||
await asyncio.sleep(0.5)
|
||
|
||
if not os.path.exists(save_path):
|
||
return ""
|
||
|
||
with open(save_path, "rb") as f:
|
||
image_data = base64.b64encode(f.read()).decode()
|
||
|
||
base64_result = f"data:image/jpeg;base64,{image_data}"
|
||
|
||
# 3. 缓存到 Redis(供后续使用)
|
||
if redis_cache and redis_cache.enabled and media_key:
|
||
redis_cache.cache_media(media_key, base64_result, "image", ttl=300)
|
||
logger.debug(f"[已缓存] 图片缓存到 Redis: {media_key[:20]}...")
|
||
|
||
try:
|
||
Path(save_path).unlink()
|
||
except:
|
||
pass
|
||
|
||
return base64_result
|
||
except Exception as e:
|
||
logger.error(f"下载图片失败: {e}")
|
||
return ""
|
||
|
||
async def _download_emoji_and_encode(self, cdn_url: str, max_retries: int = 3) -> str:
|
||
"""下载表情包并转换为base64(HTTP 直接下载,带重试机制),优先从缓存获取"""
|
||
# 替换 HTML 实体
|
||
cdn_url = cdn_url.replace("&", "&")
|
||
|
||
# 1. 优先从 Redis 缓存获取
|
||
from utils.redis_cache import RedisCache
|
||
redis_cache = get_cache()
|
||
media_key = RedisCache.generate_media_key(cdnurl=cdn_url)
|
||
if redis_cache and redis_cache.enabled and media_key:
|
||
cached_data = redis_cache.get_cached_media(media_key, "emoji")
|
||
if cached_data:
|
||
logger.debug(f"[缓存命中] 表情包从 Redis 获取: {media_key[:20]}...")
|
||
return cached_data
|
||
|
||
# 2. 缓存未命中,下载表情包
|
||
logger.debug(f"[缓存未命中] 开始下载表情包...")
|
||
temp_dir = Path(__file__).parent / "temp"
|
||
temp_dir.mkdir(exist_ok=True)
|
||
|
||
filename = f"temp_{uuid.uuid4().hex[:8]}.gif"
|
||
save_path = temp_dir / filename
|
||
|
||
last_error = None
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
# 使用 aiohttp 下载,每次重试增加超时时间
|
||
timeout = aiohttp.ClientTimeout(total=30 + attempt * 15)
|
||
|
||
# 配置代理
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_username = proxy_config.get("username")
|
||
proxy_password = proxy_config.get("password")
|
||
|
||
if proxy_username and proxy_password:
|
||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||
else:
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
|
||
if PROXY_SUPPORT:
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
except:
|
||
connector = None
|
||
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
async with session.get(cdn_url) as response:
|
||
if response.status == 200:
|
||
content = await response.read()
|
||
|
||
if len(content) == 0:
|
||
logger.warning(f"表情包下载内容为空,重试 {attempt + 1}/{max_retries}")
|
||
continue
|
||
|
||
# 编码为 base64
|
||
image_data = base64.b64encode(content).decode()
|
||
|
||
logger.debug(f"表情包下载成功,大小: {len(content)} 字节")
|
||
base64_result = f"data:image/gif;base64,{image_data}"
|
||
|
||
# 3. 缓存到 Redis(供后续使用)
|
||
if redis_cache and redis_cache.enabled and media_key:
|
||
redis_cache.cache_media(media_key, base64_result, "emoji", ttl=300)
|
||
logger.debug(f"[已缓存] 表情包缓存到 Redis: {media_key[:20]}...")
|
||
|
||
return base64_result
|
||
else:
|
||
logger.warning(f"表情包下载失败,状态码: {response.status},重试 {attempt + 1}/{max_retries}")
|
||
|
||
except asyncio.TimeoutError:
|
||
last_error = "请求超时"
|
||
logger.warning(f"表情包下载超时,重试 {attempt + 1}/{max_retries}")
|
||
except aiohttp.ClientError as e:
|
||
last_error = str(e)
|
||
logger.warning(f"表情包下载网络错误: {e},重试 {attempt + 1}/{max_retries}")
|
||
except Exception as e:
|
||
last_error = str(e)
|
||
logger.warning(f"表情包下载异常: {e},重试 {attempt + 1}/{max_retries}")
|
||
|
||
# 重试前等待(指数退避)
|
||
if attempt < max_retries - 1:
|
||
await asyncio.sleep(1 * (attempt + 1))
|
||
|
||
logger.error(f"表情包下载失败,已重试 {max_retries} 次: {last_error}")
|
||
return ""
|
||
|
||
async def _generate_image_description(self, image_base64: str, prompt: str, config: dict) -> str:
|
||
"""
|
||
使用 AI 生成图片描述
|
||
|
||
Args:
|
||
image_base64: 图片的 base64 数据
|
||
prompt: 描述提示词
|
||
config: 图片描述配置
|
||
|
||
Returns:
|
||
图片描述文本,失败返回空字符串
|
||
"""
|
||
api_config = self.config["api"]
|
||
description_model = config.get("model", api_config["model"])
|
||
|
||
# 构建消息
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{"type": "image_url", "image_url": {"url": image_base64}}
|
||
]
|
||
}
|
||
]
|
||
|
||
payload = {
|
||
"model": description_model,
|
||
"messages": messages,
|
||
"max_tokens": config.get("max_tokens", 1000),
|
||
"stream": True
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_config['api_key']}"
|
||
}
|
||
|
||
max_retries = int(config.get("retries", 2))
|
||
last_error = None
|
||
|
||
for attempt in range(max_retries + 1):
|
||
try:
|
||
timeout = aiohttp.ClientTimeout(total=api_config["timeout"])
|
||
|
||
# 配置代理(每次重试单独构造 connector)
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_username = proxy_config.get("username")
|
||
proxy_password = proxy_config.get("password")
|
||
|
||
if proxy_username and proxy_password:
|
||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||
else:
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
|
||
if PROXY_SUPPORT:
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
except Exception as e:
|
||
logger.warning(f"代理配置失败,将直连: {e}")
|
||
connector = None
|
||
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
async with session.post(
|
||
api_config["url"],
|
||
json=payload,
|
||
headers=headers
|
||
) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
raise Exception(f"图片描述 API 返回错误: {resp.status}, {error_text[:200]}")
|
||
|
||
# 流式接收响应
|
||
description = ""
|
||
async for line in resp.content:
|
||
line = line.decode('utf-8').strip()
|
||
if not line or line == "data: [DONE]":
|
||
continue
|
||
|
||
if line.startswith("data: "):
|
||
try:
|
||
data = json.loads(line[6:])
|
||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||
content = delta.get("content", "")
|
||
if content:
|
||
description += content
|
||
except Exception:
|
||
pass
|
||
|
||
logger.debug(f"图片描述生成成功: {description}")
|
||
return description.strip()
|
||
|
||
except asyncio.CancelledError:
|
||
raise
|
||
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
||
last_error = str(e)
|
||
if attempt < max_retries:
|
||
logger.warning(f"图片描述网络错误: {e},重试 {attempt + 1}/{max_retries}")
|
||
await asyncio.sleep(1 * (attempt + 1))
|
||
continue
|
||
except Exception as e:
|
||
last_error = str(e)
|
||
if attempt < max_retries:
|
||
logger.warning(f"图片描述生成异常: {e},重试 {attempt + 1}/{max_retries}")
|
||
await asyncio.sleep(1 * (attempt + 1))
|
||
continue
|
||
|
||
logger.error(f"生成图片描述失败,已重试 {max_retries + 1} 次: {last_error}")
|
||
return ""
|
||
|
||
def _collect_tools_with_plugins(self) -> dict:
|
||
"""收集所有插件的 LLM 工具,并保留来源插件名"""
|
||
from utils.plugin_manager import PluginManager
|
||
tools_config = self.config.get("tools", {})
|
||
return collect_tools_with_plugins(tools_config, PluginManager().plugins)
|
||
|
||
def _collect_tools(self):
|
||
"""收集所有插件的LLM工具(支持白名单/黑名单过滤)"""
|
||
from utils.plugin_manager import PluginManager
|
||
tools_config = self.config.get("tools", {})
|
||
return collect_tools(tools_config, PluginManager().plugins)
|
||
|
||
def _get_tool_schema_map(self, tools_map: dict | None = None) -> dict:
|
||
"""构建工具名到参数 schema 的映射"""
|
||
tools_map = tools_map or self._collect_tools_with_plugins()
|
||
return get_tool_schema_map(tools_map)
|
||
|
||
def _validate_tool_arguments(self, tool_name: str, arguments: dict, schema: dict) -> tuple:
|
||
"""轻量校验并补全默认参数"""
|
||
return validate_tool_arguments(tool_name, arguments, schema)
|
||
|
||
async def _handle_list_prompts(self, bot, from_wxid: str):
|
||
"""处理人设列表指令"""
|
||
try:
|
||
prompts_dir = Path(__file__).parent / "prompts"
|
||
|
||
# 获取所有 .txt 文件
|
||
if not prompts_dir.exists():
|
||
await bot.send_text(from_wxid, "❌ prompts 目录不存在")
|
||
return
|
||
|
||
txt_files = sorted(prompts_dir.glob("*.txt"))
|
||
|
||
if not txt_files:
|
||
await bot.send_text(from_wxid, "❌ 没有找到任何人设文件")
|
||
return
|
||
|
||
# 构建列表消息
|
||
current_file = self.config["prompt"]["system_prompt_file"]
|
||
msg = "📋 可用人设列表:\n\n"
|
||
|
||
for i, file_path in enumerate(txt_files, 1):
|
||
filename = file_path.name
|
||
# 标记当前使用的人设
|
||
if filename == current_file:
|
||
msg += f"{i}. {filename} ✅\n"
|
||
else:
|
||
msg += f"{i}. {filename}\n"
|
||
|
||
msg += f"\n💡 使用方法:/切人设 文件名.txt"
|
||
|
||
await bot.send_text(from_wxid, msg)
|
||
logger.info("已发送人设列表")
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取人设列表失败: {e}")
|
||
await bot.send_text(from_wxid, f"❌ 获取人设列表失败: {str(e)}")
|
||
|
||
def _estimate_tokens(self, text: str) -> int:
|
||
"""
|
||
估算文本的 token 数量
|
||
|
||
简单估算规则:
|
||
- 中文:约 1.5 字符 = 1 token
|
||
- 英文:约 4 字符 = 1 token
|
||
- 混合文本取平均
|
||
"""
|
||
if not text:
|
||
return 0
|
||
|
||
# 统计中文字符数
|
||
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||
# 其他字符数
|
||
other_chars = len(text) - chinese_chars
|
||
|
||
# 估算 token 数
|
||
chinese_tokens = chinese_chars / 1.5
|
||
other_tokens = other_chars / 4
|
||
|
||
return int(chinese_tokens + other_tokens)
|
||
|
||
def _estimate_message_tokens(self, message: dict) -> int:
|
||
"""估算单条消息的 token 数"""
|
||
content = message.get("content", "")
|
||
|
||
if isinstance(content, str):
|
||
return self._estimate_tokens(content)
|
||
elif isinstance(content, list):
|
||
# 多模态消息
|
||
total = 0
|
||
for item in content:
|
||
if item.get("type") == "text":
|
||
total += self._estimate_tokens(item.get("text", ""))
|
||
elif item.get("type") == "image_url":
|
||
# 图片按 85 token 估算(OpenAI 低分辨率图片)
|
||
total += 85
|
||
return total
|
||
return 0
|
||
|
||
def _extract_text_from_multimodal(self, content) -> str:
|
||
"""从多模态 content 中提取文本,模型不支持时用于降级"""
|
||
if isinstance(content, list):
|
||
texts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
||
text = "".join(texts).strip()
|
||
return text or "[图片]"
|
||
if content is None:
|
||
return ""
|
||
return str(content)
|
||
|
||
def _sanitize_llm_output(self, text) -> str:
|
||
"""
|
||
清洗 LLM 输出,尽量满足:不输出思维链、不使用 Markdown。
|
||
|
||
说明:提示词并非强约束,因此在所有“发给用户/写入上下文”的出口统一做后处理。
|
||
"""
|
||
if text is None:
|
||
return ""
|
||
raw = str(text)
|
||
cleaned = raw
|
||
|
||
output_cfg = (self.config or {}).get("output", {})
|
||
strip_thinking = output_cfg.get("strip_thinking", True)
|
||
strip_markdown = output_cfg.get("strip_markdown", True)
|
||
|
||
# 先做一次 Markdown 清理,避免 “**思考过程:**/### 思考” 这类包裹导致无法识别
|
||
if strip_markdown:
|
||
cleaned = self._strip_markdown_syntax(cleaned)
|
||
|
||
if strip_thinking:
|
||
cleaned = self._strip_thinking_content(cleaned)
|
||
|
||
# 再跑一轮:部分模型会把“思考/最终”标记写成 Markdown,或在剥离标签后才露出标记
|
||
if strip_markdown:
|
||
cleaned = self._strip_markdown_syntax(cleaned)
|
||
if strip_thinking:
|
||
cleaned = self._strip_thinking_content(cleaned)
|
||
|
||
cleaned = cleaned.strip()
|
||
# 兜底:清洗后仍残留明显“思维链/大纲”标记时,再尝试一次“抽取最终段”
|
||
if strip_thinking and cleaned and self._contains_thinking_markers(cleaned):
|
||
extracted = self._extract_after_last_answer_marker(cleaned)
|
||
if not extracted:
|
||
extracted = self._extract_final_answer_from_outline(cleaned)
|
||
if extracted:
|
||
cleaned = extracted.strip()
|
||
# 仍残留标记:尽量选取最后一个“不含标记”的段落作为最终回复
|
||
if cleaned and self._contains_thinking_markers(cleaned):
|
||
parts = [p.strip() for p in re.split(r"\n{2,}", cleaned) if p.strip()]
|
||
for p in reversed(parts):
|
||
if not self._contains_thinking_markers(p):
|
||
cleaned = p
|
||
break
|
||
cleaned = cleaned.strip()
|
||
|
||
# 最终兜底:仍然像思维链就直接丢弃(宁可不发也不要把思维链发出去)
|
||
if strip_thinking and cleaned and self._contains_thinking_markers(cleaned):
|
||
return ""
|
||
|
||
if cleaned:
|
||
return cleaned
|
||
|
||
raw_stripped = raw.strip()
|
||
# 清洗后为空时,不要回退到包含思维链标记的原文(避免把 <think>... 直接发出去)
|
||
if strip_thinking and self._contains_thinking_markers(raw_stripped):
|
||
return ""
|
||
return raw_stripped
|
||
|
||
def _contains_thinking_markers(self, text: str) -> bool:
|
||
"""粗略判断文本是否包含明显的“思考/推理”外显标记,用于决定是否允许回退原文。"""
|
||
if not text:
|
||
return False
|
||
|
||
lowered = text.lower()
|
||
tag_tokens = (
|
||
"<think", "</think",
|
||
"<analysis", "</analysis",
|
||
"<reasoning", "</reasoning",
|
||
"<thought", "</thought",
|
||
"<thinking", "</thinking",
|
||
"<thoughts", "</thoughts",
|
||
"<scratchpad", "</scratchpad",
|
||
"<think", "</think",
|
||
"<analysis", "</analysis",
|
||
"<reasoning", "</reasoning",
|
||
"<thought", "</thought",
|
||
"<thinking", "</thinking",
|
||
"<thoughts", "</thoughts",
|
||
"<scratchpad", "</scratchpad",
|
||
)
|
||
if any(tok in lowered for tok in tag_tokens):
|
||
return True
|
||
|
||
stripped = text.strip()
|
||
if stripped.startswith("{") and stripped.endswith("}"):
|
||
# JSON 结构化输出(常见于“analysis/final”)
|
||
json_keys = (
|
||
"\"analysis\"",
|
||
"\"reasoning\"",
|
||
"\"thought\"",
|
||
"\"thoughts\"",
|
||
"\"scratchpad\"",
|
||
"\"final\"",
|
||
"\"answer\"",
|
||
"\"response\"",
|
||
"\"output\"",
|
||
"\"text\"",
|
||
)
|
||
if any(k in lowered for k in json_keys):
|
||
return True
|
||
|
||
# YAML/KV 风格
|
||
if re.search(r"(?im)^\s*(analysis|reasoning|thoughts?|scratchpad|final|answer|response|output|text|思考|分析|推理|最终|输出)\s*[::]", text):
|
||
return True
|
||
|
||
marker_re = re.compile(
|
||
r"(?mi)^\s*(?:\d+\s*[\.\、::))\-–—]\s*)?(?:[-*•]+\s*)?"
|
||
r"(?:【\s*(?:思考过程|推理过程|分析过程|思考|分析|推理|内心独白|内心os|思维链|思路|"
|
||
r"chain\s*of\s*thought|reasoning|analysis|thinking|thoughts|thought\s*process|scratchpad)\s*】"
|
||
r"|(?:思考过程|推理过程|分析过程|思考|分析|推理|内心独白|内心os|思维链|思路|"
|
||
r"chain\s*of\s*thought|reasoning|analysis|analyze|thinking|thoughts|thought\s*process|scratchpad|internal\s*monologue|mind\s*space|final\s*polish|output\s*generation)"
|
||
r"(?:\s*】)?\s*(?:[::]|$|\s+))"
|
||
)
|
||
return marker_re.search(text) is not None
|
||
|
||
def _extract_after_last_answer_marker(self, text: str) -> str | None:
|
||
"""从文本中抽取最后一个“最终/输出/答案”标记后的内容(不要求必须是编号大纲)。"""
|
||
if not text:
|
||
return None
|
||
|
||
# 1) 明确的行首标记:Text:/Final Answer:/输出: ...
|
||
marker_re = re.compile(
|
||
r"(?im)^\s*(?:\d+\s*[\.\、::\)、))\-–—]\s*)?"
|
||
r"(?:text|final\s*answer|final\s*response|final\s*output|final|output|answer|response|输出|最终回复|最终答案|最终)\s*[::]\s*"
|
||
)
|
||
matches = list(marker_re.finditer(text))
|
||
if matches:
|
||
candidate = text[matches[-1].end():].strip()
|
||
if candidate:
|
||
return candidate
|
||
|
||
# 2) JSON/YAML 风格:final: ... / \"final\": \"...\"
|
||
kv_re = re.compile(
|
||
r"(?im)^\s*\"?(?:final|answer|response|output|text|最终|最终回复|最终答案|输出)\"?\s*[::]\s*"
|
||
)
|
||
kv_matches = list(kv_re.finditer(text))
|
||
if kv_matches:
|
||
candidate = text[kv_matches[-1].end():].strip()
|
||
if candidate:
|
||
return candidate
|
||
|
||
# 3) 纯 JSON 对象(尝试解析)
|
||
stripped = text.strip()
|
||
if stripped.startswith("{") and stripped.endswith("}"):
|
||
try:
|
||
obj = json.loads(stripped)
|
||
if isinstance(obj, dict):
|
||
for key in ("final", "answer", "response", "output", "text"):
|
||
v = obj.get(key)
|
||
if isinstance(v, str) and v.strip():
|
||
return v.strip()
|
||
except Exception:
|
||
pass
|
||
|
||
return None
|
||
|
||
def _extract_final_answer_from_outline(self, text: str) -> str | None:
|
||
"""从“分析/草稿/输出”这类结构化大纲中提取最终回复正文(用于拦截思维链)。"""
|
||
if not text:
|
||
return None
|
||
|
||
# 至少包含多个“1./2./3.”段落,才认为可能是大纲/思维链输出
|
||
heading_re = re.compile(r"(?m)^\s*\d+\s*[\.\、::\)、))\-–—]\s*\S+")
|
||
if len(heading_re.findall(text)) < 2:
|
||
return None
|
||
|
||
# 优先:提取最后一个 “Text:/Final Answer:/Output:” 之后的内容
|
||
marker_re = re.compile(
|
||
r"(?im)^\s*(?:\d+\s*[\.\、::\)、))\-–—]\s*)?"
|
||
r"(?:text|final\s*answer|final\s*response|final\s*output|output|answer|response|输出|最终回复|最终答案)\s*[::]\s*"
|
||
)
|
||
matches = list(marker_re.finditer(text))
|
||
if matches:
|
||
candidate = text[matches[-1].end():].strip()
|
||
if candidate:
|
||
return candidate
|
||
|
||
# 没有明确的最终标记时,仅在包含“分析/思考/草稿/输出”等元信息关键词的情况下兜底抽取
|
||
lowered = text.lower()
|
||
outline_keywords = (
|
||
"analyze",
|
||
"analysis",
|
||
"reasoning",
|
||
"internal monologue",
|
||
"mind space",
|
||
"draft",
|
||
"drafting",
|
||
"outline",
|
||
"plan",
|
||
"steps",
|
||
"formulating response",
|
||
"final polish",
|
||
"final answer",
|
||
"output generation",
|
||
"system prompt",
|
||
"chat log",
|
||
"previous turn",
|
||
"current situation",
|
||
)
|
||
cn_keywords = ("思考", "分析", "推理", "思维链", "草稿", "计划", "步骤", "输出", "最终")
|
||
if not any(k in lowered for k in outline_keywords) and not any(k in text for k in cn_keywords):
|
||
return None
|
||
|
||
# 次选:取最后一个非空段落(避免返回整段大纲)
|
||
parts = [p.strip() for p in re.split(r"\n{2,}", text) if p.strip()]
|
||
if not parts:
|
||
return None
|
||
|
||
last = parts[-1]
|
||
if len(heading_re.findall(last)) == 0:
|
||
return last
|
||
return None
|
||
|
||
def _strip_thinking_content(self, text: str) -> str:
|
||
"""移除常见的“思考/推理”外显内容(如 <think>...</think>、思考:...)。"""
|
||
if not text:
|
||
return ""
|
||
|
||
t = text.replace("\r\n", "\n").replace("\r", "\n")
|
||
|
||
# 1) 先移除显式标签块(常见于某些推理模型)
|
||
thinking_tags = ("think", "analysis", "reasoning", "thought", "thinking", "thoughts", "scratchpad", "reflection")
|
||
for tag in thinking_tags:
|
||
t = re.sub(rf"<{tag}\b[^>]*>.*?</{tag}>", "", t, flags=re.IGNORECASE | re.DOTALL)
|
||
# 兼容被转义的标签(<think>...</think>)
|
||
t = re.sub(rf"<{tag}\b[^&]*>.*?</{tag}>", "", t, flags=re.IGNORECASE | re.DOTALL)
|
||
|
||
# 1.1) 兜底:流式/截断导致标签未闭合时,若开头出现思考标签,直接截断后续内容
|
||
m = re.search(r"<(think|analysis|reasoning|thought|thinking|thoughts|scratchpad|reflection)\b[^>]*>", t, flags=re.IGNORECASE)
|
||
if m and m.start() < 200:
|
||
t = t[: m.start()].rstrip()
|
||
m2 = re.search(r"<(think|analysis|reasoning|thought|thinking|thoughts|scratchpad|reflection)\b[^&]*>", t, flags=re.IGNORECASE)
|
||
if m2 and m2.start() < 200:
|
||
t = t[: m2.start()].rstrip()
|
||
|
||
# 2) 再处理“思考:.../最终:...”这种分段格式(尽量只剥离前置思考)
|
||
lines = t.split("\n")
|
||
if not lines:
|
||
return t
|
||
|
||
# 若文本中包含明显的“最终/输出/答案”标记(不限是否编号),直接抽取最后一段,避免把大纲整体发出去
|
||
if self._contains_thinking_markers(t):
|
||
extracted_anywhere = self._extract_after_last_answer_marker(t)
|
||
if extracted_anywhere:
|
||
return extracted_anywhere
|
||
|
||
reasoning_kw = (
|
||
r"思考过程|推理过程|分析过程|思考|分析|推理|思路|内心独白|内心os|思维链|"
|
||
r"chain\s*of\s*thought|reasoning|analysis|analyze|thinking|thoughts|thought\s*process|scratchpad|plan|steps|draft|outline"
|
||
)
|
||
answer_kw = r"最终答案|最终回复|最终|回答|回复|答复|结论|输出|final(?:\s*answer)?|final\s*response|final\s*output|answer|response|output|text"
|
||
|
||
# 兼容:
|
||
# - 思考:... / 最终回复:...
|
||
# - 【思考】... / 【最终】...
|
||
# - **思考过程:**(Markdown 会在外层先被剥离)
|
||
reasoning_start = re.compile(
|
||
rf"^\s*(?:\d+\s*[\.\、::\)、))\-–—]\s*)?(?:[-*•]+\s*)?"
|
||
rf"(?:【\s*(?:{reasoning_kw})\s*】\s*[::]?\s*|(?:{reasoning_kw})(?:\s*】)?\s*(?:[::]|$|\s+))",
|
||
re.IGNORECASE,
|
||
)
|
||
answer_start = re.compile(
|
||
rf"^\s*(?:\d+\s*[\.\、::\)、))\-–—]\s*)?(?:[-*•]+\s*)?"
|
||
rf"(?:【\s*(?:{answer_kw})\s*】\s*[::]?\s*|(?:{answer_kw})(?:\s*】)?\s*(?:[::]|$)\s*)",
|
||
re.IGNORECASE,
|
||
)
|
||
|
||
# 2.0) 若文本开头就是“最终回复:/Final answer:”之类,直接去掉标记(不强依赖出现“思考块”)
|
||
for idx, line in enumerate(lines):
|
||
if line.strip() == "":
|
||
continue
|
||
m0 = answer_start.match(line)
|
||
if m0:
|
||
lines[idx] = line[m0.end():].lstrip()
|
||
break
|
||
|
||
has_reasoning = any(reasoning_start.match(line) for line in lines[:10])
|
||
has_answer_marker = any(answer_start.match(line) for line in lines)
|
||
|
||
# 2.1) 若同时存在“思考块 + 答案标记”,跳过思考块直到答案标记
|
||
if has_reasoning and has_answer_marker:
|
||
out_lines: list[str] = []
|
||
skipping = False
|
||
answer_started = False
|
||
for line in lines:
|
||
if answer_started:
|
||
out_lines.append(line)
|
||
continue
|
||
|
||
if not skipping and reasoning_start.match(line):
|
||
skipping = True
|
||
continue
|
||
|
||
if skipping:
|
||
m = answer_start.match(line)
|
||
if m:
|
||
answer_started = True
|
||
skipping = False
|
||
out_lines.append(line[m.end():].lstrip())
|
||
continue
|
||
|
||
m = answer_start.match(line)
|
||
if m:
|
||
answer_started = True
|
||
out_lines.append(line[m.end():].lstrip())
|
||
else:
|
||
out_lines.append(line)
|
||
|
||
t2 = "\n".join(out_lines).strip()
|
||
return t2 if t2 else t
|
||
|
||
# 2.2) 兜底:若开头就是“思考:”,尝试去掉第一段(到第一个空行)
|
||
if has_reasoning:
|
||
first_blank_idx = None
|
||
for idx, line in enumerate(lines):
|
||
if line.strip() == "":
|
||
first_blank_idx = idx
|
||
break
|
||
if first_blank_idx is not None and first_blank_idx + 1 < len(lines):
|
||
candidate = "\n".join(lines[first_blank_idx + 1 :]).strip()
|
||
if candidate:
|
||
return candidate
|
||
|
||
# 2.3) 兜底:识别“1. Analyze... 2. ... 6. Output ... Text: ...”这类思维链大纲并抽取最终正文
|
||
outline_extracted = self._extract_final_answer_from_outline("\n".join(lines).strip())
|
||
if outline_extracted:
|
||
return outline_extracted
|
||
|
||
# 将行级处理结果合回文本(例如去掉开头的“最终回复:”标记)
|
||
t = "\n".join(lines).strip()
|
||
|
||
# 3) 兼容 <final>...</final> 这类包裹(保留正文,去掉标签)
|
||
t = re.sub(r"</?\s*(final|answer)\s*>", "", t, flags=re.IGNORECASE).strip()
|
||
|
||
return t
|
||
|
||
def _strip_markdown_syntax(self, text: str) -> str:
|
||
"""将常见 Markdown 标记转换为更像纯文本的形式(保留内容,移除格式符)。"""
|
||
if not text:
|
||
return ""
|
||
|
||
t = text.replace("\r\n", "\n").replace("\r", "\n")
|
||
|
||
# 去掉代码块围栏(保留内容)
|
||
t = re.sub(r"```[^\n]*\n", "", t)
|
||
t = t.replace("```", "")
|
||
|
||
# 图片/链接: / [text](url)
|
||
def _md_image_repl(m: re.Match) -> str:
|
||
alt = (m.group(1) or "").strip()
|
||
url = (m.group(2) or "").strip()
|
||
if alt and url:
|
||
return f"{alt}({url})"
|
||
return url or alt or ""
|
||
|
||
def _md_link_repl(m: re.Match) -> str:
|
||
label = (m.group(1) or "").strip()
|
||
url = (m.group(2) or "").strip()
|
||
if label and url:
|
||
return f"{label}({url})"
|
||
return url or label or ""
|
||
|
||
t = re.sub(r"!\[([^\]]*)\]\(([^)]+)\)", _md_image_repl, t)
|
||
t = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", _md_link_repl, t)
|
||
|
||
# 行级标记:标题、引用、分割线
|
||
cleaned_lines: list[str] = []
|
||
for line in t.split("\n"):
|
||
line = re.sub(r"^\s{0,3}#{1,6}\s+", "", line) # 标题
|
||
line = re.sub(r"^\s{0,3}>\s?", "", line) # 引用
|
||
if re.match(r"^\s*(?:-{3,}|\*{3,}|_{3,})\s*$", line):
|
||
continue # 分割线整行移除
|
||
cleaned_lines.append(line)
|
||
t = "\n".join(cleaned_lines)
|
||
|
||
# 行内代码:`code`
|
||
t = re.sub(r"`([^`]+)`", r"\1", t)
|
||
|
||
# 粗体/删除线(保留文本)
|
||
t = t.replace("**", "")
|
||
t = t.replace("__", "")
|
||
t = t.replace("~~", "")
|
||
|
||
# 斜体(保留文本,避免误伤乘法/通配符,仅处理成对包裹)
|
||
t = re.sub(r"(?<!\*)\*([^*\n]+)\*(?!\*)", r"\1", t)
|
||
t = re.sub(r"(?<!_)_([^_\n]+)_(?!_)", r"\1", t)
|
||
|
||
# 压缩过多空行
|
||
t = re.sub(r"\n{3,}", "\n\n", t)
|
||
return t.strip()
|
||
|
||
def _append_group_history_messages(self, messages: list, recent_history: list):
|
||
"""将群聊历史按 role 追加到 LLM messages"""
|
||
for msg in recent_history:
|
||
role = msg.get("role") or "user"
|
||
msg_nickname = msg.get("nickname", "")
|
||
msg_content = msg.get("content", "")
|
||
|
||
# 机器人历史回复
|
||
if role == "assistant":
|
||
if isinstance(msg_content, list):
|
||
msg_content = self._extract_text_from_multimodal(msg_content)
|
||
# 避免旧历史中的 Markdown/思维链污染上下文
|
||
msg_content = self._sanitize_llm_output(msg_content)
|
||
messages.append({
|
||
"role": "assistant",
|
||
"content": msg_content
|
||
})
|
||
continue
|
||
|
||
# 用户历史消息
|
||
if isinstance(msg_content, list):
|
||
content_with_nickname = []
|
||
for item in msg_content:
|
||
if item.get("type") == "text":
|
||
content_with_nickname.append({
|
||
"type": "text",
|
||
"text": f"[{msg_nickname}] {item.get('text', '')}"
|
||
})
|
||
else:
|
||
content_with_nickname.append(item)
|
||
|
||
messages.append({
|
||
"role": "user",
|
||
"content": content_with_nickname
|
||
})
|
||
else:
|
||
messages.append({
|
||
"role": "user",
|
||
"content": f"[{msg_nickname}] {msg_content}"
|
||
})
|
||
|
||
def _get_bot_nickname(self) -> str:
|
||
try:
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
nickname = main_config.get("Bot", {}).get("nickname", "")
|
||
return nickname or "机器人"
|
||
except Exception:
|
||
return "机器人"
|
||
|
||
def _tool_call_to_action_text(self, function_name: str, arguments: dict) -> str:
|
||
args = arguments if isinstance(arguments, dict) else {}
|
||
|
||
if function_name == "query_weather":
|
||
city = str(args.get("city") or "").strip()
|
||
return f"查询{city}天气" if city else "查询天气"
|
||
|
||
if function_name == "register_city":
|
||
city = str(args.get("city") or "").strip()
|
||
return f"注册城市{city}" if city else "注册城市"
|
||
|
||
if function_name == "user_signin":
|
||
return "签到"
|
||
|
||
if function_name == "check_profile":
|
||
return "查询个人信息"
|
||
|
||
return f"执行{function_name}"
|
||
|
||
def _build_tool_calls_context_note(self, tool_calls_data: list) -> str:
|
||
actions: list[str] = []
|
||
for tool_call in tool_calls_data or []:
|
||
function_name = tool_call.get("function", {}).get("name", "")
|
||
if not function_name:
|
||
continue
|
||
|
||
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
|
||
try:
|
||
arguments = json.loads(arguments_str) if arguments_str else {}
|
||
except Exception:
|
||
arguments = {}
|
||
|
||
actions.append(self._tool_call_to_action_text(function_name, arguments))
|
||
|
||
if not actions:
|
||
return "(已触发工具处理:上一条请求。结果将发送到聊天中。)"
|
||
|
||
return f"(已触发工具处理:{';'.join(actions)}。结果将发送到聊天中。)"
|
||
|
||
async def _record_tool_calls_to_context(
|
||
self,
|
||
tool_calls_data: list,
|
||
*,
|
||
from_wxid: str,
|
||
chat_id: str,
|
||
is_group: bool,
|
||
user_wxid: str | None = None,
|
||
):
|
||
note = self._build_tool_calls_context_note(tool_calls_data)
|
||
if chat_id:
|
||
self._add_to_memory(chat_id, "assistant", note)
|
||
|
||
if is_group and from_wxid:
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid or "")
|
||
await self._add_to_history(history_chat_id, self._get_bot_nickname(), note, role="assistant", sender_wxid=user_wxid or None)
|
||
|
||
def _extract_tool_intent_text(self, user_message: str, tool_query: str | None = None) -> str:
|
||
text = tool_query if tool_query is not None else user_message
|
||
text = str(text or "").strip()
|
||
if not text:
|
||
return ""
|
||
|
||
# 对“聊天记录/视频”等组合消息,尽量只取用户真实提问部分,避免历史文本触发工具误判
|
||
markers = (
|
||
"[用户的问题]:",
|
||
"[用户的问题]:",
|
||
"[用户的问题]\n",
|
||
"[用户的问题]",
|
||
)
|
||
for marker in markers:
|
||
if marker in text:
|
||
text = text.rsplit(marker, 1)[-1].strip()
|
||
return text
|
||
|
||
def _select_tools_for_message(self, tools: list, *, user_message: str, tool_query: str | None = None) -> list:
|
||
tools_config = (self.config or {}).get("tools", {})
|
||
if not tools_config.get("smart_select", False):
|
||
return tools
|
||
|
||
intent_text = self._extract_tool_intent_text(user_message, tool_query=tool_query)
|
||
if not intent_text:
|
||
return tools
|
||
|
||
t = intent_text.lower()
|
||
allow: set[str] = set()
|
||
|
||
# 天气
|
||
if re.search(r"(天气|气温|温度|下雨|下雪|风力|空气质量|pm2\\.?5|湿度|预报)", t):
|
||
allow.add("query_weather")
|
||
|
||
# 注册/设置城市(避免仅凭城市名触发)
|
||
if re.search(r"(注册|设置|更新|更换|修改|绑定|默认).{0,6}城市|城市.{0,6}(注册|设置|更新|更换|修改|绑定|默认)", t):
|
||
allow.add("register_city")
|
||
|
||
# 签到/个人信息
|
||
if re.search(r"(用户签到|签到|签个到)", t):
|
||
allow.add("user_signin")
|
||
if re.search(r"(个人信息|我的信息|我的积分|查积分|积分多少|连续签到|连签|我的资料)", t):
|
||
allow.add("check_profile")
|
||
|
||
# 鹿打卡
|
||
if re.search(r"(鹿打卡|鹿签到)", t):
|
||
allow.add("deer_checkin")
|
||
if re.search(r"(补签|补打卡)", t):
|
||
allow.add("makeup_checkin")
|
||
if re.search(r"(鹿.*(日历|月历|打卡日历))|((日历|月历|打卡日历).*鹿)", t):
|
||
allow.add("view_calendar")
|
||
|
||
# 搜索/资讯
|
||
if re.search(r"(联网|搜索|搜一下|搜一搜|搜搜|帮我搜|搜新闻|搜资料|查资料|查新闻|查价格)", t):
|
||
# 兼容旧工具名与当前插件实现
|
||
allow.add("tavily_web_search")
|
||
allow.add("web_search")
|
||
# 隐式信息检索:用户询问具体实体/口碑/评价但未明确说“搜索/联网”
|
||
if re.search(r"(怎么样|如何|评价|口碑|靠谱吗|值不值得|值得吗|好不好|推荐|牛不牛|强不强|厉不厉害|有名吗|什么来头|背景|近况|最新|最近)", t) and re.search(
|
||
r"(公会|战队|服务器|区服|游戏|公司|品牌|店|商家|产品|软件|插件|项目|平台|up主|主播|作者|电影|电视剧|小说|手游|网游)",
|
||
t,
|
||
):
|
||
allow.add("tavily_web_search")
|
||
allow.add("web_search")
|
||
if re.search(r"(60秒|每日新闻|早报|新闻图片|读懂世界)", t):
|
||
allow.add("get_daily_news")
|
||
if re.search(r"(epic|喜加一|免费游戏)", t):
|
||
allow.add("get_epic_free_games")
|
||
|
||
# 音乐/短剧
|
||
if re.search(r"(搜歌|找歌|点歌|来一首|歌名|歌曲|音乐|听.*歌|播放.*歌)", t) or ("歌" in t and re.search(r"(搜|找|点|来一首|播放|听)", t)):
|
||
allow.add("search_music")
|
||
if re.search(r"(短剧|搜短剧|找短剧)", t):
|
||
allow.add("search_playlet")
|
||
|
||
# 群聊总结
|
||
if re.search(r"(群聊总结|生成总结|总结一下|今日总结|昨天总结|群总结)", t):
|
||
allow.add("generate_summary")
|
||
|
||
# 娱乐
|
||
if re.search(r"(疯狂星期四|v我50|kfc)", t):
|
||
allow.add("get_kfc")
|
||
# 发病文学:必须是明确请求(避免用户口头禅/情绪表达误触工具)
|
||
if re.search(r"(发病文学|犯病文学|发病文|犯病文|发病语录|犯病语录)", t):
|
||
allow.add("get_fabing")
|
||
elif re.search(r"(来|整|给|写|讲|说|发|搞|整点).{0,4}(发病|犯病)", t):
|
||
allow.add("get_fabing")
|
||
elif re.search(r"(发病|犯病).{0,6}(一下|一段|一条|几句|文学|文|语录|段子)", t):
|
||
allow.add("get_fabing")
|
||
if re.search(r"(随机图片|来张图|来个图|随机图)", t):
|
||
allow.add("get_random_image")
|
||
if re.search(r"(随机视频|来个视频|随机短视频)", t):
|
||
allow.add("get_random_video")
|
||
|
||
# 绘图/视频生成(只在用户明确要求时开放)
|
||
if (
|
||
# 明确绘图动词/模式
|
||
re.search(r"(画一张|画一个|画个|画一下|画图|绘图|绘制|作画|出图|生成图片|文生图|图生图|以图生图)", t)
|
||
# “生成/做/给我”+“一张/一个/张/个”+“图/图片”类表达(例如:生成一张瑞依/做一张图)
|
||
or re.search(r"(生成|做|给我|帮我).{0,4}(一张|一幅|一个|张|个).{0,8}(图|图片|照片)", t)
|
||
# “来/发”+“一张/张”+“图/图片”(例如:来张瑞依的图)
|
||
or re.search(r"(来|发).{0,2}(一张|一幅|一个|张|个).{0,10}(图|图片|照片)", t)
|
||
# 视觉诉求但没说“画”(例如:看看腿/白丝)
|
||
or re.search(r"(看看|看下|给我看|让我看看).{0,8}(腿|白丝|黑丝|丝袜|玉足|脚|足|写真|涩图|色图|福利图)", t)
|
||
):
|
||
allow.update({
|
||
"nano_ai_image_generation",
|
||
"flow2_ai_image_generation",
|
||
"jimeng_ai_image_generation",
|
||
"kiira2_ai_image_generation",
|
||
"generate_image",
|
||
})
|
||
if re.search(r"(生成视频|做个视频|视频生成|sora)", t):
|
||
allow.add("sora_video_generation")
|
||
|
||
# 如果已经命中特定领域工具(天气/音乐/短剧等),且用户未明确表示“联网/网页/链接/来源”等需求,避免把联网搜索也暴露出去造成误触
|
||
explicit_web = bool(re.search(r"(联网|网页|网站|网址|链接|来源)", t))
|
||
if not explicit_web and {"query_weather", "search_music", "search_playlet"} & allow:
|
||
allow.discard("tavily_web_search")
|
||
allow.discard("web_search")
|
||
|
||
# 严格模式:没有明显工具意图时,不向模型暴露任何 tools,避免误触
|
||
if not allow:
|
||
return []
|
||
|
||
selected = []
|
||
for tool in tools or []:
|
||
name = tool.get("function", {}).get("name", "")
|
||
if name and name in allow:
|
||
selected.append(tool)
|
||
return selected
|
||
|
||
async def _handle_context_stats(self, bot, from_wxid: str, user_wxid: str, is_group: bool):
|
||
"""处理上下文统计指令"""
|
||
try:
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
|
||
# 计算持久记忆 token
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
persistent_memories = self._get_persistent_memories(memory_chat_id) if memory_chat_id else []
|
||
persistent_tokens = 0
|
||
if persistent_memories:
|
||
persistent_tokens += self._estimate_tokens("【持久记忆】以下是用户要求你记住的重要信息:\n")
|
||
for m in persistent_memories:
|
||
mem_time = m['time'][:10] if m['time'] else ""
|
||
persistent_tokens += self._estimate_tokens(f"- [{mem_time}] {m['nickname']}: {m['content']}\n")
|
||
|
||
if is_group:
|
||
# 群聊:使用 history 机制
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
history = await self._load_history(history_chat_id)
|
||
history = self._filter_history_by_window(history)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
|
||
# 实际会发送给 AI 的上下文
|
||
context_messages = history[-max_context:] if len(history) > max_context else history
|
||
|
||
# 计算 token
|
||
context_tokens = 0
|
||
for msg in context_messages:
|
||
msg_content = msg.get("content", "")
|
||
nickname = msg.get("nickname", "")
|
||
|
||
if isinstance(msg_content, list):
|
||
# 多模态消息
|
||
for item in msg_content:
|
||
if item.get("type") == "text":
|
||
context_tokens += self._estimate_tokens(f"[{nickname}] {item.get('text', '')}")
|
||
elif item.get("type") == "image_url":
|
||
context_tokens += 85
|
||
else:
|
||
context_tokens += self._estimate_tokens(f"[{nickname}] {msg_content}")
|
||
|
||
# 加上 system prompt 的 token
|
||
system_tokens = self._estimate_tokens(self.system_prompt)
|
||
total_tokens = system_tokens + persistent_tokens + context_tokens
|
||
|
||
# 计算百分比
|
||
context_limit = self.config.get("api", {}).get("context_limit", 200000)
|
||
usage_percent = (total_tokens / context_limit) * 100
|
||
remaining_tokens = context_limit - total_tokens
|
||
|
||
msg = f"📊 群聊上下文统计\n\n"
|
||
msg += f"💬 历史总条数: {len(history)}\n"
|
||
msg += f"📤 AI可见条数: {len(context_messages)}/{max_context}\n"
|
||
msg += f"🤖 人设 Token: ~{system_tokens}\n"
|
||
msg += f"📌 持久记忆: {len(persistent_memories)} 条 (~{persistent_tokens} token)\n"
|
||
msg += f"📝 上下文 Token: ~{context_tokens}\n"
|
||
msg += f"📦 总计 Token: ~{total_tokens}\n"
|
||
msg += f"📈 使用率: {usage_percent:.1f}% (剩余 ~{remaining_tokens:,})\n"
|
||
msg += f"\n💡 /清空记忆 清空上下文 | /记忆列表 查看持久记忆"
|
||
|
||
else:
|
||
# 私聊:使用 memory 机制
|
||
memory_messages = self._get_memory_messages(chat_id)
|
||
max_messages = self.config.get("memory", {}).get("max_messages", 20)
|
||
|
||
# 计算 token
|
||
context_tokens = 0
|
||
for msg in memory_messages:
|
||
context_tokens += self._estimate_message_tokens(msg)
|
||
|
||
# 加上 system prompt 的 token
|
||
system_tokens = self._estimate_tokens(self.system_prompt)
|
||
total_tokens = system_tokens + persistent_tokens + context_tokens
|
||
|
||
# 计算百分比
|
||
context_limit = self.config.get("api", {}).get("context_limit", 200000)
|
||
usage_percent = (total_tokens / context_limit) * 100
|
||
remaining_tokens = context_limit - total_tokens
|
||
|
||
msg = f"📊 私聊上下文统计\n\n"
|
||
msg += f"💬 记忆条数: {len(memory_messages)}/{max_messages}\n"
|
||
msg += f"🤖 人设 Token: ~{system_tokens}\n"
|
||
msg += f"📌 持久记忆: {len(persistent_memories)} 条 (~{persistent_tokens} token)\n"
|
||
msg += f"📝 上下文 Token: ~{context_tokens}\n"
|
||
msg += f"📦 总计 Token: ~{total_tokens}\n"
|
||
msg += f"📈 使用率: {usage_percent:.1f}% (剩余 ~{remaining_tokens:,})\n"
|
||
msg += f"\n💡 /清空记忆 清空上下文 | /记忆列表 查看持久记忆"
|
||
|
||
await bot.send_text(from_wxid, msg)
|
||
logger.info(f"已发送上下文统计: {chat_id}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取上下文统计失败: {e}")
|
||
await bot.send_text(from_wxid, f"❌ 获取上下文统计失败: {str(e)}")
|
||
|
||
async def _handle_switch_prompt(self, bot, from_wxid: str, content: str):
|
||
"""处理切换人设指令"""
|
||
try:
|
||
# 提取文件名
|
||
parts = content.split(maxsplit=1)
|
||
if len(parts) < 2:
|
||
await bot.send_text(from_wxid, "❌ 请指定人设文件名\n格式:/切人设 文件名.txt")
|
||
return
|
||
|
||
filename = parts[1].strip()
|
||
|
||
# 检查文件是否存在
|
||
prompt_path = Path(__file__).parent / "prompts" / filename
|
||
if not prompt_path.exists():
|
||
await bot.send_text(from_wxid, f"❌ 人设文件不存在: {filename}")
|
||
return
|
||
|
||
# 读取新人设
|
||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||
new_prompt = f.read().strip()
|
||
|
||
# 更新人设
|
||
self.system_prompt = new_prompt
|
||
self.config["prompt"]["system_prompt_file"] = filename
|
||
|
||
await bot.send_text(from_wxid, f"✅ 已切换人设: {filename}")
|
||
logger.success(f"管理员切换人设: {filename}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"切换人设失败: {e}")
|
||
await bot.send_text(from_wxid, f"❌ 切换人设失败: {str(e)}")
|
||
|
||
@on_text_message(priority=80)
|
||
async def handle_message(self, bot, message: dict):
|
||
"""处理文本消息"""
|
||
content = message.get("Content", "").strip()
|
||
from_wxid = message.get("FromWxid", "")
|
||
sender_wxid = message.get("SenderWxid", "")
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
# 获取实际发送者
|
||
user_wxid = sender_wxid if is_group else from_wxid
|
||
|
||
# 获取机器人 wxid 和管理员列表
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
|
||
admins = main_config.get("Bot", {}).get("admins", [])
|
||
|
||
# 检查是否是人设列表指令(精确匹配)
|
||
if content == "/人设列表":
|
||
await self._handle_list_prompts(bot, from_wxid)
|
||
return False
|
||
|
||
# 昵称测试:返回“微信昵称(全局)”和“群昵称/群名片(群内)”
|
||
if content == "/昵称测试":
|
||
if not is_group:
|
||
await bot.send_text(from_wxid, "该指令仅支持群聊:/昵称测试")
|
||
return False
|
||
|
||
wechat_nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
|
||
group_nickname = await self._get_group_display_name(bot, from_wxid, user_wxid, force_refresh=True)
|
||
|
||
wechat_nickname = self._sanitize_speaker_name(wechat_nickname) or "(未获取到)"
|
||
group_nickname = self._sanitize_speaker_name(group_nickname) or "(未设置/未获取到)"
|
||
|
||
await bot.send_text(
|
||
from_wxid,
|
||
f"微信昵称: {wechat_nickname}\n"
|
||
f"群昵称: {group_nickname}",
|
||
)
|
||
return False
|
||
|
||
# 检查是否是切换人设指令(精确匹配前缀)
|
||
if content.startswith("/切人设 ") or content.startswith("/切换人设 "):
|
||
if user_wxid in admins:
|
||
await self._handle_switch_prompt(bot, from_wxid, content)
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可以切换人设")
|
||
return False
|
||
|
||
# 检查是否是清空记忆指令
|
||
clear_command = self.config.get("memory", {}).get("clear_command", "/清空记忆")
|
||
if content == clear_command:
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
self._clear_memory(chat_id)
|
||
await bot.send_text(from_wxid, "✅ 已清空当前会话的记忆")
|
||
return False
|
||
|
||
# 检查是否是上下文统计指令
|
||
if content == "/context" or content == "/上下文":
|
||
await self._handle_context_stats(bot, from_wxid, user_wxid, is_group)
|
||
return False
|
||
|
||
# 旧群历史 key 扫描/清理(仅管理员)
|
||
if content in ("/旧群历史", "/legacy_history"):
|
||
if user_wxid in admins and self.store:
|
||
legacy_keys = self.store.find_legacy_group_history_keys()
|
||
if legacy_keys:
|
||
await bot.send_text(
|
||
from_wxid,
|
||
f"⚠️ 检测到 {len(legacy_keys)} 个旧版群历史 key(safe_id 写入)。\n"
|
||
f"如需清理请发送 /清理旧群历史",
|
||
)
|
||
else:
|
||
await bot.send_text(from_wxid, "✅ 未发现旧版群历史 key")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可执行该指令")
|
||
return False
|
||
|
||
if content in ("/清理旧群历史", "/clean_legacy_history"):
|
||
if user_wxid in admins and self.store:
|
||
legacy_keys = self.store.find_legacy_group_history_keys()
|
||
deleted = self.store.delete_legacy_group_history_keys(legacy_keys)
|
||
await bot.send_text(
|
||
from_wxid,
|
||
f"✅ 已清理旧版群历史 key: {deleted} 个",
|
||
)
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可执行该指令")
|
||
return False
|
||
|
||
# 检查是否是记忆状态指令(仅管理员)
|
||
if content == "/记忆状态":
|
||
if user_wxid in admins:
|
||
if is_group:
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
history = await self._load_history(history_chat_id)
|
||
history = self._filter_history_by_window(history)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
context_count = min(len(history), max_context)
|
||
msg = f"📊 群聊记忆: {len(history)} 条\n"
|
||
msg += f"💬 AI可见: 最近 {context_count} 条"
|
||
await bot.send_text(from_wxid, msg)
|
||
else:
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
memory = self._get_memory_messages(chat_id)
|
||
msg = f"📊 私聊记忆: {len(memory)} 条"
|
||
await bot.send_text(from_wxid, msg)
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可以查看记忆状态")
|
||
return False
|
||
|
||
# 持久记忆相关指令
|
||
# 记录持久记忆:/记录 xxx
|
||
if content.startswith("/记录 "):
|
||
memory_content = content[4:].strip()
|
||
if memory_content:
|
||
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
|
||
# 群聊用群ID,私聊用用户ID
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
chat_type = "group" if is_group else "private"
|
||
memory_id = self._add_persistent_memory(
|
||
memory_chat_id, chat_type, user_wxid, nickname, memory_content
|
||
)
|
||
await bot.send_text(from_wxid, f"✅ 已记录到持久记忆 (ID: {memory_id})")
|
||
logger.info(f"添加持久记忆: {memory_chat_id} - {memory_content[:30]}...")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 请输入要记录的内容\n格式:/记录 要记住的内容")
|
||
return False
|
||
|
||
# 查看持久记忆列表(所有人可用)
|
||
if content == "/记忆列表" or content == "/持久记忆":
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
memories = self._get_persistent_memories(memory_chat_id)
|
||
if memories:
|
||
msg = f"📋 持久记忆列表 (共 {len(memories)} 条)\n\n"
|
||
for m in memories:
|
||
time_str = m['time'][:16] if m['time'] else "未知"
|
||
content_preview = m['content'][:30] + "..." if len(m['content']) > 30 else m['content']
|
||
msg += f"[{m['id']}] {m['nickname']}: {content_preview}\n 📅 {time_str}\n"
|
||
msg += f"\n💡 删除记忆:/删除记忆 ID (管理员)"
|
||
else:
|
||
msg = "📋 暂无持久记忆"
|
||
await bot.send_text(from_wxid, msg)
|
||
return False
|
||
|
||
# 删除持久记忆(管理员)
|
||
if content.startswith("/删除记忆 "):
|
||
if user_wxid in admins:
|
||
try:
|
||
memory_id = int(content[6:].strip())
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
if self._delete_persistent_memory(memory_chat_id, memory_id):
|
||
await bot.send_text(from_wxid, f"✅ 已删除记忆 ID: {memory_id}")
|
||
else:
|
||
await bot.send_text(from_wxid, f"❌ 未找到记忆 ID: {memory_id}")
|
||
except ValueError:
|
||
await bot.send_text(from_wxid, "❌ 请输入有效的记忆ID\n格式:/删除记忆 ID")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可以删除持久记忆")
|
||
return False
|
||
|
||
# 清空所有持久记忆(管理员)
|
||
if content == "/清空持久记忆":
|
||
if user_wxid in admins:
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
deleted_count = self._clear_persistent_memories(memory_chat_id)
|
||
await bot.send_text(from_wxid, f"✅ 已清空 {deleted_count} 条持久记忆")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可以清空持久记忆")
|
||
return False
|
||
|
||
# 检查是否应该回复
|
||
should_reply = self._should_reply(message, content, bot_wxid)
|
||
|
||
# 获取用户昵称(用于历史记录)- 使用缓存优化
|
||
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
|
||
|
||
# 提取实际消息内容(去除@),仅在需要回复时使用
|
||
actual_content = ""
|
||
if should_reply:
|
||
actual_content = self._extract_content(message, content)
|
||
|
||
# 保存到群组历史记录(默认全量保存;可配置为仅保存触发 AI 的消息,减少上下文污染/串线)
|
||
# 但如果是 AutoReply 触发的,跳过保存(消息已经在正常流程中保存过了)
|
||
if is_group and not message.get('_auto_reply_triggered'):
|
||
if self._should_capture_group_history(is_triggered=bool(should_reply)):
|
||
# mention 模式下,群聊里@机器人仅作为触发条件,不进入上下文,避免同一句话在上下文中出现两种形式(含@/不含@)
|
||
trigger_mode = self.config.get("behavior", {}).get("trigger_mode", "mention")
|
||
history_content = content
|
||
if trigger_mode == "mention" and should_reply and actual_content:
|
||
history_content = actual_content
|
||
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(history_chat_id, nickname, history_content, sender_wxid=user_wxid)
|
||
|
||
# 如果不需要回复,直接返回
|
||
if not should_reply:
|
||
return
|
||
|
||
# 限流检查(仅在需要回复时检查)
|
||
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
|
||
if not allowed:
|
||
rate_limit_config = self.config.get("rate_limit", {})
|
||
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
|
||
msg = msg.format(seconds=reset_time)
|
||
await bot.send_text(from_wxid, msg)
|
||
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
|
||
return False
|
||
|
||
if not actual_content:
|
||
return
|
||
|
||
logger.info(f"AI 处理消息: {actual_content[:50]}...")
|
||
|
||
try:
|
||
# 获取会话ID并添加用户消息到记忆
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
# 如果是 AutoReply 触发的,不重复添加用户消息(已在正常流程中添加)
|
||
if not message.get('_auto_reply_triggered'):
|
||
self._add_to_memory(chat_id, "user", actual_content)
|
||
|
||
# 群聊:消息已写入 history,则不再重复附加到 LLM messages,避免“同一句话发给AI两次”
|
||
history_enabled = bool(self.store) and self.config.get("history", {}).get("enabled", True)
|
||
captured_to_history = bool(
|
||
is_group
|
||
and history_enabled
|
||
and not message.get('_auto_reply_triggered')
|
||
and self._should_capture_group_history(is_triggered=True)
|
||
)
|
||
append_user_message = not captured_to_history
|
||
|
||
# 调用 AI API(带重试机制)
|
||
max_retries = self.config.get("api", {}).get("max_retries", 2)
|
||
response = None
|
||
last_error = None
|
||
|
||
for attempt in range(max_retries + 1):
|
||
try:
|
||
response = await self._call_ai_api(
|
||
actual_content,
|
||
bot,
|
||
from_wxid,
|
||
chat_id,
|
||
nickname,
|
||
user_wxid,
|
||
is_group,
|
||
append_user_message=append_user_message,
|
||
)
|
||
|
||
# 检查返回值:
|
||
# - None: 工具调用已异步处理,不需要重试
|
||
# - "": 真正的空响应,需要重试
|
||
# - 有内容: 正常响应
|
||
if response is None:
|
||
# 工具调用,不重试
|
||
logger.info("AI 触发工具调用,已异步处理")
|
||
break
|
||
|
||
if response == "" and attempt < max_retries:
|
||
logger.warning(f"AI 返回空内容,重试 {attempt + 1}/{max_retries}")
|
||
await asyncio.sleep(1) # 等待1秒后重试
|
||
continue
|
||
|
||
break # 成功或已达到最大重试次数
|
||
|
||
except Exception as e:
|
||
last_error = e
|
||
if attempt < max_retries:
|
||
logger.warning(f"AI API 调用失败,重试 {attempt + 1}/{max_retries}: {e}")
|
||
await asyncio.sleep(1)
|
||
else:
|
||
raise
|
||
|
||
# 发送回复并添加到记忆
|
||
# 注意:如果返回 None 或空字符串,说明已经以其他形式处理了,不需要再发送文本
|
||
if response:
|
||
cleaned_response = self._sanitize_llm_output(response)
|
||
if cleaned_response:
|
||
await bot.send_text(from_wxid, cleaned_response)
|
||
self._add_to_memory(chat_id, "assistant", cleaned_response)
|
||
# 保存机器人回复到历史记录
|
||
history_config = self.config.get("history", {})
|
||
sync_bot_messages = history_config.get("sync_bot_messages", False)
|
||
history_scope = str(history_config.get("scope", "chatroom") or "chatroom").strip().lower()
|
||
can_rely_on_hook = bool(sync_bot_messages and history_scope not in ("per_user", "user", "peruser"))
|
||
if is_group and not can_rely_on_hook:
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(
|
||
history_chat_id,
|
||
bot_nickname,
|
||
cleaned_response,
|
||
role="assistant",
|
||
sender_wxid=user_wxid,
|
||
)
|
||
logger.success(f"AI 回复成功: {cleaned_response[:50]}...")
|
||
else:
|
||
logger.warning("AI 回复清洗后为空(可能只包含思维链/格式标记),已跳过发送")
|
||
else:
|
||
logger.info("AI 回复为空或已通过其他方式发送(如聊天记录)")
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_detail = traceback.format_exc()
|
||
logger.error(f"AI 处理失败: {type(e).__name__}: {str(e)}")
|
||
logger.error(f"详细错误:\n{error_detail}")
|
||
await bot.send_text(from_wxid, "抱歉,我遇到了一些问题,请稍后再试。")
|
||
|
||
def _should_reply(self, message: dict, content: str, bot_wxid: str = None) -> bool:
|
||
"""判断是否应该回复"""
|
||
# 检查是否由AutoReply插件触发
|
||
if message.get('_auto_reply_triggered'):
|
||
return True
|
||
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
# 检查群聊/私聊开关
|
||
if is_group and not self.config["behavior"].get("reply_group", True):
|
||
return False
|
||
if not is_group and not self.config["behavior"].get("reply_private", True):
|
||
return False
|
||
|
||
trigger_mode = self.config["behavior"].get("trigger_mode", "mention")
|
||
|
||
# all 模式:回复所有消息
|
||
if trigger_mode == "all":
|
||
return True
|
||
|
||
# mention 模式:检查是否@了机器人
|
||
if trigger_mode == "mention":
|
||
if is_group:
|
||
ats = message.get("Ats", [])
|
||
|
||
# 如果没有 bot_wxid,从配置文件读取
|
||
if not bot_wxid:
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
|
||
else:
|
||
# 也需要读取昵称用于备用检测
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
|
||
|
||
# 方式1:检查 @ 列表中是否包含机器人的 wxid
|
||
if ats and bot_wxid and bot_wxid in ats:
|
||
return True
|
||
|
||
# 方式2:备用检测 - 从消息内容中检查是否包含 @机器人昵称
|
||
# (当 API 没有返回 at_user_list 时使用)
|
||
if bot_nickname and f"@{bot_nickname}" in content:
|
||
logger.debug(f"通过内容检测到 @{bot_nickname},触发回复")
|
||
return True
|
||
|
||
return False
|
||
else:
|
||
# 私聊直接回复
|
||
return True
|
||
|
||
# keyword 模式:检查关键词
|
||
if trigger_mode == "keyword":
|
||
keywords = self.config["behavior"]["keywords"]
|
||
return any(kw in content for kw in keywords)
|
||
|
||
return False
|
||
|
||
def _extract_content(self, message: dict, content: str) -> str:
|
||
"""提取实际消息内容(去除@等)"""
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
if is_group:
|
||
# 群聊消息,去除@部分
|
||
# 格式通常是 "@昵称 消息内容"
|
||
parts = content.split(maxsplit=1)
|
||
if len(parts) > 1 and parts[0].startswith("@"):
|
||
return parts[1].strip()
|
||
return content.strip()
|
||
|
||
return content.strip()
|
||
|
||
async def _call_ai_api(
|
||
self,
|
||
user_message: str,
|
||
bot=None,
|
||
from_wxid: str = None,
|
||
chat_id: str = None,
|
||
nickname: str = "",
|
||
user_wxid: str = None,
|
||
is_group: bool = False,
|
||
*,
|
||
append_user_message: bool = True,
|
||
tool_query: str | None = None,
|
||
) -> str:
|
||
"""调用 AI API"""
|
||
api_config = self.config["api"]
|
||
|
||
# 收集工具
|
||
all_tools = self._collect_tools()
|
||
tools = self._select_tools_for_message(all_tools, user_message=user_message, tool_query=tool_query)
|
||
logger.info(f"收集到 {len(all_tools)} 个工具函数,本次启用 {len(tools)} 个")
|
||
if tools:
|
||
tool_names = [t["function"]["name"] for t in tools]
|
||
logger.info(f"本次启用工具: {tool_names}")
|
||
|
||
# 构建消息列表
|
||
system_content = self.system_prompt
|
||
|
||
# 添加当前时间信息
|
||
current_time = datetime.now()
|
||
weekday_map = {
|
||
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
|
||
4: "星期五", 5: "星期六", 6: "星期日"
|
||
}
|
||
weekday = weekday_map[current_time.weekday()]
|
||
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
|
||
system_content += f"\n\n当前时间:{time_str}"
|
||
|
||
if nickname:
|
||
system_content += f"\n当前对话用户的昵称是:{nickname}"
|
||
|
||
# 加载持久记忆
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
if memory_chat_id:
|
||
persistent_memories = self._get_persistent_memories(memory_chat_id)
|
||
if persistent_memories:
|
||
system_content += "\n\n【持久记忆】以下是用户要求你记住的重要信息:\n"
|
||
for m in persistent_memories:
|
||
mem_time = m['time'][:10] if m['time'] else ""
|
||
system_content += f"- [{mem_time}] {m['nickname']}: {m['content']}\n"
|
||
|
||
messages = [{"role": "system", "content": system_content}]
|
||
|
||
# 从 JSON 历史记录加载上下文(仅群聊)
|
||
if is_group and from_wxid:
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid or "")
|
||
history = await self._load_history(history_chat_id)
|
||
history = self._filter_history_by_window(history)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
|
||
# 取最近的 N 条消息作为上下文
|
||
recent_history = history[-max_context:] if len(history) > max_context else history
|
||
|
||
# 转换为 AI 消息格式(按 role)
|
||
self._append_group_history_messages(messages, recent_history)
|
||
else:
|
||
# 私聊使用原有的 memory 机制
|
||
if chat_id:
|
||
memory_messages = self._get_memory_messages(chat_id)
|
||
if memory_messages and len(memory_messages) > 1:
|
||
messages.extend(memory_messages[:-1])
|
||
|
||
# 添加当前用户消息
|
||
if append_user_message:
|
||
messages.append({"role": "user", "content": f"[{nickname}] {user_message}" if is_group and nickname else user_message})
|
||
|
||
payload = {
|
||
"model": api_config["model"],
|
||
"messages": messages,
|
||
"max_tokens": api_config.get("max_tokens", 4096) # 防止回复被截断
|
||
}
|
||
|
||
if tools:
|
||
payload["tools"] = tools
|
||
logger.debug(f"已将 {len(tools)} 个工具添加到请求中")
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_config['api_key']}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=api_config["timeout"])
|
||
|
||
# 配置代理
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_username = proxy_config.get("username")
|
||
proxy_password = proxy_config.get("password")
|
||
|
||
# 构建代理 URL
|
||
if proxy_username and proxy_password:
|
||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||
else:
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
|
||
if PROXY_SUPPORT:
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
logger.debug(f"使用代理: {proxy_type}://{proxy_host}:{proxy_port}")
|
||
except Exception as e:
|
||
logger.warning(f"代理配置失败,将直连: {e}")
|
||
connector = None
|
||
else:
|
||
logger.warning("代理功能不可用(aiohttp_socks 未安装),将直连")
|
||
connector = None
|
||
|
||
# 启用流式响应
|
||
payload["stream"] = True
|
||
|
||
try:
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
logger.debug(f"发送流式 API 请求: {api_config['url']}")
|
||
async with session.post(
|
||
api_config["url"],
|
||
json=payload,
|
||
headers=headers
|
||
) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"API 返回错误状态码: {resp.status}, 响应: {error_text}")
|
||
raise Exception(f"API 错误 {resp.status}: {error_text}")
|
||
|
||
# 流式接收响应
|
||
import json
|
||
full_content = ""
|
||
tool_calls_dict = {} # 使用字典来组装工具调用 {index: tool_call}
|
||
tool_call_hint_sent = False # 是否已发送工具调用提示
|
||
|
||
async for line in resp.content:
|
||
line = line.decode('utf-8').strip()
|
||
if not line or line == "data: [DONE]":
|
||
continue
|
||
|
||
if line.startswith("data: "):
|
||
try:
|
||
data = json.loads(line[6:])
|
||
choices = data.get("choices", [])
|
||
if not choices:
|
||
continue
|
||
|
||
delta = choices[0].get("delta", {})
|
||
|
||
# 收集内容
|
||
content = delta.get("content", "")
|
||
if content:
|
||
full_content += content
|
||
|
||
# 收集工具调用(增量式组装)
|
||
if delta.get("tool_calls"):
|
||
# 第一次检测到工具调用时,如果有文本内容则立即发送
|
||
if not tool_call_hint_sent and bot and from_wxid:
|
||
tool_call_hint_sent = True
|
||
# 只有当 AI 有文本输出时才发送
|
||
if full_content and full_content.strip():
|
||
preview = self._sanitize_llm_output(full_content)
|
||
if preview:
|
||
logger.info(f"[流式] 检测到工具调用,先发送已有文本: {preview[:30]}...")
|
||
await bot.send_text(from_wxid, preview)
|
||
else:
|
||
logger.info("[流式] 检测到工具调用,但文本清洗后为空(可能为思维链/无有效正文),跳过发送")
|
||
else:
|
||
# AI 没有输出文本,不发送默认提示
|
||
logger.info("[流式] 检测到工具调用,AI 未输出文本")
|
||
|
||
for tool_call_delta in delta["tool_calls"]:
|
||
index = tool_call_delta.get("index", 0)
|
||
|
||
# 初始化工具调用
|
||
if index not in tool_calls_dict:
|
||
tool_calls_dict[index] = {
|
||
"id": "",
|
||
"type": "function",
|
||
"function": {
|
||
"name": "",
|
||
"arguments": ""
|
||
}
|
||
}
|
||
|
||
# 更新 id
|
||
if "id" in tool_call_delta:
|
||
tool_calls_dict[index]["id"] = tool_call_delta["id"]
|
||
|
||
# 更新 type
|
||
if "type" in tool_call_delta:
|
||
tool_calls_dict[index]["type"] = tool_call_delta["type"]
|
||
|
||
# 更新 function
|
||
if "function" in tool_call_delta:
|
||
func_delta = tool_call_delta["function"]
|
||
if "name" in func_delta:
|
||
tool_calls_dict[index]["function"]["name"] += func_delta["name"]
|
||
if "arguments" in func_delta:
|
||
tool_calls_dict[index]["function"]["arguments"] += func_delta["arguments"]
|
||
except Exception as e:
|
||
logger.debug(f"解析流式数据失败: {e}")
|
||
pass
|
||
|
||
# 转换为列表
|
||
tool_calls_data = [tool_calls_dict[i] for i in sorted(tool_calls_dict.keys())] if tool_calls_dict else []
|
||
|
||
logger.info(f"流式 API 响应完成, 内容长度: {len(full_content)}, 工具调用数: {len(tool_calls_data)}")
|
||
|
||
# 检查是否有函数调用
|
||
if tool_calls_data:
|
||
# 提示已在流式处理中发送,直接启动异步工具执行
|
||
logger.info(f"启动异步工具执行,共 {len(tool_calls_data)} 个工具")
|
||
try:
|
||
await self._record_tool_calls_to_context(
|
||
tool_calls_data,
|
||
from_wxid=from_wxid,
|
||
chat_id=chat_id,
|
||
is_group=is_group,
|
||
user_wxid=user_wxid,
|
||
)
|
||
except Exception as e:
|
||
logger.debug(f"记录工具调用到上下文失败: {e}")
|
||
asyncio.create_task(
|
||
self._execute_tools_async(
|
||
tool_calls_data, bot, from_wxid, chat_id,
|
||
user_wxid, nickname, is_group, messages
|
||
)
|
||
)
|
||
# 返回 None 表示工具调用已异步处理,不需要重试
|
||
return None
|
||
|
||
# 检查是否包含错误的工具调用格式
|
||
if "<tool_code>" in full_content or "print(" in full_content and "flow2_ai_image_generation" in full_content:
|
||
logger.warning("检测到模型输出了错误的工具调用格式,拦截并返回提示")
|
||
return "抱歉,我遇到了一些技术问题,请重新描述一下你的需求~"
|
||
|
||
return self._sanitize_llm_output(full_content)
|
||
except aiohttp.ClientError as e:
|
||
logger.error(f"网络请求失败: {type(e).__name__}: {str(e)}")
|
||
raise Exception(f"网络请求失败: {str(e)}")
|
||
except asyncio.TimeoutError:
|
||
logger.error(f"API 请求超时 (timeout={api_config['timeout']}s)")
|
||
raise Exception(f"API 请求超时")
|
||
except KeyError as e:
|
||
logger.error(f"API 响应格式错误,缺少字段: {e}")
|
||
raise Exception(f"API 响应格式错误: {e}")
|
||
|
||
|
||
async def _load_history(self, chat_id: str) -> list:
|
||
"""异步读取群聊历史(委托 ContextStore)"""
|
||
if not self.store:
|
||
return []
|
||
return await self.store.load_group_history(chat_id)
|
||
|
||
async def _add_to_history(
|
||
self,
|
||
chat_id: str,
|
||
nickname: str,
|
||
content: str,
|
||
image_base64: str = None,
|
||
*,
|
||
role: str = "user",
|
||
sender_wxid: str = None,
|
||
):
|
||
"""将消息存入群聊历史(委托 ContextStore)"""
|
||
if not self.store:
|
||
return
|
||
await self.store.add_group_message(
|
||
chat_id,
|
||
nickname,
|
||
content,
|
||
image_base64=image_base64,
|
||
role=role,
|
||
sender_wxid=sender_wxid,
|
||
)
|
||
|
||
async def _add_to_history_with_id(
|
||
self,
|
||
chat_id: str,
|
||
nickname: str,
|
||
content: str,
|
||
record_id: str,
|
||
*,
|
||
role: str = "user",
|
||
sender_wxid: str = None,
|
||
):
|
||
"""带ID的历史追加, 便于后续更新(委托 ContextStore)"""
|
||
if not self.store:
|
||
return
|
||
await self.store.add_group_message(
|
||
chat_id,
|
||
nickname,
|
||
content,
|
||
record_id=record_id,
|
||
role=role,
|
||
sender_wxid=sender_wxid,
|
||
)
|
||
|
||
async def _update_history_by_id(self, chat_id: str, record_id: str, new_content: str):
|
||
"""根据ID更新历史记录(委托 ContextStore)"""
|
||
if not self.store:
|
||
return
|
||
await self.store.update_group_message_by_id(chat_id, record_id, new_content)
|
||
|
||
|
||
async def _execute_tool_and_get_result(
|
||
self,
|
||
tool_name: str,
|
||
arguments: dict,
|
||
bot,
|
||
from_wxid: str,
|
||
user_wxid: str = None,
|
||
is_group: bool = False,
|
||
tools_map: dict | None = None,
|
||
):
|
||
"""执行工具调用并返回结果"""
|
||
from utils.plugin_manager import PluginManager
|
||
|
||
# 添加用户信息到 arguments
|
||
arguments["user_wxid"] = user_wxid or from_wxid
|
||
arguments["is_group"] = bool(is_group)
|
||
|
||
logger.info(f"开始执行工具: {tool_name}")
|
||
|
||
plugins = PluginManager().plugins
|
||
logger.info(f"检查 {len(plugins)} 个插件")
|
||
|
||
async def _normalize_result(raw, plugin_name: str):
|
||
if raw is None:
|
||
return None
|
||
|
||
if not isinstance(raw, dict):
|
||
raw = {"success": True, "message": str(raw)}
|
||
else:
|
||
raw.setdefault("success", True)
|
||
|
||
if raw.get("success"):
|
||
logger.success(f"工具执行成功: {tool_name} ({plugin_name})")
|
||
else:
|
||
logger.warning(f"工具执行失败: {tool_name} ({plugin_name})")
|
||
return raw
|
||
|
||
# 先尝试直达目标插件(来自 get_llm_tools 的映射)
|
||
if tools_map and tool_name in tools_map:
|
||
target_plugin_name, _tool_def = tools_map[tool_name]
|
||
target_plugin = plugins.get(target_plugin_name)
|
||
if target_plugin and hasattr(target_plugin, "execute_llm_tool"):
|
||
try:
|
||
logger.info(f"直达调用 {target_plugin_name}.execute_llm_tool")
|
||
result = await target_plugin.execute_llm_tool(tool_name, arguments, bot, from_wxid)
|
||
logger.info(f"{target_plugin_name} 返回: {result}")
|
||
normalized = await _normalize_result(result, target_plugin_name)
|
||
if normalized is not None:
|
||
return normalized
|
||
except Exception as e:
|
||
logger.error(f"工具执行异常 ({target_plugin_name}): {tool_name}, {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
else:
|
||
logger.warning(f"工具 {tool_name} 期望插件 {target_plugin_name} 不存在或不支持 execute_llm_tool,回退全量扫描")
|
||
|
||
# 回退:遍历所有插件
|
||
for plugin_name, plugin in plugins.items():
|
||
logger.debug(f"检查插件: {plugin_name}, 有execute_llm_tool: {hasattr(plugin, 'execute_llm_tool')}")
|
||
if not hasattr(plugin, "execute_llm_tool"):
|
||
continue
|
||
|
||
try:
|
||
logger.info(f"调用 {plugin_name}.execute_llm_tool")
|
||
result = await plugin.execute_llm_tool(tool_name, arguments, bot, from_wxid)
|
||
logger.info(f"{plugin_name} 返回: {result}")
|
||
normalized = await _normalize_result(result, plugin_name)
|
||
if normalized is not None:
|
||
return normalized
|
||
except Exception as e:
|
||
logger.error(f"工具执行异常 ({plugin_name}): {tool_name}, {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
|
||
logger.warning(f"未找到工具: {tool_name}")
|
||
return {"success": False, "message": f"未找到工具: {tool_name}"}
|
||
|
||
async def _execute_tools_async(self, tool_calls_data: list, bot, from_wxid: str,
|
||
chat_id: str, user_wxid: str, nickname: str, is_group: bool,
|
||
messages: list):
|
||
"""
|
||
异步执行工具调用(不阻塞主流程)
|
||
|
||
AI 已经先回复用户,这里异步执行工具,完成后发送结果
|
||
支持 need_ai_reply 标记:工具结果回传给 AI 继续对话(保留上下文和人设)
|
||
"""
|
||
try:
|
||
logger.info(f"开始异步执行 {len(tool_calls_data)} 个工具调用")
|
||
|
||
# 并行执行所有工具
|
||
tasks = []
|
||
tool_info_list = [] # 保存工具信息用于后续处理
|
||
tools_map = self._collect_tools_with_plugins()
|
||
schema_map = self._get_tool_schema_map(tools_map)
|
||
|
||
for tool_call in tool_calls_data:
|
||
function_name = tool_call.get("function", {}).get("name", "")
|
||
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
|
||
tool_call_id = tool_call.get("id", "")
|
||
|
||
if not function_name:
|
||
continue
|
||
|
||
try:
|
||
arguments = json.loads(arguments_str)
|
||
except Exception:
|
||
arguments = {}
|
||
|
||
schema = schema_map.get(function_name)
|
||
ok, err, arguments = self._validate_tool_arguments(function_name, arguments, schema)
|
||
if not ok:
|
||
logger.warning(f"[异步] 工具 {function_name} 参数校验失败: {err}")
|
||
try:
|
||
await bot.send_text(from_wxid, f"❌ 工具 {function_name} 参数错误: {err}")
|
||
except Exception:
|
||
pass
|
||
continue
|
||
|
||
logger.info(f"[异步] 准备执行工具: {function_name}, 参数: {arguments}")
|
||
|
||
# 创建异步任务
|
||
task = self._execute_tool_and_get_result(
|
||
function_name,
|
||
arguments,
|
||
bot,
|
||
from_wxid,
|
||
user_wxid=user_wxid,
|
||
is_group=is_group,
|
||
tools_map=tools_map,
|
||
)
|
||
tasks.append(task)
|
||
tool_info_list.append({
|
||
"tool_call_id": tool_call_id,
|
||
"function_name": function_name,
|
||
"arguments": arguments
|
||
})
|
||
|
||
# 并行执行所有工具
|
||
if tasks:
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
# 收集需要 AI 回复的工具结果
|
||
need_ai_reply_results = []
|
||
|
||
# 处理每个工具的结果
|
||
for i, result in enumerate(results):
|
||
tool_info = tool_info_list[i]
|
||
function_name = tool_info["function_name"]
|
||
tool_call_id = tool_info["tool_call_id"]
|
||
|
||
if isinstance(result, Exception):
|
||
logger.error(f"[异步] 工具 {function_name} 执行异常: {result}")
|
||
try:
|
||
await bot.send_text(from_wxid, f"❌ {function_name} 执行失败: {result}")
|
||
except Exception:
|
||
pass
|
||
continue
|
||
|
||
tool_result = ToolResult.from_raw(result)
|
||
if not tool_result:
|
||
continue
|
||
|
||
tool_message = self._sanitize_llm_output(tool_result.message) if tool_result.message is not None else ""
|
||
|
||
# 工具文本统一做一次输出清洗,避免工具内部/下游LLM把“思维链”发出来
|
||
tool_message = self._sanitize_llm_output(tool_result.message) if tool_result.message is not None else ""
|
||
|
||
if tool_result.success:
|
||
logger.success(f"[异步] 工具 {function_name} 执行成功")
|
||
else:
|
||
logger.warning(f"[异步] 工具 {function_name} 执行失败")
|
||
|
||
# 需要 AI 继续处理的结果
|
||
if tool_result.need_ai_reply:
|
||
need_ai_reply_results.append({
|
||
"tool_call_id": tool_call_id,
|
||
"function_name": function_name,
|
||
"result": tool_message
|
||
})
|
||
continue
|
||
|
||
# 工具成功且需要回文本时发送
|
||
if tool_result.success and not tool_result.already_sent and tool_result.message and not tool_result.no_reply:
|
||
if tool_result.send_result_text:
|
||
if tool_message:
|
||
await bot.send_text(from_wxid, tool_message)
|
||
else:
|
||
logger.warning(f"[异步] 工具 {function_name} 输出清洗后为空,已跳过发送")
|
||
|
||
# 工具失败默认回一条错误提示
|
||
if not tool_result.success and tool_result.message and not tool_result.no_reply:
|
||
try:
|
||
if tool_message:
|
||
await bot.send_text(from_wxid, f"❌ {tool_message}")
|
||
else:
|
||
await bot.send_text(from_wxid, f"❌ {function_name} 执行失败")
|
||
except Exception:
|
||
pass
|
||
|
||
# 保存工具结果到记忆(可选)
|
||
if tool_result.save_to_memory and chat_id:
|
||
if tool_message:
|
||
self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {tool_message}")
|
||
|
||
# 如果有需要 AI 回复的工具结果,调用 AI 继续对话
|
||
if need_ai_reply_results:
|
||
await self._continue_with_tool_results(
|
||
need_ai_reply_results, bot, from_wxid, chat_id,
|
||
nickname, is_group, messages, tool_calls_data
|
||
)
|
||
|
||
logger.info(f"[异步] 所有工具执行完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[异步] 工具执行总体异常: {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
try:
|
||
await bot.send_text(from_wxid, "❌ 工具执行过程中出现错误")
|
||
except:
|
||
pass
|
||
|
||
async def _continue_with_tool_results(self, tool_results: list, bot, from_wxid: str,
|
||
chat_id: str, nickname: str, is_group: bool,
|
||
messages: list, tool_calls_data: list):
|
||
"""
|
||
基于工具结果继续调用 AI 对话(保留上下文和人设)
|
||
|
||
用于 need_ai_reply=True 的工具,如联网搜索等
|
||
"""
|
||
import json
|
||
|
||
try:
|
||
logger.info(f"[工具回传] 开始基于 {len(tool_results)} 个工具结果继续对话")
|
||
|
||
# 构建包含工具调用和结果的消息
|
||
# 1. 添加 assistant 的工具调用消息
|
||
tool_calls_msg = []
|
||
for tool_call in tool_calls_data:
|
||
tool_call_id = tool_call.get("id", "")
|
||
function_name = tool_call.get("function", {}).get("name", "")
|
||
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
|
||
|
||
# 只添加需要 AI 回复的工具
|
||
for tr in tool_results:
|
||
if tr["tool_call_id"] == tool_call_id:
|
||
tool_calls_msg.append({
|
||
"id": tool_call_id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": function_name,
|
||
"arguments": arguments_str
|
||
}
|
||
})
|
||
break
|
||
|
||
if tool_calls_msg:
|
||
messages.append({
|
||
"role": "assistant",
|
||
"content": None,
|
||
"tool_calls": tool_calls_msg
|
||
})
|
||
|
||
# 2. 添加工具结果消息
|
||
for tr in tool_results:
|
||
messages.append({
|
||
"role": "tool",
|
||
"tool_call_id": tr["tool_call_id"],
|
||
"content": tr["result"]
|
||
})
|
||
|
||
# 3. 调用 AI 继续对话(不带 tools 参数,避免再次调用工具)
|
||
api_config = self.config["api"]
|
||
proxy_config = self.config.get("proxy", {})
|
||
|
||
payload = {
|
||
"model": api_config["model"],
|
||
"messages": messages,
|
||
"max_tokens": api_config.get("max_tokens", 4096),
|
||
"stream": True
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_config['api_key']}"
|
||
}
|
||
|
||
proxy = None
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "http")
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
|
||
timeout = aiohttp.ClientTimeout(total=api_config.get("timeout", 120))
|
||
|
||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||
async with session.post(
|
||
api_config["url"],
|
||
json=payload,
|
||
headers=headers,
|
||
proxy=proxy
|
||
) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"[工具回传] AI API 错误: {resp.status}, {error_text}")
|
||
await bot.send_text(from_wxid, "❌ AI 处理搜索结果失败")
|
||
return
|
||
|
||
# 流式读取响应
|
||
full_content = ""
|
||
async for line in resp.content:
|
||
line = line.decode("utf-8").strip()
|
||
if not line or not line.startswith("data: "):
|
||
continue
|
||
if line == "data: [DONE]":
|
||
break
|
||
|
||
try:
|
||
data = json.loads(line[6:])
|
||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||
content = delta.get("content", "")
|
||
if content:
|
||
full_content += content
|
||
except:
|
||
continue
|
||
|
||
# 发送 AI 的回复
|
||
if full_content.strip():
|
||
cleaned_content = self._sanitize_llm_output(full_content)
|
||
if cleaned_content:
|
||
await bot.send_text(from_wxid, cleaned_content)
|
||
logger.success(f"[工具回传] AI 回复完成,长度: {len(cleaned_content)}")
|
||
else:
|
||
logger.warning("[工具回传] AI 回复清洗后为空,已跳过发送")
|
||
|
||
# 保存到历史记录
|
||
if chat_id and cleaned_content:
|
||
self._add_to_memory(chat_id, "assistant", cleaned_content)
|
||
else:
|
||
logger.warning("[工具回传] AI 返回空内容")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[工具回传] 继续对话失败: {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
try:
|
||
await bot.send_text(from_wxid, "❌ 处理搜索结果时出错")
|
||
except:
|
||
pass
|
||
|
||
async def _execute_tools_async_with_image(self, tool_calls_data: list, bot, from_wxid: str,
|
||
chat_id: str, user_wxid: str, nickname: str, is_group: bool,
|
||
messages: list, image_base64: str):
|
||
"""
|
||
异步执行工具调用(带图片参数,用于图生图等场景)
|
||
|
||
AI 已经先回复用户,这里异步执行工具,完成后发送结果
|
||
"""
|
||
try:
|
||
logger.info(f"[异步-图片] 开始执行 {len(tool_calls_data)} 个工具调用")
|
||
|
||
# 并行执行所有工具
|
||
tasks = []
|
||
tool_info_list = []
|
||
tools_map = self._collect_tools_with_plugins()
|
||
schema_map = self._get_tool_schema_map(tools_map)
|
||
|
||
for tool_call in tool_calls_data:
|
||
function_name = tool_call.get("function", {}).get("name", "")
|
||
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
|
||
tool_call_id = tool_call.get("id", "")
|
||
|
||
if not function_name:
|
||
continue
|
||
|
||
try:
|
||
arguments = json.loads(arguments_str)
|
||
except Exception:
|
||
arguments = {}
|
||
|
||
# 如果是图生图工具,添加图片 base64
|
||
if function_name == "flow2_ai_image_generation" and image_base64:
|
||
arguments["image_base64"] = image_base64
|
||
logger.info(f"[异步-图片] 图生图工具,已添加图片数据")
|
||
|
||
schema = schema_map.get(function_name)
|
||
ok, err, arguments = self._validate_tool_arguments(function_name, arguments, schema)
|
||
if not ok:
|
||
logger.warning(f"[异步-图片] 工具 {function_name} 参数校验失败: {err}")
|
||
try:
|
||
await bot.send_text(from_wxid, f"❌ 工具 {function_name} 参数错误: {err}")
|
||
except Exception:
|
||
pass
|
||
continue
|
||
|
||
logger.info(f"[异步-图片] 准备执行工具: {function_name}")
|
||
|
||
task = self._execute_tool_and_get_result(
|
||
function_name,
|
||
arguments,
|
||
bot,
|
||
from_wxid,
|
||
user_wxid=user_wxid,
|
||
is_group=is_group,
|
||
tools_map=tools_map,
|
||
)
|
||
tasks.append(task)
|
||
tool_info_list.append({
|
||
"tool_call_id": tool_call_id,
|
||
"function_name": function_name,
|
||
"arguments": arguments
|
||
})
|
||
|
||
# 并行执行所有工具
|
||
if tasks:
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
for i, result in enumerate(results):
|
||
tool_info = tool_info_list[i]
|
||
function_name = tool_info["function_name"]
|
||
|
||
if isinstance(result, Exception):
|
||
logger.error(f"[异步-图片] 工具 {function_name} 执行异常: {result}")
|
||
try:
|
||
await bot.send_text(from_wxid, f"❌ {function_name} 执行失败: {result}")
|
||
except Exception:
|
||
pass
|
||
continue
|
||
|
||
tool_result = ToolResult.from_raw(result)
|
||
if not tool_result:
|
||
continue
|
||
|
||
if tool_result.success:
|
||
logger.success(f"[异步-图片] 工具 {function_name} 执行成功")
|
||
else:
|
||
logger.warning(f"[异步-图片] 工具 {function_name} 执行失败")
|
||
|
||
if tool_result.success and not tool_result.already_sent and tool_result.message and not tool_result.no_reply:
|
||
if tool_result.send_result_text:
|
||
if tool_message:
|
||
await bot.send_text(from_wxid, tool_message)
|
||
else:
|
||
logger.warning(f"[异步-图片] 工具 {function_name} 输出清洗后为空,已跳过发送")
|
||
|
||
if not tool_result.success and tool_result.message and not tool_result.no_reply:
|
||
try:
|
||
if tool_message:
|
||
await bot.send_text(from_wxid, f"❌ {tool_message}")
|
||
else:
|
||
await bot.send_text(from_wxid, f"❌ {function_name} 执行失败")
|
||
except Exception:
|
||
pass
|
||
|
||
if tool_result.save_to_memory and chat_id:
|
||
if tool_message:
|
||
self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {tool_message}")
|
||
|
||
logger.info(f"[异步-图片] 所有工具执行完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[异步-图片] 工具执行总体异常: {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
try:
|
||
await bot.send_text(from_wxid, "❌ 工具执行过程中出现错误")
|
||
except:
|
||
pass
|
||
|
||
@on_quote_message(priority=79)
|
||
async def handle_quote_message(self, bot, message: dict):
|
||
"""处理引用消息(包含图片或记录指令)"""
|
||
content = message.get("Content", "").strip()
|
||
from_wxid = message.get("FromWxid", "")
|
||
sender_wxid = message.get("SenderWxid", "")
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
user_wxid = sender_wxid if is_group else from_wxid
|
||
|
||
try:
|
||
# 群聊引用消息可能带有 "wxid:\n" 前缀,需要去除
|
||
xml_content = content
|
||
if is_group and ":\n" in content:
|
||
# 查找 XML 声明或 <msg> 标签的位置
|
||
xml_start = content.find("<?xml")
|
||
if xml_start == -1:
|
||
xml_start = content.find("<msg")
|
||
if xml_start > 0:
|
||
xml_content = content[xml_start:]
|
||
logger.debug(f"去除引用消息前缀,原长度: {len(content)}, 新长度: {len(xml_content)}")
|
||
|
||
# 解析XML获取标题和引用消息
|
||
root = ET.fromstring(xml_content)
|
||
title = root.find(".//title")
|
||
if title is None or not title.text:
|
||
logger.debug("引用消息没有标题,跳过")
|
||
return True
|
||
|
||
title_text = title.text.strip()
|
||
logger.info(f"收到引用消息,标题: {title_text[:50]}...")
|
||
|
||
# 检查是否是 /记录 指令(引用消息记录)
|
||
if title_text == "/记录" or title_text.startswith("/记录 "):
|
||
# 获取被引用的消息内容
|
||
refermsg = root.find(".//refermsg")
|
||
if refermsg is not None:
|
||
# 获取被引用消息的发送者昵称
|
||
refer_displayname = refermsg.find("displayname")
|
||
refer_nickname = refer_displayname.text if refer_displayname is not None and refer_displayname.text else "未知"
|
||
|
||
# 获取被引用消息的内容
|
||
refer_content_elem = refermsg.find("content")
|
||
if refer_content_elem is not None and refer_content_elem.text:
|
||
refer_text = refer_content_elem.text.strip()
|
||
# 如果是XML格式(如图片),尝试提取文本描述
|
||
if refer_text.startswith("<?xml") or refer_text.startswith("<"):
|
||
refer_text = f"[多媒体消息]"
|
||
else:
|
||
refer_text = "[空消息]"
|
||
|
||
# 组合记忆内容:被引用者说的话
|
||
memory_content = f"{refer_nickname}: {refer_text}"
|
||
|
||
# 如果 /记录 后面有额外备注,添加到记忆中
|
||
if title_text.startswith("/记录 "):
|
||
extra_note = title_text[4:].strip()
|
||
if extra_note:
|
||
memory_content += f" (备注: {extra_note})"
|
||
|
||
# 保存到持久记忆
|
||
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
chat_type = "group" if is_group else "private"
|
||
memory_id = self._add_persistent_memory(
|
||
memory_chat_id, chat_type, user_wxid, nickname, memory_content
|
||
)
|
||
await bot.send_text(from_wxid, f"✅ 已记录到持久记忆 (ID: {memory_id})\n📝 {memory_content[:50]}...")
|
||
logger.info(f"通过引用添加持久记忆: {memory_chat_id} - {memory_content[:30]}...")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 无法获取被引用的消息")
|
||
return False
|
||
|
||
# 检查是否应该回复
|
||
if not self._should_reply_quote(message, title_text):
|
||
logger.debug("引用消息不满足回复条件")
|
||
return True
|
||
|
||
# 获取引用消息中的图片信息
|
||
refermsg = root.find(".//refermsg")
|
||
if refermsg is None:
|
||
logger.debug("引用消息中没有 refermsg 节点")
|
||
return True
|
||
|
||
refer_content = refermsg.find("content")
|
||
if refer_content is None or not refer_content.text:
|
||
logger.debug("引用消息中没有 content")
|
||
return True
|
||
|
||
# 检查被引用消息的类型
|
||
# type=1: 纯文本,type=3: 图片,type=43: 视频,type=49: 应用消息(含聊天记录)
|
||
refer_type_elem = refermsg.find("type")
|
||
refer_type = int(refer_type_elem.text) if refer_type_elem is not None and refer_type_elem.text else 0
|
||
logger.debug(f"被引用消息类型: {refer_type}")
|
||
|
||
# 纯文本消息不需要处理(type=1)
|
||
if refer_type == 1:
|
||
logger.debug("引用的是纯文本消息,跳过")
|
||
return True
|
||
|
||
# 只处理图片(3)、视频(43)、应用消息(49,含聊天记录)
|
||
if refer_type not in [3, 43, 49]:
|
||
logger.debug(f"引用的消息类型 {refer_type} 不支持处理")
|
||
return True
|
||
|
||
# 解码HTML实体
|
||
import html
|
||
refer_xml = html.unescape(refer_content.text)
|
||
|
||
# 被引用消息的内容也可能带有 "wxid:\n" 前缀,需要去除
|
||
if ":\n" in refer_xml:
|
||
xml_start = refer_xml.find("<?xml")
|
||
if xml_start == -1:
|
||
xml_start = refer_xml.find("<msg")
|
||
if xml_start > 0:
|
||
refer_xml = refer_xml[xml_start:]
|
||
logger.debug(f"去除被引用消息前缀")
|
||
|
||
# 尝试解析 XML
|
||
try:
|
||
refer_root = ET.fromstring(refer_xml)
|
||
except ET.ParseError as e:
|
||
logger.debug(f"被引用消息内容不是有效的 XML: {e}")
|
||
return True
|
||
|
||
# 尝试提取聊天记录信息(type=19)
|
||
recorditem = refer_root.find(".//recorditem")
|
||
# 尝试提取图片信息
|
||
img = refer_root.find(".//img")
|
||
# 尝试提取视频信息
|
||
video = refer_root.find(".//videomsg")
|
||
|
||
if img is None and video is None and recorditem is None:
|
||
logger.debug("引用的消息不是图片、视频或聊天记录")
|
||
return True
|
||
|
||
# 检查是否应该回复(提前检查,避免下载后才发现不需要回复)
|
||
if not self._should_reply_quote(message, title_text):
|
||
logger.debug("引用消息不满足回复条件")
|
||
return True
|
||
|
||
# 限流检查
|
||
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
|
||
if not allowed:
|
||
rate_limit_config = self.config.get("rate_limit", {})
|
||
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
|
||
msg = msg.format(seconds=reset_time)
|
||
await bot.send_text(from_wxid, msg)
|
||
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
|
||
return False
|
||
|
||
# 获取用户昵称 - 使用缓存优化
|
||
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
|
||
# 处理聊天记录消息(type=19)
|
||
if recorditem is not None:
|
||
return await self._handle_quote_chat_record(
|
||
bot, recorditem, title_text, from_wxid, user_wxid,
|
||
is_group, nickname, chat_id
|
||
)
|
||
|
||
# 处理视频消息
|
||
if video is not None:
|
||
return await self._handle_quote_video(
|
||
bot, video, title_text, from_wxid, user_wxid,
|
||
is_group, nickname, chat_id
|
||
)
|
||
|
||
# 处理图片消息
|
||
# 按优先级尝试获取图片 URL:大图 > 中图 > 缩略图
|
||
cdnurl = img.get("cdnbigimgurl", "") or img.get("cdnmidimgurl", "") or img.get("cdnthumburl", "")
|
||
# aeskey 也有多种可能的属性名
|
||
aeskey = img.get("aeskey", "") or img.get("cdnthumbaeskey", "")
|
||
|
||
if not cdnurl or not aeskey:
|
||
logger.warning(f"图片信息不完整: cdnurl={bool(cdnurl)}, aeskey={bool(aeskey)}")
|
||
return True
|
||
|
||
logger.info(f"AI处理引用图片消息: {title_text[:50]}...")
|
||
|
||
# 下载并编码图片
|
||
logger.info(f"开始下载图片: {cdnurl[:50]}...")
|
||
image_base64 = await self._download_and_encode_image(bot, cdnurl, aeskey)
|
||
if not image_base64:
|
||
logger.error("图片下载失败")
|
||
await bot.send_text(from_wxid, "❌ 无法处理图片")
|
||
return False
|
||
logger.info("图片下载和编码成功")
|
||
|
||
# 添加消息到记忆(包含图片base64)
|
||
self._add_to_memory(chat_id, "user", title_text, image_base64=image_base64)
|
||
|
||
# 保存用户引用图片消息到群组历史记录
|
||
if is_group and self._should_capture_group_history(is_triggered=True):
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(
|
||
history_chat_id,
|
||
nickname,
|
||
title_text,
|
||
image_base64=image_base64,
|
||
sender_wxid=user_wxid,
|
||
)
|
||
|
||
# 调用AI API(带图片)
|
||
history_enabled = bool(self.store) and self.config.get("history", {}).get("enabled", True)
|
||
captured_to_history = bool(is_group and history_enabled and self._should_capture_group_history(is_triggered=True))
|
||
append_user_message = not captured_to_history
|
||
response = await self._call_ai_api_with_image(
|
||
title_text,
|
||
image_base64,
|
||
bot,
|
||
from_wxid,
|
||
chat_id,
|
||
nickname,
|
||
user_wxid,
|
||
is_group,
|
||
append_user_message=append_user_message,
|
||
tool_query=title_text,
|
||
)
|
||
|
||
if response:
|
||
cleaned_response = self._sanitize_llm_output(response)
|
||
if cleaned_response:
|
||
await bot.send_text(from_wxid, cleaned_response)
|
||
self._add_to_memory(chat_id, "assistant", cleaned_response)
|
||
# 保存机器人回复到历史记录
|
||
history_config = self.config.get("history", {})
|
||
sync_bot_messages = history_config.get("sync_bot_messages", False)
|
||
history_scope = str(history_config.get("scope", "chatroom") or "chatroom").strip().lower()
|
||
can_rely_on_hook = bool(sync_bot_messages and history_scope not in ("per_user", "user", "peruser"))
|
||
if is_group and not can_rely_on_hook:
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(
|
||
history_chat_id,
|
||
bot_nickname,
|
||
cleaned_response,
|
||
role="assistant",
|
||
sender_wxid=user_wxid,
|
||
)
|
||
logger.success(f"AI回复成功: {cleaned_response[:50]}...")
|
||
else:
|
||
logger.warning("AI 回复清洗后为空,已跳过发送")
|
||
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理引用消息失败: {e}")
|
||
return True
|
||
|
||
async def _handle_quote_chat_record(self, bot, recorditem_elem, title_text: str, from_wxid: str,
|
||
user_wxid: str, is_group: bool, nickname: str, chat_id: str):
|
||
"""处理引用的聊天记录消息(type=19)"""
|
||
try:
|
||
logger.info(f"[聊天记录] 处理引用的聊天记录: {title_text[:50]}...")
|
||
|
||
# recorditem 的内容是 CDATA,需要提取并解析
|
||
record_text = recorditem_elem.text
|
||
if not record_text:
|
||
logger.warning("[聊天记录] recorditem 内容为空")
|
||
await bot.send_text(from_wxid, "❌ 无法读取聊天记录内容")
|
||
return False
|
||
|
||
# 解析 recordinfo XML
|
||
try:
|
||
record_root = ET.fromstring(record_text)
|
||
except ET.ParseError as e:
|
||
logger.error(f"[聊天记录] 解析 recordinfo 失败: {e}")
|
||
await bot.send_text(from_wxid, "❌ 聊天记录格式解析失败")
|
||
return False
|
||
|
||
# 提取聊天记录内容
|
||
datalist = record_root.find(".//datalist")
|
||
chat_records = []
|
||
|
||
# 尝试从 datalist 解析完整消息
|
||
if datalist is not None:
|
||
for dataitem in datalist.findall("dataitem"):
|
||
source_name = dataitem.find("sourcename")
|
||
source_time = dataitem.find("sourcetime")
|
||
data_desc = dataitem.find("datadesc")
|
||
|
||
sender = source_name.text if source_name is not None and source_name.text else "未知"
|
||
time_str = source_time.text if source_time is not None and source_time.text else ""
|
||
content = data_desc.text if data_desc is not None and data_desc.text else ""
|
||
|
||
if content:
|
||
chat_records.append({
|
||
"sender": sender,
|
||
"time": time_str,
|
||
"content": content
|
||
})
|
||
|
||
# 如果 datalist 为空(引用消息的简化版本),尝试从 desc 获取摘要
|
||
if not chat_records:
|
||
desc_elem = record_root.find(".//desc")
|
||
if desc_elem is not None and desc_elem.text:
|
||
# desc 格式通常是 "发送者: 内容\n发送者: 内容"
|
||
desc_text = desc_elem.text.strip()
|
||
logger.info(f"[聊天记录] 从 desc 获取摘要内容: {desc_text[:100]}...")
|
||
chat_records.append({
|
||
"sender": "聊天记录摘要",
|
||
"time": "",
|
||
"content": desc_text
|
||
})
|
||
|
||
if not chat_records:
|
||
logger.warning("[聊天记录] 没有解析到任何消息")
|
||
await bot.send_text(from_wxid, "❌ 聊天记录中没有消息内容")
|
||
return False
|
||
|
||
logger.info(f"[聊天记录] 解析到 {len(chat_records)} 条消息")
|
||
|
||
# 构建聊天记录文本
|
||
record_title = record_root.find(".//title")
|
||
title = record_title.text if record_title is not None and record_title.text else "聊天记录"
|
||
|
||
chat_text = f"【{title}】\n\n"
|
||
for i, record in enumerate(chat_records, 1):
|
||
time_part = f" ({record['time']})" if record['time'] else ""
|
||
if record['sender'] == "聊天记录摘要":
|
||
# 摘要模式,直接显示内容
|
||
chat_text += f"{record['content']}\n\n"
|
||
else:
|
||
chat_text += f"[{record['sender']}{time_part}]:\n{record['content']}\n\n"
|
||
|
||
# 构造发送给 AI 的消息
|
||
user_question = title_text.strip() if title_text.strip() else "请分析这段聊天记录"
|
||
# 去除 @ 部分
|
||
if user_question.startswith("@"):
|
||
parts = user_question.split(maxsplit=1)
|
||
if len(parts) > 1:
|
||
user_question = parts[1].strip()
|
||
else:
|
||
user_question = "请分析这段聊天记录"
|
||
|
||
combined_message = f"[用户发送了一段聊天记录,请阅读并回答问题]\n\n{chat_text}\n[用户的问题]: {user_question}"
|
||
|
||
logger.info(f"[聊天记录] 发送给 AI,消息长度: {len(combined_message)}")
|
||
|
||
# 添加到记忆
|
||
self._add_to_memory(chat_id, "user", combined_message)
|
||
|
||
# 如果是群聊,添加到历史记录
|
||
if is_group and self._should_capture_group_history(is_triggered=True):
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(
|
||
history_chat_id,
|
||
nickname,
|
||
f"[发送了聊天记录] {user_question}",
|
||
sender_wxid=user_wxid,
|
||
)
|
||
|
||
# 调用 AI API
|
||
response = await self._call_ai_api(
|
||
combined_message,
|
||
bot,
|
||
from_wxid,
|
||
chat_id,
|
||
nickname,
|
||
user_wxid,
|
||
is_group,
|
||
tool_query=user_question,
|
||
)
|
||
|
||
if response:
|
||
cleaned_response = self._sanitize_llm_output(response)
|
||
if cleaned_response:
|
||
await bot.send_text(from_wxid, cleaned_response)
|
||
self._add_to_memory(chat_id, "assistant", cleaned_response)
|
||
# 保存机器人回复到历史记录
|
||
history_config = self.config.get("history", {})
|
||
sync_bot_messages = history_config.get("sync_bot_messages", False)
|
||
history_scope = str(history_config.get("scope", "chatroom") or "chatroom").strip().lower()
|
||
can_rely_on_hook = bool(sync_bot_messages and history_scope not in ("per_user", "user", "peruser"))
|
||
if is_group and not can_rely_on_hook:
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(
|
||
history_chat_id,
|
||
bot_nickname,
|
||
cleaned_response,
|
||
role="assistant",
|
||
sender_wxid=user_wxid,
|
||
)
|
||
logger.success(f"[聊天记录] AI 回复成功: {cleaned_response[:50]}...")
|
||
else:
|
||
logger.warning("[聊天记录] AI 回复清洗后为空,已跳过发送")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ AI 回复生成失败")
|
||
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"[聊天记录] 处理失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
await bot.send_text(from_wxid, "❌ 聊天记录处理出错")
|
||
return False
|
||
|
||
async def _handle_quote_video(self, bot, video_elem, title_text: str, from_wxid: str,
|
||
user_wxid: str, is_group: bool, nickname: str, chat_id: str):
|
||
"""处理引用的视频消息 - 双AI架构"""
|
||
try:
|
||
# 检查视频识别功能是否启用
|
||
video_config = self.config.get("video_recognition", {})
|
||
if not video_config.get("enabled", True):
|
||
logger.info("[视频识别] 功能未启用")
|
||
await bot.send_text(from_wxid, "❌ 视频识别功能未启用")
|
||
return False
|
||
|
||
# 提取视频 CDN 信息
|
||
cdnvideourl = video_elem.get("cdnvideourl", "")
|
||
aeskey = video_elem.get("aeskey", "")
|
||
|
||
# 如果主要的CDN信息为空,尝试获取原始视频信息
|
||
if not cdnvideourl or not aeskey:
|
||
cdnvideourl = video_elem.get("cdnrawvideourl", "")
|
||
aeskey = video_elem.get("cdnrawvideoaeskey", "")
|
||
|
||
if not cdnvideourl or not aeskey:
|
||
logger.warning(f"[视频识别] 视频信息不完整: cdnurl={bool(cdnvideourl)}, aeskey={bool(aeskey)}")
|
||
await bot.send_text(from_wxid, "❌ 无法获取视频信息")
|
||
return False
|
||
|
||
logger.info(f"[视频识别] 处理引用视频: {title_text[:50]}...")
|
||
|
||
# 提示用户正在处理
|
||
await bot.send_text(from_wxid, "🎬 正在分析视频,请稍候...")
|
||
|
||
# 下载并编码视频
|
||
video_base64 = await self._download_and_encode_video(bot, cdnvideourl, aeskey)
|
||
if not video_base64:
|
||
logger.error("[视频识别] 视频下载失败")
|
||
await bot.send_text(from_wxid, "❌ 视频下载失败")
|
||
return False
|
||
|
||
logger.info("[视频识别] 视频下载和编码成功")
|
||
|
||
# ========== 第一步:视频AI 分析视频内容 ==========
|
||
video_description = await self._analyze_video_content(video_base64, video_config)
|
||
if not video_description:
|
||
logger.error("[视频识别] 视频AI分析失败")
|
||
await bot.send_text(from_wxid, "❌ 视频分析失败")
|
||
return False
|
||
|
||
logger.info(f"[视频识别] 视频AI分析完成: {video_description[:100]}...")
|
||
|
||
# ========== 第二步:主AI 基于视频描述生成回复 ==========
|
||
# 构造包含视频描述的用户消息
|
||
user_question = title_text.strip() if title_text.strip() else "这个视频讲了什么?"
|
||
combined_message = f"[用户发送了一个视频,以下是视频内容描述]\n{video_description}\n\n[用户的问题]\n{user_question}"
|
||
|
||
# 添加到记忆(让主AI知道用户发了视频)
|
||
self._add_to_memory(chat_id, "user", combined_message)
|
||
|
||
# 如果是群聊,添加到历史记录
|
||
if is_group and self._should_capture_group_history(is_triggered=True):
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(
|
||
history_chat_id,
|
||
nickname,
|
||
f"[发送了一个视频] {user_question}",
|
||
sender_wxid=user_wxid,
|
||
)
|
||
|
||
# 调用主AI生成回复(使用现有的 _call_ai_api 方法,继承完整上下文)
|
||
response = await self._call_ai_api(
|
||
combined_message,
|
||
bot,
|
||
from_wxid,
|
||
chat_id,
|
||
nickname,
|
||
user_wxid,
|
||
is_group,
|
||
tool_query=user_question,
|
||
)
|
||
|
||
if response:
|
||
cleaned_response = self._sanitize_llm_output(response)
|
||
if cleaned_response:
|
||
await bot.send_text(from_wxid, cleaned_response)
|
||
self._add_to_memory(chat_id, "assistant", cleaned_response)
|
||
# 保存机器人回复到历史记录
|
||
history_config = self.config.get("history", {})
|
||
sync_bot_messages = history_config.get("sync_bot_messages", False)
|
||
history_scope = str(history_config.get("scope", "chatroom") or "chatroom").strip().lower()
|
||
can_rely_on_hook = bool(sync_bot_messages and history_scope not in ("per_user", "user", "peruser"))
|
||
if is_group and not can_rely_on_hook:
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||
await self._add_to_history(
|
||
history_chat_id,
|
||
bot_nickname,
|
||
cleaned_response,
|
||
role="assistant",
|
||
sender_wxid=user_wxid,
|
||
)
|
||
logger.success(f"[视频识别] 主AI回复成功: {cleaned_response[:50]}...")
|
||
else:
|
||
logger.warning("[视频识别] 主AI回复清洗后为空,已跳过发送")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ AI 回复生成失败")
|
||
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"[视频识别] 处理视频失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
await bot.send_text(from_wxid, "❌ 视频处理出错")
|
||
return False
|
||
|
||
async def _analyze_video_content(self, video_base64: str, video_config: dict) -> str:
|
||
"""视频AI:专门分析视频内容,生成客观描述"""
|
||
try:
|
||
api_url = video_config.get("api_url", "https://api.functen.cn/v1beta/models")
|
||
api_key = video_config.get("api_key", self.config["api"]["api_key"])
|
||
model = video_config.get("model", "gemini-3-pro-preview")
|
||
|
||
full_url = f"{api_url}/{model}:generateContent"
|
||
|
||
# 去除 data:video/mp4;base64, 前缀(如果有)
|
||
if video_base64.startswith("data:"):
|
||
video_base64 = video_base64.split(",", 1)[1]
|
||
logger.debug("[视频AI] 已去除 base64 前缀")
|
||
|
||
# 视频分析专用提示词
|
||
analyze_prompt = """请详细分析这个视频的内容,包括:
|
||
1. 视频的主要场景和环境
|
||
2. 出现的人物/物体及其动作
|
||
3. 视频中的文字、对话或声音(如果有)
|
||
4. 视频的整体主题或要表达的内容
|
||
5. 任何值得注意的细节
|
||
|
||
请用客观、详细的方式描述,不要加入主观评价。"""
|
||
|
||
payload = {
|
||
"contents": [
|
||
{
|
||
"parts": [
|
||
{"text": analyze_prompt},
|
||
{
|
||
"inline_data": {
|
||
"mime_type": "video/mp4",
|
||
"data": video_base64
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
"generationConfig": {
|
||
"maxOutputTokens": video_config.get("max_tokens", 8192)
|
||
}
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=video_config.get("timeout", 360))
|
||
|
||
# 重试机制:对于 502/503/504 等临时性错误自动重试
|
||
max_retries = 2
|
||
retry_delay = 5 # 重试间隔(秒)
|
||
|
||
for attempt in range(max_retries + 1):
|
||
try:
|
||
logger.info(f"[视频AI] 开始分析视频...{f' (重试 {attempt}/{max_retries})' if attempt > 0 else ''}")
|
||
|
||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||
async with session.post(full_url, json=payload, headers=headers) as resp:
|
||
if resp.status in [502, 503, 504]:
|
||
error_text = await resp.text()
|
||
logger.warning(f"[视频AI] API 临时错误: {resp.status}, 将重试...")
|
||
if attempt < max_retries:
|
||
await asyncio.sleep(retry_delay)
|
||
continue
|
||
else:
|
||
logger.error(f"[视频AI] API 错误: {resp.status}, 已达最大重试次数")
|
||
return ""
|
||
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"[视频AI] API 错误: {resp.status}, {error_text[:300]}")
|
||
return ""
|
||
|
||
result = await resp.json()
|
||
logger.info(f"[视频AI] API 响应 keys: {list(result.keys())}")
|
||
|
||
# 检查安全过滤
|
||
if "promptFeedback" in result:
|
||
feedback = result["promptFeedback"]
|
||
if feedback.get("blockReason"):
|
||
logger.warning(f"[视频AI] 内容被过滤: {feedback.get('blockReason')}")
|
||
return ""
|
||
|
||
# 提取文本
|
||
if "candidates" in result and result["candidates"]:
|
||
for candidate in result["candidates"]:
|
||
# 检查是否被安全过滤
|
||
if candidate.get("finishReason") == "SAFETY":
|
||
logger.warning("[视频AI] 响应被安全过滤")
|
||
return ""
|
||
|
||
content = candidate.get("content", {})
|
||
for part in content.get("parts", []):
|
||
if "text" in part:
|
||
text = part["text"]
|
||
logger.info(f"[视频AI] 分析完成,长度: {len(text)}")
|
||
return self._sanitize_llm_output(text)
|
||
|
||
# 记录失败原因
|
||
if "usageMetadata" in result:
|
||
usage = result["usageMetadata"]
|
||
logger.warning(f"[视频AI] 无响应,Token: prompt={usage.get('promptTokenCount', 0)}")
|
||
|
||
logger.error(f"[视频AI] 没有有效响应: {str(result)[:300]}")
|
||
return ""
|
||
|
||
except asyncio.TimeoutError:
|
||
logger.warning(f"[视频AI] 请求超时{f', 将重试...' if attempt < max_retries else ''}")
|
||
if attempt < max_retries:
|
||
await asyncio.sleep(retry_delay)
|
||
continue
|
||
return ""
|
||
except Exception as e:
|
||
logger.error(f"[视频AI] 分析失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return ""
|
||
|
||
# 循环结束仍未成功
|
||
return ""
|
||
|
||
except Exception as e:
|
||
logger.error(f"[视频AI] 分析失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return ""
|
||
|
||
async def _download_and_encode_video(self, bot, cdnurl: str, aeskey: str) -> str:
|
||
"""下载视频并转换为 base64"""
|
||
try:
|
||
# 从缓存获取
|
||
from utils.redis_cache import RedisCache
|
||
redis_cache = get_cache()
|
||
if redis_cache and redis_cache.enabled:
|
||
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
|
||
if media_key:
|
||
cached_data = redis_cache.get_cached_media(media_key, "video")
|
||
if cached_data:
|
||
logger.debug(f"[视频识别] 从缓存获取视频: {media_key[:20]}...")
|
||
return cached_data
|
||
|
||
# 下载视频
|
||
logger.info(f"[视频识别] 开始下载视频...")
|
||
temp_dir = Path(__file__).parent / "temp"
|
||
temp_dir.mkdir(exist_ok=True)
|
||
|
||
filename = f"video_{uuid.uuid4().hex[:8]}.mp4"
|
||
save_path = str((temp_dir / filename).resolve())
|
||
|
||
# file_type=4 表示视频
|
||
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=4)
|
||
if not success:
|
||
logger.error("[视频识别] CDN 下载失败")
|
||
return ""
|
||
|
||
# 等待文件写入完成
|
||
import os
|
||
for _ in range(30): # 最多等待15秒
|
||
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
|
||
break
|
||
await asyncio.sleep(0.5)
|
||
|
||
if not os.path.exists(save_path):
|
||
logger.error("[视频识别] 视频文件未生成")
|
||
return ""
|
||
|
||
file_size = os.path.getsize(save_path)
|
||
logger.info(f"[视频识别] 视频下载完成,大小: {file_size / 1024 / 1024:.2f} MB")
|
||
|
||
# 检查文件大小限制
|
||
video_config = self.config.get("video_recognition", {})
|
||
max_size_mb = video_config.get("max_size_mb", 20)
|
||
if file_size > max_size_mb * 1024 * 1024:
|
||
logger.warning(f"[视频识别] 视频文件过大: {file_size / 1024 / 1024:.2f} MB > {max_size_mb} MB")
|
||
try:
|
||
Path(save_path).unlink()
|
||
except:
|
||
pass
|
||
return ""
|
||
|
||
# 读取并编码为 base64
|
||
with open(save_path, "rb") as f:
|
||
video_data = base64.b64encode(f.read()).decode()
|
||
|
||
video_base64 = f"data:video/mp4;base64,{video_data}"
|
||
|
||
# 缓存到 Redis
|
||
if redis_cache and redis_cache.enabled and media_key:
|
||
redis_cache.cache_media(media_key, video_base64, "video", ttl=600)
|
||
logger.debug(f"[视频识别] 视频已缓存: {media_key[:20]}...")
|
||
|
||
# 清理临时文件
|
||
try:
|
||
Path(save_path).unlink()
|
||
except:
|
||
pass
|
||
|
||
return video_base64
|
||
|
||
except Exception as e:
|
||
logger.error(f"[视频识别] 下载视频失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return ""
|
||
|
||
async def _call_ai_api_with_video(self, user_message: str, video_base64: str, bot=None,
|
||
from_wxid: str = None, chat_id: str = None,
|
||
nickname: str = "", user_wxid: str = None,
|
||
is_group: bool = False) -> str:
|
||
"""调用 Gemini 原生 API(带视频)- 继承完整上下文"""
|
||
try:
|
||
video_config = self.config.get("video_recognition", {})
|
||
|
||
# 使用视频识别专用配置
|
||
video_model = video_config.get("model", "gemini-3-pro-preview")
|
||
api_url = video_config.get("api_url", "https://api.functen.cn/v1beta/models")
|
||
api_key = video_config.get("api_key", self.config["api"]["api_key"])
|
||
|
||
# 构建完整的 API URL
|
||
full_url = f"{api_url}/{video_model}:generateContent"
|
||
|
||
# 构建系统提示(与 _call_ai_api 保持一致)
|
||
system_content = self.system_prompt
|
||
current_time = datetime.now()
|
||
weekday_map = {
|
||
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
|
||
4: "星期五", 5: "星期六", 6: "星期日"
|
||
}
|
||
weekday = weekday_map[current_time.weekday()]
|
||
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
|
||
system_content += f"\n\n当前时间:{time_str}"
|
||
|
||
if nickname:
|
||
system_content += f"\n当前对话用户的昵称是:{nickname}"
|
||
|
||
# 加载持久记忆
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
if memory_chat_id:
|
||
persistent_memories = self._get_persistent_memories(memory_chat_id)
|
||
if persistent_memories:
|
||
system_content += "\n\n【持久记忆】以下是用户要求你记住的重要信息:\n"
|
||
for m in persistent_memories:
|
||
mem_time = m['time'][:10] if m['time'] else ""
|
||
system_content += f"- [{mem_time}] {m['nickname']}: {m['content']}\n"
|
||
|
||
# 构建历史上下文
|
||
history_context = ""
|
||
if is_group and from_wxid:
|
||
# 群聊:从 Redis/文件加载历史
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid or "")
|
||
history = await self._load_history(history_chat_id)
|
||
history = self._filter_history_by_window(history)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
recent_history = history[-max_context:] if len(history) > max_context else history
|
||
|
||
if recent_history:
|
||
history_context = "\n\n【最近的群聊记录】\n"
|
||
for msg in recent_history:
|
||
msg_nickname = msg.get("nickname", "")
|
||
msg_content = msg.get("content", "")
|
||
if isinstance(msg_content, list):
|
||
# 多模态内容,提取文本
|
||
for item in msg_content:
|
||
if item.get("type") == "text":
|
||
msg_content = item.get("text", "")
|
||
break
|
||
else:
|
||
msg_content = "[图片]"
|
||
# 限制单条消息长度
|
||
if len(str(msg_content)) > 200:
|
||
msg_content = str(msg_content)[:200] + "..."
|
||
history_context += f"[{msg_nickname}] {msg_content}\n"
|
||
else:
|
||
# 私聊:从 memory 加载
|
||
if chat_id:
|
||
memory_messages = self._get_memory_messages(chat_id)
|
||
if memory_messages:
|
||
history_context = "\n\n【最近的对话记录】\n"
|
||
for msg in memory_messages[-20:]: # 最近20条
|
||
role = msg.get("role", "")
|
||
content = msg.get("content", "")
|
||
if isinstance(content, list):
|
||
for item in content:
|
||
if item.get("type") == "text":
|
||
content = item.get("text", "")
|
||
break
|
||
else:
|
||
content = "[图片]"
|
||
role_name = "用户" if role == "user" else "你"
|
||
if len(str(content)) > 200:
|
||
content = str(content)[:200] + "..."
|
||
history_context += f"[{role_name}] {content}\n"
|
||
|
||
# 从 data:video/mp4;base64,xxx 中提取纯 base64 数据
|
||
if video_base64.startswith("data:"):
|
||
video_base64 = video_base64.split(",", 1)[1]
|
||
|
||
# 构建完整提示(人设 + 历史 + 当前问题)
|
||
full_prompt = system_content + history_context + f"\n\n【当前】用户发送了一个视频并问:{user_message or '请描述这个视频的内容'}"
|
||
|
||
# 构建 Gemini 原生格式请求
|
||
payload = {
|
||
"contents": [
|
||
{
|
||
"parts": [
|
||
{"text": full_prompt},
|
||
{
|
||
"inline_data": {
|
||
"mime_type": "video/mp4",
|
||
"data": video_base64
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
"generationConfig": {
|
||
"maxOutputTokens": video_config.get("max_tokens", 8192)
|
||
}
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=video_config.get("timeout", 360))
|
||
|
||
# 配置代理
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False) and PROXY_SUPPORT:
|
||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
except Exception as e:
|
||
logger.warning(f"[视频识别] 代理配置失败: {e}")
|
||
|
||
logger.info(f"[视频识别] 调用 Gemini API: {full_url}")
|
||
logger.debug(f"[视频识别] 提示词长度: {len(full_prompt)} 字符")
|
||
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
async with session.post(full_url, json=payload, headers=headers) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"[视频识别] API 错误: {resp.status}, {error_text[:500]}")
|
||
return ""
|
||
|
||
# 解析 Gemini 响应格式
|
||
result = await resp.json()
|
||
# 详细记录响应(用于调试)
|
||
logger.info(f"[视频识别] API 响应 keys: {list(result.keys()) if isinstance(result, dict) else type(result)}")
|
||
|
||
# 检查是否有错误
|
||
if "error" in result:
|
||
logger.error(f"[视频识别] API 返回错误: {result['error']}")
|
||
return ""
|
||
|
||
# 检查 promptFeedback(安全过滤信息)
|
||
if "promptFeedback" in result:
|
||
feedback = result["promptFeedback"]
|
||
block_reason = feedback.get("blockReason", "")
|
||
if block_reason:
|
||
logger.warning(f"[视频识别] 请求被阻止,原因: {block_reason}")
|
||
logger.warning(f"[视频识别] 安全评级: {feedback.get('safetyRatings', [])}")
|
||
return "抱歉,视频内容无法分析(内容策略限制)。"
|
||
|
||
# 提取文本内容
|
||
full_content = ""
|
||
if "candidates" in result and result["candidates"]:
|
||
logger.info(f"[视频识别] candidates 数量: {len(result['candidates'])}")
|
||
for i, candidate in enumerate(result["candidates"]):
|
||
# 检查 finishReason
|
||
finish_reason = candidate.get("finishReason", "")
|
||
if finish_reason:
|
||
logger.info(f"[视频识别] candidate[{i}] finishReason: {finish_reason}")
|
||
if finish_reason == "SAFETY":
|
||
logger.warning(f"[视频识别] 内容被安全过滤: {candidate.get('safetyRatings', [])}")
|
||
return "抱歉,视频内容无法分析。"
|
||
|
||
content = candidate.get("content", {})
|
||
parts = content.get("parts", [])
|
||
logger.info(f"[视频识别] candidate[{i}] parts 数量: {len(parts)}")
|
||
for part in parts:
|
||
if "text" in part:
|
||
full_content += part["text"]
|
||
else:
|
||
# 没有 candidates,记录完整响应
|
||
logger.error(f"[视频识别] 响应中没有 candidates: {str(result)[:500]}")
|
||
# 可能是上下文太长导致,记录 token 使用情况
|
||
if "usageMetadata" in result:
|
||
usage = result["usageMetadata"]
|
||
logger.warning(f"[视频识别] Token 使用: prompt={usage.get('promptTokenCount', 0)}, total={usage.get('totalTokenCount', 0)}")
|
||
|
||
logger.info(f"[视频识别] AI 响应完成,长度: {len(full_content)}")
|
||
|
||
# 如果没有内容,尝试简化重试
|
||
if not full_content:
|
||
logger.info("[视频识别] 尝试简化请求重试...")
|
||
return await self._call_ai_api_with_video_simple(
|
||
user_message or "请描述这个视频的内容",
|
||
video_base64,
|
||
video_config
|
||
)
|
||
|
||
return self._sanitize_llm_output(full_content)
|
||
|
||
except Exception as e:
|
||
logger.error(f"[视频识别] API 调用失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return ""
|
||
|
||
async def _call_ai_api_with_video_simple(self, user_message: str, video_base64: str, video_config: dict) -> str:
|
||
"""简化版视频识别 API 调用(不带上下文,用于降级重试)"""
|
||
try:
|
||
api_url = video_config.get("api_url", "https://api.functen.cn/v1beta/models")
|
||
api_key = video_config.get("api_key", self.config["api"]["api_key"])
|
||
model = video_config.get("model", "gemini-3-pro-preview")
|
||
|
||
full_url = f"{api_url}/{model}:generateContent"
|
||
|
||
# 简化请求:只发送用户问题和视频
|
||
payload = {
|
||
"contents": [
|
||
{
|
||
"parts": [
|
||
{"text": user_message},
|
||
{
|
||
"inline_data": {
|
||
"mime_type": "video/mp4",
|
||
"data": video_base64
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
"generationConfig": {
|
||
"maxOutputTokens": video_config.get("max_tokens", 8192)
|
||
}
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=video_config.get("timeout", 360))
|
||
|
||
logger.info(f"[视频识别-简化] 调用 API: {full_url}")
|
||
|
||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||
async with session.post(full_url, json=payload, headers=headers) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"[视频识别-简化] API 错误: {resp.status}, {error_text[:300]}")
|
||
return ""
|
||
|
||
result = await resp.json()
|
||
logger.info(f"[视频识别-简化] API 响应 keys: {list(result.keys())}")
|
||
|
||
# 提取文本
|
||
if "candidates" in result and result["candidates"]:
|
||
for candidate in result["candidates"]:
|
||
content = candidate.get("content", {})
|
||
for part in content.get("parts", []):
|
||
if "text" in part:
|
||
text = part["text"]
|
||
logger.info(f"[视频识别-简化] 成功,长度: {len(text)}")
|
||
return self._sanitize_llm_output(text)
|
||
|
||
logger.error(f"[视频识别-简化] 仍然没有 candidates: {str(result)[:300]}")
|
||
return ""
|
||
|
||
except Exception as e:
|
||
logger.error(f"[视频识别-简化] 失败: {e}")
|
||
return ""
|
||
|
||
def _should_reply_quote(self, message: dict, title_text: str) -> bool:
|
||
"""判断是否应该回复引用消息"""
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
# 检查群聊/私聊开关
|
||
if is_group and not self.config["behavior"]["reply_group"]:
|
||
return False
|
||
if not is_group and not self.config["behavior"]["reply_private"]:
|
||
return False
|
||
|
||
trigger_mode = self.config["behavior"]["trigger_mode"]
|
||
|
||
# all模式:回复所有消息
|
||
if trigger_mode == "all":
|
||
return True
|
||
|
||
# mention模式:检查是否@了机器人
|
||
if trigger_mode == "mention":
|
||
if is_group:
|
||
ats = message.get("Ats", [])
|
||
if not ats:
|
||
return False
|
||
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
|
||
|
||
return bot_wxid and bot_wxid in ats
|
||
else:
|
||
return True
|
||
|
||
# keyword模式:检查关键词
|
||
if trigger_mode == "keyword":
|
||
keywords = self.config["behavior"]["keywords"]
|
||
return any(kw in title_text for kw in keywords)
|
||
|
||
return False
|
||
|
||
async def _call_ai_api_with_image(
|
||
self,
|
||
user_message: str,
|
||
image_base64: str,
|
||
bot=None,
|
||
from_wxid: str = None,
|
||
chat_id: str = None,
|
||
nickname: str = "",
|
||
user_wxid: str = None,
|
||
is_group: bool = False,
|
||
*,
|
||
append_user_message: bool = True,
|
||
tool_query: str | None = None,
|
||
) -> str:
|
||
"""调用AI API(带图片)"""
|
||
api_config = self.config["api"]
|
||
all_tools = self._collect_tools()
|
||
tools = self._select_tools_for_message(all_tools, user_message=user_message, tool_query=tool_query)
|
||
logger.info(f"[图片] 收集到 {len(all_tools)} 个工具函数,本次启用 {len(tools)} 个")
|
||
if tools:
|
||
tool_names = [t["function"]["name"] for t in tools]
|
||
logger.info(f"[图片] 本次启用工具: {tool_names}")
|
||
|
||
# 构建消息列表
|
||
system_content = self.system_prompt
|
||
|
||
# 添加当前时间信息
|
||
current_time = datetime.now()
|
||
weekday_map = {
|
||
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
|
||
4: "星期五", 5: "星期六", 6: "星期日"
|
||
}
|
||
weekday = weekday_map[current_time.weekday()]
|
||
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
|
||
system_content += f"\n\n当前时间:{time_str}"
|
||
|
||
if nickname:
|
||
system_content += f"\n当前对话用户的昵称是:{nickname}"
|
||
# 加载持久记忆(与文本模式一致)
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
if memory_chat_id:
|
||
persistent_memories = self._get_persistent_memories(memory_chat_id)
|
||
if persistent_memories:
|
||
system_content += "\n\n【持久记忆】以下是用户要求你记住的重要信息:\n"
|
||
for m in persistent_memories:
|
||
mem_time = m['time'][:10] if m['time'] else ""
|
||
system_content += f"- [{mem_time}] {m['nickname']}: {m['content']}\n"
|
||
|
||
messages = [{"role": "system", "content": system_content}]
|
||
|
||
# 添加历史上下文
|
||
if is_group and from_wxid:
|
||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid or "")
|
||
history = await self._load_history(history_chat_id)
|
||
history = self._filter_history_by_window(history)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
recent_history = history[-max_context:] if len(history) > max_context else history
|
||
self._append_group_history_messages(messages, recent_history)
|
||
else:
|
||
if chat_id:
|
||
memory_messages = self._get_memory_messages(chat_id)
|
||
if memory_messages and len(memory_messages) > 1:
|
||
messages.extend(memory_messages[:-1])
|
||
|
||
# 添加当前用户消息(带图片)
|
||
if append_user_message:
|
||
text_value = f"[{nickname}] {user_message}" if is_group and nickname else user_message
|
||
messages.append({
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": text_value},
|
||
{"type": "image_url", "image_url": {"url": image_base64}}
|
||
]
|
||
})
|
||
|
||
payload = {
|
||
"model": api_config["model"],
|
||
"messages": messages,
|
||
"stream": True,
|
||
"max_tokens": api_config.get("max_tokens", 4096) # 防止回复被截断
|
||
}
|
||
|
||
if tools:
|
||
payload["tools"] = tools
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_config['api_key']}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=api_config["timeout"])
|
||
|
||
# 配置代理
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_username = proxy_config.get("username")
|
||
proxy_password = proxy_config.get("password")
|
||
|
||
if proxy_username and proxy_password:
|
||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||
else:
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
|
||
if PROXY_SUPPORT:
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
logger.debug(f"使用代理: {proxy_type}://{proxy_host}:{proxy_port}")
|
||
except Exception as e:
|
||
logger.warning(f"代理配置失败,将直连: {e}")
|
||
connector = None
|
||
else:
|
||
logger.warning("代理功能不可用(aiohttp_socks 未安装),将直连")
|
||
connector = None
|
||
|
||
try:
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
async with session.post(api_config["url"], json=payload, headers=headers) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"API返回错误状态码: {resp.status}, 响应: {error_text}")
|
||
raise Exception(f"API错误 {resp.status}: {error_text}")
|
||
|
||
# 流式接收响应
|
||
import json
|
||
full_content = ""
|
||
tool_calls_dict = {} # 使用字典来组装工具调用 {index: tool_call}
|
||
tool_call_hint_sent = False # 是否已发送工具调用提示
|
||
|
||
async for line in resp.content:
|
||
line = line.decode('utf-8').strip()
|
||
if not line or line == "data: [DONE]":
|
||
continue
|
||
|
||
if line.startswith("data: "):
|
||
try:
|
||
data = json.loads(line[6:])
|
||
choices = data.get("choices", [])
|
||
if not choices:
|
||
continue
|
||
|
||
delta = choices[0].get("delta", {})
|
||
content = delta.get("content", "")
|
||
if content:
|
||
full_content += content
|
||
|
||
# 收集工具调用(增量式组装)
|
||
if delta.get("tool_calls"):
|
||
# 第一次检测到工具调用时,如果有文本内容则立即发送
|
||
if not tool_call_hint_sent and bot and from_wxid:
|
||
tool_call_hint_sent = True
|
||
if full_content and full_content.strip():
|
||
preview = self._sanitize_llm_output(full_content)
|
||
if preview:
|
||
logger.info("[流式-图片] 检测到工具调用,先发送已有文本")
|
||
await bot.send_text(from_wxid, preview)
|
||
else:
|
||
logger.info("[流式-图片] 检测到工具调用,但文本清洗后为空(可能为思维链/无有效正文),跳过发送")
|
||
else:
|
||
logger.info("[流式-图片] 检测到工具调用,AI 未输出文本")
|
||
|
||
for tool_call_delta in delta["tool_calls"]:
|
||
index = tool_call_delta.get("index", 0)
|
||
|
||
# 初始化工具调用
|
||
if index not in tool_calls_dict:
|
||
tool_calls_dict[index] = {
|
||
"id": "",
|
||
"type": "function",
|
||
"function": {
|
||
"name": "",
|
||
"arguments": ""
|
||
}
|
||
}
|
||
|
||
# 更新 id
|
||
if "id" in tool_call_delta:
|
||
tool_calls_dict[index]["id"] = tool_call_delta["id"]
|
||
|
||
# 更新 type
|
||
if "type" in tool_call_delta:
|
||
tool_calls_dict[index]["type"] = tool_call_delta["type"]
|
||
|
||
# 更新 function
|
||
if "function" in tool_call_delta:
|
||
func_delta = tool_call_delta["function"]
|
||
if "name" in func_delta:
|
||
tool_calls_dict[index]["function"]["name"] += func_delta["name"]
|
||
if "arguments" in func_delta:
|
||
tool_calls_dict[index]["function"]["arguments"] += func_delta["arguments"]
|
||
except Exception as e:
|
||
logger.debug(f"解析流式数据失败: {e}")
|
||
pass
|
||
|
||
# 转换为列表
|
||
tool_calls_data = [tool_calls_dict[i] for i in sorted(tool_calls_dict.keys())] if tool_calls_dict else []
|
||
|
||
# 检查是否有函数调用
|
||
if tool_calls_data:
|
||
# 提示已在流式处理中发送,直接启动异步工具执行
|
||
logger.info(f"[图片] 启动异步工具执行,共 {len(tool_calls_data)} 个工具")
|
||
try:
|
||
await self._record_tool_calls_to_context(
|
||
tool_calls_data,
|
||
from_wxid=from_wxid,
|
||
chat_id=chat_id,
|
||
is_group=is_group,
|
||
user_wxid=user_wxid,
|
||
)
|
||
except Exception as e:
|
||
logger.debug(f"[图片] 记录工具调用到上下文失败: {e}")
|
||
asyncio.create_task(
|
||
self._execute_tools_async_with_image(
|
||
tool_calls_data, bot, from_wxid, chat_id,
|
||
user_wxid, nickname, is_group, messages, image_base64
|
||
)
|
||
)
|
||
return None
|
||
|
||
# 检查是否包含错误的工具调用格式
|
||
if "<tool_code>" in full_content or "print(" in full_content and "flow2_ai_image_generation" in full_content:
|
||
logger.warning("检测到模型输出了错误的工具调用格式,拦截并返回提示")
|
||
return "抱歉,我遇到了一些技术问题,请重新描述一下你的需求~"
|
||
|
||
return self._sanitize_llm_output(full_content)
|
||
|
||
except Exception as e:
|
||
logger.error(f"调用AI API失败: {e}")
|
||
raise
|
||
|
||
async def _send_chat_records(self, bot, from_wxid: str, title: str, content: str):
|
||
"""发送聊天记录格式消息"""
|
||
try:
|
||
import uuid
|
||
import time
|
||
import hashlib
|
||
import xml.etree.ElementTree as ET
|
||
|
||
is_group = from_wxid.endswith("@chatroom")
|
||
|
||
# 自动分割内容
|
||
max_length = 800
|
||
content_parts = []
|
||
|
||
if len(content) <= max_length:
|
||
content_parts = [content]
|
||
else:
|
||
lines = content.split('\n')
|
||
current_part = ""
|
||
|
||
for line in lines:
|
||
if len(current_part + line + '\n') > max_length:
|
||
if current_part:
|
||
content_parts.append(current_part.strip())
|
||
current_part = line + '\n'
|
||
else:
|
||
content_parts.append(line[:max_length])
|
||
current_part = line[max_length:] + '\n'
|
||
else:
|
||
current_part += line + '\n'
|
||
|
||
if current_part.strip():
|
||
content_parts.append(current_part.strip())
|
||
|
||
recordinfo = ET.Element("recordinfo")
|
||
info_el = ET.SubElement(recordinfo, "info")
|
||
info_el.text = title
|
||
is_group_el = ET.SubElement(recordinfo, "isChatRoom")
|
||
is_group_el.text = "1" if is_group else "0"
|
||
datalist = ET.SubElement(recordinfo, "datalist")
|
||
datalist.set("count", str(len(content_parts)))
|
||
desc_el = ET.SubElement(recordinfo, "desc")
|
||
desc_el.text = title
|
||
fromscene_el = ET.SubElement(recordinfo, "fromscene")
|
||
fromscene_el.text = "3"
|
||
|
||
for i, part in enumerate(content_parts):
|
||
di = ET.SubElement(datalist, "dataitem")
|
||
di.set("datatype", "1")
|
||
di.set("dataid", uuid.uuid4().hex)
|
||
|
||
src_local_id = str((int(time.time() * 1000) % 90000) + 10000)
|
||
new_msg_id = str(int(time.time() * 1000) + i)
|
||
create_time = str(int(time.time()) - len(content_parts) + i)
|
||
|
||
ET.SubElement(di, "srcMsgLocalid").text = src_local_id
|
||
ET.SubElement(di, "sourcetime").text = time.strftime("%Y-%m-%d %H:%M", time.localtime(int(create_time)))
|
||
ET.SubElement(di, "fromnewmsgid").text = new_msg_id
|
||
ET.SubElement(di, "srcMsgCreateTime").text = create_time
|
||
ET.SubElement(di, "sourcename").text = "AI助手"
|
||
ET.SubElement(di, "sourceheadurl").text = ""
|
||
ET.SubElement(di, "datatitle").text = part
|
||
ET.SubElement(di, "datadesc").text = part
|
||
ET.SubElement(di, "datafmt").text = "text"
|
||
ET.SubElement(di, "ischatroom").text = "1" if is_group else "0"
|
||
|
||
dataitemsource = ET.SubElement(di, "dataitemsource")
|
||
ET.SubElement(dataitemsource, "hashusername").text = hashlib.sha256(from_wxid.encode("utf-8")).hexdigest()
|
||
|
||
record_xml = ET.tostring(recordinfo, encoding="unicode")
|
||
|
||
appmsg_parts = [
|
||
"<appmsg appid=\"\" sdkver=\"0\">",
|
||
f"<title>{title}</title>",
|
||
f"<des>{title}</des>",
|
||
"<type>19</type>",
|
||
"<url>https://support.weixin.qq.com/cgi-bin/mmsupport-bin/readtemplate?t=page/favorite_record__w_unsupport</url>",
|
||
"<appattach><cdnthumbaeskey></cdnthumbaeskey><aeskey></aeskey></appattach>",
|
||
f"<recorditem><![CDATA[{record_xml}]]></recorditem>",
|
||
"<percent>0</percent>",
|
||
"</appmsg>"
|
||
]
|
||
appmsg_xml = "".join(appmsg_parts)
|
||
|
||
await bot._send_data_async(11214, {"to_wxid": from_wxid, "content": appmsg_xml})
|
||
logger.success(f"已发送聊天记录: {title}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"发送聊天记录失败: {e}")
|
||
|
||
async def _process_image_to_history(self, bot, message: dict, content: str) -> bool:
|
||
"""处理图片/表情包并保存描述到 history(通用方法)"""
|
||
from_wxid = message.get("FromWxid", "")
|
||
sender_wxid = message.get("SenderWxid", "")
|
||
is_group = message.get("IsGroup", False)
|
||
user_wxid = sender_wxid if is_group else from_wxid
|
||
|
||
# 只处理群聊
|
||
if not is_group:
|
||
return True
|
||
|
||
# 检查是否启用图片描述功能
|
||
image_desc_config = self.config.get("image_description", {})
|
||
if not image_desc_config.get("enabled", True):
|
||
return True
|
||
|
||
try:
|
||
# 解析XML获取图片信息
|
||
root = ET.fromstring(content)
|
||
|
||
# 尝试查找 <img> 标签(图片消息)或 <emoji> 标签(表情包)
|
||
img = root.find(".//img")
|
||
if img is None:
|
||
img = root.find(".//emoji")
|
||
|
||
if img is None:
|
||
return True
|
||
|
||
cdnbigimgurl = img.get("cdnbigimgurl", "") or img.get("cdnurl", "")
|
||
aeskey = img.get("aeskey", "")
|
||
|
||
# 检查是否是表情包(有 cdnurl 但可能没有 aeskey)
|
||
is_emoji = img.tag == "emoji"
|
||
|
||
if not cdnbigimgurl:
|
||
return True
|
||
|
||
# 图片消息需要 aeskey,表情包不需要
|
||
if not is_emoji and not aeskey:
|
||
return True
|
||
|
||
# 获取用户昵称 - 使用缓存优化
|
||
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
|
||
|
||
# 立即插入占位符到 history
|
||
placeholder_id = str(uuid.uuid4())
|
||
await self._add_to_history_with_id(from_wxid, nickname, "[图片: 处理中...]", placeholder_id)
|
||
logger.info(f"已插入图片占位符: {placeholder_id}")
|
||
|
||
# 将任务加入队列(不阻塞)
|
||
task = {
|
||
"bot": bot,
|
||
"from_wxid": from_wxid,
|
||
"nickname": nickname,
|
||
"cdnbigimgurl": cdnbigimgurl,
|
||
"aeskey": aeskey,
|
||
"is_emoji": is_emoji,
|
||
"placeholder_id": placeholder_id,
|
||
"config": image_desc_config
|
||
}
|
||
await self.image_desc_queue.put(task)
|
||
logger.info(f"图片描述任务已加入队列,当前队列长度: {self.image_desc_queue.qsize()}")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理图片消息失败: {e}")
|
||
return True
|
||
|
||
async def _image_desc_worker(self):
|
||
"""图片描述工作协程,从队列中取任务并处理"""
|
||
while True:
|
||
try:
|
||
task = await self.image_desc_queue.get()
|
||
except asyncio.CancelledError:
|
||
logger.info("图片描述工作协程收到取消信号,退出")
|
||
break
|
||
|
||
try:
|
||
await self._generate_and_update_image_description(
|
||
task["bot"], task["from_wxid"], task["nickname"],
|
||
task["cdnbigimgurl"], task["aeskey"], task["is_emoji"],
|
||
task["placeholder_id"], task["config"]
|
||
)
|
||
except asyncio.CancelledError:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"图片描述工作协程异常: {e}")
|
||
finally:
|
||
try:
|
||
self.image_desc_queue.task_done()
|
||
except ValueError:
|
||
pass
|
||
|
||
async def _generate_and_update_image_description(self, bot, from_wxid: str, nickname: str,
|
||
cdnbigimgurl: str, aeskey: str, is_emoji: bool,
|
||
placeholder_id: str, image_desc_config: dict):
|
||
"""异步生成图片描述并更新 history"""
|
||
try:
|
||
# 下载并编码图片/表情包
|
||
if is_emoji:
|
||
image_base64 = await self._download_emoji_and_encode(cdnbigimgurl)
|
||
else:
|
||
image_base64 = await self._download_and_encode_image(bot, cdnbigimgurl, aeskey)
|
||
|
||
if not image_base64:
|
||
logger.warning(f"{'表情包' if is_emoji else '图片'}下载失败")
|
||
await self._update_history_by_id(from_wxid, placeholder_id, "[图片]")
|
||
return
|
||
|
||
# 调用 AI 生成图片描述
|
||
description_prompt = image_desc_config.get("prompt", "请用一句话简洁地描述这张图片的主要内容。")
|
||
description = await self._generate_image_description(image_base64, description_prompt, image_desc_config)
|
||
|
||
if description:
|
||
cleaned_description = self._sanitize_llm_output(description)
|
||
await self._update_history_by_id(from_wxid, placeholder_id, f"[图片: {cleaned_description}]")
|
||
logger.success(f"已更新图片描述: {nickname} - {cleaned_description[:30]}...")
|
||
else:
|
||
await self._update_history_by_id(from_wxid, placeholder_id, "[图片]")
|
||
logger.warning(f"图片描述生成失败")
|
||
|
||
except asyncio.CancelledError:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"异步生成图片描述失败: {e}")
|
||
await self._update_history_by_id(from_wxid, placeholder_id, "[图片]")
|
||
|
||
@on_image_message(priority=15)
|
||
async def handle_image_message(self, bot, message: dict):
|
||
"""处理直接发送的图片消息(生成描述并保存到 history,不触发 AI 回复)"""
|
||
logger.info("AIChat: handle_image_message 被调用")
|
||
content = message.get("Content", "")
|
||
return await self._process_image_to_history(bot, message, content)
|
||
|
||
@on_emoji_message(priority=15)
|
||
async def handle_emoji_message(self, bot, message: dict):
|
||
"""处理表情包消息(生成描述并保存到 history,不触发 AI 回复)"""
|
||
logger.info("AIChat: handle_emoji_message 被调用")
|
||
content = message.get("Content", "")
|
||
return await self._process_image_to_history(bot, message, content)
|