272 lines
8.3 KiB
Python
272 lines
8.3 KiB
Python
"""
|
||
LLM 工具体系公共模块
|
||
|
||
统一工具收集、参数校验与执行结果结构,供 AIChat 等插件使用。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from dataclasses import dataclass
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
from loguru import logger
|
||
|
||
|
||
@dataclass
|
||
class ToolResult:
|
||
"""统一的工具执行结果结构"""
|
||
|
||
success: bool = True
|
||
message: str = ""
|
||
need_ai_reply: bool = False
|
||
already_sent: bool = False
|
||
send_result_text: bool = False
|
||
no_reply: bool = False
|
||
save_to_memory: bool = False
|
||
|
||
@classmethod
|
||
def from_raw(cls, raw: Any) -> Optional["ToolResult"]:
|
||
if raw is None:
|
||
return None
|
||
|
||
if not isinstance(raw, dict):
|
||
return cls(success=True, message=str(raw))
|
||
|
||
msg = raw.get("message", "")
|
||
if not isinstance(msg, str):
|
||
try:
|
||
msg = json.dumps(msg, ensure_ascii=False)
|
||
except Exception:
|
||
msg = str(msg)
|
||
|
||
return cls(
|
||
success=bool(raw.get("success", True)),
|
||
message=msg,
|
||
need_ai_reply=bool(raw.get("need_ai_reply", False)),
|
||
already_sent=bool(raw.get("already_sent", False)),
|
||
send_result_text=bool(raw.get("send_result_text", False)),
|
||
no_reply=bool(raw.get("no_reply", False)),
|
||
save_to_memory=bool(raw.get("save_to_memory", False)),
|
||
)
|
||
|
||
|
||
def collect_tools_with_plugins(
|
||
tools_config: Dict[str, Any],
|
||
plugins: Dict[str, Any],
|
||
) -> Dict[str, Tuple[str, Dict[str, Any]]]:
|
||
"""
|
||
收集所有插件的 LLM 工具,并保留来源插件名。
|
||
|
||
Args:
|
||
tools_config: AIChat 配置中的 [tools] 节
|
||
plugins: PluginManager().plugins 映射
|
||
|
||
Returns:
|
||
{tool_name: (plugin_name, tool_dict)}
|
||
"""
|
||
tools_by_name: Dict[str, Tuple[str, Dict[str, Any]]] = {}
|
||
|
||
mode = tools_config.get("mode", "all")
|
||
whitelist = set(tools_config.get("whitelist", []))
|
||
blacklist = set(tools_config.get("blacklist", []))
|
||
|
||
for plugin_name, plugin in plugins.items():
|
||
if not hasattr(plugin, "get_llm_tools"):
|
||
continue
|
||
|
||
plugin_tools = plugin.get_llm_tools() or []
|
||
for tool in plugin_tools:
|
||
tool_name = tool.get("function", {}).get("name", "")
|
||
if not tool_name:
|
||
continue
|
||
|
||
if mode == "whitelist" and tool_name not in whitelist:
|
||
continue
|
||
if mode == "blacklist" and tool_name in blacklist:
|
||
logger.debug(f"[黑名单] 禁用工具: {tool_name}")
|
||
continue
|
||
|
||
if tool_name in tools_by_name:
|
||
logger.warning(f"重复工具名 {tool_name} 来自 {plugin_name},已忽略")
|
||
continue
|
||
|
||
tools_by_name[tool_name] = (plugin_name, tool)
|
||
if mode == "whitelist":
|
||
logger.debug(f"[白名单] 启用工具: {tool_name}")
|
||
|
||
return tools_by_name
|
||
|
||
|
||
def collect_tools(
|
||
tools_config: Dict[str, Any],
|
||
plugins: Dict[str, Any],
|
||
) -> List[Dict[str, Any]]:
|
||
"""仅返回工具定义列表"""
|
||
return [item[1] for item in collect_tools_with_plugins(tools_config, plugins).values()]
|
||
|
||
|
||
def get_tool_schema_map(
|
||
tools_map: Dict[str, Tuple[str, Dict[str, Any]]],
|
||
) -> Dict[str, Dict[str, Any]]:
|
||
"""构建工具名到参数 schema 的映射"""
|
||
schema_map: Dict[str, Dict[str, Any]] = {}
|
||
for name, (_plugin_name, tool) in tools_map.items():
|
||
fn = tool.get("function", {})
|
||
schema_map[name] = fn.get("parameters", {}) or {}
|
||
return schema_map
|
||
|
||
|
||
def validate_tool_arguments(
|
||
tool_name: str,
|
||
arguments: Dict[str, Any],
|
||
schema: Optional[Dict[str, Any]],
|
||
) -> Tuple[bool, str, Dict[str, Any]]:
|
||
"""
|
||
轻量校验并补全默认参数。
|
||
|
||
Returns:
|
||
(ok, error_message, new_arguments)
|
||
"""
|
||
if not schema:
|
||
return True, "", arguments
|
||
|
||
props = schema.get("properties", {}) or {}
|
||
required = schema.get("required", []) or []
|
||
|
||
# 应用默认值
|
||
for key, prop in props.items():
|
||
if key not in arguments and isinstance(prop, dict) and "default" in prop:
|
||
arguments[key] = prop["default"]
|
||
|
||
missing = []
|
||
for key in required:
|
||
if key not in arguments or arguments[key] in (None, "", []):
|
||
missing.append(key)
|
||
|
||
if missing:
|
||
return False, f"缺少参数: {', '.join(missing)}", arguments
|
||
|
||
# 枚举与基础类型校验
|
||
for key, prop in props.items():
|
||
if key not in arguments or not isinstance(prop, dict):
|
||
continue
|
||
|
||
value = arguments[key]
|
||
|
||
if "enum" in prop and value not in prop["enum"]:
|
||
return False, f"参数 {key} 必须是 {prop['enum']}", arguments
|
||
|
||
expected_type = prop.get("type")
|
||
if expected_type == "integer":
|
||
try:
|
||
arguments[key] = int(value)
|
||
except Exception:
|
||
return False, f"参数 {key} 应为整数", arguments
|
||
elif expected_type == "number":
|
||
try:
|
||
arguments[key] = float(value)
|
||
except Exception:
|
||
return False, f"参数 {key} 应为数字", arguments
|
||
elif expected_type == "boolean":
|
||
if isinstance(value, bool):
|
||
continue
|
||
if isinstance(value, str) and value.lower() in ("true", "false", "1", "0"):
|
||
arguments[key] = value.lower() in ("true", "1")
|
||
else:
|
||
return False, f"参数 {key} 应为布尔值", arguments
|
||
elif expected_type == "string":
|
||
if not isinstance(value, str):
|
||
arguments[key] = str(value)
|
||
|
||
return True, "", arguments
|
||
|
||
|
||
# ==================== 工具注册中心集成 ====================
|
||
|
||
def register_plugin_tools(
|
||
plugin_name: str,
|
||
plugin: Any,
|
||
tools_config: Dict[str, Any],
|
||
timeout_config: Optional[Dict[str, Any]] = None,
|
||
) -> int:
|
||
"""
|
||
将插件的 LLM 工具注册到全局工具注册中心
|
||
|
||
Args:
|
||
plugin_name: 插件名称
|
||
plugin: 插件实例(需实现 get_llm_tools 和 execute_llm_tool)
|
||
tools_config: 工具配置(包含 mode, whitelist, blacklist)
|
||
timeout_config: 工具超时配置 {tool_name: timeout_seconds}
|
||
|
||
Returns:
|
||
注册的工具数量
|
||
"""
|
||
from utils.tool_registry import get_tool_registry
|
||
|
||
if not hasattr(plugin, "get_llm_tools") or not hasattr(plugin, "execute_llm_tool"):
|
||
return 0
|
||
|
||
registry = get_tool_registry()
|
||
timeout_config = timeout_config or {}
|
||
|
||
mode = tools_config.get("mode", "all")
|
||
whitelist = set(tools_config.get("whitelist", []))
|
||
blacklist = set(tools_config.get("blacklist", []))
|
||
|
||
plugin_tools = plugin.get_llm_tools() or []
|
||
registered_count = 0
|
||
|
||
for tool in plugin_tools:
|
||
tool_name = tool.get("function", {}).get("name", "")
|
||
if not tool_name:
|
||
continue
|
||
|
||
# 应用白名单/黑名单过滤
|
||
if mode == "whitelist" and tool_name not in whitelist:
|
||
continue
|
||
if mode == "blacklist" and tool_name in blacklist:
|
||
logger.debug(f"[黑名单] 跳过注册工具: {tool_name}")
|
||
continue
|
||
|
||
# 获取工具超时配置
|
||
timeout = timeout_config.get(tool_name, timeout_config.get("default", 60))
|
||
|
||
# 创建执行器闭包
|
||
async def make_executor(p, tn):
|
||
async def executor(tool_name: str, arguments: dict, bot, from_wxid: str):
|
||
return await p.execute_llm_tool(tool_name, arguments, bot, from_wxid)
|
||
return executor
|
||
|
||
# 注册工具
|
||
if registry.register(
|
||
name=tool_name,
|
||
plugin_name=plugin_name,
|
||
schema=tool,
|
||
executor=plugin.execute_llm_tool,
|
||
timeout=timeout,
|
||
):
|
||
registered_count += 1
|
||
if mode == "whitelist":
|
||
logger.debug(f"[白名单] 注册工具: {tool_name}")
|
||
|
||
if registered_count > 0:
|
||
logger.info(f"插件 {plugin_name} 注册了 {registered_count} 个工具")
|
||
|
||
return registered_count
|
||
|
||
|
||
def unregister_plugin_tools(plugin_name: str) -> int:
|
||
"""
|
||
从全局工具注册中心注销插件的所有工具
|
||
|
||
Args:
|
||
plugin_name: 插件名称
|
||
|
||
Returns:
|
||
注销的工具数量
|
||
"""
|
||
from utils.tool_registry import get_tool_registry
|
||
return get_tool_registry().unregister_plugin(plugin_name)
|
||
|