feat:优化屎山
This commit is contained in:
183
utils/llm_tooling.py
Normal file
183
utils/llm_tooling.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
LLM 工具体系公共模块
|
||||
|
||||
统一工具收集、参数校验与执行结果结构,供 AIChat 等插件使用。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""统一的工具执行结果结构"""
|
||||
|
||||
success: bool = True
|
||||
message: str = ""
|
||||
need_ai_reply: bool = False
|
||||
already_sent: bool = False
|
||||
send_result_text: bool = False
|
||||
no_reply: bool = False
|
||||
save_to_memory: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, raw: Any) -> Optional["ToolResult"]:
|
||||
if raw is None:
|
||||
return None
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
return cls(success=True, message=str(raw))
|
||||
|
||||
msg = raw.get("message", "")
|
||||
if not isinstance(msg, str):
|
||||
try:
|
||||
msg = json.dumps(msg, ensure_ascii=False)
|
||||
except Exception:
|
||||
msg = str(msg)
|
||||
|
||||
return cls(
|
||||
success=bool(raw.get("success", True)),
|
||||
message=msg,
|
||||
need_ai_reply=bool(raw.get("need_ai_reply", False)),
|
||||
already_sent=bool(raw.get("already_sent", False)),
|
||||
send_result_text=bool(raw.get("send_result_text", False)),
|
||||
no_reply=bool(raw.get("no_reply", False)),
|
||||
save_to_memory=bool(raw.get("save_to_memory", False)),
|
||||
)
|
||||
|
||||
|
||||
def collect_tools_with_plugins(
|
||||
tools_config: Dict[str, Any],
|
||||
plugins: Dict[str, Any],
|
||||
) -> Dict[str, Tuple[str, Dict[str, Any]]]:
|
||||
"""
|
||||
收集所有插件的 LLM 工具,并保留来源插件名。
|
||||
|
||||
Args:
|
||||
tools_config: AIChat 配置中的 [tools] 节
|
||||
plugins: PluginManager().plugins 映射
|
||||
|
||||
Returns:
|
||||
{tool_name: (plugin_name, tool_dict)}
|
||||
"""
|
||||
tools_by_name: Dict[str, Tuple[str, Dict[str, Any]]] = {}
|
||||
|
||||
mode = tools_config.get("mode", "all")
|
||||
whitelist = set(tools_config.get("whitelist", []))
|
||||
blacklist = set(tools_config.get("blacklist", []))
|
||||
|
||||
for plugin_name, plugin in plugins.items():
|
||||
if not hasattr(plugin, "get_llm_tools"):
|
||||
continue
|
||||
|
||||
plugin_tools = plugin.get_llm_tools() or []
|
||||
for tool in plugin_tools:
|
||||
tool_name = tool.get("function", {}).get("name", "")
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
if mode == "whitelist" and tool_name not in whitelist:
|
||||
continue
|
||||
if mode == "blacklist" and tool_name in blacklist:
|
||||
logger.debug(f"[黑名单] 禁用工具: {tool_name}")
|
||||
continue
|
||||
|
||||
if tool_name in tools_by_name:
|
||||
logger.warning(f"重复工具名 {tool_name} 来自 {plugin_name},已忽略")
|
||||
continue
|
||||
|
||||
tools_by_name[tool_name] = (plugin_name, tool)
|
||||
if mode == "whitelist":
|
||||
logger.debug(f"[白名单] 启用工具: {tool_name}")
|
||||
|
||||
return tools_by_name
|
||||
|
||||
|
||||
def collect_tools(
|
||||
tools_config: Dict[str, Any],
|
||||
plugins: Dict[str, Any],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""仅返回工具定义列表"""
|
||||
return [item[1] for item in collect_tools_with_plugins(tools_config, plugins).values()]
|
||||
|
||||
|
||||
def get_tool_schema_map(
|
||||
tools_map: Dict[str, Tuple[str, Dict[str, Any]]],
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""构建工具名到参数 schema 的映射"""
|
||||
schema_map: Dict[str, Dict[str, Any]] = {}
|
||||
for name, (_plugin_name, tool) in tools_map.items():
|
||||
fn = tool.get("function", {})
|
||||
schema_map[name] = fn.get("parameters", {}) or {}
|
||||
return schema_map
|
||||
|
||||
|
||||
def validate_tool_arguments(
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
schema: Optional[Dict[str, Any]],
|
||||
) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""
|
||||
轻量校验并补全默认参数。
|
||||
|
||||
Returns:
|
||||
(ok, error_message, new_arguments)
|
||||
"""
|
||||
if not schema:
|
||||
return True, "", arguments
|
||||
|
||||
props = schema.get("properties", {}) or {}
|
||||
required = schema.get("required", []) or []
|
||||
|
||||
# 应用默认值
|
||||
for key, prop in props.items():
|
||||
if key not in arguments and isinstance(prop, dict) and "default" in prop:
|
||||
arguments[key] = prop["default"]
|
||||
|
||||
missing = []
|
||||
for key in required:
|
||||
if key not in arguments or arguments[key] in (None, "", []):
|
||||
missing.append(key)
|
||||
|
||||
if missing:
|
||||
return False, f"缺少参数: {', '.join(missing)}", arguments
|
||||
|
||||
# 枚举与基础类型校验
|
||||
for key, prop in props.items():
|
||||
if key not in arguments or not isinstance(prop, dict):
|
||||
continue
|
||||
|
||||
value = arguments[key]
|
||||
|
||||
if "enum" in prop and value not in prop["enum"]:
|
||||
return False, f"参数 {key} 必须是 {prop['enum']}", arguments
|
||||
|
||||
expected_type = prop.get("type")
|
||||
if expected_type == "integer":
|
||||
try:
|
||||
arguments[key] = int(value)
|
||||
except Exception:
|
||||
return False, f"参数 {key} 应为整数", arguments
|
||||
elif expected_type == "number":
|
||||
try:
|
||||
arguments[key] = float(value)
|
||||
except Exception:
|
||||
return False, f"参数 {key} 应为数字", arguments
|
||||
elif expected_type == "boolean":
|
||||
if isinstance(value, bool):
|
||||
continue
|
||||
if isinstance(value, str) and value.lower() in ("true", "false", "1", "0"):
|
||||
arguments[key] = value.lower() in ("true", "1")
|
||||
else:
|
||||
return False, f"参数 {key} 应为布尔值", arguments
|
||||
elif expected_type == "string":
|
||||
if not isinstance(value, str):
|
||||
arguments[key] = str(value)
|
||||
|
||||
return True, "", arguments
|
||||
|
||||
Reference in New Issue
Block a user