Merge branch 'main' of https://gitea.functen.cn/shihao/WechatHookBot
This commit is contained in:
190
utils/config_manager.py
Normal file
190
utils/config_manager.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
统一配置管理器
|
||||
|
||||
单例模式,提供:
|
||||
- 配置缓存,避免重复读取文件
|
||||
- 配置热更新检测
|
||||
- 类型安全的配置访问
|
||||
"""
|
||||
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""
|
||||
配置管理器 (线程安全单例)
|
||||
|
||||
使用示例:
|
||||
from utils.config_manager import get_config
|
||||
|
||||
# 获取单个配置项
|
||||
admins = get_config().get("Bot", "admins", [])
|
||||
|
||||
# 获取整个配置节
|
||||
bot_config = get_config().get_section("Bot")
|
||||
|
||||
# 检查并重新加载
|
||||
if get_config().reload_if_changed():
|
||||
logger.info("配置已更新")
|
||||
"""
|
||||
|
||||
_instance: Optional["ConfigManager"] = None
|
||||
_lock = Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
instance = super().__new__(cls)
|
||||
instance._initialized = False
|
||||
cls._instance = instance
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._config: Dict[str, Any] = {}
|
||||
self._config_path = Path("main_config.toml")
|
||||
self._file_mtime: float = 0
|
||||
self._config_lock = Lock()
|
||||
self._reload()
|
||||
self._initialized = True
|
||||
logger.debug("ConfigManager 初始化完成")
|
||||
|
||||
def _reload(self) -> bool:
|
||||
"""重新加载配置文件"""
|
||||
try:
|
||||
if not self._config_path.exists():
|
||||
logger.warning(f"配置文件不存在: {self._config_path}")
|
||||
return False
|
||||
|
||||
current_mtime = self._config_path.stat().st_mtime
|
||||
if current_mtime == self._file_mtime and self._config:
|
||||
return False # 文件未变化
|
||||
|
||||
with self._config_lock:
|
||||
with open(self._config_path, "rb") as f:
|
||||
self._config = tomllib.load(f)
|
||||
self._file_mtime = current_mtime
|
||||
|
||||
logger.debug("配置文件已重新加载")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件失败: {e}")
|
||||
return False
|
||||
|
||||
def get(self, section: str, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
获取配置项
|
||||
|
||||
Args:
|
||||
section: 配置节名称,如 "Bot"
|
||||
key: 配置项名称,如 "admins"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
配置值或默认值
|
||||
"""
|
||||
return self._config.get(section, {}).get(key, default)
|
||||
|
||||
def get_section(self, section: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取整个配置节
|
||||
|
||||
Args:
|
||||
section: 配置节名称
|
||||
|
||||
Returns:
|
||||
配置节字典的副本
|
||||
"""
|
||||
return self._config.get(section, {}).copy()
|
||||
|
||||
def get_all(self) -> Dict[str, Any]:
|
||||
"""获取完整配置(只读副本)"""
|
||||
return self._config.copy()
|
||||
|
||||
def reload_if_changed(self) -> bool:
|
||||
"""
|
||||
如果文件有变化则重新加载
|
||||
|
||||
Returns:
|
||||
是否重新加载了配置
|
||||
"""
|
||||
try:
|
||||
if not self._config_path.exists():
|
||||
return False
|
||||
current_mtime = self._config_path.stat().st_mtime
|
||||
if current_mtime != self._file_mtime:
|
||||
return self._reload()
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def force_reload(self) -> bool:
|
||||
"""强制重新加载配置"""
|
||||
self._file_mtime = 0
|
||||
return self._reload()
|
||||
|
||||
|
||||
# ==================== 便捷函数 ====================
|
||||
|
||||
def get_config() -> ConfigManager:
|
||||
"""获取配置管理器实例"""
|
||||
return ConfigManager()
|
||||
|
||||
|
||||
def get_bot_config() -> Dict[str, Any]:
|
||||
"""快捷获取 [Bot] 配置节"""
|
||||
return get_config().get_section("Bot")
|
||||
|
||||
|
||||
def get_performance_config() -> Dict[str, Any]:
|
||||
"""快捷获取 [Performance] 配置节"""
|
||||
return get_config().get_section("Performance")
|
||||
|
||||
|
||||
def get_database_config() -> Dict[str, Any]:
|
||||
"""快捷获取 [Database] 配置节"""
|
||||
return get_config().get_section("Database")
|
||||
|
||||
|
||||
def get_scheduler_config() -> Dict[str, Any]:
|
||||
"""快捷获取 [Scheduler] 配置节"""
|
||||
return get_config().get_section("Scheduler")
|
||||
|
||||
|
||||
def get_queue_config() -> Dict[str, Any]:
|
||||
"""快捷获取 [Queue] 配置节"""
|
||||
return get_config().get_section("Queue")
|
||||
|
||||
|
||||
def get_concurrency_config() -> Dict[str, Any]:
|
||||
"""快捷获取 [Concurrency] 配置节"""
|
||||
return get_config().get_section("Concurrency")
|
||||
|
||||
|
||||
def get_webui_config() -> Dict[str, Any]:
|
||||
"""快捷获取 [WebUI] 配置节"""
|
||||
return get_config().get_section("WebUI")
|
||||
|
||||
|
||||
# ==================== 导出列表 ====================
|
||||
|
||||
__all__ = [
|
||||
'ConfigManager',
|
||||
'get_config',
|
||||
'get_bot_config',
|
||||
'get_performance_config',
|
||||
'get_database_config',
|
||||
'get_scheduler_config',
|
||||
'get_queue_config',
|
||||
'get_concurrency_config',
|
||||
'get_webui_config',
|
||||
]
|
||||
@@ -1,5 +1,12 @@
|
||||
"""
|
||||
消息处理装饰器模块
|
||||
|
||||
提供插件消息处理和定时任务的装饰器
|
||||
使用工厂模式消除重复代码
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, Union
|
||||
from typing import Callable, Dict, Union
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
@@ -8,15 +15,16 @@ from apscheduler.triggers.interval import IntervalTrigger
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
|
||||
# ==================== 定时任务装饰器 ====================
|
||||
|
||||
def schedule(
|
||||
trigger: Union[str, CronTrigger, IntervalTrigger],
|
||||
**trigger_args
|
||||
) -> Callable:
|
||||
"""
|
||||
定时任务装饰器
|
||||
|
||||
例子:
|
||||
|
||||
例子:
|
||||
- @schedule('interval', seconds=30)
|
||||
- @schedule('cron', hour=8, minute=30, second=30)
|
||||
- @schedule('date', run_date='2024-01-01 00:00:00')
|
||||
@@ -44,23 +52,16 @@ def add_job_safe(scheduler: AsyncIOScheduler, job_id: str, func: Callable, bot,
|
||||
"""添加函数到定时任务中,如果存在则先删除现有的任务"""
|
||||
try:
|
||||
scheduler.remove_job(job_id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 读取调度器配置
|
||||
|
||||
# 使用统一配置管理器读取调度器配置
|
||||
try:
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
config_path = Path("main_config.toml")
|
||||
if config_path.exists():
|
||||
with open(config_path, "rb") as f:
|
||||
config = tomllib.load(f)
|
||||
scheduler_config = config.get("Scheduler", {})
|
||||
else:
|
||||
scheduler_config = {}
|
||||
except:
|
||||
from utils.config_manager import get_scheduler_config
|
||||
scheduler_config = get_scheduler_config()
|
||||
except Exception:
|
||||
scheduler_config = {}
|
||||
|
||||
|
||||
# 应用调度器配置
|
||||
job_kwargs = {
|
||||
"coalesce": scheduler_config.get("coalesce", True),
|
||||
@@ -68,7 +69,7 @@ def add_job_safe(scheduler: AsyncIOScheduler, job_id: str, func: Callable, bot,
|
||||
"misfire_grace_time": scheduler_config.get("misfire_grace_time", 30)
|
||||
}
|
||||
job_kwargs.update(trigger_args)
|
||||
|
||||
|
||||
scheduler.add_job(func, trigger, args=[bot], id=job_id, **job_kwargs)
|
||||
|
||||
|
||||
@@ -76,182 +77,106 @@ def remove_job_safe(scheduler: AsyncIOScheduler, job_id: str):
|
||||
"""从定时任务中移除任务"""
|
||||
try:
|
||||
scheduler.remove_job(job_id)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def on_text_message(priority=50):
|
||||
"""文本消息装饰器"""
|
||||
# ==================== 消息装饰器工厂 ====================
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority): # 无参数调用时
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'text_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
# 有参数调用时
|
||||
setattr(func, '_event_type', 'text_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
def _create_message_decorator(event_type: str, description: str):
|
||||
"""
|
||||
消息装饰器工厂函数
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
生成支持两种调用方式的装饰器:
|
||||
- @on_xxx_message (无参数,使用默认优先级50)
|
||||
- @on_xxx_message(priority=80) (有参数,自定义优先级)
|
||||
|
||||
Args:
|
||||
event_type: 事件类型字符串,如 'text_message'
|
||||
description: 装饰器描述,用于生成文档字符串
|
||||
|
||||
def on_image_message(priority=50):
|
||||
"""图片消息装饰器"""
|
||||
Returns:
|
||||
装饰器函数
|
||||
"""
|
||||
def decorator_factory(priority=50):
|
||||
def decorator(func):
|
||||
# 处理无参数调用: @on_xxx_message 时 priority 实际是被装饰的函数
|
||||
if callable(priority):
|
||||
target_func = priority
|
||||
actual_priority = 50
|
||||
else:
|
||||
target_func = func
|
||||
actual_priority = min(max(priority, 0), 99)
|
||||
|
||||
def decorator(func):
|
||||
setattr(target_func, '_event_type', event_type)
|
||||
setattr(target_func, '_priority', actual_priority)
|
||||
return target_func
|
||||
|
||||
# 判断调用方式
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'image_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'image_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
return decorator(priority)
|
||||
return decorator
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
decorator_factory.__doc__ = f"{description}装饰器"
|
||||
decorator_factory.__name__ = f"on_{event_type}"
|
||||
return decorator_factory
|
||||
|
||||
|
||||
def on_voice_message(priority=50):
|
||||
"""语音消息装饰器"""
|
||||
# ==================== 消息类型定义 ====================
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'voice_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'voice_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
# 事件类型 -> 中文描述 映射表
|
||||
MESSAGE_DECORATOR_TYPES: Dict[str, str] = {
|
||||
'text_message': '文本消息',
|
||||
'image_message': '图片消息',
|
||||
'voice_message': '语音消息',
|
||||
'video_message': '视频消息',
|
||||
'emoji_message': '表情消息',
|
||||
'file_message': '文件消息',
|
||||
'quote_message': '引用消息',
|
||||
'pat_message': '拍一拍',
|
||||
'at_message': '@消息',
|
||||
'system_message': '系统消息',
|
||||
'other_message': '其他消息',
|
||||
}
|
||||
|
||||
|
||||
def on_emoji_message(priority=50):
|
||||
"""表情消息装饰器"""
|
||||
# ==================== 生成所有消息装饰器 ====================
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'emoji_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'emoji_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
# 使用工厂函数生成装饰器
|
||||
on_text_message = _create_message_decorator('text_message', '文本消息')
|
||||
on_image_message = _create_message_decorator('image_message', '图片消息')
|
||||
on_voice_message = _create_message_decorator('voice_message', '语音消息')
|
||||
on_video_message = _create_message_decorator('video_message', '视频消息')
|
||||
on_emoji_message = _create_message_decorator('emoji_message', '表情消息')
|
||||
on_file_message = _create_message_decorator('file_message', '文件消息')
|
||||
on_quote_message = _create_message_decorator('quote_message', '引用消息')
|
||||
on_pat_message = _create_message_decorator('pat_message', '拍一拍')
|
||||
on_at_message = _create_message_decorator('at_message', '@消息')
|
||||
on_system_message = _create_message_decorator('system_message', '系统消息')
|
||||
on_other_message = _create_message_decorator('other_message', '其他消息')
|
||||
|
||||
|
||||
def on_file_message(priority=50):
|
||||
"""文件消息装饰器"""
|
||||
# ==================== 导出列表 ====================
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'file_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'file_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
|
||||
|
||||
def on_quote_message(priority=50):
|
||||
"""引用消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'quote_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'quote_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
|
||||
|
||||
def on_video_message(priority=50):
|
||||
"""视频消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'video_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'video_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
|
||||
|
||||
def on_pat_message(priority=50):
|
||||
"""拍一拍消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'pat_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'pat_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
|
||||
|
||||
def on_at_message(priority=50):
|
||||
"""被@消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'at_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'at_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
|
||||
|
||||
def on_system_message(priority=50):
|
||||
"""其他消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'system_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'other_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
|
||||
|
||||
def on_other_message(priority=50):
|
||||
"""其他消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'other_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'other_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
__all__ = [
|
||||
# 定时任务
|
||||
'scheduler',
|
||||
'schedule',
|
||||
'add_job_safe',
|
||||
'remove_job_safe',
|
||||
# 消息装饰器
|
||||
'on_text_message',
|
||||
'on_image_message',
|
||||
'on_voice_message',
|
||||
'on_video_message',
|
||||
'on_emoji_message',
|
||||
'on_file_message',
|
||||
'on_quote_message',
|
||||
'on_pat_message',
|
||||
'on_at_message',
|
||||
'on_system_message',
|
||||
'on_other_message',
|
||||
# 工具
|
||||
'MESSAGE_DECORATOR_TYPES',
|
||||
'_create_message_decorator',
|
||||
]
|
||||
|
||||
438
utils/errors.py
Normal file
438
utils/errors.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""
|
||||
统一错误处理模块
|
||||
|
||||
提供:
|
||||
- 自定义异常类层次结构
|
||||
- 错误包装和转换
|
||||
- 用户友好的错误消息
|
||||
- 错误日志和追踪
|
||||
|
||||
使用示例:
|
||||
from utils.errors import PluginError, ToolExecutionError, handle_error
|
||||
|
||||
try:
|
||||
await some_operation()
|
||||
except Exception as e:
|
||||
result = handle_error(e, context="执行工具")
|
||||
# result = {"success": False, "error": "...", "error_type": "..."}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
# ==================== 错误类型枚举 ====================
|
||||
|
||||
class ErrorType(Enum):
|
||||
"""错误类型分类"""
|
||||
UNKNOWN = "unknown"
|
||||
PLUGIN = "plugin"
|
||||
TOOL = "tool"
|
||||
CONFIG = "config"
|
||||
NETWORK = "network"
|
||||
TIMEOUT = "timeout"
|
||||
VALIDATION = "validation"
|
||||
PERMISSION = "permission"
|
||||
RESOURCE = "resource"
|
||||
|
||||
|
||||
# ==================== 自定义异常基类 ====================
|
||||
|
||||
class BotError(Exception):
|
||||
"""机器人错误基类"""
|
||||
|
||||
error_type: ErrorType = ErrorType.UNKNOWN
|
||||
user_message: str = "发生了一个错误"
|
||||
log_level: str = "error"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
user_message: str = None,
|
||||
cause: Exception = None,
|
||||
context: Dict[str, Any] = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self._user_message = user_message
|
||||
self.cause = cause
|
||||
self.context = context or {}
|
||||
|
||||
def get_user_message(self) -> str:
|
||||
"""获取用户友好的错误消息"""
|
||||
return self._user_message or self.user_message
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典(用于 API 响应)"""
|
||||
return {
|
||||
"success": False,
|
||||
"error": self.get_user_message(),
|
||||
"error_type": self.error_type.value,
|
||||
"details": self.message if self.message != self.get_user_message() else None,
|
||||
}
|
||||
|
||||
|
||||
# ==================== 具体异常类 ====================
|
||||
|
||||
class PluginError(BotError):
|
||||
"""插件相关错误"""
|
||||
error_type = ErrorType.PLUGIN
|
||||
user_message = "插件执行出错"
|
||||
|
||||
|
||||
class PluginLoadError(PluginError):
|
||||
"""插件加载错误"""
|
||||
user_message = "插件加载失败"
|
||||
|
||||
|
||||
class PluginNotFoundError(PluginError):
|
||||
"""插件未找到"""
|
||||
user_message = "找不到指定的插件"
|
||||
|
||||
|
||||
class ToolExecutionError(BotError):
|
||||
"""工具执行错误"""
|
||||
error_type = ErrorType.TOOL
|
||||
user_message = "工具执行失败"
|
||||
|
||||
|
||||
class ToolNotFoundError(ToolExecutionError):
|
||||
"""工具未找到"""
|
||||
user_message = "找不到指定的工具"
|
||||
|
||||
|
||||
class ToolTimeoutError(ToolExecutionError):
|
||||
"""工具执行超时"""
|
||||
error_type = ErrorType.TIMEOUT
|
||||
user_message = "工具执行超时"
|
||||
|
||||
|
||||
class ConfigError(BotError):
|
||||
"""配置相关错误"""
|
||||
error_type = ErrorType.CONFIG
|
||||
user_message = "配置错误"
|
||||
|
||||
|
||||
class ConfigNotFoundError(ConfigError):
|
||||
"""配置项未找到"""
|
||||
user_message = "找不到配置项"
|
||||
|
||||
|
||||
class ConfigValidationError(ConfigError):
|
||||
"""配置验证错误"""
|
||||
error_type = ErrorType.VALIDATION
|
||||
user_message = "配置格式不正确"
|
||||
|
||||
|
||||
class NetworkError(BotError):
|
||||
"""网络相关错误"""
|
||||
error_type = ErrorType.NETWORK
|
||||
user_message = "网络请求失败"
|
||||
|
||||
|
||||
class APIError(NetworkError):
|
||||
"""API 调用错误"""
|
||||
user_message = "API 调用失败"
|
||||
|
||||
|
||||
class ValidationError(BotError):
|
||||
"""验证错误"""
|
||||
error_type = ErrorType.VALIDATION
|
||||
user_message = "参数验证失败"
|
||||
|
||||
|
||||
class PermissionError(BotError):
|
||||
"""权限错误"""
|
||||
error_type = ErrorType.PERMISSION
|
||||
user_message = "没有权限执行此操作"
|
||||
|
||||
|
||||
class ResourceError(BotError):
|
||||
"""资源错误(内存、文件等)"""
|
||||
error_type = ErrorType.RESOURCE
|
||||
user_message = "资源访问失败"
|
||||
|
||||
|
||||
# ==================== 错误处理工具函数 ====================
|
||||
|
||||
@dataclass
|
||||
class ErrorResult:
|
||||
"""错误处理结果"""
|
||||
success: bool = False
|
||||
error: str = ""
|
||||
error_type: str = "unknown"
|
||||
details: Optional[str] = None
|
||||
logged: bool = False
|
||||
original_exception: Optional[Exception] = field(default=None, repr=False)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
result = {
|
||||
"success": self.success,
|
||||
"error": self.error,
|
||||
"error_type": self.error_type,
|
||||
}
|
||||
if self.details:
|
||||
result["details"] = self.details
|
||||
return result
|
||||
|
||||
|
||||
def handle_error(
|
||||
exception: Exception,
|
||||
context: str = "",
|
||||
log: bool = True,
|
||||
include_traceback: bool = False,
|
||||
) -> ErrorResult:
|
||||
"""
|
||||
统一错误处理函数
|
||||
|
||||
Args:
|
||||
exception: 捕获的异常
|
||||
context: 错误上下文描述
|
||||
log: 是否记录日志
|
||||
include_traceback: 是否包含完整堆栈
|
||||
|
||||
Returns:
|
||||
ErrorResult 对象
|
||||
"""
|
||||
# 处理自定义异常
|
||||
if isinstance(exception, BotError):
|
||||
result = ErrorResult(
|
||||
success=False,
|
||||
error=exception.get_user_message(),
|
||||
error_type=exception.error_type.value,
|
||||
details=exception.message if exception.message != exception.get_user_message() else None,
|
||||
original_exception=exception,
|
||||
)
|
||||
if log:
|
||||
log_func = getattr(logger, exception.log_level, logger.error)
|
||||
log_func(f"[{context}] {exception.error_type.value}: {exception.message}")
|
||||
result.logged = True
|
||||
return result
|
||||
|
||||
# 处理标准超时异常
|
||||
import asyncio
|
||||
if isinstance(exception, asyncio.TimeoutError):
|
||||
result = ErrorResult(
|
||||
success=False,
|
||||
error="操作超时",
|
||||
error_type=ErrorType.TIMEOUT.value,
|
||||
original_exception=exception,
|
||||
)
|
||||
if log:
|
||||
logger.warning(f"[{context}] 超时: {exception}")
|
||||
result.logged = True
|
||||
return result
|
||||
|
||||
# 处理连接错误
|
||||
if isinstance(exception, (ConnectionError, OSError)):
|
||||
result = ErrorResult(
|
||||
success=False,
|
||||
error="网络连接失败",
|
||||
error_type=ErrorType.NETWORK.value,
|
||||
details=str(exception),
|
||||
original_exception=exception,
|
||||
)
|
||||
if log:
|
||||
logger.error(f"[{context}] 网络错误: {exception}")
|
||||
result.logged = True
|
||||
return result
|
||||
|
||||
# 处理验证错误
|
||||
if isinstance(exception, (ValueError, TypeError)):
|
||||
result = ErrorResult(
|
||||
success=False,
|
||||
error="参数错误",
|
||||
error_type=ErrorType.VALIDATION.value,
|
||||
details=str(exception),
|
||||
original_exception=exception,
|
||||
)
|
||||
if log:
|
||||
logger.warning(f"[{context}] 验证错误: {exception}")
|
||||
result.logged = True
|
||||
return result
|
||||
|
||||
# 处理未知错误
|
||||
error_msg = str(exception) or exception.__class__.__name__
|
||||
details = None
|
||||
if include_traceback:
|
||||
details = traceback.format_exc()
|
||||
|
||||
result = ErrorResult(
|
||||
success=False,
|
||||
error=f"发生错误: {error_msg[:100]}",
|
||||
error_type=ErrorType.UNKNOWN.value,
|
||||
details=details,
|
||||
original_exception=exception,
|
||||
)
|
||||
|
||||
if log:
|
||||
logger.error(f"[{context}] 未知错误: {exception}")
|
||||
if include_traceback:
|
||||
logger.debug(traceback.format_exc())
|
||||
result.logged = True
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def wrap_error(
|
||||
exception: Exception,
|
||||
error_class: Type[BotError],
|
||||
message: str = None,
|
||||
user_message: str = None,
|
||||
) -> BotError:
|
||||
"""
|
||||
将标准异常包装为自定义异常
|
||||
|
||||
Args:
|
||||
exception: 原始异常
|
||||
error_class: 目标异常类
|
||||
message: 错误消息
|
||||
user_message: 用户友好消息
|
||||
|
||||
Returns:
|
||||
包装后的 BotError 子类实例
|
||||
"""
|
||||
msg = message or str(exception)
|
||||
return error_class(
|
||||
message=msg,
|
||||
user_message=user_message,
|
||||
cause=exception,
|
||||
)
|
||||
|
||||
|
||||
def safe_error_message(exception: Exception, max_length: int = 200) -> str:
|
||||
"""
|
||||
获取安全的错误消息(截断过长内容,移除敏感信息)
|
||||
|
||||
Args:
|
||||
exception: 异常对象
|
||||
max_length: 最大长度
|
||||
|
||||
Returns:
|
||||
安全的错误消息字符串
|
||||
"""
|
||||
msg = str(exception)
|
||||
|
||||
# 移除可能的敏感信息模式
|
||||
sensitive_patterns = [
|
||||
r'api[_-]?key[=:]\s*\S+',
|
||||
r'password[=:]\s*\S+',
|
||||
r'token[=:]\s*\S+',
|
||||
r'secret[=:]\s*\S+',
|
||||
]
|
||||
|
||||
import re
|
||||
for pattern in sensitive_patterns:
|
||||
msg = re.sub(pattern, '[REDACTED]', msg, flags=re.IGNORECASE)
|
||||
|
||||
# 截断
|
||||
if len(msg) > max_length:
|
||||
msg = msg[:max_length] + "..."
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
# ==================== 装饰器 ====================
|
||||
|
||||
def catch_errors(
|
||||
error_class: Type[BotError] = BotError,
|
||||
context: str = "",
|
||||
log: bool = True,
|
||||
reraise: bool = False,
|
||||
):
|
||||
"""
|
||||
错误捕获装饰器
|
||||
|
||||
Args:
|
||||
error_class: 转换为的错误类
|
||||
context: 上下文描述
|
||||
log: 是否记录日志
|
||||
reraise: 是否重新抛出
|
||||
|
||||
Usage:
|
||||
@catch_errors(ToolExecutionError, context="执行工具")
|
||||
async def my_tool():
|
||||
...
|
||||
"""
|
||||
def decorator(func):
|
||||
import asyncio
|
||||
import functools
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except BotError:
|
||||
if reraise:
|
||||
raise
|
||||
return None
|
||||
except Exception as e:
|
||||
ctx = context or func.__name__
|
||||
handle_error(e, context=ctx, log=log)
|
||||
if reraise:
|
||||
raise wrap_error(e, error_class) from e
|
||||
return None
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except BotError:
|
||||
if reraise:
|
||||
raise
|
||||
return None
|
||||
except Exception as e:
|
||||
ctx = context or func.__name__
|
||||
handle_error(e, context=ctx, log=log)
|
||||
if reraise:
|
||||
raise wrap_error(e, error_class) from e
|
||||
return None
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# ==================== 导出列表 ====================
|
||||
|
||||
__all__ = [
|
||||
# 枚举
|
||||
'ErrorType',
|
||||
# 异常基类
|
||||
'BotError',
|
||||
# 插件异常
|
||||
'PluginError',
|
||||
'PluginLoadError',
|
||||
'PluginNotFoundError',
|
||||
# 工具异常
|
||||
'ToolExecutionError',
|
||||
'ToolNotFoundError',
|
||||
'ToolTimeoutError',
|
||||
# 配置异常
|
||||
'ConfigError',
|
||||
'ConfigNotFoundError',
|
||||
'ConfigValidationError',
|
||||
# 网络异常
|
||||
'NetworkError',
|
||||
'APIError',
|
||||
# 其他异常
|
||||
'ValidationError',
|
||||
'PermissionError',
|
||||
'ResourceError',
|
||||
# 工具函数
|
||||
'ErrorResult',
|
||||
'handle_error',
|
||||
'wrap_error',
|
||||
'safe_error_message',
|
||||
# 装饰器
|
||||
'catch_errors',
|
||||
]
|
||||
@@ -1,71 +1,315 @@
|
||||
import copy
|
||||
from typing import Callable, Dict, List
|
||||
"""
|
||||
事件管理器模块
|
||||
|
||||
提供事件的注册、分发和管理:
|
||||
- 优先级事件处理
|
||||
- 处理器缓存优化
|
||||
- 事件统计
|
||||
- 异常隔离
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandlerInfo:
|
||||
"""事件处理器信息"""
|
||||
handler: Callable
|
||||
instance: object
|
||||
priority: int
|
||||
handler_name: str = field(default="")
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.handler_name:
|
||||
self.handler_name = f"{self.instance.__class__.__name__}.{self.handler.__name__}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventStats:
|
||||
"""事件统计信息"""
|
||||
emit_count: int = 0
|
||||
handler_calls: int = 0
|
||||
total_time_ms: float = 0
|
||||
error_count: int = 0
|
||||
stopped_count: int = 0 # 被 return False 中断的次数
|
||||
|
||||
|
||||
class EventManager:
|
||||
_handlers: Dict[str, List[tuple[Callable, object, int]]] = {}
|
||||
"""
|
||||
事件管理器
|
||||
|
||||
特性:
|
||||
- 优先级排序(高优先级先执行)
|
||||
- 处理器可返回 False 中断后续处理
|
||||
- 异常隔离(单个处理器异常不影响其他)
|
||||
- 性能统计
|
||||
"""
|
||||
|
||||
# 类级别存储
|
||||
_handlers: Dict[str, List[HandlerInfo]] = {}
|
||||
_stats: Dict[str, EventStats] = defaultdict(EventStats)
|
||||
_handler_cache: Dict[str, List[HandlerInfo]] = {} # 排序后的缓存
|
||||
_cache_valid: Set[str] = set()
|
||||
|
||||
@classmethod
|
||||
def bind_instance(cls, instance: object):
|
||||
"""将实例绑定到对应的事件处理函数"""
|
||||
from loguru import logger
|
||||
registered_count = 0
|
||||
for method_name in dir(instance):
|
||||
method = getattr(instance, method_name)
|
||||
if hasattr(method, '_event_type'):
|
||||
event_type = getattr(method, '_event_type')
|
||||
priority = getattr(method, '_priority', 50)
|
||||
"""
|
||||
绑定实例的事件处理方法
|
||||
|
||||
if event_type not in cls._handlers:
|
||||
cls._handlers[event_type] = []
|
||||
cls._handlers[event_type].append((method, instance, priority))
|
||||
# 按优先级排序,优先级高的在前
|
||||
cls._handlers[event_type].sort(key=lambda x: x[2], reverse=True)
|
||||
registered_count += 1
|
||||
logger.debug(f"[EventManager] 注册事件处理器: {instance.__class__.__name__}.{method_name} -> {event_type} (优先级={priority})")
|
||||
扫描实例的所有方法,将带有 _event_type 属性的方法注册为事件处理器。
|
||||
|
||||
Args:
|
||||
instance: 插件实例
|
||||
"""
|
||||
registered_count = 0
|
||||
|
||||
for method_name in dir(instance):
|
||||
if method_name.startswith('_'):
|
||||
continue
|
||||
|
||||
try:
|
||||
method = getattr(instance, method_name)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not callable(method) or not hasattr(method, '_event_type'):
|
||||
continue
|
||||
|
||||
event_type = getattr(method, '_event_type')
|
||||
priority = getattr(method, '_priority', 50)
|
||||
|
||||
handler_info = HandlerInfo(
|
||||
handler=method,
|
||||
instance=instance,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
if event_type not in cls._handlers:
|
||||
cls._handlers[event_type] = []
|
||||
|
||||
cls._handlers[event_type].append(handler_info)
|
||||
|
||||
# 使缓存失效
|
||||
cls._cache_valid.discard(event_type)
|
||||
|
||||
registered_count += 1
|
||||
logger.debug(
|
||||
f"[EventManager] 注册: {handler_info.handler_name} -> "
|
||||
f"{event_type} (优先级={priority})"
|
||||
)
|
||||
|
||||
if registered_count > 0:
|
||||
logger.success(f"[EventManager] {instance.__class__.__name__} 注册了 {registered_count} 个事件处理器")
|
||||
|
||||
@classmethod
|
||||
async def emit(cls, event_type: str, *args, **kwargs) -> None:
|
||||
"""触发事件"""
|
||||
from loguru import logger
|
||||
|
||||
if event_type not in cls._handlers:
|
||||
logger.debug(f"[EventManager] 事件 {event_type} 没有注册的处理器")
|
||||
return
|
||||
|
||||
logger.debug(f"[EventManager] 触发事件: {event_type}, 处理器数量: {len(cls._handlers[event_type])}")
|
||||
|
||||
api_client, message = args
|
||||
for handler, instance, priority in cls._handlers[event_type]:
|
||||
try:
|
||||
logger.debug(f"[EventManager] 调用处理器: {instance.__class__.__name__}.{handler.__name__}")
|
||||
# 不再深拷贝message,让所有处理器共享同一个消息对象
|
||||
# 这样AutoReply设置的标记可以传递给AIChat
|
||||
handler_args = (api_client, message)
|
||||
new_kwargs = kwargs # kwargs也不需要深拷贝
|
||||
|
||||
result = await handler(*handler_args, **new_kwargs)
|
||||
|
||||
if isinstance(result, bool):
|
||||
# True 继续执行 False 停止执行
|
||||
if not result:
|
||||
break
|
||||
else:
|
||||
continue # 我也不知道你返回了个啥玩意,反正继续执行就是了
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"处理器 {handler.__name__} 执行失败: {e}")
|
||||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||||
logger.success(
|
||||
f"[EventManager] {instance.__class__.__name__} "
|
||||
f"注册了 {registered_count} 个事件处理器"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unbind_instance(cls, instance: object):
|
||||
"""解绑实例的所有事件处理函数"""
|
||||
for event_type in cls._handlers:
|
||||
"""
|
||||
解绑实例的所有事件处理器
|
||||
|
||||
Args:
|
||||
instance: 插件实例
|
||||
"""
|
||||
unbound_count = 0
|
||||
|
||||
for event_type in list(cls._handlers.keys()):
|
||||
original_count = len(cls._handlers[event_type])
|
||||
cls._handlers[event_type] = [
|
||||
(handler, inst, priority)
|
||||
for handler, inst, priority in cls._handlers[event_type]
|
||||
if inst is not instance
|
||||
h for h in cls._handlers[event_type]
|
||||
if h.instance is not instance
|
||||
]
|
||||
removed = original_count - len(cls._handlers[event_type])
|
||||
if removed > 0:
|
||||
unbound_count += removed
|
||||
cls._cache_valid.discard(event_type)
|
||||
|
||||
# 清理空列表
|
||||
if not cls._handlers[event_type]:
|
||||
del cls._handlers[event_type]
|
||||
|
||||
if unbound_count > 0:
|
||||
logger.debug(
|
||||
f"[EventManager] {instance.__class__.__name__} "
|
||||
f"解绑了 {unbound_count} 个事件处理器"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_sorted_handlers(cls, event_type: str) -> List[HandlerInfo]:
|
||||
"""获取排序后的处理器列表(带缓存)"""
|
||||
if event_type not in cls._cache_valid:
|
||||
handlers = cls._handlers.get(event_type, [])
|
||||
# 按优先级降序排序
|
||||
cls._handler_cache[event_type] = sorted(
|
||||
handlers,
|
||||
key=lambda h: h.priority,
|
||||
reverse=True
|
||||
)
|
||||
cls._cache_valid.add(event_type)
|
||||
|
||||
return cls._handler_cache.get(event_type, [])
|
||||
|
||||
@classmethod
|
||||
async def emit(cls, event_type: str, *args, **kwargs) -> bool:
|
||||
"""
|
||||
触发事件
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
*args: 传递给处理器的位置参数(通常是 api_client, message)
|
||||
**kwargs: 传递给处理器的关键字参数
|
||||
|
||||
Returns:
|
||||
True 表示所有处理器都执行了,False 表示被中断
|
||||
"""
|
||||
handlers = cls._get_sorted_handlers(event_type)
|
||||
|
||||
if not handlers:
|
||||
logger.debug(f"[EventManager] 事件 {event_type} 没有处理器")
|
||||
return True
|
||||
|
||||
# 更新统计
|
||||
stats = cls._stats[event_type]
|
||||
stats.emit_count += 1
|
||||
|
||||
start_time = time.time()
|
||||
all_completed = True
|
||||
|
||||
logger.debug(
|
||||
f"[EventManager] 触发: {event_type}, "
|
||||
f"处理器数量: {len(handlers)}"
|
||||
)
|
||||
|
||||
for handler_info in handlers:
|
||||
stats.handler_calls += 1
|
||||
|
||||
try:
|
||||
logger.debug(f"[EventManager] 调用: {handler_info.handler_name}")
|
||||
|
||||
result = await handler_info.handler(*args, **kwargs)
|
||||
|
||||
# 检查是否中断
|
||||
if result is False:
|
||||
stats.stopped_count += 1
|
||||
all_completed = False
|
||||
logger.debug(
|
||||
f"[EventManager] {handler_info.handler_name} "
|
||||
f"返回 False,中断事件处理"
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
stats.error_count += 1
|
||||
logger.error(
|
||||
f"[EventManager] {handler_info.handler_name} 执行失败: {e}"
|
||||
)
|
||||
logger.debug(f"详细错误:\n{traceback.format_exc()}")
|
||||
# 继续执行其他处理器
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
stats.total_time_ms += elapsed_ms
|
||||
|
||||
return all_completed
|
||||
|
||||
@classmethod
|
||||
async def emit_parallel(
|
||||
cls,
|
||||
event_type: str,
|
||||
*args,
|
||||
max_concurrency: int = 5,
|
||||
**kwargs
|
||||
) -> List[Any]:
|
||||
"""
|
||||
并行触发事件(忽略优先级和中断)
|
||||
|
||||
适用于不需要顺序执行的场景。
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
max_concurrency: 最大并发数
|
||||
*args, **kwargs: 传递给处理器的参数
|
||||
|
||||
Returns:
|
||||
所有处理器的返回值列表
|
||||
"""
|
||||
handlers = cls._get_sorted_handlers(event_type)
|
||||
|
||||
if not handlers:
|
||||
return []
|
||||
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def run_handler(handler_info: HandlerInfo):
|
||||
async with semaphore:
|
||||
try:
|
||||
return await handler_info.handler(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"[EventManager] {handler_info.handler_name} 失败: {e}")
|
||||
return None
|
||||
|
||||
tasks = [run_handler(h) for h in handlers]
|
||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
@classmethod
|
||||
def get_handlers(cls, event_type: str) -> List[str]:
|
||||
"""获取事件的所有处理器名称"""
|
||||
handlers = cls._get_sorted_handlers(event_type)
|
||||
return [h.handler_name for h in handlers]
|
||||
|
||||
@classmethod
|
||||
def get_all_events(cls) -> List[str]:
|
||||
"""获取所有已注册的事件类型"""
|
||||
return list(cls._handlers.keys())
|
||||
|
||||
@classmethod
|
||||
def get_stats(cls, event_type: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
获取事件统计信息
|
||||
|
||||
Args:
|
||||
event_type: 指定事件类型,None 返回所有
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
if event_type:
|
||||
stats = cls._stats.get(event_type, EventStats())
|
||||
return {
|
||||
"emit_count": stats.emit_count,
|
||||
"handler_calls": stats.handler_calls,
|
||||
"total_time_ms": stats.total_time_ms,
|
||||
"avg_time_ms": stats.total_time_ms / max(stats.emit_count, 1),
|
||||
"error_count": stats.error_count,
|
||||
"stopped_count": stats.stopped_count,
|
||||
}
|
||||
|
||||
return {
|
||||
event: cls.get_stats(event)
|
||||
for event in cls._stats.keys()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def reset_stats(cls):
|
||||
"""重置所有统计"""
|
||||
cls._stats.clear()
|
||||
|
||||
@classmethod
|
||||
def clear(cls):
|
||||
"""清除所有处理器和统计(用于测试)"""
|
||||
cls._handlers.clear()
|
||||
cls._handler_cache.clear()
|
||||
cls._cache_valid.clear()
|
||||
cls._stats.clear()
|
||||
|
||||
|
||||
# ==================== 导出 ====================
|
||||
|
||||
__all__ = ['EventManager', 'HandlerInfo', 'EventStats']
|
||||
|
||||
247
utils/hookbot.py
247
utils/hookbot.py
@@ -2,25 +2,48 @@
|
||||
HookBot - 机器人核心类
|
||||
|
||||
处理消息路由和事件分发
|
||||
职责单一化:仅负责消息流程编排,具体功能委托给专门模块
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import tomllib
|
||||
import time
|
||||
from typing import Dict, Any
|
||||
import random
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from WechatHook import WechatHookClient, MESSAGE_TYPE_MAP, normalize_message
|
||||
from utils.event_manager import EventManager
|
||||
from utils.config_manager import get_bot_config, get_performance_config
|
||||
from utils.message_filter import MessageFilter
|
||||
from utils.message_dedup import MessageDeduplicator
|
||||
from utils.message_stats import MessageStats
|
||||
|
||||
|
||||
class HookBot:
|
||||
"""
|
||||
HookBot 核心类
|
||||
|
||||
负责消息处理、路由和事件分发
|
||||
负责消息处理流程编排:
|
||||
1. 接收消息
|
||||
2. 去重检查
|
||||
3. 格式转换
|
||||
4. 过滤检查
|
||||
5. 事件分发
|
||||
|
||||
具体功能委托给:
|
||||
- MessageDeduplicator: 消息去重
|
||||
- MessageFilter: 消息过滤
|
||||
- MessageStats: 消息统计
|
||||
"""
|
||||
|
||||
# API 响应消息类型(需要忽略)
|
||||
API_RESPONSE_TYPES = {11032, 11174, 11230}
|
||||
|
||||
# 重要消息类型(始终记录日志)
|
||||
IMPORTANT_MESSAGE_TYPES = {
|
||||
11058, 11098, 11099, 11025, # 系统消息、群成员变动、登录信息
|
||||
11051, 11047, 11052, 11055 # 视频、图片、表情、文件消息
|
||||
}
|
||||
|
||||
def __init__(self, client: WechatHookClient):
|
||||
"""
|
||||
初始化 HookBot
|
||||
@@ -29,92 +52,33 @@ class HookBot:
|
||||
client: WechatHookClient 实例
|
||||
"""
|
||||
self.client = client
|
||||
self.wxid = None
|
||||
self.nickname = None
|
||||
self.wxid: Optional[str] = None
|
||||
self.nickname: Optional[str] = None
|
||||
|
||||
# 读取配置
|
||||
with open("main_config.toml", "rb") as f:
|
||||
main_config = tomllib.load(f)
|
||||
# 加载配置
|
||||
bot_config = get_bot_config()
|
||||
perf_config = get_performance_config()
|
||||
|
||||
bot_config = main_config.get("Bot", {})
|
||||
# 预设机器人信息
|
||||
preset_wxid = bot_config.get("wxid") or bot_config.get("bot_wxid")
|
||||
preset_nickname = bot_config.get("nickname") or bot_config.get("bot_nickname")
|
||||
|
||||
if preset_wxid:
|
||||
self.wxid = preset_wxid
|
||||
logger.info(f"使用配置中的机器人 wxid: {self.wxid}")
|
||||
if preset_nickname:
|
||||
self.nickname = preset_nickname
|
||||
logger.info(f"使用配置中的机器人昵称: {self.nickname}")
|
||||
self.ignore_mode = bot_config.get("ignore-mode", "None")
|
||||
self.whitelist = bot_config.get("whitelist", [])
|
||||
self.blacklist = bot_config.get("blacklist", [])
|
||||
|
||||
# 性能配置
|
||||
perf_config = main_config.get("Performance", {})
|
||||
# 日志采样率
|
||||
self.log_sampling_rate = perf_config.get("log_sampling_rate", 1.0)
|
||||
|
||||
# 消息去重(部分环境会重复回调同一条消息,导致插件回复两次)
|
||||
self._dedup_ttl_seconds = perf_config.get("dedup_ttl_seconds", 30)
|
||||
self._dedup_max_size = perf_config.get("dedup_max_size", 5000)
|
||||
self._dedup_lock = asyncio.Lock()
|
||||
self._recent_message_keys: Dict[str, float] = {}
|
||||
|
||||
# 消息计数和统计
|
||||
self.message_count = 0
|
||||
self.filtered_count = 0
|
||||
self.processed_count = 0
|
||||
# 初始化组件(职责委托)
|
||||
self._filter = MessageFilter.from_config(bot_config)
|
||||
self._dedup = MessageDeduplicator.from_config(perf_config)
|
||||
self._stats = MessageStats()
|
||||
|
||||
logger.info("HookBot 初始化完成")
|
||||
|
||||
def _extract_msg_id(self, data: Dict[str, Any]) -> str:
|
||||
"""从原始回调数据中提取消息ID(用于去重)"""
|
||||
for k in ("msgid", "msg_id", "MsgId", "id"):
|
||||
v = data.get(k)
|
||||
if v:
|
||||
return str(v)
|
||||
return ""
|
||||
|
||||
async def _is_duplicate_message(self, msg_type: int, data: Dict[str, Any]) -> bool:
|
||||
"""判断该条消息是否为短时间内重复回调。"""
|
||||
msg_id = self._extract_msg_id(data)
|
||||
if not msg_id:
|
||||
# 没有稳定 msgid 时不做去重,避免误伤(同一秒内同内容可能是用户真实重复发送)
|
||||
return False
|
||||
|
||||
key = f"msgid:{msg_id}"
|
||||
|
||||
now = time.time()
|
||||
ttl = max(float(self._dedup_ttl_seconds or 0), 0.0)
|
||||
if ttl <= 0:
|
||||
return False
|
||||
|
||||
async with self._dedup_lock:
|
||||
last_seen = self._recent_message_keys.get(key)
|
||||
if last_seen is not None and (now - last_seen) < ttl:
|
||||
return True
|
||||
|
||||
# 记录/刷新
|
||||
self._recent_message_keys.pop(key, None)
|
||||
self._recent_message_keys[key] = now
|
||||
|
||||
# 清理过期 key(按插入顺序从旧到新)
|
||||
cutoff = now - ttl
|
||||
while self._recent_message_keys:
|
||||
first_key = next(iter(self._recent_message_keys))
|
||||
if self._recent_message_keys.get(first_key, now) >= cutoff:
|
||||
break
|
||||
self._recent_message_keys.pop(first_key, None)
|
||||
|
||||
# 限制大小,避免长期运行内存增长
|
||||
max_size = int(self._dedup_max_size or 0)
|
||||
if max_size > 0:
|
||||
while len(self._recent_message_keys) > max_size and self._recent_message_keys:
|
||||
first_key = next(iter(self._recent_message_keys))
|
||||
self._recent_message_keys.pop(first_key, None)
|
||||
|
||||
return False
|
||||
|
||||
def update_profile(self, wxid: str, nickname: str):
|
||||
"""
|
||||
更新机器人信息
|
||||
@@ -125,9 +89,10 @@ class HookBot:
|
||||
"""
|
||||
self.wxid = wxid
|
||||
self.nickname = nickname
|
||||
self._filter.set_bot_wxid(wxid)
|
||||
logger.info(f"机器人信息: wxid={wxid}, nickname={nickname}")
|
||||
|
||||
async def process_message(self, msg_type: int, data: dict):
|
||||
async def process_message(self, msg_type: int, data: Dict[str, Any]):
|
||||
"""
|
||||
处理接收到的消息
|
||||
|
||||
@@ -135,131 +100,105 @@ class HookBot:
|
||||
msg_type: 消息类型
|
||||
data: 消息数据
|
||||
"""
|
||||
# 过滤 API 响应消息
|
||||
# - 11032: 获取群成员信息响应
|
||||
# - 11174/11230: 协议/上传等 API 回调
|
||||
if msg_type in [11032, 11174, 11230]:
|
||||
# 1. 过滤 API 响应消息
|
||||
if msg_type in self.API_RESPONSE_TYPES:
|
||||
return
|
||||
|
||||
# 去重:同一条消息重复回调时不再重复触发事件(避免“同一句话回复两次”)
|
||||
# 2. 去重检查
|
||||
try:
|
||||
if await self._is_duplicate_message(msg_type, data):
|
||||
logger.debug(f"[HookBot] 重复消息已丢弃: type={msg_type}, msgid={self._extract_msg_id(data) or 'N/A'}")
|
||||
if await self._dedup.is_duplicate(data):
|
||||
self._stats.record_duplicate()
|
||||
logger.debug(f"[HookBot] 重复消息已丢弃: type={msg_type}")
|
||||
return
|
||||
except Exception as e:
|
||||
# 去重失败不影响主流程
|
||||
logger.debug(f"[HookBot] 消息去重检查失败: {e}")
|
||||
|
||||
# 消息计数
|
||||
self.message_count += 1
|
||||
|
||||
# 日志采样 - 只记录部分消息以减少日志量
|
||||
# 3. 记录收到消息
|
||||
self._stats.record_received()
|
||||
should_log = self._should_log_message(msg_type)
|
||||
|
||||
if should_log:
|
||||
logger.debug(f"处理消息: type={msg_type}")
|
||||
|
||||
# 重要事件始终记录
|
||||
if msg_type in [11098, 11099, 11058]: # 群成员变动、系统消息
|
||||
if msg_type in self.IMPORTANT_MESSAGE_TYPES:
|
||||
logger.info(f"重要事件: type={msg_type}")
|
||||
|
||||
# 获取事件类型
|
||||
# 4. 获取事件类型
|
||||
event_type = MESSAGE_TYPE_MAP.get(msg_type)
|
||||
|
||||
|
||||
if should_log and event_type:
|
||||
logger.info(f"[HookBot] 消息类型映射: {msg_type} -> {event_type}")
|
||||
|
||||
if not event_type:
|
||||
# 记录未知消息类型的详细信息,帮助调试
|
||||
content_preview = str(data.get('raw_msg', data.get('msg', '')))[:200]
|
||||
logger.warning(f"未映射的消息类型: {msg_type}, wx_type: {data.get('wx_type')}, 内容预览: {content_preview}")
|
||||
logger.warning(
|
||||
f"未映射的消息类型: {msg_type}, "
|
||||
f"wx_type: {data.get('wx_type')}, "
|
||||
f"内容预览: {content_preview}"
|
||||
)
|
||||
return
|
||||
|
||||
# 格式转换
|
||||
# 5. 格式转换
|
||||
try:
|
||||
message = normalize_message(msg_type, data)
|
||||
except Exception as e:
|
||||
logger.error(f"格式转换失败: {e}")
|
||||
self._stats.record_error()
|
||||
return
|
||||
|
||||
# 过滤消息
|
||||
if not self._check_filter(message):
|
||||
self.filtered_count += 1
|
||||
# 6. 过滤检查
|
||||
if not self._filter.should_process(message):
|
||||
self._stats.record_filtered()
|
||||
if should_log:
|
||||
logger.debug(f"消息被过滤: {message.get('FromWxid')}")
|
||||
return
|
||||
|
||||
self.processed_count += 1
|
||||
# 7. 记录处理
|
||||
self._stats.record_processed(event_type)
|
||||
|
||||
# 采样记录处理的消息
|
||||
if should_log:
|
||||
content = message.get('Content', '')
|
||||
if len(content) > 50:
|
||||
content = content[:50] + "..."
|
||||
logger.info(f"处理消息: type={event_type}, from={message.get('FromWxid')}, content={content}")
|
||||
logger.info(
|
||||
f"处理消息: type={event_type}, "
|
||||
f"from={message.get('FromWxid')}, "
|
||||
f"content={content}"
|
||||
)
|
||||
|
||||
# 触发事件
|
||||
# 8. 触发事件
|
||||
try:
|
||||
await EventManager.emit(event_type, self.client, message)
|
||||
except Exception as e:
|
||||
logger.error(f"事件处理失败: {e}")
|
||||
self._stats.record_error()
|
||||
|
||||
def _should_log_message(self, msg_type: int) -> bool:
|
||||
"""判断是否应该记录此消息的日志"""
|
||||
# 重要消息类型始终记录
|
||||
important_types = {
|
||||
11058, 11098, 11099, 11025, # 系统消息、群成员变动、登录信息
|
||||
11051, 11047, 11052, 11055 # 视频、图片、表情、文件消息
|
||||
}
|
||||
if msg_type in important_types:
|
||||
if msg_type in self.IMPORTANT_MESSAGE_TYPES:
|
||||
return True
|
||||
|
||||
# 其他消息按采样率记录
|
||||
import random
|
||||
return random.random() < self.log_sampling_rate
|
||||
|
||||
def _check_filter(self, message: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
检查消息是否通过过滤
|
||||
|
||||
Args:
|
||||
message: 消息字典
|
||||
|
||||
Returns:
|
||||
是否通过过滤
|
||||
"""
|
||||
from_wxid = message.get("FromWxid", "")
|
||||
sender_wxid = message.get("SenderWxid", "")
|
||||
msg_type = message.get("MsgType", 0)
|
||||
|
||||
# 系统消息(type=11058)不过滤,因为包含重要的群聊事件信息
|
||||
if msg_type == 11058:
|
||||
return True
|
||||
|
||||
# 过滤机器人自己发送的消息,避免无限循环
|
||||
if self.wxid and (from_wxid == self.wxid or sender_wxid == self.wxid):
|
||||
return False
|
||||
|
||||
# None 模式:处理所有消息
|
||||
if self.ignore_mode == "None":
|
||||
return True
|
||||
|
||||
# Whitelist 模式:仅处理白名单
|
||||
if self.ignore_mode == "Whitelist":
|
||||
return from_wxid in self.whitelist or sender_wxid in self.whitelist
|
||||
|
||||
# Blacklist 模式:屏蔽黑名单
|
||||
if self.ignore_mode == "Blacklist":
|
||||
return from_wxid not in self.blacklist and sender_wxid not in self.blacklist
|
||||
|
||||
return True
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取消息处理统计信息"""
|
||||
return {
|
||||
"total_messages": self.message_count,
|
||||
"filtered_messages": self.filtered_count,
|
||||
"processed_messages": self.processed_count,
|
||||
"filter_rate": self.filtered_count / max(self.message_count, 1),
|
||||
"process_rate": self.processed_count / max(self.message_count, 1)
|
||||
}
|
||||
stats = self._stats.get_stats()
|
||||
stats["dedup"] = self._dedup.get_stats()
|
||||
return stats
|
||||
|
||||
# ==================== 兼容旧接口 ====================
|
||||
|
||||
@property
|
||||
def message_count(self) -> int:
|
||||
"""兼容旧接口:总消息数"""
|
||||
return self._stats.get_stats()["total_messages"]
|
||||
|
||||
@property
|
||||
def filtered_count(self) -> int:
|
||||
"""兼容旧接口:被过滤消息数"""
|
||||
return self._stats.get_stats()["filtered_messages"]
|
||||
|
||||
@property
|
||||
def processed_count(self) -> int:
|
||||
"""兼容旧接口:已处理消息数"""
|
||||
return self._stats.get_stats()["processed_messages"]
|
||||
|
||||
690
utils/image_processor.py
Normal file
690
utils/image_processor.py
Normal file
@@ -0,0 +1,690 @@
|
||||
"""
|
||||
图片/视频处理模块
|
||||
|
||||
提供媒体文件的下载、编码和描述生成:
|
||||
- 图片下载与 base64 编码
|
||||
- 表情包下载与编码
|
||||
- 视频下载与编码
|
||||
- AI 图片/视频描述生成
|
||||
|
||||
使用示例:
|
||||
from utils.image_processor import ImageProcessor, MediaConfig
|
||||
|
||||
config = MediaConfig(
|
||||
api_url="https://api.openai.com/v1/chat/completions",
|
||||
api_key="sk-xxx",
|
||||
model="gpt-4-vision-preview",
|
||||
)
|
||||
processor = ImageProcessor(config)
|
||||
|
||||
# 下载图片
|
||||
image_base64 = await processor.download_image(bot, cdnurl, aeskey)
|
||||
|
||||
# 生成描述
|
||||
description = await processor.generate_description(image_base64, "描述这张图片")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
# 可选代理支持
|
||||
try:
|
||||
from aiohttp_socks import ProxyConnector
|
||||
PROXY_SUPPORT = True
|
||||
except ImportError:
|
||||
PROXY_SUPPORT = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass # bot 类型提示
|
||||
|
||||
|
||||
@dataclass
|
||||
class MediaConfig:
|
||||
"""媒体处理配置"""
|
||||
# API 配置
|
||||
api_url: str = "https://api.openai.com/v1/chat/completions"
|
||||
api_key: str = ""
|
||||
model: str = "gpt-4-vision-preview"
|
||||
timeout: int = 120
|
||||
max_tokens: int = 1000
|
||||
retries: int = 2
|
||||
|
||||
# 代理配置
|
||||
proxy_enabled: bool = False
|
||||
proxy_type: str = "socks5"
|
||||
proxy_host: str = "127.0.0.1"
|
||||
proxy_port: int = 7890
|
||||
proxy_username: str = ""
|
||||
proxy_password: str = ""
|
||||
|
||||
# 视频专用配置
|
||||
video_api_url: str = ""
|
||||
video_model: str = ""
|
||||
video_max_size_mb: int = 20
|
||||
video_timeout: int = 360
|
||||
video_max_tokens: int = 8192
|
||||
|
||||
# 临时目录
|
||||
temp_dir: Optional[Path] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config: Dict[str, Any]) -> "MediaConfig":
|
||||
"""从配置字典创建"""
|
||||
api_config = config.get("api", {})
|
||||
proxy_config = config.get("proxy", {})
|
||||
image_desc_config = config.get("image_description", {})
|
||||
video_config = config.get("video_recognition", {})
|
||||
|
||||
return cls(
|
||||
api_url=api_config.get("url", "https://api.openai.com/v1/chat/completions"),
|
||||
api_key=api_config.get("api_key", ""),
|
||||
model=image_desc_config.get("model", api_config.get("model", "gpt-4-vision-preview")),
|
||||
timeout=api_config.get("timeout", 120),
|
||||
max_tokens=image_desc_config.get("max_tokens", 1000),
|
||||
retries=image_desc_config.get("retries", 2),
|
||||
proxy_enabled=proxy_config.get("enabled", False),
|
||||
proxy_type=proxy_config.get("type", "socks5"),
|
||||
proxy_host=proxy_config.get("host", "127.0.0.1"),
|
||||
proxy_port=proxy_config.get("port", 7890),
|
||||
proxy_username=proxy_config.get("username", ""),
|
||||
proxy_password=proxy_config.get("password", ""),
|
||||
video_api_url=video_config.get("api_url", ""),
|
||||
video_model=video_config.get("model", ""),
|
||||
video_max_size_mb=video_config.get("max_size_mb", 20),
|
||||
video_timeout=video_config.get("timeout", 360),
|
||||
video_max_tokens=video_config.get("max_tokens", 8192),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MediaResult:
|
||||
"""媒体处理结果"""
|
||||
success: bool = False
|
||||
data: str = "" # base64 数据
|
||||
description: str = ""
|
||||
error: Optional[str] = None
|
||||
media_type: str = "image" # image, emoji, video
|
||||
|
||||
|
||||
class ImageProcessor:
|
||||
"""
|
||||
图片/视频处理器
|
||||
|
||||
提供统一的媒体处理接口:
|
||||
- 下载和编码
|
||||
- AI 描述生成
|
||||
- 缓存支持
|
||||
"""
|
||||
|
||||
def __init__(self, config: MediaConfig, temp_dir: Optional[Path] = None):
|
||||
self.config = config
|
||||
self.temp_dir = temp_dir or config.temp_dir or Path("temp")
|
||||
self.temp_dir.mkdir(exist_ok=True)
|
||||
|
||||
def _get_proxy_connector(self) -> Optional[Any]:
|
||||
"""获取代理连接器"""
|
||||
if not self.config.proxy_enabled or not PROXY_SUPPORT:
|
||||
return None
|
||||
|
||||
proxy_type = self.config.proxy_type.upper()
|
||||
if self.config.proxy_username and self.config.proxy_password:
|
||||
proxy_url = (
|
||||
f"{proxy_type}://{self.config.proxy_username}:"
|
||||
f"{self.config.proxy_password}@"
|
||||
f"{self.config.proxy_host}:{self.config.proxy_port}"
|
||||
)
|
||||
else:
|
||||
proxy_url = f"{proxy_type}://{self.config.proxy_host}:{self.config.proxy_port}"
|
||||
|
||||
try:
|
||||
return ProxyConnector.from_url(proxy_url)
|
||||
except Exception as e:
|
||||
logger.warning(f"[ImageProcessor] 代理配置失败: {e}")
|
||||
return None
|
||||
|
||||
async def download_image(
|
||||
self,
|
||||
bot,
|
||||
cdnurl: str,
|
||||
aeskey: str,
|
||||
use_cache: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
下载图片并转换为 base64
|
||||
|
||||
Args:
|
||||
bot: WechatHookClient 实例(用于 CDN 下载)
|
||||
cdnurl: CDN URL
|
||||
aeskey: AES 密钥
|
||||
use_cache: 是否使用缓存
|
||||
|
||||
Returns:
|
||||
base64 编码的图片数据(带 data URI 前缀)
|
||||
"""
|
||||
try:
|
||||
# 1. 优先从 Redis 缓存获取
|
||||
if use_cache:
|
||||
from utils.redis_cache import RedisCache, get_cache
|
||||
redis_cache = get_cache()
|
||||
if redis_cache and redis_cache.enabled:
|
||||
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
|
||||
if media_key:
|
||||
cached_data = redis_cache.get_cached_media(media_key, "image")
|
||||
if cached_data:
|
||||
logger.debug(f"[ImageProcessor] 图片缓存命中: {media_key[:20]}...")
|
||||
return cached_data
|
||||
|
||||
# 2. 缓存未命中,下载图片
|
||||
logger.debug(f"[ImageProcessor] 开始下载图片...")
|
||||
|
||||
filename = f"temp_{uuid.uuid4().hex[:8]}.jpg"
|
||||
save_path = str((self.temp_dir / filename).resolve())
|
||||
|
||||
# 尝试下载中图,失败则下载原图
|
||||
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=2)
|
||||
if not success:
|
||||
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=1)
|
||||
|
||||
if not success:
|
||||
logger.error("[ImageProcessor] CDN 下载失败")
|
||||
return ""
|
||||
|
||||
# 等待文件写入完成
|
||||
import os
|
||||
for _ in range(20): # 最多等待10秒
|
||||
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
logger.error("[ImageProcessor] 图片文件未生成")
|
||||
return ""
|
||||
|
||||
with open(save_path, "rb") as f:
|
||||
image_data = base64.b64encode(f.read()).decode()
|
||||
|
||||
base64_result = f"data:image/jpeg;base64,{image_data}"
|
||||
|
||||
# 3. 缓存到 Redis
|
||||
if use_cache:
|
||||
try:
|
||||
from utils.redis_cache import RedisCache, get_cache
|
||||
redis_cache = get_cache()
|
||||
if redis_cache and redis_cache.enabled:
|
||||
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
|
||||
if media_key:
|
||||
redis_cache.cache_media(media_key, base64_result, "image", ttl=300)
|
||||
logger.debug(f"[ImageProcessor] 图片已缓存: {media_key[:20]}...")
|
||||
except Exception as e:
|
||||
logger.debug(f"[ImageProcessor] 缓存图片失败: {e}")
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
Path(save_path).unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return base64_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ImageProcessor] 下载图片失败: {e}")
|
||||
return ""
|
||||
|
||||
async def download_emoji(
|
||||
self,
|
||||
cdn_url: str,
|
||||
max_retries: int = 3,
|
||||
use_cache: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
下载表情包并转换为 base64
|
||||
|
||||
Args:
|
||||
cdn_url: CDN URL
|
||||
max_retries: 最大重试次数
|
||||
use_cache: 是否使用缓存
|
||||
|
||||
Returns:
|
||||
base64 编码的表情包数据(带 data URI 前缀)
|
||||
"""
|
||||
# 替换 HTML 实体
|
||||
cdn_url = cdn_url.replace("&", "&")
|
||||
|
||||
# 1. 优先从 Redis 缓存获取
|
||||
media_key = None
|
||||
if use_cache:
|
||||
try:
|
||||
from utils.redis_cache import RedisCache, get_cache
|
||||
redis_cache = get_cache()
|
||||
media_key = RedisCache.generate_media_key(cdnurl=cdn_url)
|
||||
if redis_cache and redis_cache.enabled and media_key:
|
||||
cached_data = redis_cache.get_cached_media(media_key, "emoji")
|
||||
if cached_data:
|
||||
logger.debug(f"[ImageProcessor] 表情包缓存命中: {media_key[:20]}...")
|
||||
return cached_data
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. 缓存未命中,下载表情包
|
||||
logger.debug(f"[ImageProcessor] 开始下载表情包...")
|
||||
|
||||
last_error = None
|
||||
connector = self._get_proxy_connector()
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=30 + attempt * 15)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||||
async with session.get(cdn_url) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
|
||||
if len(content) == 0:
|
||||
logger.warning(f"[ImageProcessor] 表情包内容为空,重试 {attempt + 1}/{max_retries}")
|
||||
continue
|
||||
|
||||
image_data = base64.b64encode(content).decode()
|
||||
base64_result = f"data:image/gif;base64,{image_data}"
|
||||
|
||||
logger.debug(f"[ImageProcessor] 表情包下载成功,大小: {len(content)} 字节")
|
||||
|
||||
# 3. 缓存到 Redis
|
||||
if use_cache and media_key:
|
||||
try:
|
||||
from utils.redis_cache import get_cache
|
||||
redis_cache = get_cache()
|
||||
if redis_cache and redis_cache.enabled:
|
||||
redis_cache.cache_media(media_key, base64_result, "emoji", ttl=300)
|
||||
logger.debug(f"[ImageProcessor] 表情包已缓存: {media_key[:20]}...")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return base64_result
|
||||
else:
|
||||
logger.warning(f"[ImageProcessor] 表情包下载失败,状态码: {response.status}")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
last_error = "请求超时"
|
||||
logger.warning(f"[ImageProcessor] 表情包下载超时,重试 {attempt + 1}/{max_retries}")
|
||||
except aiohttp.ClientError as e:
|
||||
last_error = str(e)
|
||||
logger.warning(f"[ImageProcessor] 表情包下载网络错误: {e}")
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
logger.warning(f"[ImageProcessor] 表情包下载异常: {e}")
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(1 * (attempt + 1))
|
||||
|
||||
logger.error(f"[ImageProcessor] 表情包下载失败,已重试 {max_retries} 次: {last_error}")
|
||||
return ""
|
||||
|
||||
async def download_video(
|
||||
self,
|
||||
bot,
|
||||
cdnurl: str,
|
||||
aeskey: str,
|
||||
use_cache: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
下载视频并转换为 base64
|
||||
|
||||
Args:
|
||||
bot: WechatHookClient 实例
|
||||
cdnurl: CDN URL
|
||||
aeskey: AES 密钥
|
||||
use_cache: 是否使用缓存
|
||||
|
||||
Returns:
|
||||
base64 编码的视频数据(带 data URI 前缀)
|
||||
"""
|
||||
try:
|
||||
# 从缓存获取
|
||||
media_key = None
|
||||
if use_cache:
|
||||
try:
|
||||
from utils.redis_cache import RedisCache, get_cache
|
||||
redis_cache = get_cache()
|
||||
if redis_cache and redis_cache.enabled:
|
||||
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
|
||||
if media_key:
|
||||
cached_data = redis_cache.get_cached_media(media_key, "video")
|
||||
if cached_data:
|
||||
logger.debug(f"[ImageProcessor] 视频缓存命中: {media_key[:20]}...")
|
||||
return cached_data
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 下载视频
|
||||
logger.info(f"[ImageProcessor] 开始下载视频...")
|
||||
|
||||
filename = f"video_{uuid.uuid4().hex[:8]}.mp4"
|
||||
save_path = str((self.temp_dir / filename).resolve())
|
||||
|
||||
# file_type=4 表示视频
|
||||
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=4)
|
||||
if not success:
|
||||
logger.error("[ImageProcessor] 视频 CDN 下载失败")
|
||||
return ""
|
||||
|
||||
# 等待文件写入完成
|
||||
import os
|
||||
for _ in range(30):
|
||||
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
logger.error("[ImageProcessor] 视频文件未生成")
|
||||
return ""
|
||||
|
||||
file_size = os.path.getsize(save_path)
|
||||
logger.info(f"[ImageProcessor] 视频下载完成,大小: {file_size / 1024 / 1024:.2f} MB")
|
||||
|
||||
# 检查文件大小限制
|
||||
max_size_mb = self.config.video_max_size_mb
|
||||
if file_size > max_size_mb * 1024 * 1024:
|
||||
logger.warning(f"[ImageProcessor] 视频文件过大: {file_size / 1024 / 1024:.2f} MB > {max_size_mb} MB")
|
||||
try:
|
||||
Path(save_path).unlink()
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
# 读取并编码
|
||||
with open(save_path, "rb") as f:
|
||||
video_data = base64.b64encode(f.read()).decode()
|
||||
|
||||
video_base64 = f"data:video/mp4;base64,{video_data}"
|
||||
|
||||
# 缓存到 Redis
|
||||
if use_cache and media_key:
|
||||
try:
|
||||
from utils.redis_cache import get_cache
|
||||
redis_cache = get_cache()
|
||||
if redis_cache and redis_cache.enabled:
|
||||
redis_cache.cache_media(media_key, video_base64, "video", ttl=600)
|
||||
logger.debug(f"[ImageProcessor] 视频已缓存: {media_key[:20]}...")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
Path(save_path).unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return video_base64
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ImageProcessor] 下载视频失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return ""
|
||||
|
||||
async def generate_description(
|
||||
self,
|
||||
image_base64: str,
|
||||
prompt: str = "请用一句话简洁地描述这张图片的主要内容。",
|
||||
model: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
使用 AI 生成图片描述
|
||||
|
||||
Args:
|
||||
image_base64: 图片的 base64 数据
|
||||
prompt: 描述提示词
|
||||
model: 使用的模型(默认使用配置中的模型)
|
||||
|
||||
Returns:
|
||||
图片描述文本,失败返回空字符串
|
||||
"""
|
||||
description_model = model or self.config.model
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": image_base64}}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
payload = {
|
||||
"model": description_model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"stream": True
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.config.api_key}"
|
||||
}
|
||||
|
||||
max_retries = self.config.retries
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
|
||||
connector = self._get_proxy_connector()
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||||
async with session.post(
|
||||
self.config.api_url,
|
||||
json=payload,
|
||||
headers=headers
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
raise Exception(f"API 返回错误: {resp.status}, {error_text[:200]}")
|
||||
|
||||
# 流式接收响应
|
||||
description = ""
|
||||
async for line in resp.content:
|
||||
line = line.decode('utf-8').strip()
|
||||
if not line or line == "data: [DONE]":
|
||||
continue
|
||||
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
data = json.loads(line[6:])
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
description += content
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.debug(f"[ImageProcessor] 图片描述生成成功: {description[:50]}...")
|
||||
return description.strip()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
||||
last_error = str(e)
|
||||
if attempt < max_retries:
|
||||
logger.warning(f"[ImageProcessor] 图片描述网络错误: {e},重试 {attempt + 1}/{max_retries}")
|
||||
await asyncio.sleep(1 * (attempt + 1))
|
||||
continue
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
if attempt < max_retries:
|
||||
logger.warning(f"[ImageProcessor] 图片描述生成异常: {e},重试 {attempt + 1}/{max_retries}")
|
||||
await asyncio.sleep(1 * (attempt + 1))
|
||||
continue
|
||||
|
||||
logger.error(f"[ImageProcessor] 生成图片描述失败,已重试 {max_retries + 1} 次: {last_error}")
|
||||
return ""
|
||||
|
||||
async def analyze_video(
|
||||
self,
|
||||
video_base64: str,
|
||||
prompt: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
使用 AI 分析视频内容
|
||||
|
||||
Args:
|
||||
video_base64: 视频的 base64 数据
|
||||
prompt: 分析提示词
|
||||
|
||||
Returns:
|
||||
视频分析描述,失败返回空字符串
|
||||
"""
|
||||
if not self.config.video_api_url or not self.config.video_model:
|
||||
logger.error("[ImageProcessor] 视频分析配置不完整")
|
||||
return ""
|
||||
|
||||
# 去除 data:video/mp4;base64, 前缀(如果有)
|
||||
if video_base64.startswith("data:"):
|
||||
video_base64 = video_base64.split(",", 1)[1]
|
||||
|
||||
default_prompt = """请详细分析这个视频的内容,包括:
|
||||
1. 视频的主要场景和环境
|
||||
2. 出现的人物/物体及其动作
|
||||
3. 视频中的文字、对话或声音(如果有)
|
||||
4. 视频的整体主题或要表达的内容
|
||||
5. 任何值得注意的细节
|
||||
|
||||
请用客观、详细的方式描述,不要加入主观评价。"""
|
||||
|
||||
analyze_prompt = prompt or default_prompt
|
||||
|
||||
full_url = f"{self.config.video_api_url}/{self.config.video_model}:generateContent"
|
||||
|
||||
payload = {
|
||||
"contents": [
|
||||
{
|
||||
"parts": [
|
||||
{"text": analyze_prompt},
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": "video/mp4",
|
||||
"data": video_base64
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"generationConfig": {
|
||||
"maxOutputTokens": self.config.video_max_tokens
|
||||
}
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.config.api_key}"
|
||||
}
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=self.config.video_timeout)
|
||||
max_retries = 2
|
||||
retry_delay = 5
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
logger.info(f"[ImageProcessor] 开始分析视频...{f' (重试 {attempt}/{max_retries})' if attempt > 0 else ''}")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(full_url, json=payload, headers=headers) as resp:
|
||||
if resp.status in [502, 503, 504]:
|
||||
logger.warning(f"[ImageProcessor] 视频 API 临时错误: {resp.status}")
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
return ""
|
||||
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
logger.error(f"[ImageProcessor] 视频 API 错误: {resp.status}, {error_text[:300]}")
|
||||
return ""
|
||||
|
||||
result = await resp.json()
|
||||
|
||||
# 检查安全过滤
|
||||
if "promptFeedback" in result:
|
||||
feedback = result["promptFeedback"]
|
||||
if feedback.get("blockReason"):
|
||||
logger.warning(f"[ImageProcessor] 视频内容被过滤: {feedback.get('blockReason')}")
|
||||
return ""
|
||||
|
||||
# 提取文本
|
||||
if "candidates" in result and result["candidates"]:
|
||||
for candidate in result["candidates"]:
|
||||
if candidate.get("finishReason") == "SAFETY":
|
||||
logger.warning("[ImageProcessor] 视频响应被安全过滤")
|
||||
return ""
|
||||
|
||||
content = candidate.get("content", {})
|
||||
for part in content.get("parts", []):
|
||||
if "text" in part:
|
||||
text = part["text"]
|
||||
logger.info(f"[ImageProcessor] 视频分析完成,长度: {len(text)}")
|
||||
return text
|
||||
|
||||
logger.error(f"[ImageProcessor] 视频分析无有效响应")
|
||||
return ""
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[ImageProcessor] 视频分析超时{f', 将重试...' if attempt < max_retries else ''}")
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"[ImageProcessor] 视频分析失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return ""
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
# ==================== 便捷函数 ====================
|
||||
|
||||
_default_processor: Optional[ImageProcessor] = None
|
||||
|
||||
|
||||
def get_image_processor(config: Optional[MediaConfig] = None) -> ImageProcessor:
|
||||
"""获取默认图片处理器"""
|
||||
global _default_processor
|
||||
if config:
|
||||
_default_processor = ImageProcessor(config)
|
||||
if _default_processor is None:
|
||||
raise ValueError("ImageProcessor 未初始化,请先传入配置")
|
||||
return _default_processor
|
||||
|
||||
|
||||
def init_image_processor(config_dict: Dict[str, Any], temp_dir: Optional[Path] = None) -> ImageProcessor:
|
||||
"""从配置字典初始化图片处理器"""
|
||||
config = MediaConfig.from_dict(config_dict)
|
||||
if temp_dir:
|
||||
config.temp_dir = temp_dir
|
||||
processor = ImageProcessor(config, temp_dir)
|
||||
global _default_processor
|
||||
_default_processor = processor
|
||||
return processor
|
||||
|
||||
|
||||
# ==================== 导出 ====================
|
||||
|
||||
__all__ = [
|
||||
'MediaConfig',
|
||||
'MediaResult',
|
||||
'ImageProcessor',
|
||||
'get_image_processor',
|
||||
'init_image_processor',
|
||||
]
|
||||
392
utils/llm_client.py
Normal file
392
utils/llm_client.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""
|
||||
LLM 客户端抽象层
|
||||
|
||||
提供统一的 LLM API 调用接口:
|
||||
- 支持 OpenAI 兼容 API
|
||||
- 自动重试和错误处理
|
||||
- 流式/非流式响应
|
||||
- 代理支持
|
||||
- Token 估算
|
||||
|
||||
使用示例:
|
||||
from utils.llm_client import LLMClient, LLMConfig
|
||||
|
||||
config = LLMConfig(
|
||||
api_base="https://api.openai.com/v1",
|
||||
api_key="sk-xxx",
|
||||
model="gpt-4",
|
||||
)
|
||||
client = LLMClient(config)
|
||||
|
||||
response = await client.chat_completion(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
tools=[...],
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
# 可选代理支持
|
||||
try:
|
||||
from aiohttp_socks import ProxyConnector
|
||||
PROXY_SUPPORT = True
|
||||
except ImportError:
|
||||
PROXY_SUPPORT = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""LLM 配置"""
|
||||
api_base: str = "https://api.openai.com/v1"
|
||||
api_key: str = ""
|
||||
model: str = "gpt-4"
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 4096
|
||||
timeout: int = 120
|
||||
max_retries: int = 3
|
||||
retry_delay: float = 1.0
|
||||
|
||||
# 代理配置
|
||||
proxy_enabled: bool = False
|
||||
proxy_type: str = "socks5"
|
||||
proxy_host: str = "127.0.0.1"
|
||||
proxy_port: int = 7890
|
||||
|
||||
# 额外参数
|
||||
extra_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config: Dict[str, Any]) -> "LLMConfig":
|
||||
"""从配置字典创建"""
|
||||
api_config = config.get("api", {})
|
||||
proxy_config = config.get("proxy", {})
|
||||
|
||||
return cls(
|
||||
api_base=api_config.get("base_url", "https://api.openai.com/v1"),
|
||||
api_key=api_config.get("api_key", ""),
|
||||
model=api_config.get("model", "gpt-4"),
|
||||
temperature=api_config.get("temperature", 0.7),
|
||||
max_tokens=api_config.get("max_tokens", 4096),
|
||||
timeout=api_config.get("timeout", 120),
|
||||
max_retries=api_config.get("max_retries", 3),
|
||||
retry_delay=api_config.get("retry_delay", 1.0),
|
||||
proxy_enabled=proxy_config.get("enabled", False),
|
||||
proxy_type=proxy_config.get("type", "socks5"),
|
||||
proxy_host=proxy_config.get("host", "127.0.0.1"),
|
||||
proxy_port=proxy_config.get("port", 7890),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""LLM 响应"""
|
||||
content: str = ""
|
||||
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
|
||||
finish_reason: str = ""
|
||||
usage: Dict[str, int] = field(default_factory=dict)
|
||||
raw_response: Dict[str, Any] = field(default_factory=dict)
|
||||
error: Optional[str] = None
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return self.error is None
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""
|
||||
LLM 客户端
|
||||
|
||||
提供统一的 API 调用接口,支持:
|
||||
- OpenAI 兼容 API
|
||||
- 自动重试
|
||||
- 代理
|
||||
- 流式响应
|
||||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""获取或创建 HTTP 会话"""
|
||||
if self._session is None or self._session.closed:
|
||||
connector = None
|
||||
|
||||
# 配置代理
|
||||
if self.config.proxy_enabled and PROXY_SUPPORT:
|
||||
proxy_url = (
|
||||
f"{self.config.proxy_type}://"
|
||||
f"{self.config.proxy_host}:{self.config.proxy_port}"
|
||||
)
|
||||
connector = ProxyConnector.from_url(proxy_url)
|
||||
logger.debug(f"[LLMClient] 使用代理: {proxy_url}")
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""关闭会话"""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
}
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建请求体"""
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||
|
||||
# 合并额外参数
|
||||
payload.update(self.config.extra_params)
|
||||
payload.update(kwargs)
|
||||
|
||||
return payload
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
非流式聊天补全
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
tools: 工具列表(可选)
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
LLMResponse 对象
|
||||
"""
|
||||
session = await self._get_session()
|
||||
url = f"{self.config.api_base}/chat/completions"
|
||||
headers = self._build_headers()
|
||||
payload = self._build_payload(messages, tools, stream=False, **kwargs)
|
||||
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
async with session.post(url, headers=headers, json=payload) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return self._parse_response(data)
|
||||
|
||||
error_text = await resp.text()
|
||||
last_error = f"HTTP {resp.status}: {error_text[:200]}"
|
||||
logger.warning(f"[LLMClient] 请求失败 (尝试 {attempt + 1}): {last_error}")
|
||||
|
||||
# 不可重试的错误
|
||||
if resp.status in [400, 401, 403]:
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
last_error = f"请求超时 ({self.config.timeout}s)"
|
||||
logger.warning(f"[LLMClient] {last_error} (尝试 {attempt + 1})")
|
||||
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
logger.warning(f"[LLMClient] 请求异常 (尝试 {attempt + 1}): {e}")
|
||||
|
||||
# 重试延迟
|
||||
if attempt < self.config.max_retries - 1:
|
||||
await asyncio.sleep(self.config.retry_delay * (attempt + 1))
|
||||
|
||||
return LLMResponse(error=last_error)
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
流式聊天补全
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
tools: 工具列表(可选)
|
||||
**kwargs: 额外参数
|
||||
|
||||
Yields:
|
||||
文本片段
|
||||
"""
|
||||
session = await self._get_session()
|
||||
url = f"{self.config.api_base}/chat/completions"
|
||||
headers = self._build_headers()
|
||||
payload = self._build_payload(messages, tools, stream=True, **kwargs)
|
||||
|
||||
try:
|
||||
async with session.post(url, headers=headers, json=payload) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
logger.error(f"[LLMClient] 流式请求失败: HTTP {resp.status}")
|
||||
return
|
||||
|
||||
async for line in resp.content:
|
||||
line = line.decode("utf-8").strip()
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMClient] 流式请求异常: {e}")
|
||||
|
||||
def _parse_response(self, data: Dict[str, Any]) -> LLMResponse:
|
||||
"""解析 API 响应"""
|
||||
try:
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
content = message.get("content", "") or ""
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
finish_reason = choice.get("finish_reason", "")
|
||||
usage = data.get("usage", {})
|
||||
|
||||
# 标准化 tool_calls
|
||||
parsed_tool_calls = []
|
||||
for tc in tool_calls:
|
||||
parsed_tool_calls.append({
|
||||
"id": tc.get("id", ""),
|
||||
"type": tc.get("type", "function"),
|
||||
"function": {
|
||||
"name": tc.get("function", {}).get("name", ""),
|
||||
"arguments": tc.get("function", {}).get("arguments", "{}"),
|
||||
}
|
||||
})
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=parsed_tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
raw_response=data,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMClient] 解析响应失败: {e}")
|
||||
return LLMResponse(error=f"解析响应失败: {e}")
|
||||
|
||||
# ==================== Token 估算 ====================
|
||||
|
||||
@staticmethod
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""
|
||||
估算文本的 token 数量
|
||||
|
||||
使用简化规则:
|
||||
- 英文约 4 字符 = 1 token
|
||||
- 中文约 1.5 字符 = 1 token
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||||
other_chars = len(text) - chinese_chars
|
||||
|
||||
chinese_tokens = chinese_chars / 1.5
|
||||
other_tokens = other_chars / 4
|
||||
|
||||
return int(chinese_tokens + other_tokens)
|
||||
|
||||
@staticmethod
|
||||
def estimate_message_tokens(message: Dict[str, Any]) -> int:
|
||||
"""估算单条消息的 token 数量"""
|
||||
content = message.get("content", "")
|
||||
|
||||
if isinstance(content, str):
|
||||
return LLMClient.estimate_tokens(content) + 4 # role 等开销
|
||||
|
||||
if isinstance(content, list):
|
||||
total = 4
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
total += LLMClient.estimate_tokens(item.get("text", ""))
|
||||
elif item.get("type") == "image_url":
|
||||
total += 85 # 图片固定开销
|
||||
return total
|
||||
|
||||
return 4
|
||||
|
||||
@staticmethod
|
||||
def estimate_messages_tokens(messages: List[Dict[str, Any]]) -> int:
|
||||
"""估算消息列表的总 token 数量"""
|
||||
return sum(LLMClient.estimate_message_tokens(m) for m in messages)
|
||||
|
||||
|
||||
# ==================== 便捷函数 ====================
|
||||
|
||||
_default_client: Optional[LLMClient] = None
|
||||
|
||||
|
||||
def get_llm_client(config: Optional[LLMConfig] = None) -> LLMClient:
|
||||
"""获取默认 LLM 客户端"""
|
||||
global _default_client
|
||||
if config:
|
||||
_default_client = LLMClient(config)
|
||||
if _default_client is None:
|
||||
raise ValueError("LLM 客户端未初始化,请先传入配置")
|
||||
return _default_client
|
||||
|
||||
|
||||
# ==================== 导出 ====================
|
||||
|
||||
__all__ = [
|
||||
'LLMConfig',
|
||||
'LLMResponse',
|
||||
'LLMClient',
|
||||
'get_llm_client',
|
||||
]
|
||||
@@ -181,3 +181,91 @@ def validate_tool_arguments(
|
||||
|
||||
return True, "", arguments
|
||||
|
||||
|
||||
# ==================== 工具注册中心集成 ====================
|
||||
|
||||
def register_plugin_tools(
|
||||
plugin_name: str,
|
||||
plugin: Any,
|
||||
tools_config: Dict[str, Any],
|
||||
timeout_config: Optional[Dict[str, Any]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
将插件的 LLM 工具注册到全局工具注册中心
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
plugin: 插件实例(需实现 get_llm_tools 和 execute_llm_tool)
|
||||
tools_config: 工具配置(包含 mode, whitelist, blacklist)
|
||||
timeout_config: 工具超时配置 {tool_name: timeout_seconds}
|
||||
|
||||
Returns:
|
||||
注册的工具数量
|
||||
"""
|
||||
from utils.tool_registry import get_tool_registry
|
||||
|
||||
if not hasattr(plugin, "get_llm_tools") or not hasattr(plugin, "execute_llm_tool"):
|
||||
return 0
|
||||
|
||||
registry = get_tool_registry()
|
||||
timeout_config = timeout_config or {}
|
||||
|
||||
mode = tools_config.get("mode", "all")
|
||||
whitelist = set(tools_config.get("whitelist", []))
|
||||
blacklist = set(tools_config.get("blacklist", []))
|
||||
|
||||
plugin_tools = plugin.get_llm_tools() or []
|
||||
registered_count = 0
|
||||
|
||||
for tool in plugin_tools:
|
||||
tool_name = tool.get("function", {}).get("name", "")
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
# 应用白名单/黑名单过滤
|
||||
if mode == "whitelist" and tool_name not in whitelist:
|
||||
continue
|
||||
if mode == "blacklist" and tool_name in blacklist:
|
||||
logger.debug(f"[黑名单] 跳过注册工具: {tool_name}")
|
||||
continue
|
||||
|
||||
# 获取工具超时配置
|
||||
timeout = timeout_config.get(tool_name, timeout_config.get("default", 60))
|
||||
|
||||
# 创建执行器闭包
|
||||
async def make_executor(p, tn):
|
||||
async def executor(tool_name: str, arguments: dict, bot, from_wxid: str):
|
||||
return await p.execute_llm_tool(tool_name, arguments, bot, from_wxid)
|
||||
return executor
|
||||
|
||||
# 注册工具
|
||||
if registry.register(
|
||||
name=tool_name,
|
||||
plugin_name=plugin_name,
|
||||
schema=tool,
|
||||
executor=plugin.execute_llm_tool,
|
||||
timeout=timeout,
|
||||
):
|
||||
registered_count += 1
|
||||
if mode == "whitelist":
|
||||
logger.debug(f"[白名单] 注册工具: {tool_name}")
|
||||
|
||||
if registered_count > 0:
|
||||
logger.info(f"插件 {plugin_name} 注册了 {registered_count} 个工具")
|
||||
|
||||
return registered_count
|
||||
|
||||
|
||||
def unregister_plugin_tools(plugin_name: str) -> int:
|
||||
"""
|
||||
从全局工具注册中心注销插件的所有工具
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
注销的工具数量
|
||||
"""
|
||||
from utils.tool_registry import get_tool_registry
|
||||
return get_tool_registry().unregister_plugin(plugin_name)
|
||||
|
||||
|
||||
145
utils/message_dedup.py
Normal file
145
utils/message_dedup.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
消息去重器模块
|
||||
|
||||
防止同一条消息被重复处理(某些环境下回调会重复触发)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class MessageDeduplicator:
|
||||
"""
|
||||
消息去重器
|
||||
|
||||
使用基于时间的滑动窗口实现去重:
|
||||
- 记录最近处理的消息 ID
|
||||
- 在 TTL 时间内重复的消息会被过滤
|
||||
- 自动清理过期记录,限制内存占用
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ttl_seconds: float = 30.0,
|
||||
max_size: int = 5000,
|
||||
):
|
||||
"""
|
||||
初始化去重器
|
||||
|
||||
Args:
|
||||
ttl_seconds: 消息 ID 的有效期(秒),0 表示禁用去重
|
||||
max_size: 最大缓存条目数,防止内存泄漏
|
||||
"""
|
||||
self.ttl_seconds = max(float(ttl_seconds), 0.0)
|
||||
self.max_size = max(int(max_size), 0)
|
||||
self._cache: Dict[str, float] = {} # key -> timestamp
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@staticmethod
|
||||
def extract_msg_id(data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
从原始消息数据中提取消息 ID
|
||||
|
||||
Args:
|
||||
data: 原始消息数据
|
||||
|
||||
Returns:
|
||||
消息 ID 字符串,提取失败返回空字符串
|
||||
"""
|
||||
for key in ("msgid", "msg_id", "MsgId", "id"):
|
||||
value = data.get(key)
|
||||
if value:
|
||||
return str(value)
|
||||
return ""
|
||||
|
||||
async def is_duplicate(self, data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
检查消息是否重复
|
||||
|
||||
Args:
|
||||
data: 原始消息数据
|
||||
|
||||
Returns:
|
||||
True 表示是重复消息,False 表示是新消息
|
||||
"""
|
||||
if self.ttl_seconds <= 0:
|
||||
return False
|
||||
|
||||
msg_id = self.extract_msg_id(data)
|
||||
if not msg_id:
|
||||
# 没有消息 ID 时不做去重,避免误判
|
||||
return False
|
||||
|
||||
key = f"msgid:{msg_id}"
|
||||
now = time.time()
|
||||
|
||||
async with self._lock:
|
||||
# 检查是否存在且未过期
|
||||
last_seen = self._cache.get(key)
|
||||
if last_seen is not None and (now - last_seen) < self.ttl_seconds:
|
||||
return True
|
||||
|
||||
# 记录新消息
|
||||
self._cache.pop(key, None) # 确保插入到末尾(保持顺序)
|
||||
self._cache[key] = now
|
||||
|
||||
# 清理过期条目
|
||||
self._cleanup_expired(now)
|
||||
|
||||
# 限制大小
|
||||
self._limit_size()
|
||||
|
||||
return False
|
||||
|
||||
def _cleanup_expired(self, now: float):
|
||||
"""清理过期条目(需在锁内调用)"""
|
||||
cutoff = now - self.ttl_seconds
|
||||
while self._cache:
|
||||
first_key = next(iter(self._cache))
|
||||
if self._cache[first_key] >= cutoff:
|
||||
break
|
||||
self._cache.pop(first_key, None)
|
||||
|
||||
def _limit_size(self):
|
||||
"""限制缓存大小(需在锁内调用)"""
|
||||
if self.max_size <= 0:
|
||||
return
|
||||
while len(self._cache) > self.max_size:
|
||||
first_key = next(iter(self._cache))
|
||||
self._cache.pop(first_key, None)
|
||||
|
||||
def clear(self):
|
||||
"""清空缓存"""
|
||||
self._cache.clear()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
return {
|
||||
"cached_count": len(self._cache),
|
||||
"ttl_seconds": self.ttl_seconds,
|
||||
"max_size": self.max_size,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, perf_config: Dict[str, Any]) -> "MessageDeduplicator":
|
||||
"""
|
||||
从配置创建去重器
|
||||
|
||||
Args:
|
||||
perf_config: Performance 配置节
|
||||
|
||||
Returns:
|
||||
MessageDeduplicator 实例
|
||||
"""
|
||||
return cls(
|
||||
ttl_seconds=perf_config.get("dedup_ttl_seconds", 30),
|
||||
max_size=perf_config.get("dedup_max_size", 5000),
|
||||
)
|
||||
|
||||
|
||||
# ==================== 导出 ====================
|
||||
|
||||
__all__ = ['MessageDeduplicator']
|
||||
128
utils/message_filter.py
Normal file
128
utils/message_filter.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
消息过滤器模块
|
||||
|
||||
提供消息过滤功能:
|
||||
- 白名单/黑名单过滤
|
||||
- 机器人自身消息过滤
|
||||
- 系统消息放行
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class MessageFilter:
|
||||
"""
|
||||
消息过滤器
|
||||
|
||||
支持三种模式:
|
||||
- None: 不过滤,处理所有消息
|
||||
- Whitelist: 只处理白名单中的消息
|
||||
- Blacklist: 过滤黑名单中的消息
|
||||
"""
|
||||
|
||||
# 系统消息类型(始终放行)
|
||||
SYSTEM_MESSAGE_TYPES = {11058}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str = "None",
|
||||
whitelist: List[str] = None,
|
||||
blacklist: List[str] = None,
|
||||
bot_wxid: str = None,
|
||||
):
|
||||
"""
|
||||
初始化过滤器
|
||||
|
||||
Args:
|
||||
mode: 过滤模式 ("None", "Whitelist", "Blacklist")
|
||||
whitelist: 白名单 wxid 列表
|
||||
blacklist: 黑名单 wxid 列表
|
||||
bot_wxid: 机器人自身 wxid(用于过滤自己的消息)
|
||||
"""
|
||||
self.mode = mode
|
||||
self.whitelist: Set[str] = set(whitelist or [])
|
||||
self.blacklist: Set[str] = set(blacklist or [])
|
||||
self.bot_wxid = bot_wxid
|
||||
|
||||
def set_bot_wxid(self, wxid: str):
|
||||
"""设置机器人 wxid"""
|
||||
self.bot_wxid = wxid
|
||||
|
||||
def add_to_whitelist(self, wxid: str):
|
||||
"""添加到白名单"""
|
||||
self.whitelist.add(wxid)
|
||||
|
||||
def remove_from_whitelist(self, wxid: str):
|
||||
"""从白名单移除"""
|
||||
self.whitelist.discard(wxid)
|
||||
|
||||
def add_to_blacklist(self, wxid: str):
|
||||
"""添加到黑名单"""
|
||||
self.blacklist.add(wxid)
|
||||
|
||||
def remove_from_blacklist(self, wxid: str):
|
||||
"""从黑名单移除"""
|
||||
self.blacklist.discard(wxid)
|
||||
|
||||
def should_process(self, message: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断消息是否应该被处理
|
||||
|
||||
Args:
|
||||
message: 标准化后的消息字典
|
||||
|
||||
Returns:
|
||||
True 表示应该处理,False 表示应该过滤
|
||||
"""
|
||||
from_wxid = message.get("FromWxid", "")
|
||||
sender_wxid = message.get("SenderWxid", "")
|
||||
msg_type = message.get("MsgType", 0)
|
||||
|
||||
# 系统消息始终放行
|
||||
if msg_type in self.SYSTEM_MESSAGE_TYPES:
|
||||
return True
|
||||
|
||||
# 过滤机器人自己的消息
|
||||
if self.bot_wxid and (from_wxid == self.bot_wxid or sender_wxid == self.bot_wxid):
|
||||
return False
|
||||
|
||||
# 根据模式过滤
|
||||
return self._check_mode(from_wxid, sender_wxid)
|
||||
|
||||
def _check_mode(self, from_wxid: str, sender_wxid: str) -> bool:
|
||||
"""根据模式检查是否放行"""
|
||||
if self.mode == "None":
|
||||
return True
|
||||
|
||||
if self.mode == "Whitelist":
|
||||
return from_wxid in self.whitelist or sender_wxid in self.whitelist
|
||||
|
||||
if self.mode == "Blacklist":
|
||||
return from_wxid not in self.blacklist and sender_wxid not in self.blacklist
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, bot_config: Dict[str, Any]) -> "MessageFilter":
|
||||
"""
|
||||
从配置创建过滤器
|
||||
|
||||
Args:
|
||||
bot_config: Bot 配置节
|
||||
|
||||
Returns:
|
||||
MessageFilter 实例
|
||||
"""
|
||||
return cls(
|
||||
mode=bot_config.get("ignore-mode", "None"),
|
||||
whitelist=bot_config.get("whitelist", []),
|
||||
blacklist=bot_config.get("blacklist", []),
|
||||
bot_wxid=bot_config.get("wxid") or bot_config.get("bot_wxid"),
|
||||
)
|
||||
|
||||
|
||||
# ==================== 导出 ====================
|
||||
|
||||
__all__ = ['MessageFilter']
|
||||
@@ -56,10 +56,18 @@ async def log_bot_message(to_wxid: str, content: str, msg_type: str = "text", me
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
sync_content = content
|
||||
if msg_type == "image":
|
||||
sync_content = "[图片]"
|
||||
elif msg_type == "video":
|
||||
sync_content = "[视频]"
|
||||
elif msg_type == "file":
|
||||
sync_content = "[文件]"
|
||||
|
||||
await store.add_group_message(
|
||||
to_wxid,
|
||||
bot_nickname,
|
||||
content,
|
||||
sync_content,
|
||||
role="assistant",
|
||||
sender_wxid=bot_wxid or None,
|
||||
)
|
||||
|
||||
305
utils/message_queue.py
Normal file
305
utils/message_queue.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
消息队列模块
|
||||
|
||||
提供高性能的优先级消息队列,支持多种溢出策略:
|
||||
- drop_oldest: 丢弃最旧的消息
|
||||
- drop_lowest: 丢弃优先级最低的消息
|
||||
- sampling: 按采样率丢弃消息
|
||||
- reject: 拒绝新消息
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import heapq
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
# ==================== 消息优先级常量 ====================
|
||||
|
||||
class MessagePriority:
|
||||
"""消息优先级常量"""
|
||||
CRITICAL = 100 # 系统消息、登录信息
|
||||
HIGH = 80 # 管理员命令、群成员变动
|
||||
NORMAL = 50 # @bot 消息(默认)
|
||||
LOW = 20 # 普通群消息
|
||||
|
||||
|
||||
# ==================== 溢出策略 ====================
|
||||
|
||||
class OverflowStrategy(Enum):
|
||||
"""队列溢出策略"""
|
||||
DROP_OLDEST = "drop_oldest" # 丢弃最旧的消息
|
||||
DROP_LOWEST = "drop_lowest" # 丢弃优先级最低的消息
|
||||
SAMPLING = "sampling" # 按采样率丢弃
|
||||
REJECT = "reject" # 拒绝新消息
|
||||
|
||||
|
||||
# ==================== 优先级消息 ====================
|
||||
|
||||
@dataclass(order=True)
|
||||
class PriorityMessage:
|
||||
"""优先级消息"""
|
||||
priority: int = field(compare=True)
|
||||
timestamp: float = field(compare=True)
|
||||
msg_type: int = field(compare=False)
|
||||
data: Dict[str, Any] = field(compare=False)
|
||||
|
||||
def __init__(self, msg_type: int, data: Dict[str, Any], priority: int = None):
|
||||
# 优先级越高,数值越大,但 heapq 是最小堆,所以取负数
|
||||
self.priority = -(priority if priority is not None else MessagePriority.NORMAL)
|
||||
self.timestamp = time.time()
|
||||
self.msg_type = msg_type
|
||||
self.data = data
|
||||
|
||||
|
||||
# ==================== 优先级消息队列 ====================
|
||||
|
||||
class PriorityMessageQueue:
|
||||
"""
|
||||
优先级消息队列
|
||||
|
||||
特性:
|
||||
- 基于堆的优先级队列
|
||||
- 支持多种溢出策略
|
||||
- 线程安全(使用 asyncio.Lock)
|
||||
- 支持任务计数和 join
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxsize: int = 1000,
|
||||
overflow_strategy: str = "drop_oldest",
|
||||
sampling_rate: float = 0.5,
|
||||
):
|
||||
"""
|
||||
初始化队列
|
||||
|
||||
Args:
|
||||
maxsize: 最大队列大小
|
||||
overflow_strategy: 溢出策略 (drop_oldest, drop_lowest, sampling, reject)
|
||||
sampling_rate: 采样策略的保留率 (0.0-1.0)
|
||||
"""
|
||||
self.maxsize = maxsize
|
||||
self.overflow_strategy = OverflowStrategy(overflow_strategy)
|
||||
self.sampling_rate = max(0.0, min(1.0, sampling_rate))
|
||||
|
||||
self._heap: List[PriorityMessage] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._not_empty = asyncio.Event()
|
||||
self._unfinished_tasks = 0
|
||||
self._finished = asyncio.Event()
|
||||
self._finished.set()
|
||||
|
||||
# 统计
|
||||
self._total_put = 0
|
||||
self._total_dropped = 0
|
||||
self._total_rejected = 0
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""返回队列大小"""
|
||||
return len(self._heap)
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""队列是否为空"""
|
||||
return len(self._heap) == 0
|
||||
|
||||
def full(self) -> bool:
|
||||
"""队列是否已满"""
|
||||
return len(self._heap) >= self.maxsize
|
||||
|
||||
async def put(
|
||||
self,
|
||||
msg_type: int,
|
||||
data: Dict[str, Any],
|
||||
priority: int = None,
|
||||
) -> bool:
|
||||
"""
|
||||
添加消息到队列
|
||||
|
||||
Args:
|
||||
msg_type: 消息类型
|
||||
data: 消息数据
|
||||
priority: 优先级(可选)
|
||||
|
||||
Returns:
|
||||
是否成功添加
|
||||
"""
|
||||
async with self._lock:
|
||||
self._total_put += 1
|
||||
|
||||
# 处理队列满的情况
|
||||
if self.full():
|
||||
if not self._handle_overflow():
|
||||
self._total_rejected += 1
|
||||
return False
|
||||
|
||||
msg = PriorityMessage(msg_type, data, priority)
|
||||
heapq.heappush(self._heap, msg)
|
||||
self._unfinished_tasks += 1
|
||||
self._finished.clear()
|
||||
self._not_empty.set()
|
||||
return True
|
||||
|
||||
def _handle_overflow(self) -> bool:
|
||||
"""
|
||||
处理队列溢出
|
||||
|
||||
Returns:
|
||||
True 表示成功腾出空间,False 表示拒绝
|
||||
"""
|
||||
if self.overflow_strategy == OverflowStrategy.REJECT:
|
||||
logger.warning("队列已满,拒绝新消息")
|
||||
return False
|
||||
|
||||
if self.overflow_strategy == OverflowStrategy.DROP_OLDEST:
|
||||
# 找到最旧的消息(timestamp 最小)
|
||||
if self._heap:
|
||||
oldest_idx = 0
|
||||
for i, msg in enumerate(self._heap):
|
||||
if msg.timestamp < self._heap[oldest_idx].timestamp:
|
||||
oldest_idx = i
|
||||
self._heap.pop(oldest_idx)
|
||||
heapq.heapify(self._heap)
|
||||
self._total_dropped += 1
|
||||
self._unfinished_tasks = max(0, self._unfinished_tasks - 1)
|
||||
return True
|
||||
|
||||
elif self.overflow_strategy == OverflowStrategy.DROP_LOWEST:
|
||||
# 找到优先级最低的消息(priority 值最大,因为是负数)
|
||||
if self._heap:
|
||||
lowest_idx = 0
|
||||
for i, msg in enumerate(self._heap):
|
||||
if msg.priority > self._heap[lowest_idx].priority:
|
||||
lowest_idx = i
|
||||
self._heap.pop(lowest_idx)
|
||||
heapq.heapify(self._heap)
|
||||
self._total_dropped += 1
|
||||
self._unfinished_tasks = max(0, self._unfinished_tasks - 1)
|
||||
return True
|
||||
|
||||
elif self.overflow_strategy == OverflowStrategy.SAMPLING:
|
||||
# 按采样率决定是否接受
|
||||
if random.random() < self.sampling_rate:
|
||||
# 接受新消息,丢弃最旧的
|
||||
if self._heap:
|
||||
oldest_idx = 0
|
||||
for i, msg in enumerate(self._heap):
|
||||
if msg.timestamp < self._heap[oldest_idx].timestamp:
|
||||
oldest_idx = i
|
||||
self._heap.pop(oldest_idx)
|
||||
heapq.heapify(self._heap)
|
||||
self._total_dropped += 1
|
||||
self._unfinished_tasks = max(0, self._unfinished_tasks - 1)
|
||||
return True
|
||||
else:
|
||||
self._total_dropped += 1
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
async def get(self, timeout: float = None) -> Tuple[int, Dict[str, Any]]:
|
||||
"""
|
||||
获取优先级最高的消息
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒),None 表示无限等待
|
||||
|
||||
Returns:
|
||||
(msg_type, data) 元组
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: 超时
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
async with self._lock:
|
||||
if self._heap:
|
||||
msg = heapq.heappop(self._heap)
|
||||
if not self._heap:
|
||||
self._not_empty.clear()
|
||||
return (msg.msg_type, msg.data)
|
||||
|
||||
# 计算剩余超时时间
|
||||
if timeout is not None:
|
||||
elapsed = time.time() - start_time
|
||||
remaining = timeout - elapsed
|
||||
if remaining <= 0:
|
||||
raise asyncio.TimeoutError("Queue get timeout")
|
||||
try:
|
||||
await asyncio.wait_for(self._not_empty.wait(), timeout=remaining)
|
||||
except asyncio.TimeoutError:
|
||||
raise asyncio.TimeoutError("Queue get timeout")
|
||||
else:
|
||||
await self._not_empty.wait()
|
||||
|
||||
def get_nowait(self) -> Tuple[int, Dict[str, Any]]:
|
||||
"""非阻塞获取消息"""
|
||||
if not self._heap:
|
||||
raise asyncio.QueueEmpty()
|
||||
msg = heapq.heappop(self._heap)
|
||||
if not self._heap:
|
||||
self._not_empty.clear()
|
||||
return (msg.msg_type, msg.data)
|
||||
|
||||
def task_done(self):
|
||||
"""标记任务完成"""
|
||||
self._unfinished_tasks = max(0, self._unfinished_tasks - 1)
|
||||
if self._unfinished_tasks == 0:
|
||||
self._finished.set()
|
||||
|
||||
async def join(self):
|
||||
"""等待所有任务完成"""
|
||||
await self._finished.wait()
|
||||
|
||||
def clear(self):
|
||||
"""清空队列"""
|
||||
self._heap.clear()
|
||||
self._not_empty.clear()
|
||||
self._unfinished_tasks = 0
|
||||
self._finished.set()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取队列统计信息"""
|
||||
return {
|
||||
"current_size": len(self._heap),
|
||||
"max_size": self.maxsize,
|
||||
"total_put": self._total_put,
|
||||
"total_dropped": self._total_dropped,
|
||||
"total_rejected": self._total_rejected,
|
||||
"unfinished_tasks": self._unfinished_tasks,
|
||||
"overflow_strategy": self.overflow_strategy.value,
|
||||
"utilization": len(self._heap) / max(self.maxsize, 1),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, queue_config: Dict[str, Any]) -> "PriorityMessageQueue":
|
||||
"""
|
||||
从配置创建队列
|
||||
|
||||
Args:
|
||||
queue_config: Queue 配置节
|
||||
|
||||
Returns:
|
||||
PriorityMessageQueue 实例
|
||||
"""
|
||||
return cls(
|
||||
maxsize=queue_config.get("max_size", 1000),
|
||||
overflow_strategy=queue_config.get("overflow_strategy", "drop_oldest"),
|
||||
sampling_rate=queue_config.get("sampling_rate", 0.5),
|
||||
)
|
||||
|
||||
|
||||
# ==================== 导出 ====================
|
||||
|
||||
__all__ = [
|
||||
'MessagePriority',
|
||||
'OverflowStrategy',
|
||||
'PriorityMessage',
|
||||
'PriorityMessageQueue',
|
||||
]
|
||||
114
utils/message_stats.py
Normal file
114
utils/message_stats.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
消息统计器模块
|
||||
|
||||
提供消息处理的统计功能:
|
||||
- 消息计数
|
||||
- 过滤率统计
|
||||
- 按类型统计
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class MessageStats:
|
||||
"""
|
||||
消息统计器
|
||||
|
||||
线程安全的消息统计实现
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = Lock()
|
||||
self._total_count = 0
|
||||
self._filtered_count = 0
|
||||
self._processed_count = 0
|
||||
self._duplicate_count = 0
|
||||
self._error_count = 0
|
||||
self._by_type: Dict[str, int] = defaultdict(int)
|
||||
self._start_time = time.time()
|
||||
|
||||
def record_received(self):
|
||||
"""记录收到消息"""
|
||||
with self._lock:
|
||||
self._total_count += 1
|
||||
|
||||
def record_filtered(self):
|
||||
"""记录被过滤的消息"""
|
||||
with self._lock:
|
||||
self._filtered_count += 1
|
||||
|
||||
def record_processed(self, event_type: str = None):
|
||||
"""
|
||||
记录已处理的消息
|
||||
|
||||
Args:
|
||||
event_type: 消息事件类型(可选)
|
||||
"""
|
||||
with self._lock:
|
||||
self._processed_count += 1
|
||||
if event_type:
|
||||
self._by_type[event_type] += 1
|
||||
|
||||
def record_duplicate(self):
|
||||
"""记录重复消息"""
|
||||
with self._lock:
|
||||
self._duplicate_count += 1
|
||||
|
||||
def record_error(self):
|
||||
"""记录处理错误"""
|
||||
with self._lock:
|
||||
self._error_count += 1
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
with self._lock:
|
||||
uptime = time.time() - self._start_time
|
||||
total = max(self._total_count, 1) # 避免除零
|
||||
|
||||
return {
|
||||
"total_messages": self._total_count,
|
||||
"filtered_messages": self._filtered_count,
|
||||
"processed_messages": self._processed_count,
|
||||
"duplicate_messages": self._duplicate_count,
|
||||
"error_count": self._error_count,
|
||||
"filter_rate": self._filtered_count / total,
|
||||
"process_rate": self._processed_count / total,
|
||||
"duplicate_rate": self._duplicate_count / total,
|
||||
"messages_per_minute": (self._total_count / uptime) * 60 if uptime > 0 else 0,
|
||||
"uptime_seconds": uptime,
|
||||
"by_type": dict(self._by_type),
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""重置统计"""
|
||||
with self._lock:
|
||||
self._total_count = 0
|
||||
self._filtered_count = 0
|
||||
self._processed_count = 0
|
||||
self._duplicate_count = 0
|
||||
self._error_count = 0
|
||||
self._by_type.clear()
|
||||
self._start_time = time.time()
|
||||
|
||||
|
||||
# 全局单例(可选使用)
|
||||
_global_stats: MessageStats = None
|
||||
_stats_lock = Lock()
|
||||
|
||||
|
||||
def get_message_stats() -> MessageStats:
|
||||
"""获取全局消息统计器实例"""
|
||||
global _global_stats
|
||||
if _global_stats is None:
|
||||
with _stats_lock:
|
||||
if _global_stats is None:
|
||||
_global_stats = MessageStats()
|
||||
return _global_stats
|
||||
|
||||
|
||||
# ==================== 导出 ====================
|
||||
|
||||
__all__ = ['MessageStats', 'get_message_stats']
|
||||
@@ -1,35 +1,144 @@
|
||||
"""
|
||||
插件基类模块
|
||||
|
||||
提供插件的基础功能:
|
||||
- 生命周期钩子(on_load, on_enable, on_disable, on_unload, on_reload)
|
||||
- 定时任务管理
|
||||
- 依赖声明
|
||||
- 插件元数据
|
||||
"""
|
||||
|
||||
from abc import ABC
|
||||
from typing import List
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .decorators import scheduler, add_job_safe, remove_job_safe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from utils.plugin_manager import PluginManager
|
||||
|
||||
|
||||
class PluginState(Enum):
|
||||
"""插件状态"""
|
||||
UNLOADED = "unloaded" # 未加载
|
||||
LOADED = "loaded" # 已加载(未启用)
|
||||
ENABLED = "enabled" # 已启用
|
||||
DISABLED = "disabled" # 已禁用
|
||||
ERROR = "error" # 错误状态
|
||||
|
||||
|
||||
class PluginBase(ABC):
|
||||
"""插件基类"""
|
||||
"""
|
||||
插件基类
|
||||
|
||||
生命周期:
|
||||
1. __init__() - 构造函数(同步)
|
||||
2. on_load() - 加载时调用(异步,可访问其他插件)
|
||||
3. async_init() - 异步初始化(异步,加载配置、资源等)
|
||||
4. on_enable() - 启用时调用(异步,注册定时任务)
|
||||
5. on_disable() - 禁用时调用(异步,清理定时任务)
|
||||
6. on_unload() - 卸载时调用(异步,释放资源)
|
||||
7. on_reload() - 重载前调用(异步,保存状态)
|
||||
|
||||
使用示例:
|
||||
class MyPlugin(PluginBase):
|
||||
description = "我的插件"
|
||||
author = "作者"
|
||||
version = "1.0.0"
|
||||
dependencies = ["AIChat"] # 依赖的插件
|
||||
load_priority = 60 # 加载优先级
|
||||
|
||||
async def on_load(self, plugin_manager):
|
||||
# 可以访问其他插件
|
||||
aichat = plugin_manager.plugins.get("AIChat")
|
||||
|
||||
async def async_init(self):
|
||||
# 加载配置、初始化资源
|
||||
self.config = load_config()
|
||||
|
||||
async def on_enable(self, bot):
|
||||
await super().on_enable(bot) # 注册定时任务
|
||||
# 额外的启用逻辑
|
||||
|
||||
async def on_disable(self):
|
||||
await super().on_disable() # 清理定时任务
|
||||
# 额外的禁用逻辑
|
||||
|
||||
async def on_unload(self):
|
||||
# 释放资源、关闭连接
|
||||
await self.close_connections()
|
||||
|
||||
async def on_reload(self) -> dict:
|
||||
# 返回需要保存的状态
|
||||
return {"counter": self.counter}
|
||||
|
||||
async def restore_state(self, state: dict):
|
||||
# 重载后恢复状态
|
||||
self.counter = state.get("counter", 0)
|
||||
"""
|
||||
|
||||
# ==================== 插件元数据 ====================
|
||||
|
||||
# 插件元数据
|
||||
description: str = "暂无描述"
|
||||
author: str = "未知"
|
||||
version: str = "1.0.0"
|
||||
|
||||
# 插件依赖(填写依赖的插件类名列表)
|
||||
# 例如: dependencies = ["MessageLogger", "AIChat"]
|
||||
dependencies: List[str] = []
|
||||
|
||||
# 加载优先级(数值越大越先加载,默认50)
|
||||
# 基础插件设置高优先级,依赖其他插件的设置低优先级
|
||||
load_priority: int = 50
|
||||
|
||||
# ==================== 实例属性 ====================
|
||||
|
||||
def __init__(self):
|
||||
self.enabled = False
|
||||
self._scheduled_jobs = set()
|
||||
self.state = PluginState.UNLOADED
|
||||
self._scheduled_jobs: set = set()
|
||||
self._bot = None
|
||||
self._plugin_manager: Optional["PluginManager"] = None
|
||||
self._saved_state: Dict[str, Any] = {}
|
||||
|
||||
# ==================== 生命周期钩子 ====================
|
||||
|
||||
async def on_load(self, plugin_manager: "PluginManager"):
|
||||
"""
|
||||
插件加载时调用
|
||||
|
||||
此时其他依赖的插件已经加载完成,可以安全访问。
|
||||
|
||||
Args:
|
||||
plugin_manager: 插件管理器实例
|
||||
"""
|
||||
self._plugin_manager = plugin_manager
|
||||
self.state = PluginState.LOADED
|
||||
logger.debug(f"[{self.__class__.__name__}] on_load 调用")
|
||||
|
||||
async def async_init(self):
|
||||
"""
|
||||
插件异步初始化
|
||||
|
||||
用于加载配置、初始化资源等耗时操作。
|
||||
在 on_load 之后、on_enable 之前调用。
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_enable(self, bot=None):
|
||||
"""插件启用时调用"""
|
||||
"""
|
||||
插件启用时调用
|
||||
|
||||
# 定时任务
|
||||
注册定时任务、启动后台服务等。
|
||||
|
||||
Args:
|
||||
bot: WechatHookClient 实例
|
||||
"""
|
||||
self._bot = bot
|
||||
self.enabled = True
|
||||
self.state = PluginState.ENABLED
|
||||
|
||||
# 注册定时任务
|
||||
for method_name in dir(self):
|
||||
method = getattr(self, method_name)
|
||||
if hasattr(method, '_is_scheduled'):
|
||||
@@ -39,18 +148,85 @@ class PluginBase(ABC):
|
||||
|
||||
add_job_safe(scheduler, job_id, method, bot, trigger, **trigger_args)
|
||||
self._scheduled_jobs.add(job_id)
|
||||
|
||||
if self._scheduled_jobs:
|
||||
logger.success("插件 {} 已加载定时任务: {}", self.__class__.__name__, self._scheduled_jobs)
|
||||
logger.success(f"插件 {self.__class__.__name__} 已加载定时任务: {self._scheduled_jobs}")
|
||||
|
||||
async def on_disable(self):
|
||||
"""插件禁用时调用"""
|
||||
|
||||
"""
|
||||
插件禁用时调用
|
||||
|
||||
清理定时任务、停止后台服务等。
|
||||
"""
|
||||
self.enabled = False
|
||||
self.state = PluginState.DISABLED
|
||||
|
||||
# 移除定时任务
|
||||
for job_id in self._scheduled_jobs:
|
||||
remove_job_safe(scheduler, job_id)
|
||||
logger.info("已卸载定时任务: {}", self._scheduled_jobs)
|
||||
|
||||
if self._scheduled_jobs:
|
||||
logger.info(f"已卸载定时任务: {self._scheduled_jobs}")
|
||||
self._scheduled_jobs.clear()
|
||||
|
||||
async def async_init(self):
|
||||
"""插件异步初始化"""
|
||||
return
|
||||
async def on_unload(self):
|
||||
"""
|
||||
插件卸载时调用
|
||||
|
||||
释放资源、关闭连接、保存数据等。
|
||||
在 on_disable 之后调用。
|
||||
"""
|
||||
self.state = PluginState.UNLOADED
|
||||
self._bot = None
|
||||
self._plugin_manager = None
|
||||
logger.debug(f"[{self.__class__.__name__}] on_unload 调用")
|
||||
|
||||
async def on_reload(self) -> Dict[str, Any]:
|
||||
"""
|
||||
插件重载前调用
|
||||
|
||||
返回需要在重载后恢复的状态数据。
|
||||
|
||||
Returns:
|
||||
需要保存的状态字典
|
||||
"""
|
||||
logger.debug(f"[{self.__class__.__name__}] on_reload 调用")
|
||||
return {}
|
||||
|
||||
async def restore_state(self, state: Dict[str, Any]):
|
||||
"""
|
||||
重载后恢复状态
|
||||
|
||||
Args:
|
||||
state: on_reload 返回的状态字典
|
||||
"""
|
||||
self._saved_state = state
|
||||
logger.debug(f"[{self.__class__.__name__}] 状态已恢复: {list(state.keys())}")
|
||||
|
||||
# ==================== 辅助方法 ====================
|
||||
|
||||
def get_plugin(self, plugin_name: str) -> Optional["PluginBase"]:
|
||||
"""
|
||||
获取其他插件实例
|
||||
|
||||
Args:
|
||||
plugin_name: 插件类名
|
||||
|
||||
Returns:
|
||||
插件实例,不存在返回 None
|
||||
"""
|
||||
if self._plugin_manager:
|
||||
return self._plugin_manager.plugins.get(plugin_name)
|
||||
return None
|
||||
|
||||
def get_bot(self):
|
||||
"""获取 bot 实例"""
|
||||
return self._bot
|
||||
|
||||
@property
|
||||
def plugin_name(self) -> str:
|
||||
"""获取插件名称"""
|
||||
return self.__class__.__name__
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} v{self.version} state={self.state.value}>"
|
||||
|
||||
213
utils/plugin_inject.py
Normal file
213
utils/plugin_inject.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
插件依赖注入模块
|
||||
|
||||
提供插件间依赖注入功能:
|
||||
- @inject 装饰器自动注入依赖
|
||||
- 延迟注入(lazy injection)避免循环依赖
|
||||
- 类型安全的依赖获取
|
||||
|
||||
使用示例:
|
||||
from utils.plugin_inject import inject, require_plugin
|
||||
|
||||
class MyPlugin(PluginBase):
|
||||
# 方式1: 使用装饰器注入
|
||||
@inject("AIChat")
|
||||
def get_aichat(self) -> "AIChat":
|
||||
pass # 自动注入,无需实现
|
||||
|
||||
# 方式2: 使用 require_plugin
|
||||
async def some_method(self):
|
||||
aichat = require_plugin("AIChat")
|
||||
await aichat.do_something()
|
||||
|
||||
# 方式3: 使用基类的 get_plugin
|
||||
async def another_method(self):
|
||||
aichat = self.get_plugin("AIChat")
|
||||
if aichat:
|
||||
await aichat.do_something()
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Optional, Type, TypeVar, TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from utils.plugin_base import PluginBase
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class PluginNotAvailableError(Exception):
|
||||
"""插件不可用错误"""
|
||||
pass
|
||||
|
||||
|
||||
def _get_plugin_manager():
|
||||
"""延迟获取 PluginManager 避免循环导入"""
|
||||
from utils.plugin_manager import PluginManager
|
||||
return PluginManager()
|
||||
|
||||
|
||||
def require_plugin(plugin_name: str) -> "PluginBase":
|
||||
"""
|
||||
获取必需的插件(不存在则抛出异常)
|
||||
|
||||
Args:
|
||||
plugin_name: 插件类名
|
||||
|
||||
Returns:
|
||||
插件实例
|
||||
|
||||
Raises:
|
||||
PluginNotAvailableError: 插件不存在或未启用
|
||||
"""
|
||||
pm = _get_plugin_manager()
|
||||
plugin = pm.plugins.get(plugin_name)
|
||||
if plugin is None:
|
||||
raise PluginNotAvailableError(f"插件 {plugin_name} 不可用")
|
||||
return plugin
|
||||
|
||||
|
||||
def get_plugin(plugin_name: str) -> Optional["PluginBase"]:
|
||||
"""
|
||||
获取插件(不存在返回 None)
|
||||
|
||||
Args:
|
||||
plugin_name: 插件类名
|
||||
|
||||
Returns:
|
||||
插件实例或 None
|
||||
"""
|
||||
pm = _get_plugin_manager()
|
||||
return pm.plugins.get(plugin_name)
|
||||
|
||||
|
||||
def inject(plugin_name: str) -> Callable:
|
||||
"""
|
||||
插件注入装饰器
|
||||
|
||||
将方法转换为属性 getter,自动返回指定插件实例。
|
||||
|
||||
Args:
|
||||
plugin_name: 要注入的插件类名
|
||||
|
||||
Usage:
|
||||
class MyPlugin(PluginBase):
|
||||
@inject("AIChat")
|
||||
def aichat(self) -> "AIChat":
|
||||
pass # 无需实现
|
||||
|
||||
async def handle(self, bot, message):
|
||||
# 直接使用
|
||||
await self.aichat.process(message)
|
||||
"""
|
||||
def decorator(method: Callable) -> property:
|
||||
@wraps(method)
|
||||
def getter(self) -> Optional["PluginBase"]:
|
||||
# 优先使用插件自身的 _plugin_manager
|
||||
if hasattr(self, '_plugin_manager') and self._plugin_manager:
|
||||
return self._plugin_manager.plugins.get(plugin_name)
|
||||
# 回退到全局 PluginManager
|
||||
return get_plugin(plugin_name)
|
||||
|
||||
return property(getter)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def inject_required(plugin_name: str) -> Callable:
|
||||
"""
|
||||
必需插件注入装饰器
|
||||
|
||||
与 inject 类似,但如果插件不存在会抛出异常。
|
||||
|
||||
Args:
|
||||
plugin_name: 要注入的插件类名
|
||||
|
||||
Raises:
|
||||
PluginNotAvailableError: 插件不存在
|
||||
"""
|
||||
def decorator(method: Callable) -> property:
|
||||
@wraps(method)
|
||||
def getter(self) -> "PluginBase":
|
||||
plugin = None
|
||||
if hasattr(self, '_plugin_manager') and self._plugin_manager:
|
||||
plugin = self._plugin_manager.plugins.get(plugin_name)
|
||||
else:
|
||||
plugin = get_plugin(plugin_name)
|
||||
|
||||
if plugin is None:
|
||||
raise PluginNotAvailableError(
|
||||
f"插件 {self.__class__.__name__} 依赖的 {plugin_name} 不可用"
|
||||
)
|
||||
return plugin
|
||||
|
||||
return property(getter)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class PluginProxy:
|
||||
"""
|
||||
插件代理
|
||||
|
||||
延迟获取插件,避免初始化时的循环依赖问题。
|
||||
|
||||
Usage:
|
||||
class MyPlugin(PluginBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._aichat = PluginProxy("AIChat")
|
||||
|
||||
async def handle(self):
|
||||
# 首次访问时才获取插件
|
||||
if self._aichat.available:
|
||||
await self._aichat.instance.process()
|
||||
"""
|
||||
|
||||
def __init__(self, plugin_name: str):
|
||||
self._plugin_name = plugin_name
|
||||
self._cached_instance: Optional["PluginBase"] = None
|
||||
self._checked = False
|
||||
|
||||
@property
|
||||
def instance(self) -> Optional["PluginBase"]:
|
||||
"""获取插件实例(带缓存)"""
|
||||
if not self._checked:
|
||||
self._cached_instance = get_plugin(self._plugin_name)
|
||||
self._checked = True
|
||||
return self._cached_instance
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
"""检查插件是否可用"""
|
||||
return self.instance is not None
|
||||
|
||||
def require(self) -> "PluginBase":
|
||||
"""获取插件,不存在则抛出异常"""
|
||||
inst = self.instance
|
||||
if inst is None:
|
||||
raise PluginNotAvailableError(f"插件 {self._plugin_name} 不可用")
|
||||
return inst
|
||||
|
||||
def invalidate(self):
|
||||
"""清除缓存,下次访问重新获取"""
|
||||
self._cached_instance = None
|
||||
self._checked = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
status = "available" if self.available else "unavailable"
|
||||
return f"<PluginProxy({self._plugin_name}) {status}>"
|
||||
|
||||
|
||||
# ==================== 导出 ====================
|
||||
|
||||
__all__ = [
|
||||
'PluginNotAvailableError',
|
||||
'require_plugin',
|
||||
'get_plugin',
|
||||
'inject',
|
||||
'inject_required',
|
||||
'PluginProxy',
|
||||
]
|
||||
@@ -2,7 +2,6 @@ import importlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import tomllib
|
||||
import traceback
|
||||
from typing import Dict, Type, List, Union
|
||||
|
||||
@@ -10,6 +9,8 @@ from loguru import logger
|
||||
|
||||
# from WechatAPI import WechatAPIClient # 注释掉,WechatHookBot 不需要这个导入
|
||||
from utils.singleton import Singleton
|
||||
from utils.config_manager import get_bot_config
|
||||
from utils.llm_tooling import register_plugin_tools, unregister_plugin_tools
|
||||
from .event_manager import EventManager
|
||||
from .plugin_base import PluginBase
|
||||
|
||||
@@ -22,10 +23,9 @@ class PluginManager(metaclass=Singleton):
|
||||
|
||||
self.bot = None
|
||||
|
||||
with open("main_config.toml", "rb") as f:
|
||||
main_config = tomllib.load(f)
|
||||
|
||||
self.excluded_plugins = main_config.get("Bot", {}).get("disabled-plugins", [])
|
||||
# 使用统一配置管理器
|
||||
bot_config = get_bot_config()
|
||||
self.excluded_plugins = bot_config.get("disabled-plugins", [])
|
||||
|
||||
def set_bot(self, bot):
|
||||
"""设置 bot 客户端(WechatHookClient)"""
|
||||
@@ -74,13 +74,34 @@ class PluginManager(metaclass=Singleton):
|
||||
if is_disabled:
|
||||
return False
|
||||
|
||||
# 创建插件实例
|
||||
plugin = plugin_class()
|
||||
EventManager.bind_instance(plugin)
|
||||
await plugin.on_enable(self.bot)
|
||||
|
||||
# 生命周期: on_load(可访问其他插件)
|
||||
await plugin.on_load(self)
|
||||
|
||||
# 生命周期: async_init(加载配置、资源)
|
||||
await plugin.async_init()
|
||||
|
||||
# 绑定事件处理器
|
||||
EventManager.bind_instance(plugin)
|
||||
|
||||
# 生命周期: on_enable(注册定时任务)
|
||||
await plugin.on_enable(self.bot)
|
||||
|
||||
# 注册到插件管理器
|
||||
self.plugins[plugin_name] = plugin
|
||||
self.plugin_classes[plugin_name] = plugin_class
|
||||
self.plugin_info[plugin_name]["enabled"] = True
|
||||
|
||||
# 注册插件的 LLM 工具到全局注册中心
|
||||
try:
|
||||
tools_config = self._get_tools_config()
|
||||
timeout_config = self._get_timeout_config()
|
||||
register_plugin_tools(plugin_name, plugin, tools_config, timeout_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"注册插件 {plugin_name} 的工具时出错: {e}")
|
||||
|
||||
logger.success(f"加载插件 {plugin_name} 成功")
|
||||
return True
|
||||
except:
|
||||
@@ -232,8 +253,22 @@ class PluginManager(metaclass=Singleton):
|
||||
|
||||
try:
|
||||
plugin = self.plugins[plugin_name]
|
||||
|
||||
# 生命周期: on_disable(清理定时任务)
|
||||
await plugin.on_disable()
|
||||
|
||||
# 解绑事件处理器
|
||||
EventManager.unbind_instance(plugin)
|
||||
|
||||
# 从全局注册中心注销插件的工具
|
||||
try:
|
||||
unregister_plugin_tools(plugin_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"注销插件 {plugin_name} 的工具时出错: {e}")
|
||||
|
||||
# 生命周期: on_unload(释放资源)
|
||||
await plugin.on_unload()
|
||||
|
||||
del self.plugins[plugin_name]
|
||||
del self.plugin_classes[plugin_name]
|
||||
if plugin_name in self.plugin_info.keys():
|
||||
@@ -256,7 +291,7 @@ class PluginManager(metaclass=Singleton):
|
||||
return unloaded_plugins, failed_unloads
|
||||
|
||||
async def reload_plugin(self, plugin_name: str) -> bool:
|
||||
"""重载单个插件"""
|
||||
"""重载单个插件(支持状态保存和恢复)"""
|
||||
if plugin_name not in self.plugin_classes:
|
||||
return False
|
||||
|
||||
@@ -270,7 +305,15 @@ class PluginManager(metaclass=Singleton):
|
||||
plugin_class = self.plugin_classes[plugin_name]
|
||||
module_name = plugin_class.__module__
|
||||
|
||||
# 先卸载插件
|
||||
# 生命周期: on_reload(保存状态)
|
||||
saved_state = {}
|
||||
if plugin_name in self.plugins:
|
||||
try:
|
||||
saved_state = await self.plugins[plugin_name].on_reload()
|
||||
except Exception as e:
|
||||
logger.warning(f"保存插件 {plugin_name} 状态失败: {e}")
|
||||
|
||||
# 卸载插件
|
||||
if not await self.unload_plugin(plugin_name):
|
||||
return False
|
||||
|
||||
@@ -284,8 +327,15 @@ class PluginManager(metaclass=Singleton):
|
||||
issubclass(obj, PluginBase) and
|
||||
obj != PluginBase and
|
||||
obj.__name__ == plugin_name):
|
||||
# 使用新的插件类而不是旧的
|
||||
return await self.load_plugin(obj)
|
||||
# 加载新插件
|
||||
if await self.load_plugin(obj):
|
||||
# 恢复状态
|
||||
if saved_state and plugin_name in self.plugins:
|
||||
try:
|
||||
await self.plugins[plugin_name].restore_state(saved_state)
|
||||
except Exception as e:
|
||||
logger.warning(f"恢复插件 {plugin_name} 状态失败: {e}")
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
@@ -349,13 +399,42 @@ class PluginManager(metaclass=Singleton):
|
||||
|
||||
def get_plugin_info(self, plugin_name: str = None) -> Union[dict, List[dict]]:
|
||||
"""获取插件信息
|
||||
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称,如果为None则返回所有插件信息
|
||||
|
||||
|
||||
Returns:
|
||||
如果指定插件名,返回单个插件信息字典;否则返回所有插件信息列表
|
||||
"""
|
||||
if plugin_name:
|
||||
return self.plugin_info.get(plugin_name)
|
||||
return list(self.plugin_info.values())
|
||||
|
||||
def _get_tools_config(self) -> dict:
|
||||
"""获取工具配置(用于工具注册)"""
|
||||
try:
|
||||
# 尝试从 AIChat 插件配置读取
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
aichat_config_path = Path("plugins/AIChat/config.toml")
|
||||
if aichat_config_path.exists():
|
||||
with open(aichat_config_path, "rb") as f:
|
||||
aichat_config = tomllib.load(f)
|
||||
return aichat_config.get("tools", {})
|
||||
except Exception:
|
||||
pass
|
||||
return {"mode": "all", "whitelist": [], "blacklist": []}
|
||||
|
||||
def _get_timeout_config(self) -> dict:
|
||||
"""获取工具超时配置"""
|
||||
try:
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
aichat_config_path = Path("plugins/AIChat/config.toml")
|
||||
if aichat_config_path.exists():
|
||||
with open(aichat_config_path, "rb") as f:
|
||||
aichat_config = tomllib.load(f)
|
||||
return aichat_config.get("tools", {}).get("timeout", {})
|
||||
except Exception:
|
||||
pass
|
||||
return {"default": 60}
|
||||
|
||||
488
utils/tool_executor.py
Normal file
488
utils/tool_executor.py
Normal file
@@ -0,0 +1,488 @@
|
||||
"""
|
||||
工具执行器模块
|
||||
|
||||
提供工具调用的高级执行逻辑:
|
||||
- 批量工具执行(支持并行)
|
||||
- 工具调用链处理
|
||||
- 执行日志和审计
|
||||
- 结果聚合
|
||||
|
||||
使用示例:
|
||||
from utils.tool_executor import ToolExecutor, ToolCallRequest
|
||||
|
||||
executor = ToolExecutor()
|
||||
|
||||
# 单个工具执行
|
||||
result = await executor.execute_single(
|
||||
tool_call={"id": "call_1", "function": {"name": "get_weather", "arguments": "{}"}},
|
||||
bot=bot,
|
||||
from_wxid=wxid,
|
||||
)
|
||||
|
||||
# 批量工具执行
|
||||
results = await executor.execute_batch(
|
||||
tool_calls=[...],
|
||||
bot=bot,
|
||||
from_wxid=wxid,
|
||||
parallel=True,
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallRequest:
|
||||
"""工具调用请求"""
|
||||
id: str
|
||||
name: str
|
||||
arguments: Dict[str, Any]
|
||||
raw_arguments: str = "" # 原始 JSON 字符串
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallResult:
|
||||
"""工具调用结果"""
|
||||
id: str
|
||||
name: str
|
||||
success: bool = True
|
||||
message: str = ""
|
||||
raw_result: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# 控制标志
|
||||
need_ai_reply: bool = False
|
||||
already_sent: bool = False
|
||||
send_result_text: bool = False
|
||||
no_reply: bool = False
|
||||
save_to_memory: bool = False
|
||||
|
||||
# 执行信息
|
||||
execution_time_ms: float = 0.0
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_message(self) -> Dict[str, Any]:
|
||||
"""转换为 OpenAI 兼容的 tool message"""
|
||||
content = self.message if self.success else f"错误: {self.error or self.message}"
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": self.id,
|
||||
"content": content
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionStats:
|
||||
"""执行统计"""
|
||||
total_calls: int = 0
|
||||
successful_calls: int = 0
|
||||
failed_calls: int = 0
|
||||
timeout_calls: int = 0
|
||||
total_time_ms: float = 0.0
|
||||
avg_time_ms: float = 0.0
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""
|
||||
工具执行器
|
||||
|
||||
提供统一的工具执行接口:
|
||||
- 参数解析和校验
|
||||
- 超时保护
|
||||
- 错误处理
|
||||
- 执行统计
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_timeout: float = 60.0,
|
||||
max_parallel: int = 5,
|
||||
validate_args: bool = True,
|
||||
):
|
||||
self.default_timeout = default_timeout
|
||||
self.max_parallel = max_parallel
|
||||
self.validate_args = validate_args
|
||||
self._stats = ExecutionStats()
|
||||
|
||||
def parse_tool_call(self, tool_call: Dict[str, Any]) -> ToolCallRequest:
|
||||
"""
|
||||
解析 OpenAI 格式的工具调用
|
||||
|
||||
Args:
|
||||
tool_call: OpenAI 返回的 tool_call 对象
|
||||
|
||||
Returns:
|
||||
ToolCallRequest 对象
|
||||
"""
|
||||
call_id = tool_call.get("id", "")
|
||||
function = tool_call.get("function", {})
|
||||
name = function.get("name", "")
|
||||
raw_args = function.get("arguments", "{}")
|
||||
|
||||
# 解析 arguments JSON
|
||||
try:
|
||||
arguments = json.loads(raw_args) if raw_args else {}
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"[ToolExecutor] 解析参数失败: {e}, raw={raw_args[:100]}")
|
||||
arguments = {}
|
||||
|
||||
return ToolCallRequest(
|
||||
id=call_id,
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
raw_arguments=raw_args,
|
||||
)
|
||||
|
||||
async def execute_single(
|
||||
self,
|
||||
tool_call: Dict[str, Any],
|
||||
bot,
|
||||
from_wxid: str,
|
||||
timeout_override: Optional[float] = None,
|
||||
) -> ToolCallResult:
|
||||
"""
|
||||
执行单个工具调用
|
||||
|
||||
Args:
|
||||
tool_call: OpenAI 格式的 tool_call
|
||||
bot: WechatHookClient 实例
|
||||
from_wxid: 消息来源 wxid
|
||||
timeout_override: 覆盖默认超时
|
||||
|
||||
Returns:
|
||||
ToolCallResult 对象
|
||||
"""
|
||||
from utils.tool_registry import get_tool_registry
|
||||
from utils.llm_tooling import validate_tool_arguments, ToolResult
|
||||
|
||||
start_time = time.time()
|
||||
request = self.parse_tool_call(tool_call)
|
||||
registry = get_tool_registry()
|
||||
|
||||
result = ToolCallResult(
|
||||
id=request.id,
|
||||
name=request.name,
|
||||
)
|
||||
|
||||
# 获取工具定义
|
||||
tool_def = registry.get(request.name)
|
||||
if not tool_def:
|
||||
result.success = False
|
||||
result.error = f"工具 {request.name} 不存在"
|
||||
result.message = result.error
|
||||
self._update_stats(False, time.time() - start_time)
|
||||
return result
|
||||
|
||||
# 参数校验
|
||||
if self.validate_args:
|
||||
schema = tool_def.schema.get("function", {}).get("parameters", {})
|
||||
ok, error_msg, validated_args = validate_tool_arguments(
|
||||
request.name, request.arguments, schema
|
||||
)
|
||||
if not ok:
|
||||
result.success = False
|
||||
result.error = error_msg
|
||||
result.message = error_msg
|
||||
self._update_stats(False, time.time() - start_time)
|
||||
return result
|
||||
request.arguments = validated_args
|
||||
|
||||
# 执行工具
|
||||
timeout = timeout_override or tool_def.timeout or self.default_timeout
|
||||
|
||||
try:
|
||||
logger.debug(f"[ToolExecutor] 执行工具: {request.name}")
|
||||
|
||||
raw_result = await asyncio.wait_for(
|
||||
tool_def.executor(request.name, request.arguments, bot, from_wxid),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# 解析结果
|
||||
tool_result = ToolResult.from_raw(raw_result)
|
||||
if tool_result:
|
||||
result.success = tool_result.success
|
||||
result.message = tool_result.message
|
||||
result.need_ai_reply = tool_result.need_ai_reply
|
||||
result.already_sent = tool_result.already_sent
|
||||
result.send_result_text = tool_result.send_result_text
|
||||
result.no_reply = tool_result.no_reply
|
||||
result.save_to_memory = tool_result.save_to_memory
|
||||
else:
|
||||
result.message = str(raw_result) if raw_result else "执行完成"
|
||||
|
||||
result.raw_result = raw_result if isinstance(raw_result, dict) else {"result": raw_result}
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
result.execution_time_ms = execution_time * 1000
|
||||
self._update_stats(result.success, execution_time)
|
||||
|
||||
logger.debug(
|
||||
f"[ToolExecutor] 工具 {request.name} 执行完成 "
|
||||
f"({result.execution_time_ms:.1f}ms)"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
result.success = False
|
||||
result.error = f"执行超时 ({timeout}s)"
|
||||
result.message = result.error
|
||||
self._update_stats(False, time.time() - start_time, timeout=True)
|
||||
logger.warning(f"[ToolExecutor] 工具 {request.name} 执行超时")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
result.success = False
|
||||
result.error = str(e)
|
||||
result.message = f"执行失败: {e}"
|
||||
self._update_stats(False, time.time() - start_time)
|
||||
logger.error(f"[ToolExecutor] 工具 {request.name} 执行异常: {e}")
|
||||
|
||||
return result
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
bot,
|
||||
from_wxid: str,
|
||||
parallel: bool = True,
|
||||
stop_on_error: bool = False,
|
||||
) -> List[ToolCallResult]:
|
||||
"""
|
||||
批量执行工具调用
|
||||
|
||||
Args:
|
||||
tool_calls: 工具调用列表
|
||||
bot: WechatHookClient 实例
|
||||
from_wxid: 消息来源 wxid
|
||||
parallel: 是否并行执行
|
||||
stop_on_error: 遇到错误是否停止
|
||||
|
||||
Returns:
|
||||
ToolCallResult 列表
|
||||
"""
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
if parallel and len(tool_calls) > 1:
|
||||
return await self._execute_parallel(tool_calls, bot, from_wxid, stop_on_error)
|
||||
else:
|
||||
return await self._execute_sequential(tool_calls, bot, from_wxid, stop_on_error)
|
||||
|
||||
async def _execute_sequential(
|
||||
self,
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
bot,
|
||||
from_wxid: str,
|
||||
stop_on_error: bool,
|
||||
) -> List[ToolCallResult]:
|
||||
"""顺序执行"""
|
||||
results = []
|
||||
for tool_call in tool_calls:
|
||||
result = await self.execute_single(tool_call, bot, from_wxid)
|
||||
results.append(result)
|
||||
|
||||
if stop_on_error and not result.success:
|
||||
logger.warning(f"[ToolExecutor] 工具 {result.name} 失败,停止批量执行")
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
async def _execute_parallel(
|
||||
self,
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
bot,
|
||||
from_wxid: str,
|
||||
stop_on_error: bool,
|
||||
) -> List[ToolCallResult]:
|
||||
"""并行执行(带并发限制)"""
|
||||
semaphore = asyncio.Semaphore(self.max_parallel)
|
||||
|
||||
async def execute_with_limit(tool_call):
|
||||
async with semaphore:
|
||||
return await self.execute_single(tool_call, bot, from_wxid)
|
||||
|
||||
tasks = [execute_with_limit(tc) for tc in tool_calls]
|
||||
|
||||
if stop_on_error:
|
||||
# 使用 gather 但不 return_exceptions,让第一个错误停止执行
|
||||
results = []
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
try:
|
||||
result = await coro
|
||||
results.append(result)
|
||||
if not result.success:
|
||||
# 取消剩余任务
|
||||
for task in tasks:
|
||||
if isinstance(task, asyncio.Task) and not task.done():
|
||||
task.cancel()
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"[ToolExecutor] 并行执行异常: {e}")
|
||||
break
|
||||
return results
|
||||
else:
|
||||
# 全部执行,收集所有结果
|
||||
return await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
def _update_stats(self, success: bool, execution_time: float, timeout: bool = False):
|
||||
"""更新执行统计"""
|
||||
self._stats.total_calls += 1
|
||||
if success:
|
||||
self._stats.successful_calls += 1
|
||||
else:
|
||||
self._stats.failed_calls += 1
|
||||
if timeout:
|
||||
self._stats.timeout_calls += 1
|
||||
|
||||
self._stats.total_time_ms += execution_time * 1000
|
||||
self._stats.avg_time_ms = self._stats.total_time_ms / self._stats.total_calls
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取执行统计"""
|
||||
return {
|
||||
"total_calls": self._stats.total_calls,
|
||||
"successful_calls": self._stats.successful_calls,
|
||||
"failed_calls": self._stats.failed_calls,
|
||||
"timeout_calls": self._stats.timeout_calls,
|
||||
"total_time_ms": self._stats.total_time_ms,
|
||||
"avg_time_ms": self._stats.avg_time_ms,
|
||||
"success_rate": (
|
||||
self._stats.successful_calls / self._stats.total_calls
|
||||
if self._stats.total_calls > 0 else 0
|
||||
),
|
||||
}
|
||||
|
||||
def reset_stats(self):
|
||||
"""重置统计"""
|
||||
self._stats = ExecutionStats()
|
||||
|
||||
|
||||
class ToolCallChain:
|
||||
"""
|
||||
工具调用链
|
||||
|
||||
用于处理需要多轮工具调用的场景,记录调用历史。
|
||||
"""
|
||||
|
||||
def __init__(self, max_rounds: int = 10):
|
||||
self.max_rounds = max_rounds
|
||||
self.history: List[ToolCallResult] = []
|
||||
self.current_round = 0
|
||||
|
||||
def add_result(self, result: ToolCallResult):
|
||||
"""添加调用结果"""
|
||||
self.history.append(result)
|
||||
|
||||
def add_results(self, results: List[ToolCallResult]):
|
||||
"""添加多个调用结果"""
|
||||
self.history.extend(results)
|
||||
|
||||
def increment_round(self):
|
||||
"""增加轮次"""
|
||||
self.current_round += 1
|
||||
|
||||
def can_continue(self) -> bool:
|
||||
"""检查是否可以继续调用"""
|
||||
return self.current_round < self.max_rounds
|
||||
|
||||
def get_tool_messages(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有工具调用的消息(用于发送给 LLM)"""
|
||||
return [result.to_message() for result in self.history]
|
||||
|
||||
def get_last_results(self, n: int = 1) -> List[ToolCallResult]:
|
||||
"""获取最后 n 个结果"""
|
||||
return self.history[-n:] if self.history else []
|
||||
|
||||
def has_special_flags(self) -> Dict[str, bool]:
|
||||
"""检查是否有特殊标志"""
|
||||
flags = {
|
||||
"need_ai_reply": False,
|
||||
"already_sent": False,
|
||||
"no_reply": False,
|
||||
"save_to_memory": False,
|
||||
"send_result_text": False,
|
||||
}
|
||||
|
||||
for result in self.history:
|
||||
if result.need_ai_reply:
|
||||
flags["need_ai_reply"] = True
|
||||
if result.already_sent:
|
||||
flags["already_sent"] = True
|
||||
if result.no_reply:
|
||||
flags["no_reply"] = True
|
||||
if result.save_to_memory:
|
||||
flags["save_to_memory"] = True
|
||||
if result.send_result_text:
|
||||
flags["send_result_text"] = True
|
||||
|
||||
return flags
|
||||
|
||||
def get_summary(self) -> str:
|
||||
"""获取调用链摘要"""
|
||||
if not self.history:
|
||||
return "无工具调用"
|
||||
|
||||
successful = sum(1 for r in self.history if r.success)
|
||||
failed = len(self.history) - successful
|
||||
total_time = sum(r.execution_time_ms for r in self.history)
|
||||
|
||||
tools_called = [r.name for r in self.history]
|
||||
|
||||
return (
|
||||
f"调用链: {len(self.history)} 个工具, "
|
||||
f"成功 {successful}, 失败 {failed}, "
|
||||
f"总耗时 {total_time:.1f}ms, "
|
||||
f"工具: {', '.join(tools_called)}"
|
||||
)
|
||||
|
||||
|
||||
# ==================== 便捷函数 ====================
|
||||
|
||||
_default_executor: Optional[ToolExecutor] = None
|
||||
|
||||
|
||||
def get_tool_executor(
|
||||
default_timeout: float = 60.0,
|
||||
max_parallel: int = 5,
|
||||
) -> ToolExecutor:
|
||||
"""获取默认工具执行器"""
|
||||
global _default_executor
|
||||
if _default_executor is None:
|
||||
_default_executor = ToolExecutor(
|
||||
default_timeout=default_timeout,
|
||||
max_parallel=max_parallel,
|
||||
)
|
||||
return _default_executor
|
||||
|
||||
|
||||
async def execute_tool_calls(
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
bot,
|
||||
from_wxid: str,
|
||||
parallel: bool = True,
|
||||
) -> List[ToolCallResult]:
|
||||
"""便捷函数:执行工具调用列表"""
|
||||
executor = get_tool_executor()
|
||||
return await executor.execute_batch(tool_calls, bot, from_wxid, parallel=parallel)
|
||||
|
||||
|
||||
# ==================== 导出 ====================
|
||||
|
||||
__all__ = [
|
||||
'ToolCallRequest',
|
||||
'ToolCallResult',
|
||||
'ExecutionStats',
|
||||
'ToolExecutor',
|
||||
'ToolCallChain',
|
||||
'get_tool_executor',
|
||||
'execute_tool_calls',
|
||||
]
|
||||
286
utils/tool_registry.py
Normal file
286
utils/tool_registry.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
工具注册中心
|
||||
|
||||
集中管理所有 LLM 工具的注册、查找和执行
|
||||
- O(1) 工具查找(替代 O(n) 插件遍历)
|
||||
- 统一的超时保护
|
||||
- 工具元信息管理
|
||||
|
||||
使用示例:
|
||||
from utils.tool_registry import get_tool_registry
|
||||
|
||||
registry = get_tool_registry()
|
||||
|
||||
# 注册工具
|
||||
registry.register(
|
||||
name="generate_image",
|
||||
plugin_name="AIChat",
|
||||
schema={...},
|
||||
executor=some_async_func,
|
||||
timeout=120
|
||||
)
|
||||
|
||||
# 执行工具
|
||||
result = await registry.execute("generate_image", arguments, bot, from_wxid)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Dict, List, Optional, Awaitable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
"""工具定义"""
|
||||
name: str
|
||||
plugin_name: str
|
||||
schema: Dict[str, Any] # OpenAI-compatible tool schema
|
||||
executor: Callable[..., Awaitable[Dict[str, Any]]]
|
||||
timeout: float = 60.0
|
||||
priority: int = 50 # 同名工具时优先级高的生效
|
||||
description: str = ""
|
||||
|
||||
def __post_init__(self):
|
||||
# 从 schema 提取描述
|
||||
if not self.description and self.schema:
|
||||
func_def = self.schema.get("function", {})
|
||||
self.description = func_def.get("description", "")
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""
|
||||
工具注册中心(线程安全单例)
|
||||
|
||||
功能:
|
||||
- 工具注册与注销
|
||||
- O(1) 工具查找
|
||||
- 统一超时保护执行
|
||||
- 工具列表导出(供 LLM 使用)
|
||||
"""
|
||||
|
||||
_instance: Optional["ToolRegistry"] = None
|
||||
_lock = Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
instance = super().__new__(cls)
|
||||
instance._initialized = False
|
||||
cls._instance = instance
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._tools: Dict[str, ToolDefinition] = {}
|
||||
self._tools_by_plugin: Dict[str, List[str]] = {} # plugin_name -> [tool_names]
|
||||
self._registry_lock = Lock()
|
||||
self._initialized = True
|
||||
logger.debug("ToolRegistry 初始化完成")
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
plugin_name: str,
|
||||
schema: Dict[str, Any],
|
||||
executor: Callable[..., Awaitable[Dict[str, Any]]],
|
||||
timeout: float = 60.0,
|
||||
priority: int = 50,
|
||||
) -> bool:
|
||||
"""
|
||||
注册工具
|
||||
|
||||
Args:
|
||||
name: 工具名称(唯一标识)
|
||||
plugin_name: 所属插件名
|
||||
schema: OpenAI-compatible tool schema
|
||||
executor: 异步执行函数,签名: async (tool_name, arguments, bot, from_wxid) -> dict
|
||||
timeout: 执行超时(秒)
|
||||
priority: 优先级(同名工具时高优先级覆盖低优先级)
|
||||
|
||||
Returns:
|
||||
是否注册成功
|
||||
"""
|
||||
with self._registry_lock:
|
||||
# 检查是否已存在同名工具
|
||||
existing = self._tools.get(name)
|
||||
if existing:
|
||||
if existing.priority >= priority:
|
||||
logger.debug(f"工具 {name} 已存在且优先级更高,跳过注册")
|
||||
return False
|
||||
logger.info(f"工具 {name} 被 {plugin_name} 覆盖(优先级 {priority} > {existing.priority})")
|
||||
# 从旧插件的工具列表中移除
|
||||
old_plugin_tools = self._tools_by_plugin.get(existing.plugin_name, [])
|
||||
if name in old_plugin_tools:
|
||||
old_plugin_tools.remove(name)
|
||||
|
||||
# 注册新工具
|
||||
tool_def = ToolDefinition(
|
||||
name=name,
|
||||
plugin_name=plugin_name,
|
||||
schema=schema,
|
||||
executor=executor,
|
||||
timeout=timeout,
|
||||
priority=priority,
|
||||
)
|
||||
self._tools[name] = tool_def
|
||||
|
||||
# 更新插件工具映射
|
||||
if plugin_name not in self._tools_by_plugin:
|
||||
self._tools_by_plugin[plugin_name] = []
|
||||
if name not in self._tools_by_plugin[plugin_name]:
|
||||
self._tools_by_plugin[plugin_name].append(name)
|
||||
|
||||
logger.debug(f"注册工具: {name} (插件: {plugin_name}, 超时: {timeout}s)")
|
||||
return True
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""注销工具"""
|
||||
with self._registry_lock:
|
||||
tool_def = self._tools.pop(name, None)
|
||||
if tool_def:
|
||||
plugin_tools = self._tools_by_plugin.get(tool_def.plugin_name, [])
|
||||
if name in plugin_tools:
|
||||
plugin_tools.remove(name)
|
||||
logger.debug(f"注销工具: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def unregister_plugin(self, plugin_name: str) -> int:
|
||||
"""
|
||||
注销插件的所有工具
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名
|
||||
|
||||
Returns:
|
||||
注销的工具数量
|
||||
"""
|
||||
with self._registry_lock:
|
||||
tool_names = self._tools_by_plugin.pop(plugin_name, [])
|
||||
count = 0
|
||||
for name in tool_names:
|
||||
if self._tools.pop(name, None):
|
||||
count += 1
|
||||
if count > 0:
|
||||
logger.info(f"注销插件 {plugin_name} 的 {count} 个工具")
|
||||
return count
|
||||
|
||||
def get(self, name: str) -> Optional[ToolDefinition]:
|
||||
"""获取工具定义(O(1) 查找)"""
|
||||
return self._tools.get(name)
|
||||
|
||||
def get_all_schemas(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有工具的 schema 列表(供 LLM 使用)"""
|
||||
return [tool.schema for tool in self._tools.values()]
|
||||
|
||||
def get_plugin_schemas(self, plugin_name: str) -> List[Dict[str, Any]]:
|
||||
"""获取指定插件的工具 schema 列表"""
|
||||
tool_names = self._tools_by_plugin.get(plugin_name, [])
|
||||
return [self._tools[name].schema for name in tool_names if name in self._tools]
|
||||
|
||||
def list_tools(self) -> List[str]:
|
||||
"""列出所有工具名"""
|
||||
return list(self._tools.keys())
|
||||
|
||||
def list_plugin_tools(self, plugin_name: str) -> List[str]:
|
||||
"""列出插件的所有工具名"""
|
||||
return self._tools_by_plugin.get(plugin_name, []).copy()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
name: str,
|
||||
arguments: Dict[str, Any],
|
||||
bot,
|
||||
from_wxid: str,
|
||||
timeout_override: float = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行工具(带超时保护和统一错误处理)
|
||||
|
||||
Args:
|
||||
name: 工具名
|
||||
arguments: 工具参数
|
||||
bot: WechatHookClient 实例
|
||||
from_wxid: 消息来源 wxid
|
||||
timeout_override: 覆盖默认超时时间
|
||||
|
||||
Returns:
|
||||
工具执行结果字典
|
||||
"""
|
||||
from utils.errors import (
|
||||
ToolNotFoundError, ToolTimeoutError, ToolExecutionError,
|
||||
handle_error
|
||||
)
|
||||
|
||||
tool_def = self._tools.get(name)
|
||||
if not tool_def:
|
||||
err = ToolNotFoundError(f"工具 {name} 不存在")
|
||||
return err.to_dict()
|
||||
|
||||
timeout = timeout_override if timeout_override is not None else tool_def.timeout
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
tool_def.executor(name, arguments, bot, from_wxid),
|
||||
timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
err = ToolTimeoutError(
|
||||
message=f"工具 {name} 执行超时 ({timeout}s)",
|
||||
user_message=f"工具执行超时 ({timeout}s)",
|
||||
context={"tool_name": name, "timeout": timeout}
|
||||
)
|
||||
logger.warning(err.message)
|
||||
result = err.to_dict()
|
||||
result["timeout"] = True
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_result = handle_error(
|
||||
e,
|
||||
context=f"执行工具 {name}",
|
||||
log=True,
|
||||
)
|
||||
return error_result.to_dict()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取注册统计信息"""
|
||||
return {
|
||||
"total_tools": len(self._tools),
|
||||
"plugins": len(self._tools_by_plugin),
|
||||
"tools_by_plugin": {
|
||||
plugin: len(tools)
|
||||
for plugin, tools in self._tools_by_plugin.items()
|
||||
}
|
||||
}
|
||||
|
||||
def clear(self):
|
||||
"""清空所有注册(用于测试或重置)"""
|
||||
with self._registry_lock:
|
||||
self._tools.clear()
|
||||
self._tools_by_plugin.clear()
|
||||
logger.info("ToolRegistry 已清空")
|
||||
|
||||
|
||||
# ==================== 便捷函数 ====================
|
||||
|
||||
def get_tool_registry() -> ToolRegistry:
|
||||
"""获取工具注册中心实例"""
|
||||
return ToolRegistry()
|
||||
|
||||
|
||||
# ==================== 导出列表 ====================
|
||||
|
||||
__all__ = [
|
||||
'ToolDefinition',
|
||||
'ToolRegistry',
|
||||
'get_tool_registry',
|
||||
]
|
||||
Reference in New Issue
Block a user