Files
WechatHookBot/utils/llm_tooling.py

272 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
LLM 工具体系公共模块
统一工具收集、参数校验与执行结果结构,供 AIChat 等插件使用。
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from loguru import logger
@dataclass
class ToolResult:
"""统一的工具执行结果结构"""
success: bool = True
message: str = ""
need_ai_reply: bool = False
already_sent: bool = False
send_result_text: bool = False
no_reply: bool = False
save_to_memory: bool = False
@classmethod
def from_raw(cls, raw: Any) -> Optional["ToolResult"]:
if raw is None:
return None
if not isinstance(raw, dict):
return cls(success=True, message=str(raw))
msg = raw.get("message", "")
if not isinstance(msg, str):
try:
msg = json.dumps(msg, ensure_ascii=False)
except Exception:
msg = str(msg)
return cls(
success=bool(raw.get("success", True)),
message=msg,
need_ai_reply=bool(raw.get("need_ai_reply", False)),
already_sent=bool(raw.get("already_sent", False)),
send_result_text=bool(raw.get("send_result_text", False)),
no_reply=bool(raw.get("no_reply", False)),
save_to_memory=bool(raw.get("save_to_memory", False)),
)
def collect_tools_with_plugins(
tools_config: Dict[str, Any],
plugins: Dict[str, Any],
) -> Dict[str, Tuple[str, Dict[str, Any]]]:
"""
收集所有插件的 LLM 工具,并保留来源插件名。
Args:
tools_config: AIChat 配置中的 [tools] 节
plugins: PluginManager().plugins 映射
Returns:
{tool_name: (plugin_name, tool_dict)}
"""
tools_by_name: Dict[str, Tuple[str, Dict[str, Any]]] = {}
mode = tools_config.get("mode", "all")
whitelist = set(tools_config.get("whitelist", []))
blacklist = set(tools_config.get("blacklist", []))
for plugin_name, plugin in plugins.items():
if not hasattr(plugin, "get_llm_tools"):
continue
plugin_tools = plugin.get_llm_tools() or []
for tool in plugin_tools:
tool_name = tool.get("function", {}).get("name", "")
if not tool_name:
continue
if mode == "whitelist" and tool_name not in whitelist:
continue
if mode == "blacklist" and tool_name in blacklist:
logger.debug(f"[黑名单] 禁用工具: {tool_name}")
continue
if tool_name in tools_by_name:
logger.warning(f"重复工具名 {tool_name} 来自 {plugin_name},已忽略")
continue
tools_by_name[tool_name] = (plugin_name, tool)
if mode == "whitelist":
logger.debug(f"[白名单] 启用工具: {tool_name}")
return tools_by_name
def collect_tools(
tools_config: Dict[str, Any],
plugins: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""仅返回工具定义列表"""
return [item[1] for item in collect_tools_with_plugins(tools_config, plugins).values()]
def get_tool_schema_map(
tools_map: Dict[str, Tuple[str, Dict[str, Any]]],
) -> Dict[str, Dict[str, Any]]:
"""构建工具名到参数 schema 的映射"""
schema_map: Dict[str, Dict[str, Any]] = {}
for name, (_plugin_name, tool) in tools_map.items():
fn = tool.get("function", {})
schema_map[name] = fn.get("parameters", {}) or {}
return schema_map
def validate_tool_arguments(
tool_name: str,
arguments: Dict[str, Any],
schema: Optional[Dict[str, Any]],
) -> Tuple[bool, str, Dict[str, Any]]:
"""
轻量校验并补全默认参数。
Returns:
(ok, error_message, new_arguments)
"""
if not schema:
return True, "", arguments
props = schema.get("properties", {}) or {}
required = schema.get("required", []) or []
# 应用默认值
for key, prop in props.items():
if key not in arguments and isinstance(prop, dict) and "default" in prop:
arguments[key] = prop["default"]
missing = []
for key in required:
if key not in arguments or arguments[key] in (None, "", []):
missing.append(key)
if missing:
return False, f"缺少参数: {', '.join(missing)}", arguments
# 枚举与基础类型校验
for key, prop in props.items():
if key not in arguments or not isinstance(prop, dict):
continue
value = arguments[key]
if "enum" in prop and value not in prop["enum"]:
return False, f"参数 {key} 必须是 {prop['enum']}", arguments
expected_type = prop.get("type")
if expected_type == "integer":
try:
arguments[key] = int(value)
except Exception:
return False, f"参数 {key} 应为整数", arguments
elif expected_type == "number":
try:
arguments[key] = float(value)
except Exception:
return False, f"参数 {key} 应为数字", arguments
elif expected_type == "boolean":
if isinstance(value, bool):
continue
if isinstance(value, str) and value.lower() in ("true", "false", "1", "0"):
arguments[key] = value.lower() in ("true", "1")
else:
return False, f"参数 {key} 应为布尔值", arguments
elif expected_type == "string":
if not isinstance(value, str):
arguments[key] = str(value)
return True, "", arguments
# ==================== 工具注册中心集成 ====================
def register_plugin_tools(
plugin_name: str,
plugin: Any,
tools_config: Dict[str, Any],
timeout_config: Optional[Dict[str, Any]] = None,
) -> int:
"""
将插件的 LLM 工具注册到全局工具注册中心
Args:
plugin_name: 插件名称
plugin: 插件实例(需实现 get_llm_tools 和 execute_llm_tool
tools_config: 工具配置(包含 mode, whitelist, blacklist
timeout_config: 工具超时配置 {tool_name: timeout_seconds}
Returns:
注册的工具数量
"""
from utils.tool_registry import get_tool_registry
if not hasattr(plugin, "get_llm_tools") or not hasattr(plugin, "execute_llm_tool"):
return 0
registry = get_tool_registry()
timeout_config = timeout_config or {}
mode = tools_config.get("mode", "all")
whitelist = set(tools_config.get("whitelist", []))
blacklist = set(tools_config.get("blacklist", []))
plugin_tools = plugin.get_llm_tools() or []
registered_count = 0
for tool in plugin_tools:
tool_name = tool.get("function", {}).get("name", "")
if not tool_name:
continue
# 应用白名单/黑名单过滤
if mode == "whitelist" and tool_name not in whitelist:
continue
if mode == "blacklist" and tool_name in blacklist:
logger.debug(f"[黑名单] 跳过注册工具: {tool_name}")
continue
# 获取工具超时配置
timeout = timeout_config.get(tool_name, timeout_config.get("default", 60))
# 创建执行器闭包
async def make_executor(p, tn):
async def executor(tool_name: str, arguments: dict, bot, from_wxid: str):
return await p.execute_llm_tool(tool_name, arguments, bot, from_wxid)
return executor
# 注册工具
if registry.register(
name=tool_name,
plugin_name=plugin_name,
schema=tool,
executor=plugin.execute_llm_tool,
timeout=timeout,
):
registered_count += 1
if mode == "whitelist":
logger.debug(f"[白名单] 注册工具: {tool_name}")
if registered_count > 0:
logger.info(f"插件 {plugin_name} 注册了 {registered_count} 个工具")
return registered_count
def unregister_plugin_tools(plugin_name: str) -> int:
"""
从全局工具注册中心注销插件的所有工具
Args:
plugin_name: 插件名称
Returns:
注销的工具数量
"""
from utils.tool_registry import get_tool_registry
return get_tool_registry().unregister_plugin(plugin_name)