184 lines
5.6 KiB
Python
184 lines
5.6 KiB
Python
"""
|
|
LLM 工具体系公共模块
|
|
|
|
统一工具收集、参数校验与执行结果结构,供 AIChat 等插件使用。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from loguru import logger
|
|
|
|
|
|
@dataclass
|
|
class ToolResult:
|
|
"""统一的工具执行结果结构"""
|
|
|
|
success: bool = True
|
|
message: str = ""
|
|
need_ai_reply: bool = False
|
|
already_sent: bool = False
|
|
send_result_text: bool = False
|
|
no_reply: bool = False
|
|
save_to_memory: bool = False
|
|
|
|
@classmethod
|
|
def from_raw(cls, raw: Any) -> Optional["ToolResult"]:
|
|
if raw is None:
|
|
return None
|
|
|
|
if not isinstance(raw, dict):
|
|
return cls(success=True, message=str(raw))
|
|
|
|
msg = raw.get("message", "")
|
|
if not isinstance(msg, str):
|
|
try:
|
|
msg = json.dumps(msg, ensure_ascii=False)
|
|
except Exception:
|
|
msg = str(msg)
|
|
|
|
return cls(
|
|
success=bool(raw.get("success", True)),
|
|
message=msg,
|
|
need_ai_reply=bool(raw.get("need_ai_reply", False)),
|
|
already_sent=bool(raw.get("already_sent", False)),
|
|
send_result_text=bool(raw.get("send_result_text", False)),
|
|
no_reply=bool(raw.get("no_reply", False)),
|
|
save_to_memory=bool(raw.get("save_to_memory", False)),
|
|
)
|
|
|
|
|
|
def collect_tools_with_plugins(
|
|
tools_config: Dict[str, Any],
|
|
plugins: Dict[str, Any],
|
|
) -> Dict[str, Tuple[str, Dict[str, Any]]]:
|
|
"""
|
|
收集所有插件的 LLM 工具,并保留来源插件名。
|
|
|
|
Args:
|
|
tools_config: AIChat 配置中的 [tools] 节
|
|
plugins: PluginManager().plugins 映射
|
|
|
|
Returns:
|
|
{tool_name: (plugin_name, tool_dict)}
|
|
"""
|
|
tools_by_name: Dict[str, Tuple[str, Dict[str, Any]]] = {}
|
|
|
|
mode = tools_config.get("mode", "all")
|
|
whitelist = set(tools_config.get("whitelist", []))
|
|
blacklist = set(tools_config.get("blacklist", []))
|
|
|
|
for plugin_name, plugin in plugins.items():
|
|
if not hasattr(plugin, "get_llm_tools"):
|
|
continue
|
|
|
|
plugin_tools = plugin.get_llm_tools() or []
|
|
for tool in plugin_tools:
|
|
tool_name = tool.get("function", {}).get("name", "")
|
|
if not tool_name:
|
|
continue
|
|
|
|
if mode == "whitelist" and tool_name not in whitelist:
|
|
continue
|
|
if mode == "blacklist" and tool_name in blacklist:
|
|
logger.debug(f"[黑名单] 禁用工具: {tool_name}")
|
|
continue
|
|
|
|
if tool_name in tools_by_name:
|
|
logger.warning(f"重复工具名 {tool_name} 来自 {plugin_name},已忽略")
|
|
continue
|
|
|
|
tools_by_name[tool_name] = (plugin_name, tool)
|
|
if mode == "whitelist":
|
|
logger.debug(f"[白名单] 启用工具: {tool_name}")
|
|
|
|
return tools_by_name
|
|
|
|
|
|
def collect_tools(
|
|
tools_config: Dict[str, Any],
|
|
plugins: Dict[str, Any],
|
|
) -> List[Dict[str, Any]]:
|
|
"""仅返回工具定义列表"""
|
|
return [item[1] for item in collect_tools_with_plugins(tools_config, plugins).values()]
|
|
|
|
|
|
def get_tool_schema_map(
|
|
tools_map: Dict[str, Tuple[str, Dict[str, Any]]],
|
|
) -> Dict[str, Dict[str, Any]]:
|
|
"""构建工具名到参数 schema 的映射"""
|
|
schema_map: Dict[str, Dict[str, Any]] = {}
|
|
for name, (_plugin_name, tool) in tools_map.items():
|
|
fn = tool.get("function", {})
|
|
schema_map[name] = fn.get("parameters", {}) or {}
|
|
return schema_map
|
|
|
|
|
|
def validate_tool_arguments(
|
|
tool_name: str,
|
|
arguments: Dict[str, Any],
|
|
schema: Optional[Dict[str, Any]],
|
|
) -> Tuple[bool, str, Dict[str, Any]]:
|
|
"""
|
|
轻量校验并补全默认参数。
|
|
|
|
Returns:
|
|
(ok, error_message, new_arguments)
|
|
"""
|
|
if not schema:
|
|
return True, "", arguments
|
|
|
|
props = schema.get("properties", {}) or {}
|
|
required = schema.get("required", []) or []
|
|
|
|
# 应用默认值
|
|
for key, prop in props.items():
|
|
if key not in arguments and isinstance(prop, dict) and "default" in prop:
|
|
arguments[key] = prop["default"]
|
|
|
|
missing = []
|
|
for key in required:
|
|
if key not in arguments or arguments[key] in (None, "", []):
|
|
missing.append(key)
|
|
|
|
if missing:
|
|
return False, f"缺少参数: {', '.join(missing)}", arguments
|
|
|
|
# 枚举与基础类型校验
|
|
for key, prop in props.items():
|
|
if key not in arguments or not isinstance(prop, dict):
|
|
continue
|
|
|
|
value = arguments[key]
|
|
|
|
if "enum" in prop and value not in prop["enum"]:
|
|
return False, f"参数 {key} 必须是 {prop['enum']}", arguments
|
|
|
|
expected_type = prop.get("type")
|
|
if expected_type == "integer":
|
|
try:
|
|
arguments[key] = int(value)
|
|
except Exception:
|
|
return False, f"参数 {key} 应为整数", arguments
|
|
elif expected_type == "number":
|
|
try:
|
|
arguments[key] = float(value)
|
|
except Exception:
|
|
return False, f"参数 {key} 应为数字", arguments
|
|
elif expected_type == "boolean":
|
|
if isinstance(value, bool):
|
|
continue
|
|
if isinstance(value, str) and value.lower() in ("true", "false", "1", "0"):
|
|
arguments[key] = value.lower() in ("true", "1")
|
|
else:
|
|
return False, f"参数 {key} 应为布尔值", arguments
|
|
elif expected_type == "string":
|
|
if not isinstance(value, str):
|
|
arguments[key] = str(value)
|
|
|
|
return True, "", arguments
|
|
|