Merge branch 'main' of https://gitea.functen.cn/shihao/WechatHookBot
This commit is contained in:
@@ -181,3 +181,91 @@ def validate_tool_arguments(
|
||||
|
||||
return True, "", arguments
|
||||
|
||||
|
||||
# ==================== 工具注册中心集成 ====================
|
||||
|
||||
def register_plugin_tools(
|
||||
plugin_name: str,
|
||||
plugin: Any,
|
||||
tools_config: Dict[str, Any],
|
||||
timeout_config: Optional[Dict[str, Any]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
将插件的 LLM 工具注册到全局工具注册中心
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
plugin: 插件实例(需实现 get_llm_tools 和 execute_llm_tool)
|
||||
tools_config: 工具配置(包含 mode, whitelist, blacklist)
|
||||
timeout_config: 工具超时配置 {tool_name: timeout_seconds}
|
||||
|
||||
Returns:
|
||||
注册的工具数量
|
||||
"""
|
||||
from utils.tool_registry import get_tool_registry
|
||||
|
||||
if not hasattr(plugin, "get_llm_tools") or not hasattr(plugin, "execute_llm_tool"):
|
||||
return 0
|
||||
|
||||
registry = get_tool_registry()
|
||||
timeout_config = timeout_config or {}
|
||||
|
||||
mode = tools_config.get("mode", "all")
|
||||
whitelist = set(tools_config.get("whitelist", []))
|
||||
blacklist = set(tools_config.get("blacklist", []))
|
||||
|
||||
plugin_tools = plugin.get_llm_tools() or []
|
||||
registered_count = 0
|
||||
|
||||
for tool in plugin_tools:
|
||||
tool_name = tool.get("function", {}).get("name", "")
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
# 应用白名单/黑名单过滤
|
||||
if mode == "whitelist" and tool_name not in whitelist:
|
||||
continue
|
||||
if mode == "blacklist" and tool_name in blacklist:
|
||||
logger.debug(f"[黑名单] 跳过注册工具: {tool_name}")
|
||||
continue
|
||||
|
||||
# 获取工具超时配置
|
||||
timeout = timeout_config.get(tool_name, timeout_config.get("default", 60))
|
||||
|
||||
# 创建执行器闭包
|
||||
async def make_executor(p, tn):
|
||||
async def executor(tool_name: str, arguments: dict, bot, from_wxid: str):
|
||||
return await p.execute_llm_tool(tool_name, arguments, bot, from_wxid)
|
||||
return executor
|
||||
|
||||
# 注册工具
|
||||
if registry.register(
|
||||
name=tool_name,
|
||||
plugin_name=plugin_name,
|
||||
schema=tool,
|
||||
executor=plugin.execute_llm_tool,
|
||||
timeout=timeout,
|
||||
):
|
||||
registered_count += 1
|
||||
if mode == "whitelist":
|
||||
logger.debug(f"[白名单] 注册工具: {tool_name}")
|
||||
|
||||
if registered_count > 0:
|
||||
logger.info(f"插件 {plugin_name} 注册了 {registered_count} 个工具")
|
||||
|
||||
return registered_count
|
||||
|
||||
|
||||
def unregister_plugin_tools(plugin_name: str) -> int:
|
||||
"""
|
||||
从全局工具注册中心注销插件的所有工具
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
注销的工具数量
|
||||
"""
|
||||
from utils.tool_registry import get_tool_registry
|
||||
return get_tool_registry().unregister_plugin(plugin_name)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user