""" 工具注册中心 集中管理所有 LLM 工具的注册、查找和执行 - O(1) 工具查找(替代 O(n) 插件遍历) - 统一的超时保护 - 工具元信息管理 使用示例: from utils.tool_registry import get_tool_registry registry = get_tool_registry() # 注册工具 registry.register( name="generate_image", plugin_name="AIChat", schema={...}, executor=some_async_func, timeout=120 ) # 执行工具 result = await registry.execute("generate_image", arguments, bot, from_wxid) """ import asyncio from dataclasses import dataclass, field from threading import Lock from typing import Any, Callable, Dict, List, Optional, Awaitable from loguru import logger @dataclass class ToolDefinition: """工具定义""" name: str plugin_name: str schema: Dict[str, Any] # OpenAI-compatible tool schema executor: Callable[..., Awaitable[Dict[str, Any]]] timeout: float = 60.0 priority: int = 50 # 同名工具时优先级高的生效 description: str = "" def __post_init__(self): # 从 schema 提取描述 if not self.description and self.schema: func_def = self.schema.get("function", {}) self.description = func_def.get("description", "") class ToolRegistry: """ 工具注册中心(线程安全单例) 功能: - 工具注册与注销 - O(1) 工具查找 - 统一超时保护执行 - 工具列表导出(供 LLM 使用) """ _instance: Optional["ToolRegistry"] = None _lock = Lock() def __new__(cls): if cls._instance is None: with cls._lock: if cls._instance is None: instance = super().__new__(cls) instance._initialized = False cls._instance = instance return cls._instance def __init__(self): if self._initialized: return self._tools: Dict[str, ToolDefinition] = {} self._tools_by_plugin: Dict[str, List[str]] = {} # plugin_name -> [tool_names] self._registry_lock = Lock() self._initialized = True logger.debug("ToolRegistry 初始化完成") def register( self, name: str, plugin_name: str, schema: Dict[str, Any], executor: Callable[..., Awaitable[Dict[str, Any]]], timeout: float = 60.0, priority: int = 50, ) -> bool: """ 注册工具 Args: name: 工具名称(唯一标识) plugin_name: 所属插件名 schema: OpenAI-compatible tool schema executor: 异步执行函数,签名: async (tool_name, arguments, bot, from_wxid) -> dict timeout: 执行超时(秒) priority: 优先级(同名工具时高优先级覆盖低优先级) Returns: 是否注册成功 """ with self._registry_lock: # 检查是否已存在同名工具 existing = self._tools.get(name) if existing: if existing.priority >= priority: logger.debug(f"工具 {name} 已存在且优先级更高,跳过注册") return False logger.info(f"工具 {name} 被 {plugin_name} 覆盖(优先级 {priority} > {existing.priority})") # 从旧插件的工具列表中移除 old_plugin_tools = self._tools_by_plugin.get(existing.plugin_name, []) if name in old_plugin_tools: old_plugin_tools.remove(name) # 注册新工具 tool_def = ToolDefinition( name=name, plugin_name=plugin_name, schema=schema, executor=executor, timeout=timeout, priority=priority, ) self._tools[name] = tool_def # 更新插件工具映射 if plugin_name not in self._tools_by_plugin: self._tools_by_plugin[plugin_name] = [] if name not in self._tools_by_plugin[plugin_name]: self._tools_by_plugin[plugin_name].append(name) logger.debug(f"注册工具: {name} (插件: {plugin_name}, 超时: {timeout}s)") return True def unregister(self, name: str) -> bool: """注销工具""" with self._registry_lock: tool_def = self._tools.pop(name, None) if tool_def: plugin_tools = self._tools_by_plugin.get(tool_def.plugin_name, []) if name in plugin_tools: plugin_tools.remove(name) logger.debug(f"注销工具: {name}") return True return False def unregister_plugin(self, plugin_name: str) -> int: """ 注销插件的所有工具 Args: plugin_name: 插件名 Returns: 注销的工具数量 """ with self._registry_lock: tool_names = self._tools_by_plugin.pop(plugin_name, []) count = 0 for name in tool_names: if self._tools.pop(name, None): count += 1 if count > 0: logger.info(f"注销插件 {plugin_name} 的 {count} 个工具") return count def get(self, name: str) -> Optional[ToolDefinition]: """获取工具定义(O(1) 查找)""" return self._tools.get(name) def get_all_schemas(self) -> List[Dict[str, Any]]: """获取所有工具的 schema 列表(供 LLM 使用)""" return [tool.schema for tool in self._tools.values()] def get_plugin_schemas(self, plugin_name: str) -> List[Dict[str, Any]]: """获取指定插件的工具 schema 列表""" tool_names = self._tools_by_plugin.get(plugin_name, []) return [self._tools[name].schema for name in tool_names if name in self._tools] def list_tools(self) -> List[str]: """列出所有工具名""" return list(self._tools.keys()) def list_plugin_tools(self, plugin_name: str) -> List[str]: """列出插件的所有工具名""" return self._tools_by_plugin.get(plugin_name, []).copy() async def execute( self, name: str, arguments: Dict[str, Any], bot, from_wxid: str, timeout_override: float = None, ) -> Dict[str, Any]: """ 执行工具(带超时保护和统一错误处理) Args: name: 工具名 arguments: 工具参数 bot: WechatHookClient 实例 from_wxid: 消息来源 wxid timeout_override: 覆盖默认超时时间 Returns: 工具执行结果字典 """ from utils.errors import ( ToolNotFoundError, ToolTimeoutError, ToolExecutionError, handle_error ) tool_def = self._tools.get(name) if not tool_def: err = ToolNotFoundError(f"工具 {name} 不存在") return err.to_dict() timeout = timeout_override if timeout_override is not None else tool_def.timeout try: result = await asyncio.wait_for( tool_def.executor(name, arguments, bot, from_wxid), timeout=timeout ) return result except asyncio.TimeoutError: err = ToolTimeoutError( message=f"工具 {name} 执行超时 ({timeout}s)", user_message=f"工具执行超时 ({timeout}s)", context={"tool_name": name, "timeout": timeout} ) logger.warning(err.message) result = err.to_dict() result["timeout"] = True return result except Exception as e: error_result = handle_error( e, context=f"执行工具 {name}", log=True, ) return error_result.to_dict() def get_stats(self) -> Dict[str, Any]: """获取注册统计信息""" return { "total_tools": len(self._tools), "plugins": len(self._tools_by_plugin), "tools_by_plugin": { plugin: len(tools) for plugin, tools in self._tools_by_plugin.items() } } def clear(self): """清空所有注册(用于测试或重置)""" with self._registry_lock: self._tools.clear() self._tools_by_plugin.clear() logger.info("ToolRegistry 已清空") # ==================== 便捷函数 ==================== def get_tool_registry() -> ToolRegistry: """获取工具注册中心实例""" return ToolRegistry() # ==================== 导出列表 ==================== __all__ = [ 'ToolDefinition', 'ToolRegistry', 'get_tool_registry', ]