This commit is contained in:
2025-12-31 17:47:39 +08:00
38 changed files with 4435 additions and 1343 deletions

488
utils/tool_executor.py Normal file
View File

@@ -0,0 +1,488 @@
"""
工具执行器模块
提供工具调用的高级执行逻辑:
- 批量工具执行(支持并行)
- 工具调用链处理
- 执行日志和审计
- 结果聚合
使用示例:
from utils.tool_executor import ToolExecutor, ToolCallRequest
executor = ToolExecutor()
# 单个工具执行
result = await executor.execute_single(
tool_call={"id": "call_1", "function": {"name": "get_weather", "arguments": "{}"}},
bot=bot,
from_wxid=wxid,
)
# 批量工具执行
results = await executor.execute_batch(
tool_calls=[...],
bot=bot,
from_wxid=wxid,
parallel=True,
)
"""
from __future__ import annotations
import asyncio
import json
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
from loguru import logger
@dataclass
class ToolCallRequest:
"""工具调用请求"""
id: str
name: str
arguments: Dict[str, Any]
raw_arguments: str = "" # 原始 JSON 字符串
@dataclass
class ToolCallResult:
"""工具调用结果"""
id: str
name: str
success: bool = True
message: str = ""
raw_result: Dict[str, Any] = field(default_factory=dict)
# 控制标志
need_ai_reply: bool = False
already_sent: bool = False
send_result_text: bool = False
no_reply: bool = False
save_to_memory: bool = False
# 执行信息
execution_time_ms: float = 0.0
error: Optional[str] = None
def to_message(self) -> Dict[str, Any]:
"""转换为 OpenAI 兼容的 tool message"""
content = self.message if self.success else f"错误: {self.error or self.message}"
return {
"role": "tool",
"tool_call_id": self.id,
"content": content
}
@dataclass
class ExecutionStats:
"""执行统计"""
total_calls: int = 0
successful_calls: int = 0
failed_calls: int = 0
timeout_calls: int = 0
total_time_ms: float = 0.0
avg_time_ms: float = 0.0
class ToolExecutor:
"""
工具执行器
提供统一的工具执行接口:
- 参数解析和校验
- 超时保护
- 错误处理
- 执行统计
"""
def __init__(
self,
default_timeout: float = 60.0,
max_parallel: int = 5,
validate_args: bool = True,
):
self.default_timeout = default_timeout
self.max_parallel = max_parallel
self.validate_args = validate_args
self._stats = ExecutionStats()
def parse_tool_call(self, tool_call: Dict[str, Any]) -> ToolCallRequest:
"""
解析 OpenAI 格式的工具调用
Args:
tool_call: OpenAI 返回的 tool_call 对象
Returns:
ToolCallRequest 对象
"""
call_id = tool_call.get("id", "")
function = tool_call.get("function", {})
name = function.get("name", "")
raw_args = function.get("arguments", "{}")
# 解析 arguments JSON
try:
arguments = json.loads(raw_args) if raw_args else {}
except json.JSONDecodeError as e:
logger.warning(f"[ToolExecutor] 解析参数失败: {e}, raw={raw_args[:100]}")
arguments = {}
return ToolCallRequest(
id=call_id,
name=name,
arguments=arguments,
raw_arguments=raw_args,
)
async def execute_single(
self,
tool_call: Dict[str, Any],
bot,
from_wxid: str,
timeout_override: Optional[float] = None,
) -> ToolCallResult:
"""
执行单个工具调用
Args:
tool_call: OpenAI 格式的 tool_call
bot: WechatHookClient 实例
from_wxid: 消息来源 wxid
timeout_override: 覆盖默认超时
Returns:
ToolCallResult 对象
"""
from utils.tool_registry import get_tool_registry
from utils.llm_tooling import validate_tool_arguments, ToolResult
start_time = time.time()
request = self.parse_tool_call(tool_call)
registry = get_tool_registry()
result = ToolCallResult(
id=request.id,
name=request.name,
)
# 获取工具定义
tool_def = registry.get(request.name)
if not tool_def:
result.success = False
result.error = f"工具 {request.name} 不存在"
result.message = result.error
self._update_stats(False, time.time() - start_time)
return result
# 参数校验
if self.validate_args:
schema = tool_def.schema.get("function", {}).get("parameters", {})
ok, error_msg, validated_args = validate_tool_arguments(
request.name, request.arguments, schema
)
if not ok:
result.success = False
result.error = error_msg
result.message = error_msg
self._update_stats(False, time.time() - start_time)
return result
request.arguments = validated_args
# 执行工具
timeout = timeout_override or tool_def.timeout or self.default_timeout
try:
logger.debug(f"[ToolExecutor] 执行工具: {request.name}")
raw_result = await asyncio.wait_for(
tool_def.executor(request.name, request.arguments, bot, from_wxid),
timeout=timeout
)
# 解析结果
tool_result = ToolResult.from_raw(raw_result)
if tool_result:
result.success = tool_result.success
result.message = tool_result.message
result.need_ai_reply = tool_result.need_ai_reply
result.already_sent = tool_result.already_sent
result.send_result_text = tool_result.send_result_text
result.no_reply = tool_result.no_reply
result.save_to_memory = tool_result.save_to_memory
else:
result.message = str(raw_result) if raw_result else "执行完成"
result.raw_result = raw_result if isinstance(raw_result, dict) else {"result": raw_result}
execution_time = time.time() - start_time
result.execution_time_ms = execution_time * 1000
self._update_stats(result.success, execution_time)
logger.debug(
f"[ToolExecutor] 工具 {request.name} 执行完成 "
f"({result.execution_time_ms:.1f}ms)"
)
except asyncio.TimeoutError:
result.success = False
result.error = f"执行超时 ({timeout}s)"
result.message = result.error
self._update_stats(False, time.time() - start_time, timeout=True)
logger.warning(f"[ToolExecutor] 工具 {request.name} 执行超时")
except asyncio.CancelledError:
raise
except Exception as e:
result.success = False
result.error = str(e)
result.message = f"执行失败: {e}"
self._update_stats(False, time.time() - start_time)
logger.error(f"[ToolExecutor] 工具 {request.name} 执行异常: {e}")
return result
async def execute_batch(
self,
tool_calls: List[Dict[str, Any]],
bot,
from_wxid: str,
parallel: bool = True,
stop_on_error: bool = False,
) -> List[ToolCallResult]:
"""
批量执行工具调用
Args:
tool_calls: 工具调用列表
bot: WechatHookClient 实例
from_wxid: 消息来源 wxid
parallel: 是否并行执行
stop_on_error: 遇到错误是否停止
Returns:
ToolCallResult 列表
"""
if not tool_calls:
return []
if parallel and len(tool_calls) > 1:
return await self._execute_parallel(tool_calls, bot, from_wxid, stop_on_error)
else:
return await self._execute_sequential(tool_calls, bot, from_wxid, stop_on_error)
async def _execute_sequential(
self,
tool_calls: List[Dict[str, Any]],
bot,
from_wxid: str,
stop_on_error: bool,
) -> List[ToolCallResult]:
"""顺序执行"""
results = []
for tool_call in tool_calls:
result = await self.execute_single(tool_call, bot, from_wxid)
results.append(result)
if stop_on_error and not result.success:
logger.warning(f"[ToolExecutor] 工具 {result.name} 失败,停止批量执行")
break
return results
async def _execute_parallel(
self,
tool_calls: List[Dict[str, Any]],
bot,
from_wxid: str,
stop_on_error: bool,
) -> List[ToolCallResult]:
"""并行执行(带并发限制)"""
semaphore = asyncio.Semaphore(self.max_parallel)
async def execute_with_limit(tool_call):
async with semaphore:
return await self.execute_single(tool_call, bot, from_wxid)
tasks = [execute_with_limit(tc) for tc in tool_calls]
if stop_on_error:
# 使用 gather 但不 return_exceptions让第一个错误停止执行
results = []
for coro in asyncio.as_completed(tasks):
try:
result = await coro
results.append(result)
if not result.success:
# 取消剩余任务
for task in tasks:
if isinstance(task, asyncio.Task) and not task.done():
task.cancel()
break
except Exception as e:
logger.error(f"[ToolExecutor] 并行执行异常: {e}")
break
return results
else:
# 全部执行,收集所有结果
return await asyncio.gather(*tasks, return_exceptions=False)
def _update_stats(self, success: bool, execution_time: float, timeout: bool = False):
"""更新执行统计"""
self._stats.total_calls += 1
if success:
self._stats.successful_calls += 1
else:
self._stats.failed_calls += 1
if timeout:
self._stats.timeout_calls += 1
self._stats.total_time_ms += execution_time * 1000
self._stats.avg_time_ms = self._stats.total_time_ms / self._stats.total_calls
def get_stats(self) -> Dict[str, Any]:
"""获取执行统计"""
return {
"total_calls": self._stats.total_calls,
"successful_calls": self._stats.successful_calls,
"failed_calls": self._stats.failed_calls,
"timeout_calls": self._stats.timeout_calls,
"total_time_ms": self._stats.total_time_ms,
"avg_time_ms": self._stats.avg_time_ms,
"success_rate": (
self._stats.successful_calls / self._stats.total_calls
if self._stats.total_calls > 0 else 0
),
}
def reset_stats(self):
"""重置统计"""
self._stats = ExecutionStats()
class ToolCallChain:
"""
工具调用链
用于处理需要多轮工具调用的场景,记录调用历史。
"""
def __init__(self, max_rounds: int = 10):
self.max_rounds = max_rounds
self.history: List[ToolCallResult] = []
self.current_round = 0
def add_result(self, result: ToolCallResult):
"""添加调用结果"""
self.history.append(result)
def add_results(self, results: List[ToolCallResult]):
"""添加多个调用结果"""
self.history.extend(results)
def increment_round(self):
"""增加轮次"""
self.current_round += 1
def can_continue(self) -> bool:
"""检查是否可以继续调用"""
return self.current_round < self.max_rounds
def get_tool_messages(self) -> List[Dict[str, Any]]:
"""获取所有工具调用的消息(用于发送给 LLM"""
return [result.to_message() for result in self.history]
def get_last_results(self, n: int = 1) -> List[ToolCallResult]:
"""获取最后 n 个结果"""
return self.history[-n:] if self.history else []
def has_special_flags(self) -> Dict[str, bool]:
"""检查是否有特殊标志"""
flags = {
"need_ai_reply": False,
"already_sent": False,
"no_reply": False,
"save_to_memory": False,
"send_result_text": False,
}
for result in self.history:
if result.need_ai_reply:
flags["need_ai_reply"] = True
if result.already_sent:
flags["already_sent"] = True
if result.no_reply:
flags["no_reply"] = True
if result.save_to_memory:
flags["save_to_memory"] = True
if result.send_result_text:
flags["send_result_text"] = True
return flags
def get_summary(self) -> str:
"""获取调用链摘要"""
if not self.history:
return "无工具调用"
successful = sum(1 for r in self.history if r.success)
failed = len(self.history) - successful
total_time = sum(r.execution_time_ms for r in self.history)
tools_called = [r.name for r in self.history]
return (
f"调用链: {len(self.history)} 个工具, "
f"成功 {successful}, 失败 {failed}, "
f"总耗时 {total_time:.1f}ms, "
f"工具: {', '.join(tools_called)}"
)
# ==================== 便捷函数 ====================
_default_executor: Optional[ToolExecutor] = None
def get_tool_executor(
default_timeout: float = 60.0,
max_parallel: int = 5,
) -> ToolExecutor:
"""获取默认工具执行器"""
global _default_executor
if _default_executor is None:
_default_executor = ToolExecutor(
default_timeout=default_timeout,
max_parallel=max_parallel,
)
return _default_executor
async def execute_tool_calls(
tool_calls: List[Dict[str, Any]],
bot,
from_wxid: str,
parallel: bool = True,
) -> List[ToolCallResult]:
"""便捷函数:执行工具调用列表"""
executor = get_tool_executor()
return await executor.execute_batch(tool_calls, bot, from_wxid, parallel=parallel)
# ==================== 导出 ====================
__all__ = [
'ToolCallRequest',
'ToolCallResult',
'ExecutionStats',
'ToolExecutor',
'ToolCallChain',
'get_tool_executor',
'execute_tool_calls',
]