Merge branch 'main' of https://gitea.functen.cn/shihao/WechatHookBot
This commit is contained in:
488
utils/tool_executor.py
Normal file
488
utils/tool_executor.py
Normal file
@@ -0,0 +1,488 @@
|
||||
"""
|
||||
工具执行器模块
|
||||
|
||||
提供工具调用的高级执行逻辑:
|
||||
- 批量工具执行(支持并行)
|
||||
- 工具调用链处理
|
||||
- 执行日志和审计
|
||||
- 结果聚合
|
||||
|
||||
使用示例:
|
||||
from utils.tool_executor import ToolExecutor, ToolCallRequest
|
||||
|
||||
executor = ToolExecutor()
|
||||
|
||||
# 单个工具执行
|
||||
result = await executor.execute_single(
|
||||
tool_call={"id": "call_1", "function": {"name": "get_weather", "arguments": "{}"}},
|
||||
bot=bot,
|
||||
from_wxid=wxid,
|
||||
)
|
||||
|
||||
# 批量工具执行
|
||||
results = await executor.execute_batch(
|
||||
tool_calls=[...],
|
||||
bot=bot,
|
||||
from_wxid=wxid,
|
||||
parallel=True,
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallRequest:
|
||||
"""工具调用请求"""
|
||||
id: str
|
||||
name: str
|
||||
arguments: Dict[str, Any]
|
||||
raw_arguments: str = "" # 原始 JSON 字符串
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallResult:
|
||||
"""工具调用结果"""
|
||||
id: str
|
||||
name: str
|
||||
success: bool = True
|
||||
message: str = ""
|
||||
raw_result: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# 控制标志
|
||||
need_ai_reply: bool = False
|
||||
already_sent: bool = False
|
||||
send_result_text: bool = False
|
||||
no_reply: bool = False
|
||||
save_to_memory: bool = False
|
||||
|
||||
# 执行信息
|
||||
execution_time_ms: float = 0.0
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_message(self) -> Dict[str, Any]:
|
||||
"""转换为 OpenAI 兼容的 tool message"""
|
||||
content = self.message if self.success else f"错误: {self.error or self.message}"
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": self.id,
|
||||
"content": content
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionStats:
|
||||
"""执行统计"""
|
||||
total_calls: int = 0
|
||||
successful_calls: int = 0
|
||||
failed_calls: int = 0
|
||||
timeout_calls: int = 0
|
||||
total_time_ms: float = 0.0
|
||||
avg_time_ms: float = 0.0
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""
|
||||
工具执行器
|
||||
|
||||
提供统一的工具执行接口:
|
||||
- 参数解析和校验
|
||||
- 超时保护
|
||||
- 错误处理
|
||||
- 执行统计
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_timeout: float = 60.0,
|
||||
max_parallel: int = 5,
|
||||
validate_args: bool = True,
|
||||
):
|
||||
self.default_timeout = default_timeout
|
||||
self.max_parallel = max_parallel
|
||||
self.validate_args = validate_args
|
||||
self._stats = ExecutionStats()
|
||||
|
||||
def parse_tool_call(self, tool_call: Dict[str, Any]) -> ToolCallRequest:
|
||||
"""
|
||||
解析 OpenAI 格式的工具调用
|
||||
|
||||
Args:
|
||||
tool_call: OpenAI 返回的 tool_call 对象
|
||||
|
||||
Returns:
|
||||
ToolCallRequest 对象
|
||||
"""
|
||||
call_id = tool_call.get("id", "")
|
||||
function = tool_call.get("function", {})
|
||||
name = function.get("name", "")
|
||||
raw_args = function.get("arguments", "{}")
|
||||
|
||||
# 解析 arguments JSON
|
||||
try:
|
||||
arguments = json.loads(raw_args) if raw_args else {}
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"[ToolExecutor] 解析参数失败: {e}, raw={raw_args[:100]}")
|
||||
arguments = {}
|
||||
|
||||
return ToolCallRequest(
|
||||
id=call_id,
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
raw_arguments=raw_args,
|
||||
)
|
||||
|
||||
async def execute_single(
|
||||
self,
|
||||
tool_call: Dict[str, Any],
|
||||
bot,
|
||||
from_wxid: str,
|
||||
timeout_override: Optional[float] = None,
|
||||
) -> ToolCallResult:
|
||||
"""
|
||||
执行单个工具调用
|
||||
|
||||
Args:
|
||||
tool_call: OpenAI 格式的 tool_call
|
||||
bot: WechatHookClient 实例
|
||||
from_wxid: 消息来源 wxid
|
||||
timeout_override: 覆盖默认超时
|
||||
|
||||
Returns:
|
||||
ToolCallResult 对象
|
||||
"""
|
||||
from utils.tool_registry import get_tool_registry
|
||||
from utils.llm_tooling import validate_tool_arguments, ToolResult
|
||||
|
||||
start_time = time.time()
|
||||
request = self.parse_tool_call(tool_call)
|
||||
registry = get_tool_registry()
|
||||
|
||||
result = ToolCallResult(
|
||||
id=request.id,
|
||||
name=request.name,
|
||||
)
|
||||
|
||||
# 获取工具定义
|
||||
tool_def = registry.get(request.name)
|
||||
if not tool_def:
|
||||
result.success = False
|
||||
result.error = f"工具 {request.name} 不存在"
|
||||
result.message = result.error
|
||||
self._update_stats(False, time.time() - start_time)
|
||||
return result
|
||||
|
||||
# 参数校验
|
||||
if self.validate_args:
|
||||
schema = tool_def.schema.get("function", {}).get("parameters", {})
|
||||
ok, error_msg, validated_args = validate_tool_arguments(
|
||||
request.name, request.arguments, schema
|
||||
)
|
||||
if not ok:
|
||||
result.success = False
|
||||
result.error = error_msg
|
||||
result.message = error_msg
|
||||
self._update_stats(False, time.time() - start_time)
|
||||
return result
|
||||
request.arguments = validated_args
|
||||
|
||||
# 执行工具
|
||||
timeout = timeout_override or tool_def.timeout or self.default_timeout
|
||||
|
||||
try:
|
||||
logger.debug(f"[ToolExecutor] 执行工具: {request.name}")
|
||||
|
||||
raw_result = await asyncio.wait_for(
|
||||
tool_def.executor(request.name, request.arguments, bot, from_wxid),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# 解析结果
|
||||
tool_result = ToolResult.from_raw(raw_result)
|
||||
if tool_result:
|
||||
result.success = tool_result.success
|
||||
result.message = tool_result.message
|
||||
result.need_ai_reply = tool_result.need_ai_reply
|
||||
result.already_sent = tool_result.already_sent
|
||||
result.send_result_text = tool_result.send_result_text
|
||||
result.no_reply = tool_result.no_reply
|
||||
result.save_to_memory = tool_result.save_to_memory
|
||||
else:
|
||||
result.message = str(raw_result) if raw_result else "执行完成"
|
||||
|
||||
result.raw_result = raw_result if isinstance(raw_result, dict) else {"result": raw_result}
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
result.execution_time_ms = execution_time * 1000
|
||||
self._update_stats(result.success, execution_time)
|
||||
|
||||
logger.debug(
|
||||
f"[ToolExecutor] 工具 {request.name} 执行完成 "
|
||||
f"({result.execution_time_ms:.1f}ms)"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
result.success = False
|
||||
result.error = f"执行超时 ({timeout}s)"
|
||||
result.message = result.error
|
||||
self._update_stats(False, time.time() - start_time, timeout=True)
|
||||
logger.warning(f"[ToolExecutor] 工具 {request.name} 执行超时")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
result.success = False
|
||||
result.error = str(e)
|
||||
result.message = f"执行失败: {e}"
|
||||
self._update_stats(False, time.time() - start_time)
|
||||
logger.error(f"[ToolExecutor] 工具 {request.name} 执行异常: {e}")
|
||||
|
||||
return result
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
bot,
|
||||
from_wxid: str,
|
||||
parallel: bool = True,
|
||||
stop_on_error: bool = False,
|
||||
) -> List[ToolCallResult]:
|
||||
"""
|
||||
批量执行工具调用
|
||||
|
||||
Args:
|
||||
tool_calls: 工具调用列表
|
||||
bot: WechatHookClient 实例
|
||||
from_wxid: 消息来源 wxid
|
||||
parallel: 是否并行执行
|
||||
stop_on_error: 遇到错误是否停止
|
||||
|
||||
Returns:
|
||||
ToolCallResult 列表
|
||||
"""
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
if parallel and len(tool_calls) > 1:
|
||||
return await self._execute_parallel(tool_calls, bot, from_wxid, stop_on_error)
|
||||
else:
|
||||
return await self._execute_sequential(tool_calls, bot, from_wxid, stop_on_error)
|
||||
|
||||
async def _execute_sequential(
|
||||
self,
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
bot,
|
||||
from_wxid: str,
|
||||
stop_on_error: bool,
|
||||
) -> List[ToolCallResult]:
|
||||
"""顺序执行"""
|
||||
results = []
|
||||
for tool_call in tool_calls:
|
||||
result = await self.execute_single(tool_call, bot, from_wxid)
|
||||
results.append(result)
|
||||
|
||||
if stop_on_error and not result.success:
|
||||
logger.warning(f"[ToolExecutor] 工具 {result.name} 失败,停止批量执行")
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
async def _execute_parallel(
|
||||
self,
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
bot,
|
||||
from_wxid: str,
|
||||
stop_on_error: bool,
|
||||
) -> List[ToolCallResult]:
|
||||
"""并行执行(带并发限制)"""
|
||||
semaphore = asyncio.Semaphore(self.max_parallel)
|
||||
|
||||
async def execute_with_limit(tool_call):
|
||||
async with semaphore:
|
||||
return await self.execute_single(tool_call, bot, from_wxid)
|
||||
|
||||
tasks = [execute_with_limit(tc) for tc in tool_calls]
|
||||
|
||||
if stop_on_error:
|
||||
# 使用 gather 但不 return_exceptions,让第一个错误停止执行
|
||||
results = []
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
try:
|
||||
result = await coro
|
||||
results.append(result)
|
||||
if not result.success:
|
||||
# 取消剩余任务
|
||||
for task in tasks:
|
||||
if isinstance(task, asyncio.Task) and not task.done():
|
||||
task.cancel()
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"[ToolExecutor] 并行执行异常: {e}")
|
||||
break
|
||||
return results
|
||||
else:
|
||||
# 全部执行,收集所有结果
|
||||
return await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
def _update_stats(self, success: bool, execution_time: float, timeout: bool = False):
|
||||
"""更新执行统计"""
|
||||
self._stats.total_calls += 1
|
||||
if success:
|
||||
self._stats.successful_calls += 1
|
||||
else:
|
||||
self._stats.failed_calls += 1
|
||||
if timeout:
|
||||
self._stats.timeout_calls += 1
|
||||
|
||||
self._stats.total_time_ms += execution_time * 1000
|
||||
self._stats.avg_time_ms = self._stats.total_time_ms / self._stats.total_calls
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取执行统计"""
|
||||
return {
|
||||
"total_calls": self._stats.total_calls,
|
||||
"successful_calls": self._stats.successful_calls,
|
||||
"failed_calls": self._stats.failed_calls,
|
||||
"timeout_calls": self._stats.timeout_calls,
|
||||
"total_time_ms": self._stats.total_time_ms,
|
||||
"avg_time_ms": self._stats.avg_time_ms,
|
||||
"success_rate": (
|
||||
self._stats.successful_calls / self._stats.total_calls
|
||||
if self._stats.total_calls > 0 else 0
|
||||
),
|
||||
}
|
||||
|
||||
def reset_stats(self):
|
||||
"""重置统计"""
|
||||
self._stats = ExecutionStats()
|
||||
|
||||
|
||||
class ToolCallChain:
|
||||
"""
|
||||
工具调用链
|
||||
|
||||
用于处理需要多轮工具调用的场景,记录调用历史。
|
||||
"""
|
||||
|
||||
def __init__(self, max_rounds: int = 10):
|
||||
self.max_rounds = max_rounds
|
||||
self.history: List[ToolCallResult] = []
|
||||
self.current_round = 0
|
||||
|
||||
def add_result(self, result: ToolCallResult):
|
||||
"""添加调用结果"""
|
||||
self.history.append(result)
|
||||
|
||||
def add_results(self, results: List[ToolCallResult]):
|
||||
"""添加多个调用结果"""
|
||||
self.history.extend(results)
|
||||
|
||||
def increment_round(self):
|
||||
"""增加轮次"""
|
||||
self.current_round += 1
|
||||
|
||||
def can_continue(self) -> bool:
|
||||
"""检查是否可以继续调用"""
|
||||
return self.current_round < self.max_rounds
|
||||
|
||||
def get_tool_messages(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有工具调用的消息(用于发送给 LLM)"""
|
||||
return [result.to_message() for result in self.history]
|
||||
|
||||
def get_last_results(self, n: int = 1) -> List[ToolCallResult]:
|
||||
"""获取最后 n 个结果"""
|
||||
return self.history[-n:] if self.history else []
|
||||
|
||||
def has_special_flags(self) -> Dict[str, bool]:
|
||||
"""检查是否有特殊标志"""
|
||||
flags = {
|
||||
"need_ai_reply": False,
|
||||
"already_sent": False,
|
||||
"no_reply": False,
|
||||
"save_to_memory": False,
|
||||
"send_result_text": False,
|
||||
}
|
||||
|
||||
for result in self.history:
|
||||
if result.need_ai_reply:
|
||||
flags["need_ai_reply"] = True
|
||||
if result.already_sent:
|
||||
flags["already_sent"] = True
|
||||
if result.no_reply:
|
||||
flags["no_reply"] = True
|
||||
if result.save_to_memory:
|
||||
flags["save_to_memory"] = True
|
||||
if result.send_result_text:
|
||||
flags["send_result_text"] = True
|
||||
|
||||
return flags
|
||||
|
||||
def get_summary(self) -> str:
|
||||
"""获取调用链摘要"""
|
||||
if not self.history:
|
||||
return "无工具调用"
|
||||
|
||||
successful = sum(1 for r in self.history if r.success)
|
||||
failed = len(self.history) - successful
|
||||
total_time = sum(r.execution_time_ms for r in self.history)
|
||||
|
||||
tools_called = [r.name for r in self.history]
|
||||
|
||||
return (
|
||||
f"调用链: {len(self.history)} 个工具, "
|
||||
f"成功 {successful}, 失败 {failed}, "
|
||||
f"总耗时 {total_time:.1f}ms, "
|
||||
f"工具: {', '.join(tools_called)}"
|
||||
)
|
||||
|
||||
|
||||
# ==================== 便捷函数 ====================
|
||||
|
||||
_default_executor: Optional[ToolExecutor] = None
|
||||
|
||||
|
||||
def get_tool_executor(
|
||||
default_timeout: float = 60.0,
|
||||
max_parallel: int = 5,
|
||||
) -> ToolExecutor:
|
||||
"""获取默认工具执行器"""
|
||||
global _default_executor
|
||||
if _default_executor is None:
|
||||
_default_executor = ToolExecutor(
|
||||
default_timeout=default_timeout,
|
||||
max_parallel=max_parallel,
|
||||
)
|
||||
return _default_executor
|
||||
|
||||
|
||||
async def execute_tool_calls(
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
bot,
|
||||
from_wxid: str,
|
||||
parallel: bool = True,
|
||||
) -> List[ToolCallResult]:
|
||||
"""便捷函数:执行工具调用列表"""
|
||||
executor = get_tool_executor()
|
||||
return await executor.execute_batch(tool_calls, bot, from_wxid, parallel=parallel)
|
||||
|
||||
|
||||
# ==================== 导出 ====================
|
||||
|
||||
__all__ = [
|
||||
'ToolCallRequest',
|
||||
'ToolCallResult',
|
||||
'ExecutionStats',
|
||||
'ToolExecutor',
|
||||
'ToolCallChain',
|
||||
'get_tool_executor',
|
||||
'execute_tool_calls',
|
||||
]
|
||||
Reference in New Issue
Block a user