Files
WechatHookBot/utils/tool_registry.py

287 lines
8.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工具注册中心
集中管理所有 LLM 工具的注册、查找和执行
- O(1) 工具查找(替代 O(n) 插件遍历)
- 统一的超时保护
- 工具元信息管理
使用示例:
from utils.tool_registry import get_tool_registry
registry = get_tool_registry()
# 注册工具
registry.register(
name="generate_image",
plugin_name="AIChat",
schema={...},
executor=some_async_func,
timeout=120
)
# 执行工具
result = await registry.execute("generate_image", arguments, bot, from_wxid)
"""
import asyncio
from dataclasses import dataclass, field
from threading import Lock
from typing import Any, Callable, Dict, List, Optional, Awaitable
from loguru import logger
@dataclass
class ToolDefinition:
"""工具定义"""
name: str
plugin_name: str
schema: Dict[str, Any] # OpenAI-compatible tool schema
executor: Callable[..., Awaitable[Dict[str, Any]]]
timeout: float = 60.0
priority: int = 50 # 同名工具时优先级高的生效
description: str = ""
def __post_init__(self):
# 从 schema 提取描述
if not self.description and self.schema:
func_def = self.schema.get("function", {})
self.description = func_def.get("description", "")
class ToolRegistry:
"""
工具注册中心(线程安全单例)
功能:
- 工具注册与注销
- O(1) 工具查找
- 统一超时保护执行
- 工具列表导出(供 LLM 使用)
"""
_instance: Optional["ToolRegistry"] = None
_lock = Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
instance = super().__new__(cls)
instance._initialized = False
cls._instance = instance
return cls._instance
def __init__(self):
if self._initialized:
return
self._tools: Dict[str, ToolDefinition] = {}
self._tools_by_plugin: Dict[str, List[str]] = {} # plugin_name -> [tool_names]
self._registry_lock = Lock()
self._initialized = True
logger.debug("ToolRegistry 初始化完成")
def register(
self,
name: str,
plugin_name: str,
schema: Dict[str, Any],
executor: Callable[..., Awaitable[Dict[str, Any]]],
timeout: float = 60.0,
priority: int = 50,
) -> bool:
"""
注册工具
Args:
name: 工具名称(唯一标识)
plugin_name: 所属插件名
schema: OpenAI-compatible tool schema
executor: 异步执行函数,签名: async (tool_name, arguments, bot, from_wxid) -> dict
timeout: 执行超时(秒)
priority: 优先级(同名工具时高优先级覆盖低优先级)
Returns:
是否注册成功
"""
with self._registry_lock:
# 检查是否已存在同名工具
existing = self._tools.get(name)
if existing:
if existing.priority >= priority:
logger.debug(f"工具 {name} 已存在且优先级更高,跳过注册")
return False
logger.info(f"工具 {name}{plugin_name} 覆盖(优先级 {priority} > {existing.priority}")
# 从旧插件的工具列表中移除
old_plugin_tools = self._tools_by_plugin.get(existing.plugin_name, [])
if name in old_plugin_tools:
old_plugin_tools.remove(name)
# 注册新工具
tool_def = ToolDefinition(
name=name,
plugin_name=plugin_name,
schema=schema,
executor=executor,
timeout=timeout,
priority=priority,
)
self._tools[name] = tool_def
# 更新插件工具映射
if plugin_name not in self._tools_by_plugin:
self._tools_by_plugin[plugin_name] = []
if name not in self._tools_by_plugin[plugin_name]:
self._tools_by_plugin[plugin_name].append(name)
logger.debug(f"注册工具: {name} (插件: {plugin_name}, 超时: {timeout}s)")
return True
def unregister(self, name: str) -> bool:
"""注销工具"""
with self._registry_lock:
tool_def = self._tools.pop(name, None)
if tool_def:
plugin_tools = self._tools_by_plugin.get(tool_def.plugin_name, [])
if name in plugin_tools:
plugin_tools.remove(name)
logger.debug(f"注销工具: {name}")
return True
return False
def unregister_plugin(self, plugin_name: str) -> int:
"""
注销插件的所有工具
Args:
plugin_name: 插件名
Returns:
注销的工具数量
"""
with self._registry_lock:
tool_names = self._tools_by_plugin.pop(plugin_name, [])
count = 0
for name in tool_names:
if self._tools.pop(name, None):
count += 1
if count > 0:
logger.info(f"注销插件 {plugin_name}{count} 个工具")
return count
def get(self, name: str) -> Optional[ToolDefinition]:
"""获取工具定义O(1) 查找)"""
return self._tools.get(name)
def get_all_schemas(self) -> List[Dict[str, Any]]:
"""获取所有工具的 schema 列表(供 LLM 使用)"""
return [tool.schema for tool in self._tools.values()]
def get_plugin_schemas(self, plugin_name: str) -> List[Dict[str, Any]]:
"""获取指定插件的工具 schema 列表"""
tool_names = self._tools_by_plugin.get(plugin_name, [])
return [self._tools[name].schema for name in tool_names if name in self._tools]
def list_tools(self) -> List[str]:
"""列出所有工具名"""
return list(self._tools.keys())
def list_plugin_tools(self, plugin_name: str) -> List[str]:
"""列出插件的所有工具名"""
return self._tools_by_plugin.get(plugin_name, []).copy()
async def execute(
self,
name: str,
arguments: Dict[str, Any],
bot,
from_wxid: str,
timeout_override: float = None,
) -> Dict[str, Any]:
"""
执行工具(带超时保护和统一错误处理)
Args:
name: 工具名
arguments: 工具参数
bot: WechatHookClient 实例
from_wxid: 消息来源 wxid
timeout_override: 覆盖默认超时时间
Returns:
工具执行结果字典
"""
from utils.errors import (
ToolNotFoundError, ToolTimeoutError, ToolExecutionError,
handle_error
)
tool_def = self._tools.get(name)
if not tool_def:
err = ToolNotFoundError(f"工具 {name} 不存在")
return err.to_dict()
timeout = timeout_override if timeout_override is not None else tool_def.timeout
try:
result = await asyncio.wait_for(
tool_def.executor(name, arguments, bot, from_wxid),
timeout=timeout
)
return result
except asyncio.TimeoutError:
err = ToolTimeoutError(
message=f"工具 {name} 执行超时 ({timeout}s)",
user_message=f"工具执行超时 ({timeout}s)",
context={"tool_name": name, "timeout": timeout}
)
logger.warning(err.message)
result = err.to_dict()
result["timeout"] = True
return result
except Exception as e:
error_result = handle_error(
e,
context=f"执行工具 {name}",
log=True,
)
return error_result.to_dict()
def get_stats(self) -> Dict[str, Any]:
"""获取注册统计信息"""
return {
"total_tools": len(self._tools),
"plugins": len(self._tools_by_plugin),
"tools_by_plugin": {
plugin: len(tools)
for plugin, tools in self._tools_by_plugin.items()
}
}
def clear(self):
"""清空所有注册(用于测试或重置)"""
with self._registry_lock:
self._tools.clear()
self._tools_by_plugin.clear()
logger.info("ToolRegistry 已清空")
# ==================== 便捷函数 ====================
def get_tool_registry() -> ToolRegistry:
"""获取工具注册中心实例"""
return ToolRegistry()
# ==================== 导出列表 ====================
__all__ = [
'ToolDefinition',
'ToolRegistry',
'get_tool_registry',
]