""" 工具执行器模块 提供工具调用的高级执行逻辑: - 批量工具执行(支持并行) - 工具调用链处理 - 执行日志和审计 - 结果聚合 使用示例: from utils.tool_executor import ToolExecutor, ToolCallRequest executor = ToolExecutor() # 单个工具执行 result = await executor.execute_single( tool_call={"id": "call_1", "function": {"name": "get_weather", "arguments": "{}"}}, bot=bot, from_wxid=wxid, ) # 批量工具执行 results = await executor.execute_batch( tool_calls=[...], bot=bot, from_wxid=wxid, parallel=True, ) """ from __future__ import annotations import asyncio import json import time from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional from loguru import logger @dataclass class ToolCallRequest: """工具调用请求""" id: str name: str arguments: Dict[str, Any] raw_arguments: str = "" # 原始 JSON 字符串 @dataclass class ToolCallResult: """工具调用结果""" id: str name: str success: bool = True message: str = "" raw_result: Dict[str, Any] = field(default_factory=dict) # 控制标志 need_ai_reply: bool = False already_sent: bool = False send_result_text: bool = False no_reply: bool = False save_to_memory: bool = False # 执行信息 execution_time_ms: float = 0.0 error: Optional[str] = None def to_message(self) -> Dict[str, Any]: """转换为 OpenAI 兼容的 tool message""" content = self.message if self.success else f"错误: {self.error or self.message}" return { "role": "tool", "tool_call_id": self.id, "content": content } @dataclass class ExecutionStats: """执行统计""" total_calls: int = 0 successful_calls: int = 0 failed_calls: int = 0 timeout_calls: int = 0 total_time_ms: float = 0.0 avg_time_ms: float = 0.0 class ToolExecutor: """ 工具执行器 提供统一的工具执行接口: - 参数解析和校验 - 超时保护 - 错误处理 - 执行统计 """ def __init__( self, default_timeout: float = 60.0, max_parallel: int = 5, validate_args: bool = True, ): self.default_timeout = default_timeout self.max_parallel = max_parallel self.validate_args = validate_args self._stats = ExecutionStats() def parse_tool_call(self, tool_call: Dict[str, Any]) -> ToolCallRequest: """ 解析 OpenAI 格式的工具调用 Args: tool_call: OpenAI 返回的 tool_call 对象 Returns: ToolCallRequest 对象 """ call_id = tool_call.get("id", "") function = tool_call.get("function", {}) name = function.get("name", "") raw_args = function.get("arguments", "{}") # 解析 arguments JSON try: arguments = json.loads(raw_args) if raw_args else {} except json.JSONDecodeError as e: logger.warning(f"[ToolExecutor] 解析参数失败: {e}, raw={raw_args[:100]}") arguments = {} return ToolCallRequest( id=call_id, name=name, arguments=arguments, raw_arguments=raw_args, ) async def execute_single( self, tool_call: Dict[str, Any], bot, from_wxid: str, timeout_override: Optional[float] = None, ) -> ToolCallResult: """ 执行单个工具调用 Args: tool_call: OpenAI 格式的 tool_call bot: WechatHookClient 实例 from_wxid: 消息来源 wxid timeout_override: 覆盖默认超时 Returns: ToolCallResult 对象 """ from utils.tool_registry import get_tool_registry from utils.llm_tooling import validate_tool_arguments, ToolResult start_time = time.time() request = self.parse_tool_call(tool_call) registry = get_tool_registry() result = ToolCallResult( id=request.id, name=request.name, ) # 获取工具定义 tool_def = registry.get(request.name) if not tool_def: result.success = False result.error = f"工具 {request.name} 不存在" result.message = result.error self._update_stats(False, time.time() - start_time) return result # 参数校验 if self.validate_args: schema = tool_def.schema.get("function", {}).get("parameters", {}) ok, error_msg, validated_args = validate_tool_arguments( request.name, request.arguments, schema ) if not ok: result.success = False result.error = error_msg result.message = error_msg self._update_stats(False, time.time() - start_time) return result request.arguments = validated_args # 执行工具 timeout = timeout_override or tool_def.timeout or self.default_timeout try: logger.debug(f"[ToolExecutor] 执行工具: {request.name}") raw_result = await asyncio.wait_for( tool_def.executor(request.name, request.arguments, bot, from_wxid), timeout=timeout ) # 解析结果 tool_result = ToolResult.from_raw(raw_result) if tool_result: result.success = tool_result.success result.message = tool_result.message result.need_ai_reply = tool_result.need_ai_reply result.already_sent = tool_result.already_sent result.send_result_text = tool_result.send_result_text result.no_reply = tool_result.no_reply result.save_to_memory = tool_result.save_to_memory else: result.message = str(raw_result) if raw_result else "执行完成" result.raw_result = raw_result if isinstance(raw_result, dict) else {"result": raw_result} execution_time = time.time() - start_time result.execution_time_ms = execution_time * 1000 self._update_stats(result.success, execution_time) logger.debug( f"[ToolExecutor] 工具 {request.name} 执行完成 " f"({result.execution_time_ms:.1f}ms)" ) except asyncio.TimeoutError: result.success = False result.error = f"执行超时 ({timeout}s)" result.message = result.error self._update_stats(False, time.time() - start_time, timeout=True) logger.warning(f"[ToolExecutor] 工具 {request.name} 执行超时") except asyncio.CancelledError: raise except Exception as e: result.success = False result.error = str(e) result.message = f"执行失败: {e}" self._update_stats(False, time.time() - start_time) logger.error(f"[ToolExecutor] 工具 {request.name} 执行异常: {e}") return result async def execute_batch( self, tool_calls: List[Dict[str, Any]], bot, from_wxid: str, parallel: bool = True, stop_on_error: bool = False, ) -> List[ToolCallResult]: """ 批量执行工具调用 Args: tool_calls: 工具调用列表 bot: WechatHookClient 实例 from_wxid: 消息来源 wxid parallel: 是否并行执行 stop_on_error: 遇到错误是否停止 Returns: ToolCallResult 列表 """ if not tool_calls: return [] if parallel and len(tool_calls) > 1: return await self._execute_parallel(tool_calls, bot, from_wxid, stop_on_error) else: return await self._execute_sequential(tool_calls, bot, from_wxid, stop_on_error) async def _execute_sequential( self, tool_calls: List[Dict[str, Any]], bot, from_wxid: str, stop_on_error: bool, ) -> List[ToolCallResult]: """顺序执行""" results = [] for tool_call in tool_calls: result = await self.execute_single(tool_call, bot, from_wxid) results.append(result) if stop_on_error and not result.success: logger.warning(f"[ToolExecutor] 工具 {result.name} 失败,停止批量执行") break return results async def _execute_parallel( self, tool_calls: List[Dict[str, Any]], bot, from_wxid: str, stop_on_error: bool, ) -> List[ToolCallResult]: """并行执行(带并发限制)""" semaphore = asyncio.Semaphore(self.max_parallel) async def execute_with_limit(tool_call): async with semaphore: return await self.execute_single(tool_call, bot, from_wxid) tasks = [execute_with_limit(tc) for tc in tool_calls] if stop_on_error: # 使用 gather 但不 return_exceptions,让第一个错误停止执行 results = [] for coro in asyncio.as_completed(tasks): try: result = await coro results.append(result) if not result.success: # 取消剩余任务 for task in tasks: if isinstance(task, asyncio.Task) and not task.done(): task.cancel() break except Exception as e: logger.error(f"[ToolExecutor] 并行执行异常: {e}") break return results else: # 全部执行,收集所有结果 return await asyncio.gather(*tasks, return_exceptions=False) def _update_stats(self, success: bool, execution_time: float, timeout: bool = False): """更新执行统计""" self._stats.total_calls += 1 if success: self._stats.successful_calls += 1 else: self._stats.failed_calls += 1 if timeout: self._stats.timeout_calls += 1 self._stats.total_time_ms += execution_time * 1000 self._stats.avg_time_ms = self._stats.total_time_ms / self._stats.total_calls def get_stats(self) -> Dict[str, Any]: """获取执行统计""" return { "total_calls": self._stats.total_calls, "successful_calls": self._stats.successful_calls, "failed_calls": self._stats.failed_calls, "timeout_calls": self._stats.timeout_calls, "total_time_ms": self._stats.total_time_ms, "avg_time_ms": self._stats.avg_time_ms, "success_rate": ( self._stats.successful_calls / self._stats.total_calls if self._stats.total_calls > 0 else 0 ), } def reset_stats(self): """重置统计""" self._stats = ExecutionStats() class ToolCallChain: """ 工具调用链 用于处理需要多轮工具调用的场景,记录调用历史。 """ def __init__(self, max_rounds: int = 10): self.max_rounds = max_rounds self.history: List[ToolCallResult] = [] self.current_round = 0 def add_result(self, result: ToolCallResult): """添加调用结果""" self.history.append(result) def add_results(self, results: List[ToolCallResult]): """添加多个调用结果""" self.history.extend(results) def increment_round(self): """增加轮次""" self.current_round += 1 def can_continue(self) -> bool: """检查是否可以继续调用""" return self.current_round < self.max_rounds def get_tool_messages(self) -> List[Dict[str, Any]]: """获取所有工具调用的消息(用于发送给 LLM)""" return [result.to_message() for result in self.history] def get_last_results(self, n: int = 1) -> List[ToolCallResult]: """获取最后 n 个结果""" return self.history[-n:] if self.history else [] def has_special_flags(self) -> Dict[str, bool]: """检查是否有特殊标志""" flags = { "need_ai_reply": False, "already_sent": False, "no_reply": False, "save_to_memory": False, "send_result_text": False, } for result in self.history: if result.need_ai_reply: flags["need_ai_reply"] = True if result.already_sent: flags["already_sent"] = True if result.no_reply: flags["no_reply"] = True if result.save_to_memory: flags["save_to_memory"] = True if result.send_result_text: flags["send_result_text"] = True return flags def get_summary(self) -> str: """获取调用链摘要""" if not self.history: return "无工具调用" successful = sum(1 for r in self.history if r.success) failed = len(self.history) - successful total_time = sum(r.execution_time_ms for r in self.history) tools_called = [r.name for r in self.history] return ( f"调用链: {len(self.history)} 个工具, " f"成功 {successful}, 失败 {failed}, " f"总耗时 {total_time:.1f}ms, " f"工具: {', '.join(tools_called)}" ) # ==================== 便捷函数 ==================== _default_executor: Optional[ToolExecutor] = None def get_tool_executor( default_timeout: float = 60.0, max_parallel: int = 5, ) -> ToolExecutor: """获取默认工具执行器""" global _default_executor if _default_executor is None: _default_executor = ToolExecutor( default_timeout=default_timeout, max_parallel=max_parallel, ) return _default_executor async def execute_tool_calls( tool_calls: List[Dict[str, Any]], bot, from_wxid: str, parallel: bool = True, ) -> List[ToolCallResult]: """便捷函数:执行工具调用列表""" executor = get_tool_executor() return await executor.execute_batch(tool_calls, bot, from_wxid, parallel=parallel) # ==================== 导出 ==================== __all__ = [ 'ToolCallRequest', 'ToolCallResult', 'ExecutionStats', 'ToolExecutor', 'ToolCallChain', 'get_tool_executor', 'execute_tool_calls', ]