feat:超级屎山优化

This commit is contained in:
2025-12-31 18:39:21 +08:00
parent b25d3b4f0a
commit 820861752b

View File

@@ -17,8 +17,8 @@ from loguru import logger
from utils.plugin_base import PluginBase
from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message
from utils.redis_cache import get_cache
from utils.llm_tooling import ToolResult, collect_tools_with_plugins, collect_tools, get_tool_schema_map, validate_tool_arguments
from utils.image_processor import ImageProcessor, MediaConfig
from utils.tool_executor import ToolExecutor
from utils.tool_registry import get_tool_registry
import xml.etree.ElementTree as ET
import base64
@@ -472,25 +472,39 @@ class AIChat(PluginBase):
return ""
def _collect_tools_with_plugins(self) -> dict:
"""收集所有插件的 LLM 工具,并保留来源插件名"""
from utils.plugin_manager import PluginManager
tools_config = self.config.get("tools", {})
return collect_tools_with_plugins(tools_config, PluginManager().plugins)
"""收集工具定义(来自 ToolRegistry并保留来源插件名"""
registry = get_tool_registry()
tools_config = (self.config or {}).get("tools", {})
mode = tools_config.get("mode", "all")
whitelist = set(tools_config.get("whitelist", []))
blacklist = set(tools_config.get("blacklist", []))
tools_map = {}
for name in registry.list_tools():
tool_def = registry.get(name)
if not tool_def:
continue
if mode == "whitelist" and name not in whitelist:
continue
if mode == "blacklist" and name in blacklist:
continue
tools_map[name] = (tool_def.plugin_name, tool_def.schema)
return tools_map
def _collect_tools(self):
"""收集所有插件的LLM工具支持白名单/黑名单过滤)"""
from utils.plugin_manager import PluginManager
tools_config = self.config.get("tools", {})
return collect_tools(tools_config, PluginManager().plugins)
tools_map = self._collect_tools_with_plugins()
return [item[1] for item in tools_map.values()]
def _get_tool_schema_map(self, tools_map: dict | None = None) -> dict:
"""构建工具名到参数 schema 的映射"""
tools_map = tools_map or self._collect_tools_with_plugins()
return get_tool_schema_map(tools_map)
def _validate_tool_arguments(self, tool_name: str, arguments: dict, schema: dict) -> tuple:
"""轻量校验并补全默认参数"""
return validate_tool_arguments(tool_name, arguments, schema)
schema_map = {}
for name, (_plugin_name, tool) in tools_map.items():
fn = tool.get("function", {})
schema_map[name] = fn.get("parameters", {}) or {}
return schema_map
async def _handle_list_prompts(self, bot, from_wxid: str):
"""处理人设列表指令"""
@@ -627,10 +641,12 @@ class AIChat(PluginBase):
)
cleaned = cleaned.replace("展开阅读下文", "")
cleaned = re.sub(
r"已触发工具处理[^]{0,300}结果将发送到聊天中。)",
r"[(]已触发工具处理[^)\r\n]{0,500}[)]?",
"",
cleaned,
)
cleaned = re.sub(r"(?m)^.*已触发工具处理.*$", "", cleaned)
cleaned = re.sub(r"(?m)^.*结果将发送到聊天中.*$", "", cleaned)
# 过滤图片占位符/文件名,避免把日志占位符当成正文发出去
cleaned = re.sub(
r"\\[图片[^\\]]*\\]\\s*\\S+\\.(?:png|jpe?g|gif|webp)",
@@ -2297,56 +2313,69 @@ class AIChat(PluginBase):
await self.store.update_group_message_by_id(chat_id, record_id, new_content)
async def _execute_tool_and_get_result(
def _prepare_tool_calls_for_executor(
self,
tool_name: str,
arguments: dict,
bot,
tool_calls_data: list,
messages: list,
*,
user_wxid: str,
from_wxid: str,
user_wxid: str = None,
is_group: bool = False,
tools_map: dict | None = None,
timeout: float = None,
):
"""
执行工具调用并返回结果(使用 ToolRegistry
is_group: bool,
image_base64: str | None = None,
) -> list:
prepared = []
if not tool_calls_data:
return prepared
通过 ToolRegistry 实现 O(1) 工具查找和统一超时保护
"""
# 获取工具专属超时时间
if timeout is None:
tool_timeout_config = self.config.get("tools", {}).get("timeout", {})
timeout = tool_timeout_config.get(tool_name, tool_timeout_config.get("default", 60))
for tool_call in tool_calls_data:
function = (tool_call or {}).get("function") or {}
function_name = function.get("name", "")
if not function_name:
continue
# 添加用户信息到 arguments
arguments["user_wxid"] = user_wxid or from_wxid
arguments["is_group"] = bool(is_group)
tool_call_id = (tool_call or {}).get("id", "")
if not tool_call_id:
tool_call_id = f"call_{uuid.uuid4().hex[:8]}"
tool_call["id"] = tool_call_id
logger.info(f"开始执行工具: {tool_name} (超时: {timeout}s)")
raw_arguments = function.get("arguments", "{}")
try:
arguments = json.loads(raw_arguments) if raw_arguments else {}
if not isinstance(arguments, dict):
arguments = {}
except Exception:
arguments = {}
if "function" not in tool_call:
tool_call["function"] = {}
tool_call["function"]["arguments"] = "{}"
# 使用 ToolRegistry 执行工具O(1) 查找 + 统一超时保护)
registry = get_tool_registry()
result = await registry.execute(tool_name, arguments, bot, from_wxid, timeout_override=timeout)
if function_name in ("tavily_web_search", "web_search") and not arguments.get("query"):
fallback_query = self._extract_tool_intent_text(self._extract_last_user_text(messages))
fallback_query = str(fallback_query or "").strip()
if fallback_query:
arguments["query"] = fallback_query[:400]
if "function" not in tool_call:
tool_call["function"] = {}
tool_call["function"]["arguments"] = json.dumps(arguments, ensure_ascii=False)
# 规范化结果
if result is None:
return {"success": False, "message": f"工具 {tool_name} 返回空结果"}
exec_args = dict(arguments)
exec_args["user_wxid"] = user_wxid or from_wxid
exec_args["is_group"] = bool(is_group)
if not isinstance(result, dict):
result = {"success": True, "message": str(result)}
else:
result.setdefault("success", True)
if image_base64 and function_name == "flow2_ai_image_generation":
exec_args["image_base64"] = image_base64
logger.info("[异步-图片] 图生图工具,已添加图片数据")
# 记录执行结果
tool_def = registry.get(tool_name)
plugin_name = tool_def.plugin_name if tool_def else "unknown"
prepared.append({
"id": tool_call_id,
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(exec_args, ensure_ascii=False),
},
})
if result.get("success"):
logger.success(f"工具执行成功: {tool_name} ({plugin_name})")
else:
logger.warning(f"工具执行失败: {tool_name} ({plugin_name})")
return result
return prepared
async def _execute_tools_async(self, tool_calls_data: list, bot, from_wxid: str,
chat_id: str, user_wxid: str, nickname: str, is_group: bool,
@@ -2360,104 +2389,39 @@ class AIChat(PluginBase):
try:
logger.info(f"开始异步执行 {len(tool_calls_data)} 个工具调用")
# 获取并发控制配置
concurrency_config = self.config.get("tools", {}).get("concurrency", {})
concurrency_config = (self.config or {}).get("tools", {}).get("concurrency", {})
max_concurrent = concurrency_config.get("max_concurrent", 5)
semaphore = asyncio.Semaphore(max_concurrent)
timeout_config = (self.config or {}).get("tools", {}).get("timeout", {})
default_timeout = timeout_config.get("default", 60)
# 并行执行所有工具(带并发限制)
tasks = []
tool_info_list = [] # 保存工具信息用于后续处理
tools_map = self._collect_tools_with_plugins()
schema_map = self._get_tool_schema_map(tools_map)
for tool_call in tool_calls_data:
function_name = tool_call.get("function", {}).get("name", "")
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
tool_call_id = tool_call.get("id", "")
if not function_name:
continue
try:
arguments = json.loads(arguments_str)
except Exception:
arguments = {}
if function_name in ("tavily_web_search", "web_search") and not arguments.get("query"):
fallback_query = self._extract_tool_intent_text(self._extract_last_user_text(messages))
fallback_query = str(fallback_query or "").strip()
if fallback_query:
arguments["query"] = fallback_query[:400]
schema = schema_map.get(function_name)
ok, err, arguments = self._validate_tool_arguments(function_name, arguments, schema)
if not ok:
logger.warning(f"[异步] 工具 {function_name} 参数校验失败: {err}")
try:
await bot.send_text(from_wxid, f"❌ 工具 {function_name} 参数错误: {err}")
except Exception:
pass
continue
logger.info(f"[异步] 准备执行工具: {function_name}, 参数: {arguments}")
# 创建带并发限制的异步任务
async def execute_with_semaphore(fn, args, bot_ref, wxid, user_wxid_ref, is_grp, t_map, sem):
async with sem:
return await self._execute_tool_and_get_result(
fn, args, bot_ref, wxid,
user_wxid=user_wxid_ref, is_group=is_grp, tools_map=t_map
executor = ToolExecutor(default_timeout=default_timeout, max_parallel=max_concurrent)
prepared_tool_calls = self._prepare_tool_calls_for_executor(
tool_calls_data,
messages,
user_wxid=user_wxid,
from_wxid=from_wxid,
is_group=is_group,
)
task = execute_with_semaphore(
function_name, arguments, bot, from_wxid,
user_wxid, is_group, tools_map, semaphore
)
tasks.append(task)
tool_info_list.append({
"tool_call_id": tool_call_id,
"function_name": function_name,
"arguments": arguments
})
if not prepared_tool_calls:
logger.info("[异步] 没有可执行的工具调用")
return
# 并行执行所有工具(带并发限制,防止资源耗尽)
if tasks:
logger.info(f"[异步] 开始并行执行 {len(tasks)} 个工具 (最大并发: {max_concurrent})")
results = await asyncio.gather(*tasks, return_exceptions=True)
logger.info(f"[异步] 开始并行执行 {len(prepared_tool_calls)} 个工具 (最大并发: {max_concurrent})")
results = await executor.execute_batch(prepared_tool_calls, bot, from_wxid, parallel=True)
need_ai_reply_results = []
# 处理每个工具的结果
for i, result in enumerate(results):
tool_info = tool_info_list[i]
function_name = tool_info["function_name"]
tool_call_id = tool_info["tool_call_id"]
tool_call_id = tool_info["tool_call_id"]
for result in results:
function_name = result.name
tool_call_id = result.id
tool_message = self._sanitize_llm_output(result.message or "")
if isinstance(result, Exception):
logger.error(f"[异步] 工具 {function_name} 执行异常: {result}")
try:
await bot.send_text(from_wxid, f"{function_name} 执行失败: {result}")
except Exception:
pass
continue
tool_result = ToolResult.from_raw(result)
if not tool_result:
continue
tool_message = self._sanitize_llm_output(tool_result.message) if tool_result.message is not None else ""
# 工具文本统一做一次输出清洗,避免工具内部/下游LLM把“思维链”发出来
tool_message = self._sanitize_llm_output(tool_result.message) if tool_result.message is not None else ""
if tool_result.success:
if result.success:
logger.success(f"[异步] 工具 {function_name} 执行成功")
else:
logger.warning(f"[异步] 工具 {function_name} 执行失败")
logger.warning(f"[异步] 工具 {function_name} 执行失败: {result.error or result.message}")
# 需要 AI 继续处理的结果
if tool_result.need_ai_reply:
if result.need_ai_reply:
need_ai_reply_results.append({
"tool_call_id": tool_call_id,
"function_name": function_name,
@@ -2465,30 +2429,25 @@ class AIChat(PluginBase):
})
continue
# 工具成功且需要回文本时发送
if tool_result.success and not tool_result.already_sent and tool_result.message and not tool_result.no_reply:
if tool_result.send_result_text:
if result.success and not result.already_sent and tool_message and not result.no_reply:
if result.send_result_text:
if tool_message:
await bot.send_text(from_wxid, tool_message)
else:
logger.warning(f"[异步] 工具 {function_name} 输出清洗后为空,已跳过发送")
# 工具失败默认回一条错误提示
if not tool_result.success and tool_message and not tool_result.no_reply:
if not result.success and not result.no_reply:
try:
if tool_message:
await bot.send_text(from_wxid, f" {tool_message}")
await bot.send_text(from_wxid, f"? {tool_message}")
else:
await bot.send_text(from_wxid, f" {function_name} 执行失败")
await bot.send_text(from_wxid, f"? {function_name} 执行失败")
except Exception:
pass
# 保存工具结果到记忆(可选)
if tool_result.save_to_memory and chat_id:
if tool_message:
if result.save_to_memory and chat_id and tool_message:
self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {tool_message}")
# 如果有需要 AI 回复的工具结果,调用 AI 继续对话
if need_ai_reply_results:
await self._continue_with_tool_results(
need_ai_reply_results, bot, from_wxid, user_wxid, chat_id,
@@ -2502,7 +2461,7 @@ class AIChat(PluginBase):
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
try:
await bot.send_text(from_wxid, " 工具执行过程中出现错误")
await bot.send_text(from_wxid, "? 工具执行过程中出现错误")
except:
pass
@@ -2700,94 +2659,40 @@ class AIChat(PluginBase):
try:
logger.info(f"[异步-图片] 开始执行 {len(tool_calls_data)} 个工具调用")
# 并行执行所有工具
tasks = []
tool_info_list = []
tools_map = self._collect_tools_with_plugins()
schema_map = self._get_tool_schema_map(tools_map)
concurrency_config = (self.config or {}).get("tools", {}).get("concurrency", {})
max_concurrent = concurrency_config.get("max_concurrent", 5)
timeout_config = (self.config or {}).get("tools", {}).get("timeout", {})
default_timeout = timeout_config.get("default", 60)
for tool_call in tool_calls_data:
function_name = tool_call.get("function", {}).get("name", "")
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
tool_call_id = tool_call.get("id", "")
if not function_name:
continue
try:
arguments = json.loads(arguments_str)
except Exception:
arguments = {}
if function_name in ("tavily_web_search", "web_search") and not arguments.get("query"):
fallback_query = self._extract_tool_intent_text(self._extract_last_user_text(messages))
fallback_query = str(fallback_query or "").strip()
if fallback_query:
arguments["query"] = fallback_query[:400]
# 如果是图生图工具,添加图片 base64
if function_name == "flow2_ai_image_generation" and image_base64:
arguments["image_base64"] = image_base64
logger.info(f"[异步-图片] 图生图工具,已添加图片数据")
schema = schema_map.get(function_name)
ok, err, arguments = self._validate_tool_arguments(function_name, arguments, schema)
if not ok:
logger.warning(f"[异步-图片] 工具 {function_name} 参数校验失败: {err}")
try:
await bot.send_text(from_wxid, f"❌ 工具 {function_name} 参数错误: {err}")
except Exception:
pass
continue
logger.info(f"[异步-图片] 准备执行工具: {function_name}")
task = self._execute_tool_and_get_result(
function_name,
arguments,
bot,
from_wxid,
executor = ToolExecutor(default_timeout=default_timeout, max_parallel=max_concurrent)
prepared_tool_calls = self._prepare_tool_calls_for_executor(
tool_calls_data,
messages,
user_wxid=user_wxid,
from_wxid=from_wxid,
is_group=is_group,
tools_map=tools_map,
image_base64=image_base64,
)
tasks.append(task)
tool_info_list.append({
"tool_call_id": tool_call_id,
"function_name": function_name,
"arguments": arguments
})
# 并行执行所有工具
if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True)
if not prepared_tool_calls:
logger.info("[异步-图片] 没有可执行的工具调用")
return
logger.info(f"[异步-图片] 开始并行执行 {len(prepared_tool_calls)} 个工具 (最大并发: {max_concurrent})")
results = await executor.execute_batch(prepared_tool_calls, bot, from_wxid, parallel=True)
need_ai_reply_results = []
for i, result in enumerate(results):
tool_info = tool_info_list[i]
function_name = tool_info["function_name"]
tool_call_id = tool_info["tool_call_id"]
for result in results:
function_name = result.name
tool_call_id = result.id
tool_message = self._sanitize_llm_output(result.message or "")
if isinstance(result, Exception):
logger.error(f"[异步-图片] 工具 {function_name} 执行异常: {result}")
try:
await bot.send_text(from_wxid, f"{function_name} 执行失败: {result}")
except Exception:
pass
continue
tool_result = ToolResult.from_raw(result)
if not tool_result:
continue
tool_message = self._sanitize_llm_output(tool_result.message or "")
if tool_result.success:
if result.success:
logger.success(f"[异步-图片] 工具 {function_name} 执行成功")
else:
logger.warning(f"[异步-图片] 工具 {function_name} 执行失败")
logger.warning(f"[异步-图片] 工具 {function_name} 执行失败: {result.error or result.message}")
if tool_result.need_ai_reply:
if result.need_ai_reply:
need_ai_reply_results.append({
"tool_call_id": tool_call_id,
"function_name": function_name,
@@ -2795,24 +2700,23 @@ class AIChat(PluginBase):
})
continue
if tool_result.success and not tool_result.already_sent and tool_message and not tool_result.no_reply:
if tool_result.send_result_text:
if result.success and not result.already_sent and tool_message and not result.no_reply:
if result.send_result_text:
if tool_message:
await bot.send_text(from_wxid, tool_message)
else:
logger.warning(f"[异步-图片] 工具 {function_name} 输出清洗后为空,已跳过发送")
if not tool_result.success and tool_message and not tool_result.no_reply:
if not result.success and not result.no_reply:
try:
if tool_message:
await bot.send_text(from_wxid, f" {tool_message}")
await bot.send_text(from_wxid, f"? {tool_message}")
else:
await bot.send_text(from_wxid, f" {function_name} 执行失败")
await bot.send_text(from_wxid, f"? {function_name} 执行失败")
except Exception:
pass
if tool_result.save_to_memory and chat_id:
if tool_message:
if result.save_to_memory and chat_id and tool_message:
self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {tool_message}")
if need_ai_reply_results:
@@ -2828,7 +2732,7 @@ class AIChat(PluginBase):
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
try:
await bot.send_text(from_wxid, " 工具执行过程中出现错误")
await bot.send_text(from_wxid, "? 工具执行过程中出现错误")
except:
pass