""" AI 聊天插件 支持自定义模型、API 和人设 支持 Redis 存储对话历史和限流 """ import asyncio import tomllib import aiohttp import json import re import time from pathlib import Path from datetime import datetime from loguru import logger from utils.plugin_base import PluginBase from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message from utils.redis_cache import get_cache from utils.llm_tooling import ToolResult, collect_tools_with_plugins, collect_tools, get_tool_schema_map, validate_tool_arguments import xml.etree.ElementTree as ET import base64 import uuid # 可选导入代理支持 try: from aiohttp_socks import ProxyConnector PROXY_SUPPORT = True except ImportError: PROXY_SUPPORT = False logger.warning("aiohttp_socks 未安装,代理功能将不可用") class AIChat(PluginBase): """AI 聊天插件""" # 插件元数据 description = "AI 聊天插件,支持自定义模型和人设" author = "ShiHao" version = "1.0.0" def __init__(self): super().__init__() self.config = None self.system_prompt = "" self.memory = {} # 存储每个会话的记忆 {chat_id: [messages]} self.history_dir = None # 历史记录目录 self.history_locks = {} # 每个会话一把锁 self.image_desc_queue = asyncio.Queue() # 图片描述任务队列 self.image_desc_workers = [] # 工作协程列表 self.persistent_memory_db = None # 持久记忆数据库路径 self.store = None # ContextStore 实例(统一存储) self._chatroom_member_cache = {} # {chatroom_id: (ts, {wxid: display_name})} self._chatroom_member_cache_locks = {} # {chatroom_id: asyncio.Lock} self._chatroom_member_cache_ttl_seconds = 3600 # 群名片缓存1小时,减少协议 API 调用 self._intent_cache = {} # {normalized_text: (ts, intent)} async def async_init(self): """插件异步初始化""" # 读取配置 config_path = Path(__file__).parent / "config.toml" with open(config_path, "rb") as f: self.config = tomllib.load(f) # 读取人设 prompt_file = self.config["prompt"]["system_prompt_file"] prompt_path = Path(__file__).parent / "prompts" / prompt_file if prompt_path.exists(): with open(prompt_path, "r", encoding="utf-8") as f: self.system_prompt = f.read().strip() logger.success(f"已加载人设: {prompt_file}") else: logger.warning(f"人设文件不存在: {prompt_file},使用默认人设") self.system_prompt = "你是一个友好的 AI 助手。" # 检查代理配置 proxy_config = self.config.get("proxy", {}) if proxy_config.get("enabled", False): proxy_type = proxy_config.get("type", "socks5") proxy_host = proxy_config.get("host", "127.0.0.1") proxy_port = proxy_config.get("port", 7890) logger.info(f"AI 聊天插件已启用代理: {proxy_type}://{proxy_host}:{proxy_port}") # 初始化历史记录目录 history_config = self.config.get("history", {}) if history_config.get("enabled", True): history_dir_name = history_config.get("history_dir", "history") self.history_dir = Path(__file__).parent / history_dir_name self.history_dir.mkdir(exist_ok=True) logger.info(f"历史记录目录: {self.history_dir}") # 启动图片描述工作协程(并发数为2) for i in range(2): worker = asyncio.create_task(self._image_desc_worker()) self.image_desc_workers.append(worker) logger.info("已启动 2 个图片描述工作协程") # 初始化持久记忆数据库与统一存储 from utils.context_store import ContextStore db_dir = Path(__file__).parent / "data" db_dir.mkdir(exist_ok=True) self.persistent_memory_db = db_dir / "persistent_memory.db" self.store = ContextStore( self.config, self.history_dir, self.memory, self.history_locks, self.persistent_memory_db, ) self.store.init_persistent_memory_db() logger.info(f"AI 聊天插件已加载,模型: {self.config['api']['model']}") async def on_disable(self): """插件禁用时调用,清理后台任务和队列""" await super().on_disable() # 取消图片描述工作协程,避免重载后叠加 if self.image_desc_workers: for worker in self.image_desc_workers: worker.cancel() await asyncio.gather(*self.image_desc_workers, return_exceptions=True) self.image_desc_workers.clear() # 清空图片描述队列 try: while self.image_desc_queue and not self.image_desc_queue.empty(): self.image_desc_queue.get_nowait() self.image_desc_queue.task_done() except Exception: pass self.image_desc_queue = asyncio.Queue() logger.info("AIChat 已清理后台图片描述任务") def _add_persistent_memory(self, chat_id: str, chat_type: str, user_wxid: str, user_nickname: str, content: str) -> int: """添加持久记忆,返回记忆ID(委托 ContextStore)""" if not self.store: return -1 return self.store.add_persistent_memory(chat_id, chat_type, user_wxid, user_nickname, content) def _get_persistent_memories(self, chat_id: str) -> list: """获取指定会话的所有持久记忆(委托 ContextStore)""" if not self.store: return [] return self.store.get_persistent_memories(chat_id) def _delete_persistent_memory(self, chat_id: str, memory_id: int) -> bool: """删除指定的持久记忆(委托 ContextStore)""" if not self.store: return False return self.store.delete_persistent_memory(chat_id, memory_id) def _clear_persistent_memories(self, chat_id: str) -> int: """清空指定会话的所有持久记忆(委托 ContextStore)""" if not self.store: return 0 return self.store.clear_persistent_memories(chat_id) def _get_chat_id(self, from_wxid: str, sender_wxid: str = None, is_group: bool = False) -> str: """获取会话ID""" if is_group: # 群聊使用 "群ID:用户ID" 组合,确保每个用户有独立的对话记忆 user_wxid = sender_wxid or from_wxid return f"{from_wxid}:{user_wxid}" else: return sender_wxid or from_wxid # 私聊使用用户ID def _get_group_history_chat_id(self, from_wxid: str, user_wxid: str = None) -> str: """获取群聊 history 的会话ID(可配置为全群共享或按用户隔离)""" if not from_wxid: return "" history_config = (self.config or {}).get("history", {}) scope = str(history_config.get("scope", "chatroom") or "chatroom").strip().lower() if scope in ("per_user", "user", "peruser"): if not user_wxid: return from_wxid return self._get_chat_id(from_wxid, user_wxid, is_group=True) return from_wxid def _should_capture_group_history(self, *, is_triggered: bool) -> bool: """判断群聊消息是否需要写入 history(减少无关上下文污染)""" history_config = (self.config or {}).get("history", {}) capture = str(history_config.get("capture", "all") or "all").strip().lower() if capture in ("none", "off", "disable", "disabled"): return False if capture in ("reply", "ai_only", "triggered"): return bool(is_triggered) return True def _parse_history_timestamp(self, ts) -> float | None: if ts is None: return None if isinstance(ts, (int, float)): return float(ts) if isinstance(ts, str): s = ts.strip() if not s: return None try: return float(s) except Exception: pass try: return datetime.fromisoformat(s).timestamp() except Exception: return None return None def _filter_history_by_window(self, history: list) -> list: history_config = (self.config or {}).get("history", {}) window_seconds = history_config.get("context_window_seconds", None) if window_seconds is None: window_seconds = history_config.get("window_seconds", 0) try: window_seconds = float(window_seconds or 0) except Exception: window_seconds = 0 if window_seconds <= 0: return history cutoff = time.time() - window_seconds filtered = [] for msg in history or []: ts = self._parse_history_timestamp((msg or {}).get("timestamp")) if ts is None or ts >= cutoff: filtered.append(msg) return filtered def _sanitize_speaker_name(self, name: str) -> str: """清洗昵称,避免破坏历史格式(如 [name] 前缀)。""" if name is None: return "" s = str(name).strip() if not s: return "" s = s.replace("\r", " ").replace("\n", " ") s = re.sub(r"\s{2,}", " ", s) # 避免与历史前缀 [xxx] 冲突 s = s.replace("[", "(").replace("]", ")") return s.strip() def _combine_display_and_nickname(self, display_name: str, wechat_nickname: str) -> str: display_name = self._sanitize_speaker_name(display_name) wechat_nickname = self._sanitize_speaker_name(wechat_nickname) # 重要:群昵称(群名片) 与 微信昵称(全局) 是两个不同概念,尽量同时给 AI。 if display_name and wechat_nickname: return f"群昵称={display_name} | 微信昵称={wechat_nickname}" if display_name: return f"群昵称={display_name}" if wechat_nickname: return f"微信昵称={wechat_nickname}" return "" def _get_chatroom_member_lock(self, chatroom_id: str) -> asyncio.Lock: lock = self._chatroom_member_cache_locks.get(chatroom_id) if lock is None: lock = asyncio.Lock() self._chatroom_member_cache_locks[chatroom_id] = lock return lock async def _get_group_display_name(self, bot, chatroom_id: str, user_wxid: str, *, force_refresh: bool = False) -> str: """获取群名片(群内昵称)。失败时返回空串。""" if not chatroom_id or not user_wxid: return "" if not hasattr(bot, "get_chatroom_members"): return "" now = time.time() if not force_refresh: cached = self._chatroom_member_cache.get(chatroom_id) if cached: ts, member_map = cached if now - float(ts or 0) < float(self._chatroom_member_cache_ttl_seconds or 0): return self._sanitize_speaker_name(member_map.get(user_wxid, "")) lock = self._get_chatroom_member_lock(chatroom_id) async with lock: now = time.time() if not force_refresh: cached = self._chatroom_member_cache.get(chatroom_id) if cached: ts, member_map = cached if now - float(ts or 0) < float(self._chatroom_member_cache_ttl_seconds or 0): return self._sanitize_speaker_name(member_map.get(user_wxid, "")) try: # 群成员列表可能较大,避免长期阻塞消息处理 members = await asyncio.wait_for(bot.get_chatroom_members(chatroom_id), timeout=8) except Exception as e: logger.debug(f"获取群成员列表失败: {chatroom_id}, {e}") return "" member_map = {} try: for m in members or []: wxid = (m.get("wxid") or "").strip() if not wxid: continue display_name = m.get("display_name") or m.get("displayName") or "" member_map[wxid] = str(display_name or "").strip() except Exception as e: logger.debug(f"解析群成员列表失败: {chatroom_id}, {e}") self._chatroom_member_cache[chatroom_id] = (time.time(), member_map) return self._sanitize_speaker_name(member_map.get(user_wxid, "")) async def _get_user_display_label(self, bot, from_wxid: str, user_wxid: str, is_group: bool) -> str: """用于历史记录:群聊优先使用群名片,其次微信昵称。""" if not is_group: return "" wechat_nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group) group_display = await self._get_group_display_name(bot, from_wxid, user_wxid) return self._combine_display_and_nickname(group_display, wechat_nickname) or wechat_nickname or user_wxid async def _get_user_nickname(self, bot, from_wxid: str, user_wxid: str, is_group: bool) -> str: """ 获取用户昵称,优先使用 Redis 缓存 Args: bot: WechatHookClient 实例 from_wxid: 消息来源(群聊ID或私聊用户ID) user_wxid: 用户wxid is_group: 是否群聊 Returns: 用户昵称 """ if not is_group: return "" nickname = "" # 1. 优先从 Redis 缓存获取 redis_cache = get_cache() if redis_cache and redis_cache.enabled: cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid) if cached_info and cached_info.get("nickname"): logger.debug(f"[缓存命中] 用户昵称: {user_wxid} -> {cached_info['nickname']}") return cached_info["nickname"] # 2. 缓存未命中,调用 API 获取 try: user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid) if user_info and user_info.get("nickName", {}).get("string"): nickname = user_info["nickName"]["string"] # 存入缓存 if redis_cache and redis_cache.enabled: redis_cache.set_user_info(from_wxid, user_wxid, user_info) logger.debug(f"[已缓存] 用户昵称: {user_wxid} -> {nickname}") return nickname except Exception as e: logger.warning(f"API获取用户昵称失败: {e}") # 3. 从 MessageLogger 数据库查询 if not nickname: try: from plugins.MessageLogger.main import MessageLogger msg_logger = MessageLogger.get_instance() if msg_logger: with msg_logger.get_db_connection() as conn: with conn.cursor() as cursor: cursor.execute( "SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1", (user_wxid,) ) result = cursor.fetchone() if result: nickname = result[0] except Exception as e: logger.debug(f"从数据库获取昵称失败: {e}") # 4. 最后降级使用 wxid if not nickname: nickname = user_wxid or "未知用户" return nickname def _check_rate_limit(self, user_wxid: str) -> tuple: """ 检查用户是否超过限流 Args: user_wxid: 用户wxid Returns: (是否允许, 剩余次数, 重置时间秒数) """ rate_limit_config = self.config.get("rate_limit", {}) if not rate_limit_config.get("enabled", True): return (True, 999, 0) redis_cache = get_cache() if not redis_cache or not redis_cache.enabled: return (True, 999, 0) # Redis 不可用时不限流 limit = rate_limit_config.get("ai_chat_limit", 20) window = rate_limit_config.get("ai_chat_window", 60) return redis_cache.check_rate_limit(user_wxid, limit, window, "ai_chat") def _add_to_memory(self, chat_id: str, role: str, content, image_base64: str = None): """ 添加消息到记忆 Args: chat_id: 会话ID role: 角色 (user/assistant) content: 消息内容(可以是字符串或列表) image_base64: 可选的图片base64数据 """ if not self.store: return self.store.add_private_message(chat_id, role, content, image_base64=image_base64) def _get_memory_messages(self, chat_id: str) -> list: """获取记忆中的消息""" if not self.store: return [] return self.store.get_private_messages(chat_id) def _clear_memory(self, chat_id: str): """清空指定会话的记忆""" if not self.store: return self.store.clear_private_messages(chat_id) async def _download_and_encode_image(self, bot, cdnurl: str, aeskey: str) -> str: """下载图片并转换为base64,优先从缓存获取""" try: # 1. 优先从 Redis 缓存获取 from utils.redis_cache import RedisCache redis_cache = get_cache() if redis_cache and redis_cache.enabled: media_key = RedisCache.generate_media_key(cdnurl, aeskey) if media_key: cached_data = redis_cache.get_cached_media(media_key, "image") if cached_data: logger.debug(f"[缓存命中] 图片从 Redis 获取: {media_key[:20]}...") return cached_data # 2. 缓存未命中,下载图片 logger.debug(f"[缓存未命中] 开始下载图片...") temp_dir = Path(__file__).parent / "temp" temp_dir.mkdir(exist_ok=True) filename = f"temp_{uuid.uuid4().hex[:8]}.jpg" save_path = str((temp_dir / filename).resolve()) success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=2) if not success: success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=1) if not success: return "" # 等待文件写入完成 import os import asyncio for _ in range(20): # 最多等待10秒 if os.path.exists(save_path) and os.path.getsize(save_path) > 0: break await asyncio.sleep(0.5) if not os.path.exists(save_path): return "" with open(save_path, "rb") as f: image_data = base64.b64encode(f.read()).decode() base64_result = f"data:image/jpeg;base64,{image_data}" # 3. 缓存到 Redis(供后续使用) if redis_cache and redis_cache.enabled and media_key: redis_cache.cache_media(media_key, base64_result, "image", ttl=300) logger.debug(f"[已缓存] 图片缓存到 Redis: {media_key[:20]}...") try: Path(save_path).unlink() except: pass return base64_result except Exception as e: logger.error(f"下载图片失败: {e}") return "" async def _download_emoji_and_encode(self, cdn_url: str, max_retries: int = 3) -> str: """下载表情包并转换为base64(HTTP 直接下载,带重试机制),优先从缓存获取""" # 替换 HTML 实体 cdn_url = cdn_url.replace("&", "&") # 1. 优先从 Redis 缓存获取 from utils.redis_cache import RedisCache redis_cache = get_cache() media_key = RedisCache.generate_media_key(cdnurl=cdn_url) if redis_cache and redis_cache.enabled and media_key: cached_data = redis_cache.get_cached_media(media_key, "emoji") if cached_data: logger.debug(f"[缓存命中] 表情包从 Redis 获取: {media_key[:20]}...") return cached_data # 2. 缓存未命中,下载表情包 logger.debug(f"[缓存未命中] 开始下载表情包...") temp_dir = Path(__file__).parent / "temp" temp_dir.mkdir(exist_ok=True) filename = f"temp_{uuid.uuid4().hex[:8]}.gif" save_path = temp_dir / filename last_error = None for attempt in range(max_retries): try: # 使用 aiohttp 下载,每次重试增加超时时间 timeout = aiohttp.ClientTimeout(total=30 + attempt * 15) # 配置代理 connector = None proxy_config = self.config.get("proxy", {}) if proxy_config.get("enabled", False): proxy_type = proxy_config.get("type", "socks5").upper() proxy_host = proxy_config.get("host", "127.0.0.1") proxy_port = proxy_config.get("port", 7890) proxy_username = proxy_config.get("username") proxy_password = proxy_config.get("password") if proxy_username and proxy_password: proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}" else: proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}" if PROXY_SUPPORT: try: connector = ProxyConnector.from_url(proxy_url) except: connector = None async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: async with session.get(cdn_url) as response: if response.status == 200: content = await response.read() if len(content) == 0: logger.warning(f"表情包下载内容为空,重试 {attempt + 1}/{max_retries}") continue # 编码为 base64 image_data = base64.b64encode(content).decode() logger.debug(f"表情包下载成功,大小: {len(content)} 字节") base64_result = f"data:image/gif;base64,{image_data}" # 3. 缓存到 Redis(供后续使用) if redis_cache and redis_cache.enabled and media_key: redis_cache.cache_media(media_key, base64_result, "emoji", ttl=300) logger.debug(f"[已缓存] 表情包缓存到 Redis: {media_key[:20]}...") return base64_result else: logger.warning(f"表情包下载失败,状态码: {response.status},重试 {attempt + 1}/{max_retries}") except asyncio.TimeoutError: last_error = "请求超时" logger.warning(f"表情包下载超时,重试 {attempt + 1}/{max_retries}") except aiohttp.ClientError as e: last_error = str(e) logger.warning(f"表情包下载网络错误: {e},重试 {attempt + 1}/{max_retries}") except Exception as e: last_error = str(e) logger.warning(f"表情包下载异常: {e},重试 {attempt + 1}/{max_retries}") # 重试前等待(指数退避) if attempt < max_retries - 1: await asyncio.sleep(1 * (attempt + 1)) logger.error(f"表情包下载失败,已重试 {max_retries} 次: {last_error}") return "" async def _generate_image_description(self, image_base64: str, prompt: str, config: dict) -> str: """ 使用 AI 生成图片描述 Args: image_base64: 图片的 base64 数据 prompt: 描述提示词 config: 图片描述配置 Returns: 图片描述文本,失败返回空字符串 """ api_config = self.config["api"] description_model = config.get("model", api_config["model"]) # 构建消息 messages = [ { "role": "user", "content": [ {"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": image_base64}} ] } ] payload = { "model": description_model, "messages": messages, "max_tokens": config.get("max_tokens", 1000), "stream": True } headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_config['api_key']}" } max_retries = int(config.get("retries", 2)) last_error = None for attempt in range(max_retries + 1): try: timeout = aiohttp.ClientTimeout(total=api_config["timeout"]) # 配置代理(每次重试单独构造 connector) connector = None proxy_config = self.config.get("proxy", {}) if proxy_config.get("enabled", False): proxy_type = proxy_config.get("type", "socks5").upper() proxy_host = proxy_config.get("host", "127.0.0.1") proxy_port = proxy_config.get("port", 7890) proxy_username = proxy_config.get("username") proxy_password = proxy_config.get("password") if proxy_username and proxy_password: proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}" else: proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}" if PROXY_SUPPORT: try: connector = ProxyConnector.from_url(proxy_url) except Exception as e: logger.warning(f"代理配置失败,将直连: {e}") connector = None async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: async with session.post( api_config["url"], json=payload, headers=headers ) as resp: if resp.status != 200: error_text = await resp.text() raise Exception(f"图片描述 API 返回错误: {resp.status}, {error_text[:200]}") # 流式接收响应 description = "" async for line in resp.content: line = line.decode('utf-8').strip() if not line or line == "data: [DONE]": continue if line.startswith("data: "): try: data = json.loads(line[6:]) delta = data.get("choices", [{}])[0].get("delta", {}) content = delta.get("content", "") if content: description += content except Exception: pass logger.debug(f"图片描述生成成功: {description}") return description.strip() except asyncio.CancelledError: raise except (aiohttp.ClientError, asyncio.TimeoutError) as e: last_error = str(e) if attempt < max_retries: logger.warning(f"图片描述网络错误: {e},重试 {attempt + 1}/{max_retries}") await asyncio.sleep(1 * (attempt + 1)) continue except Exception as e: last_error = str(e) if attempt < max_retries: logger.warning(f"图片描述生成异常: {e},重试 {attempt + 1}/{max_retries}") await asyncio.sleep(1 * (attempt + 1)) continue logger.error(f"生成图片描述失败,已重试 {max_retries + 1} 次: {last_error}") return "" def _collect_tools_with_plugins(self) -> dict: """收集所有插件的 LLM 工具,并保留来源插件名""" from utils.plugin_manager import PluginManager tools_config = self.config.get("tools", {}) return collect_tools_with_plugins(tools_config, PluginManager().plugins) def _collect_tools(self): """收集所有插件的LLM工具(支持白名单/黑名单过滤)""" from utils.plugin_manager import PluginManager tools_config = self.config.get("tools", {}) return collect_tools(tools_config, PluginManager().plugins) def _get_tool_schema_map(self, tools_map: dict | None = None) -> dict: """构建工具名到参数 schema 的映射""" tools_map = tools_map or self._collect_tools_with_plugins() return get_tool_schema_map(tools_map) def _validate_tool_arguments(self, tool_name: str, arguments: dict, schema: dict) -> tuple: """轻量校验并补全默认参数""" return validate_tool_arguments(tool_name, arguments, schema) async def _handle_list_prompts(self, bot, from_wxid: str): """处理人设列表指令""" try: prompts_dir = Path(__file__).parent / "prompts" # 获取所有 .txt 文件 if not prompts_dir.exists(): await bot.send_text(from_wxid, "❌ prompts 目录不存在") return txt_files = sorted(prompts_dir.glob("*.txt")) if not txt_files: await bot.send_text(from_wxid, "❌ 没有找到任何人设文件") return # 构建列表消息 current_file = self.config["prompt"]["system_prompt_file"] msg = "📋 可用人设列表:\n\n" for i, file_path in enumerate(txt_files, 1): filename = file_path.name # 标记当前使用的人设 if filename == current_file: msg += f"{i}. {filename} ✅\n" else: msg += f"{i}. {filename}\n" msg += f"\n💡 使用方法:/切人设 文件名.txt" await bot.send_text(from_wxid, msg) logger.info("已发送人设列表") except Exception as e: logger.error(f"获取人设列表失败: {e}") await bot.send_text(from_wxid, f"❌ 获取人设列表失败: {str(e)}") def _estimate_tokens(self, text: str) -> int: """ 估算文本的 token 数量 简单估算规则: - 中文:约 1.5 字符 = 1 token - 英文:约 4 字符 = 1 token - 混合文本取平均 """ if not text: return 0 # 统计中文字符数 chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') # 其他字符数 other_chars = len(text) - chinese_chars # 估算 token 数 chinese_tokens = chinese_chars / 1.5 other_tokens = other_chars / 4 return int(chinese_tokens + other_tokens) def _estimate_message_tokens(self, message: dict) -> int: """估算单条消息的 token 数""" content = message.get("content", "") if isinstance(content, str): return self._estimate_tokens(content) elif isinstance(content, list): # 多模态消息 total = 0 for item in content: if item.get("type") == "text": total += self._estimate_tokens(item.get("text", "")) elif item.get("type") == "image_url": # 图片按 85 token 估算(OpenAI 低分辨率图片) total += 85 return total return 0 def _extract_text_from_multimodal(self, content) -> str: """从多模态 content 中提取文本,模型不支持时用于降级""" if isinstance(content, list): texts = [item.get("text", "") for item in content if item.get("type") == "text"] text = "".join(texts).strip() return text or "[图片]" if content is None: return "" return str(content) def _sanitize_llm_output(self, text) -> str: """ 清洗 LLM 输出,尽量满足:不输出思维链、不使用 Markdown。 说明:提示词并非强约束,因此在所有“发给用户/写入上下文”的出口统一做后处理。 """ if text is None: return "" raw = str(text) cleaned = raw output_cfg = (self.config or {}).get("output", {}) strip_thinking = output_cfg.get("strip_thinking", True) strip_markdown = output_cfg.get("strip_markdown", True) # 先做一次 Markdown 清理,避免 “**思考过程:**/### 思考” 这类包裹导致无法识别 if strip_markdown: cleaned = self._strip_markdown_syntax(cleaned) if strip_thinking: cleaned = self._strip_thinking_content(cleaned) # 清理模型偶发输出的“文本工具调用”痕迹(如 tavilywebsearch{query:...} / ) # 这些内容既不是正常回复,也会破坏“工具只能用 Function Calling”的约束 try: cleaned = re.sub(r"", "", cleaned, flags=re.IGNORECASE) cleaned = re.sub( r"(?:展开阅读下文\\s*)?(?:tavilywebsearch|tavily_web_search|web_search)\\s*\\{[^{}]{0,1500}\\}", "", cleaned, flags=re.IGNORECASE, ) cleaned = re.sub( r"(?:tavilywebsearch|tavily_web_search|web_search)\\s*\\([^\\)]{0,1500}\\)", "", cleaned, flags=re.IGNORECASE, ) cleaned = cleaned.replace("展开阅读下文", "") cleaned = re.sub( r"(已触发工具处理:[^)]{0,300}结果将发送到聊天中。)", "", cleaned, ) except Exception: pass # 再跑一轮:部分模型会把“思考/最终”标记写成 Markdown,或在剥离标签后才露出标记 if strip_markdown: cleaned = self._strip_markdown_syntax(cleaned) if strip_thinking: cleaned = self._strip_thinking_content(cleaned) cleaned = cleaned.strip() # 兜底:清洗后仍残留明显“思维链/大纲”标记时,再尝试一次“抽取最终段” if strip_thinking and cleaned and self._contains_thinking_markers(cleaned): extracted = self._extract_after_last_answer_marker(cleaned) if not extracted: extracted = self._extract_final_answer_from_outline(cleaned) if extracted: cleaned = extracted.strip() # 仍残留标记:尽量选取最后一个“不含标记”的段落作为最终回复 if cleaned and self._contains_thinking_markers(cleaned): parts = [p.strip() for p in re.split(r"\n{2,}", cleaned) if p.strip()] for p in reversed(parts): if not self._contains_thinking_markers(p): cleaned = p break cleaned = cleaned.strip() # 最终兜底:仍然像思维链就直接丢弃(宁可不发也不要把思维链发出去) if strip_thinking and cleaned and self._contains_thinking_markers(cleaned): return "" if cleaned: return cleaned raw_stripped = raw.strip() # 清洗后为空时,不要回退到包含思维链标记的原文(避免把 ... 直接发出去) if strip_thinking and self._contains_thinking_markers(raw_stripped): return "" return raw_stripped def _contains_thinking_markers(self, text: str) -> bool: """粗略判断文本是否包含明显的“思考/推理”外显标记,用于决定是否允许回退原文。""" if not text: return False lowered = text.lower() tag_tokens = ( " str | None: """从文本中抽取最后一个“最终/输出/答案”标记后的内容(不要求必须是编号大纲)。""" if not text: return None # 1) 明确的行首标记:Text:/Final Answer:/输出: ... marker_re = re.compile( r"(?im)^\s*(?:\d+\s*[\.\、::\)、))\-–—]\s*)?" r"(?:text|final\s*answer|final\s*response|final\s*output|final|output|answer|response|输出|最终回复|最终答案|最终)\s*[::]\s*" ) matches = list(marker_re.finditer(text)) if matches: candidate = text[matches[-1].end():].strip() if candidate: return candidate # 2) JSON/YAML 风格:final: ... / \"final\": \"...\" kv_re = re.compile( r"(?im)^\s*\"?(?:final|answer|response|output|text|最终|最终回复|最终答案|输出)\"?\s*[::]\s*" ) kv_matches = list(kv_re.finditer(text)) if kv_matches: candidate = text[kv_matches[-1].end():].strip() if candidate: return candidate # 3) 纯 JSON 对象(尝试解析) stripped = text.strip() if stripped.startswith("{") and stripped.endswith("}"): try: obj = json.loads(stripped) if isinstance(obj, dict): for key in ("final", "answer", "response", "output", "text"): v = obj.get(key) if isinstance(v, str) and v.strip(): return v.strip() except Exception: pass return None def _extract_final_answer_from_outline(self, text: str) -> str | None: """从“分析/草稿/输出”这类结构化大纲中提取最终回复正文(用于拦截思维链)。""" if not text: return None # 至少包含多个“1./2./3.”段落,才认为可能是大纲/思维链输出 heading_re = re.compile(r"(?m)^\s*\d+\s*[\.\、::\)、))\-–—]\s*\S+") if len(heading_re.findall(text)) < 2: return None # 优先:提取最后一个 “Text:/Final Answer:/Output:” 之后的内容 marker_re = re.compile( r"(?im)^\s*(?:\d+\s*[\.\、::\)、))\-–—]\s*)?" r"(?:text|final\s*answer|final\s*response|final\s*output|output|answer|response|输出|最终回复|最终答案)\s*[::]\s*" ) matches = list(marker_re.finditer(text)) if matches: candidate = text[matches[-1].end():].strip() if candidate: return candidate # 没有明确的最终标记时,仅在包含“分析/思考/草稿/输出”等元信息关键词的情况下兜底抽取 lowered = text.lower() outline_keywords = ( "analyze", "analysis", "reasoning", "internal monologue", "mind space", "draft", "drafting", "outline", "plan", "steps", "formulating response", "final polish", "final answer", "output generation", "system prompt", "chat log", "previous turn", "current situation", ) cn_keywords = ("思考", "分析", "推理", "思维链", "草稿", "计划", "步骤", "输出", "最终") if not any(k in lowered for k in outline_keywords) and not any(k in text for k in cn_keywords): return None # 次选:取最后一个非空段落(避免返回整段大纲) parts = [p.strip() for p in re.split(r"\n{2,}", text) if p.strip()] if not parts: return None last = parts[-1] if len(heading_re.findall(last)) == 0: return last return None def _strip_thinking_content(self, text: str) -> str: """移除常见的“思考/推理”外显内容(如 ...、思考:...)。""" if not text: return "" t = text.replace("\r\n", "\n").replace("\r", "\n") # 1) 先移除显式标签块(常见于某些推理模型) thinking_tags = ("think", "analysis", "reasoning", "thought", "thinking", "thoughts", "scratchpad", "reflection") for tag in thinking_tags: t = re.sub(rf"<{tag}\b[^>]*>.*?", "", t, flags=re.IGNORECASE | re.DOTALL) # 兼容被转义的标签(<think>...</think>) t = re.sub(rf"<{tag}\b[^&]*>.*?</{tag}>", "", t, flags=re.IGNORECASE | re.DOTALL) # 1.1) 兜底:流式/截断导致标签未闭合时,若开头出现思考标签,直接截断后续内容 m = re.search(r"<(think|analysis|reasoning|thought|thinking|thoughts|scratchpad|reflection)\b[^>]*>", t, flags=re.IGNORECASE) if m and m.start() < 200: t = t[: m.start()].rstrip() m2 = re.search(r"<(think|analysis|reasoning|thought|thinking|thoughts|scratchpad|reflection)\b[^&]*>", t, flags=re.IGNORECASE) if m2 and m2.start() < 200: t = t[: m2.start()].rstrip() # 2) 再处理“思考:.../最终:...”这种分段格式(尽量只剥离前置思考) lines = t.split("\n") if not lines: return t # 若文本中包含明显的“最终/输出/答案”标记(不限是否编号),直接抽取最后一段,避免把大纲整体发出去 if self._contains_thinking_markers(t): extracted_anywhere = self._extract_after_last_answer_marker(t) if extracted_anywhere: return extracted_anywhere reasoning_kw = ( r"思考过程|推理过程|分析过程|思考|分析|推理|思路|内心独白|内心os|思维链|" r"chain\s*of\s*thought|reasoning|analysis|analyze|thinking|thoughts|thought\s*process|scratchpad|plan|steps|draft|outline" ) answer_kw = r"最终答案|最终回复|最终|回答|回复|答复|结论|输出|final(?:\s*answer)?|final\s*response|final\s*output|answer|response|output|text" # 兼容: # - 思考:... / 最终回复:... # - 【思考】... / 【最终】... # - **思考过程:**(Markdown 会在外层先被剥离) reasoning_start = re.compile( rf"^\s*(?:\d+\s*[\.\、::\)、))\-–—]\s*)?(?:[-*•]+\s*)?" rf"(?:【\s*(?:{reasoning_kw})\s*】\s*[::]?\s*|(?:{reasoning_kw})(?:\s*】)?\s*(?:[::]|$|\s+))", re.IGNORECASE, ) answer_start = re.compile( rf"^\s*(?:\d+\s*[\.\、::\)、))\-–—]\s*)?(?:[-*•]+\s*)?" rf"(?:【\s*(?:{answer_kw})\s*】\s*[::]?\s*|(?:{answer_kw})(?:\s*】)?\s*(?:[::]|$)\s*)", re.IGNORECASE, ) # 2.0) 若文本开头就是“最终回复:/Final answer:”之类,直接去掉标记(不强依赖出现“思考块”) for idx, line in enumerate(lines): if line.strip() == "": continue m0 = answer_start.match(line) if m0: lines[idx] = line[m0.end():].lstrip() break has_reasoning = any(reasoning_start.match(line) for line in lines[:10]) has_answer_marker = any(answer_start.match(line) for line in lines) # 2.1) 若同时存在“思考块 + 答案标记”,跳过思考块直到答案标记 if has_reasoning and has_answer_marker: out_lines: list[str] = [] skipping = False answer_started = False for line in lines: if answer_started: out_lines.append(line) continue if not skipping and reasoning_start.match(line): skipping = True continue if skipping: m = answer_start.match(line) if m: answer_started = True skipping = False out_lines.append(line[m.end():].lstrip()) continue m = answer_start.match(line) if m: answer_started = True out_lines.append(line[m.end():].lstrip()) else: out_lines.append(line) t2 = "\n".join(out_lines).strip() return t2 if t2 else t # 2.2) 兜底:若开头就是“思考:”,尝试去掉第一段(到第一个空行) if has_reasoning: first_blank_idx = None for idx, line in enumerate(lines): if line.strip() == "": first_blank_idx = idx break if first_blank_idx is not None and first_blank_idx + 1 < len(lines): candidate = "\n".join(lines[first_blank_idx + 1 :]).strip() if candidate: return candidate # 2.3) 兜底:识别“1. Analyze... 2. ... 6. Output ... Text: ...”这类思维链大纲并抽取最终正文 outline_extracted = self._extract_final_answer_from_outline("\n".join(lines).strip()) if outline_extracted: return outline_extracted # 将行级处理结果合回文本(例如去掉开头的“最终回复:”标记) t = "\n".join(lines).strip() # 3) 兼容 ... 这类包裹(保留正文,去掉标签) t = re.sub(r"", "", t, flags=re.IGNORECASE).strip() return t def _strip_markdown_syntax(self, text: str) -> str: """将常见 Markdown 标记转换为更像纯文本的形式(保留内容,移除格式符)。""" if not text: return "" t = text.replace("\r\n", "\n").replace("\r", "\n") # 去掉代码块围栏(保留内容) t = re.sub(r"```[^\n]*\n", "", t) t = t.replace("```", "") # 图片/链接:![alt](url) / [text](url) def _md_image_repl(m: re.Match) -> str: alt = (m.group(1) or "").strip() url = (m.group(2) or "").strip() if alt and url: return f"{alt}({url})" return url or alt or "" def _md_link_repl(m: re.Match) -> str: label = (m.group(1) or "").strip() url = (m.group(2) or "").strip() if label and url: return f"{label}({url})" return url or label or "" t = re.sub(r"!\[([^\]]*)\]\(([^)]+)\)", _md_image_repl, t) t = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", _md_link_repl, t) # 行级标记:标题、引用、分割线 cleaned_lines: list[str] = [] for line in t.split("\n"): line = re.sub(r"^\s{0,3}#{1,6}\s+", "", line) # 标题 line = re.sub(r"^\s{0,3}>\s?", "", line) # 引用 if re.match(r"^\s*(?:-{3,}|\*{3,}|_{3,})\s*$", line): continue # 分割线整行移除 cleaned_lines.append(line) t = "\n".join(cleaned_lines) # 行内代码:`code` t = re.sub(r"`([^`]+)`", r"\1", t) # 粗体/删除线(保留文本) t = t.replace("**", "") t = t.replace("__", "") t = t.replace("~~", "") # 斜体(保留文本,避免误伤乘法/通配符,仅处理成对包裹) t = re.sub(r"(? str: try: with open("main_config.toml", "rb") as f: main_config = tomllib.load(f) nickname = main_config.get("Bot", {}).get("nickname", "") return nickname or "机器人" except Exception: return "机器人" def _tool_call_to_action_text(self, function_name: str, arguments: dict) -> str: args = arguments if isinstance(arguments, dict) else {} if function_name == "query_weather": city = str(args.get("city") or "").strip() return f"查询{city}天气" if city else "查询天气" if function_name == "register_city": city = str(args.get("city") or "").strip() return f"注册城市{city}" if city else "注册城市" if function_name == "user_signin": return "签到" if function_name == "check_profile": return "查询个人信息" return f"执行{function_name}" def _build_tool_calls_context_note(self, tool_calls_data: list) -> str: actions: list[str] = [] for tool_call in tool_calls_data or []: function_name = tool_call.get("function", {}).get("name", "") if not function_name: continue arguments_str = tool_call.get("function", {}).get("arguments", "{}") try: arguments = json.loads(arguments_str) if arguments_str else {} except Exception: arguments = {} actions.append(self._tool_call_to_action_text(function_name, arguments)) if not actions: return "(已触发工具处理:上一条请求。结果将发送到聊天中。)" return f"(已触发工具处理:{';'.join(actions)}。结果将发送到聊天中。)" async def _record_tool_calls_to_context( self, tool_calls_data: list, *, from_wxid: str, chat_id: str, is_group: bool, user_wxid: str | None = None, ): note = self._build_tool_calls_context_note(tool_calls_data) if chat_id: self._add_to_memory(chat_id, "assistant", note) if is_group and from_wxid: history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid or "") await self._add_to_history(history_chat_id, self._get_bot_nickname(), note, role="assistant", sender_wxid=user_wxid or None) def _extract_tool_intent_text(self, user_message: str, tool_query: str | None = None) -> str: text = tool_query if tool_query is not None else user_message text = str(text or "").strip() if not text: return "" # 对“聊天记录/视频”等组合消息,尽量只取用户真实提问部分,避免历史文本触发工具误判 markers = ( "[用户的问题]:", "[用户的问题]:", "[用户的问题]\n", "[用户的问题]", ) for marker in markers: if marker in text: text = text.rsplit(marker, 1)[-1].strip() return text def _looks_like_info_query(self, text: str) -> bool: t = str(text or "").strip().lower() if not t: return False # 太短的消息不值得额外走一轮分类 if len(t) < 6: return False # 疑问/求评价/求推荐类 if any(x in t for x in ("?", "?")): return True if re.search(r"(什么|咋|怎么|如何|为啥|为什么|哪|哪里|哪个|多少|推荐|值不值得|值得吗|好不好|靠谱吗|评价|口碑|怎么样|如何评价|近况|最新|最近)", t): return True if re.search(r"\\b(what|who|when|where|why|how|details?|impact|latest|news|review|rating|price|info|information)\\b", t): return True # 明确的实体/对象询问(公会/游戏/公司/项目等) if re.search(r"(公会|战队|服务器|区服|游戏|公司|品牌|产品|软件|插件|项目|平台|up主|主播|作者|电影|电视剧|小说)", t) and len(t) >= 8: return True return False def _extract_legacy_text_search_tool_call(self, text: str) -> tuple[str, dict] | None: """ 解析模型偶发输出的“文本工具调用”写法(例如 tavilywebsearch{query:...}),并转换为真实工具调用参数。 """ raw = str(text or "") if not raw: return None # 去掉 之类的控制标记 cleaned = re.sub(r"", "", raw, flags=re.IGNORECASE) m = re.search( r"(?i)\\b(?Ptavilywebsearch|tavily_web_search|web_search)\\s*\\{\\s*query\\s*[:=]\\s*(?P[^{}]{1,800})\\}", cleaned, ) if not m: m = re.search( r"(?i)\\b(?Ptavilywebsearch|tavily_web_search|web_search)\\s*\\(\\s*query\\s*[:=]\\s*(?P[^\\)]{1,800})\\)", cleaned, ) if not m: return None tool = str(m.group("tool") or "").strip().lower() query = str(m.group("q") or "").strip().strip("\"'`") if not query: return None # 统一映射到项目实际存在的工具名 if tool in ("tavilywebsearch", "tavily_web_search"): tool_name = "tavily_web_search" else: tool_name = "web_search" return tool_name, {"query": query[:400]} def _intent_to_allowed_tool_names(self, intent: str) -> set[str]: intent = str(intent or "").strip().lower() mapping = { "search": {"tavily_web_search", "web_search"}, "draw": { "nano_ai_image_generation", "flow2_ai_image_generation", "jimeng_ai_image_generation", "kiira2_ai_image_generation", "generate_image", }, "weather": {"query_weather"}, "register_city": {"register_city"}, "signin": {"user_signin"}, "profile": {"check_profile"}, "news": {"get_daily_news"}, "music": {"search_music"}, "playlet": {"search_playlet"}, "kfc": {"get_kfc"}, "fabing": {"get_fabing"}, "chat": set(), } return set(mapping.get(intent, set())) def _should_run_intent_router(self, intent_text: str) -> bool: cfg = (self.config or {}).get("intent_router", {}) if not cfg.get("enabled", False): return False mode = str(cfg.get("mode", "hybrid") or "hybrid").strip().lower() if mode == "always": return True # hybrid:只在像“信息查询/求评价”的问题时触发,避免闲聊额外增加延迟 return self._looks_like_info_query(intent_text) def _extract_intent_from_llm_output(self, content: str) -> str: raw = str(content or "").strip() if not raw: return "" # 尝试直接 JSON try: obj = json.loads(raw) if isinstance(obj, dict): intent = obj.get("intent") or obj.get("class_name") or "" return str(intent or "").strip().lower() except Exception: pass # 尝试截取 JSON 片段 try: m = re.search(r"\{[\s\S]*\}", raw) if m: obj = json.loads(m.group(0)) if isinstance(obj, dict): intent = obj.get("intent") or obj.get("class_name") or "" return str(intent or "").strip().lower() except Exception: pass # 兜底:纯文本标签 token = re.sub(r"[^a-zA-Z_]+", "", raw).lower() return token async def _classify_intent_with_llm(self, intent_text: str) -> str: cfg = (self.config or {}).get("intent_router", {}) api_cfg = (self.config or {}).get("api", {}) text = str(intent_text or "").strip() if not text: return "chat" cache_ttl = cfg.get("cache_ttl_seconds", 120) try: cache_ttl = float(cache_ttl or 0) except Exception: cache_ttl = 0 cache_key = re.sub(r"\s+", " ", text).strip().lower() if cache_ttl > 0: cached = self._intent_cache.get(cache_key) if cached: ts, intent = cached if (time.time() - float(ts or 0)) <= cache_ttl and intent: return str(intent).strip().lower() model = cfg.get("model") or api_cfg.get("model") url = api_cfg.get("url") api_key = api_cfg.get("api_key") timeout_s = cfg.get("timeout", 15) try: timeout_s = float(timeout_s or 15) except Exception: timeout_s = 15 system_prompt = ( "你是一个聊天机器人“意图分类器”,只负责判断用户当前这句话最可能的意图。\n" "只允许返回 JSON,格式固定:{\"intent\":\"