From 820861752b687ed9393d12f1942783ad87fb8b60 Mon Sep 17 00:00:00 2001 From: shihao <3127647737@qq.com> Date: Wed, 31 Dec 2025 18:39:21 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E8=B6=85=E7=BA=A7=E5=B1=8E=E5=B1=B1?= =?UTF-8?q?=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- plugins/AIChat/main.py | 476 ++++++++++++++++------------------------- 1 file changed, 190 insertions(+), 286 deletions(-) diff --git a/plugins/AIChat/main.py b/plugins/AIChat/main.py index 7b828ec..4378c17 100644 --- a/plugins/AIChat/main.py +++ b/plugins/AIChat/main.py @@ -17,8 +17,8 @@ from loguru import logger from utils.plugin_base import PluginBase from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message from utils.redis_cache import get_cache -from utils.llm_tooling import ToolResult, collect_tools_with_plugins, collect_tools, get_tool_schema_map, validate_tool_arguments from utils.image_processor import ImageProcessor, MediaConfig +from utils.tool_executor import ToolExecutor from utils.tool_registry import get_tool_registry import xml.etree.ElementTree as ET import base64 @@ -472,25 +472,39 @@ class AIChat(PluginBase): return "" def _collect_tools_with_plugins(self) -> dict: - """收集所有插件的 LLM 工具,并保留来源插件名""" - from utils.plugin_manager import PluginManager - tools_config = self.config.get("tools", {}) - return collect_tools_with_plugins(tools_config, PluginManager().plugins) + """收集工具定义(来自 ToolRegistry)并保留来源插件名""" + registry = get_tool_registry() + tools_config = (self.config or {}).get("tools", {}) + mode = tools_config.get("mode", "all") + whitelist = set(tools_config.get("whitelist", [])) + blacklist = set(tools_config.get("blacklist", [])) + + tools_map = {} + for name in registry.list_tools(): + tool_def = registry.get(name) + if not tool_def: + continue + if mode == "whitelist" and name not in whitelist: + continue + if mode == "blacklist" and name in blacklist: + continue + tools_map[name] = (tool_def.plugin_name, tool_def.schema) + + return tools_map def _collect_tools(self): """收集所有插件的LLM工具(支持白名单/黑名单过滤)""" - from utils.plugin_manager import PluginManager - tools_config = self.config.get("tools", {}) - return collect_tools(tools_config, PluginManager().plugins) + tools_map = self._collect_tools_with_plugins() + return [item[1] for item in tools_map.values()] def _get_tool_schema_map(self, tools_map: dict | None = None) -> dict: """构建工具名到参数 schema 的映射""" tools_map = tools_map or self._collect_tools_with_plugins() - return get_tool_schema_map(tools_map) - - def _validate_tool_arguments(self, tool_name: str, arguments: dict, schema: dict) -> tuple: - """轻量校验并补全默认参数""" - return validate_tool_arguments(tool_name, arguments, schema) + schema_map = {} + for name, (_plugin_name, tool) in tools_map.items(): + fn = tool.get("function", {}) + schema_map[name] = fn.get("parameters", {}) or {} + return schema_map async def _handle_list_prompts(self, bot, from_wxid: str): """处理人设列表指令""" @@ -627,10 +641,12 @@ class AIChat(PluginBase): ) cleaned = cleaned.replace("展开阅读下文", "") cleaned = re.sub( - r"(已触发工具处理:[^)]{0,300}结果将发送到聊天中。)", + r"[((]已触发工具处理[^))\r\n]{0,500}[))]?", "", cleaned, ) + cleaned = re.sub(r"(?m)^.*已触发工具处理.*$", "", cleaned) + cleaned = re.sub(r"(?m)^.*结果将发送到聊天中.*$", "", cleaned) # 过滤图片占位符/文件名,避免把日志占位符当成正文发出去 cleaned = re.sub( r"\\[图片[^\\]]*\\]\\s*\\S+\\.(?:png|jpe?g|gif|webp)", @@ -2297,56 +2313,69 @@ class AIChat(PluginBase): await self.store.update_group_message_by_id(chat_id, record_id, new_content) - async def _execute_tool_and_get_result( + def _prepare_tool_calls_for_executor( self, - tool_name: str, - arguments: dict, - bot, + tool_calls_data: list, + messages: list, + *, + user_wxid: str, from_wxid: str, - user_wxid: str = None, - is_group: bool = False, - tools_map: dict | None = None, - timeout: float = None, - ): - """ - 执行工具调用并返回结果(使用 ToolRegistry) + is_group: bool, + image_base64: str | None = None, + ) -> list: + prepared = [] + if not tool_calls_data: + return prepared - 通过 ToolRegistry 实现 O(1) 工具查找和统一超时保护 - """ - # 获取工具专属超时时间 - if timeout is None: - tool_timeout_config = self.config.get("tools", {}).get("timeout", {}) - timeout = tool_timeout_config.get(tool_name, tool_timeout_config.get("default", 60)) + for tool_call in tool_calls_data: + function = (tool_call or {}).get("function") or {} + function_name = function.get("name", "") + if not function_name: + continue - # 添加用户信息到 arguments - arguments["user_wxid"] = user_wxid or from_wxid - arguments["is_group"] = bool(is_group) + tool_call_id = (tool_call or {}).get("id", "") + if not tool_call_id: + tool_call_id = f"call_{uuid.uuid4().hex[:8]}" + tool_call["id"] = tool_call_id - logger.info(f"开始执行工具: {tool_name} (超时: {timeout}s)") + raw_arguments = function.get("arguments", "{}") + try: + arguments = json.loads(raw_arguments) if raw_arguments else {} + if not isinstance(arguments, dict): + arguments = {} + except Exception: + arguments = {} + if "function" not in tool_call: + tool_call["function"] = {} + tool_call["function"]["arguments"] = "{}" - # 使用 ToolRegistry 执行工具(O(1) 查找 + 统一超时保护) - registry = get_tool_registry() - result = await registry.execute(tool_name, arguments, bot, from_wxid, timeout_override=timeout) + if function_name in ("tavily_web_search", "web_search") and not arguments.get("query"): + fallback_query = self._extract_tool_intent_text(self._extract_last_user_text(messages)) + fallback_query = str(fallback_query or "").strip() + if fallback_query: + arguments["query"] = fallback_query[:400] + if "function" not in tool_call: + tool_call["function"] = {} + tool_call["function"]["arguments"] = json.dumps(arguments, ensure_ascii=False) - # 规范化结果 - if result is None: - return {"success": False, "message": f"工具 {tool_name} 返回空结果"} + exec_args = dict(arguments) + exec_args["user_wxid"] = user_wxid or from_wxid + exec_args["is_group"] = bool(is_group) - if not isinstance(result, dict): - result = {"success": True, "message": str(result)} - else: - result.setdefault("success", True) + if image_base64 and function_name == "flow2_ai_image_generation": + exec_args["image_base64"] = image_base64 + logger.info("[异步-图片] 图生图工具,已添加图片数据") - # 记录执行结果 - tool_def = registry.get(tool_name) - plugin_name = tool_def.plugin_name if tool_def else "unknown" + prepared.append({ + "id": tool_call_id, + "type": "function", + "function": { + "name": function_name, + "arguments": json.dumps(exec_args, ensure_ascii=False), + }, + }) - if result.get("success"): - logger.success(f"工具执行成功: {tool_name} ({plugin_name})") - else: - logger.warning(f"工具执行失败: {tool_name} ({plugin_name})") - - return result + return prepared async def _execute_tools_async(self, tool_calls_data: list, bot, from_wxid: str, chat_id: str, user_wxid: str, nickname: str, is_group: bool, @@ -2360,140 +2389,70 @@ class AIChat(PluginBase): try: logger.info(f"开始异步执行 {len(tool_calls_data)} 个工具调用") - # 获取并发控制配置 - concurrency_config = self.config.get("tools", {}).get("concurrency", {}) + concurrency_config = (self.config or {}).get("tools", {}).get("concurrency", {}) max_concurrent = concurrency_config.get("max_concurrent", 5) - semaphore = asyncio.Semaphore(max_concurrent) + timeout_config = (self.config or {}).get("tools", {}).get("timeout", {}) + default_timeout = timeout_config.get("default", 60) - # 并行执行所有工具(带并发限制) - tasks = [] - tool_info_list = [] # 保存工具信息用于后续处理 - tools_map = self._collect_tools_with_plugins() - schema_map = self._get_tool_schema_map(tools_map) + executor = ToolExecutor(default_timeout=default_timeout, max_parallel=max_concurrent) + prepared_tool_calls = self._prepare_tool_calls_for_executor( + tool_calls_data, + messages, + user_wxid=user_wxid, + from_wxid=from_wxid, + is_group=is_group, + ) - for tool_call in tool_calls_data: - function_name = tool_call.get("function", {}).get("name", "") - arguments_str = tool_call.get("function", {}).get("arguments", "{}") - tool_call_id = tool_call.get("id", "") + if not prepared_tool_calls: + logger.info("[异步] 没有可执行的工具调用") + return - if not function_name: + logger.info(f"[异步] 开始并行执行 {len(prepared_tool_calls)} 个工具 (最大并发: {max_concurrent})") + results = await executor.execute_batch(prepared_tool_calls, bot, from_wxid, parallel=True) + need_ai_reply_results = [] + + for result in results: + function_name = result.name + tool_call_id = result.id + tool_message = self._sanitize_llm_output(result.message or "") + + if result.success: + logger.success(f"[异步] 工具 {function_name} 执行成功") + else: + logger.warning(f"[异步] 工具 {function_name} 执行失败: {result.error or result.message}") + + if result.need_ai_reply: + need_ai_reply_results.append({ + "tool_call_id": tool_call_id, + "function_name": function_name, + "result": tool_message + }) continue - try: - arguments = json.loads(arguments_str) - except Exception: - arguments = {} + if result.success and not result.already_sent and tool_message and not result.no_reply: + if result.send_result_text: + if tool_message: + await bot.send_text(from_wxid, tool_message) + else: + logger.warning(f"[异步] 工具 {function_name} 输出清洗后为空,已跳过发送") - if function_name in ("tavily_web_search", "web_search") and not arguments.get("query"): - fallback_query = self._extract_tool_intent_text(self._extract_last_user_text(messages)) - fallback_query = str(fallback_query or "").strip() - if fallback_query: - arguments["query"] = fallback_query[:400] - - schema = schema_map.get(function_name) - ok, err, arguments = self._validate_tool_arguments(function_name, arguments, schema) - if not ok: - logger.warning(f"[异步] 工具 {function_name} 参数校验失败: {err}") + if not result.success and not result.no_reply: try: - await bot.send_text(from_wxid, f"❌ 工具 {function_name} 参数错误: {err}") + if tool_message: + await bot.send_text(from_wxid, f"? {tool_message}") + else: + await bot.send_text(from_wxid, f"? {function_name} 执行失败") except Exception: pass - continue - logger.info(f"[异步] 准备执行工具: {function_name}, 参数: {arguments}") + if result.save_to_memory and chat_id and tool_message: + self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {tool_message}") - # 创建带并发限制的异步任务 - async def execute_with_semaphore(fn, args, bot_ref, wxid, user_wxid_ref, is_grp, t_map, sem): - async with sem: - return await self._execute_tool_and_get_result( - fn, args, bot_ref, wxid, - user_wxid=user_wxid_ref, is_group=is_grp, tools_map=t_map - ) - - task = execute_with_semaphore( - function_name, arguments, bot, from_wxid, - user_wxid, is_group, tools_map, semaphore + if need_ai_reply_results: + await self._continue_with_tool_results( + need_ai_reply_results, bot, from_wxid, user_wxid, chat_id, + nickname, is_group, messages, tool_calls_data ) - tasks.append(task) - tool_info_list.append({ - "tool_call_id": tool_call_id, - "function_name": function_name, - "arguments": arguments - }) - - # 并行执行所有工具(带并发限制,防止资源耗尽) - if tasks: - logger.info(f"[异步] 开始并行执行 {len(tasks)} 个工具 (最大并发: {max_concurrent})") - results = await asyncio.gather(*tasks, return_exceptions=True) - need_ai_reply_results = [] - - # 处理每个工具的结果 - for i, result in enumerate(results): - tool_info = tool_info_list[i] - function_name = tool_info["function_name"] - tool_call_id = tool_info["tool_call_id"] - tool_call_id = tool_info["tool_call_id"] - - if isinstance(result, Exception): - logger.error(f"[异步] 工具 {function_name} 执行异常: {result}") - try: - await bot.send_text(from_wxid, f"❌ {function_name} 执行失败: {result}") - except Exception: - pass - continue - - tool_result = ToolResult.from_raw(result) - if not tool_result: - continue - - tool_message = self._sanitize_llm_output(tool_result.message) if tool_result.message is not None else "" - - # 工具文本统一做一次输出清洗,避免工具内部/下游LLM把“思维链”发出来 - tool_message = self._sanitize_llm_output(tool_result.message) if tool_result.message is not None else "" - - if tool_result.success: - logger.success(f"[异步] 工具 {function_name} 执行成功") - else: - logger.warning(f"[异步] 工具 {function_name} 执行失败") - - # 需要 AI 继续处理的结果 - if tool_result.need_ai_reply: - need_ai_reply_results.append({ - "tool_call_id": tool_call_id, - "function_name": function_name, - "result": tool_message - }) - continue - - # 工具成功且需要回文本时发送 - if tool_result.success and not tool_result.already_sent and tool_result.message and not tool_result.no_reply: - if tool_result.send_result_text: - if tool_message: - await bot.send_text(from_wxid, tool_message) - else: - logger.warning(f"[异步] 工具 {function_name} 输出清洗后为空,已跳过发送") - - # 工具失败默认回一条错误提示 - if not tool_result.success and tool_message and not tool_result.no_reply: - try: - if tool_message: - await bot.send_text(from_wxid, f"❌ {tool_message}") - else: - await bot.send_text(from_wxid, f"❌ {function_name} 执行失败") - except Exception: - pass - - # 保存工具结果到记忆(可选) - if tool_result.save_to_memory and chat_id: - if tool_message: - self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {tool_message}") - - # 如果有需要 AI 回复的工具结果,调用 AI 继续对话 - if need_ai_reply_results: - await self._continue_with_tool_results( - need_ai_reply_results, bot, from_wxid, user_wxid, chat_id, - nickname, is_group, messages, tool_calls_data - ) logger.info(f"[异步] 所有工具执行完成") @@ -2502,7 +2461,7 @@ class AIChat(PluginBase): import traceback logger.error(f"详细错误: {traceback.format_exc()}") try: - await bot.send_text(from_wxid, "❌ 工具执行过程中出现错误") + await bot.send_text(from_wxid, "? 工具执行过程中出现错误") except: pass @@ -2700,126 +2659,71 @@ class AIChat(PluginBase): try: logger.info(f"[异步-图片] 开始执行 {len(tool_calls_data)} 个工具调用") - # 并行执行所有工具 - tasks = [] - tool_info_list = [] - tools_map = self._collect_tools_with_plugins() - schema_map = self._get_tool_schema_map(tools_map) + concurrency_config = (self.config or {}).get("tools", {}).get("concurrency", {}) + max_concurrent = concurrency_config.get("max_concurrent", 5) + timeout_config = (self.config or {}).get("tools", {}).get("timeout", {}) + default_timeout = timeout_config.get("default", 60) - for tool_call in tool_calls_data: - function_name = tool_call.get("function", {}).get("name", "") - arguments_str = tool_call.get("function", {}).get("arguments", "{}") - tool_call_id = tool_call.get("id", "") + executor = ToolExecutor(default_timeout=default_timeout, max_parallel=max_concurrent) + prepared_tool_calls = self._prepare_tool_calls_for_executor( + tool_calls_data, + messages, + user_wxid=user_wxid, + from_wxid=from_wxid, + is_group=is_group, + image_base64=image_base64, + ) - if not function_name: + if not prepared_tool_calls: + logger.info("[异步-图片] 没有可执行的工具调用") + return + + logger.info(f"[异步-图片] 开始并行执行 {len(prepared_tool_calls)} 个工具 (最大并发: {max_concurrent})") + results = await executor.execute_batch(prepared_tool_calls, bot, from_wxid, parallel=True) + need_ai_reply_results = [] + + for result in results: + function_name = result.name + tool_call_id = result.id + tool_message = self._sanitize_llm_output(result.message or "") + + if result.success: + logger.success(f"[异步-图片] 工具 {function_name} 执行成功") + else: + logger.warning(f"[异步-图片] 工具 {function_name} 执行失败: {result.error or result.message}") + + if result.need_ai_reply: + need_ai_reply_results.append({ + "tool_call_id": tool_call_id, + "function_name": function_name, + "result": tool_message + }) continue - try: - arguments = json.loads(arguments_str) - except Exception: - arguments = {} + if result.success and not result.already_sent and tool_message and not result.no_reply: + if result.send_result_text: + if tool_message: + await bot.send_text(from_wxid, tool_message) + else: + logger.warning(f"[异步-图片] 工具 {function_name} 输出清洗后为空,已跳过发送") - if function_name in ("tavily_web_search", "web_search") and not arguments.get("query"): - fallback_query = self._extract_tool_intent_text(self._extract_last_user_text(messages)) - fallback_query = str(fallback_query or "").strip() - if fallback_query: - arguments["query"] = fallback_query[:400] - - # 如果是图生图工具,添加图片 base64 - if function_name == "flow2_ai_image_generation" and image_base64: - arguments["image_base64"] = image_base64 - logger.info(f"[异步-图片] 图生图工具,已添加图片数据") - - schema = schema_map.get(function_name) - ok, err, arguments = self._validate_tool_arguments(function_name, arguments, schema) - if not ok: - logger.warning(f"[异步-图片] 工具 {function_name} 参数校验失败: {err}") + if not result.success and not result.no_reply: try: - await bot.send_text(from_wxid, f"❌ 工具 {function_name} 参数错误: {err}") + if tool_message: + await bot.send_text(from_wxid, f"? {tool_message}") + else: + await bot.send_text(from_wxid, f"? {function_name} 执行失败") except Exception: pass - continue - logger.info(f"[异步-图片] 准备执行工具: {function_name}") + if result.save_to_memory and chat_id and tool_message: + self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {tool_message}") - task = self._execute_tool_and_get_result( - function_name, - arguments, - bot, - from_wxid, - user_wxid=user_wxid, - is_group=is_group, - tools_map=tools_map, + if need_ai_reply_results: + await self._continue_with_tool_results( + need_ai_reply_results, bot, from_wxid, user_wxid, chat_id, + nickname, is_group, messages, tool_calls_data ) - tasks.append(task) - tool_info_list.append({ - "tool_call_id": tool_call_id, - "function_name": function_name, - "arguments": arguments - }) - - # 并行执行所有工具 - if tasks: - results = await asyncio.gather(*tasks, return_exceptions=True) - need_ai_reply_results = [] - - for i, result in enumerate(results): - tool_info = tool_info_list[i] - function_name = tool_info["function_name"] - tool_call_id = tool_info["tool_call_id"] - - if isinstance(result, Exception): - logger.error(f"[异步-图片] 工具 {function_name} 执行异常: {result}") - try: - await bot.send_text(from_wxid, f"❌ {function_name} 执行失败: {result}") - except Exception: - pass - continue - - tool_result = ToolResult.from_raw(result) - if not tool_result: - continue - - tool_message = self._sanitize_llm_output(tool_result.message or "") - - if tool_result.success: - logger.success(f"[异步-图片] 工具 {function_name} 执行成功") - else: - logger.warning(f"[异步-图片] 工具 {function_name} 执行失败") - - if tool_result.need_ai_reply: - need_ai_reply_results.append({ - "tool_call_id": tool_call_id, - "function_name": function_name, - "result": tool_message - }) - continue - - if tool_result.success and not tool_result.already_sent and tool_message and not tool_result.no_reply: - if tool_result.send_result_text: - if tool_message: - await bot.send_text(from_wxid, tool_message) - else: - logger.warning(f"[异步-图片] 工具 {function_name} 输出清洗后为空,已跳过发送") - - if not tool_result.success and tool_message and not tool_result.no_reply: - try: - if tool_message: - await bot.send_text(from_wxid, f"❌ {tool_message}") - else: - await bot.send_text(from_wxid, f"❌ {function_name} 执行失败") - except Exception: - pass - - if tool_result.save_to_memory and chat_id: - if tool_message: - self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {tool_message}") - - if need_ai_reply_results: - await self._continue_with_tool_results( - need_ai_reply_results, bot, from_wxid, user_wxid, chat_id, - nickname, is_group, messages, tool_calls_data - ) logger.info(f"[异步-图片] 所有工具执行完成") @@ -2828,7 +2732,7 @@ class AIChat(PluginBase): import traceback logger.error(f"详细错误: {traceback.format_exc()}") try: - await bot.send_text(from_wxid, "❌ 工具执行过程中出现错误") + await bot.send_text(from_wxid, "? 工具执行过程中出现错误") except: pass