287 lines
8.8 KiB
Python
287 lines
8.8 KiB
Python
"""
|
||
工具注册中心
|
||
|
||
集中管理所有 LLM 工具的注册、查找和执行
|
||
- O(1) 工具查找(替代 O(n) 插件遍历)
|
||
- 统一的超时保护
|
||
- 工具元信息管理
|
||
|
||
使用示例:
|
||
from utils.tool_registry import get_tool_registry
|
||
|
||
registry = get_tool_registry()
|
||
|
||
# 注册工具
|
||
registry.register(
|
||
name="generate_image",
|
||
plugin_name="AIChat",
|
||
schema={...},
|
||
executor=some_async_func,
|
||
timeout=120
|
||
)
|
||
|
||
# 执行工具
|
||
result = await registry.execute("generate_image", arguments, bot, from_wxid)
|
||
"""
|
||
|
||
import asyncio
|
||
from dataclasses import dataclass, field
|
||
from threading import Lock
|
||
from typing import Any, Callable, Dict, List, Optional, Awaitable
|
||
|
||
from loguru import logger
|
||
|
||
|
||
@dataclass
|
||
class ToolDefinition:
|
||
"""工具定义"""
|
||
name: str
|
||
plugin_name: str
|
||
schema: Dict[str, Any] # OpenAI-compatible tool schema
|
||
executor: Callable[..., Awaitable[Dict[str, Any]]]
|
||
timeout: float = 60.0
|
||
priority: int = 50 # 同名工具时优先级高的生效
|
||
description: str = ""
|
||
|
||
def __post_init__(self):
|
||
# 从 schema 提取描述
|
||
if not self.description and self.schema:
|
||
func_def = self.schema.get("function", {})
|
||
self.description = func_def.get("description", "")
|
||
|
||
|
||
class ToolRegistry:
|
||
"""
|
||
工具注册中心(线程安全单例)
|
||
|
||
功能:
|
||
- 工具注册与注销
|
||
- O(1) 工具查找
|
||
- 统一超时保护执行
|
||
- 工具列表导出(供 LLM 使用)
|
||
"""
|
||
|
||
_instance: Optional["ToolRegistry"] = None
|
||
_lock = Lock()
|
||
|
||
def __new__(cls):
|
||
if cls._instance is None:
|
||
with cls._lock:
|
||
if cls._instance is None:
|
||
instance = super().__new__(cls)
|
||
instance._initialized = False
|
||
cls._instance = instance
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if self._initialized:
|
||
return
|
||
|
||
self._tools: Dict[str, ToolDefinition] = {}
|
||
self._tools_by_plugin: Dict[str, List[str]] = {} # plugin_name -> [tool_names]
|
||
self._registry_lock = Lock()
|
||
self._initialized = True
|
||
logger.debug("ToolRegistry 初始化完成")
|
||
|
||
def register(
|
||
self,
|
||
name: str,
|
||
plugin_name: str,
|
||
schema: Dict[str, Any],
|
||
executor: Callable[..., Awaitable[Dict[str, Any]]],
|
||
timeout: float = 60.0,
|
||
priority: int = 50,
|
||
) -> bool:
|
||
"""
|
||
注册工具
|
||
|
||
Args:
|
||
name: 工具名称(唯一标识)
|
||
plugin_name: 所属插件名
|
||
schema: OpenAI-compatible tool schema
|
||
executor: 异步执行函数,签名: async (tool_name, arguments, bot, from_wxid) -> dict
|
||
timeout: 执行超时(秒)
|
||
priority: 优先级(同名工具时高优先级覆盖低优先级)
|
||
|
||
Returns:
|
||
是否注册成功
|
||
"""
|
||
with self._registry_lock:
|
||
# 检查是否已存在同名工具
|
||
existing = self._tools.get(name)
|
||
if existing:
|
||
if existing.priority >= priority:
|
||
logger.debug(f"工具 {name} 已存在且优先级更高,跳过注册")
|
||
return False
|
||
logger.info(f"工具 {name} 被 {plugin_name} 覆盖(优先级 {priority} > {existing.priority})")
|
||
# 从旧插件的工具列表中移除
|
||
old_plugin_tools = self._tools_by_plugin.get(existing.plugin_name, [])
|
||
if name in old_plugin_tools:
|
||
old_plugin_tools.remove(name)
|
||
|
||
# 注册新工具
|
||
tool_def = ToolDefinition(
|
||
name=name,
|
||
plugin_name=plugin_name,
|
||
schema=schema,
|
||
executor=executor,
|
||
timeout=timeout,
|
||
priority=priority,
|
||
)
|
||
self._tools[name] = tool_def
|
||
|
||
# 更新插件工具映射
|
||
if plugin_name not in self._tools_by_plugin:
|
||
self._tools_by_plugin[plugin_name] = []
|
||
if name not in self._tools_by_plugin[plugin_name]:
|
||
self._tools_by_plugin[plugin_name].append(name)
|
||
|
||
logger.debug(f"注册工具: {name} (插件: {plugin_name}, 超时: {timeout}s)")
|
||
return True
|
||
|
||
def unregister(self, name: str) -> bool:
|
||
"""注销工具"""
|
||
with self._registry_lock:
|
||
tool_def = self._tools.pop(name, None)
|
||
if tool_def:
|
||
plugin_tools = self._tools_by_plugin.get(tool_def.plugin_name, [])
|
||
if name in plugin_tools:
|
||
plugin_tools.remove(name)
|
||
logger.debug(f"注销工具: {name}")
|
||
return True
|
||
return False
|
||
|
||
def unregister_plugin(self, plugin_name: str) -> int:
|
||
"""
|
||
注销插件的所有工具
|
||
|
||
Args:
|
||
plugin_name: 插件名
|
||
|
||
Returns:
|
||
注销的工具数量
|
||
"""
|
||
with self._registry_lock:
|
||
tool_names = self._tools_by_plugin.pop(plugin_name, [])
|
||
count = 0
|
||
for name in tool_names:
|
||
if self._tools.pop(name, None):
|
||
count += 1
|
||
if count > 0:
|
||
logger.info(f"注销插件 {plugin_name} 的 {count} 个工具")
|
||
return count
|
||
|
||
def get(self, name: str) -> Optional[ToolDefinition]:
|
||
"""获取工具定义(O(1) 查找)"""
|
||
return self._tools.get(name)
|
||
|
||
def get_all_schemas(self) -> List[Dict[str, Any]]:
|
||
"""获取所有工具的 schema 列表(供 LLM 使用)"""
|
||
return [tool.schema for tool in self._tools.values()]
|
||
|
||
def get_plugin_schemas(self, plugin_name: str) -> List[Dict[str, Any]]:
|
||
"""获取指定插件的工具 schema 列表"""
|
||
tool_names = self._tools_by_plugin.get(plugin_name, [])
|
||
return [self._tools[name].schema for name in tool_names if name in self._tools]
|
||
|
||
def list_tools(self) -> List[str]:
|
||
"""列出所有工具名"""
|
||
return list(self._tools.keys())
|
||
|
||
def list_plugin_tools(self, plugin_name: str) -> List[str]:
|
||
"""列出插件的所有工具名"""
|
||
return self._tools_by_plugin.get(plugin_name, []).copy()
|
||
|
||
async def execute(
|
||
self,
|
||
name: str,
|
||
arguments: Dict[str, Any],
|
||
bot,
|
||
from_wxid: str,
|
||
timeout_override: float = None,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
执行工具(带超时保护和统一错误处理)
|
||
|
||
Args:
|
||
name: 工具名
|
||
arguments: 工具参数
|
||
bot: WechatHookClient 实例
|
||
from_wxid: 消息来源 wxid
|
||
timeout_override: 覆盖默认超时时间
|
||
|
||
Returns:
|
||
工具执行结果字典
|
||
"""
|
||
from utils.errors import (
|
||
ToolNotFoundError, ToolTimeoutError, ToolExecutionError,
|
||
handle_error
|
||
)
|
||
|
||
tool_def = self._tools.get(name)
|
||
if not tool_def:
|
||
err = ToolNotFoundError(f"工具 {name} 不存在")
|
||
return err.to_dict()
|
||
|
||
timeout = timeout_override if timeout_override is not None else tool_def.timeout
|
||
|
||
try:
|
||
result = await asyncio.wait_for(
|
||
tool_def.executor(name, arguments, bot, from_wxid),
|
||
timeout=timeout
|
||
)
|
||
return result
|
||
|
||
except asyncio.TimeoutError:
|
||
err = ToolTimeoutError(
|
||
message=f"工具 {name} 执行超时 ({timeout}s)",
|
||
user_message=f"工具执行超时 ({timeout}s)",
|
||
context={"tool_name": name, "timeout": timeout}
|
||
)
|
||
logger.warning(err.message)
|
||
result = err.to_dict()
|
||
result["timeout"] = True
|
||
return result
|
||
|
||
except Exception as e:
|
||
error_result = handle_error(
|
||
e,
|
||
context=f"执行工具 {name}",
|
||
log=True,
|
||
)
|
||
return error_result.to_dict()
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""获取注册统计信息"""
|
||
return {
|
||
"total_tools": len(self._tools),
|
||
"plugins": len(self._tools_by_plugin),
|
||
"tools_by_plugin": {
|
||
plugin: len(tools)
|
||
for plugin, tools in self._tools_by_plugin.items()
|
||
}
|
||
}
|
||
|
||
def clear(self):
|
||
"""清空所有注册(用于测试或重置)"""
|
||
with self._registry_lock:
|
||
self._tools.clear()
|
||
self._tools_by_plugin.clear()
|
||
logger.info("ToolRegistry 已清空")
|
||
|
||
|
||
# ==================== 便捷函数 ====================
|
||
|
||
def get_tool_registry() -> ToolRegistry:
|
||
"""获取工具注册中心实例"""
|
||
return ToolRegistry()
|
||
|
||
|
||
# ==================== 导出列表 ====================
|
||
|
||
__all__ = [
|
||
'ToolDefinition',
|
||
'ToolRegistry',
|
||
'get_tool_registry',
|
||
]
|