Merge branch 'main' of https://gitea.functen.cn/shihao/WechatHookBot
This commit is contained in:
286
utils/tool_registry.py
Normal file
286
utils/tool_registry.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
工具注册中心
|
||||
|
||||
集中管理所有 LLM 工具的注册、查找和执行
|
||||
- O(1) 工具查找(替代 O(n) 插件遍历)
|
||||
- 统一的超时保护
|
||||
- 工具元信息管理
|
||||
|
||||
使用示例:
|
||||
from utils.tool_registry import get_tool_registry
|
||||
|
||||
registry = get_tool_registry()
|
||||
|
||||
# 注册工具
|
||||
registry.register(
|
||||
name="generate_image",
|
||||
plugin_name="AIChat",
|
||||
schema={...},
|
||||
executor=some_async_func,
|
||||
timeout=120
|
||||
)
|
||||
|
||||
# 执行工具
|
||||
result = await registry.execute("generate_image", arguments, bot, from_wxid)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Dict, List, Optional, Awaitable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
"""工具定义"""
|
||||
name: str
|
||||
plugin_name: str
|
||||
schema: Dict[str, Any] # OpenAI-compatible tool schema
|
||||
executor: Callable[..., Awaitable[Dict[str, Any]]]
|
||||
timeout: float = 60.0
|
||||
priority: int = 50 # 同名工具时优先级高的生效
|
||||
description: str = ""
|
||||
|
||||
def __post_init__(self):
|
||||
# 从 schema 提取描述
|
||||
if not self.description and self.schema:
|
||||
func_def = self.schema.get("function", {})
|
||||
self.description = func_def.get("description", "")
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""
|
||||
工具注册中心(线程安全单例)
|
||||
|
||||
功能:
|
||||
- 工具注册与注销
|
||||
- O(1) 工具查找
|
||||
- 统一超时保护执行
|
||||
- 工具列表导出(供 LLM 使用)
|
||||
"""
|
||||
|
||||
_instance: Optional["ToolRegistry"] = None
|
||||
_lock = Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
instance = super().__new__(cls)
|
||||
instance._initialized = False
|
||||
cls._instance = instance
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._tools: Dict[str, ToolDefinition] = {}
|
||||
self._tools_by_plugin: Dict[str, List[str]] = {} # plugin_name -> [tool_names]
|
||||
self._registry_lock = Lock()
|
||||
self._initialized = True
|
||||
logger.debug("ToolRegistry 初始化完成")
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
plugin_name: str,
|
||||
schema: Dict[str, Any],
|
||||
executor: Callable[..., Awaitable[Dict[str, Any]]],
|
||||
timeout: float = 60.0,
|
||||
priority: int = 50,
|
||||
) -> bool:
|
||||
"""
|
||||
注册工具
|
||||
|
||||
Args:
|
||||
name: 工具名称(唯一标识)
|
||||
plugin_name: 所属插件名
|
||||
schema: OpenAI-compatible tool schema
|
||||
executor: 异步执行函数,签名: async (tool_name, arguments, bot, from_wxid) -> dict
|
||||
timeout: 执行超时(秒)
|
||||
priority: 优先级(同名工具时高优先级覆盖低优先级)
|
||||
|
||||
Returns:
|
||||
是否注册成功
|
||||
"""
|
||||
with self._registry_lock:
|
||||
# 检查是否已存在同名工具
|
||||
existing = self._tools.get(name)
|
||||
if existing:
|
||||
if existing.priority >= priority:
|
||||
logger.debug(f"工具 {name} 已存在且优先级更高,跳过注册")
|
||||
return False
|
||||
logger.info(f"工具 {name} 被 {plugin_name} 覆盖(优先级 {priority} > {existing.priority})")
|
||||
# 从旧插件的工具列表中移除
|
||||
old_plugin_tools = self._tools_by_plugin.get(existing.plugin_name, [])
|
||||
if name in old_plugin_tools:
|
||||
old_plugin_tools.remove(name)
|
||||
|
||||
# 注册新工具
|
||||
tool_def = ToolDefinition(
|
||||
name=name,
|
||||
plugin_name=plugin_name,
|
||||
schema=schema,
|
||||
executor=executor,
|
||||
timeout=timeout,
|
||||
priority=priority,
|
||||
)
|
||||
self._tools[name] = tool_def
|
||||
|
||||
# 更新插件工具映射
|
||||
if plugin_name not in self._tools_by_plugin:
|
||||
self._tools_by_plugin[plugin_name] = []
|
||||
if name not in self._tools_by_plugin[plugin_name]:
|
||||
self._tools_by_plugin[plugin_name].append(name)
|
||||
|
||||
logger.debug(f"注册工具: {name} (插件: {plugin_name}, 超时: {timeout}s)")
|
||||
return True
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""注销工具"""
|
||||
with self._registry_lock:
|
||||
tool_def = self._tools.pop(name, None)
|
||||
if tool_def:
|
||||
plugin_tools = self._tools_by_plugin.get(tool_def.plugin_name, [])
|
||||
if name in plugin_tools:
|
||||
plugin_tools.remove(name)
|
||||
logger.debug(f"注销工具: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def unregister_plugin(self, plugin_name: str) -> int:
|
||||
"""
|
||||
注销插件的所有工具
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名
|
||||
|
||||
Returns:
|
||||
注销的工具数量
|
||||
"""
|
||||
with self._registry_lock:
|
||||
tool_names = self._tools_by_plugin.pop(plugin_name, [])
|
||||
count = 0
|
||||
for name in tool_names:
|
||||
if self._tools.pop(name, None):
|
||||
count += 1
|
||||
if count > 0:
|
||||
logger.info(f"注销插件 {plugin_name} 的 {count} 个工具")
|
||||
return count
|
||||
|
||||
def get(self, name: str) -> Optional[ToolDefinition]:
|
||||
"""获取工具定义(O(1) 查找)"""
|
||||
return self._tools.get(name)
|
||||
|
||||
def get_all_schemas(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有工具的 schema 列表(供 LLM 使用)"""
|
||||
return [tool.schema for tool in self._tools.values()]
|
||||
|
||||
def get_plugin_schemas(self, plugin_name: str) -> List[Dict[str, Any]]:
|
||||
"""获取指定插件的工具 schema 列表"""
|
||||
tool_names = self._tools_by_plugin.get(plugin_name, [])
|
||||
return [self._tools[name].schema for name in tool_names if name in self._tools]
|
||||
|
||||
def list_tools(self) -> List[str]:
|
||||
"""列出所有工具名"""
|
||||
return list(self._tools.keys())
|
||||
|
||||
def list_plugin_tools(self, plugin_name: str) -> List[str]:
|
||||
"""列出插件的所有工具名"""
|
||||
return self._tools_by_plugin.get(plugin_name, []).copy()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
name: str,
|
||||
arguments: Dict[str, Any],
|
||||
bot,
|
||||
from_wxid: str,
|
||||
timeout_override: float = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行工具(带超时保护和统一错误处理)
|
||||
|
||||
Args:
|
||||
name: 工具名
|
||||
arguments: 工具参数
|
||||
bot: WechatHookClient 实例
|
||||
from_wxid: 消息来源 wxid
|
||||
timeout_override: 覆盖默认超时时间
|
||||
|
||||
Returns:
|
||||
工具执行结果字典
|
||||
"""
|
||||
from utils.errors import (
|
||||
ToolNotFoundError, ToolTimeoutError, ToolExecutionError,
|
||||
handle_error
|
||||
)
|
||||
|
||||
tool_def = self._tools.get(name)
|
||||
if not tool_def:
|
||||
err = ToolNotFoundError(f"工具 {name} 不存在")
|
||||
return err.to_dict()
|
||||
|
||||
timeout = timeout_override if timeout_override is not None else tool_def.timeout
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
tool_def.executor(name, arguments, bot, from_wxid),
|
||||
timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
err = ToolTimeoutError(
|
||||
message=f"工具 {name} 执行超时 ({timeout}s)",
|
||||
user_message=f"工具执行超时 ({timeout}s)",
|
||||
context={"tool_name": name, "timeout": timeout}
|
||||
)
|
||||
logger.warning(err.message)
|
||||
result = err.to_dict()
|
||||
result["timeout"] = True
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_result = handle_error(
|
||||
e,
|
||||
context=f"执行工具 {name}",
|
||||
log=True,
|
||||
)
|
||||
return error_result.to_dict()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取注册统计信息"""
|
||||
return {
|
||||
"total_tools": len(self._tools),
|
||||
"plugins": len(self._tools_by_plugin),
|
||||
"tools_by_plugin": {
|
||||
plugin: len(tools)
|
||||
for plugin, tools in self._tools_by_plugin.items()
|
||||
}
|
||||
}
|
||||
|
||||
def clear(self):
|
||||
"""清空所有注册(用于测试或重置)"""
|
||||
with self._registry_lock:
|
||||
self._tools.clear()
|
||||
self._tools_by_plugin.clear()
|
||||
logger.info("ToolRegistry 已清空")
|
||||
|
||||
|
||||
# ==================== 便捷函数 ====================
|
||||
|
||||
def get_tool_registry() -> ToolRegistry:
|
||||
"""获取工具注册中心实例"""
|
||||
return ToolRegistry()
|
||||
|
||||
|
||||
# ==================== 导出列表 ====================
|
||||
|
||||
__all__ = [
|
||||
'ToolDefinition',
|
||||
'ToolRegistry',
|
||||
'get_tool_registry',
|
||||
]
|
||||
Reference in New Issue
Block a user