Files
WechatHookBot/utils/tool_executor.py

489 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工具执行器模块
提供工具调用的高级执行逻辑:
- 批量工具执行(支持并行)
- 工具调用链处理
- 执行日志和审计
- 结果聚合
使用示例:
from utils.tool_executor import ToolExecutor, ToolCallRequest
executor = ToolExecutor()
# 单个工具执行
result = await executor.execute_single(
tool_call={"id": "call_1", "function": {"name": "get_weather", "arguments": "{}"}},
bot=bot,
from_wxid=wxid,
)
# 批量工具执行
results = await executor.execute_batch(
tool_calls=[...],
bot=bot,
from_wxid=wxid,
parallel=True,
)
"""
from __future__ import annotations
import asyncio
import json
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
from loguru import logger
@dataclass
class ToolCallRequest:
"""工具调用请求"""
id: str
name: str
arguments: Dict[str, Any]
raw_arguments: str = "" # 原始 JSON 字符串
@dataclass
class ToolCallResult:
"""工具调用结果"""
id: str
name: str
success: bool = True
message: str = ""
raw_result: Dict[str, Any] = field(default_factory=dict)
# 控制标志
need_ai_reply: bool = False
already_sent: bool = False
send_result_text: bool = False
no_reply: bool = False
save_to_memory: bool = False
# 执行信息
execution_time_ms: float = 0.0
error: Optional[str] = None
def to_message(self) -> Dict[str, Any]:
"""转换为 OpenAI 兼容的 tool message"""
content = self.message if self.success else f"错误: {self.error or self.message}"
return {
"role": "tool",
"tool_call_id": self.id,
"content": content
}
@dataclass
class ExecutionStats:
"""执行统计"""
total_calls: int = 0
successful_calls: int = 0
failed_calls: int = 0
timeout_calls: int = 0
total_time_ms: float = 0.0
avg_time_ms: float = 0.0
class ToolExecutor:
"""
工具执行器
提供统一的工具执行接口:
- 参数解析和校验
- 超时保护
- 错误处理
- 执行统计
"""
def __init__(
self,
default_timeout: float = 60.0,
max_parallel: int = 5,
validate_args: bool = True,
):
self.default_timeout = default_timeout
self.max_parallel = max_parallel
self.validate_args = validate_args
self._stats = ExecutionStats()
def parse_tool_call(self, tool_call: Dict[str, Any]) -> ToolCallRequest:
"""
解析 OpenAI 格式的工具调用
Args:
tool_call: OpenAI 返回的 tool_call 对象
Returns:
ToolCallRequest 对象
"""
call_id = tool_call.get("id", "")
function = tool_call.get("function", {})
name = function.get("name", "")
raw_args = function.get("arguments", "{}")
# 解析 arguments JSON
try:
arguments = json.loads(raw_args) if raw_args else {}
except json.JSONDecodeError as e:
logger.warning(f"[ToolExecutor] 解析参数失败: {e}, raw={raw_args[:100]}")
arguments = {}
return ToolCallRequest(
id=call_id,
name=name,
arguments=arguments,
raw_arguments=raw_args,
)
async def execute_single(
self,
tool_call: Dict[str, Any],
bot,
from_wxid: str,
timeout_override: Optional[float] = None,
) -> ToolCallResult:
"""
执行单个工具调用
Args:
tool_call: OpenAI 格式的 tool_call
bot: WechatHookClient 实例
from_wxid: 消息来源 wxid
timeout_override: 覆盖默认超时
Returns:
ToolCallResult 对象
"""
from utils.tool_registry import get_tool_registry
from utils.llm_tooling import validate_tool_arguments, ToolResult
start_time = time.time()
request = self.parse_tool_call(tool_call)
registry = get_tool_registry()
result = ToolCallResult(
id=request.id,
name=request.name,
)
# 获取工具定义
tool_def = registry.get(request.name)
if not tool_def:
result.success = False
result.error = f"工具 {request.name} 不存在"
result.message = result.error
self._update_stats(False, time.time() - start_time)
return result
# 参数校验
if self.validate_args:
schema = tool_def.schema.get("function", {}).get("parameters", {})
ok, error_msg, validated_args = validate_tool_arguments(
request.name, request.arguments, schema
)
if not ok:
result.success = False
result.error = error_msg
result.message = error_msg
self._update_stats(False, time.time() - start_time)
return result
request.arguments = validated_args
# 执行工具
timeout = timeout_override or tool_def.timeout or self.default_timeout
try:
logger.debug(f"[ToolExecutor] 执行工具: {request.name}")
raw_result = await asyncio.wait_for(
tool_def.executor(request.name, request.arguments, bot, from_wxid),
timeout=timeout
)
# 解析结果
tool_result = ToolResult.from_raw(raw_result)
if tool_result:
result.success = tool_result.success
result.message = tool_result.message
result.need_ai_reply = tool_result.need_ai_reply
result.already_sent = tool_result.already_sent
result.send_result_text = tool_result.send_result_text
result.no_reply = tool_result.no_reply
result.save_to_memory = tool_result.save_to_memory
else:
result.message = str(raw_result) if raw_result else "执行完成"
result.raw_result = raw_result if isinstance(raw_result, dict) else {"result": raw_result}
execution_time = time.time() - start_time
result.execution_time_ms = execution_time * 1000
self._update_stats(result.success, execution_time)
logger.debug(
f"[ToolExecutor] 工具 {request.name} 执行完成 "
f"({result.execution_time_ms:.1f}ms)"
)
except asyncio.TimeoutError:
result.success = False
result.error = f"执行超时 ({timeout}s)"
result.message = result.error
self._update_stats(False, time.time() - start_time, timeout=True)
logger.warning(f"[ToolExecutor] 工具 {request.name} 执行超时")
except asyncio.CancelledError:
raise
except Exception as e:
result.success = False
result.error = str(e)
result.message = f"执行失败: {e}"
self._update_stats(False, time.time() - start_time)
logger.error(f"[ToolExecutor] 工具 {request.name} 执行异常: {e}")
return result
async def execute_batch(
self,
tool_calls: List[Dict[str, Any]],
bot,
from_wxid: str,
parallel: bool = True,
stop_on_error: bool = False,
) -> List[ToolCallResult]:
"""
批量执行工具调用
Args:
tool_calls: 工具调用列表
bot: WechatHookClient 实例
from_wxid: 消息来源 wxid
parallel: 是否并行执行
stop_on_error: 遇到错误是否停止
Returns:
ToolCallResult 列表
"""
if not tool_calls:
return []
if parallel and len(tool_calls) > 1:
return await self._execute_parallel(tool_calls, bot, from_wxid, stop_on_error)
else:
return await self._execute_sequential(tool_calls, bot, from_wxid, stop_on_error)
async def _execute_sequential(
self,
tool_calls: List[Dict[str, Any]],
bot,
from_wxid: str,
stop_on_error: bool,
) -> List[ToolCallResult]:
"""顺序执行"""
results = []
for tool_call in tool_calls:
result = await self.execute_single(tool_call, bot, from_wxid)
results.append(result)
if stop_on_error and not result.success:
logger.warning(f"[ToolExecutor] 工具 {result.name} 失败,停止批量执行")
break
return results
async def _execute_parallel(
self,
tool_calls: List[Dict[str, Any]],
bot,
from_wxid: str,
stop_on_error: bool,
) -> List[ToolCallResult]:
"""并行执行(带并发限制)"""
semaphore = asyncio.Semaphore(self.max_parallel)
async def execute_with_limit(tool_call):
async with semaphore:
return await self.execute_single(tool_call, bot, from_wxid)
tasks = [execute_with_limit(tc) for tc in tool_calls]
if stop_on_error:
# 使用 gather 但不 return_exceptions让第一个错误停止执行
results = []
for coro in asyncio.as_completed(tasks):
try:
result = await coro
results.append(result)
if not result.success:
# 取消剩余任务
for task in tasks:
if isinstance(task, asyncio.Task) and not task.done():
task.cancel()
break
except Exception as e:
logger.error(f"[ToolExecutor] 并行执行异常: {e}")
break
return results
else:
# 全部执行,收集所有结果
return await asyncio.gather(*tasks, return_exceptions=False)
def _update_stats(self, success: bool, execution_time: float, timeout: bool = False):
"""更新执行统计"""
self._stats.total_calls += 1
if success:
self._stats.successful_calls += 1
else:
self._stats.failed_calls += 1
if timeout:
self._stats.timeout_calls += 1
self._stats.total_time_ms += execution_time * 1000
self._stats.avg_time_ms = self._stats.total_time_ms / self._stats.total_calls
def get_stats(self) -> Dict[str, Any]:
"""获取执行统计"""
return {
"total_calls": self._stats.total_calls,
"successful_calls": self._stats.successful_calls,
"failed_calls": self._stats.failed_calls,
"timeout_calls": self._stats.timeout_calls,
"total_time_ms": self._stats.total_time_ms,
"avg_time_ms": self._stats.avg_time_ms,
"success_rate": (
self._stats.successful_calls / self._stats.total_calls
if self._stats.total_calls > 0 else 0
),
}
def reset_stats(self):
"""重置统计"""
self._stats = ExecutionStats()
class ToolCallChain:
"""
工具调用链
用于处理需要多轮工具调用的场景,记录调用历史。
"""
def __init__(self, max_rounds: int = 10):
self.max_rounds = max_rounds
self.history: List[ToolCallResult] = []
self.current_round = 0
def add_result(self, result: ToolCallResult):
"""添加调用结果"""
self.history.append(result)
def add_results(self, results: List[ToolCallResult]):
"""添加多个调用结果"""
self.history.extend(results)
def increment_round(self):
"""增加轮次"""
self.current_round += 1
def can_continue(self) -> bool:
"""检查是否可以继续调用"""
return self.current_round < self.max_rounds
def get_tool_messages(self) -> List[Dict[str, Any]]:
"""获取所有工具调用的消息(用于发送给 LLM"""
return [result.to_message() for result in self.history]
def get_last_results(self, n: int = 1) -> List[ToolCallResult]:
"""获取最后 n 个结果"""
return self.history[-n:] if self.history else []
def has_special_flags(self) -> Dict[str, bool]:
"""检查是否有特殊标志"""
flags = {
"need_ai_reply": False,
"already_sent": False,
"no_reply": False,
"save_to_memory": False,
"send_result_text": False,
}
for result in self.history:
if result.need_ai_reply:
flags["need_ai_reply"] = True
if result.already_sent:
flags["already_sent"] = True
if result.no_reply:
flags["no_reply"] = True
if result.save_to_memory:
flags["save_to_memory"] = True
if result.send_result_text:
flags["send_result_text"] = True
return flags
def get_summary(self) -> str:
"""获取调用链摘要"""
if not self.history:
return "无工具调用"
successful = sum(1 for r in self.history if r.success)
failed = len(self.history) - successful
total_time = sum(r.execution_time_ms for r in self.history)
tools_called = [r.name for r in self.history]
return (
f"调用链: {len(self.history)} 个工具, "
f"成功 {successful}, 失败 {failed}, "
f"总耗时 {total_time:.1f}ms, "
f"工具: {', '.join(tools_called)}"
)
# ==================== 便捷函数 ====================
_default_executor: Optional[ToolExecutor] = None
def get_tool_executor(
default_timeout: float = 60.0,
max_parallel: int = 5,
) -> ToolExecutor:
"""获取默认工具执行器"""
global _default_executor
if _default_executor is None:
_default_executor = ToolExecutor(
default_timeout=default_timeout,
max_parallel=max_parallel,
)
return _default_executor
async def execute_tool_calls(
tool_calls: List[Dict[str, Any]],
bot,
from_wxid: str,
parallel: bool = True,
) -> List[ToolCallResult]:
"""便捷函数:执行工具调用列表"""
executor = get_tool_executor()
return await executor.execute_batch(tool_calls, bot, from_wxid, parallel=parallel)
# ==================== 导出 ====================
__all__ = [
'ToolCallRequest',
'ToolCallResult',
'ExecutionStats',
'ToolExecutor',
'ToolCallChain',
'get_tool_executor',
'execute_tool_calls',
]