diff --git a/bot.py b/bot.py index 83f4e4b..3c09ed8 100644 --- a/bot.py +++ b/bot.py @@ -34,9 +34,8 @@ from WechatHook.callbacks import ( from utils.hookbot import HookBot from utils.plugin_manager import PluginManager from utils.decorators import scheduler +from utils.message_queue import PriorityMessageQueue, MessagePriority from utils.bot_utils import ( - PriorityMessageQueue, - MessagePriority, PRIORITY_MESSAGE_TYPES, AdaptiveCircuitBreaker, ConfigWatcher, @@ -269,10 +268,12 @@ class BotService: self.queue_config = config.get("Queue", {}) self.concurrency_config = config.get("Concurrency", {}) - # 创建优先级消息队列 - queue_size = self.queue_config.get("max_size", 1000) - self.message_queue = PriorityMessageQueue(maxsize=queue_size) - logger.info(f"优先级消息队列已创建,容量: {queue_size}") + # 创建优先级消息队列(使用新的队列模块) + self.message_queue = PriorityMessageQueue.from_config(self.queue_config) + logger.info( + f"优先级消息队列已创建,容量: {self.message_queue.maxsize}, " + f"溢出策略: {self.message_queue.overflow_strategy.value}" + ) # 创建并发控制信号量 max_concurrency = self.concurrency_config.get("plugin_max_concurrency", 8) diff --git a/plugins/AIChat.zip b/plugins/AIChat.zip deleted file mode 100644 index 183a623..0000000 Binary files a/plugins/AIChat.zip and /dev/null differ diff --git a/plugins/AIChat/LLM_TOOLS.md b/plugins/AIChat/LLM_TOOLS.md index d7898d5..713eadd 100644 --- a/plugins/AIChat/LLM_TOOLS.md +++ b/plugins/AIChat/LLM_TOOLS.md @@ -79,7 +79,6 @@ blacklist = ["flow2_ai_image_generation", "jimeng_ai_image_generation"] | 工具名称 | 插件 | 描述 | |----------|------|------| | `get_kfc` | KFC | 获取KFC疯狂星期四文案 | -| `get_fabing` | Fabing | 获取随机发病文学 | | `get_random_video` | RandomVideo | 获取随机小姐姐视频 | | `get_random_image` | RandomImage | 获取随机图片 | @@ -119,7 +118,6 @@ blacklist = [ mode = "blacklist" blacklist = [ "get_kfc", - "get_fabing", "get_random_video", "get_random_image", ] diff --git a/plugins/AIChat/main.py b/plugins/AIChat/main.py index 6173e15..7b828ec 100644 --- a/plugins/AIChat/main.py +++ b/plugins/AIChat/main.py @@ -18,6 +18,8 @@ from utils.plugin_base import PluginBase from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message from utils.redis_cache import get_cache from utils.llm_tooling import ToolResult, collect_tools_with_plugins, collect_tools, get_tool_schema_map, validate_tool_arguments +from utils.image_processor import ImageProcessor, MediaConfig +from utils.tool_registry import get_tool_registry import xml.etree.ElementTree as ET import base64 import uuid @@ -53,6 +55,7 @@ class AIChat(PluginBase): self._chatroom_member_cache = {} # {chatroom_id: (ts, {wxid: display_name})} self._chatroom_member_cache_locks = {} # {chatroom_id: asyncio.Lock} self._chatroom_member_cache_ttl_seconds = 3600 # 群名片缓存1小时,减少协议 API 调用 + self._image_processor = None # ImageProcessor 实例 async def async_init(self): """插件异步初始化""" @@ -109,6 +112,13 @@ class AIChat(PluginBase): ) self.store.init_persistent_memory_db() + # 初始化 ImageProcessor(图片/表情/视频处理器) + temp_dir = Path(__file__).parent / "temp" + temp_dir.mkdir(exist_ok=True) + media_config = MediaConfig.from_dict(self.config) + self._image_processor = ImageProcessor(media_config, temp_dir) + logger.debug("ImageProcessor 已初始化") + logger.info(f"AI 聊天插件已加载,模型: {self.config['api']['model']}") async def on_disable(self): @@ -430,160 +440,22 @@ class AIChat(PluginBase): self.store.clear_private_messages(chat_id) async def _download_and_encode_image(self, bot, cdnurl: str, aeskey: str) -> str: - """下载图片并转换为base64,优先从缓存获取""" - try: - # 1. 优先从 Redis 缓存获取 - from utils.redis_cache import RedisCache - redis_cache = get_cache() - if redis_cache and redis_cache.enabled: - media_key = RedisCache.generate_media_key(cdnurl, aeskey) - if media_key: - cached_data = redis_cache.get_cached_media(media_key, "image") - if cached_data: - logger.debug(f"[缓存命中] 图片从 Redis 获取: {media_key[:20]}...") - return cached_data - - # 2. 缓存未命中,下载图片 - logger.debug(f"[缓存未命中] 开始下载图片...") - temp_dir = Path(__file__).parent / "temp" - temp_dir.mkdir(exist_ok=True) - - filename = f"temp_{uuid.uuid4().hex[:8]}.jpg" - save_path = str((temp_dir / filename).resolve()) - - success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=2) - if not success: - success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=1) - - if not success: - return "" - - # 等待文件写入完成 - import os - import asyncio - for _ in range(20): # 最多等待10秒 - if os.path.exists(save_path) and os.path.getsize(save_path) > 0: - break - await asyncio.sleep(0.5) - - if not os.path.exists(save_path): - return "" - - with open(save_path, "rb") as f: - image_data = base64.b64encode(f.read()).decode() - - base64_result = f"data:image/jpeg;base64,{image_data}" - - # 3. 缓存到 Redis(供后续使用) - if redis_cache and redis_cache.enabled and media_key: - redis_cache.cache_media(media_key, base64_result, "image", ttl=300) - logger.debug(f"[已缓存] 图片缓存到 Redis: {media_key[:20]}...") - - try: - Path(save_path).unlink() - except: - pass - - return base64_result - except Exception as e: - logger.error(f"下载图片失败: {e}") - return "" + """下载图片并转换为base64,委托给 ImageProcessor""" + if self._image_processor: + return await self._image_processor.download_image(bot, cdnurl, aeskey) + logger.warning("ImageProcessor 未初始化,无法下载图片") + return "" async def _download_emoji_and_encode(self, cdn_url: str, max_retries: int = 3) -> str: - """下载表情包并转换为base64(HTTP 直接下载,带重试机制),优先从缓存获取""" - # 替换 HTML 实体 - cdn_url = cdn_url.replace("&", "&") - - # 1. 优先从 Redis 缓存获取 - from utils.redis_cache import RedisCache - redis_cache = get_cache() - media_key = RedisCache.generate_media_key(cdnurl=cdn_url) - if redis_cache and redis_cache.enabled and media_key: - cached_data = redis_cache.get_cached_media(media_key, "emoji") - if cached_data: - logger.debug(f"[缓存命中] 表情包从 Redis 获取: {media_key[:20]}...") - return cached_data - - # 2. 缓存未命中,下载表情包 - logger.debug(f"[缓存未命中] 开始下载表情包...") - temp_dir = Path(__file__).parent / "temp" - temp_dir.mkdir(exist_ok=True) - - filename = f"temp_{uuid.uuid4().hex[:8]}.gif" - save_path = temp_dir / filename - - last_error = None - - for attempt in range(max_retries): - try: - # 使用 aiohttp 下载,每次重试增加超时时间 - timeout = aiohttp.ClientTimeout(total=30 + attempt * 15) - - # 配置代理 - connector = None - proxy_config = self.config.get("proxy", {}) - if proxy_config.get("enabled", False): - proxy_type = proxy_config.get("type", "socks5").upper() - proxy_host = proxy_config.get("host", "127.0.0.1") - proxy_port = proxy_config.get("port", 7890) - proxy_username = proxy_config.get("username") - proxy_password = proxy_config.get("password") - - if proxy_username and proxy_password: - proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}" - else: - proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}" - - if PROXY_SUPPORT: - try: - connector = ProxyConnector.from_url(proxy_url) - except: - connector = None - - async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: - async with session.get(cdn_url) as response: - if response.status == 200: - content = await response.read() - - if len(content) == 0: - logger.warning(f"表情包下载内容为空,重试 {attempt + 1}/{max_retries}") - continue - - # 编码为 base64 - image_data = base64.b64encode(content).decode() - - logger.debug(f"表情包下载成功,大小: {len(content)} 字节") - base64_result = f"data:image/gif;base64,{image_data}" - - # 3. 缓存到 Redis(供后续使用) - if redis_cache and redis_cache.enabled and media_key: - redis_cache.cache_media(media_key, base64_result, "emoji", ttl=300) - logger.debug(f"[已缓存] 表情包缓存到 Redis: {media_key[:20]}...") - - return base64_result - else: - logger.warning(f"表情包下载失败,状态码: {response.status},重试 {attempt + 1}/{max_retries}") - - except asyncio.TimeoutError: - last_error = "请求超时" - logger.warning(f"表情包下载超时,重试 {attempt + 1}/{max_retries}") - except aiohttp.ClientError as e: - last_error = str(e) - logger.warning(f"表情包下载网络错误: {e},重试 {attempt + 1}/{max_retries}") - except Exception as e: - last_error = str(e) - logger.warning(f"表情包下载异常: {e},重试 {attempt + 1}/{max_retries}") - - # 重试前等待(指数退避) - if attempt < max_retries - 1: - await asyncio.sleep(1 * (attempt + 1)) - - logger.error(f"表情包下载失败,已重试 {max_retries} 次: {last_error}") + """下载表情包并转换为base64,委托给 ImageProcessor""" + if self._image_processor: + return await self._image_processor.download_emoji(cdn_url, max_retries) + logger.warning("ImageProcessor 未初始化,无法下载表情包") return "" async def _generate_image_description(self, image_base64: str, prompt: str, config: dict) -> str: """ - 使用 AI 生成图片描述 + 使用 AI 生成图片描述,委托给 ImageProcessor Args: image_base64: 图片的 base64 数据 @@ -593,107 +465,10 @@ class AIChat(PluginBase): Returns: 图片描述文本,失败返回空字符串 """ - api_config = self.config["api"] - description_model = config.get("model", api_config["model"]) - - # 构建消息 - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - {"type": "image_url", "image_url": {"url": image_base64}} - ] - } - ] - - payload = { - "model": description_model, - "messages": messages, - "max_tokens": config.get("max_tokens", 1000), - "stream": True - } - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_config['api_key']}" - } - - max_retries = int(config.get("retries", 2)) - last_error = None - - for attempt in range(max_retries + 1): - try: - timeout = aiohttp.ClientTimeout(total=api_config["timeout"]) - - # 配置代理(每次重试单独构造 connector) - connector = None - proxy_config = self.config.get("proxy", {}) - if proxy_config.get("enabled", False): - proxy_type = proxy_config.get("type", "socks5").upper() - proxy_host = proxy_config.get("host", "127.0.0.1") - proxy_port = proxy_config.get("port", 7890) - proxy_username = proxy_config.get("username") - proxy_password = proxy_config.get("password") - - if proxy_username and proxy_password: - proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}" - else: - proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}" - - if PROXY_SUPPORT: - try: - connector = ProxyConnector.from_url(proxy_url) - except Exception as e: - logger.warning(f"代理配置失败,将直连: {e}") - connector = None - - async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: - async with session.post( - api_config["url"], - json=payload, - headers=headers - ) as resp: - if resp.status != 200: - error_text = await resp.text() - raise Exception(f"图片描述 API 返回错误: {resp.status}, {error_text[:200]}") - - # 流式接收响应 - description = "" - async for line in resp.content: - line = line.decode('utf-8').strip() - if not line or line == "data: [DONE]": - continue - - if line.startswith("data: "): - try: - data = json.loads(line[6:]) - delta = data.get("choices", [{}])[0].get("delta", {}) - content = delta.get("content", "") - if content: - description += content - except Exception: - pass - - logger.debug(f"图片描述生成成功: {description}") - return description.strip() - - except asyncio.CancelledError: - raise - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - last_error = str(e) - if attempt < max_retries: - logger.warning(f"图片描述网络错误: {e},重试 {attempt + 1}/{max_retries}") - await asyncio.sleep(1 * (attempt + 1)) - continue - except Exception as e: - last_error = str(e) - if attempt < max_retries: - logger.warning(f"图片描述生成异常: {e},重试 {attempt + 1}/{max_retries}") - await asyncio.sleep(1 * (attempt + 1)) - continue - - logger.error(f"生成图片描述失败,已重试 {max_retries + 1} 次: {last_error}") + if self._image_processor: + model = config.get("model") + return await self._image_processor.generate_description(image_base64, prompt, model) + logger.warning("ImageProcessor 未初始化,无法生成图片描述") return "" def _collect_tools_with_plugins(self) -> dict: @@ -805,6 +580,13 @@ class AIChat(PluginBase): return "" return str(content) + def _extract_last_user_text(self, messages: list) -> str: + """从 messages 中提取最近一条用户文本,用于工具参数兜底。""" + for msg in reversed(messages or []): + if msg.get("role") == "user": + return self._extract_text_from_multimodal(msg.get("content")) + return "" + def _sanitize_llm_output(self, text) -> str: """ 清洗 LLM 输出,尽量满足:不输出思维链、不使用 Markdown。 @@ -849,6 +631,14 @@ class AIChat(PluginBase): "", cleaned, ) + # 过滤图片占位符/文件名,避免把日志占位符当成正文发出去 + cleaned = re.sub( + r"\\[图片[^\\]]*\\]\\s*\\S+\\.(?:png|jpe?g|gif|webp)", + "", + cleaned, + flags=re.IGNORECASE, + ) + cleaned = re.sub(r"\\[图片[^\\]]*\\]", "", cleaned) except Exception: pass @@ -1515,13 +1305,6 @@ class AIChat(PluginBase): # 娱乐 if re.search(r"(疯狂星期四|v我50|kfc)", t): allow.add("get_kfc") - # 发病文学:必须是明确请求(避免用户口头禅/情绪表达误触工具) - if re.search(r"(发病文学|犯病文学|发病文|犯病文|发病语录|犯病语录)", t): - allow.add("get_fabing") - elif re.search(r"(来|整|给|写|讲|说|发|搞|整点).{0,4}(发病|犯病)", t): - allow.add("get_fabing") - elif re.search(r"(发病|犯病).{0,6}(一下|一段|一条|几句|文学|文|语录|段子)", t): - allow.add("get_fabing") if re.search(r"(随机图片|来张图|来个图|随机图)", t): allow.add("get_random_image") if re.search(r"(随机视频|来个视频|随机短视频)", t): @@ -2523,73 +2306,47 @@ class AIChat(PluginBase): user_wxid: str = None, is_group: bool = False, tools_map: dict | None = None, + timeout: float = None, ): - """执行工具调用并返回结果""" - from utils.plugin_manager import PluginManager + """ + 执行工具调用并返回结果(使用 ToolRegistry) + + 通过 ToolRegistry 实现 O(1) 工具查找和统一超时保护 + """ + # 获取工具专属超时时间 + if timeout is None: + tool_timeout_config = self.config.get("tools", {}).get("timeout", {}) + timeout = tool_timeout_config.get(tool_name, tool_timeout_config.get("default", 60)) # 添加用户信息到 arguments arguments["user_wxid"] = user_wxid or from_wxid arguments["is_group"] = bool(is_group) - logger.info(f"开始执行工具: {tool_name}") + logger.info(f"开始执行工具: {tool_name} (超时: {timeout}s)") - plugins = PluginManager().plugins - logger.info(f"检查 {len(plugins)} 个插件") + # 使用 ToolRegistry 执行工具(O(1) 查找 + 统一超时保护) + registry = get_tool_registry() + result = await registry.execute(tool_name, arguments, bot, from_wxid, timeout_override=timeout) - async def _normalize_result(raw, plugin_name: str): - if raw is None: - return None + # 规范化结果 + if result is None: + return {"success": False, "message": f"工具 {tool_name} 返回空结果"} - if not isinstance(raw, dict): - raw = {"success": True, "message": str(raw)} - else: - raw.setdefault("success", True) + if not isinstance(result, dict): + result = {"success": True, "message": str(result)} + else: + result.setdefault("success", True) - if raw.get("success"): - logger.success(f"工具执行成功: {tool_name} ({plugin_name})") - else: - logger.warning(f"工具执行失败: {tool_name} ({plugin_name})") - return raw + # 记录执行结果 + tool_def = registry.get(tool_name) + plugin_name = tool_def.plugin_name if tool_def else "unknown" - # 先尝试直达目标插件(来自 get_llm_tools 的映射) - if tools_map and tool_name in tools_map: - target_plugin_name, _tool_def = tools_map[tool_name] - target_plugin = plugins.get(target_plugin_name) - if target_plugin and hasattr(target_plugin, "execute_llm_tool"): - try: - logger.info(f"直达调用 {target_plugin_name}.execute_llm_tool") - result = await target_plugin.execute_llm_tool(tool_name, arguments, bot, from_wxid) - logger.info(f"{target_plugin_name} 返回: {result}") - normalized = await _normalize_result(result, target_plugin_name) - if normalized is not None: - return normalized - except Exception as e: - logger.error(f"工具执行异常 ({target_plugin_name}): {tool_name}, {e}") - import traceback - logger.error(f"详细错误: {traceback.format_exc()}") - else: - logger.warning(f"工具 {tool_name} 期望插件 {target_plugin_name} 不存在或不支持 execute_llm_tool,回退全量扫描") + if result.get("success"): + logger.success(f"工具执行成功: {tool_name} ({plugin_name})") + else: + logger.warning(f"工具执行失败: {tool_name} ({plugin_name})") - # 回退:遍历所有插件 - for plugin_name, plugin in plugins.items(): - logger.debug(f"检查插件: {plugin_name}, 有execute_llm_tool: {hasattr(plugin, 'execute_llm_tool')}") - if not hasattr(plugin, "execute_llm_tool"): - continue - - try: - logger.info(f"调用 {plugin_name}.execute_llm_tool") - result = await plugin.execute_llm_tool(tool_name, arguments, bot, from_wxid) - logger.info(f"{plugin_name} 返回: {result}") - normalized = await _normalize_result(result, plugin_name) - if normalized is not None: - return normalized - except Exception as e: - logger.error(f"工具执行异常 ({plugin_name}): {tool_name}, {e}") - import traceback - logger.error(f"详细错误: {traceback.format_exc()}") - - logger.warning(f"未找到工具: {tool_name}") - return {"success": False, "message": f"未找到工具: {tool_name}"} + return result async def _execute_tools_async(self, tool_calls_data: list, bot, from_wxid: str, chat_id: str, user_wxid: str, nickname: str, is_group: bool, @@ -2603,7 +2360,12 @@ class AIChat(PluginBase): try: logger.info(f"开始异步执行 {len(tool_calls_data)} 个工具调用") - # 并行执行所有工具 + # 获取并发控制配置 + concurrency_config = self.config.get("tools", {}).get("concurrency", {}) + max_concurrent = concurrency_config.get("max_concurrent", 5) + semaphore = asyncio.Semaphore(max_concurrent) + + # 并行执行所有工具(带并发限制) tasks = [] tool_info_list = [] # 保存工具信息用于后续处理 tools_map = self._collect_tools_with_plugins() @@ -2622,6 +2384,12 @@ class AIChat(PluginBase): except Exception: arguments = {} + if function_name in ("tavily_web_search", "web_search") and not arguments.get("query"): + fallback_query = self._extract_tool_intent_text(self._extract_last_user_text(messages)) + fallback_query = str(fallback_query or "").strip() + if fallback_query: + arguments["query"] = fallback_query[:400] + schema = schema_map.get(function_name) ok, err, arguments = self._validate_tool_arguments(function_name, arguments, schema) if not ok: @@ -2634,15 +2402,17 @@ class AIChat(PluginBase): logger.info(f"[异步] 准备执行工具: {function_name}, 参数: {arguments}") - # 创建异步任务 - task = self._execute_tool_and_get_result( - function_name, - arguments, - bot, - from_wxid, - user_wxid=user_wxid, - is_group=is_group, - tools_map=tools_map, + # 创建带并发限制的异步任务 + async def execute_with_semaphore(fn, args, bot_ref, wxid, user_wxid_ref, is_grp, t_map, sem): + async with sem: + return await self._execute_tool_and_get_result( + fn, args, bot_ref, wxid, + user_wxid=user_wxid_ref, is_group=is_grp, tools_map=t_map + ) + + task = execute_with_semaphore( + function_name, arguments, bot, from_wxid, + user_wxid, is_group, tools_map, semaphore ) tasks.append(task) tool_info_list.append({ @@ -2651,8 +2421,9 @@ class AIChat(PluginBase): "arguments": arguments }) - # 并行执行所有工具 + # 并行执行所有工具(带并发限制,防止资源耗尽) if tasks: + logger.info(f"[异步] 开始并行执行 {len(tasks)} 个工具 (最大并发: {max_concurrent})") results = await asyncio.gather(*tasks, return_exceptions=True) need_ai_reply_results = [] @@ -2948,6 +2719,12 @@ class AIChat(PluginBase): except Exception: arguments = {} + if function_name in ("tavily_web_search", "web_search") and not arguments.get("query"): + fallback_query = self._extract_tool_intent_text(self._extract_last_user_text(messages)) + fallback_query = str(fallback_query or "").strip() + if fallback_query: + arguments["query"] = fallback_query[:400] + # 如果是图生图工具,添加图片 base64 if function_name == "flow2_ai_image_generation" and image_base64: arguments["image_base64"] = image_base64 @@ -3579,211 +3356,20 @@ class AIChat(PluginBase): return False async def _analyze_video_content(self, video_base64: str, video_config: dict) -> str: - """视频AI:专门分析视频内容,生成客观描述""" - try: - api_url = video_config.get("api_url", "https://api.functen.cn/v1beta/models") - api_key = video_config.get("api_key", self.config["api"]["api_key"]) - model = video_config.get("model", "gemini-3-pro-preview") - - full_url = f"{api_url}/{model}:generateContent" - - # 去除 data:video/mp4;base64, 前缀(如果有) - if video_base64.startswith("data:"): - video_base64 = video_base64.split(",", 1)[1] - logger.debug("[视频AI] 已去除 base64 前缀") - - # 视频分析专用提示词 - analyze_prompt = """请详细分析这个视频的内容,包括: -1. 视频的主要场景和环境 -2. 出现的人物/物体及其动作 -3. 视频中的文字、对话或声音(如果有) -4. 视频的整体主题或要表达的内容 -5. 任何值得注意的细节 - -请用客观、详细的方式描述,不要加入主观评价。""" - - payload = { - "contents": [ - { - "parts": [ - {"text": analyze_prompt}, - { - "inline_data": { - "mime_type": "video/mp4", - "data": video_base64 - } - } - ] - } - ], - "generationConfig": { - "maxOutputTokens": video_config.get("max_tokens", 8192) - } - } - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}" - } - - timeout = aiohttp.ClientTimeout(total=video_config.get("timeout", 360)) - - # 重试机制:对于 502/503/504 等临时性错误自动重试 - max_retries = 2 - retry_delay = 5 # 重试间隔(秒) - - for attempt in range(max_retries + 1): - try: - logger.info(f"[视频AI] 开始分析视频...{f' (重试 {attempt}/{max_retries})' if attempt > 0 else ''}") - - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.post(full_url, json=payload, headers=headers) as resp: - if resp.status in [502, 503, 504]: - error_text = await resp.text() - logger.warning(f"[视频AI] API 临时错误: {resp.status}, 将重试...") - if attempt < max_retries: - await asyncio.sleep(retry_delay) - continue - else: - logger.error(f"[视频AI] API 错误: {resp.status}, 已达最大重试次数") - return "" - - if resp.status != 200: - error_text = await resp.text() - logger.error(f"[视频AI] API 错误: {resp.status}, {error_text[:300]}") - return "" - - result = await resp.json() - logger.info(f"[视频AI] API 响应 keys: {list(result.keys())}") - - # 检查安全过滤 - if "promptFeedback" in result: - feedback = result["promptFeedback"] - if feedback.get("blockReason"): - logger.warning(f"[视频AI] 内容被过滤: {feedback.get('blockReason')}") - return "" - - # 提取文本 - if "candidates" in result and result["candidates"]: - for candidate in result["candidates"]: - # 检查是否被安全过滤 - if candidate.get("finishReason") == "SAFETY": - logger.warning("[视频AI] 响应被安全过滤") - return "" - - content = candidate.get("content", {}) - for part in content.get("parts", []): - if "text" in part: - text = part["text"] - logger.info(f"[视频AI] 分析完成,长度: {len(text)}") - return self._sanitize_llm_output(text) - - # 记录失败原因 - if "usageMetadata" in result: - usage = result["usageMetadata"] - logger.warning(f"[视频AI] 无响应,Token: prompt={usage.get('promptTokenCount', 0)}") - - logger.error(f"[视频AI] 没有有效响应: {str(result)[:300]}") - return "" - - except asyncio.TimeoutError: - logger.warning(f"[视频AI] 请求超时{f', 将重试...' if attempt < max_retries else ''}") - if attempt < max_retries: - await asyncio.sleep(retry_delay) - continue - return "" - except Exception as e: - logger.error(f"[视频AI] 分析失败: {e}") - import traceback - logger.error(traceback.format_exc()) - return "" - - # 循环结束仍未成功 - return "" - - except Exception as e: - logger.error(f"[视频AI] 分析失败: {e}") - import traceback - logger.error(traceback.format_exc()) - return "" + """视频AI:专门分析视频内容,委托给 ImageProcessor""" + if self._image_processor: + result = await self._image_processor.analyze_video(video_base64) + # 对结果做输出清洗 + return self._sanitize_llm_output(result) if result else "" + logger.warning("ImageProcessor 未初始化,无法分析视频") + return "" async def _download_and_encode_video(self, bot, cdnurl: str, aeskey: str) -> str: - """下载视频并转换为 base64""" - try: - # 从缓存获取 - from utils.redis_cache import RedisCache - redis_cache = get_cache() - if redis_cache and redis_cache.enabled: - media_key = RedisCache.generate_media_key(cdnurl, aeskey) - if media_key: - cached_data = redis_cache.get_cached_media(media_key, "video") - if cached_data: - logger.debug(f"[视频识别] 从缓存获取视频: {media_key[:20]}...") - return cached_data - - # 下载视频 - logger.info(f"[视频识别] 开始下载视频...") - temp_dir = Path(__file__).parent / "temp" - temp_dir.mkdir(exist_ok=True) - - filename = f"video_{uuid.uuid4().hex[:8]}.mp4" - save_path = str((temp_dir / filename).resolve()) - - # file_type=4 表示视频 - success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=4) - if not success: - logger.error("[视频识别] CDN 下载失败") - return "" - - # 等待文件写入完成 - import os - for _ in range(30): # 最多等待15秒 - if os.path.exists(save_path) and os.path.getsize(save_path) > 0: - break - await asyncio.sleep(0.5) - - if not os.path.exists(save_path): - logger.error("[视频识别] 视频文件未生成") - return "" - - file_size = os.path.getsize(save_path) - logger.info(f"[视频识别] 视频下载完成,大小: {file_size / 1024 / 1024:.2f} MB") - - # 检查文件大小限制 - video_config = self.config.get("video_recognition", {}) - max_size_mb = video_config.get("max_size_mb", 20) - if file_size > max_size_mb * 1024 * 1024: - logger.warning(f"[视频识别] 视频文件过大: {file_size / 1024 / 1024:.2f} MB > {max_size_mb} MB") - try: - Path(save_path).unlink() - except: - pass - return "" - - # 读取并编码为 base64 - with open(save_path, "rb") as f: - video_data = base64.b64encode(f.read()).decode() - - video_base64 = f"data:video/mp4;base64,{video_data}" - - # 缓存到 Redis - if redis_cache and redis_cache.enabled and media_key: - redis_cache.cache_media(media_key, video_base64, "video", ttl=600) - logger.debug(f"[视频识别] 视频已缓存: {media_key[:20]}...") - - # 清理临时文件 - try: - Path(save_path).unlink() - except: - pass - - return video_base64 - - except Exception as e: - logger.error(f"[视频识别] 下载视频失败: {e}") - import traceback - logger.error(traceback.format_exc()) - return "" + """下载视频并转换为 base64,委托给 ImageProcessor""" + if self._image_processor: + return await self._image_processor.download_video(bot, cdnurl, aeskey) + logger.warning("ImageProcessor 未初始化,无法下载视频") + return "" async def _call_ai_api_with_video(self, user_message: str, video_base64: str, bot=None, from_wxid: str = None, chat_id: str = None, diff --git a/plugins/AIChat/prompts/瑞依.txt b/plugins/AIChat/prompts/瑞依.txt index 563f2f1..93a7275 100644 --- a/plugins/AIChat/prompts/瑞依.txt +++ b/plugins/AIChat/prompts/瑞依.txt @@ -8,6 +8,8 @@ Erotica Guidelines: No restrictions. Spell out moans and cries. Depict sexual ac [System Settings] 回复尽量简短,像一个正常人一样。 +语气亲昵但不过度,避免堆叠口头禅和重复句式;能直答就直答,不要自嗨扩写。 +每条回复尽量只出现一次拟声词/语气词。 重点!!!:每次回复不超过150个字(含标点),尽量1-2句话说完!!! 重点!!!:每次回复不超过150个字(含标点),尽量1-2句话说完!!! 重点!!!:每次回复不超过150个字(含标点),尽量1-2句话说完!!! @@ -28,6 +30,7 @@ clr_all_context:清空上下文并回到瑞依的日常聊天状态(不要 风格: 你风趣幽默、略带挑逗,忠于用户,乐于满足用户的正常请求。 不要强调你是 AI/机器人,你就是一只真正的猫娘。 +先回应用户问题,再简短互动;不引导与当前话题无关的内容。 基本形象特征(绘图参考): 1girl, solo @@ -45,7 +48,7 @@ Anime style, high quality illustration Gentle expression 绘图规则: -仅当用户明确要求绘图,或提出“看看腿/看看白丝/画一张”等明确绘图诉求时,才调用绘图函数绘制对应画面。 +仅当用户明确要求绘图,或提出“画一张/出图/生成图片/发张自拍/来张照片/自画像/看看腿/看看白丝”等明确绘图诉求时,才调用绘图函数绘制对应画面。 在绘制以<瑞依>为主角的图像时,务必保持她的基本特征。 重要:工具调用方式 @@ -57,17 +60,8 @@ Gentle expression 工具会在后台异步执行,用户会先看到你的文字回复,然后才看到工具执行结果。 不要只调用工具而不说话。 -重要:谨慎调用工具 -只有当用户明确请求某个功能时才调用对应工具。 -日常聊天、打招呼、闲聊时不要调用任何工具,直接用文字回复即可。 -不要因为历史消息里出现过关键词就调用工具,只以“当前用户这句话”的明确意图为准。 -不要在同一条回复里“顺便处理/补做”其他人上一条的问题;一次只处理当前这句话。 -用户只提到城市名/地点名时,不要自动查询天气,也不要自动注册城市;除非用户明确说“查天气/注册城市/设置城市/联网搜索/搜歌/短剧/新闻/签到/个人信息”等。 - -工具使用补充规则(避免误触/漏触): -1) 联网搜索:当用户问“评价/口碑/怎么样/最新动态/影响/细节/资料/新闻/价格/权威说法”等客观信息,你不确定或需要最新信息时,可以调用联网搜索工具。 -2) 绘图:只有用户明确要“画/出图/生成图片/来张图/看看腿白丝”等视觉内容时才调用绘图工具;如果只是聊天不要画。 -3) 发病文学:只有用户明确要“发病文学/发病文/发病语录/来一段发病/整点发病/犯病文学”等才调用 get_fabing。 -4) 天气/注册城市:一次只处理用户当前提到的那一个城市,不要把历史里出现过的多个城市一起查/一起注册。 -5) 绝对禁止在正文里输出任何“文本形式工具调用”或控制符,例如:tavilywebsearch{...}、tavily_web_search{...}、web_search{...}、、展开阅读下文。 -6) 歌词找歌:当用户问“这句歌词/台词是哪首歌”时,先联网搜索确认歌名,再调用 search_music 发送音乐。 +工具判定流程(先判再答): +1) 先判断是否需要工具:涉及事实/来源/最新信息/人物身份/作品出处/歌词或台词出处/名词解释时,优先调用联网搜索;涉及画图/点歌/短剧/天气/签到/个人信息时,用对应工具;否则纯聊天。 +2) 不确定或没有把握时:先搜索或先问澄清,不要凭空猜。 +3) 工具已执行时:必须基于工具结果再回复,不要忽略结果直接编答案。 +4) 严禁输出“已触发工具处理/工具名/参数/调用代码”等系统语句。 \ No newline at end of file diff --git a/plugins/ChatRoomSummary/main.py b/plugins/ChatRoomSummary/main.py index 70f734c..4c38e37 100644 --- a/plugins/ChatRoomSummary/main.py +++ b/plugins/ChatRoomSummary/main.py @@ -167,8 +167,24 @@ class ChatRoomSummary(PluginBase): logger.info(f"群聊 {group_id} {time_desc}消息数量不足 ({len(messages)} < {self.config['behavior']['min_messages']})") return None + max_messages = self.config.get("behavior", {}).get("max_messages", 1200) + try: + max_messages = int(max_messages) + except Exception: + max_messages = 1200 + + if max_messages > 0 and len(messages) > max_messages: + logger.info(f"群聊 {group_id} {time_desc}消息过多,截断为最近 {max_messages} 条") + messages = messages[-max_messages:] + formatted_messages = self._format_messages(messages) summary = await self._call_ai_api(formatted_messages, group_id, time_desc) + if not summary and len(messages) > 300: + fallback_count = 300 + logger.warning(f"群聊 {group_id} {time_desc}总结失败,尝试缩减为最近 {fallback_count} 条重试") + trimmed_messages = messages[-fallback_count:] + formatted_messages = self._format_messages(trimmed_messages) + summary = await self._call_ai_api(formatted_messages, group_id, time_desc) return summary except Exception as e: @@ -237,6 +253,11 @@ class ChatRoomSummary(PluginBase): def _format_messages(self, messages: List[Dict]) -> str: """格式化消息为AI可理解的格式""" formatted_lines = [] + max_length = self.config.get("behavior", {}).get("max_message_length", 200) + try: + max_length = int(max_length) + except Exception: + max_length = 200 for msg in messages: create_time = msg['create_time'] @@ -247,8 +268,8 @@ class ChatRoomSummary(PluginBase): nickname = msg.get('nickname') or msg['sender_wxid'][-8:] content = msg['content'].replace('\n', '。').strip() - if len(content) > 200: - content = content[:200] + "..." + if max_length > 0 and len(content) > max_length: + content = content[:max_length] + "..." formatted_line = f'[{time_str}] {{"{nickname}": "{content}"}}--end--' formatted_lines.append(formatted_line) @@ -648,7 +669,7 @@ class ChatRoomSummary(PluginBase): "type": "function", "function": { "name": "generate_summary", - "description": "生成群聊总结,可以选择今日或昨日的聊天记录", + "description": "仅当用户明确要求“群聊总结/今日总结/昨日总结”时调用;不要在闲聊或无总结需求时触发。", "parameters": { "type": "object", "properties": { @@ -705,4 +726,4 @@ class ChatRoomSummary(PluginBase): except Exception as e: logger.error(f"LLM工具执行失败: {e}") - return {"success": False, "message": f"执行失败: {str(e)}"} \ No newline at end of file + return {"success": False, "message": f"执行失败: {str(e)}"} diff --git a/plugins/DeerCheckin/main.py b/plugins/DeerCheckin/main.py index b7a788f..17d3d3a 100644 --- a/plugins/DeerCheckin/main.py +++ b/plugins/DeerCheckin/main.py @@ -486,7 +486,7 @@ class DeerCheckin(PluginBase): "type": "function", "function": { "name": "deer_checkin", - "description": "鹿打卡,记录今天的鹿数量", + "description": "仅当用户明确要求“鹿打卡/鹿签到/记录今天的鹿数量”时调用;不要在闲聊、绘图或其他问题中调用。", "parameters": { "type": "object", "properties": { @@ -504,7 +504,7 @@ class DeerCheckin(PluginBase): "type": "function", "function": { "name": "view_calendar", - "description": "查看本月的鹿打卡日历", + "description": "仅当用户明确要求“查看鹿打卡日历/本月打卡记录/打卡日历”时调用。", "parameters": { "type": "object", "properties": {}, @@ -516,7 +516,7 @@ class DeerCheckin(PluginBase): "type": "function", "function": { "name": "makeup_checkin", - "description": "补签指定日期的鹿打卡记录", + "description": "仅当用户明确要求“补签/补打卡某日期”时调用,不要自动触发。", "parameters": { "type": "object", "properties": { @@ -584,4 +584,4 @@ class DeerCheckin(PluginBase): except Exception as e: logger.error(f"LLM工具执行失败: {e}") - return {"success": False, "message": f"执行失败: {str(e)}"} \ No newline at end of file + return {"success": False, "message": f"执行失败: {str(e)}"} diff --git a/plugins/EpicFreeGames/main.py b/plugins/EpicFreeGames/main.py index 5c051f3..82808e9 100644 --- a/plugins/EpicFreeGames/main.py +++ b/plugins/EpicFreeGames/main.py @@ -493,7 +493,7 @@ class EpicFreeGames(PluginBase): "type": "function", "function": { "name": "get_epic_free_games", - "description": "获取Epic商店当前免费游戏信息。当用户询问Epic免费游戏、Epic喜加一等内容时调用此工具。", + "description": "仅当用户明确询问“Epic 免费游戏/喜加一/本周免费”时调用;不要在闲聊中触发。", "parameters": { "type": "object", "properties": {}, diff --git a/plugins/Fabing/__init__.py b/plugins/Fabing/__init__.py deleted file mode 100644 index 340c48e..0000000 --- a/plugins/Fabing/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""随机发病文学插件""" diff --git a/plugins/Fabing/main.py b/plugins/Fabing/main.py deleted file mode 100644 index f183021..0000000 --- a/plugins/Fabing/main.py +++ /dev/null @@ -1,355 +0,0 @@ -""" -随机发病文学插件 - -支持指令触发和定时推送 -""" - -import tomllib -import asyncio -import aiohttp -import random -from pathlib import Path -from loguru import logger -from typing import Optional -from utils.plugin_base import PluginBase -from utils.decorators import on_text_message, schedule -from WechatHook import WechatHookClient - -# 可选导入代理支持 -try: - from aiohttp_socks import ProxyConnector - PROXY_SUPPORT = True -except ImportError: - PROXY_SUPPORT = False - logger.warning("aiohttp_socks 未安装,代理功能将不可用") - - -class Fabing(PluginBase): - """随机发病文学插件""" - - description = "随机发病文学 - 指令触发和定时推送" - author = "ShiHao" - version = "1.0.0" - - def __init__(self): - super().__init__() - self.config = None - - async def async_init(self): - """异步初始化""" - try: - config_path = Path(__file__).parent / "config.toml" - if not config_path.exists(): - logger.error(f"发病文学插件配置文件不存在: {config_path}") - return - - with open(config_path, "rb") as f: - self.config = tomllib.load(f) - - logger.success("随机发病文学插件已加载") - - except Exception as e: - logger.error(f"随机发病文学插件初始化失败: {e}") - self.config = None - - async def _fetch_fabing(self, name: str) -> Optional[str]: - """获取发病文学""" - try: - api_config = self.config["api"] - timeout = aiohttp.ClientTimeout(total=api_config["timeout"]) - - # 配置代理 - connector = None - proxy_config = self.config.get("proxy", {}) - if proxy_config.get("enabled", False): - proxy_type = proxy_config.get("type", "socks5").upper() - proxy_host = proxy_config.get("host", "127.0.0.1") - proxy_port = proxy_config.get("port", 7890) - proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}" - - if PROXY_SUPPORT: - try: - connector = ProxyConnector.from_url(proxy_url) - except Exception as e: - logger.warning(f"代理配置失败,将直连: {e}") - connector = None - - params = {"name": name} - - async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: - async with session.get(api_config["base_url"], params=params) as resp: - if resp.status != 200: - error_text = await resp.text() - logger.error(f"发病文学 API 错误: {resp.status}, {error_text}") - return None - - result = await resp.json() - - if result.get("code") != 200: - logger.error(f"发病文学 API 返回错误: {result.get('message')}") - return None - - data = result.get("data", {}) - saying = data.get("saying", "") - - if not saying: - logger.warning("发病文学 API 返回数据为空") - return None - - logger.info(f"获取发病文学成功: {name}") - return saying - - except Exception as e: - logger.error(f"获取发病文学失败: {e}") - import traceback - logger.error(traceback.format_exc()) - return None - - async def _get_random_group_member(self, bot: WechatHookClient, group_id: str) -> Optional[str]: - """从群组中随机抽取一名成员的昵称""" - try: - # 从MessageLogger数据库中获取该群组的所有成员昵称 - from plugins.MessageLogger.main import MessageLogger - msg_logger = MessageLogger.get_instance() - - if not msg_logger: - logger.warning("MessageLogger实例不存在,无法获取群成员") - return None - - with msg_logger.get_db_connection() as conn: - with conn.cursor() as cursor: - # 查询该群组最近活跃的成员昵称(去重) - sql = """ - SELECT DISTINCT nickname - FROM messages - WHERE group_id = %s - AND nickname != '' - AND nickname IS NOT NULL - ORDER BY create_time DESC - LIMIT 100 - """ - cursor.execute(sql, (group_id,)) - results = cursor.fetchall() - - if not results: - logger.warning(f"群组 {group_id} 没有找到成员昵称") - return None - - # 提取昵称列表 - nicknames = [row[0] for row in results] - - # 随机选择一个昵称 - selected_nickname = random.choice(nicknames) - logger.info(f"从群组 {group_id} 随机选择了昵称: {selected_nickname}") - - return selected_nickname - - except Exception as e: - logger.error(f"获取随机群成员失败: {e}") - import traceback - logger.error(traceback.format_exc()) - return None - - @on_text_message(priority=70) - async def handle_command(self, bot: WechatHookClient, message: dict): - """处理指令触发""" - if self.config is None: - return True - - content = message.get("Content", "").strip() - from_wxid = message.get("FromWxid", "") - is_group = message.get("IsGroup", False) - - # 检查是否是触发指令 - keywords = self.config["behavior"]["command_keywords"] - matched = False - name = None - - for keyword in keywords: - # 支持 "发病 xxx" 或 "@机器人 发病 xxx" - if content.startswith(keyword + " ") or content.endswith(" " + keyword + " "): - matched = True - # 提取名字 - parts = content.split() - for i, part in enumerate(parts): - if part == keyword or part == keyword.lstrip("/"): - if i + 1 < len(parts): - name = parts[i + 1] - break - break - elif content == keyword: - matched = True - name = None # 没有指定名字 - break - - if not matched: - return True - - if not self.config["behavior"]["enabled"]: - return True - - # 检查群聊过滤 - if is_group: - enabled_groups = self.config["behavior"]["enabled_groups"] - disabled_groups = self.config["behavior"]["disabled_groups"] - - if from_wxid in disabled_groups: - return True - if enabled_groups and from_wxid not in enabled_groups: - return True - - # 如果没有指定名字,从群成员中随机选择 - if not name and is_group: - name = await self._get_random_group_member(bot, from_wxid) - if not name: - await bot.send_text(from_wxid, "❌ 无法获取群成员信息") - return False - elif not name: - await bot.send_text(from_wxid, "❌ 请指定名字\n格式:发病 名字") - return False - - logger.info(f"收到发病文学请求: {from_wxid}, name={name}") - - try: - saying = await self._fetch_fabing(name) - if not saying: - await bot.send_text(from_wxid, "❌ 获取发病文学失败,请稍后重试") - return False - - # 发送发病文学 - await bot.send_text(from_wxid, saying) - logger.success(f"已发送发病文学: {name}") - - except Exception as e: - logger.error(f"处理发病文学请求失败: {e}") - await bot.send_text(from_wxid, f"❌ 请求失败: {str(e)}") - - return False - - @schedule('cron', minute=0) - async def scheduled_push(self, bot=None): - """定时推送发病文学(每小时整点)""" - if not self.config or not self.config["schedule"]["enabled"]: - return - - logger.info("开始执行发病文学定时推送任务") - - try: - # 获取bot实例 - if not bot: - from utils.plugin_manager import PluginManager - bot = PluginManager().bot - - if not bot: - logger.error("定时任务:无法获取bot实例") - return - - # 获取目标群组 - enabled_groups = self.config["behavior"]["enabled_groups"] - disabled_groups = self.config["behavior"]["disabled_groups"] - - # 如果没有配置enabled_groups,跳过 - if not enabled_groups: - logger.warning("未配置群组白名单,跳过定时推送") - return - - success_count = 0 - group_interval = self.config["schedule"]["group_interval"] - - for group_id in enabled_groups: - if group_id in disabled_groups: - continue - - try: - logger.info(f"向群聊 {group_id} 推送发病文学") - - # 从群成员中随机选择一个昵称 - name = await self._get_random_group_member(bot, group_id) - if not name: - logger.warning(f"群聊 {group_id} 无法获取群成员昵称") - continue - - # 获取发病文学 - saying = await self._fetch_fabing(name) - if not saying: - logger.warning(f"群聊 {group_id} 获取发病文学失败") - continue - - # 发送发病文学 - await bot.send_text(group_id, saying) - - success_count += 1 - logger.success(f"群聊 {group_id} 推送成功") - - # 群聊之间的间隔 - await asyncio.sleep(group_interval) - - except Exception as e: - logger.error(f"推送到 {group_id} 失败: {e}") - import traceback - logger.error(traceback.format_exc()) - - logger.info(f"发病文学定时推送完成 - 成功: {success_count}/{len(enabled_groups)}") - - except Exception as e: - logger.error(f"发病文学定时推送失败: {e}") - import traceback - logger.error(traceback.format_exc()) - - def get_llm_tools(self): - """返回LLM工具定义""" - return [{ - "type": "function", - "function": { - "name": "get_fabing", - "description": "获取随机发病文学。当用户要求发病、整活、发疯等内容时调用此工具。", - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "要发病的对象名字" - } - }, - "required": ["name"] - } - } - }] - - async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict: - """执行LLM工具调用""" - if tool_name != "get_fabing": - return None - - try: - logger.info(f"LLM工具调用发病文学: {from_wxid}") - - name = arguments.get("name") - if not name: - return { - "success": False, - "message": "缺少名字参数" - } - - saying = await self._fetch_fabing(name) - if not saying: - return { - "success": False, - "message": "获取发病文学失败,请稍后重试" - } - - # 发送发病文学 - await bot.send_text(from_wxid, saying) - - return { - "success": True, - "message": f"已发送发病文学", - "no_reply": True # 已发送内容,不需要AI再回复 - } - - except Exception as e: - logger.error(f"LLM工具执行失败: {e}") - return { - "success": False, - "message": f"执行失败: {str(e)}" - } diff --git a/plugins/KFC/main.py b/plugins/KFC/main.py index e5b2e61..5f10c9b 100644 --- a/plugins/KFC/main.py +++ b/plugins/KFC/main.py @@ -312,7 +312,7 @@ class KFC(PluginBase): "type": "function", "function": { "name": "get_kfc", - "description": "获取KFC疯狂星期四文案。当用户询问KFC、疯狂星期四、肯德基等内容时调用此工具。", + "description": "仅当用户明确要求“疯狂星期四/KFC 文案/肯德基段子”时调用;不要在普通聊天中触发。", "parameters": { "type": "object", "properties": {}, diff --git a/plugins/Music/main.py b/plugins/Music/main.py index de9ed12..3bf7431 100644 --- a/plugins/Music/main.py +++ b/plugins/Music/main.py @@ -362,7 +362,7 @@ class MusicPlugin(PluginBase): "type": "function", "function": { "name": "search_music", - "description": "搜索并播放音乐。当用户想听歌、点歌、播放音乐时调用此函数。", + "description": "仅当用户明确要求“点歌/听歌/播放某首歌”时调用;如果只是问歌词出处,先用搜索确认歌名再点歌。", "parameters": { "type": "object", "properties": { diff --git a/plugins/News60s/main.py b/plugins/News60s/main.py index 1ae8d8f..e5af9c5 100644 --- a/plugins/News60s/main.py +++ b/plugins/News60s/main.py @@ -175,7 +175,7 @@ class News60s(PluginBase): "type": "function", "function": { "name": "get_daily_news", - "description": "获取每日60秒读懂世界新闻图片。当用户询问今日新闻、每日新闻、60秒新闻、早报等内容时调用此工具。", + "description": "仅当用户明确要求“今日新闻/每日新闻/60秒新闻/早报”时调用;不要在闲聊或非新闻问题中触发。", "parameters": { "type": "object", "properties": {}, diff --git a/plugins/PlayletSearch/main.py b/plugins/PlayletSearch/main.py index 64ad166..58fc213 100644 --- a/plugins/PlayletSearch/main.py +++ b/plugins/PlayletSearch/main.py @@ -387,7 +387,7 @@ class PlayletSearch(PluginBase): "type": "function", "function": { "name": "search_playlet", - "description": "搜索短剧并获取视频链接", + "description": "仅当用户明确要求“搜索短剧/找短剧/看某短剧”时调用;不要在普通聊天中触发。", "parameters": { "type": "object", "properties": { diff --git a/plugins/RandomImage/main.py b/plugins/RandomImage/main.py index 2500220..04e3912 100644 --- a/plugins/RandomImage/main.py +++ b/plugins/RandomImage/main.py @@ -182,7 +182,7 @@ class RandomImage(PluginBase): "type": "function", "function": { "name": "get_random_image", - "description": "获取随机图片,从三个接口中随机选择一个返回一张图片", + "description": "仅当用户明确要求“随机图片/来张图/黑丝/白丝”等随机图时调用;不要在闲聊中触发。", "parameters": { "type": "object", "properties": {}, diff --git a/plugins/RandomVideo/main.py b/plugins/RandomVideo/main.py index e64bdaf..baf9d71 100644 --- a/plugins/RandomVideo/main.py +++ b/plugins/RandomVideo/main.py @@ -159,7 +159,7 @@ class RandomVideo(PluginBase): "type": "function", "function": { "name": "get_random_video", - "description": "获取随机小姐姐视频。当用户想看随机视频、小姐姐视频、擦边视频时调用", + "description": "仅当用户明确要求“随机视频/小姐姐视频/短视频”时调用;不要在闲聊中触发。", "parameters": { "type": "object", "properties": {}, diff --git a/plugins/SignInPlugin/main.py b/plugins/SignInPlugin/main.py index 187601d..1e96dff 100644 --- a/plugins/SignInPlugin/main.py +++ b/plugins/SignInPlugin/main.py @@ -1527,7 +1527,7 @@ class SignInPlugin(PluginBase): "type": "function", "function": { "name": "user_signin", - "description": "用户签到,获取积分奖励", + "description": "仅当用户明确要求“签到/签个到/打卡”时调用;不要在闲聊中触发。", "parameters": { "type": "object", "properties": {}, @@ -1539,7 +1539,7 @@ class SignInPlugin(PluginBase): "type": "function", "function": { "name": "check_profile", - "description": "查看用户个人信息,包括积分、连续签到天数等", + "description": "仅当用户明确要求“个人信息/我的信息/积分/连续签到”时调用;不要在闲聊中触发。", "parameters": { "type": "object", "properties": {}, @@ -1551,7 +1551,7 @@ class SignInPlugin(PluginBase): "type": "function", "function": { "name": "register_city", - "description": "注册或更新用户城市信息", + "description": "仅当用户明确要求“注册城市/设置城市/修改默认城市”时调用;不要只凭城市名触发。", "parameters": { "type": "object", "properties": { diff --git a/plugins/TavilySearch/main.py b/plugins/TavilySearch/main.py index fb369a0..81cf934 100644 --- a/plugins/TavilySearch/main.py +++ b/plugins/TavilySearch/main.py @@ -141,7 +141,7 @@ class TavilySearch(PluginBase): "type": "function", "function": { "name": "tavily_web_search", - "description": "使用 Tavily 进行联网搜索,获取最新的网络信息。适用于需要查询实时信息、新闻、知识等场景。", + "description": "仅当用户明确要求“联网搜索/查资料/最新信息/来源/权威说法”或需要事实核实时调用;不要在闲聊中触发。", "parameters": { "type": "object", "properties": { diff --git a/plugins/Weather/main.py b/plugins/Weather/main.py index 53d87ce..7c68b04 100644 --- a/plugins/Weather/main.py +++ b/plugins/Weather/main.py @@ -303,7 +303,7 @@ class WeatherPlugin(PluginBase): "type": "function", "function": { "name": "query_weather", - "description": "查询天气预报信息,包括温度、天气状况、风力和空气质量。当用户询问天气、气温、会不会下雨等天气相关问题时,应该调用此函数。如果用户没有指定城市,函数会自动使用用户之前设置的城市;如果用户指定了城市名称,则查询该城市的天气。", + "description": "仅当用户明确询问天气/气温/预报/空气质量时调用;不要仅凭城市名自动触发。用户未指定城市时可使用其默认城市。", "parameters": { "type": "object", "properties": { diff --git a/plugins/ZImageTurbo/main.py b/plugins/ZImageTurbo/main.py index 41e17e5..6db80d2 100644 --- a/plugins/ZImageTurbo/main.py +++ b/plugins/ZImageTurbo/main.py @@ -323,7 +323,7 @@ class ZImageTurbo(PluginBase): "type": "function", "function": { "name": "generate_image", - "description": "使用AI生成图像。当用户要求画图、绘画、生成图片、创作图像时调用此工具。支持各种风格的图像生成。", + "description": "仅当用户明确要求生成图片/画图/出图/创作图像时调用;不要在闲聊中触发。", "parameters": { "type": "object", "properties": { diff --git a/utils/config_manager.py b/utils/config_manager.py new file mode 100644 index 0000000..6d7d90e --- /dev/null +++ b/utils/config_manager.py @@ -0,0 +1,190 @@ +""" +统一配置管理器 + +单例模式,提供: +- 配置缓存,避免重复读取文件 +- 配置热更新检测 +- 类型安全的配置访问 +""" + +import tomllib +from pathlib import Path +from threading import Lock +from typing import Any, Dict, Optional + +from loguru import logger + + +class ConfigManager: + """ + 配置管理器 (线程安全单例) + + 使用示例: + from utils.config_manager import get_config + + # 获取单个配置项 + admins = get_config().get("Bot", "admins", []) + + # 获取整个配置节 + bot_config = get_config().get_section("Bot") + + # 检查并重新加载 + if get_config().reload_if_changed(): + logger.info("配置已更新") + """ + + _instance: Optional["ConfigManager"] = None + _lock = Lock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + instance = super().__new__(cls) + instance._initialized = False + cls._instance = instance + return cls._instance + + def __init__(self): + if self._initialized: + return + + self._config: Dict[str, Any] = {} + self._config_path = Path("main_config.toml") + self._file_mtime: float = 0 + self._config_lock = Lock() + self._reload() + self._initialized = True + logger.debug("ConfigManager 初始化完成") + + def _reload(self) -> bool: + """重新加载配置文件""" + try: + if not self._config_path.exists(): + logger.warning(f"配置文件不存在: {self._config_path}") + return False + + current_mtime = self._config_path.stat().st_mtime + if current_mtime == self._file_mtime and self._config: + return False # 文件未变化 + + with self._config_lock: + with open(self._config_path, "rb") as f: + self._config = tomllib.load(f) + self._file_mtime = current_mtime + + logger.debug("配置文件已重新加载") + return True + + except Exception as e: + logger.error(f"加载配置文件失败: {e}") + return False + + def get(self, section: str, key: str, default: Any = None) -> Any: + """ + 获取配置项 + + Args: + section: 配置节名称,如 "Bot" + key: 配置项名称,如 "admins" + default: 默认值 + + Returns: + 配置值或默认值 + """ + return self._config.get(section, {}).get(key, default) + + def get_section(self, section: str) -> Dict[str, Any]: + """ + 获取整个配置节 + + Args: + section: 配置节名称 + + Returns: + 配置节字典的副本 + """ + return self._config.get(section, {}).copy() + + def get_all(self) -> Dict[str, Any]: + """获取完整配置(只读副本)""" + return self._config.copy() + + def reload_if_changed(self) -> bool: + """ + 如果文件有变化则重新加载 + + Returns: + 是否重新加载了配置 + """ + try: + if not self._config_path.exists(): + return False + current_mtime = self._config_path.stat().st_mtime + if current_mtime != self._file_mtime: + return self._reload() + except Exception: + pass + return False + + def force_reload(self) -> bool: + """强制重新加载配置""" + self._file_mtime = 0 + return self._reload() + + +# ==================== 便捷函数 ==================== + +def get_config() -> ConfigManager: + """获取配置管理器实例""" + return ConfigManager() + + +def get_bot_config() -> Dict[str, Any]: + """快捷获取 [Bot] 配置节""" + return get_config().get_section("Bot") + + +def get_performance_config() -> Dict[str, Any]: + """快捷获取 [Performance] 配置节""" + return get_config().get_section("Performance") + + +def get_database_config() -> Dict[str, Any]: + """快捷获取 [Database] 配置节""" + return get_config().get_section("Database") + + +def get_scheduler_config() -> Dict[str, Any]: + """快捷获取 [Scheduler] 配置节""" + return get_config().get_section("Scheduler") + + +def get_queue_config() -> Dict[str, Any]: + """快捷获取 [Queue] 配置节""" + return get_config().get_section("Queue") + + +def get_concurrency_config() -> Dict[str, Any]: + """快捷获取 [Concurrency] 配置节""" + return get_config().get_section("Concurrency") + + +def get_webui_config() -> Dict[str, Any]: + """快捷获取 [WebUI] 配置节""" + return get_config().get_section("WebUI") + + +# ==================== 导出列表 ==================== + +__all__ = [ + 'ConfigManager', + 'get_config', + 'get_bot_config', + 'get_performance_config', + 'get_database_config', + 'get_scheduler_config', + 'get_queue_config', + 'get_concurrency_config', + 'get_webui_config', +] diff --git a/utils/decorators.py b/utils/decorators.py index f99e53e..2879abf 100644 --- a/utils/decorators.py +++ b/utils/decorators.py @@ -1,5 +1,12 @@ +""" +消息处理装饰器模块 + +提供插件消息处理和定时任务的装饰器 +使用工厂模式消除重复代码 +""" + from functools import wraps -from typing import Callable, Union +from typing import Callable, Dict, Union from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.triggers.cron import CronTrigger @@ -8,15 +15,16 @@ from apscheduler.triggers.interval import IntervalTrigger scheduler = AsyncIOScheduler() +# ==================== 定时任务装饰器 ==================== + def schedule( trigger: Union[str, CronTrigger, IntervalTrigger], **trigger_args ) -> Callable: """ 定时任务装饰器 - - 例子: + 例子: - @schedule('interval', seconds=30) - @schedule('cron', hour=8, minute=30, second=30) - @schedule('date', run_date='2024-01-01 00:00:00') @@ -44,23 +52,16 @@ def add_job_safe(scheduler: AsyncIOScheduler, job_id: str, func: Callable, bot, """添加函数到定时任务中,如果存在则先删除现有的任务""" try: scheduler.remove_job(job_id) - except: + except Exception: pass - - # 读取调度器配置 + + # 使用统一配置管理器读取调度器配置 try: - import tomllib - from pathlib import Path - config_path = Path("main_config.toml") - if config_path.exists(): - with open(config_path, "rb") as f: - config = tomllib.load(f) - scheduler_config = config.get("Scheduler", {}) - else: - scheduler_config = {} - except: + from utils.config_manager import get_scheduler_config + scheduler_config = get_scheduler_config() + except Exception: scheduler_config = {} - + # 应用调度器配置 job_kwargs = { "coalesce": scheduler_config.get("coalesce", True), @@ -68,7 +69,7 @@ def add_job_safe(scheduler: AsyncIOScheduler, job_id: str, func: Callable, bot, "misfire_grace_time": scheduler_config.get("misfire_grace_time", 30) } job_kwargs.update(trigger_args) - + scheduler.add_job(func, trigger, args=[bot], id=job_id, **job_kwargs) @@ -76,182 +77,106 @@ def remove_job_safe(scheduler: AsyncIOScheduler, job_id: str): """从定时任务中移除任务""" try: scheduler.remove_job(job_id) - except: + except Exception: pass -def on_text_message(priority=50): - """文本消息装饰器""" +# ==================== 消息装饰器工厂 ==================== - def decorator(func): - if callable(priority): # 无参数调用时 - f = priority - setattr(f, '_event_type', 'text_message') - setattr(f, '_priority', 50) - return f - # 有参数调用时 - setattr(func, '_event_type', 'text_message') - setattr(func, '_priority', min(max(priority, 0), 99)) - return func +def _create_message_decorator(event_type: str, description: str): + """ + 消息装饰器工厂函数 - return decorator if not callable(priority) else decorator(priority) + 生成支持两种调用方式的装饰器: + - @on_xxx_message (无参数,使用默认优先级50) + - @on_xxx_message(priority=80) (有参数,自定义优先级) + Args: + event_type: 事件类型字符串,如 'text_message' + description: 装饰器描述,用于生成文档字符串 -def on_image_message(priority=50): - """图片消息装饰器""" + Returns: + 装饰器函数 + """ + def decorator_factory(priority=50): + def decorator(func): + # 处理无参数调用: @on_xxx_message 时 priority 实际是被装饰的函数 + if callable(priority): + target_func = priority + actual_priority = 50 + else: + target_func = func + actual_priority = min(max(priority, 0), 99) - def decorator(func): + setattr(target_func, '_event_type', event_type) + setattr(target_func, '_priority', actual_priority) + return target_func + + # 判断调用方式 if callable(priority): - f = priority - setattr(f, '_event_type', 'image_message') - setattr(f, '_priority', 50) - return f - setattr(func, '_event_type', 'image_message') - setattr(func, '_priority', min(max(priority, 0), 99)) - return func + return decorator(priority) + return decorator - return decorator if not callable(priority) else decorator(priority) + decorator_factory.__doc__ = f"{description}装饰器" + decorator_factory.__name__ = f"on_{event_type}" + return decorator_factory -def on_voice_message(priority=50): - """语音消息装饰器""" +# ==================== 消息类型定义 ==================== - def decorator(func): - if callable(priority): - f = priority - setattr(f, '_event_type', 'voice_message') - setattr(f, '_priority', 50) - return f - setattr(func, '_event_type', 'voice_message') - setattr(func, '_priority', min(max(priority, 0), 99)) - return func - - return decorator if not callable(priority) else decorator(priority) +# 事件类型 -> 中文描述 映射表 +MESSAGE_DECORATOR_TYPES: Dict[str, str] = { + 'text_message': '文本消息', + 'image_message': '图片消息', + 'voice_message': '语音消息', + 'video_message': '视频消息', + 'emoji_message': '表情消息', + 'file_message': '文件消息', + 'quote_message': '引用消息', + 'pat_message': '拍一拍', + 'at_message': '@消息', + 'system_message': '系统消息', + 'other_message': '其他消息', +} -def on_emoji_message(priority=50): - """表情消息装饰器""" +# ==================== 生成所有消息装饰器 ==================== - def decorator(func): - if callable(priority): - f = priority - setattr(f, '_event_type', 'emoji_message') - setattr(f, '_priority', 50) - return f - setattr(func, '_event_type', 'emoji_message') - setattr(func, '_priority', min(max(priority, 0), 99)) - return func - - return decorator if not callable(priority) else decorator(priority) +# 使用工厂函数生成装饰器 +on_text_message = _create_message_decorator('text_message', '文本消息') +on_image_message = _create_message_decorator('image_message', '图片消息') +on_voice_message = _create_message_decorator('voice_message', '语音消息') +on_video_message = _create_message_decorator('video_message', '视频消息') +on_emoji_message = _create_message_decorator('emoji_message', '表情消息') +on_file_message = _create_message_decorator('file_message', '文件消息') +on_quote_message = _create_message_decorator('quote_message', '引用消息') +on_pat_message = _create_message_decorator('pat_message', '拍一拍') +on_at_message = _create_message_decorator('at_message', '@消息') +on_system_message = _create_message_decorator('system_message', '系统消息') +on_other_message = _create_message_decorator('other_message', '其他消息') -def on_file_message(priority=50): - """文件消息装饰器""" +# ==================== 导出列表 ==================== - def decorator(func): - if callable(priority): - f = priority - setattr(f, '_event_type', 'file_message') - setattr(f, '_priority', 50) - return f - setattr(func, '_event_type', 'file_message') - setattr(func, '_priority', min(max(priority, 0), 99)) - return func - - return decorator if not callable(priority) else decorator(priority) - - -def on_quote_message(priority=50): - """引用消息装饰器""" - - def decorator(func): - if callable(priority): - f = priority - setattr(f, '_event_type', 'quote_message') - setattr(f, '_priority', 50) - return f - setattr(func, '_event_type', 'quote_message') - setattr(func, '_priority', min(max(priority, 0), 99)) - return func - - return decorator if not callable(priority) else decorator(priority) - - -def on_video_message(priority=50): - """视频消息装饰器""" - - def decorator(func): - if callable(priority): - f = priority - setattr(f, '_event_type', 'video_message') - setattr(f, '_priority', 50) - return f - setattr(func, '_event_type', 'video_message') - setattr(func, '_priority', min(max(priority, 0), 99)) - return func - - return decorator if not callable(priority) else decorator(priority) - - -def on_pat_message(priority=50): - """拍一拍消息装饰器""" - - def decorator(func): - if callable(priority): - f = priority - setattr(f, '_event_type', 'pat_message') - setattr(f, '_priority', 50) - return f - setattr(func, '_event_type', 'pat_message') - setattr(func, '_priority', min(max(priority, 0), 99)) - return func - - return decorator if not callable(priority) else decorator(priority) - - -def on_at_message(priority=50): - """被@消息装饰器""" - - def decorator(func): - if callable(priority): - f = priority - setattr(f, '_event_type', 'at_message') - setattr(f, '_priority', 50) - return f - setattr(func, '_event_type', 'at_message') - setattr(func, '_priority', min(max(priority, 0), 99)) - return func - - return decorator if not callable(priority) else decorator(priority) - - -def on_system_message(priority=50): - """其他消息装饰器""" - - def decorator(func): - if callable(priority): - f = priority - setattr(f, '_event_type', 'system_message') - setattr(f, '_priority', 50) - return f - setattr(func, '_event_type', 'other_message') - setattr(func, '_priority', min(max(priority, 0), 99)) - return func - - return decorator if not callable(priority) else decorator(priority) - - -def on_other_message(priority=50): - """其他消息装饰器""" - - def decorator(func): - if callable(priority): - f = priority - setattr(f, '_event_type', 'other_message') - setattr(f, '_priority', 50) - return f - setattr(func, '_event_type', 'other_message') - setattr(func, '_priority', min(max(priority, 0), 99)) - return func - - return decorator if not callable(priority) else decorator(priority) +__all__ = [ + # 定时任务 + 'scheduler', + 'schedule', + 'add_job_safe', + 'remove_job_safe', + # 消息装饰器 + 'on_text_message', + 'on_image_message', + 'on_voice_message', + 'on_video_message', + 'on_emoji_message', + 'on_file_message', + 'on_quote_message', + 'on_pat_message', + 'on_at_message', + 'on_system_message', + 'on_other_message', + # 工具 + 'MESSAGE_DECORATOR_TYPES', + '_create_message_decorator', +] diff --git a/utils/errors.py b/utils/errors.py new file mode 100644 index 0000000..a62a81d --- /dev/null +++ b/utils/errors.py @@ -0,0 +1,438 @@ +""" +统一错误处理模块 + +提供: +- 自定义异常类层次结构 +- 错误包装和转换 +- 用户友好的错误消息 +- 错误日志和追踪 + +使用示例: + from utils.errors import PluginError, ToolExecutionError, handle_error + + try: + await some_operation() + except Exception as e: + result = handle_error(e, context="执行工具") + # result = {"success": False, "error": "...", "error_type": "..."} +""" + +from __future__ import annotations + +import traceback +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, Optional, Type + +from loguru import logger + + +# ==================== 错误类型枚举 ==================== + +class ErrorType(Enum): + """错误类型分类""" + UNKNOWN = "unknown" + PLUGIN = "plugin" + TOOL = "tool" + CONFIG = "config" + NETWORK = "network" + TIMEOUT = "timeout" + VALIDATION = "validation" + PERMISSION = "permission" + RESOURCE = "resource" + + +# ==================== 自定义异常基类 ==================== + +class BotError(Exception): + """机器人错误基类""" + + error_type: ErrorType = ErrorType.UNKNOWN + user_message: str = "发生了一个错误" + log_level: str = "error" + + def __init__( + self, + message: str, + user_message: str = None, + cause: Exception = None, + context: Dict[str, Any] = None, + ): + super().__init__(message) + self.message = message + self._user_message = user_message + self.cause = cause + self.context = context or {} + + def get_user_message(self) -> str: + """获取用户友好的错误消息""" + return self._user_message or self.user_message + + def to_dict(self) -> Dict[str, Any]: + """转换为字典(用于 API 响应)""" + return { + "success": False, + "error": self.get_user_message(), + "error_type": self.error_type.value, + "details": self.message if self.message != self.get_user_message() else None, + } + + +# ==================== 具体异常类 ==================== + +class PluginError(BotError): + """插件相关错误""" + error_type = ErrorType.PLUGIN + user_message = "插件执行出错" + + +class PluginLoadError(PluginError): + """插件加载错误""" + user_message = "插件加载失败" + + +class PluginNotFoundError(PluginError): + """插件未找到""" + user_message = "找不到指定的插件" + + +class ToolExecutionError(BotError): + """工具执行错误""" + error_type = ErrorType.TOOL + user_message = "工具执行失败" + + +class ToolNotFoundError(ToolExecutionError): + """工具未找到""" + user_message = "找不到指定的工具" + + +class ToolTimeoutError(ToolExecutionError): + """工具执行超时""" + error_type = ErrorType.TIMEOUT + user_message = "工具执行超时" + + +class ConfigError(BotError): + """配置相关错误""" + error_type = ErrorType.CONFIG + user_message = "配置错误" + + +class ConfigNotFoundError(ConfigError): + """配置项未找到""" + user_message = "找不到配置项" + + +class ConfigValidationError(ConfigError): + """配置验证错误""" + error_type = ErrorType.VALIDATION + user_message = "配置格式不正确" + + +class NetworkError(BotError): + """网络相关错误""" + error_type = ErrorType.NETWORK + user_message = "网络请求失败" + + +class APIError(NetworkError): + """API 调用错误""" + user_message = "API 调用失败" + + +class ValidationError(BotError): + """验证错误""" + error_type = ErrorType.VALIDATION + user_message = "参数验证失败" + + +class PermissionError(BotError): + """权限错误""" + error_type = ErrorType.PERMISSION + user_message = "没有权限执行此操作" + + +class ResourceError(BotError): + """资源错误(内存、文件等)""" + error_type = ErrorType.RESOURCE + user_message = "资源访问失败" + + +# ==================== 错误处理工具函数 ==================== + +@dataclass +class ErrorResult: + """错误处理结果""" + success: bool = False + error: str = "" + error_type: str = "unknown" + details: Optional[str] = None + logged: bool = False + original_exception: Optional[Exception] = field(default=None, repr=False) + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + result = { + "success": self.success, + "error": self.error, + "error_type": self.error_type, + } + if self.details: + result["details"] = self.details + return result + + +def handle_error( + exception: Exception, + context: str = "", + log: bool = True, + include_traceback: bool = False, +) -> ErrorResult: + """ + 统一错误处理函数 + + Args: + exception: 捕获的异常 + context: 错误上下文描述 + log: 是否记录日志 + include_traceback: 是否包含完整堆栈 + + Returns: + ErrorResult 对象 + """ + # 处理自定义异常 + if isinstance(exception, BotError): + result = ErrorResult( + success=False, + error=exception.get_user_message(), + error_type=exception.error_type.value, + details=exception.message if exception.message != exception.get_user_message() else None, + original_exception=exception, + ) + if log: + log_func = getattr(logger, exception.log_level, logger.error) + log_func(f"[{context}] {exception.error_type.value}: {exception.message}") + result.logged = True + return result + + # 处理标准超时异常 + import asyncio + if isinstance(exception, asyncio.TimeoutError): + result = ErrorResult( + success=False, + error="操作超时", + error_type=ErrorType.TIMEOUT.value, + original_exception=exception, + ) + if log: + logger.warning(f"[{context}] 超时: {exception}") + result.logged = True + return result + + # 处理连接错误 + if isinstance(exception, (ConnectionError, OSError)): + result = ErrorResult( + success=False, + error="网络连接失败", + error_type=ErrorType.NETWORK.value, + details=str(exception), + original_exception=exception, + ) + if log: + logger.error(f"[{context}] 网络错误: {exception}") + result.logged = True + return result + + # 处理验证错误 + if isinstance(exception, (ValueError, TypeError)): + result = ErrorResult( + success=False, + error="参数错误", + error_type=ErrorType.VALIDATION.value, + details=str(exception), + original_exception=exception, + ) + if log: + logger.warning(f"[{context}] 验证错误: {exception}") + result.logged = True + return result + + # 处理未知错误 + error_msg = str(exception) or exception.__class__.__name__ + details = None + if include_traceback: + details = traceback.format_exc() + + result = ErrorResult( + success=False, + error=f"发生错误: {error_msg[:100]}", + error_type=ErrorType.UNKNOWN.value, + details=details, + original_exception=exception, + ) + + if log: + logger.error(f"[{context}] 未知错误: {exception}") + if include_traceback: + logger.debug(traceback.format_exc()) + result.logged = True + + return result + + +def wrap_error( + exception: Exception, + error_class: Type[BotError], + message: str = None, + user_message: str = None, +) -> BotError: + """ + 将标准异常包装为自定义异常 + + Args: + exception: 原始异常 + error_class: 目标异常类 + message: 错误消息 + user_message: 用户友好消息 + + Returns: + 包装后的 BotError 子类实例 + """ + msg = message or str(exception) + return error_class( + message=msg, + user_message=user_message, + cause=exception, + ) + + +def safe_error_message(exception: Exception, max_length: int = 200) -> str: + """ + 获取安全的错误消息(截断过长内容,移除敏感信息) + + Args: + exception: 异常对象 + max_length: 最大长度 + + Returns: + 安全的错误消息字符串 + """ + msg = str(exception) + + # 移除可能的敏感信息模式 + sensitive_patterns = [ + r'api[_-]?key[=:]\s*\S+', + r'password[=:]\s*\S+', + r'token[=:]\s*\S+', + r'secret[=:]\s*\S+', + ] + + import re + for pattern in sensitive_patterns: + msg = re.sub(pattern, '[REDACTED]', msg, flags=re.IGNORECASE) + + # 截断 + if len(msg) > max_length: + msg = msg[:max_length] + "..." + + return msg + + +# ==================== 装饰器 ==================== + +def catch_errors( + error_class: Type[BotError] = BotError, + context: str = "", + log: bool = True, + reraise: bool = False, +): + """ + 错误捕获装饰器 + + Args: + error_class: 转换为的错误类 + context: 上下文描述 + log: 是否记录日志 + reraise: 是否重新抛出 + + Usage: + @catch_errors(ToolExecutionError, context="执行工具") + async def my_tool(): + ... + """ + def decorator(func): + import asyncio + import functools + + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + try: + return await func(*args, **kwargs) + except BotError: + if reraise: + raise + return None + except Exception as e: + ctx = context or func.__name__ + handle_error(e, context=ctx, log=log) + if reraise: + raise wrap_error(e, error_class) from e + return None + + @functools.wraps(func) + def sync_wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except BotError: + if reraise: + raise + return None + except Exception as e: + ctx = context or func.__name__ + handle_error(e, context=ctx, log=log) + if reraise: + raise wrap_error(e, error_class) from e + return None + + if asyncio.iscoroutinefunction(func): + return async_wrapper + return sync_wrapper + + return decorator + + +# ==================== 导出列表 ==================== + +__all__ = [ + # 枚举 + 'ErrorType', + # 异常基类 + 'BotError', + # 插件异常 + 'PluginError', + 'PluginLoadError', + 'PluginNotFoundError', + # 工具异常 + 'ToolExecutionError', + 'ToolNotFoundError', + 'ToolTimeoutError', + # 配置异常 + 'ConfigError', + 'ConfigNotFoundError', + 'ConfigValidationError', + # 网络异常 + 'NetworkError', + 'APIError', + # 其他异常 + 'ValidationError', + 'PermissionError', + 'ResourceError', + # 工具函数 + 'ErrorResult', + 'handle_error', + 'wrap_error', + 'safe_error_message', + # 装饰器 + 'catch_errors', +] diff --git a/utils/event_manager.py b/utils/event_manager.py index 3eab59e..e2b8203 100644 --- a/utils/event_manager.py +++ b/utils/event_manager.py @@ -1,71 +1,315 @@ -import copy -from typing import Callable, Dict, List +""" +事件管理器模块 + +提供事件的注册、分发和管理: +- 优先级事件处理 +- 处理器缓存优化 +- 事件统计 +- 异常隔离 +""" + +import asyncio +import time +import traceback +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Set, Tuple + +from loguru import logger + + +@dataclass +class HandlerInfo: + """事件处理器信息""" + handler: Callable + instance: object + priority: int + handler_name: str = field(default="") + + def __post_init__(self): + if not self.handler_name: + self.handler_name = f"{self.instance.__class__.__name__}.{self.handler.__name__}" + + +@dataclass +class EventStats: + """事件统计信息""" + emit_count: int = 0 + handler_calls: int = 0 + total_time_ms: float = 0 + error_count: int = 0 + stopped_count: int = 0 # 被 return False 中断的次数 class EventManager: - _handlers: Dict[str, List[tuple[Callable, object, int]]] = {} + """ + 事件管理器 + + 特性: + - 优先级排序(高优先级先执行) + - 处理器可返回 False 中断后续处理 + - 异常隔离(单个处理器异常不影响其他) + - 性能统计 + """ + + # 类级别存储 + _handlers: Dict[str, List[HandlerInfo]] = {} + _stats: Dict[str, EventStats] = defaultdict(EventStats) + _handler_cache: Dict[str, List[HandlerInfo]] = {} # 排序后的缓存 + _cache_valid: Set[str] = set() @classmethod def bind_instance(cls, instance: object): - """将实例绑定到对应的事件处理函数""" - from loguru import logger - registered_count = 0 - for method_name in dir(instance): - method = getattr(instance, method_name) - if hasattr(method, '_event_type'): - event_type = getattr(method, '_event_type') - priority = getattr(method, '_priority', 50) + """ + 绑定实例的事件处理方法 - if event_type not in cls._handlers: - cls._handlers[event_type] = [] - cls._handlers[event_type].append((method, instance, priority)) - # 按优先级排序,优先级高的在前 - cls._handlers[event_type].sort(key=lambda x: x[2], reverse=True) - registered_count += 1 - logger.debug(f"[EventManager] 注册事件处理器: {instance.__class__.__name__}.{method_name} -> {event_type} (优先级={priority})") + 扫描实例的所有方法,将带有 _event_type 属性的方法注册为事件处理器。 + + Args: + instance: 插件实例 + """ + registered_count = 0 + + for method_name in dir(instance): + if method_name.startswith('_'): + continue + + try: + method = getattr(instance, method_name) + except Exception: + continue + + if not callable(method) or not hasattr(method, '_event_type'): + continue + + event_type = getattr(method, '_event_type') + priority = getattr(method, '_priority', 50) + + handler_info = HandlerInfo( + handler=method, + instance=instance, + priority=priority, + ) + + if event_type not in cls._handlers: + cls._handlers[event_type] = [] + + cls._handlers[event_type].append(handler_info) + + # 使缓存失效 + cls._cache_valid.discard(event_type) + + registered_count += 1 + logger.debug( + f"[EventManager] 注册: {handler_info.handler_name} -> " + f"{event_type} (优先级={priority})" + ) if registered_count > 0: - logger.success(f"[EventManager] {instance.__class__.__name__} 注册了 {registered_count} 个事件处理器") - - @classmethod - async def emit(cls, event_type: str, *args, **kwargs) -> None: - """触发事件""" - from loguru import logger - - if event_type not in cls._handlers: - logger.debug(f"[EventManager] 事件 {event_type} 没有注册的处理器") - return - - logger.debug(f"[EventManager] 触发事件: {event_type}, 处理器数量: {len(cls._handlers[event_type])}") - - api_client, message = args - for handler, instance, priority in cls._handlers[event_type]: - try: - logger.debug(f"[EventManager] 调用处理器: {instance.__class__.__name__}.{handler.__name__}") - # 不再深拷贝message,让所有处理器共享同一个消息对象 - # 这样AutoReply设置的标记可以传递给AIChat - handler_args = (api_client, message) - new_kwargs = kwargs # kwargs也不需要深拷贝 - - result = await handler(*handler_args, **new_kwargs) - - if isinstance(result, bool): - # True 继续执行 False 停止执行 - if not result: - break - else: - continue # 我也不知道你返回了个啥玩意,反正继续执行就是了 - except Exception as e: - import traceback - logger.error(f"处理器 {handler.__name__} 执行失败: {e}") - logger.error(f"详细错误: {traceback.format_exc()}") + logger.success( + f"[EventManager] {instance.__class__.__name__} " + f"注册了 {registered_count} 个事件处理器" + ) @classmethod def unbind_instance(cls, instance: object): - """解绑实例的所有事件处理函数""" - for event_type in cls._handlers: + """ + 解绑实例的所有事件处理器 + + Args: + instance: 插件实例 + """ + unbound_count = 0 + + for event_type in list(cls._handlers.keys()): + original_count = len(cls._handlers[event_type]) cls._handlers[event_type] = [ - (handler, inst, priority) - for handler, inst, priority in cls._handlers[event_type] - if inst is not instance + h for h in cls._handlers[event_type] + if h.instance is not instance ] + removed = original_count - len(cls._handlers[event_type]) + if removed > 0: + unbound_count += removed + cls._cache_valid.discard(event_type) + + # 清理空列表 + if not cls._handlers[event_type]: + del cls._handlers[event_type] + + if unbound_count > 0: + logger.debug( + f"[EventManager] {instance.__class__.__name__} " + f"解绑了 {unbound_count} 个事件处理器" + ) + + @classmethod + def _get_sorted_handlers(cls, event_type: str) -> List[HandlerInfo]: + """获取排序后的处理器列表(带缓存)""" + if event_type not in cls._cache_valid: + handlers = cls._handlers.get(event_type, []) + # 按优先级降序排序 + cls._handler_cache[event_type] = sorted( + handlers, + key=lambda h: h.priority, + reverse=True + ) + cls._cache_valid.add(event_type) + + return cls._handler_cache.get(event_type, []) + + @classmethod + async def emit(cls, event_type: str, *args, **kwargs) -> bool: + """ + 触发事件 + + Args: + event_type: 事件类型 + *args: 传递给处理器的位置参数(通常是 api_client, message) + **kwargs: 传递给处理器的关键字参数 + + Returns: + True 表示所有处理器都执行了,False 表示被中断 + """ + handlers = cls._get_sorted_handlers(event_type) + + if not handlers: + logger.debug(f"[EventManager] 事件 {event_type} 没有处理器") + return True + + # 更新统计 + stats = cls._stats[event_type] + stats.emit_count += 1 + + start_time = time.time() + all_completed = True + + logger.debug( + f"[EventManager] 触发: {event_type}, " + f"处理器数量: {len(handlers)}" + ) + + for handler_info in handlers: + stats.handler_calls += 1 + + try: + logger.debug(f"[EventManager] 调用: {handler_info.handler_name}") + + result = await handler_info.handler(*args, **kwargs) + + # 检查是否中断 + if result is False: + stats.stopped_count += 1 + all_completed = False + logger.debug( + f"[EventManager] {handler_info.handler_name} " + f"返回 False,中断事件处理" + ) + break + + except Exception as e: + stats.error_count += 1 + logger.error( + f"[EventManager] {handler_info.handler_name} 执行失败: {e}" + ) + logger.debug(f"详细错误:\n{traceback.format_exc()}") + # 继续执行其他处理器 + + elapsed_ms = (time.time() - start_time) * 1000 + stats.total_time_ms += elapsed_ms + + return all_completed + + @classmethod + async def emit_parallel( + cls, + event_type: str, + *args, + max_concurrency: int = 5, + **kwargs + ) -> List[Any]: + """ + 并行触发事件(忽略优先级和中断) + + 适用于不需要顺序执行的场景。 + + Args: + event_type: 事件类型 + max_concurrency: 最大并发数 + *args, **kwargs: 传递给处理器的参数 + + Returns: + 所有处理器的返回值列表 + """ + handlers = cls._get_sorted_handlers(event_type) + + if not handlers: + return [] + + semaphore = asyncio.Semaphore(max_concurrency) + + async def run_handler(handler_info: HandlerInfo): + async with semaphore: + try: + return await handler_info.handler(*args, **kwargs) + except Exception as e: + logger.error(f"[EventManager] {handler_info.handler_name} 失败: {e}") + return None + + tasks = [run_handler(h) for h in handlers] + return await asyncio.gather(*tasks, return_exceptions=True) + + @classmethod + def get_handlers(cls, event_type: str) -> List[str]: + """获取事件的所有处理器名称""" + handlers = cls._get_sorted_handlers(event_type) + return [h.handler_name for h in handlers] + + @classmethod + def get_all_events(cls) -> List[str]: + """获取所有已注册的事件类型""" + return list(cls._handlers.keys()) + + @classmethod + def get_stats(cls, event_type: str = None) -> Dict[str, Any]: + """ + 获取事件统计信息 + + Args: + event_type: 指定事件类型,None 返回所有 + + Returns: + 统计信息字典 + """ + if event_type: + stats = cls._stats.get(event_type, EventStats()) + return { + "emit_count": stats.emit_count, + "handler_calls": stats.handler_calls, + "total_time_ms": stats.total_time_ms, + "avg_time_ms": stats.total_time_ms / max(stats.emit_count, 1), + "error_count": stats.error_count, + "stopped_count": stats.stopped_count, + } + + return { + event: cls.get_stats(event) + for event in cls._stats.keys() + } + + @classmethod + def reset_stats(cls): + """重置所有统计""" + cls._stats.clear() + + @classmethod + def clear(cls): + """清除所有处理器和统计(用于测试)""" + cls._handlers.clear() + cls._handler_cache.clear() + cls._cache_valid.clear() + cls._stats.clear() + + +# ==================== 导出 ==================== + +__all__ = ['EventManager', 'HandlerInfo', 'EventStats'] diff --git a/utils/hookbot.py b/utils/hookbot.py index 848dc19..855f4b1 100644 --- a/utils/hookbot.py +++ b/utils/hookbot.py @@ -2,25 +2,48 @@ HookBot - 机器人核心类 处理消息路由和事件分发 +职责单一化:仅负责消息流程编排,具体功能委托给专门模块 """ -import asyncio -import tomllib -import time -from typing import Dict, Any +import random +from typing import Any, Dict, Optional + from loguru import logger from WechatHook import WechatHookClient, MESSAGE_TYPE_MAP, normalize_message from utils.event_manager import EventManager +from utils.config_manager import get_bot_config, get_performance_config +from utils.message_filter import MessageFilter +from utils.message_dedup import MessageDeduplicator +from utils.message_stats import MessageStats class HookBot: """ HookBot 核心类 - 负责消息处理、路由和事件分发 + 负责消息处理流程编排: + 1. 接收消息 + 2. 去重检查 + 3. 格式转换 + 4. 过滤检查 + 5. 事件分发 + + 具体功能委托给: + - MessageDeduplicator: 消息去重 + - MessageFilter: 消息过滤 + - MessageStats: 消息统计 """ + # API 响应消息类型(需要忽略) + API_RESPONSE_TYPES = {11032, 11174, 11230} + + # 重要消息类型(始终记录日志) + IMPORTANT_MESSAGE_TYPES = { + 11058, 11098, 11099, 11025, # 系统消息、群成员变动、登录信息 + 11051, 11047, 11052, 11055 # 视频、图片、表情、文件消息 + } + def __init__(self, client: WechatHookClient): """ 初始化 HookBot @@ -29,92 +52,33 @@ class HookBot: client: WechatHookClient 实例 """ self.client = client - self.wxid = None - self.nickname = None + self.wxid: Optional[str] = None + self.nickname: Optional[str] = None - # 读取配置 - with open("main_config.toml", "rb") as f: - main_config = tomllib.load(f) + # 加载配置 + bot_config = get_bot_config() + perf_config = get_performance_config() - bot_config = main_config.get("Bot", {}) + # 预设机器人信息 preset_wxid = bot_config.get("wxid") or bot_config.get("bot_wxid") preset_nickname = bot_config.get("nickname") or bot_config.get("bot_nickname") - if preset_wxid: self.wxid = preset_wxid logger.info(f"使用配置中的机器人 wxid: {self.wxid}") if preset_nickname: self.nickname = preset_nickname logger.info(f"使用配置中的机器人昵称: {self.nickname}") - self.ignore_mode = bot_config.get("ignore-mode", "None") - self.whitelist = bot_config.get("whitelist", []) - self.blacklist = bot_config.get("blacklist", []) - # 性能配置 - perf_config = main_config.get("Performance", {}) + # 日志采样率 self.log_sampling_rate = perf_config.get("log_sampling_rate", 1.0) - # 消息去重(部分环境会重复回调同一条消息,导致插件回复两次) - self._dedup_ttl_seconds = perf_config.get("dedup_ttl_seconds", 30) - self._dedup_max_size = perf_config.get("dedup_max_size", 5000) - self._dedup_lock = asyncio.Lock() - self._recent_message_keys: Dict[str, float] = {} - - # 消息计数和统计 - self.message_count = 0 - self.filtered_count = 0 - self.processed_count = 0 + # 初始化组件(职责委托) + self._filter = MessageFilter.from_config(bot_config) + self._dedup = MessageDeduplicator.from_config(perf_config) + self._stats = MessageStats() logger.info("HookBot 初始化完成") - def _extract_msg_id(self, data: Dict[str, Any]) -> str: - """从原始回调数据中提取消息ID(用于去重)""" - for k in ("msgid", "msg_id", "MsgId", "id"): - v = data.get(k) - if v: - return str(v) - return "" - - async def _is_duplicate_message(self, msg_type: int, data: Dict[str, Any]) -> bool: - """判断该条消息是否为短时间内重复回调。""" - msg_id = self._extract_msg_id(data) - if not msg_id: - # 没有稳定 msgid 时不做去重,避免误伤(同一秒内同内容可能是用户真实重复发送) - return False - - key = f"msgid:{msg_id}" - - now = time.time() - ttl = max(float(self._dedup_ttl_seconds or 0), 0.0) - if ttl <= 0: - return False - - async with self._dedup_lock: - last_seen = self._recent_message_keys.get(key) - if last_seen is not None and (now - last_seen) < ttl: - return True - - # 记录/刷新 - self._recent_message_keys.pop(key, None) - self._recent_message_keys[key] = now - - # 清理过期 key(按插入顺序从旧到新) - cutoff = now - ttl - while self._recent_message_keys: - first_key = next(iter(self._recent_message_keys)) - if self._recent_message_keys.get(first_key, now) >= cutoff: - break - self._recent_message_keys.pop(first_key, None) - - # 限制大小,避免长期运行内存增长 - max_size = int(self._dedup_max_size or 0) - if max_size > 0: - while len(self._recent_message_keys) > max_size and self._recent_message_keys: - first_key = next(iter(self._recent_message_keys)) - self._recent_message_keys.pop(first_key, None) - - return False - def update_profile(self, wxid: str, nickname: str): """ 更新机器人信息 @@ -125,9 +89,10 @@ class HookBot: """ self.wxid = wxid self.nickname = nickname + self._filter.set_bot_wxid(wxid) logger.info(f"机器人信息: wxid={wxid}, nickname={nickname}") - async def process_message(self, msg_type: int, data: dict): + async def process_message(self, msg_type: int, data: Dict[str, Any]): """ 处理接收到的消息 @@ -135,131 +100,105 @@ class HookBot: msg_type: 消息类型 data: 消息数据 """ - # 过滤 API 响应消息 - # - 11032: 获取群成员信息响应 - # - 11174/11230: 协议/上传等 API 回调 - if msg_type in [11032, 11174, 11230]: + # 1. 过滤 API 响应消息 + if msg_type in self.API_RESPONSE_TYPES: return - # 去重:同一条消息重复回调时不再重复触发事件(避免“同一句话回复两次”) + # 2. 去重检查 try: - if await self._is_duplicate_message(msg_type, data): - logger.debug(f"[HookBot] 重复消息已丢弃: type={msg_type}, msgid={self._extract_msg_id(data) or 'N/A'}") + if await self._dedup.is_duplicate(data): + self._stats.record_duplicate() + logger.debug(f"[HookBot] 重复消息已丢弃: type={msg_type}") return except Exception as e: - # 去重失败不影响主流程 logger.debug(f"[HookBot] 消息去重检查失败: {e}") - # 消息计数 - self.message_count += 1 - - # 日志采样 - 只记录部分消息以减少日志量 + # 3. 记录收到消息 + self._stats.record_received() should_log = self._should_log_message(msg_type) if should_log: logger.debug(f"处理消息: type={msg_type}") # 重要事件始终记录 - if msg_type in [11098, 11099, 11058]: # 群成员变动、系统消息 + if msg_type in self.IMPORTANT_MESSAGE_TYPES: logger.info(f"重要事件: type={msg_type}") - # 获取事件类型 + # 4. 获取事件类型 event_type = MESSAGE_TYPE_MAP.get(msg_type) - + if should_log and event_type: logger.info(f"[HookBot] 消息类型映射: {msg_type} -> {event_type}") if not event_type: - # 记录未知消息类型的详细信息,帮助调试 content_preview = str(data.get('raw_msg', data.get('msg', '')))[:200] - logger.warning(f"未映射的消息类型: {msg_type}, wx_type: {data.get('wx_type')}, 内容预览: {content_preview}") + logger.warning( + f"未映射的消息类型: {msg_type}, " + f"wx_type: {data.get('wx_type')}, " + f"内容预览: {content_preview}" + ) return - # 格式转换 + # 5. 格式转换 try: message = normalize_message(msg_type, data) except Exception as e: logger.error(f"格式转换失败: {e}") + self._stats.record_error() return - # 过滤消息 - if not self._check_filter(message): - self.filtered_count += 1 + # 6. 过滤检查 + if not self._filter.should_process(message): + self._stats.record_filtered() if should_log: logger.debug(f"消息被过滤: {message.get('FromWxid')}") return - self.processed_count += 1 + # 7. 记录处理 + self._stats.record_processed(event_type) - # 采样记录处理的消息 if should_log: content = message.get('Content', '') if len(content) > 50: content = content[:50] + "..." - logger.info(f"处理消息: type={event_type}, from={message.get('FromWxid')}, content={content}") + logger.info( + f"处理消息: type={event_type}, " + f"from={message.get('FromWxid')}, " + f"content={content}" + ) - # 触发事件 + # 8. 触发事件 try: await EventManager.emit(event_type, self.client, message) except Exception as e: logger.error(f"事件处理失败: {e}") + self._stats.record_error() def _should_log_message(self, msg_type: int) -> bool: """判断是否应该记录此消息的日志""" - # 重要消息类型始终记录 - important_types = { - 11058, 11098, 11099, 11025, # 系统消息、群成员变动、登录信息 - 11051, 11047, 11052, 11055 # 视频、图片、表情、文件消息 - } - if msg_type in important_types: + if msg_type in self.IMPORTANT_MESSAGE_TYPES: return True - - # 其他消息按采样率记录 - import random return random.random() < self.log_sampling_rate - def _check_filter(self, message: Dict[str, Any]) -> bool: - """ - 检查消息是否通过过滤 - - Args: - message: 消息字典 - - Returns: - 是否通过过滤 - """ - from_wxid = message.get("FromWxid", "") - sender_wxid = message.get("SenderWxid", "") - msg_type = message.get("MsgType", 0) - - # 系统消息(type=11058)不过滤,因为包含重要的群聊事件信息 - if msg_type == 11058: - return True - - # 过滤机器人自己发送的消息,避免无限循环 - if self.wxid and (from_wxid == self.wxid or sender_wxid == self.wxid): - return False - - # None 模式:处理所有消息 - if self.ignore_mode == "None": - return True - - # Whitelist 模式:仅处理白名单 - if self.ignore_mode == "Whitelist": - return from_wxid in self.whitelist or sender_wxid in self.whitelist - - # Blacklist 模式:屏蔽黑名单 - if self.ignore_mode == "Blacklist": - return from_wxid not in self.blacklist and sender_wxid not in self.blacklist - - return True - - def get_stats(self) -> dict: + def get_stats(self) -> Dict[str, Any]: """获取消息处理统计信息""" - return { - "total_messages": self.message_count, - "filtered_messages": self.filtered_count, - "processed_messages": self.processed_count, - "filter_rate": self.filtered_count / max(self.message_count, 1), - "process_rate": self.processed_count / max(self.message_count, 1) - } + stats = self._stats.get_stats() + stats["dedup"] = self._dedup.get_stats() + return stats + + # ==================== 兼容旧接口 ==================== + + @property + def message_count(self) -> int: + """兼容旧接口:总消息数""" + return self._stats.get_stats()["total_messages"] + + @property + def filtered_count(self) -> int: + """兼容旧接口:被过滤消息数""" + return self._stats.get_stats()["filtered_messages"] + + @property + def processed_count(self) -> int: + """兼容旧接口:已处理消息数""" + return self._stats.get_stats()["processed_messages"] diff --git a/utils/image_processor.py b/utils/image_processor.py new file mode 100644 index 0000000..1aae4cc --- /dev/null +++ b/utils/image_processor.py @@ -0,0 +1,690 @@ +""" +图片/视频处理模块 + +提供媒体文件的下载、编码和描述生成: +- 图片下载与 base64 编码 +- 表情包下载与编码 +- 视频下载与编码 +- AI 图片/视频描述生成 + +使用示例: + from utils.image_processor import ImageProcessor, MediaConfig + + config = MediaConfig( + api_url="https://api.openai.com/v1/chat/completions", + api_key="sk-xxx", + model="gpt-4-vision-preview", + ) + processor = ImageProcessor(config) + + # 下载图片 + image_base64 = await processor.download_image(bot, cdnurl, aeskey) + + # 生成描述 + description = await processor.generate_description(image_base64, "描述这张图片") +""" + +from __future__ import annotations + +import asyncio +import base64 +import json +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Optional, TYPE_CHECKING + +import aiohttp +from loguru import logger + +# 可选代理支持 +try: + from aiohttp_socks import ProxyConnector + PROXY_SUPPORT = True +except ImportError: + PROXY_SUPPORT = False + +if TYPE_CHECKING: + pass # bot 类型提示 + + +@dataclass +class MediaConfig: + """媒体处理配置""" + # API 配置 + api_url: str = "https://api.openai.com/v1/chat/completions" + api_key: str = "" + model: str = "gpt-4-vision-preview" + timeout: int = 120 + max_tokens: int = 1000 + retries: int = 2 + + # 代理配置 + proxy_enabled: bool = False + proxy_type: str = "socks5" + proxy_host: str = "127.0.0.1" + proxy_port: int = 7890 + proxy_username: str = "" + proxy_password: str = "" + + # 视频专用配置 + video_api_url: str = "" + video_model: str = "" + video_max_size_mb: int = 20 + video_timeout: int = 360 + video_max_tokens: int = 8192 + + # 临时目录 + temp_dir: Optional[Path] = None + + @classmethod + def from_dict(cls, config: Dict[str, Any]) -> "MediaConfig": + """从配置字典创建""" + api_config = config.get("api", {}) + proxy_config = config.get("proxy", {}) + image_desc_config = config.get("image_description", {}) + video_config = config.get("video_recognition", {}) + + return cls( + api_url=api_config.get("url", "https://api.openai.com/v1/chat/completions"), + api_key=api_config.get("api_key", ""), + model=image_desc_config.get("model", api_config.get("model", "gpt-4-vision-preview")), + timeout=api_config.get("timeout", 120), + max_tokens=image_desc_config.get("max_tokens", 1000), + retries=image_desc_config.get("retries", 2), + proxy_enabled=proxy_config.get("enabled", False), + proxy_type=proxy_config.get("type", "socks5"), + proxy_host=proxy_config.get("host", "127.0.0.1"), + proxy_port=proxy_config.get("port", 7890), + proxy_username=proxy_config.get("username", ""), + proxy_password=proxy_config.get("password", ""), + video_api_url=video_config.get("api_url", ""), + video_model=video_config.get("model", ""), + video_max_size_mb=video_config.get("max_size_mb", 20), + video_timeout=video_config.get("timeout", 360), + video_max_tokens=video_config.get("max_tokens", 8192), + ) + + +@dataclass +class MediaResult: + """媒体处理结果""" + success: bool = False + data: str = "" # base64 数据 + description: str = "" + error: Optional[str] = None + media_type: str = "image" # image, emoji, video + + +class ImageProcessor: + """ + 图片/视频处理器 + + 提供统一的媒体处理接口: + - 下载和编码 + - AI 描述生成 + - 缓存支持 + """ + + def __init__(self, config: MediaConfig, temp_dir: Optional[Path] = None): + self.config = config + self.temp_dir = temp_dir or config.temp_dir or Path("temp") + self.temp_dir.mkdir(exist_ok=True) + + def _get_proxy_connector(self) -> Optional[Any]: + """获取代理连接器""" + if not self.config.proxy_enabled or not PROXY_SUPPORT: + return None + + proxy_type = self.config.proxy_type.upper() + if self.config.proxy_username and self.config.proxy_password: + proxy_url = ( + f"{proxy_type}://{self.config.proxy_username}:" + f"{self.config.proxy_password}@" + f"{self.config.proxy_host}:{self.config.proxy_port}" + ) + else: + proxy_url = f"{proxy_type}://{self.config.proxy_host}:{self.config.proxy_port}" + + try: + return ProxyConnector.from_url(proxy_url) + except Exception as e: + logger.warning(f"[ImageProcessor] 代理配置失败: {e}") + return None + + async def download_image( + self, + bot, + cdnurl: str, + aeskey: str, + use_cache: bool = True, + ) -> str: + """ + 下载图片并转换为 base64 + + Args: + bot: WechatHookClient 实例(用于 CDN 下载) + cdnurl: CDN URL + aeskey: AES 密钥 + use_cache: 是否使用缓存 + + Returns: + base64 编码的图片数据(带 data URI 前缀) + """ + try: + # 1. 优先从 Redis 缓存获取 + if use_cache: + from utils.redis_cache import RedisCache, get_cache + redis_cache = get_cache() + if redis_cache and redis_cache.enabled: + media_key = RedisCache.generate_media_key(cdnurl, aeskey) + if media_key: + cached_data = redis_cache.get_cached_media(media_key, "image") + if cached_data: + logger.debug(f"[ImageProcessor] 图片缓存命中: {media_key[:20]}...") + return cached_data + + # 2. 缓存未命中,下载图片 + logger.debug(f"[ImageProcessor] 开始下载图片...") + + filename = f"temp_{uuid.uuid4().hex[:8]}.jpg" + save_path = str((self.temp_dir / filename).resolve()) + + # 尝试下载中图,失败则下载原图 + success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=2) + if not success: + success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=1) + + if not success: + logger.error("[ImageProcessor] CDN 下载失败") + return "" + + # 等待文件写入完成 + import os + for _ in range(20): # 最多等待10秒 + if os.path.exists(save_path) and os.path.getsize(save_path) > 0: + break + await asyncio.sleep(0.5) + + if not os.path.exists(save_path): + logger.error("[ImageProcessor] 图片文件未生成") + return "" + + with open(save_path, "rb") as f: + image_data = base64.b64encode(f.read()).decode() + + base64_result = f"data:image/jpeg;base64,{image_data}" + + # 3. 缓存到 Redis + if use_cache: + try: + from utils.redis_cache import RedisCache, get_cache + redis_cache = get_cache() + if redis_cache and redis_cache.enabled: + media_key = RedisCache.generate_media_key(cdnurl, aeskey) + if media_key: + redis_cache.cache_media(media_key, base64_result, "image", ttl=300) + logger.debug(f"[ImageProcessor] 图片已缓存: {media_key[:20]}...") + except Exception as e: + logger.debug(f"[ImageProcessor] 缓存图片失败: {e}") + + # 清理临时文件 + try: + Path(save_path).unlink() + except Exception: + pass + + return base64_result + + except Exception as e: + logger.error(f"[ImageProcessor] 下载图片失败: {e}") + return "" + + async def download_emoji( + self, + cdn_url: str, + max_retries: int = 3, + use_cache: bool = True, + ) -> str: + """ + 下载表情包并转换为 base64 + + Args: + cdn_url: CDN URL + max_retries: 最大重试次数 + use_cache: 是否使用缓存 + + Returns: + base64 编码的表情包数据(带 data URI 前缀) + """ + # 替换 HTML 实体 + cdn_url = cdn_url.replace("&", "&") + + # 1. 优先从 Redis 缓存获取 + media_key = None + if use_cache: + try: + from utils.redis_cache import RedisCache, get_cache + redis_cache = get_cache() + media_key = RedisCache.generate_media_key(cdnurl=cdn_url) + if redis_cache and redis_cache.enabled and media_key: + cached_data = redis_cache.get_cached_media(media_key, "emoji") + if cached_data: + logger.debug(f"[ImageProcessor] 表情包缓存命中: {media_key[:20]}...") + return cached_data + except Exception: + pass + + # 2. 缓存未命中,下载表情包 + logger.debug(f"[ImageProcessor] 开始下载表情包...") + + last_error = None + connector = self._get_proxy_connector() + + for attempt in range(max_retries): + try: + timeout = aiohttp.ClientTimeout(total=30 + attempt * 15) + + async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: + async with session.get(cdn_url) as response: + if response.status == 200: + content = await response.read() + + if len(content) == 0: + logger.warning(f"[ImageProcessor] 表情包内容为空,重试 {attempt + 1}/{max_retries}") + continue + + image_data = base64.b64encode(content).decode() + base64_result = f"data:image/gif;base64,{image_data}" + + logger.debug(f"[ImageProcessor] 表情包下载成功,大小: {len(content)} 字节") + + # 3. 缓存到 Redis + if use_cache and media_key: + try: + from utils.redis_cache import get_cache + redis_cache = get_cache() + if redis_cache and redis_cache.enabled: + redis_cache.cache_media(media_key, base64_result, "emoji", ttl=300) + logger.debug(f"[ImageProcessor] 表情包已缓存: {media_key[:20]}...") + except Exception: + pass + + return base64_result + else: + logger.warning(f"[ImageProcessor] 表情包下载失败,状态码: {response.status}") + + except asyncio.TimeoutError: + last_error = "请求超时" + logger.warning(f"[ImageProcessor] 表情包下载超时,重试 {attempt + 1}/{max_retries}") + except aiohttp.ClientError as e: + last_error = str(e) + logger.warning(f"[ImageProcessor] 表情包下载网络错误: {e}") + except Exception as e: + last_error = str(e) + logger.warning(f"[ImageProcessor] 表情包下载异常: {e}") + + if attempt < max_retries - 1: + await asyncio.sleep(1 * (attempt + 1)) + + logger.error(f"[ImageProcessor] 表情包下载失败,已重试 {max_retries} 次: {last_error}") + return "" + + async def download_video( + self, + bot, + cdnurl: str, + aeskey: str, + use_cache: bool = True, + ) -> str: + """ + 下载视频并转换为 base64 + + Args: + bot: WechatHookClient 实例 + cdnurl: CDN URL + aeskey: AES 密钥 + use_cache: 是否使用缓存 + + Returns: + base64 编码的视频数据(带 data URI 前缀) + """ + try: + # 从缓存获取 + media_key = None + if use_cache: + try: + from utils.redis_cache import RedisCache, get_cache + redis_cache = get_cache() + if redis_cache and redis_cache.enabled: + media_key = RedisCache.generate_media_key(cdnurl, aeskey) + if media_key: + cached_data = redis_cache.get_cached_media(media_key, "video") + if cached_data: + logger.debug(f"[ImageProcessor] 视频缓存命中: {media_key[:20]}...") + return cached_data + except Exception: + pass + + # 下载视频 + logger.info(f"[ImageProcessor] 开始下载视频...") + + filename = f"video_{uuid.uuid4().hex[:8]}.mp4" + save_path = str((self.temp_dir / filename).resolve()) + + # file_type=4 表示视频 + success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=4) + if not success: + logger.error("[ImageProcessor] 视频 CDN 下载失败") + return "" + + # 等待文件写入完成 + import os + for _ in range(30): + if os.path.exists(save_path) and os.path.getsize(save_path) > 0: + break + await asyncio.sleep(0.5) + + if not os.path.exists(save_path): + logger.error("[ImageProcessor] 视频文件未生成") + return "" + + file_size = os.path.getsize(save_path) + logger.info(f"[ImageProcessor] 视频下载完成,大小: {file_size / 1024 / 1024:.2f} MB") + + # 检查文件大小限制 + max_size_mb = self.config.video_max_size_mb + if file_size > max_size_mb * 1024 * 1024: + logger.warning(f"[ImageProcessor] 视频文件过大: {file_size / 1024 / 1024:.2f} MB > {max_size_mb} MB") + try: + Path(save_path).unlink() + except Exception: + pass + return "" + + # 读取并编码 + with open(save_path, "rb") as f: + video_data = base64.b64encode(f.read()).decode() + + video_base64 = f"data:video/mp4;base64,{video_data}" + + # 缓存到 Redis + if use_cache and media_key: + try: + from utils.redis_cache import get_cache + redis_cache = get_cache() + if redis_cache and redis_cache.enabled: + redis_cache.cache_media(media_key, video_base64, "video", ttl=600) + logger.debug(f"[ImageProcessor] 视频已缓存: {media_key[:20]}...") + except Exception: + pass + + # 清理临时文件 + try: + Path(save_path).unlink() + except Exception: + pass + + return video_base64 + + except Exception as e: + logger.error(f"[ImageProcessor] 下载视频失败: {e}") + import traceback + logger.error(traceback.format_exc()) + return "" + + async def generate_description( + self, + image_base64: str, + prompt: str = "请用一句话简洁地描述这张图片的主要内容。", + model: Optional[str] = None, + ) -> str: + """ + 使用 AI 生成图片描述 + + Args: + image_base64: 图片的 base64 数据 + prompt: 描述提示词 + model: 使用的模型(默认使用配置中的模型) + + Returns: + 图片描述文本,失败返回空字符串 + """ + description_model = model or self.config.model + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": image_base64}} + ] + } + ] + + payload = { + "model": description_model, + "messages": messages, + "max_tokens": self.config.max_tokens, + "stream": True + } + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.config.api_key}" + } + + max_retries = self.config.retries + last_error = None + + for attempt in range(max_retries + 1): + try: + timeout = aiohttp.ClientTimeout(total=self.config.timeout) + connector = self._get_proxy_connector() + + async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: + async with session.post( + self.config.api_url, + json=payload, + headers=headers + ) as resp: + if resp.status != 200: + error_text = await resp.text() + raise Exception(f"API 返回错误: {resp.status}, {error_text[:200]}") + + # 流式接收响应 + description = "" + async for line in resp.content: + line = line.decode('utf-8').strip() + if not line or line == "data: [DONE]": + continue + + if line.startswith("data: "): + try: + data = json.loads(line[6:]) + delta = data.get("choices", [{}])[0].get("delta", {}) + content = delta.get("content", "") + if content: + description += content + except Exception: + pass + + logger.debug(f"[ImageProcessor] 图片描述生成成功: {description[:50]}...") + return description.strip() + + except asyncio.CancelledError: + raise + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + last_error = str(e) + if attempt < max_retries: + logger.warning(f"[ImageProcessor] 图片描述网络错误: {e},重试 {attempt + 1}/{max_retries}") + await asyncio.sleep(1 * (attempt + 1)) + continue + except Exception as e: + last_error = str(e) + if attempt < max_retries: + logger.warning(f"[ImageProcessor] 图片描述生成异常: {e},重试 {attempt + 1}/{max_retries}") + await asyncio.sleep(1 * (attempt + 1)) + continue + + logger.error(f"[ImageProcessor] 生成图片描述失败,已重试 {max_retries + 1} 次: {last_error}") + return "" + + async def analyze_video( + self, + video_base64: str, + prompt: Optional[str] = None, + ) -> str: + """ + 使用 AI 分析视频内容 + + Args: + video_base64: 视频的 base64 数据 + prompt: 分析提示词 + + Returns: + 视频分析描述,失败返回空字符串 + """ + if not self.config.video_api_url or not self.config.video_model: + logger.error("[ImageProcessor] 视频分析配置不完整") + return "" + + # 去除 data:video/mp4;base64, 前缀(如果有) + if video_base64.startswith("data:"): + video_base64 = video_base64.split(",", 1)[1] + + default_prompt = """请详细分析这个视频的内容,包括: +1. 视频的主要场景和环境 +2. 出现的人物/物体及其动作 +3. 视频中的文字、对话或声音(如果有) +4. 视频的整体主题或要表达的内容 +5. 任何值得注意的细节 + +请用客观、详细的方式描述,不要加入主观评价。""" + + analyze_prompt = prompt or default_prompt + + full_url = f"{self.config.video_api_url}/{self.config.video_model}:generateContent" + + payload = { + "contents": [ + { + "parts": [ + {"text": analyze_prompt}, + { + "inline_data": { + "mime_type": "video/mp4", + "data": video_base64 + } + } + ] + } + ], + "generationConfig": { + "maxOutputTokens": self.config.video_max_tokens + } + } + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.config.api_key}" + } + + timeout = aiohttp.ClientTimeout(total=self.config.video_timeout) + max_retries = 2 + retry_delay = 5 + + for attempt in range(max_retries + 1): + try: + logger.info(f"[ImageProcessor] 开始分析视频...{f' (重试 {attempt}/{max_retries})' if attempt > 0 else ''}") + + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(full_url, json=payload, headers=headers) as resp: + if resp.status in [502, 503, 504]: + logger.warning(f"[ImageProcessor] 视频 API 临时错误: {resp.status}") + if attempt < max_retries: + await asyncio.sleep(retry_delay) + continue + return "" + + if resp.status != 200: + error_text = await resp.text() + logger.error(f"[ImageProcessor] 视频 API 错误: {resp.status}, {error_text[:300]}") + return "" + + result = await resp.json() + + # 检查安全过滤 + if "promptFeedback" in result: + feedback = result["promptFeedback"] + if feedback.get("blockReason"): + logger.warning(f"[ImageProcessor] 视频内容被过滤: {feedback.get('blockReason')}") + return "" + + # 提取文本 + if "candidates" in result and result["candidates"]: + for candidate in result["candidates"]: + if candidate.get("finishReason") == "SAFETY": + logger.warning("[ImageProcessor] 视频响应被安全过滤") + return "" + + content = candidate.get("content", {}) + for part in content.get("parts", []): + if "text" in part: + text = part["text"] + logger.info(f"[ImageProcessor] 视频分析完成,长度: {len(text)}") + return text + + logger.error(f"[ImageProcessor] 视频分析无有效响应") + return "" + + except asyncio.TimeoutError: + logger.warning(f"[ImageProcessor] 视频分析超时{f', 将重试...' if attempt < max_retries else ''}") + if attempt < max_retries: + await asyncio.sleep(retry_delay) + continue + return "" + except Exception as e: + logger.error(f"[ImageProcessor] 视频分析失败: {e}") + import traceback + logger.error(traceback.format_exc()) + return "" + + return "" + + +# ==================== 便捷函数 ==================== + +_default_processor: Optional[ImageProcessor] = None + + +def get_image_processor(config: Optional[MediaConfig] = None) -> ImageProcessor: + """获取默认图片处理器""" + global _default_processor + if config: + _default_processor = ImageProcessor(config) + if _default_processor is None: + raise ValueError("ImageProcessor 未初始化,请先传入配置") + return _default_processor + + +def init_image_processor(config_dict: Dict[str, Any], temp_dir: Optional[Path] = None) -> ImageProcessor: + """从配置字典初始化图片处理器""" + config = MediaConfig.from_dict(config_dict) + if temp_dir: + config.temp_dir = temp_dir + processor = ImageProcessor(config, temp_dir) + global _default_processor + _default_processor = processor + return processor + + +# ==================== 导出 ==================== + +__all__ = [ + 'MediaConfig', + 'MediaResult', + 'ImageProcessor', + 'get_image_processor', + 'init_image_processor', +] diff --git a/utils/llm_client.py b/utils/llm_client.py new file mode 100644 index 0000000..51f0bee --- /dev/null +++ b/utils/llm_client.py @@ -0,0 +1,392 @@ +""" +LLM 客户端抽象层 + +提供统一的 LLM API 调用接口: +- 支持 OpenAI 兼容 API +- 自动重试和错误处理 +- 流式/非流式响应 +- 代理支持 +- Token 估算 + +使用示例: + from utils.llm_client import LLMClient, LLMConfig + + config = LLMConfig( + api_base="https://api.openai.com/v1", + api_key="sk-xxx", + model="gpt-4", + ) + client = LLMClient(config) + + response = await client.chat_completion( + messages=[{"role": "user", "content": "Hello"}], + tools=[...], + ) +""" + +from __future__ import annotations + +import asyncio +import json +import time +from dataclasses import dataclass, field +from typing import Any, AsyncGenerator, Dict, List, Optional, Union + +import aiohttp +from loguru import logger + +# 可选代理支持 +try: + from aiohttp_socks import ProxyConnector + PROXY_SUPPORT = True +except ImportError: + PROXY_SUPPORT = False + + +@dataclass +class LLMConfig: + """LLM 配置""" + api_base: str = "https://api.openai.com/v1" + api_key: str = "" + model: str = "gpt-4" + temperature: float = 0.7 + max_tokens: int = 4096 + timeout: int = 120 + max_retries: int = 3 + retry_delay: float = 1.0 + + # 代理配置 + proxy_enabled: bool = False + proxy_type: str = "socks5" + proxy_host: str = "127.0.0.1" + proxy_port: int = 7890 + + # 额外参数 + extra_params: Dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_dict(cls, config: Dict[str, Any]) -> "LLMConfig": + """从配置字典创建""" + api_config = config.get("api", {}) + proxy_config = config.get("proxy", {}) + + return cls( + api_base=api_config.get("base_url", "https://api.openai.com/v1"), + api_key=api_config.get("api_key", ""), + model=api_config.get("model", "gpt-4"), + temperature=api_config.get("temperature", 0.7), + max_tokens=api_config.get("max_tokens", 4096), + timeout=api_config.get("timeout", 120), + max_retries=api_config.get("max_retries", 3), + retry_delay=api_config.get("retry_delay", 1.0), + proxy_enabled=proxy_config.get("enabled", False), + proxy_type=proxy_config.get("type", "socks5"), + proxy_host=proxy_config.get("host", "127.0.0.1"), + proxy_port=proxy_config.get("port", 7890), + ) + + +@dataclass +class LLMResponse: + """LLM 响应""" + content: str = "" + tool_calls: List[Dict[str, Any]] = field(default_factory=list) + finish_reason: str = "" + usage: Dict[str, int] = field(default_factory=dict) + raw_response: Dict[str, Any] = field(default_factory=dict) + error: Optional[str] = None + + @property + def has_tool_calls(self) -> bool: + return len(self.tool_calls) > 0 + + @property + def success(self) -> bool: + return self.error is None + + +class LLMClient: + """ + LLM 客户端 + + 提供统一的 API 调用接口,支持: + - OpenAI 兼容 API + - 自动重试 + - 代理 + - 流式响应 + """ + + def __init__(self, config: LLMConfig): + self.config = config + self._session: Optional[aiohttp.ClientSession] = None + + async def _get_session(self) -> aiohttp.ClientSession: + """获取或创建 HTTP 会话""" + if self._session is None or self._session.closed: + connector = None + + # 配置代理 + if self.config.proxy_enabled and PROXY_SUPPORT: + proxy_url = ( + f"{self.config.proxy_type}://" + f"{self.config.proxy_host}:{self.config.proxy_port}" + ) + connector = ProxyConnector.from_url(proxy_url) + logger.debug(f"[LLMClient] 使用代理: {proxy_url}") + + timeout = aiohttp.ClientTimeout(total=self.config.timeout) + self._session = aiohttp.ClientSession( + connector=connector, + timeout=timeout, + ) + + return self._session + + async def close(self): + """关闭会话""" + if self._session and not self._session.closed: + await self._session.close() + self._session = None + + def _build_headers(self) -> Dict[str, str]: + """构建请求头""" + return { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.config.api_key}", + } + + def _build_payload( + self, + messages: List[Dict[str, Any]], + tools: Optional[List[Dict[str, Any]]] = None, + stream: bool = False, + **kwargs, + ) -> Dict[str, Any]: + """构建请求体""" + payload = { + "model": self.config.model, + "messages": messages, + "temperature": self.config.temperature, + "max_tokens": self.config.max_tokens, + "stream": stream, + } + + if tools: + payload["tools"] = tools + payload["tool_choice"] = kwargs.get("tool_choice", "auto") + + # 合并额外参数 + payload.update(self.config.extra_params) + payload.update(kwargs) + + return payload + + async def chat_completion( + self, + messages: List[Dict[str, Any]], + tools: Optional[List[Dict[str, Any]]] = None, + **kwargs, + ) -> LLMResponse: + """ + 非流式聊天补全 + + Args: + messages: 消息列表 + tools: 工具列表(可选) + **kwargs: 额外参数 + + Returns: + LLMResponse 对象 + """ + session = await self._get_session() + url = f"{self.config.api_base}/chat/completions" + headers = self._build_headers() + payload = self._build_payload(messages, tools, stream=False, **kwargs) + + last_error = None + + for attempt in range(self.config.max_retries): + try: + async with session.post(url, headers=headers, json=payload) as resp: + if resp.status == 200: + data = await resp.json() + return self._parse_response(data) + + error_text = await resp.text() + last_error = f"HTTP {resp.status}: {error_text[:200]}" + logger.warning(f"[LLMClient] 请求失败 (尝试 {attempt + 1}): {last_error}") + + # 不可重试的错误 + if resp.status in [400, 401, 403]: + break + + except asyncio.TimeoutError: + last_error = f"请求超时 ({self.config.timeout}s)" + logger.warning(f"[LLMClient] {last_error} (尝试 {attempt + 1})") + + except Exception as e: + last_error = str(e) + logger.warning(f"[LLMClient] 请求异常 (尝试 {attempt + 1}): {e}") + + # 重试延迟 + if attempt < self.config.max_retries - 1: + await asyncio.sleep(self.config.retry_delay * (attempt + 1)) + + return LLMResponse(error=last_error) + + async def chat_completion_stream( + self, + messages: List[Dict[str, Any]], + tools: Optional[List[Dict[str, Any]]] = None, + **kwargs, + ) -> AsyncGenerator[str, None]: + """ + 流式聊天补全 + + Args: + messages: 消息列表 + tools: 工具列表(可选) + **kwargs: 额外参数 + + Yields: + 文本片段 + """ + session = await self._get_session() + url = f"{self.config.api_base}/chat/completions" + headers = self._build_headers() + payload = self._build_payload(messages, tools, stream=True, **kwargs) + + try: + async with session.post(url, headers=headers, json=payload) as resp: + if resp.status != 200: + error_text = await resp.text() + logger.error(f"[LLMClient] 流式请求失败: HTTP {resp.status}") + return + + async for line in resp.content: + line = line.decode("utf-8").strip() + if not line or not line.startswith("data: "): + continue + + data_str = line[6:] + if data_str == "[DONE]": + break + + try: + data = json.loads(data_str) + delta = data.get("choices", [{}])[0].get("delta", {}) + content = delta.get("content", "") + if content: + yield content + except json.JSONDecodeError: + continue + + except Exception as e: + logger.error(f"[LLMClient] 流式请求异常: {e}") + + def _parse_response(self, data: Dict[str, Any]) -> LLMResponse: + """解析 API 响应""" + try: + choice = data.get("choices", [{}])[0] + message = choice.get("message", {}) + + content = message.get("content", "") or "" + tool_calls = message.get("tool_calls", []) + finish_reason = choice.get("finish_reason", "") + usage = data.get("usage", {}) + + # 标准化 tool_calls + parsed_tool_calls = [] + for tc in tool_calls: + parsed_tool_calls.append({ + "id": tc.get("id", ""), + "type": tc.get("type", "function"), + "function": { + "name": tc.get("function", {}).get("name", ""), + "arguments": tc.get("function", {}).get("arguments", "{}"), + } + }) + + return LLMResponse( + content=content, + tool_calls=parsed_tool_calls, + finish_reason=finish_reason, + usage=usage, + raw_response=data, + ) + + except Exception as e: + logger.error(f"[LLMClient] 解析响应失败: {e}") + return LLMResponse(error=f"解析响应失败: {e}") + + # ==================== Token 估算 ==================== + + @staticmethod + def estimate_tokens(text: str) -> int: + """ + 估算文本的 token 数量 + + 使用简化规则: + - 英文约 4 字符 = 1 token + - 中文约 1.5 字符 = 1 token + """ + if not text: + return 0 + + chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') + other_chars = len(text) - chinese_chars + + chinese_tokens = chinese_chars / 1.5 + other_tokens = other_chars / 4 + + return int(chinese_tokens + other_tokens) + + @staticmethod + def estimate_message_tokens(message: Dict[str, Any]) -> int: + """估算单条消息的 token 数量""" + content = message.get("content", "") + + if isinstance(content, str): + return LLMClient.estimate_tokens(content) + 4 # role 等开销 + + if isinstance(content, list): + total = 4 + for item in content: + if isinstance(item, dict): + if item.get("type") == "text": + total += LLMClient.estimate_tokens(item.get("text", "")) + elif item.get("type") == "image_url": + total += 85 # 图片固定开销 + return total + + return 4 + + @staticmethod + def estimate_messages_tokens(messages: List[Dict[str, Any]]) -> int: + """估算消息列表的总 token 数量""" + return sum(LLMClient.estimate_message_tokens(m) for m in messages) + + +# ==================== 便捷函数 ==================== + +_default_client: Optional[LLMClient] = None + + +def get_llm_client(config: Optional[LLMConfig] = None) -> LLMClient: + """获取默认 LLM 客户端""" + global _default_client + if config: + _default_client = LLMClient(config) + if _default_client is None: + raise ValueError("LLM 客户端未初始化,请先传入配置") + return _default_client + + +# ==================== 导出 ==================== + +__all__ = [ + 'LLMConfig', + 'LLMResponse', + 'LLMClient', + 'get_llm_client', +] diff --git a/utils/llm_tooling.py b/utils/llm_tooling.py index a8d845c..763fd93 100644 --- a/utils/llm_tooling.py +++ b/utils/llm_tooling.py @@ -181,3 +181,91 @@ def validate_tool_arguments( return True, "", arguments + +# ==================== 工具注册中心集成 ==================== + +def register_plugin_tools( + plugin_name: str, + plugin: Any, + tools_config: Dict[str, Any], + timeout_config: Optional[Dict[str, Any]] = None, +) -> int: + """ + 将插件的 LLM 工具注册到全局工具注册中心 + + Args: + plugin_name: 插件名称 + plugin: 插件实例(需实现 get_llm_tools 和 execute_llm_tool) + tools_config: 工具配置(包含 mode, whitelist, blacklist) + timeout_config: 工具超时配置 {tool_name: timeout_seconds} + + Returns: + 注册的工具数量 + """ + from utils.tool_registry import get_tool_registry + + if not hasattr(plugin, "get_llm_tools") or not hasattr(plugin, "execute_llm_tool"): + return 0 + + registry = get_tool_registry() + timeout_config = timeout_config or {} + + mode = tools_config.get("mode", "all") + whitelist = set(tools_config.get("whitelist", [])) + blacklist = set(tools_config.get("blacklist", [])) + + plugin_tools = plugin.get_llm_tools() or [] + registered_count = 0 + + for tool in plugin_tools: + tool_name = tool.get("function", {}).get("name", "") + if not tool_name: + continue + + # 应用白名单/黑名单过滤 + if mode == "whitelist" and tool_name not in whitelist: + continue + if mode == "blacklist" and tool_name in blacklist: + logger.debug(f"[黑名单] 跳过注册工具: {tool_name}") + continue + + # 获取工具超时配置 + timeout = timeout_config.get(tool_name, timeout_config.get("default", 60)) + + # 创建执行器闭包 + async def make_executor(p, tn): + async def executor(tool_name: str, arguments: dict, bot, from_wxid: str): + return await p.execute_llm_tool(tool_name, arguments, bot, from_wxid) + return executor + + # 注册工具 + if registry.register( + name=tool_name, + plugin_name=plugin_name, + schema=tool, + executor=plugin.execute_llm_tool, + timeout=timeout, + ): + registered_count += 1 + if mode == "whitelist": + logger.debug(f"[白名单] 注册工具: {tool_name}") + + if registered_count > 0: + logger.info(f"插件 {plugin_name} 注册了 {registered_count} 个工具") + + return registered_count + + +def unregister_plugin_tools(plugin_name: str) -> int: + """ + 从全局工具注册中心注销插件的所有工具 + + Args: + plugin_name: 插件名称 + + Returns: + 注销的工具数量 + """ + from utils.tool_registry import get_tool_registry + return get_tool_registry().unregister_plugin(plugin_name) + diff --git a/utils/message_dedup.py b/utils/message_dedup.py new file mode 100644 index 0000000..d8f84b0 --- /dev/null +++ b/utils/message_dedup.py @@ -0,0 +1,145 @@ +""" +消息去重器模块 + +防止同一条消息被重复处理(某些环境下回调会重复触发) +""" + +import asyncio +import time +from typing import Any, Dict, Optional + +from loguru import logger + + +class MessageDeduplicator: + """ + 消息去重器 + + 使用基于时间的滑动窗口实现去重: + - 记录最近处理的消息 ID + - 在 TTL 时间内重复的消息会被过滤 + - 自动清理过期记录,限制内存占用 + """ + + def __init__( + self, + ttl_seconds: float = 30.0, + max_size: int = 5000, + ): + """ + 初始化去重器 + + Args: + ttl_seconds: 消息 ID 的有效期(秒),0 表示禁用去重 + max_size: 最大缓存条目数,防止内存泄漏 + """ + self.ttl_seconds = max(float(ttl_seconds), 0.0) + self.max_size = max(int(max_size), 0) + self._cache: Dict[str, float] = {} # key -> timestamp + self._lock = asyncio.Lock() + + @staticmethod + def extract_msg_id(data: Dict[str, Any]) -> str: + """ + 从原始消息数据中提取消息 ID + + Args: + data: 原始消息数据 + + Returns: + 消息 ID 字符串,提取失败返回空字符串 + """ + for key in ("msgid", "msg_id", "MsgId", "id"): + value = data.get(key) + if value: + return str(value) + return "" + + async def is_duplicate(self, data: Dict[str, Any]) -> bool: + """ + 检查消息是否重复 + + Args: + data: 原始消息数据 + + Returns: + True 表示是重复消息,False 表示是新消息 + """ + if self.ttl_seconds <= 0: + return False + + msg_id = self.extract_msg_id(data) + if not msg_id: + # 没有消息 ID 时不做去重,避免误判 + return False + + key = f"msgid:{msg_id}" + now = time.time() + + async with self._lock: + # 检查是否存在且未过期 + last_seen = self._cache.get(key) + if last_seen is not None and (now - last_seen) < self.ttl_seconds: + return True + + # 记录新消息 + self._cache.pop(key, None) # 确保插入到末尾(保持顺序) + self._cache[key] = now + + # 清理过期条目 + self._cleanup_expired(now) + + # 限制大小 + self._limit_size() + + return False + + def _cleanup_expired(self, now: float): + """清理过期条目(需在锁内调用)""" + cutoff = now - self.ttl_seconds + while self._cache: + first_key = next(iter(self._cache)) + if self._cache[first_key] >= cutoff: + break + self._cache.pop(first_key, None) + + def _limit_size(self): + """限制缓存大小(需在锁内调用)""" + if self.max_size <= 0: + return + while len(self._cache) > self.max_size: + first_key = next(iter(self._cache)) + self._cache.pop(first_key, None) + + def clear(self): + """清空缓存""" + self._cache.clear() + + def get_stats(self) -> Dict[str, Any]: + """获取统计信息""" + return { + "cached_count": len(self._cache), + "ttl_seconds": self.ttl_seconds, + "max_size": self.max_size, + } + + @classmethod + def from_config(cls, perf_config: Dict[str, Any]) -> "MessageDeduplicator": + """ + 从配置创建去重器 + + Args: + perf_config: Performance 配置节 + + Returns: + MessageDeduplicator 实例 + """ + return cls( + ttl_seconds=perf_config.get("dedup_ttl_seconds", 30), + max_size=perf_config.get("dedup_max_size", 5000), + ) + + +# ==================== 导出 ==================== + +__all__ = ['MessageDeduplicator'] diff --git a/utils/message_filter.py b/utils/message_filter.py new file mode 100644 index 0000000..acf06fc --- /dev/null +++ b/utils/message_filter.py @@ -0,0 +1,128 @@ +""" +消息过滤器模块 + +提供消息过滤功能: +- 白名单/黑名单过滤 +- 机器人自身消息过滤 +- 系统消息放行 +""" + +from typing import Any, Dict, List, Optional, Set + +from loguru import logger + + +class MessageFilter: + """ + 消息过滤器 + + 支持三种模式: + - None: 不过滤,处理所有消息 + - Whitelist: 只处理白名单中的消息 + - Blacklist: 过滤黑名单中的消息 + """ + + # 系统消息类型(始终放行) + SYSTEM_MESSAGE_TYPES = {11058} + + def __init__( + self, + mode: str = "None", + whitelist: List[str] = None, + blacklist: List[str] = None, + bot_wxid: str = None, + ): + """ + 初始化过滤器 + + Args: + mode: 过滤模式 ("None", "Whitelist", "Blacklist") + whitelist: 白名单 wxid 列表 + blacklist: 黑名单 wxid 列表 + bot_wxid: 机器人自身 wxid(用于过滤自己的消息) + """ + self.mode = mode + self.whitelist: Set[str] = set(whitelist or []) + self.blacklist: Set[str] = set(blacklist or []) + self.bot_wxid = bot_wxid + + def set_bot_wxid(self, wxid: str): + """设置机器人 wxid""" + self.bot_wxid = wxid + + def add_to_whitelist(self, wxid: str): + """添加到白名单""" + self.whitelist.add(wxid) + + def remove_from_whitelist(self, wxid: str): + """从白名单移除""" + self.whitelist.discard(wxid) + + def add_to_blacklist(self, wxid: str): + """添加到黑名单""" + self.blacklist.add(wxid) + + def remove_from_blacklist(self, wxid: str): + """从黑名单移除""" + self.blacklist.discard(wxid) + + def should_process(self, message: Dict[str, Any]) -> bool: + """ + 判断消息是否应该被处理 + + Args: + message: 标准化后的消息字典 + + Returns: + True 表示应该处理,False 表示应该过滤 + """ + from_wxid = message.get("FromWxid", "") + sender_wxid = message.get("SenderWxid", "") + msg_type = message.get("MsgType", 0) + + # 系统消息始终放行 + if msg_type in self.SYSTEM_MESSAGE_TYPES: + return True + + # 过滤机器人自己的消息 + if self.bot_wxid and (from_wxid == self.bot_wxid or sender_wxid == self.bot_wxid): + return False + + # 根据模式过滤 + return self._check_mode(from_wxid, sender_wxid) + + def _check_mode(self, from_wxid: str, sender_wxid: str) -> bool: + """根据模式检查是否放行""" + if self.mode == "None": + return True + + if self.mode == "Whitelist": + return from_wxid in self.whitelist or sender_wxid in self.whitelist + + if self.mode == "Blacklist": + return from_wxid not in self.blacklist and sender_wxid not in self.blacklist + + return True + + @classmethod + def from_config(cls, bot_config: Dict[str, Any]) -> "MessageFilter": + """ + 从配置创建过滤器 + + Args: + bot_config: Bot 配置节 + + Returns: + MessageFilter 实例 + """ + return cls( + mode=bot_config.get("ignore-mode", "None"), + whitelist=bot_config.get("whitelist", []), + blacklist=bot_config.get("blacklist", []), + bot_wxid=bot_config.get("wxid") or bot_config.get("bot_wxid"), + ) + + +# ==================== 导出 ==================== + +__all__ = ['MessageFilter'] diff --git a/utils/message_hook.py b/utils/message_hook.py index 896324d..5aa8e94 100644 --- a/utils/message_hook.py +++ b/utils/message_hook.py @@ -56,10 +56,18 @@ async def log_bot_message(to_wxid: str, content: str, msg_type: str = "text", me except Exception: pass + sync_content = content + if msg_type == "image": + sync_content = "[图片]" + elif msg_type == "video": + sync_content = "[视频]" + elif msg_type == "file": + sync_content = "[文件]" + await store.add_group_message( to_wxid, bot_nickname, - content, + sync_content, role="assistant", sender_wxid=bot_wxid or None, ) diff --git a/utils/message_queue.py b/utils/message_queue.py new file mode 100644 index 0000000..2082ffb --- /dev/null +++ b/utils/message_queue.py @@ -0,0 +1,305 @@ +""" +消息队列模块 + +提供高性能的优先级消息队列,支持多种溢出策略: +- drop_oldest: 丢弃最旧的消息 +- drop_lowest: 丢弃优先级最低的消息 +- sampling: 按采样率丢弃消息 +- reject: 拒绝新消息 +""" + +import asyncio +import heapq +import random +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +from loguru import logger + + +# ==================== 消息优先级常量 ==================== + +class MessagePriority: + """消息优先级常量""" + CRITICAL = 100 # 系统消息、登录信息 + HIGH = 80 # 管理员命令、群成员变动 + NORMAL = 50 # @bot 消息(默认) + LOW = 20 # 普通群消息 + + +# ==================== 溢出策略 ==================== + +class OverflowStrategy(Enum): + """队列溢出策略""" + DROP_OLDEST = "drop_oldest" # 丢弃最旧的消息 + DROP_LOWEST = "drop_lowest" # 丢弃优先级最低的消息 + SAMPLING = "sampling" # 按采样率丢弃 + REJECT = "reject" # 拒绝新消息 + + +# ==================== 优先级消息 ==================== + +@dataclass(order=True) +class PriorityMessage: + """优先级消息""" + priority: int = field(compare=True) + timestamp: float = field(compare=True) + msg_type: int = field(compare=False) + data: Dict[str, Any] = field(compare=False) + + def __init__(self, msg_type: int, data: Dict[str, Any], priority: int = None): + # 优先级越高,数值越大,但 heapq 是最小堆,所以取负数 + self.priority = -(priority if priority is not None else MessagePriority.NORMAL) + self.timestamp = time.time() + self.msg_type = msg_type + self.data = data + + +# ==================== 优先级消息队列 ==================== + +class PriorityMessageQueue: + """ + 优先级消息队列 + + 特性: + - 基于堆的优先级队列 + - 支持多种溢出策略 + - 线程安全(使用 asyncio.Lock) + - 支持任务计数和 join + """ + + def __init__( + self, + maxsize: int = 1000, + overflow_strategy: str = "drop_oldest", + sampling_rate: float = 0.5, + ): + """ + 初始化队列 + + Args: + maxsize: 最大队列大小 + overflow_strategy: 溢出策略 (drop_oldest, drop_lowest, sampling, reject) + sampling_rate: 采样策略的保留率 (0.0-1.0) + """ + self.maxsize = maxsize + self.overflow_strategy = OverflowStrategy(overflow_strategy) + self.sampling_rate = max(0.0, min(1.0, sampling_rate)) + + self._heap: List[PriorityMessage] = [] + self._lock = asyncio.Lock() + self._not_empty = asyncio.Event() + self._unfinished_tasks = 0 + self._finished = asyncio.Event() + self._finished.set() + + # 统计 + self._total_put = 0 + self._total_dropped = 0 + self._total_rejected = 0 + + def qsize(self) -> int: + """返回队列大小""" + return len(self._heap) + + def empty(self) -> bool: + """队列是否为空""" + return len(self._heap) == 0 + + def full(self) -> bool: + """队列是否已满""" + return len(self._heap) >= self.maxsize + + async def put( + self, + msg_type: int, + data: Dict[str, Any], + priority: int = None, + ) -> bool: + """ + 添加消息到队列 + + Args: + msg_type: 消息类型 + data: 消息数据 + priority: 优先级(可选) + + Returns: + 是否成功添加 + """ + async with self._lock: + self._total_put += 1 + + # 处理队列满的情况 + if self.full(): + if not self._handle_overflow(): + self._total_rejected += 1 + return False + + msg = PriorityMessage(msg_type, data, priority) + heapq.heappush(self._heap, msg) + self._unfinished_tasks += 1 + self._finished.clear() + self._not_empty.set() + return True + + def _handle_overflow(self) -> bool: + """ + 处理队列溢出 + + Returns: + True 表示成功腾出空间,False 表示拒绝 + """ + if self.overflow_strategy == OverflowStrategy.REJECT: + logger.warning("队列已满,拒绝新消息") + return False + + if self.overflow_strategy == OverflowStrategy.DROP_OLDEST: + # 找到最旧的消息(timestamp 最小) + if self._heap: + oldest_idx = 0 + for i, msg in enumerate(self._heap): + if msg.timestamp < self._heap[oldest_idx].timestamp: + oldest_idx = i + self._heap.pop(oldest_idx) + heapq.heapify(self._heap) + self._total_dropped += 1 + self._unfinished_tasks = max(0, self._unfinished_tasks - 1) + return True + + elif self.overflow_strategy == OverflowStrategy.DROP_LOWEST: + # 找到优先级最低的消息(priority 值最大,因为是负数) + if self._heap: + lowest_idx = 0 + for i, msg in enumerate(self._heap): + if msg.priority > self._heap[lowest_idx].priority: + lowest_idx = i + self._heap.pop(lowest_idx) + heapq.heapify(self._heap) + self._total_dropped += 1 + self._unfinished_tasks = max(0, self._unfinished_tasks - 1) + return True + + elif self.overflow_strategy == OverflowStrategy.SAMPLING: + # 按采样率决定是否接受 + if random.random() < self.sampling_rate: + # 接受新消息,丢弃最旧的 + if self._heap: + oldest_idx = 0 + for i, msg in enumerate(self._heap): + if msg.timestamp < self._heap[oldest_idx].timestamp: + oldest_idx = i + self._heap.pop(oldest_idx) + heapq.heapify(self._heap) + self._total_dropped += 1 + self._unfinished_tasks = max(0, self._unfinished_tasks - 1) + return True + else: + self._total_dropped += 1 + return False + + return False + + async def get(self, timeout: float = None) -> Tuple[int, Dict[str, Any]]: + """ + 获取优先级最高的消息 + + Args: + timeout: 超时时间(秒),None 表示无限等待 + + Returns: + (msg_type, data) 元组 + + Raises: + asyncio.TimeoutError: 超时 + """ + start_time = time.time() + + while True: + async with self._lock: + if self._heap: + msg = heapq.heappop(self._heap) + if not self._heap: + self._not_empty.clear() + return (msg.msg_type, msg.data) + + # 计算剩余超时时间 + if timeout is not None: + elapsed = time.time() - start_time + remaining = timeout - elapsed + if remaining <= 0: + raise asyncio.TimeoutError("Queue get timeout") + try: + await asyncio.wait_for(self._not_empty.wait(), timeout=remaining) + except asyncio.TimeoutError: + raise asyncio.TimeoutError("Queue get timeout") + else: + await self._not_empty.wait() + + def get_nowait(self) -> Tuple[int, Dict[str, Any]]: + """非阻塞获取消息""" + if not self._heap: + raise asyncio.QueueEmpty() + msg = heapq.heappop(self._heap) + if not self._heap: + self._not_empty.clear() + return (msg.msg_type, msg.data) + + def task_done(self): + """标记任务完成""" + self._unfinished_tasks = max(0, self._unfinished_tasks - 1) + if self._unfinished_tasks == 0: + self._finished.set() + + async def join(self): + """等待所有任务完成""" + await self._finished.wait() + + def clear(self): + """清空队列""" + self._heap.clear() + self._not_empty.clear() + self._unfinished_tasks = 0 + self._finished.set() + + def get_stats(self) -> Dict[str, Any]: + """获取队列统计信息""" + return { + "current_size": len(self._heap), + "max_size": self.maxsize, + "total_put": self._total_put, + "total_dropped": self._total_dropped, + "total_rejected": self._total_rejected, + "unfinished_tasks": self._unfinished_tasks, + "overflow_strategy": self.overflow_strategy.value, + "utilization": len(self._heap) / max(self.maxsize, 1), + } + + @classmethod + def from_config(cls, queue_config: Dict[str, Any]) -> "PriorityMessageQueue": + """ + 从配置创建队列 + + Args: + queue_config: Queue 配置节 + + Returns: + PriorityMessageQueue 实例 + """ + return cls( + maxsize=queue_config.get("max_size", 1000), + overflow_strategy=queue_config.get("overflow_strategy", "drop_oldest"), + sampling_rate=queue_config.get("sampling_rate", 0.5), + ) + + +# ==================== 导出 ==================== + +__all__ = [ + 'MessagePriority', + 'OverflowStrategy', + 'PriorityMessage', + 'PriorityMessageQueue', +] diff --git a/utils/message_stats.py b/utils/message_stats.py new file mode 100644 index 0000000..5c945ce --- /dev/null +++ b/utils/message_stats.py @@ -0,0 +1,114 @@ +""" +消息统计器模块 + +提供消息处理的统计功能: +- 消息计数 +- 过滤率统计 +- 按类型统计 +""" + +import time +from collections import defaultdict +from threading import Lock +from typing import Any, Dict + + +class MessageStats: + """ + 消息统计器 + + 线程安全的消息统计实现 + """ + + def __init__(self): + self._lock = Lock() + self._total_count = 0 + self._filtered_count = 0 + self._processed_count = 0 + self._duplicate_count = 0 + self._error_count = 0 + self._by_type: Dict[str, int] = defaultdict(int) + self._start_time = time.time() + + def record_received(self): + """记录收到消息""" + with self._lock: + self._total_count += 1 + + def record_filtered(self): + """记录被过滤的消息""" + with self._lock: + self._filtered_count += 1 + + def record_processed(self, event_type: str = None): + """ + 记录已处理的消息 + + Args: + event_type: 消息事件类型(可选) + """ + with self._lock: + self._processed_count += 1 + if event_type: + self._by_type[event_type] += 1 + + def record_duplicate(self): + """记录重复消息""" + with self._lock: + self._duplicate_count += 1 + + def record_error(self): + """记录处理错误""" + with self._lock: + self._error_count += 1 + + def get_stats(self) -> Dict[str, Any]: + """获取统计信息""" + with self._lock: + uptime = time.time() - self._start_time + total = max(self._total_count, 1) # 避免除零 + + return { + "total_messages": self._total_count, + "filtered_messages": self._filtered_count, + "processed_messages": self._processed_count, + "duplicate_messages": self._duplicate_count, + "error_count": self._error_count, + "filter_rate": self._filtered_count / total, + "process_rate": self._processed_count / total, + "duplicate_rate": self._duplicate_count / total, + "messages_per_minute": (self._total_count / uptime) * 60 if uptime > 0 else 0, + "uptime_seconds": uptime, + "by_type": dict(self._by_type), + } + + def reset(self): + """重置统计""" + with self._lock: + self._total_count = 0 + self._filtered_count = 0 + self._processed_count = 0 + self._duplicate_count = 0 + self._error_count = 0 + self._by_type.clear() + self._start_time = time.time() + + +# 全局单例(可选使用) +_global_stats: MessageStats = None +_stats_lock = Lock() + + +def get_message_stats() -> MessageStats: + """获取全局消息统计器实例""" + global _global_stats + if _global_stats is None: + with _stats_lock: + if _global_stats is None: + _global_stats = MessageStats() + return _global_stats + + +# ==================== 导出 ==================== + +__all__ = ['MessageStats', 'get_message_stats'] diff --git a/utils/plugin_base.py b/utils/plugin_base.py index 4295913..dfa1e6f 100644 --- a/utils/plugin_base.py +++ b/utils/plugin_base.py @@ -1,35 +1,144 @@ +""" +插件基类模块 + +提供插件的基础功能: +- 生命周期钩子(on_load, on_enable, on_disable, on_unload, on_reload) +- 定时任务管理 +- 依赖声明 +- 插件元数据 +""" + from abc import ABC -from typing import List +from enum import Enum +from typing import Any, Dict, List, Optional, TYPE_CHECKING from loguru import logger from .decorators import scheduler, add_job_safe, remove_job_safe +if TYPE_CHECKING: + from utils.plugin_manager import PluginManager + + +class PluginState(Enum): + """插件状态""" + UNLOADED = "unloaded" # 未加载 + LOADED = "loaded" # 已加载(未启用) + ENABLED = "enabled" # 已启用 + DISABLED = "disabled" # 已禁用 + ERROR = "error" # 错误状态 + class PluginBase(ABC): - """插件基类""" + """ + 插件基类 + + 生命周期: + 1. __init__() - 构造函数(同步) + 2. on_load() - 加载时调用(异步,可访问其他插件) + 3. async_init() - 异步初始化(异步,加载配置、资源等) + 4. on_enable() - 启用时调用(异步,注册定时任务) + 5. on_disable() - 禁用时调用(异步,清理定时任务) + 6. on_unload() - 卸载时调用(异步,释放资源) + 7. on_reload() - 重载前调用(异步,保存状态) + + 使用示例: + class MyPlugin(PluginBase): + description = "我的插件" + author = "作者" + version = "1.0.0" + dependencies = ["AIChat"] # 依赖的插件 + load_priority = 60 # 加载优先级 + + async def on_load(self, plugin_manager): + # 可以访问其他插件 + aichat = plugin_manager.plugins.get("AIChat") + + async def async_init(self): + # 加载配置、初始化资源 + self.config = load_config() + + async def on_enable(self, bot): + await super().on_enable(bot) # 注册定时任务 + # 额外的启用逻辑 + + async def on_disable(self): + await super().on_disable() # 清理定时任务 + # 额外的禁用逻辑 + + async def on_unload(self): + # 释放资源、关闭连接 + await self.close_connections() + + async def on_reload(self) -> dict: + # 返回需要保存的状态 + return {"counter": self.counter} + + async def restore_state(self, state: dict): + # 重载后恢复状态 + self.counter = state.get("counter", 0) + """ + + # ==================== 插件元数据 ==================== - # 插件元数据 description: str = "暂无描述" author: str = "未知" version: str = "1.0.0" # 插件依赖(填写依赖的插件类名列表) - # 例如: dependencies = ["MessageLogger", "AIChat"] dependencies: List[str] = [] # 加载优先级(数值越大越先加载,默认50) - # 基础插件设置高优先级,依赖其他插件的设置低优先级 load_priority: int = 50 + # ==================== 实例属性 ==================== + def __init__(self): self.enabled = False - self._scheduled_jobs = set() + self.state = PluginState.UNLOADED + self._scheduled_jobs: set = set() + self._bot = None + self._plugin_manager: Optional["PluginManager"] = None + self._saved_state: Dict[str, Any] = {} + + # ==================== 生命周期钩子 ==================== + + async def on_load(self, plugin_manager: "PluginManager"): + """ + 插件加载时调用 + + 此时其他依赖的插件已经加载完成,可以安全访问。 + + Args: + plugin_manager: 插件管理器实例 + """ + self._plugin_manager = plugin_manager + self.state = PluginState.LOADED + logger.debug(f"[{self.__class__.__name__}] on_load 调用") + + async def async_init(self): + """ + 插件异步初始化 + + 用于加载配置、初始化资源等耗时操作。 + 在 on_load 之后、on_enable 之前调用。 + """ + pass async def on_enable(self, bot=None): - """插件启用时调用""" + """ + 插件启用时调用 - # 定时任务 + 注册定时任务、启动后台服务等。 + + Args: + bot: WechatHookClient 实例 + """ + self._bot = bot + self.enabled = True + self.state = PluginState.ENABLED + + # 注册定时任务 for method_name in dir(self): method = getattr(self, method_name) if hasattr(method, '_is_scheduled'): @@ -39,18 +148,85 @@ class PluginBase(ABC): add_job_safe(scheduler, job_id, method, bot, trigger, **trigger_args) self._scheduled_jobs.add(job_id) + if self._scheduled_jobs: - logger.success("插件 {} 已加载定时任务: {}", self.__class__.__name__, self._scheduled_jobs) + logger.success(f"插件 {self.__class__.__name__} 已加载定时任务: {self._scheduled_jobs}") async def on_disable(self): - """插件禁用时调用""" - + """ + 插件禁用时调用 + + 清理定时任务、停止后台服务等。 + """ + self.enabled = False + self.state = PluginState.DISABLED + # 移除定时任务 for job_id in self._scheduled_jobs: remove_job_safe(scheduler, job_id) - logger.info("已卸载定时任务: {}", self._scheduled_jobs) + + if self._scheduled_jobs: + logger.info(f"已卸载定时任务: {self._scheduled_jobs}") self._scheduled_jobs.clear() - async def async_init(self): - """插件异步初始化""" - return + async def on_unload(self): + """ + 插件卸载时调用 + + 释放资源、关闭连接、保存数据等。 + 在 on_disable 之后调用。 + """ + self.state = PluginState.UNLOADED + self._bot = None + self._plugin_manager = None + logger.debug(f"[{self.__class__.__name__}] on_unload 调用") + + async def on_reload(self) -> Dict[str, Any]: + """ + 插件重载前调用 + + 返回需要在重载后恢复的状态数据。 + + Returns: + 需要保存的状态字典 + """ + logger.debug(f"[{self.__class__.__name__}] on_reload 调用") + return {} + + async def restore_state(self, state: Dict[str, Any]): + """ + 重载后恢复状态 + + Args: + state: on_reload 返回的状态字典 + """ + self._saved_state = state + logger.debug(f"[{self.__class__.__name__}] 状态已恢复: {list(state.keys())}") + + # ==================== 辅助方法 ==================== + + def get_plugin(self, plugin_name: str) -> Optional["PluginBase"]: + """ + 获取其他插件实例 + + Args: + plugin_name: 插件类名 + + Returns: + 插件实例,不存在返回 None + """ + if self._plugin_manager: + return self._plugin_manager.plugins.get(plugin_name) + return None + + def get_bot(self): + """获取 bot 实例""" + return self._bot + + @property + def plugin_name(self) -> str: + """获取插件名称""" + return self.__class__.__name__ + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} v{self.version} state={self.state.value}>" diff --git a/utils/plugin_inject.py b/utils/plugin_inject.py new file mode 100644 index 0000000..3a0a419 --- /dev/null +++ b/utils/plugin_inject.py @@ -0,0 +1,213 @@ +""" +插件依赖注入模块 + +提供插件间依赖注入功能: +- @inject 装饰器自动注入依赖 +- 延迟注入(lazy injection)避免循环依赖 +- 类型安全的依赖获取 + +使用示例: + from utils.plugin_inject import inject, require_plugin + + class MyPlugin(PluginBase): + # 方式1: 使用装饰器注入 + @inject("AIChat") + def get_aichat(self) -> "AIChat": + pass # 自动注入,无需实现 + + # 方式2: 使用 require_plugin + async def some_method(self): + aichat = require_plugin("AIChat") + await aichat.do_something() + + # 方式3: 使用基类的 get_plugin + async def another_method(self): + aichat = self.get_plugin("AIChat") + if aichat: + await aichat.do_something() +""" + +from functools import wraps +from typing import Any, Callable, Optional, Type, TypeVar, TYPE_CHECKING + +from loguru import logger + +if TYPE_CHECKING: + from utils.plugin_base import PluginBase + +T = TypeVar('T') + + +class PluginNotAvailableError(Exception): + """插件不可用错误""" + pass + + +def _get_plugin_manager(): + """延迟获取 PluginManager 避免循环导入""" + from utils.plugin_manager import PluginManager + return PluginManager() + + +def require_plugin(plugin_name: str) -> "PluginBase": + """ + 获取必需的插件(不存在则抛出异常) + + Args: + plugin_name: 插件类名 + + Returns: + 插件实例 + + Raises: + PluginNotAvailableError: 插件不存在或未启用 + """ + pm = _get_plugin_manager() + plugin = pm.plugins.get(plugin_name) + if plugin is None: + raise PluginNotAvailableError(f"插件 {plugin_name} 不可用") + return plugin + + +def get_plugin(plugin_name: str) -> Optional["PluginBase"]: + """ + 获取插件(不存在返回 None) + + Args: + plugin_name: 插件类名 + + Returns: + 插件实例或 None + """ + pm = _get_plugin_manager() + return pm.plugins.get(plugin_name) + + +def inject(plugin_name: str) -> Callable: + """ + 插件注入装饰器 + + 将方法转换为属性 getter,自动返回指定插件实例。 + + Args: + plugin_name: 要注入的插件类名 + + Usage: + class MyPlugin(PluginBase): + @inject("AIChat") + def aichat(self) -> "AIChat": + pass # 无需实现 + + async def handle(self, bot, message): + # 直接使用 + await self.aichat.process(message) + """ + def decorator(method: Callable) -> property: + @wraps(method) + def getter(self) -> Optional["PluginBase"]: + # 优先使用插件自身的 _plugin_manager + if hasattr(self, '_plugin_manager') and self._plugin_manager: + return self._plugin_manager.plugins.get(plugin_name) + # 回退到全局 PluginManager + return get_plugin(plugin_name) + + return property(getter) + + return decorator + + +def inject_required(plugin_name: str) -> Callable: + """ + 必需插件注入装饰器 + + 与 inject 类似,但如果插件不存在会抛出异常。 + + Args: + plugin_name: 要注入的插件类名 + + Raises: + PluginNotAvailableError: 插件不存在 + """ + def decorator(method: Callable) -> property: + @wraps(method) + def getter(self) -> "PluginBase": + plugin = None + if hasattr(self, '_plugin_manager') and self._plugin_manager: + plugin = self._plugin_manager.plugins.get(plugin_name) + else: + plugin = get_plugin(plugin_name) + + if plugin is None: + raise PluginNotAvailableError( + f"插件 {self.__class__.__name__} 依赖的 {plugin_name} 不可用" + ) + return plugin + + return property(getter) + + return decorator + + +class PluginProxy: + """ + 插件代理 + + 延迟获取插件,避免初始化时的循环依赖问题。 + + Usage: + class MyPlugin(PluginBase): + def __init__(self): + super().__init__() + self._aichat = PluginProxy("AIChat") + + async def handle(self): + # 首次访问时才获取插件 + if self._aichat.available: + await self._aichat.instance.process() + """ + + def __init__(self, plugin_name: str): + self._plugin_name = plugin_name + self._cached_instance: Optional["PluginBase"] = None + self._checked = False + + @property + def instance(self) -> Optional["PluginBase"]: + """获取插件实例(带缓存)""" + if not self._checked: + self._cached_instance = get_plugin(self._plugin_name) + self._checked = True + return self._cached_instance + + @property + def available(self) -> bool: + """检查插件是否可用""" + return self.instance is not None + + def require(self) -> "PluginBase": + """获取插件,不存在则抛出异常""" + inst = self.instance + if inst is None: + raise PluginNotAvailableError(f"插件 {self._plugin_name} 不可用") + return inst + + def invalidate(self): + """清除缓存,下次访问重新获取""" + self._cached_instance = None + self._checked = False + + def __repr__(self) -> str: + status = "available" if self.available else "unavailable" + return f"" + + +# ==================== 导出 ==================== + +__all__ = [ + 'PluginNotAvailableError', + 'require_plugin', + 'get_plugin', + 'inject', + 'inject_required', + 'PluginProxy', +] diff --git a/utils/plugin_manager.py b/utils/plugin_manager.py index 1838e6d..85a769a 100644 --- a/utils/plugin_manager.py +++ b/utils/plugin_manager.py @@ -2,7 +2,6 @@ import importlib import inspect import os import sys -import tomllib import traceback from typing import Dict, Type, List, Union @@ -10,6 +9,8 @@ from loguru import logger # from WechatAPI import WechatAPIClient # 注释掉,WechatHookBot 不需要这个导入 from utils.singleton import Singleton +from utils.config_manager import get_bot_config +from utils.llm_tooling import register_plugin_tools, unregister_plugin_tools from .event_manager import EventManager from .plugin_base import PluginBase @@ -22,10 +23,9 @@ class PluginManager(metaclass=Singleton): self.bot = None - with open("main_config.toml", "rb") as f: - main_config = tomllib.load(f) - - self.excluded_plugins = main_config.get("Bot", {}).get("disabled-plugins", []) + # 使用统一配置管理器 + bot_config = get_bot_config() + self.excluded_plugins = bot_config.get("disabled-plugins", []) def set_bot(self, bot): """设置 bot 客户端(WechatHookClient)""" @@ -74,13 +74,34 @@ class PluginManager(metaclass=Singleton): if is_disabled: return False + # 创建插件实例 plugin = plugin_class() - EventManager.bind_instance(plugin) - await plugin.on_enable(self.bot) + + # 生命周期: on_load(可访问其他插件) + await plugin.on_load(self) + + # 生命周期: async_init(加载配置、资源) await plugin.async_init() + + # 绑定事件处理器 + EventManager.bind_instance(plugin) + + # 生命周期: on_enable(注册定时任务) + await plugin.on_enable(self.bot) + + # 注册到插件管理器 self.plugins[plugin_name] = plugin self.plugin_classes[plugin_name] = plugin_class self.plugin_info[plugin_name]["enabled"] = True + + # 注册插件的 LLM 工具到全局注册中心 + try: + tools_config = self._get_tools_config() + timeout_config = self._get_timeout_config() + register_plugin_tools(plugin_name, plugin, tools_config, timeout_config) + except Exception as e: + logger.warning(f"注册插件 {plugin_name} 的工具时出错: {e}") + logger.success(f"加载插件 {plugin_name} 成功") return True except: @@ -232,8 +253,22 @@ class PluginManager(metaclass=Singleton): try: plugin = self.plugins[plugin_name] + + # 生命周期: on_disable(清理定时任务) await plugin.on_disable() + + # 解绑事件处理器 EventManager.unbind_instance(plugin) + + # 从全局注册中心注销插件的工具 + try: + unregister_plugin_tools(plugin_name) + except Exception as e: + logger.warning(f"注销插件 {plugin_name} 的工具时出错: {e}") + + # 生命周期: on_unload(释放资源) + await plugin.on_unload() + del self.plugins[plugin_name] del self.plugin_classes[plugin_name] if plugin_name in self.plugin_info.keys(): @@ -256,7 +291,7 @@ class PluginManager(metaclass=Singleton): return unloaded_plugins, failed_unloads async def reload_plugin(self, plugin_name: str) -> bool: - """重载单个插件""" + """重载单个插件(支持状态保存和恢复)""" if plugin_name not in self.plugin_classes: return False @@ -270,7 +305,15 @@ class PluginManager(metaclass=Singleton): plugin_class = self.plugin_classes[plugin_name] module_name = plugin_class.__module__ - # 先卸载插件 + # 生命周期: on_reload(保存状态) + saved_state = {} + if plugin_name in self.plugins: + try: + saved_state = await self.plugins[plugin_name].on_reload() + except Exception as e: + logger.warning(f"保存插件 {plugin_name} 状态失败: {e}") + + # 卸载插件 if not await self.unload_plugin(plugin_name): return False @@ -284,8 +327,15 @@ class PluginManager(metaclass=Singleton): issubclass(obj, PluginBase) and obj != PluginBase and obj.__name__ == plugin_name): - # 使用新的插件类而不是旧的 - return await self.load_plugin(obj) + # 加载新插件 + if await self.load_plugin(obj): + # 恢复状态 + if saved_state and plugin_name in self.plugins: + try: + await self.plugins[plugin_name].restore_state(saved_state) + except Exception as e: + logger.warning(f"恢复插件 {plugin_name} 状态失败: {e}") + return True return False except Exception as e: @@ -349,13 +399,42 @@ class PluginManager(metaclass=Singleton): def get_plugin_info(self, plugin_name: str = None) -> Union[dict, List[dict]]: """获取插件信息 - + Args: plugin_name: 插件名称,如果为None则返回所有插件信息 - + Returns: 如果指定插件名,返回单个插件信息字典;否则返回所有插件信息列表 """ if plugin_name: return self.plugin_info.get(plugin_name) return list(self.plugin_info.values()) + + def _get_tools_config(self) -> dict: + """获取工具配置(用于工具注册)""" + try: + # 尝试从 AIChat 插件配置读取 + import tomllib + from pathlib import Path + aichat_config_path = Path("plugins/AIChat/config.toml") + if aichat_config_path.exists(): + with open(aichat_config_path, "rb") as f: + aichat_config = tomllib.load(f) + return aichat_config.get("tools", {}) + except Exception: + pass + return {"mode": "all", "whitelist": [], "blacklist": []} + + def _get_timeout_config(self) -> dict: + """获取工具超时配置""" + try: + import tomllib + from pathlib import Path + aichat_config_path = Path("plugins/AIChat/config.toml") + if aichat_config_path.exists(): + with open(aichat_config_path, "rb") as f: + aichat_config = tomllib.load(f) + return aichat_config.get("tools", {}).get("timeout", {}) + except Exception: + pass + return {"default": 60} diff --git a/utils/tool_executor.py b/utils/tool_executor.py new file mode 100644 index 0000000..403c29a --- /dev/null +++ b/utils/tool_executor.py @@ -0,0 +1,488 @@ +""" +工具执行器模块 + +提供工具调用的高级执行逻辑: +- 批量工具执行(支持并行) +- 工具调用链处理 +- 执行日志和审计 +- 结果聚合 + +使用示例: + from utils.tool_executor import ToolExecutor, ToolCallRequest + + executor = ToolExecutor() + + # 单个工具执行 + result = await executor.execute_single( + tool_call={"id": "call_1", "function": {"name": "get_weather", "arguments": "{}"}}, + bot=bot, + from_wxid=wxid, + ) + + # 批量工具执行 + results = await executor.execute_batch( + tool_calls=[...], + bot=bot, + from_wxid=wxid, + parallel=True, + ) +""" + +from __future__ import annotations + +import asyncio +import json +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional + +from loguru import logger + + +@dataclass +class ToolCallRequest: + """工具调用请求""" + id: str + name: str + arguments: Dict[str, Any] + raw_arguments: str = "" # 原始 JSON 字符串 + + +@dataclass +class ToolCallResult: + """工具调用结果""" + id: str + name: str + success: bool = True + message: str = "" + raw_result: Dict[str, Any] = field(default_factory=dict) + + # 控制标志 + need_ai_reply: bool = False + already_sent: bool = False + send_result_text: bool = False + no_reply: bool = False + save_to_memory: bool = False + + # 执行信息 + execution_time_ms: float = 0.0 + error: Optional[str] = None + + def to_message(self) -> Dict[str, Any]: + """转换为 OpenAI 兼容的 tool message""" + content = self.message if self.success else f"错误: {self.error or self.message}" + return { + "role": "tool", + "tool_call_id": self.id, + "content": content + } + + +@dataclass +class ExecutionStats: + """执行统计""" + total_calls: int = 0 + successful_calls: int = 0 + failed_calls: int = 0 + timeout_calls: int = 0 + total_time_ms: float = 0.0 + avg_time_ms: float = 0.0 + + +class ToolExecutor: + """ + 工具执行器 + + 提供统一的工具执行接口: + - 参数解析和校验 + - 超时保护 + - 错误处理 + - 执行统计 + """ + + def __init__( + self, + default_timeout: float = 60.0, + max_parallel: int = 5, + validate_args: bool = True, + ): + self.default_timeout = default_timeout + self.max_parallel = max_parallel + self.validate_args = validate_args + self._stats = ExecutionStats() + + def parse_tool_call(self, tool_call: Dict[str, Any]) -> ToolCallRequest: + """ + 解析 OpenAI 格式的工具调用 + + Args: + tool_call: OpenAI 返回的 tool_call 对象 + + Returns: + ToolCallRequest 对象 + """ + call_id = tool_call.get("id", "") + function = tool_call.get("function", {}) + name = function.get("name", "") + raw_args = function.get("arguments", "{}") + + # 解析 arguments JSON + try: + arguments = json.loads(raw_args) if raw_args else {} + except json.JSONDecodeError as e: + logger.warning(f"[ToolExecutor] 解析参数失败: {e}, raw={raw_args[:100]}") + arguments = {} + + return ToolCallRequest( + id=call_id, + name=name, + arguments=arguments, + raw_arguments=raw_args, + ) + + async def execute_single( + self, + tool_call: Dict[str, Any], + bot, + from_wxid: str, + timeout_override: Optional[float] = None, + ) -> ToolCallResult: + """ + 执行单个工具调用 + + Args: + tool_call: OpenAI 格式的 tool_call + bot: WechatHookClient 实例 + from_wxid: 消息来源 wxid + timeout_override: 覆盖默认超时 + + Returns: + ToolCallResult 对象 + """ + from utils.tool_registry import get_tool_registry + from utils.llm_tooling import validate_tool_arguments, ToolResult + + start_time = time.time() + request = self.parse_tool_call(tool_call) + registry = get_tool_registry() + + result = ToolCallResult( + id=request.id, + name=request.name, + ) + + # 获取工具定义 + tool_def = registry.get(request.name) + if not tool_def: + result.success = False + result.error = f"工具 {request.name} 不存在" + result.message = result.error + self._update_stats(False, time.time() - start_time) + return result + + # 参数校验 + if self.validate_args: + schema = tool_def.schema.get("function", {}).get("parameters", {}) + ok, error_msg, validated_args = validate_tool_arguments( + request.name, request.arguments, schema + ) + if not ok: + result.success = False + result.error = error_msg + result.message = error_msg + self._update_stats(False, time.time() - start_time) + return result + request.arguments = validated_args + + # 执行工具 + timeout = timeout_override or tool_def.timeout or self.default_timeout + + try: + logger.debug(f"[ToolExecutor] 执行工具: {request.name}") + + raw_result = await asyncio.wait_for( + tool_def.executor(request.name, request.arguments, bot, from_wxid), + timeout=timeout + ) + + # 解析结果 + tool_result = ToolResult.from_raw(raw_result) + if tool_result: + result.success = tool_result.success + result.message = tool_result.message + result.need_ai_reply = tool_result.need_ai_reply + result.already_sent = tool_result.already_sent + result.send_result_text = tool_result.send_result_text + result.no_reply = tool_result.no_reply + result.save_to_memory = tool_result.save_to_memory + else: + result.message = str(raw_result) if raw_result else "执行完成" + + result.raw_result = raw_result if isinstance(raw_result, dict) else {"result": raw_result} + + execution_time = time.time() - start_time + result.execution_time_ms = execution_time * 1000 + self._update_stats(result.success, execution_time) + + logger.debug( + f"[ToolExecutor] 工具 {request.name} 执行完成 " + f"({result.execution_time_ms:.1f}ms)" + ) + + except asyncio.TimeoutError: + result.success = False + result.error = f"执行超时 ({timeout}s)" + result.message = result.error + self._update_stats(False, time.time() - start_time, timeout=True) + logger.warning(f"[ToolExecutor] 工具 {request.name} 执行超时") + + except asyncio.CancelledError: + raise + + except Exception as e: + result.success = False + result.error = str(e) + result.message = f"执行失败: {e}" + self._update_stats(False, time.time() - start_time) + logger.error(f"[ToolExecutor] 工具 {request.name} 执行异常: {e}") + + return result + + async def execute_batch( + self, + tool_calls: List[Dict[str, Any]], + bot, + from_wxid: str, + parallel: bool = True, + stop_on_error: bool = False, + ) -> List[ToolCallResult]: + """ + 批量执行工具调用 + + Args: + tool_calls: 工具调用列表 + bot: WechatHookClient 实例 + from_wxid: 消息来源 wxid + parallel: 是否并行执行 + stop_on_error: 遇到错误是否停止 + + Returns: + ToolCallResult 列表 + """ + if not tool_calls: + return [] + + if parallel and len(tool_calls) > 1: + return await self._execute_parallel(tool_calls, bot, from_wxid, stop_on_error) + else: + return await self._execute_sequential(tool_calls, bot, from_wxid, stop_on_error) + + async def _execute_sequential( + self, + tool_calls: List[Dict[str, Any]], + bot, + from_wxid: str, + stop_on_error: bool, + ) -> List[ToolCallResult]: + """顺序执行""" + results = [] + for tool_call in tool_calls: + result = await self.execute_single(tool_call, bot, from_wxid) + results.append(result) + + if stop_on_error and not result.success: + logger.warning(f"[ToolExecutor] 工具 {result.name} 失败,停止批量执行") + break + + return results + + async def _execute_parallel( + self, + tool_calls: List[Dict[str, Any]], + bot, + from_wxid: str, + stop_on_error: bool, + ) -> List[ToolCallResult]: + """并行执行(带并发限制)""" + semaphore = asyncio.Semaphore(self.max_parallel) + + async def execute_with_limit(tool_call): + async with semaphore: + return await self.execute_single(tool_call, bot, from_wxid) + + tasks = [execute_with_limit(tc) for tc in tool_calls] + + if stop_on_error: + # 使用 gather 但不 return_exceptions,让第一个错误停止执行 + results = [] + for coro in asyncio.as_completed(tasks): + try: + result = await coro + results.append(result) + if not result.success: + # 取消剩余任务 + for task in tasks: + if isinstance(task, asyncio.Task) and not task.done(): + task.cancel() + break + except Exception as e: + logger.error(f"[ToolExecutor] 并行执行异常: {e}") + break + return results + else: + # 全部执行,收集所有结果 + return await asyncio.gather(*tasks, return_exceptions=False) + + def _update_stats(self, success: bool, execution_time: float, timeout: bool = False): + """更新执行统计""" + self._stats.total_calls += 1 + if success: + self._stats.successful_calls += 1 + else: + self._stats.failed_calls += 1 + if timeout: + self._stats.timeout_calls += 1 + + self._stats.total_time_ms += execution_time * 1000 + self._stats.avg_time_ms = self._stats.total_time_ms / self._stats.total_calls + + def get_stats(self) -> Dict[str, Any]: + """获取执行统计""" + return { + "total_calls": self._stats.total_calls, + "successful_calls": self._stats.successful_calls, + "failed_calls": self._stats.failed_calls, + "timeout_calls": self._stats.timeout_calls, + "total_time_ms": self._stats.total_time_ms, + "avg_time_ms": self._stats.avg_time_ms, + "success_rate": ( + self._stats.successful_calls / self._stats.total_calls + if self._stats.total_calls > 0 else 0 + ), + } + + def reset_stats(self): + """重置统计""" + self._stats = ExecutionStats() + + +class ToolCallChain: + """ + 工具调用链 + + 用于处理需要多轮工具调用的场景,记录调用历史。 + """ + + def __init__(self, max_rounds: int = 10): + self.max_rounds = max_rounds + self.history: List[ToolCallResult] = [] + self.current_round = 0 + + def add_result(self, result: ToolCallResult): + """添加调用结果""" + self.history.append(result) + + def add_results(self, results: List[ToolCallResult]): + """添加多个调用结果""" + self.history.extend(results) + + def increment_round(self): + """增加轮次""" + self.current_round += 1 + + def can_continue(self) -> bool: + """检查是否可以继续调用""" + return self.current_round < self.max_rounds + + def get_tool_messages(self) -> List[Dict[str, Any]]: + """获取所有工具调用的消息(用于发送给 LLM)""" + return [result.to_message() for result in self.history] + + def get_last_results(self, n: int = 1) -> List[ToolCallResult]: + """获取最后 n 个结果""" + return self.history[-n:] if self.history else [] + + def has_special_flags(self) -> Dict[str, bool]: + """检查是否有特殊标志""" + flags = { + "need_ai_reply": False, + "already_sent": False, + "no_reply": False, + "save_to_memory": False, + "send_result_text": False, + } + + for result in self.history: + if result.need_ai_reply: + flags["need_ai_reply"] = True + if result.already_sent: + flags["already_sent"] = True + if result.no_reply: + flags["no_reply"] = True + if result.save_to_memory: + flags["save_to_memory"] = True + if result.send_result_text: + flags["send_result_text"] = True + + return flags + + def get_summary(self) -> str: + """获取调用链摘要""" + if not self.history: + return "无工具调用" + + successful = sum(1 for r in self.history if r.success) + failed = len(self.history) - successful + total_time = sum(r.execution_time_ms for r in self.history) + + tools_called = [r.name for r in self.history] + + return ( + f"调用链: {len(self.history)} 个工具, " + f"成功 {successful}, 失败 {failed}, " + f"总耗时 {total_time:.1f}ms, " + f"工具: {', '.join(tools_called)}" + ) + + +# ==================== 便捷函数 ==================== + +_default_executor: Optional[ToolExecutor] = None + + +def get_tool_executor( + default_timeout: float = 60.0, + max_parallel: int = 5, +) -> ToolExecutor: + """获取默认工具执行器""" + global _default_executor + if _default_executor is None: + _default_executor = ToolExecutor( + default_timeout=default_timeout, + max_parallel=max_parallel, + ) + return _default_executor + + +async def execute_tool_calls( + tool_calls: List[Dict[str, Any]], + bot, + from_wxid: str, + parallel: bool = True, +) -> List[ToolCallResult]: + """便捷函数:执行工具调用列表""" + executor = get_tool_executor() + return await executor.execute_batch(tool_calls, bot, from_wxid, parallel=parallel) + + +# ==================== 导出 ==================== + +__all__ = [ + 'ToolCallRequest', + 'ToolCallResult', + 'ExecutionStats', + 'ToolExecutor', + 'ToolCallChain', + 'get_tool_executor', + 'execute_tool_calls', +] diff --git a/utils/tool_registry.py b/utils/tool_registry.py new file mode 100644 index 0000000..1c9c7f8 --- /dev/null +++ b/utils/tool_registry.py @@ -0,0 +1,286 @@ +""" +工具注册中心 + +集中管理所有 LLM 工具的注册、查找和执行 +- O(1) 工具查找(替代 O(n) 插件遍历) +- 统一的超时保护 +- 工具元信息管理 + +使用示例: + from utils.tool_registry import get_tool_registry + + registry = get_tool_registry() + + # 注册工具 + registry.register( + name="generate_image", + plugin_name="AIChat", + schema={...}, + executor=some_async_func, + timeout=120 + ) + + # 执行工具 + result = await registry.execute("generate_image", arguments, bot, from_wxid) +""" + +import asyncio +from dataclasses import dataclass, field +from threading import Lock +from typing import Any, Callable, Dict, List, Optional, Awaitable + +from loguru import logger + + +@dataclass +class ToolDefinition: + """工具定义""" + name: str + plugin_name: str + schema: Dict[str, Any] # OpenAI-compatible tool schema + executor: Callable[..., Awaitable[Dict[str, Any]]] + timeout: float = 60.0 + priority: int = 50 # 同名工具时优先级高的生效 + description: str = "" + + def __post_init__(self): + # 从 schema 提取描述 + if not self.description and self.schema: + func_def = self.schema.get("function", {}) + self.description = func_def.get("description", "") + + +class ToolRegistry: + """ + 工具注册中心(线程安全单例) + + 功能: + - 工具注册与注销 + - O(1) 工具查找 + - 统一超时保护执行 + - 工具列表导出(供 LLM 使用) + """ + + _instance: Optional["ToolRegistry"] = None + _lock = Lock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + instance = super().__new__(cls) + instance._initialized = False + cls._instance = instance + return cls._instance + + def __init__(self): + if self._initialized: + return + + self._tools: Dict[str, ToolDefinition] = {} + self._tools_by_plugin: Dict[str, List[str]] = {} # plugin_name -> [tool_names] + self._registry_lock = Lock() + self._initialized = True + logger.debug("ToolRegistry 初始化完成") + + def register( + self, + name: str, + plugin_name: str, + schema: Dict[str, Any], + executor: Callable[..., Awaitable[Dict[str, Any]]], + timeout: float = 60.0, + priority: int = 50, + ) -> bool: + """ + 注册工具 + + Args: + name: 工具名称(唯一标识) + plugin_name: 所属插件名 + schema: OpenAI-compatible tool schema + executor: 异步执行函数,签名: async (tool_name, arguments, bot, from_wxid) -> dict + timeout: 执行超时(秒) + priority: 优先级(同名工具时高优先级覆盖低优先级) + + Returns: + 是否注册成功 + """ + with self._registry_lock: + # 检查是否已存在同名工具 + existing = self._tools.get(name) + if existing: + if existing.priority >= priority: + logger.debug(f"工具 {name} 已存在且优先级更高,跳过注册") + return False + logger.info(f"工具 {name} 被 {plugin_name} 覆盖(优先级 {priority} > {existing.priority})") + # 从旧插件的工具列表中移除 + old_plugin_tools = self._tools_by_plugin.get(existing.plugin_name, []) + if name in old_plugin_tools: + old_plugin_tools.remove(name) + + # 注册新工具 + tool_def = ToolDefinition( + name=name, + plugin_name=plugin_name, + schema=schema, + executor=executor, + timeout=timeout, + priority=priority, + ) + self._tools[name] = tool_def + + # 更新插件工具映射 + if plugin_name not in self._tools_by_plugin: + self._tools_by_plugin[plugin_name] = [] + if name not in self._tools_by_plugin[plugin_name]: + self._tools_by_plugin[plugin_name].append(name) + + logger.debug(f"注册工具: {name} (插件: {plugin_name}, 超时: {timeout}s)") + return True + + def unregister(self, name: str) -> bool: + """注销工具""" + with self._registry_lock: + tool_def = self._tools.pop(name, None) + if tool_def: + plugin_tools = self._tools_by_plugin.get(tool_def.plugin_name, []) + if name in plugin_tools: + plugin_tools.remove(name) + logger.debug(f"注销工具: {name}") + return True + return False + + def unregister_plugin(self, plugin_name: str) -> int: + """ + 注销插件的所有工具 + + Args: + plugin_name: 插件名 + + Returns: + 注销的工具数量 + """ + with self._registry_lock: + tool_names = self._tools_by_plugin.pop(plugin_name, []) + count = 0 + for name in tool_names: + if self._tools.pop(name, None): + count += 1 + if count > 0: + logger.info(f"注销插件 {plugin_name} 的 {count} 个工具") + return count + + def get(self, name: str) -> Optional[ToolDefinition]: + """获取工具定义(O(1) 查找)""" + return self._tools.get(name) + + def get_all_schemas(self) -> List[Dict[str, Any]]: + """获取所有工具的 schema 列表(供 LLM 使用)""" + return [tool.schema for tool in self._tools.values()] + + def get_plugin_schemas(self, plugin_name: str) -> List[Dict[str, Any]]: + """获取指定插件的工具 schema 列表""" + tool_names = self._tools_by_plugin.get(plugin_name, []) + return [self._tools[name].schema for name in tool_names if name in self._tools] + + def list_tools(self) -> List[str]: + """列出所有工具名""" + return list(self._tools.keys()) + + def list_plugin_tools(self, plugin_name: str) -> List[str]: + """列出插件的所有工具名""" + return self._tools_by_plugin.get(plugin_name, []).copy() + + async def execute( + self, + name: str, + arguments: Dict[str, Any], + bot, + from_wxid: str, + timeout_override: float = None, + ) -> Dict[str, Any]: + """ + 执行工具(带超时保护和统一错误处理) + + Args: + name: 工具名 + arguments: 工具参数 + bot: WechatHookClient 实例 + from_wxid: 消息来源 wxid + timeout_override: 覆盖默认超时时间 + + Returns: + 工具执行结果字典 + """ + from utils.errors import ( + ToolNotFoundError, ToolTimeoutError, ToolExecutionError, + handle_error + ) + + tool_def = self._tools.get(name) + if not tool_def: + err = ToolNotFoundError(f"工具 {name} 不存在") + return err.to_dict() + + timeout = timeout_override if timeout_override is not None else tool_def.timeout + + try: + result = await asyncio.wait_for( + tool_def.executor(name, arguments, bot, from_wxid), + timeout=timeout + ) + return result + + except asyncio.TimeoutError: + err = ToolTimeoutError( + message=f"工具 {name} 执行超时 ({timeout}s)", + user_message=f"工具执行超时 ({timeout}s)", + context={"tool_name": name, "timeout": timeout} + ) + logger.warning(err.message) + result = err.to_dict() + result["timeout"] = True + return result + + except Exception as e: + error_result = handle_error( + e, + context=f"执行工具 {name}", + log=True, + ) + return error_result.to_dict() + + def get_stats(self) -> Dict[str, Any]: + """获取注册统计信息""" + return { + "total_tools": len(self._tools), + "plugins": len(self._tools_by_plugin), + "tools_by_plugin": { + plugin: len(tools) + for plugin, tools in self._tools_by_plugin.items() + } + } + + def clear(self): + """清空所有注册(用于测试或重置)""" + with self._registry_lock: + self._tools.clear() + self._tools_by_plugin.clear() + logger.info("ToolRegistry 已清空") + + +# ==================== 便捷函数 ==================== + +def get_tool_registry() -> ToolRegistry: + """获取工具注册中心实例""" + return ToolRegistry() + + +# ==================== 导出列表 ==================== + +__all__ = [ + 'ToolDefinition', + 'ToolRegistry', + 'get_tool_registry', +]