This commit is contained in:
2025-12-31 17:47:39 +08:00
38 changed files with 4435 additions and 1343 deletions

286
utils/tool_registry.py Normal file
View File

@@ -0,0 +1,286 @@
"""
工具注册中心
集中管理所有 LLM 工具的注册、查找和执行
- O(1) 工具查找(替代 O(n) 插件遍历)
- 统一的超时保护
- 工具元信息管理
使用示例:
from utils.tool_registry import get_tool_registry
registry = get_tool_registry()
# 注册工具
registry.register(
name="generate_image",
plugin_name="AIChat",
schema={...},
executor=some_async_func,
timeout=120
)
# 执行工具
result = await registry.execute("generate_image", arguments, bot, from_wxid)
"""
import asyncio
from dataclasses import dataclass, field
from threading import Lock
from typing import Any, Callable, Dict, List, Optional, Awaitable
from loguru import logger
@dataclass
class ToolDefinition:
"""工具定义"""
name: str
plugin_name: str
schema: Dict[str, Any] # OpenAI-compatible tool schema
executor: Callable[..., Awaitable[Dict[str, Any]]]
timeout: float = 60.0
priority: int = 50 # 同名工具时优先级高的生效
description: str = ""
def __post_init__(self):
# 从 schema 提取描述
if not self.description and self.schema:
func_def = self.schema.get("function", {})
self.description = func_def.get("description", "")
class ToolRegistry:
"""
工具注册中心(线程安全单例)
功能:
- 工具注册与注销
- O(1) 工具查找
- 统一超时保护执行
- 工具列表导出(供 LLM 使用)
"""
_instance: Optional["ToolRegistry"] = None
_lock = Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
instance = super().__new__(cls)
instance._initialized = False
cls._instance = instance
return cls._instance
def __init__(self):
if self._initialized:
return
self._tools: Dict[str, ToolDefinition] = {}
self._tools_by_plugin: Dict[str, List[str]] = {} # plugin_name -> [tool_names]
self._registry_lock = Lock()
self._initialized = True
logger.debug("ToolRegistry 初始化完成")
def register(
self,
name: str,
plugin_name: str,
schema: Dict[str, Any],
executor: Callable[..., Awaitable[Dict[str, Any]]],
timeout: float = 60.0,
priority: int = 50,
) -> bool:
"""
注册工具
Args:
name: 工具名称(唯一标识)
plugin_name: 所属插件名
schema: OpenAI-compatible tool schema
executor: 异步执行函数,签名: async (tool_name, arguments, bot, from_wxid) -> dict
timeout: 执行超时(秒)
priority: 优先级(同名工具时高优先级覆盖低优先级)
Returns:
是否注册成功
"""
with self._registry_lock:
# 检查是否已存在同名工具
existing = self._tools.get(name)
if existing:
if existing.priority >= priority:
logger.debug(f"工具 {name} 已存在且优先级更高,跳过注册")
return False
logger.info(f"工具 {name}{plugin_name} 覆盖(优先级 {priority} > {existing.priority}")
# 从旧插件的工具列表中移除
old_plugin_tools = self._tools_by_plugin.get(existing.plugin_name, [])
if name in old_plugin_tools:
old_plugin_tools.remove(name)
# 注册新工具
tool_def = ToolDefinition(
name=name,
plugin_name=plugin_name,
schema=schema,
executor=executor,
timeout=timeout,
priority=priority,
)
self._tools[name] = tool_def
# 更新插件工具映射
if plugin_name not in self._tools_by_plugin:
self._tools_by_plugin[plugin_name] = []
if name not in self._tools_by_plugin[plugin_name]:
self._tools_by_plugin[plugin_name].append(name)
logger.debug(f"注册工具: {name} (插件: {plugin_name}, 超时: {timeout}s)")
return True
def unregister(self, name: str) -> bool:
"""注销工具"""
with self._registry_lock:
tool_def = self._tools.pop(name, None)
if tool_def:
plugin_tools = self._tools_by_plugin.get(tool_def.plugin_name, [])
if name in plugin_tools:
plugin_tools.remove(name)
logger.debug(f"注销工具: {name}")
return True
return False
def unregister_plugin(self, plugin_name: str) -> int:
"""
注销插件的所有工具
Args:
plugin_name: 插件名
Returns:
注销的工具数量
"""
with self._registry_lock:
tool_names = self._tools_by_plugin.pop(plugin_name, [])
count = 0
for name in tool_names:
if self._tools.pop(name, None):
count += 1
if count > 0:
logger.info(f"注销插件 {plugin_name}{count} 个工具")
return count
def get(self, name: str) -> Optional[ToolDefinition]:
"""获取工具定义O(1) 查找)"""
return self._tools.get(name)
def get_all_schemas(self) -> List[Dict[str, Any]]:
"""获取所有工具的 schema 列表(供 LLM 使用)"""
return [tool.schema for tool in self._tools.values()]
def get_plugin_schemas(self, plugin_name: str) -> List[Dict[str, Any]]:
"""获取指定插件的工具 schema 列表"""
tool_names = self._tools_by_plugin.get(plugin_name, [])
return [self._tools[name].schema for name in tool_names if name in self._tools]
def list_tools(self) -> List[str]:
"""列出所有工具名"""
return list(self._tools.keys())
def list_plugin_tools(self, plugin_name: str) -> List[str]:
"""列出插件的所有工具名"""
return self._tools_by_plugin.get(plugin_name, []).copy()
async def execute(
self,
name: str,
arguments: Dict[str, Any],
bot,
from_wxid: str,
timeout_override: float = None,
) -> Dict[str, Any]:
"""
执行工具(带超时保护和统一错误处理)
Args:
name: 工具名
arguments: 工具参数
bot: WechatHookClient 实例
from_wxid: 消息来源 wxid
timeout_override: 覆盖默认超时时间
Returns:
工具执行结果字典
"""
from utils.errors import (
ToolNotFoundError, ToolTimeoutError, ToolExecutionError,
handle_error
)
tool_def = self._tools.get(name)
if not tool_def:
err = ToolNotFoundError(f"工具 {name} 不存在")
return err.to_dict()
timeout = timeout_override if timeout_override is not None else tool_def.timeout
try:
result = await asyncio.wait_for(
tool_def.executor(name, arguments, bot, from_wxid),
timeout=timeout
)
return result
except asyncio.TimeoutError:
err = ToolTimeoutError(
message=f"工具 {name} 执行超时 ({timeout}s)",
user_message=f"工具执行超时 ({timeout}s)",
context={"tool_name": name, "timeout": timeout}
)
logger.warning(err.message)
result = err.to_dict()
result["timeout"] = True
return result
except Exception as e:
error_result = handle_error(
e,
context=f"执行工具 {name}",
log=True,
)
return error_result.to_dict()
def get_stats(self) -> Dict[str, Any]:
"""获取注册统计信息"""
return {
"total_tools": len(self._tools),
"plugins": len(self._tools_by_plugin),
"tools_by_plugin": {
plugin: len(tools)
for plugin, tools in self._tools_by_plugin.items()
}
}
def clear(self):
"""清空所有注册(用于测试或重置)"""
with self._registry_lock:
self._tools.clear()
self._tools_by_plugin.clear()
logger.info("ToolRegistry 已清空")
# ==================== 便捷函数 ====================
def get_tool_registry() -> ToolRegistry:
"""获取工具注册中心实例"""
return ToolRegistry()
# ==================== 导出列表 ====================
__all__ = [
'ToolDefinition',
'ToolRegistry',
'get_tool_registry',
]