This commit is contained in:
2025-12-31 17:47:39 +08:00
38 changed files with 4435 additions and 1343 deletions

190
utils/config_manager.py Normal file
View File

@@ -0,0 +1,190 @@
"""
统一配置管理器
单例模式,提供:
- 配置缓存,避免重复读取文件
- 配置热更新检测
- 类型安全的配置访问
"""
import tomllib
from pathlib import Path
from threading import Lock
from typing import Any, Dict, Optional
from loguru import logger
class ConfigManager:
"""
配置管理器 (线程安全单例)
使用示例:
from utils.config_manager import get_config
# 获取单个配置项
admins = get_config().get("Bot", "admins", [])
# 获取整个配置节
bot_config = get_config().get_section("Bot")
# 检查并重新加载
if get_config().reload_if_changed():
logger.info("配置已更新")
"""
_instance: Optional["ConfigManager"] = None
_lock = Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
instance = super().__new__(cls)
instance._initialized = False
cls._instance = instance
return cls._instance
def __init__(self):
if self._initialized:
return
self._config: Dict[str, Any] = {}
self._config_path = Path("main_config.toml")
self._file_mtime: float = 0
self._config_lock = Lock()
self._reload()
self._initialized = True
logger.debug("ConfigManager 初始化完成")
def _reload(self) -> bool:
"""重新加载配置文件"""
try:
if not self._config_path.exists():
logger.warning(f"配置文件不存在: {self._config_path}")
return False
current_mtime = self._config_path.stat().st_mtime
if current_mtime == self._file_mtime and self._config:
return False # 文件未变化
with self._config_lock:
with open(self._config_path, "rb") as f:
self._config = tomllib.load(f)
self._file_mtime = current_mtime
logger.debug("配置文件已重新加载")
return True
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
return False
def get(self, section: str, key: str, default: Any = None) -> Any:
"""
获取配置项
Args:
section: 配置节名称,如 "Bot"
key: 配置项名称,如 "admins"
default: 默认值
Returns:
配置值或默认值
"""
return self._config.get(section, {}).get(key, default)
def get_section(self, section: str) -> Dict[str, Any]:
"""
获取整个配置节
Args:
section: 配置节名称
Returns:
配置节字典的副本
"""
return self._config.get(section, {}).copy()
def get_all(self) -> Dict[str, Any]:
"""获取完整配置(只读副本)"""
return self._config.copy()
def reload_if_changed(self) -> bool:
"""
如果文件有变化则重新加载
Returns:
是否重新加载了配置
"""
try:
if not self._config_path.exists():
return False
current_mtime = self._config_path.stat().st_mtime
if current_mtime != self._file_mtime:
return self._reload()
except Exception:
pass
return False
def force_reload(self) -> bool:
"""强制重新加载配置"""
self._file_mtime = 0
return self._reload()
# ==================== 便捷函数 ====================
def get_config() -> ConfigManager:
"""获取配置管理器实例"""
return ConfigManager()
def get_bot_config() -> Dict[str, Any]:
"""快捷获取 [Bot] 配置节"""
return get_config().get_section("Bot")
def get_performance_config() -> Dict[str, Any]:
"""快捷获取 [Performance] 配置节"""
return get_config().get_section("Performance")
def get_database_config() -> Dict[str, Any]:
"""快捷获取 [Database] 配置节"""
return get_config().get_section("Database")
def get_scheduler_config() -> Dict[str, Any]:
"""快捷获取 [Scheduler] 配置节"""
return get_config().get_section("Scheduler")
def get_queue_config() -> Dict[str, Any]:
"""快捷获取 [Queue] 配置节"""
return get_config().get_section("Queue")
def get_concurrency_config() -> Dict[str, Any]:
"""快捷获取 [Concurrency] 配置节"""
return get_config().get_section("Concurrency")
def get_webui_config() -> Dict[str, Any]:
"""快捷获取 [WebUI] 配置节"""
return get_config().get_section("WebUI")
# ==================== 导出列表 ====================
__all__ = [
'ConfigManager',
'get_config',
'get_bot_config',
'get_performance_config',
'get_database_config',
'get_scheduler_config',
'get_queue_config',
'get_concurrency_config',
'get_webui_config',
]

View File

@@ -1,5 +1,12 @@
"""
消息处理装饰器模块
提供插件消息处理和定时任务的装饰器
使用工厂模式消除重复代码
"""
from functools import wraps
from typing import Callable, Union
from typing import Callable, Dict, Union
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
@@ -8,15 +15,16 @@ from apscheduler.triggers.interval import IntervalTrigger
scheduler = AsyncIOScheduler()
# ==================== 定时任务装饰器 ====================
def schedule(
trigger: Union[str, CronTrigger, IntervalTrigger],
**trigger_args
) -> Callable:
"""
定时任务装饰器
例子:
例子:
- @schedule('interval', seconds=30)
- @schedule('cron', hour=8, minute=30, second=30)
- @schedule('date', run_date='2024-01-01 00:00:00')
@@ -44,23 +52,16 @@ def add_job_safe(scheduler: AsyncIOScheduler, job_id: str, func: Callable, bot,
"""添加函数到定时任务中,如果存在则先删除现有的任务"""
try:
scheduler.remove_job(job_id)
except:
except Exception:
pass
# 读取调度器配置
# 使用统一配置管理器读取调度器配置
try:
import tomllib
from pathlib import Path
config_path = Path("main_config.toml")
if config_path.exists():
with open(config_path, "rb") as f:
config = tomllib.load(f)
scheduler_config = config.get("Scheduler", {})
else:
scheduler_config = {}
except:
from utils.config_manager import get_scheduler_config
scheduler_config = get_scheduler_config()
except Exception:
scheduler_config = {}
# 应用调度器配置
job_kwargs = {
"coalesce": scheduler_config.get("coalesce", True),
@@ -68,7 +69,7 @@ def add_job_safe(scheduler: AsyncIOScheduler, job_id: str, func: Callable, bot,
"misfire_grace_time": scheduler_config.get("misfire_grace_time", 30)
}
job_kwargs.update(trigger_args)
scheduler.add_job(func, trigger, args=[bot], id=job_id, **job_kwargs)
@@ -76,182 +77,106 @@ def remove_job_safe(scheduler: AsyncIOScheduler, job_id: str):
"""从定时任务中移除任务"""
try:
scheduler.remove_job(job_id)
except:
except Exception:
pass
def on_text_message(priority=50):
"""文本消息装饰器"""
# ==================== 消息装饰器工厂 ====================
def decorator(func):
if callable(priority): # 无参数调用时
f = priority
setattr(f, '_event_type', 'text_message')
setattr(f, '_priority', 50)
return f
# 有参数调用时
setattr(func, '_event_type', 'text_message')
setattr(func, '_priority', min(max(priority, 0), 99))
return func
def _create_message_decorator(event_type: str, description: str):
"""
消息装饰器工厂函数
return decorator if not callable(priority) else decorator(priority)
生成支持两种调用方式的装饰器:
- @on_xxx_message (无参数使用默认优先级50)
- @on_xxx_message(priority=80) (有参数,自定义优先级)
Args:
event_type: 事件类型字符串,如 'text_message'
description: 装饰器描述,用于生成文档字符串
def on_image_message(priority=50):
"""图片消息装饰器"""
Returns:
装饰器函数
"""
def decorator_factory(priority=50):
def decorator(func):
# 处理无参数调用: @on_xxx_message 时 priority 实际是被装饰的函数
if callable(priority):
target_func = priority
actual_priority = 50
else:
target_func = func
actual_priority = min(max(priority, 0), 99)
def decorator(func):
setattr(target_func, '_event_type', event_type)
setattr(target_func, '_priority', actual_priority)
return target_func
# 判断调用方式
if callable(priority):
f = priority
setattr(f, '_event_type', 'image_message')
setattr(f, '_priority', 50)
return f
setattr(func, '_event_type', 'image_message')
setattr(func, '_priority', min(max(priority, 0), 99))
return func
return decorator(priority)
return decorator
return decorator if not callable(priority) else decorator(priority)
decorator_factory.__doc__ = f"{description}装饰器"
decorator_factory.__name__ = f"on_{event_type}"
return decorator_factory
def on_voice_message(priority=50):
"""语音消息装饰器"""
# ==================== 消息类型定义 ====================
def decorator(func):
if callable(priority):
f = priority
setattr(f, '_event_type', 'voice_message')
setattr(f, '_priority', 50)
return f
setattr(func, '_event_type', 'voice_message')
setattr(func, '_priority', min(max(priority, 0), 99))
return func
return decorator if not callable(priority) else decorator(priority)
# 事件类型 -> 中文描述 映射表
MESSAGE_DECORATOR_TYPES: Dict[str, str] = {
'text_message': '文本消息',
'image_message': '图片消息',
'voice_message': '语音消息',
'video_message': '视频消息',
'emoji_message': '表情消息',
'file_message': '文件消息',
'quote_message': '引用消息',
'pat_message': '拍一拍',
'at_message': '@消息',
'system_message': '系统消息',
'other_message': '其他消息',
}
def on_emoji_message(priority=50):
"""表情消息装饰器"""
# ==================== 生成所有消息装饰器 ====================
def decorator(func):
if callable(priority):
f = priority
setattr(f, '_event_type', 'emoji_message')
setattr(f, '_priority', 50)
return f
setattr(func, '_event_type', 'emoji_message')
setattr(func, '_priority', min(max(priority, 0), 99))
return func
return decorator if not callable(priority) else decorator(priority)
# 使用工厂函数生成装饰器
on_text_message = _create_message_decorator('text_message', '文本消息')
on_image_message = _create_message_decorator('image_message', '图片消息')
on_voice_message = _create_message_decorator('voice_message', '语音消息')
on_video_message = _create_message_decorator('video_message', '视频消息')
on_emoji_message = _create_message_decorator('emoji_message', '表情消息')
on_file_message = _create_message_decorator('file_message', '文件消息')
on_quote_message = _create_message_decorator('quote_message', '引用消息')
on_pat_message = _create_message_decorator('pat_message', '拍一拍')
on_at_message = _create_message_decorator('at_message', '@消息')
on_system_message = _create_message_decorator('system_message', '系统消息')
on_other_message = _create_message_decorator('other_message', '其他消息')
def on_file_message(priority=50):
"""文件消息装饰器"""
# ==================== 导出列表 ====================
def decorator(func):
if callable(priority):
f = priority
setattr(f, '_event_type', 'file_message')
setattr(f, '_priority', 50)
return f
setattr(func, '_event_type', 'file_message')
setattr(func, '_priority', min(max(priority, 0), 99))
return func
return decorator if not callable(priority) else decorator(priority)
def on_quote_message(priority=50):
"""引用消息装饰器"""
def decorator(func):
if callable(priority):
f = priority
setattr(f, '_event_type', 'quote_message')
setattr(f, '_priority', 50)
return f
setattr(func, '_event_type', 'quote_message')
setattr(func, '_priority', min(max(priority, 0), 99))
return func
return decorator if not callable(priority) else decorator(priority)
def on_video_message(priority=50):
"""视频消息装饰器"""
def decorator(func):
if callable(priority):
f = priority
setattr(f, '_event_type', 'video_message')
setattr(f, '_priority', 50)
return f
setattr(func, '_event_type', 'video_message')
setattr(func, '_priority', min(max(priority, 0), 99))
return func
return decorator if not callable(priority) else decorator(priority)
def on_pat_message(priority=50):
"""拍一拍消息装饰器"""
def decorator(func):
if callable(priority):
f = priority
setattr(f, '_event_type', 'pat_message')
setattr(f, '_priority', 50)
return f
setattr(func, '_event_type', 'pat_message')
setattr(func, '_priority', min(max(priority, 0), 99))
return func
return decorator if not callable(priority) else decorator(priority)
def on_at_message(priority=50):
"""被@消息装饰器"""
def decorator(func):
if callable(priority):
f = priority
setattr(f, '_event_type', 'at_message')
setattr(f, '_priority', 50)
return f
setattr(func, '_event_type', 'at_message')
setattr(func, '_priority', min(max(priority, 0), 99))
return func
return decorator if not callable(priority) else decorator(priority)
def on_system_message(priority=50):
"""其他消息装饰器"""
def decorator(func):
if callable(priority):
f = priority
setattr(f, '_event_type', 'system_message')
setattr(f, '_priority', 50)
return f
setattr(func, '_event_type', 'other_message')
setattr(func, '_priority', min(max(priority, 0), 99))
return func
return decorator if not callable(priority) else decorator(priority)
def on_other_message(priority=50):
"""其他消息装饰器"""
def decorator(func):
if callable(priority):
f = priority
setattr(f, '_event_type', 'other_message')
setattr(f, '_priority', 50)
return f
setattr(func, '_event_type', 'other_message')
setattr(func, '_priority', min(max(priority, 0), 99))
return func
return decorator if not callable(priority) else decorator(priority)
__all__ = [
# 定时任务
'scheduler',
'schedule',
'add_job_safe',
'remove_job_safe',
# 消息装饰器
'on_text_message',
'on_image_message',
'on_voice_message',
'on_video_message',
'on_emoji_message',
'on_file_message',
'on_quote_message',
'on_pat_message',
'on_at_message',
'on_system_message',
'on_other_message',
# 工具
'MESSAGE_DECORATOR_TYPES',
'_create_message_decorator',
]

438
utils/errors.py Normal file
View File

@@ -0,0 +1,438 @@
"""
统一错误处理模块
提供:
- 自定义异常类层次结构
- 错误包装和转换
- 用户友好的错误消息
- 错误日志和追踪
使用示例:
from utils.errors import PluginError, ToolExecutionError, handle_error
try:
await some_operation()
except Exception as e:
result = handle_error(e, context="执行工具")
# result = {"success": False, "error": "...", "error_type": "..."}
"""
from __future__ import annotations
import traceback
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Optional, Type
from loguru import logger
# ==================== 错误类型枚举 ====================
class ErrorType(Enum):
"""错误类型分类"""
UNKNOWN = "unknown"
PLUGIN = "plugin"
TOOL = "tool"
CONFIG = "config"
NETWORK = "network"
TIMEOUT = "timeout"
VALIDATION = "validation"
PERMISSION = "permission"
RESOURCE = "resource"
# ==================== 自定义异常基类 ====================
class BotError(Exception):
"""机器人错误基类"""
error_type: ErrorType = ErrorType.UNKNOWN
user_message: str = "发生了一个错误"
log_level: str = "error"
def __init__(
self,
message: str,
user_message: str = None,
cause: Exception = None,
context: Dict[str, Any] = None,
):
super().__init__(message)
self.message = message
self._user_message = user_message
self.cause = cause
self.context = context or {}
def get_user_message(self) -> str:
"""获取用户友好的错误消息"""
return self._user_message or self.user_message
def to_dict(self) -> Dict[str, Any]:
"""转换为字典(用于 API 响应)"""
return {
"success": False,
"error": self.get_user_message(),
"error_type": self.error_type.value,
"details": self.message if self.message != self.get_user_message() else None,
}
# ==================== 具体异常类 ====================
class PluginError(BotError):
"""插件相关错误"""
error_type = ErrorType.PLUGIN
user_message = "插件执行出错"
class PluginLoadError(PluginError):
"""插件加载错误"""
user_message = "插件加载失败"
class PluginNotFoundError(PluginError):
"""插件未找到"""
user_message = "找不到指定的插件"
class ToolExecutionError(BotError):
"""工具执行错误"""
error_type = ErrorType.TOOL
user_message = "工具执行失败"
class ToolNotFoundError(ToolExecutionError):
"""工具未找到"""
user_message = "找不到指定的工具"
class ToolTimeoutError(ToolExecutionError):
"""工具执行超时"""
error_type = ErrorType.TIMEOUT
user_message = "工具执行超时"
class ConfigError(BotError):
"""配置相关错误"""
error_type = ErrorType.CONFIG
user_message = "配置错误"
class ConfigNotFoundError(ConfigError):
"""配置项未找到"""
user_message = "找不到配置项"
class ConfigValidationError(ConfigError):
"""配置验证错误"""
error_type = ErrorType.VALIDATION
user_message = "配置格式不正确"
class NetworkError(BotError):
"""网络相关错误"""
error_type = ErrorType.NETWORK
user_message = "网络请求失败"
class APIError(NetworkError):
"""API 调用错误"""
user_message = "API 调用失败"
class ValidationError(BotError):
"""验证错误"""
error_type = ErrorType.VALIDATION
user_message = "参数验证失败"
class PermissionError(BotError):
"""权限错误"""
error_type = ErrorType.PERMISSION
user_message = "没有权限执行此操作"
class ResourceError(BotError):
"""资源错误(内存、文件等)"""
error_type = ErrorType.RESOURCE
user_message = "资源访问失败"
# ==================== 错误处理工具函数 ====================
@dataclass
class ErrorResult:
"""错误处理结果"""
success: bool = False
error: str = ""
error_type: str = "unknown"
details: Optional[str] = None
logged: bool = False
original_exception: Optional[Exception] = field(default=None, repr=False)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
result = {
"success": self.success,
"error": self.error,
"error_type": self.error_type,
}
if self.details:
result["details"] = self.details
return result
def handle_error(
exception: Exception,
context: str = "",
log: bool = True,
include_traceback: bool = False,
) -> ErrorResult:
"""
统一错误处理函数
Args:
exception: 捕获的异常
context: 错误上下文描述
log: 是否记录日志
include_traceback: 是否包含完整堆栈
Returns:
ErrorResult 对象
"""
# 处理自定义异常
if isinstance(exception, BotError):
result = ErrorResult(
success=False,
error=exception.get_user_message(),
error_type=exception.error_type.value,
details=exception.message if exception.message != exception.get_user_message() else None,
original_exception=exception,
)
if log:
log_func = getattr(logger, exception.log_level, logger.error)
log_func(f"[{context}] {exception.error_type.value}: {exception.message}")
result.logged = True
return result
# 处理标准超时异常
import asyncio
if isinstance(exception, asyncio.TimeoutError):
result = ErrorResult(
success=False,
error="操作超时",
error_type=ErrorType.TIMEOUT.value,
original_exception=exception,
)
if log:
logger.warning(f"[{context}] 超时: {exception}")
result.logged = True
return result
# 处理连接错误
if isinstance(exception, (ConnectionError, OSError)):
result = ErrorResult(
success=False,
error="网络连接失败",
error_type=ErrorType.NETWORK.value,
details=str(exception),
original_exception=exception,
)
if log:
logger.error(f"[{context}] 网络错误: {exception}")
result.logged = True
return result
# 处理验证错误
if isinstance(exception, (ValueError, TypeError)):
result = ErrorResult(
success=False,
error="参数错误",
error_type=ErrorType.VALIDATION.value,
details=str(exception),
original_exception=exception,
)
if log:
logger.warning(f"[{context}] 验证错误: {exception}")
result.logged = True
return result
# 处理未知错误
error_msg = str(exception) or exception.__class__.__name__
details = None
if include_traceback:
details = traceback.format_exc()
result = ErrorResult(
success=False,
error=f"发生错误: {error_msg[:100]}",
error_type=ErrorType.UNKNOWN.value,
details=details,
original_exception=exception,
)
if log:
logger.error(f"[{context}] 未知错误: {exception}")
if include_traceback:
logger.debug(traceback.format_exc())
result.logged = True
return result
def wrap_error(
exception: Exception,
error_class: Type[BotError],
message: str = None,
user_message: str = None,
) -> BotError:
"""
将标准异常包装为自定义异常
Args:
exception: 原始异常
error_class: 目标异常类
message: 错误消息
user_message: 用户友好消息
Returns:
包装后的 BotError 子类实例
"""
msg = message or str(exception)
return error_class(
message=msg,
user_message=user_message,
cause=exception,
)
def safe_error_message(exception: Exception, max_length: int = 200) -> str:
"""
获取安全的错误消息(截断过长内容,移除敏感信息)
Args:
exception: 异常对象
max_length: 最大长度
Returns:
安全的错误消息字符串
"""
msg = str(exception)
# 移除可能的敏感信息模式
sensitive_patterns = [
r'api[_-]?key[=:]\s*\S+',
r'password[=:]\s*\S+',
r'token[=:]\s*\S+',
r'secret[=:]\s*\S+',
]
import re
for pattern in sensitive_patterns:
msg = re.sub(pattern, '[REDACTED]', msg, flags=re.IGNORECASE)
# 截断
if len(msg) > max_length:
msg = msg[:max_length] + "..."
return msg
# ==================== 装饰器 ====================
def catch_errors(
error_class: Type[BotError] = BotError,
context: str = "",
log: bool = True,
reraise: bool = False,
):
"""
错误捕获装饰器
Args:
error_class: 转换为的错误类
context: 上下文描述
log: 是否记录日志
reraise: 是否重新抛出
Usage:
@catch_errors(ToolExecutionError, context="执行工具")
async def my_tool():
...
"""
def decorator(func):
import asyncio
import functools
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except BotError:
if reraise:
raise
return None
except Exception as e:
ctx = context or func.__name__
handle_error(e, context=ctx, log=log)
if reraise:
raise wrap_error(e, error_class) from e
return None
@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except BotError:
if reraise:
raise
return None
except Exception as e:
ctx = context or func.__name__
handle_error(e, context=ctx, log=log)
if reraise:
raise wrap_error(e, error_class) from e
return None
if asyncio.iscoroutinefunction(func):
return async_wrapper
return sync_wrapper
return decorator
# ==================== 导出列表 ====================
__all__ = [
# 枚举
'ErrorType',
# 异常基类
'BotError',
# 插件异常
'PluginError',
'PluginLoadError',
'PluginNotFoundError',
# 工具异常
'ToolExecutionError',
'ToolNotFoundError',
'ToolTimeoutError',
# 配置异常
'ConfigError',
'ConfigNotFoundError',
'ConfigValidationError',
# 网络异常
'NetworkError',
'APIError',
# 其他异常
'ValidationError',
'PermissionError',
'ResourceError',
# 工具函数
'ErrorResult',
'handle_error',
'wrap_error',
'safe_error_message',
# 装饰器
'catch_errors',
]

View File

@@ -1,71 +1,315 @@
import copy
from typing import Callable, Dict, List
"""
事件管理器模块
提供事件的注册、分发和管理:
- 优先级事件处理
- 处理器缓存优化
- 事件统计
- 异常隔离
"""
import asyncio
import time
import traceback
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from loguru import logger
@dataclass
class HandlerInfo:
"""事件处理器信息"""
handler: Callable
instance: object
priority: int
handler_name: str = field(default="")
def __post_init__(self):
if not self.handler_name:
self.handler_name = f"{self.instance.__class__.__name__}.{self.handler.__name__}"
@dataclass
class EventStats:
"""事件统计信息"""
emit_count: int = 0
handler_calls: int = 0
total_time_ms: float = 0
error_count: int = 0
stopped_count: int = 0 # 被 return False 中断的次数
class EventManager:
_handlers: Dict[str, List[tuple[Callable, object, int]]] = {}
"""
事件管理器
特性:
- 优先级排序(高优先级先执行)
- 处理器可返回 False 中断后续处理
- 异常隔离(单个处理器异常不影响其他)
- 性能统计
"""
# 类级别存储
_handlers: Dict[str, List[HandlerInfo]] = {}
_stats: Dict[str, EventStats] = defaultdict(EventStats)
_handler_cache: Dict[str, List[HandlerInfo]] = {} # 排序后的缓存
_cache_valid: Set[str] = set()
@classmethod
def bind_instance(cls, instance: object):
"""将实例绑定到对应的事件处理函数"""
from loguru import logger
registered_count = 0
for method_name in dir(instance):
method = getattr(instance, method_name)
if hasattr(method, '_event_type'):
event_type = getattr(method, '_event_type')
priority = getattr(method, '_priority', 50)
"""
绑定实例的事件处理方法
if event_type not in cls._handlers:
cls._handlers[event_type] = []
cls._handlers[event_type].append((method, instance, priority))
# 按优先级排序,优先级高的在前
cls._handlers[event_type].sort(key=lambda x: x[2], reverse=True)
registered_count += 1
logger.debug(f"[EventManager] 注册事件处理器: {instance.__class__.__name__}.{method_name} -> {event_type} (优先级={priority})")
扫描实例的所有方法,将带有 _event_type 属性的方法注册为事件处理器。
Args:
instance: 插件实例
"""
registered_count = 0
for method_name in dir(instance):
if method_name.startswith('_'):
continue
try:
method = getattr(instance, method_name)
except Exception:
continue
if not callable(method) or not hasattr(method, '_event_type'):
continue
event_type = getattr(method, '_event_type')
priority = getattr(method, '_priority', 50)
handler_info = HandlerInfo(
handler=method,
instance=instance,
priority=priority,
)
if event_type not in cls._handlers:
cls._handlers[event_type] = []
cls._handlers[event_type].append(handler_info)
# 使缓存失效
cls._cache_valid.discard(event_type)
registered_count += 1
logger.debug(
f"[EventManager] 注册: {handler_info.handler_name} -> "
f"{event_type} (优先级={priority})"
)
if registered_count > 0:
logger.success(f"[EventManager] {instance.__class__.__name__} 注册了 {registered_count} 个事件处理器")
@classmethod
async def emit(cls, event_type: str, *args, **kwargs) -> None:
"""触发事件"""
from loguru import logger
if event_type not in cls._handlers:
logger.debug(f"[EventManager] 事件 {event_type} 没有注册的处理器")
return
logger.debug(f"[EventManager] 触发事件: {event_type}, 处理器数量: {len(cls._handlers[event_type])}")
api_client, message = args
for handler, instance, priority in cls._handlers[event_type]:
try:
logger.debug(f"[EventManager] 调用处理器: {instance.__class__.__name__}.{handler.__name__}")
# 不再深拷贝message让所有处理器共享同一个消息对象
# 这样AutoReply设置的标记可以传递给AIChat
handler_args = (api_client, message)
new_kwargs = kwargs # kwargs也不需要深拷贝
result = await handler(*handler_args, **new_kwargs)
if isinstance(result, bool):
# True 继续执行 False 停止执行
if not result:
break
else:
continue # 我也不知道你返回了个啥玩意,反正继续执行就是了
except Exception as e:
import traceback
logger.error(f"处理器 {handler.__name__} 执行失败: {e}")
logger.error(f"详细错误: {traceback.format_exc()}")
logger.success(
f"[EventManager] {instance.__class__.__name__} "
f"注册了 {registered_count} 个事件处理器"
)
@classmethod
def unbind_instance(cls, instance: object):
"""解绑实例的所有事件处理函数"""
for event_type in cls._handlers:
"""
解绑实例的所有事件处理器
Args:
instance: 插件实例
"""
unbound_count = 0
for event_type in list(cls._handlers.keys()):
original_count = len(cls._handlers[event_type])
cls._handlers[event_type] = [
(handler, inst, priority)
for handler, inst, priority in cls._handlers[event_type]
if inst is not instance
h for h in cls._handlers[event_type]
if h.instance is not instance
]
removed = original_count - len(cls._handlers[event_type])
if removed > 0:
unbound_count += removed
cls._cache_valid.discard(event_type)
# 清理空列表
if not cls._handlers[event_type]:
del cls._handlers[event_type]
if unbound_count > 0:
logger.debug(
f"[EventManager] {instance.__class__.__name__} "
f"解绑了 {unbound_count} 个事件处理器"
)
@classmethod
def _get_sorted_handlers(cls, event_type: str) -> List[HandlerInfo]:
"""获取排序后的处理器列表(带缓存)"""
if event_type not in cls._cache_valid:
handlers = cls._handlers.get(event_type, [])
# 按优先级降序排序
cls._handler_cache[event_type] = sorted(
handlers,
key=lambda h: h.priority,
reverse=True
)
cls._cache_valid.add(event_type)
return cls._handler_cache.get(event_type, [])
@classmethod
async def emit(cls, event_type: str, *args, **kwargs) -> bool:
"""
触发事件
Args:
event_type: 事件类型
*args: 传递给处理器的位置参数(通常是 api_client, message
**kwargs: 传递给处理器的关键字参数
Returns:
True 表示所有处理器都执行了False 表示被中断
"""
handlers = cls._get_sorted_handlers(event_type)
if not handlers:
logger.debug(f"[EventManager] 事件 {event_type} 没有处理器")
return True
# 更新统计
stats = cls._stats[event_type]
stats.emit_count += 1
start_time = time.time()
all_completed = True
logger.debug(
f"[EventManager] 触发: {event_type}, "
f"处理器数量: {len(handlers)}"
)
for handler_info in handlers:
stats.handler_calls += 1
try:
logger.debug(f"[EventManager] 调用: {handler_info.handler_name}")
result = await handler_info.handler(*args, **kwargs)
# 检查是否中断
if result is False:
stats.stopped_count += 1
all_completed = False
logger.debug(
f"[EventManager] {handler_info.handler_name} "
f"返回 False中断事件处理"
)
break
except Exception as e:
stats.error_count += 1
logger.error(
f"[EventManager] {handler_info.handler_name} 执行失败: {e}"
)
logger.debug(f"详细错误:\n{traceback.format_exc()}")
# 继续执行其他处理器
elapsed_ms = (time.time() - start_time) * 1000
stats.total_time_ms += elapsed_ms
return all_completed
@classmethod
async def emit_parallel(
cls,
event_type: str,
*args,
max_concurrency: int = 5,
**kwargs
) -> List[Any]:
"""
并行触发事件(忽略优先级和中断)
适用于不需要顺序执行的场景。
Args:
event_type: 事件类型
max_concurrency: 最大并发数
*args, **kwargs: 传递给处理器的参数
Returns:
所有处理器的返回值列表
"""
handlers = cls._get_sorted_handlers(event_type)
if not handlers:
return []
semaphore = asyncio.Semaphore(max_concurrency)
async def run_handler(handler_info: HandlerInfo):
async with semaphore:
try:
return await handler_info.handler(*args, **kwargs)
except Exception as e:
logger.error(f"[EventManager] {handler_info.handler_name} 失败: {e}")
return None
tasks = [run_handler(h) for h in handlers]
return await asyncio.gather(*tasks, return_exceptions=True)
@classmethod
def get_handlers(cls, event_type: str) -> List[str]:
"""获取事件的所有处理器名称"""
handlers = cls._get_sorted_handlers(event_type)
return [h.handler_name for h in handlers]
@classmethod
def get_all_events(cls) -> List[str]:
"""获取所有已注册的事件类型"""
return list(cls._handlers.keys())
@classmethod
def get_stats(cls, event_type: str = None) -> Dict[str, Any]:
"""
获取事件统计信息
Args:
event_type: 指定事件类型None 返回所有
Returns:
统计信息字典
"""
if event_type:
stats = cls._stats.get(event_type, EventStats())
return {
"emit_count": stats.emit_count,
"handler_calls": stats.handler_calls,
"total_time_ms": stats.total_time_ms,
"avg_time_ms": stats.total_time_ms / max(stats.emit_count, 1),
"error_count": stats.error_count,
"stopped_count": stats.stopped_count,
}
return {
event: cls.get_stats(event)
for event in cls._stats.keys()
}
@classmethod
def reset_stats(cls):
"""重置所有统计"""
cls._stats.clear()
@classmethod
def clear(cls):
"""清除所有处理器和统计(用于测试)"""
cls._handlers.clear()
cls._handler_cache.clear()
cls._cache_valid.clear()
cls._stats.clear()
# ==================== 导出 ====================
__all__ = ['EventManager', 'HandlerInfo', 'EventStats']

View File

@@ -2,25 +2,48 @@
HookBot - 机器人核心类
处理消息路由和事件分发
职责单一化:仅负责消息流程编排,具体功能委托给专门模块
"""
import asyncio
import tomllib
import time
from typing import Dict, Any
import random
from typing import Any, Dict, Optional
from loguru import logger
from WechatHook import WechatHookClient, MESSAGE_TYPE_MAP, normalize_message
from utils.event_manager import EventManager
from utils.config_manager import get_bot_config, get_performance_config
from utils.message_filter import MessageFilter
from utils.message_dedup import MessageDeduplicator
from utils.message_stats import MessageStats
class HookBot:
"""
HookBot 核心类
负责消息处理、路由和事件分发
负责消息处理流程编排:
1. 接收消息
2. 去重检查
3. 格式转换
4. 过滤检查
5. 事件分发
具体功能委托给:
- MessageDeduplicator: 消息去重
- MessageFilter: 消息过滤
- MessageStats: 消息统计
"""
# API 响应消息类型(需要忽略)
API_RESPONSE_TYPES = {11032, 11174, 11230}
# 重要消息类型(始终记录日志)
IMPORTANT_MESSAGE_TYPES = {
11058, 11098, 11099, 11025, # 系统消息、群成员变动、登录信息
11051, 11047, 11052, 11055 # 视频、图片、表情、文件消息
}
def __init__(self, client: WechatHookClient):
"""
初始化 HookBot
@@ -29,92 +52,33 @@ class HookBot:
client: WechatHookClient 实例
"""
self.client = client
self.wxid = None
self.nickname = None
self.wxid: Optional[str] = None
self.nickname: Optional[str] = None
# 读取配置
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
# 加载配置
bot_config = get_bot_config()
perf_config = get_performance_config()
bot_config = main_config.get("Bot", {})
# 预设机器人信息
preset_wxid = bot_config.get("wxid") or bot_config.get("bot_wxid")
preset_nickname = bot_config.get("nickname") or bot_config.get("bot_nickname")
if preset_wxid:
self.wxid = preset_wxid
logger.info(f"使用配置中的机器人 wxid: {self.wxid}")
if preset_nickname:
self.nickname = preset_nickname
logger.info(f"使用配置中的机器人昵称: {self.nickname}")
self.ignore_mode = bot_config.get("ignore-mode", "None")
self.whitelist = bot_config.get("whitelist", [])
self.blacklist = bot_config.get("blacklist", [])
# 性能配置
perf_config = main_config.get("Performance", {})
# 日志采样率
self.log_sampling_rate = perf_config.get("log_sampling_rate", 1.0)
# 消息去重(部分环境会重复回调同一条消息,导致插件回复两次
self._dedup_ttl_seconds = perf_config.get("dedup_ttl_seconds", 30)
self._dedup_max_size = perf_config.get("dedup_max_size", 5000)
self._dedup_lock = asyncio.Lock()
self._recent_message_keys: Dict[str, float] = {}
# 消息计数和统计
self.message_count = 0
self.filtered_count = 0
self.processed_count = 0
# 初始化组件(职责委托
self._filter = MessageFilter.from_config(bot_config)
self._dedup = MessageDeduplicator.from_config(perf_config)
self._stats = MessageStats()
logger.info("HookBot 初始化完成")
def _extract_msg_id(self, data: Dict[str, Any]) -> str:
"""从原始回调数据中提取消息ID用于去重"""
for k in ("msgid", "msg_id", "MsgId", "id"):
v = data.get(k)
if v:
return str(v)
return ""
async def _is_duplicate_message(self, msg_type: int, data: Dict[str, Any]) -> bool:
"""判断该条消息是否为短时间内重复回调。"""
msg_id = self._extract_msg_id(data)
if not msg_id:
# 没有稳定 msgid 时不做去重,避免误伤(同一秒内同内容可能是用户真实重复发送)
return False
key = f"msgid:{msg_id}"
now = time.time()
ttl = max(float(self._dedup_ttl_seconds or 0), 0.0)
if ttl <= 0:
return False
async with self._dedup_lock:
last_seen = self._recent_message_keys.get(key)
if last_seen is not None and (now - last_seen) < ttl:
return True
# 记录/刷新
self._recent_message_keys.pop(key, None)
self._recent_message_keys[key] = now
# 清理过期 key按插入顺序从旧到新
cutoff = now - ttl
while self._recent_message_keys:
first_key = next(iter(self._recent_message_keys))
if self._recent_message_keys.get(first_key, now) >= cutoff:
break
self._recent_message_keys.pop(first_key, None)
# 限制大小,避免长期运行内存增长
max_size = int(self._dedup_max_size or 0)
if max_size > 0:
while len(self._recent_message_keys) > max_size and self._recent_message_keys:
first_key = next(iter(self._recent_message_keys))
self._recent_message_keys.pop(first_key, None)
return False
def update_profile(self, wxid: str, nickname: str):
"""
更新机器人信息
@@ -125,9 +89,10 @@ class HookBot:
"""
self.wxid = wxid
self.nickname = nickname
self._filter.set_bot_wxid(wxid)
logger.info(f"机器人信息: wxid={wxid}, nickname={nickname}")
async def process_message(self, msg_type: int, data: dict):
async def process_message(self, msg_type: int, data: Dict[str, Any]):
"""
处理接收到的消息
@@ -135,131 +100,105 @@ class HookBot:
msg_type: 消息类型
data: 消息数据
"""
# 过滤 API 响应消息
# - 11032: 获取群成员信息响应
# - 11174/11230: 协议/上传等 API 回调
if msg_type in [11032, 11174, 11230]:
# 1. 过滤 API 响应消息
if msg_type in self.API_RESPONSE_TYPES:
return
# 去重:同一条消息重复回调时不再重复触发事件(避免“同一句话回复两次”)
# 2. 去重检查
try:
if await self._is_duplicate_message(msg_type, data):
logger.debug(f"[HookBot] 重复消息已丢弃: type={msg_type}, msgid={self._extract_msg_id(data) or 'N/A'}")
if await self._dedup.is_duplicate(data):
self._stats.record_duplicate()
logger.debug(f"[HookBot] 重复消息已丢弃: type={msg_type}")
return
except Exception as e:
# 去重失败不影响主流程
logger.debug(f"[HookBot] 消息去重检查失败: {e}")
# 消息计数
self.message_count += 1
# 日志采样 - 只记录部分消息以减少日志量
# 3. 记录收到消息
self._stats.record_received()
should_log = self._should_log_message(msg_type)
if should_log:
logger.debug(f"处理消息: type={msg_type}")
# 重要事件始终记录
if msg_type in [11098, 11099, 11058]: # 群成员变动、系统消息
if msg_type in self.IMPORTANT_MESSAGE_TYPES:
logger.info(f"重要事件: type={msg_type}")
# 获取事件类型
# 4. 获取事件类型
event_type = MESSAGE_TYPE_MAP.get(msg_type)
if should_log and event_type:
logger.info(f"[HookBot] 消息类型映射: {msg_type} -> {event_type}")
if not event_type:
# 记录未知消息类型的详细信息,帮助调试
content_preview = str(data.get('raw_msg', data.get('msg', '')))[:200]
logger.warning(f"未映射的消息类型: {msg_type}, wx_type: {data.get('wx_type')}, 内容预览: {content_preview}")
logger.warning(
f"未映射的消息类型: {msg_type}, "
f"wx_type: {data.get('wx_type')}, "
f"内容预览: {content_preview}"
)
return
# 格式转换
# 5. 格式转换
try:
message = normalize_message(msg_type, data)
except Exception as e:
logger.error(f"格式转换失败: {e}")
self._stats.record_error()
return
# 过滤消息
if not self._check_filter(message):
self.filtered_count += 1
# 6. 过滤检查
if not self._filter.should_process(message):
self._stats.record_filtered()
if should_log:
logger.debug(f"消息被过滤: {message.get('FromWxid')}")
return
self.processed_count += 1
# 7. 记录处理
self._stats.record_processed(event_type)
# 采样记录处理的消息
if should_log:
content = message.get('Content', '')
if len(content) > 50:
content = content[:50] + "..."
logger.info(f"处理消息: type={event_type}, from={message.get('FromWxid')}, content={content}")
logger.info(
f"处理消息: type={event_type}, "
f"from={message.get('FromWxid')}, "
f"content={content}"
)
# 触发事件
# 8. 触发事件
try:
await EventManager.emit(event_type, self.client, message)
except Exception as e:
logger.error(f"事件处理失败: {e}")
self._stats.record_error()
def _should_log_message(self, msg_type: int) -> bool:
"""判断是否应该记录此消息的日志"""
# 重要消息类型始终记录
important_types = {
11058, 11098, 11099, 11025, # 系统消息、群成员变动、登录信息
11051, 11047, 11052, 11055 # 视频、图片、表情、文件消息
}
if msg_type in important_types:
if msg_type in self.IMPORTANT_MESSAGE_TYPES:
return True
# 其他消息按采样率记录
import random
return random.random() < self.log_sampling_rate
def _check_filter(self, message: Dict[str, Any]) -> bool:
"""
检查消息是否通过过滤
Args:
message: 消息字典
Returns:
是否通过过滤
"""
from_wxid = message.get("FromWxid", "")
sender_wxid = message.get("SenderWxid", "")
msg_type = message.get("MsgType", 0)
# 系统消息type=11058不过滤因为包含重要的群聊事件信息
if msg_type == 11058:
return True
# 过滤机器人自己发送的消息,避免无限循环
if self.wxid and (from_wxid == self.wxid or sender_wxid == self.wxid):
return False
# None 模式:处理所有消息
if self.ignore_mode == "None":
return True
# Whitelist 模式:仅处理白名单
if self.ignore_mode == "Whitelist":
return from_wxid in self.whitelist or sender_wxid in self.whitelist
# Blacklist 模式:屏蔽黑名单
if self.ignore_mode == "Blacklist":
return from_wxid not in self.blacklist and sender_wxid not in self.blacklist
return True
def get_stats(self) -> dict:
def get_stats(self) -> Dict[str, Any]:
"""获取消息处理统计信息"""
return {
"total_messages": self.message_count,
"filtered_messages": self.filtered_count,
"processed_messages": self.processed_count,
"filter_rate": self.filtered_count / max(self.message_count, 1),
"process_rate": self.processed_count / max(self.message_count, 1)
}
stats = self._stats.get_stats()
stats["dedup"] = self._dedup.get_stats()
return stats
# ==================== 兼容旧接口 ====================
@property
def message_count(self) -> int:
"""兼容旧接口:总消息数"""
return self._stats.get_stats()["total_messages"]
@property
def filtered_count(self) -> int:
"""兼容旧接口:被过滤消息数"""
return self._stats.get_stats()["filtered_messages"]
@property
def processed_count(self) -> int:
"""兼容旧接口:已处理消息数"""
return self._stats.get_stats()["processed_messages"]

690
utils/image_processor.py Normal file
View File

@@ -0,0 +1,690 @@
"""
图片/视频处理模块
提供媒体文件的下载、编码和描述生成:
- 图片下载与 base64 编码
- 表情包下载与编码
- 视频下载与编码
- AI 图片/视频描述生成
使用示例:
from utils.image_processor import ImageProcessor, MediaConfig
config = MediaConfig(
api_url="https://api.openai.com/v1/chat/completions",
api_key="sk-xxx",
model="gpt-4-vision-preview",
)
processor = ImageProcessor(config)
# 下载图片
image_base64 = await processor.download_image(bot, cdnurl, aeskey)
# 生成描述
description = await processor.generate_description(image_base64, "描述这张图片")
"""
from __future__ import annotations
import asyncio
import base64
import json
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional, TYPE_CHECKING
import aiohttp
from loguru import logger
# 可选代理支持
try:
from aiohttp_socks import ProxyConnector
PROXY_SUPPORT = True
except ImportError:
PROXY_SUPPORT = False
if TYPE_CHECKING:
pass # bot 类型提示
@dataclass
class MediaConfig:
"""媒体处理配置"""
# API 配置
api_url: str = "https://api.openai.com/v1/chat/completions"
api_key: str = ""
model: str = "gpt-4-vision-preview"
timeout: int = 120
max_tokens: int = 1000
retries: int = 2
# 代理配置
proxy_enabled: bool = False
proxy_type: str = "socks5"
proxy_host: str = "127.0.0.1"
proxy_port: int = 7890
proxy_username: str = ""
proxy_password: str = ""
# 视频专用配置
video_api_url: str = ""
video_model: str = ""
video_max_size_mb: int = 20
video_timeout: int = 360
video_max_tokens: int = 8192
# 临时目录
temp_dir: Optional[Path] = None
@classmethod
def from_dict(cls, config: Dict[str, Any]) -> "MediaConfig":
"""从配置字典创建"""
api_config = config.get("api", {})
proxy_config = config.get("proxy", {})
image_desc_config = config.get("image_description", {})
video_config = config.get("video_recognition", {})
return cls(
api_url=api_config.get("url", "https://api.openai.com/v1/chat/completions"),
api_key=api_config.get("api_key", ""),
model=image_desc_config.get("model", api_config.get("model", "gpt-4-vision-preview")),
timeout=api_config.get("timeout", 120),
max_tokens=image_desc_config.get("max_tokens", 1000),
retries=image_desc_config.get("retries", 2),
proxy_enabled=proxy_config.get("enabled", False),
proxy_type=proxy_config.get("type", "socks5"),
proxy_host=proxy_config.get("host", "127.0.0.1"),
proxy_port=proxy_config.get("port", 7890),
proxy_username=proxy_config.get("username", ""),
proxy_password=proxy_config.get("password", ""),
video_api_url=video_config.get("api_url", ""),
video_model=video_config.get("model", ""),
video_max_size_mb=video_config.get("max_size_mb", 20),
video_timeout=video_config.get("timeout", 360),
video_max_tokens=video_config.get("max_tokens", 8192),
)
@dataclass
class MediaResult:
"""媒体处理结果"""
success: bool = False
data: str = "" # base64 数据
description: str = ""
error: Optional[str] = None
media_type: str = "image" # image, emoji, video
class ImageProcessor:
"""
图片/视频处理器
提供统一的媒体处理接口:
- 下载和编码
- AI 描述生成
- 缓存支持
"""
def __init__(self, config: MediaConfig, temp_dir: Optional[Path] = None):
self.config = config
self.temp_dir = temp_dir or config.temp_dir or Path("temp")
self.temp_dir.mkdir(exist_ok=True)
def _get_proxy_connector(self) -> Optional[Any]:
"""获取代理连接器"""
if not self.config.proxy_enabled or not PROXY_SUPPORT:
return None
proxy_type = self.config.proxy_type.upper()
if self.config.proxy_username and self.config.proxy_password:
proxy_url = (
f"{proxy_type}://{self.config.proxy_username}:"
f"{self.config.proxy_password}@"
f"{self.config.proxy_host}:{self.config.proxy_port}"
)
else:
proxy_url = f"{proxy_type}://{self.config.proxy_host}:{self.config.proxy_port}"
try:
return ProxyConnector.from_url(proxy_url)
except Exception as e:
logger.warning(f"[ImageProcessor] 代理配置失败: {e}")
return None
async def download_image(
self,
bot,
cdnurl: str,
aeskey: str,
use_cache: bool = True,
) -> str:
"""
下载图片并转换为 base64
Args:
bot: WechatHookClient 实例(用于 CDN 下载)
cdnurl: CDN URL
aeskey: AES 密钥
use_cache: 是否使用缓存
Returns:
base64 编码的图片数据(带 data URI 前缀)
"""
try:
# 1. 优先从 Redis 缓存获取
if use_cache:
from utils.redis_cache import RedisCache, get_cache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
if media_key:
cached_data = redis_cache.get_cached_media(media_key, "image")
if cached_data:
logger.debug(f"[ImageProcessor] 图片缓存命中: {media_key[:20]}...")
return cached_data
# 2. 缓存未命中,下载图片
logger.debug(f"[ImageProcessor] 开始下载图片...")
filename = f"temp_{uuid.uuid4().hex[:8]}.jpg"
save_path = str((self.temp_dir / filename).resolve())
# 尝试下载中图,失败则下载原图
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=2)
if not success:
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=1)
if not success:
logger.error("[ImageProcessor] CDN 下载失败")
return ""
# 等待文件写入完成
import os
for _ in range(20): # 最多等待10秒
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
break
await asyncio.sleep(0.5)
if not os.path.exists(save_path):
logger.error("[ImageProcessor] 图片文件未生成")
return ""
with open(save_path, "rb") as f:
image_data = base64.b64encode(f.read()).decode()
base64_result = f"data:image/jpeg;base64,{image_data}"
# 3. 缓存到 Redis
if use_cache:
try:
from utils.redis_cache import RedisCache, get_cache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
if media_key:
redis_cache.cache_media(media_key, base64_result, "image", ttl=300)
logger.debug(f"[ImageProcessor] 图片已缓存: {media_key[:20]}...")
except Exception as e:
logger.debug(f"[ImageProcessor] 缓存图片失败: {e}")
# 清理临时文件
try:
Path(save_path).unlink()
except Exception:
pass
return base64_result
except Exception as e:
logger.error(f"[ImageProcessor] 下载图片失败: {e}")
return ""
async def download_emoji(
self,
cdn_url: str,
max_retries: int = 3,
use_cache: bool = True,
) -> str:
"""
下载表情包并转换为 base64
Args:
cdn_url: CDN URL
max_retries: 最大重试次数
use_cache: 是否使用缓存
Returns:
base64 编码的表情包数据(带 data URI 前缀)
"""
# 替换 HTML 实体
cdn_url = cdn_url.replace("&amp;", "&")
# 1. 优先从 Redis 缓存获取
media_key = None
if use_cache:
try:
from utils.redis_cache import RedisCache, get_cache
redis_cache = get_cache()
media_key = RedisCache.generate_media_key(cdnurl=cdn_url)
if redis_cache and redis_cache.enabled and media_key:
cached_data = redis_cache.get_cached_media(media_key, "emoji")
if cached_data:
logger.debug(f"[ImageProcessor] 表情包缓存命中: {media_key[:20]}...")
return cached_data
except Exception:
pass
# 2. 缓存未命中,下载表情包
logger.debug(f"[ImageProcessor] 开始下载表情包...")
last_error = None
connector = self._get_proxy_connector()
for attempt in range(max_retries):
try:
timeout = aiohttp.ClientTimeout(total=30 + attempt * 15)
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.get(cdn_url) as response:
if response.status == 200:
content = await response.read()
if len(content) == 0:
logger.warning(f"[ImageProcessor] 表情包内容为空,重试 {attempt + 1}/{max_retries}")
continue
image_data = base64.b64encode(content).decode()
base64_result = f"data:image/gif;base64,{image_data}"
logger.debug(f"[ImageProcessor] 表情包下载成功,大小: {len(content)} 字节")
# 3. 缓存到 Redis
if use_cache and media_key:
try:
from utils.redis_cache import get_cache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
redis_cache.cache_media(media_key, base64_result, "emoji", ttl=300)
logger.debug(f"[ImageProcessor] 表情包已缓存: {media_key[:20]}...")
except Exception:
pass
return base64_result
else:
logger.warning(f"[ImageProcessor] 表情包下载失败,状态码: {response.status}")
except asyncio.TimeoutError:
last_error = "请求超时"
logger.warning(f"[ImageProcessor] 表情包下载超时,重试 {attempt + 1}/{max_retries}")
except aiohttp.ClientError as e:
last_error = str(e)
logger.warning(f"[ImageProcessor] 表情包下载网络错误: {e}")
except Exception as e:
last_error = str(e)
logger.warning(f"[ImageProcessor] 表情包下载异常: {e}")
if attempt < max_retries - 1:
await asyncio.sleep(1 * (attempt + 1))
logger.error(f"[ImageProcessor] 表情包下载失败,已重试 {max_retries} 次: {last_error}")
return ""
async def download_video(
self,
bot,
cdnurl: str,
aeskey: str,
use_cache: bool = True,
) -> str:
"""
下载视频并转换为 base64
Args:
bot: WechatHookClient 实例
cdnurl: CDN URL
aeskey: AES 密钥
use_cache: 是否使用缓存
Returns:
base64 编码的视频数据(带 data URI 前缀)
"""
try:
# 从缓存获取
media_key = None
if use_cache:
try:
from utils.redis_cache import RedisCache, get_cache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
if media_key:
cached_data = redis_cache.get_cached_media(media_key, "video")
if cached_data:
logger.debug(f"[ImageProcessor] 视频缓存命中: {media_key[:20]}...")
return cached_data
except Exception:
pass
# 下载视频
logger.info(f"[ImageProcessor] 开始下载视频...")
filename = f"video_{uuid.uuid4().hex[:8]}.mp4"
save_path = str((self.temp_dir / filename).resolve())
# file_type=4 表示视频
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=4)
if not success:
logger.error("[ImageProcessor] 视频 CDN 下载失败")
return ""
# 等待文件写入完成
import os
for _ in range(30):
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
break
await asyncio.sleep(0.5)
if not os.path.exists(save_path):
logger.error("[ImageProcessor] 视频文件未生成")
return ""
file_size = os.path.getsize(save_path)
logger.info(f"[ImageProcessor] 视频下载完成,大小: {file_size / 1024 / 1024:.2f} MB")
# 检查文件大小限制
max_size_mb = self.config.video_max_size_mb
if file_size > max_size_mb * 1024 * 1024:
logger.warning(f"[ImageProcessor] 视频文件过大: {file_size / 1024 / 1024:.2f} MB > {max_size_mb} MB")
try:
Path(save_path).unlink()
except Exception:
pass
return ""
# 读取并编码
with open(save_path, "rb") as f:
video_data = base64.b64encode(f.read()).decode()
video_base64 = f"data:video/mp4;base64,{video_data}"
# 缓存到 Redis
if use_cache and media_key:
try:
from utils.redis_cache import get_cache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
redis_cache.cache_media(media_key, video_base64, "video", ttl=600)
logger.debug(f"[ImageProcessor] 视频已缓存: {media_key[:20]}...")
except Exception:
pass
# 清理临时文件
try:
Path(save_path).unlink()
except Exception:
pass
return video_base64
except Exception as e:
logger.error(f"[ImageProcessor] 下载视频失败: {e}")
import traceback
logger.error(traceback.format_exc())
return ""
async def generate_description(
self,
image_base64: str,
prompt: str = "请用一句话简洁地描述这张图片的主要内容。",
model: Optional[str] = None,
) -> str:
"""
使用 AI 生成图片描述
Args:
image_base64: 图片的 base64 数据
prompt: 描述提示词
model: 使用的模型(默认使用配置中的模型)
Returns:
图片描述文本,失败返回空字符串
"""
description_model = model or self.config.model
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": image_base64}}
]
}
]
payload = {
"model": description_model,
"messages": messages,
"max_tokens": self.config.max_tokens,
"stream": True
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.config.api_key}"
}
max_retries = self.config.retries
last_error = None
for attempt in range(max_retries + 1):
try:
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
connector = self._get_proxy_connector()
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.post(
self.config.api_url,
json=payload,
headers=headers
) as resp:
if resp.status != 200:
error_text = await resp.text()
raise Exception(f"API 返回错误: {resp.status}, {error_text[:200]}")
# 流式接收响应
description = ""
async for line in resp.content:
line = line.decode('utf-8').strip()
if not line or line == "data: [DONE]":
continue
if line.startswith("data: "):
try:
data = json.loads(line[6:])
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
description += content
except Exception:
pass
logger.debug(f"[ImageProcessor] 图片描述生成成功: {description[:50]}...")
return description.strip()
except asyncio.CancelledError:
raise
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
last_error = str(e)
if attempt < max_retries:
logger.warning(f"[ImageProcessor] 图片描述网络错误: {e},重试 {attempt + 1}/{max_retries}")
await asyncio.sleep(1 * (attempt + 1))
continue
except Exception as e:
last_error = str(e)
if attempt < max_retries:
logger.warning(f"[ImageProcessor] 图片描述生成异常: {e},重试 {attempt + 1}/{max_retries}")
await asyncio.sleep(1 * (attempt + 1))
continue
logger.error(f"[ImageProcessor] 生成图片描述失败,已重试 {max_retries + 1} 次: {last_error}")
return ""
async def analyze_video(
self,
video_base64: str,
prompt: Optional[str] = None,
) -> str:
"""
使用 AI 分析视频内容
Args:
video_base64: 视频的 base64 数据
prompt: 分析提示词
Returns:
视频分析描述,失败返回空字符串
"""
if not self.config.video_api_url or not self.config.video_model:
logger.error("[ImageProcessor] 视频分析配置不完整")
return ""
# 去除 data:video/mp4;base64, 前缀(如果有)
if video_base64.startswith("data:"):
video_base64 = video_base64.split(",", 1)[1]
default_prompt = """请详细分析这个视频的内容,包括:
1. 视频的主要场景和环境
2. 出现的人物/物体及其动作
3. 视频中的文字、对话或声音(如果有)
4. 视频的整体主题或要表达的内容
5. 任何值得注意的细节
请用客观、详细的方式描述,不要加入主观评价。"""
analyze_prompt = prompt or default_prompt
full_url = f"{self.config.video_api_url}/{self.config.video_model}:generateContent"
payload = {
"contents": [
{
"parts": [
{"text": analyze_prompt},
{
"inline_data": {
"mime_type": "video/mp4",
"data": video_base64
}
}
]
}
],
"generationConfig": {
"maxOutputTokens": self.config.video_max_tokens
}
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.config.api_key}"
}
timeout = aiohttp.ClientTimeout(total=self.config.video_timeout)
max_retries = 2
retry_delay = 5
for attempt in range(max_retries + 1):
try:
logger.info(f"[ImageProcessor] 开始分析视频...{f' (重试 {attempt}/{max_retries})' if attempt > 0 else ''}")
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(full_url, json=payload, headers=headers) as resp:
if resp.status in [502, 503, 504]:
logger.warning(f"[ImageProcessor] 视频 API 临时错误: {resp.status}")
if attempt < max_retries:
await asyncio.sleep(retry_delay)
continue
return ""
if resp.status != 200:
error_text = await resp.text()
logger.error(f"[ImageProcessor] 视频 API 错误: {resp.status}, {error_text[:300]}")
return ""
result = await resp.json()
# 检查安全过滤
if "promptFeedback" in result:
feedback = result["promptFeedback"]
if feedback.get("blockReason"):
logger.warning(f"[ImageProcessor] 视频内容被过滤: {feedback.get('blockReason')}")
return ""
# 提取文本
if "candidates" in result and result["candidates"]:
for candidate in result["candidates"]:
if candidate.get("finishReason") == "SAFETY":
logger.warning("[ImageProcessor] 视频响应被安全过滤")
return ""
content = candidate.get("content", {})
for part in content.get("parts", []):
if "text" in part:
text = part["text"]
logger.info(f"[ImageProcessor] 视频分析完成,长度: {len(text)}")
return text
logger.error(f"[ImageProcessor] 视频分析无有效响应")
return ""
except asyncio.TimeoutError:
logger.warning(f"[ImageProcessor] 视频分析超时{f', 将重试...' if attempt < max_retries else ''}")
if attempt < max_retries:
await asyncio.sleep(retry_delay)
continue
return ""
except Exception as e:
logger.error(f"[ImageProcessor] 视频分析失败: {e}")
import traceback
logger.error(traceback.format_exc())
return ""
return ""
# ==================== 便捷函数 ====================
_default_processor: Optional[ImageProcessor] = None
def get_image_processor(config: Optional[MediaConfig] = None) -> ImageProcessor:
"""获取默认图片处理器"""
global _default_processor
if config:
_default_processor = ImageProcessor(config)
if _default_processor is None:
raise ValueError("ImageProcessor 未初始化,请先传入配置")
return _default_processor
def init_image_processor(config_dict: Dict[str, Any], temp_dir: Optional[Path] = None) -> ImageProcessor:
"""从配置字典初始化图片处理器"""
config = MediaConfig.from_dict(config_dict)
if temp_dir:
config.temp_dir = temp_dir
processor = ImageProcessor(config, temp_dir)
global _default_processor
_default_processor = processor
return processor
# ==================== 导出 ====================
__all__ = [
'MediaConfig',
'MediaResult',
'ImageProcessor',
'get_image_processor',
'init_image_processor',
]

392
utils/llm_client.py Normal file
View File

@@ -0,0 +1,392 @@
"""
LLM 客户端抽象层
提供统一的 LLM API 调用接口:
- 支持 OpenAI 兼容 API
- 自动重试和错误处理
- 流式/非流式响应
- 代理支持
- Token 估算
使用示例:
from utils.llm_client import LLMClient, LLMConfig
config = LLMConfig(
api_base="https://api.openai.com/v1",
api_key="sk-xxx",
model="gpt-4",
)
client = LLMClient(config)
response = await client.chat_completion(
messages=[{"role": "user", "content": "Hello"}],
tools=[...],
)
"""
from __future__ import annotations
import asyncio
import json
import time
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
import aiohttp
from loguru import logger
# 可选代理支持
try:
from aiohttp_socks import ProxyConnector
PROXY_SUPPORT = True
except ImportError:
PROXY_SUPPORT = False
@dataclass
class LLMConfig:
"""LLM 配置"""
api_base: str = "https://api.openai.com/v1"
api_key: str = ""
model: str = "gpt-4"
temperature: float = 0.7
max_tokens: int = 4096
timeout: int = 120
max_retries: int = 3
retry_delay: float = 1.0
# 代理配置
proxy_enabled: bool = False
proxy_type: str = "socks5"
proxy_host: str = "127.0.0.1"
proxy_port: int = 7890
# 额外参数
extra_params: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, config: Dict[str, Any]) -> "LLMConfig":
"""从配置字典创建"""
api_config = config.get("api", {})
proxy_config = config.get("proxy", {})
return cls(
api_base=api_config.get("base_url", "https://api.openai.com/v1"),
api_key=api_config.get("api_key", ""),
model=api_config.get("model", "gpt-4"),
temperature=api_config.get("temperature", 0.7),
max_tokens=api_config.get("max_tokens", 4096),
timeout=api_config.get("timeout", 120),
max_retries=api_config.get("max_retries", 3),
retry_delay=api_config.get("retry_delay", 1.0),
proxy_enabled=proxy_config.get("enabled", False),
proxy_type=proxy_config.get("type", "socks5"),
proxy_host=proxy_config.get("host", "127.0.0.1"),
proxy_port=proxy_config.get("port", 7890),
)
@dataclass
class LLMResponse:
"""LLM 响应"""
content: str = ""
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
finish_reason: str = ""
usage: Dict[str, int] = field(default_factory=dict)
raw_response: Dict[str, Any] = field(default_factory=dict)
error: Optional[str] = None
@property
def has_tool_calls(self) -> bool:
return len(self.tool_calls) > 0
@property
def success(self) -> bool:
return self.error is None
class LLMClient:
"""
LLM 客户端
提供统一的 API 调用接口,支持:
- OpenAI 兼容 API
- 自动重试
- 代理
- 流式响应
"""
def __init__(self, config: LLMConfig):
self.config = config
self._session: Optional[aiohttp.ClientSession] = None
async def _get_session(self) -> aiohttp.ClientSession:
"""获取或创建 HTTP 会话"""
if self._session is None or self._session.closed:
connector = None
# 配置代理
if self.config.proxy_enabled and PROXY_SUPPORT:
proxy_url = (
f"{self.config.proxy_type}://"
f"{self.config.proxy_host}:{self.config.proxy_port}"
)
connector = ProxyConnector.from_url(proxy_url)
logger.debug(f"[LLMClient] 使用代理: {proxy_url}")
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
self._session = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
)
return self._session
async def close(self):
"""关闭会话"""
if self._session and not self._session.closed:
await self._session.close()
self._session = None
def _build_headers(self) -> Dict[str, str]:
"""构建请求头"""
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.config.api_key}",
}
def _build_payload(
self,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
stream: bool = False,
**kwargs,
) -> Dict[str, Any]:
"""构建请求体"""
payload = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"stream": stream,
}
if tools:
payload["tools"] = tools
payload["tool_choice"] = kwargs.get("tool_choice", "auto")
# 合并额外参数
payload.update(self.config.extra_params)
payload.update(kwargs)
return payload
async def chat_completion(
self,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs,
) -> LLMResponse:
"""
非流式聊天补全
Args:
messages: 消息列表
tools: 工具列表(可选)
**kwargs: 额外参数
Returns:
LLMResponse 对象
"""
session = await self._get_session()
url = f"{self.config.api_base}/chat/completions"
headers = self._build_headers()
payload = self._build_payload(messages, tools, stream=False, **kwargs)
last_error = None
for attempt in range(self.config.max_retries):
try:
async with session.post(url, headers=headers, json=payload) as resp:
if resp.status == 200:
data = await resp.json()
return self._parse_response(data)
error_text = await resp.text()
last_error = f"HTTP {resp.status}: {error_text[:200]}"
logger.warning(f"[LLMClient] 请求失败 (尝试 {attempt + 1}): {last_error}")
# 不可重试的错误
if resp.status in [400, 401, 403]:
break
except asyncio.TimeoutError:
last_error = f"请求超时 ({self.config.timeout}s)"
logger.warning(f"[LLMClient] {last_error} (尝试 {attempt + 1})")
except Exception as e:
last_error = str(e)
logger.warning(f"[LLMClient] 请求异常 (尝试 {attempt + 1}): {e}")
# 重试延迟
if attempt < self.config.max_retries - 1:
await asyncio.sleep(self.config.retry_delay * (attempt + 1))
return LLMResponse(error=last_error)
async def chat_completion_stream(
self,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs,
) -> AsyncGenerator[str, None]:
"""
流式聊天补全
Args:
messages: 消息列表
tools: 工具列表(可选)
**kwargs: 额外参数
Yields:
文本片段
"""
session = await self._get_session()
url = f"{self.config.api_base}/chat/completions"
headers = self._build_headers()
payload = self._build_payload(messages, tools, stream=True, **kwargs)
try:
async with session.post(url, headers=headers, json=payload) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"[LLMClient] 流式请求失败: HTTP {resp.status}")
return
async for line in resp.content:
line = line.decode("utf-8").strip()
if not line or not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"[LLMClient] 流式请求异常: {e}")
def _parse_response(self, data: Dict[str, Any]) -> LLMResponse:
"""解析 API 响应"""
try:
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
content = message.get("content", "") or ""
tool_calls = message.get("tool_calls", [])
finish_reason = choice.get("finish_reason", "")
usage = data.get("usage", {})
# 标准化 tool_calls
parsed_tool_calls = []
for tc in tool_calls:
parsed_tool_calls.append({
"id": tc.get("id", ""),
"type": tc.get("type", "function"),
"function": {
"name": tc.get("function", {}).get("name", ""),
"arguments": tc.get("function", {}).get("arguments", "{}"),
}
})
return LLMResponse(
content=content,
tool_calls=parsed_tool_calls,
finish_reason=finish_reason,
usage=usage,
raw_response=data,
)
except Exception as e:
logger.error(f"[LLMClient] 解析响应失败: {e}")
return LLMResponse(error=f"解析响应失败: {e}")
# ==================== Token 估算 ====================
@staticmethod
def estimate_tokens(text: str) -> int:
"""
估算文本的 token 数量
使用简化规则:
- 英文约 4 字符 = 1 token
- 中文约 1.5 字符 = 1 token
"""
if not text:
return 0
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
other_chars = len(text) - chinese_chars
chinese_tokens = chinese_chars / 1.5
other_tokens = other_chars / 4
return int(chinese_tokens + other_tokens)
@staticmethod
def estimate_message_tokens(message: Dict[str, Any]) -> int:
"""估算单条消息的 token 数量"""
content = message.get("content", "")
if isinstance(content, str):
return LLMClient.estimate_tokens(content) + 4 # role 等开销
if isinstance(content, list):
total = 4
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
total += LLMClient.estimate_tokens(item.get("text", ""))
elif item.get("type") == "image_url":
total += 85 # 图片固定开销
return total
return 4
@staticmethod
def estimate_messages_tokens(messages: List[Dict[str, Any]]) -> int:
"""估算消息列表的总 token 数量"""
return sum(LLMClient.estimate_message_tokens(m) for m in messages)
# ==================== 便捷函数 ====================
_default_client: Optional[LLMClient] = None
def get_llm_client(config: Optional[LLMConfig] = None) -> LLMClient:
"""获取默认 LLM 客户端"""
global _default_client
if config:
_default_client = LLMClient(config)
if _default_client is None:
raise ValueError("LLM 客户端未初始化,请先传入配置")
return _default_client
# ==================== 导出 ====================
__all__ = [
'LLMConfig',
'LLMResponse',
'LLMClient',
'get_llm_client',
]

View File

@@ -181,3 +181,91 @@ def validate_tool_arguments(
return True, "", arguments
# ==================== 工具注册中心集成 ====================
def register_plugin_tools(
plugin_name: str,
plugin: Any,
tools_config: Dict[str, Any],
timeout_config: Optional[Dict[str, Any]] = None,
) -> int:
"""
将插件的 LLM 工具注册到全局工具注册中心
Args:
plugin_name: 插件名称
plugin: 插件实例(需实现 get_llm_tools 和 execute_llm_tool
tools_config: 工具配置(包含 mode, whitelist, blacklist
timeout_config: 工具超时配置 {tool_name: timeout_seconds}
Returns:
注册的工具数量
"""
from utils.tool_registry import get_tool_registry
if not hasattr(plugin, "get_llm_tools") or not hasattr(plugin, "execute_llm_tool"):
return 0
registry = get_tool_registry()
timeout_config = timeout_config or {}
mode = tools_config.get("mode", "all")
whitelist = set(tools_config.get("whitelist", []))
blacklist = set(tools_config.get("blacklist", []))
plugin_tools = plugin.get_llm_tools() or []
registered_count = 0
for tool in plugin_tools:
tool_name = tool.get("function", {}).get("name", "")
if not tool_name:
continue
# 应用白名单/黑名单过滤
if mode == "whitelist" and tool_name not in whitelist:
continue
if mode == "blacklist" and tool_name in blacklist:
logger.debug(f"[黑名单] 跳过注册工具: {tool_name}")
continue
# 获取工具超时配置
timeout = timeout_config.get(tool_name, timeout_config.get("default", 60))
# 创建执行器闭包
async def make_executor(p, tn):
async def executor(tool_name: str, arguments: dict, bot, from_wxid: str):
return await p.execute_llm_tool(tool_name, arguments, bot, from_wxid)
return executor
# 注册工具
if registry.register(
name=tool_name,
plugin_name=plugin_name,
schema=tool,
executor=plugin.execute_llm_tool,
timeout=timeout,
):
registered_count += 1
if mode == "whitelist":
logger.debug(f"[白名单] 注册工具: {tool_name}")
if registered_count > 0:
logger.info(f"插件 {plugin_name} 注册了 {registered_count} 个工具")
return registered_count
def unregister_plugin_tools(plugin_name: str) -> int:
"""
从全局工具注册中心注销插件的所有工具
Args:
plugin_name: 插件名称
Returns:
注销的工具数量
"""
from utils.tool_registry import get_tool_registry
return get_tool_registry().unregister_plugin(plugin_name)

145
utils/message_dedup.py Normal file
View File

@@ -0,0 +1,145 @@
"""
消息去重器模块
防止同一条消息被重复处理(某些环境下回调会重复触发)
"""
import asyncio
import time
from typing import Any, Dict, Optional
from loguru import logger
class MessageDeduplicator:
"""
消息去重器
使用基于时间的滑动窗口实现去重:
- 记录最近处理的消息 ID
- 在 TTL 时间内重复的消息会被过滤
- 自动清理过期记录,限制内存占用
"""
def __init__(
self,
ttl_seconds: float = 30.0,
max_size: int = 5000,
):
"""
初始化去重器
Args:
ttl_seconds: 消息 ID 的有效期0 表示禁用去重
max_size: 最大缓存条目数,防止内存泄漏
"""
self.ttl_seconds = max(float(ttl_seconds), 0.0)
self.max_size = max(int(max_size), 0)
self._cache: Dict[str, float] = {} # key -> timestamp
self._lock = asyncio.Lock()
@staticmethod
def extract_msg_id(data: Dict[str, Any]) -> str:
"""
从原始消息数据中提取消息 ID
Args:
data: 原始消息数据
Returns:
消息 ID 字符串,提取失败返回空字符串
"""
for key in ("msgid", "msg_id", "MsgId", "id"):
value = data.get(key)
if value:
return str(value)
return ""
async def is_duplicate(self, data: Dict[str, Any]) -> bool:
"""
检查消息是否重复
Args:
data: 原始消息数据
Returns:
True 表示是重复消息False 表示是新消息
"""
if self.ttl_seconds <= 0:
return False
msg_id = self.extract_msg_id(data)
if not msg_id:
# 没有消息 ID 时不做去重,避免误判
return False
key = f"msgid:{msg_id}"
now = time.time()
async with self._lock:
# 检查是否存在且未过期
last_seen = self._cache.get(key)
if last_seen is not None and (now - last_seen) < self.ttl_seconds:
return True
# 记录新消息
self._cache.pop(key, None) # 确保插入到末尾(保持顺序)
self._cache[key] = now
# 清理过期条目
self._cleanup_expired(now)
# 限制大小
self._limit_size()
return False
def _cleanup_expired(self, now: float):
"""清理过期条目(需在锁内调用)"""
cutoff = now - self.ttl_seconds
while self._cache:
first_key = next(iter(self._cache))
if self._cache[first_key] >= cutoff:
break
self._cache.pop(first_key, None)
def _limit_size(self):
"""限制缓存大小(需在锁内调用)"""
if self.max_size <= 0:
return
while len(self._cache) > self.max_size:
first_key = next(iter(self._cache))
self._cache.pop(first_key, None)
def clear(self):
"""清空缓存"""
self._cache.clear()
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
"cached_count": len(self._cache),
"ttl_seconds": self.ttl_seconds,
"max_size": self.max_size,
}
@classmethod
def from_config(cls, perf_config: Dict[str, Any]) -> "MessageDeduplicator":
"""
从配置创建去重器
Args:
perf_config: Performance 配置节
Returns:
MessageDeduplicator 实例
"""
return cls(
ttl_seconds=perf_config.get("dedup_ttl_seconds", 30),
max_size=perf_config.get("dedup_max_size", 5000),
)
# ==================== 导出 ====================
__all__ = ['MessageDeduplicator']

128
utils/message_filter.py Normal file
View File

@@ -0,0 +1,128 @@
"""
消息过滤器模块
提供消息过滤功能:
- 白名单/黑名单过滤
- 机器人自身消息过滤
- 系统消息放行
"""
from typing import Any, Dict, List, Optional, Set
from loguru import logger
class MessageFilter:
"""
消息过滤器
支持三种模式:
- None: 不过滤,处理所有消息
- Whitelist: 只处理白名单中的消息
- Blacklist: 过滤黑名单中的消息
"""
# 系统消息类型(始终放行)
SYSTEM_MESSAGE_TYPES = {11058}
def __init__(
self,
mode: str = "None",
whitelist: List[str] = None,
blacklist: List[str] = None,
bot_wxid: str = None,
):
"""
初始化过滤器
Args:
mode: 过滤模式 ("None", "Whitelist", "Blacklist")
whitelist: 白名单 wxid 列表
blacklist: 黑名单 wxid 列表
bot_wxid: 机器人自身 wxid用于过滤自己的消息
"""
self.mode = mode
self.whitelist: Set[str] = set(whitelist or [])
self.blacklist: Set[str] = set(blacklist or [])
self.bot_wxid = bot_wxid
def set_bot_wxid(self, wxid: str):
"""设置机器人 wxid"""
self.bot_wxid = wxid
def add_to_whitelist(self, wxid: str):
"""添加到白名单"""
self.whitelist.add(wxid)
def remove_from_whitelist(self, wxid: str):
"""从白名单移除"""
self.whitelist.discard(wxid)
def add_to_blacklist(self, wxid: str):
"""添加到黑名单"""
self.blacklist.add(wxid)
def remove_from_blacklist(self, wxid: str):
"""从黑名单移除"""
self.blacklist.discard(wxid)
def should_process(self, message: Dict[str, Any]) -> bool:
"""
判断消息是否应该被处理
Args:
message: 标准化后的消息字典
Returns:
True 表示应该处理False 表示应该过滤
"""
from_wxid = message.get("FromWxid", "")
sender_wxid = message.get("SenderWxid", "")
msg_type = message.get("MsgType", 0)
# 系统消息始终放行
if msg_type in self.SYSTEM_MESSAGE_TYPES:
return True
# 过滤机器人自己的消息
if self.bot_wxid and (from_wxid == self.bot_wxid or sender_wxid == self.bot_wxid):
return False
# 根据模式过滤
return self._check_mode(from_wxid, sender_wxid)
def _check_mode(self, from_wxid: str, sender_wxid: str) -> bool:
"""根据模式检查是否放行"""
if self.mode == "None":
return True
if self.mode == "Whitelist":
return from_wxid in self.whitelist or sender_wxid in self.whitelist
if self.mode == "Blacklist":
return from_wxid not in self.blacklist and sender_wxid not in self.blacklist
return True
@classmethod
def from_config(cls, bot_config: Dict[str, Any]) -> "MessageFilter":
"""
从配置创建过滤器
Args:
bot_config: Bot 配置节
Returns:
MessageFilter 实例
"""
return cls(
mode=bot_config.get("ignore-mode", "None"),
whitelist=bot_config.get("whitelist", []),
blacklist=bot_config.get("blacklist", []),
bot_wxid=bot_config.get("wxid") or bot_config.get("bot_wxid"),
)
# ==================== 导出 ====================
__all__ = ['MessageFilter']

View File

@@ -56,10 +56,18 @@ async def log_bot_message(to_wxid: str, content: str, msg_type: str = "text", me
except Exception:
pass
sync_content = content
if msg_type == "image":
sync_content = "[图片]"
elif msg_type == "video":
sync_content = "[视频]"
elif msg_type == "file":
sync_content = "[文件]"
await store.add_group_message(
to_wxid,
bot_nickname,
content,
sync_content,
role="assistant",
sender_wxid=bot_wxid or None,
)

305
utils/message_queue.py Normal file
View File

@@ -0,0 +1,305 @@
"""
消息队列模块
提供高性能的优先级消息队列,支持多种溢出策略:
- drop_oldest: 丢弃最旧的消息
- drop_lowest: 丢弃优先级最低的消息
- sampling: 按采样率丢弃消息
- reject: 拒绝新消息
"""
import asyncio
import heapq
import random
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from loguru import logger
# ==================== 消息优先级常量 ====================
class MessagePriority:
"""消息优先级常量"""
CRITICAL = 100 # 系统消息、登录信息
HIGH = 80 # 管理员命令、群成员变动
NORMAL = 50 # @bot 消息(默认)
LOW = 20 # 普通群消息
# ==================== 溢出策略 ====================
class OverflowStrategy(Enum):
"""队列溢出策略"""
DROP_OLDEST = "drop_oldest" # 丢弃最旧的消息
DROP_LOWEST = "drop_lowest" # 丢弃优先级最低的消息
SAMPLING = "sampling" # 按采样率丢弃
REJECT = "reject" # 拒绝新消息
# ==================== 优先级消息 ====================
@dataclass(order=True)
class PriorityMessage:
"""优先级消息"""
priority: int = field(compare=True)
timestamp: float = field(compare=True)
msg_type: int = field(compare=False)
data: Dict[str, Any] = field(compare=False)
def __init__(self, msg_type: int, data: Dict[str, Any], priority: int = None):
# 优先级越高,数值越大,但 heapq 是最小堆,所以取负数
self.priority = -(priority if priority is not None else MessagePriority.NORMAL)
self.timestamp = time.time()
self.msg_type = msg_type
self.data = data
# ==================== 优先级消息队列 ====================
class PriorityMessageQueue:
"""
优先级消息队列
特性:
- 基于堆的优先级队列
- 支持多种溢出策略
- 线程安全(使用 asyncio.Lock
- 支持任务计数和 join
"""
def __init__(
self,
maxsize: int = 1000,
overflow_strategy: str = "drop_oldest",
sampling_rate: float = 0.5,
):
"""
初始化队列
Args:
maxsize: 最大队列大小
overflow_strategy: 溢出策略 (drop_oldest, drop_lowest, sampling, reject)
sampling_rate: 采样策略的保留率 (0.0-1.0)
"""
self.maxsize = maxsize
self.overflow_strategy = OverflowStrategy(overflow_strategy)
self.sampling_rate = max(0.0, min(1.0, sampling_rate))
self._heap: List[PriorityMessage] = []
self._lock = asyncio.Lock()
self._not_empty = asyncio.Event()
self._unfinished_tasks = 0
self._finished = asyncio.Event()
self._finished.set()
# 统计
self._total_put = 0
self._total_dropped = 0
self._total_rejected = 0
def qsize(self) -> int:
"""返回队列大小"""
return len(self._heap)
def empty(self) -> bool:
"""队列是否为空"""
return len(self._heap) == 0
def full(self) -> bool:
"""队列是否已满"""
return len(self._heap) >= self.maxsize
async def put(
self,
msg_type: int,
data: Dict[str, Any],
priority: int = None,
) -> bool:
"""
添加消息到队列
Args:
msg_type: 消息类型
data: 消息数据
priority: 优先级(可选)
Returns:
是否成功添加
"""
async with self._lock:
self._total_put += 1
# 处理队列满的情况
if self.full():
if not self._handle_overflow():
self._total_rejected += 1
return False
msg = PriorityMessage(msg_type, data, priority)
heapq.heappush(self._heap, msg)
self._unfinished_tasks += 1
self._finished.clear()
self._not_empty.set()
return True
def _handle_overflow(self) -> bool:
"""
处理队列溢出
Returns:
True 表示成功腾出空间False 表示拒绝
"""
if self.overflow_strategy == OverflowStrategy.REJECT:
logger.warning("队列已满,拒绝新消息")
return False
if self.overflow_strategy == OverflowStrategy.DROP_OLDEST:
# 找到最旧的消息timestamp 最小)
if self._heap:
oldest_idx = 0
for i, msg in enumerate(self._heap):
if msg.timestamp < self._heap[oldest_idx].timestamp:
oldest_idx = i
self._heap.pop(oldest_idx)
heapq.heapify(self._heap)
self._total_dropped += 1
self._unfinished_tasks = max(0, self._unfinished_tasks - 1)
return True
elif self.overflow_strategy == OverflowStrategy.DROP_LOWEST:
# 找到优先级最低的消息priority 值最大,因为是负数)
if self._heap:
lowest_idx = 0
for i, msg in enumerate(self._heap):
if msg.priority > self._heap[lowest_idx].priority:
lowest_idx = i
self._heap.pop(lowest_idx)
heapq.heapify(self._heap)
self._total_dropped += 1
self._unfinished_tasks = max(0, self._unfinished_tasks - 1)
return True
elif self.overflow_strategy == OverflowStrategy.SAMPLING:
# 按采样率决定是否接受
if random.random() < self.sampling_rate:
# 接受新消息,丢弃最旧的
if self._heap:
oldest_idx = 0
for i, msg in enumerate(self._heap):
if msg.timestamp < self._heap[oldest_idx].timestamp:
oldest_idx = i
self._heap.pop(oldest_idx)
heapq.heapify(self._heap)
self._total_dropped += 1
self._unfinished_tasks = max(0, self._unfinished_tasks - 1)
return True
else:
self._total_dropped += 1
return False
return False
async def get(self, timeout: float = None) -> Tuple[int, Dict[str, Any]]:
"""
获取优先级最高的消息
Args:
timeout: 超时时间None 表示无限等待
Returns:
(msg_type, data) 元组
Raises:
asyncio.TimeoutError: 超时
"""
start_time = time.time()
while True:
async with self._lock:
if self._heap:
msg = heapq.heappop(self._heap)
if not self._heap:
self._not_empty.clear()
return (msg.msg_type, msg.data)
# 计算剩余超时时间
if timeout is not None:
elapsed = time.time() - start_time
remaining = timeout - elapsed
if remaining <= 0:
raise asyncio.TimeoutError("Queue get timeout")
try:
await asyncio.wait_for(self._not_empty.wait(), timeout=remaining)
except asyncio.TimeoutError:
raise asyncio.TimeoutError("Queue get timeout")
else:
await self._not_empty.wait()
def get_nowait(self) -> Tuple[int, Dict[str, Any]]:
"""非阻塞获取消息"""
if not self._heap:
raise asyncio.QueueEmpty()
msg = heapq.heappop(self._heap)
if not self._heap:
self._not_empty.clear()
return (msg.msg_type, msg.data)
def task_done(self):
"""标记任务完成"""
self._unfinished_tasks = max(0, self._unfinished_tasks - 1)
if self._unfinished_tasks == 0:
self._finished.set()
async def join(self):
"""等待所有任务完成"""
await self._finished.wait()
def clear(self):
"""清空队列"""
self._heap.clear()
self._not_empty.clear()
self._unfinished_tasks = 0
self._finished.set()
def get_stats(self) -> Dict[str, Any]:
"""获取队列统计信息"""
return {
"current_size": len(self._heap),
"max_size": self.maxsize,
"total_put": self._total_put,
"total_dropped": self._total_dropped,
"total_rejected": self._total_rejected,
"unfinished_tasks": self._unfinished_tasks,
"overflow_strategy": self.overflow_strategy.value,
"utilization": len(self._heap) / max(self.maxsize, 1),
}
@classmethod
def from_config(cls, queue_config: Dict[str, Any]) -> "PriorityMessageQueue":
"""
从配置创建队列
Args:
queue_config: Queue 配置节
Returns:
PriorityMessageQueue 实例
"""
return cls(
maxsize=queue_config.get("max_size", 1000),
overflow_strategy=queue_config.get("overflow_strategy", "drop_oldest"),
sampling_rate=queue_config.get("sampling_rate", 0.5),
)
# ==================== 导出 ====================
__all__ = [
'MessagePriority',
'OverflowStrategy',
'PriorityMessage',
'PriorityMessageQueue',
]

114
utils/message_stats.py Normal file
View File

@@ -0,0 +1,114 @@
"""
消息统计器模块
提供消息处理的统计功能:
- 消息计数
- 过滤率统计
- 按类型统计
"""
import time
from collections import defaultdict
from threading import Lock
from typing import Any, Dict
class MessageStats:
"""
消息统计器
线程安全的消息统计实现
"""
def __init__(self):
self._lock = Lock()
self._total_count = 0
self._filtered_count = 0
self._processed_count = 0
self._duplicate_count = 0
self._error_count = 0
self._by_type: Dict[str, int] = defaultdict(int)
self._start_time = time.time()
def record_received(self):
"""记录收到消息"""
with self._lock:
self._total_count += 1
def record_filtered(self):
"""记录被过滤的消息"""
with self._lock:
self._filtered_count += 1
def record_processed(self, event_type: str = None):
"""
记录已处理的消息
Args:
event_type: 消息事件类型(可选)
"""
with self._lock:
self._processed_count += 1
if event_type:
self._by_type[event_type] += 1
def record_duplicate(self):
"""记录重复消息"""
with self._lock:
self._duplicate_count += 1
def record_error(self):
"""记录处理错误"""
with self._lock:
self._error_count += 1
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
with self._lock:
uptime = time.time() - self._start_time
total = max(self._total_count, 1) # 避免除零
return {
"total_messages": self._total_count,
"filtered_messages": self._filtered_count,
"processed_messages": self._processed_count,
"duplicate_messages": self._duplicate_count,
"error_count": self._error_count,
"filter_rate": self._filtered_count / total,
"process_rate": self._processed_count / total,
"duplicate_rate": self._duplicate_count / total,
"messages_per_minute": (self._total_count / uptime) * 60 if uptime > 0 else 0,
"uptime_seconds": uptime,
"by_type": dict(self._by_type),
}
def reset(self):
"""重置统计"""
with self._lock:
self._total_count = 0
self._filtered_count = 0
self._processed_count = 0
self._duplicate_count = 0
self._error_count = 0
self._by_type.clear()
self._start_time = time.time()
# 全局单例(可选使用)
_global_stats: MessageStats = None
_stats_lock = Lock()
def get_message_stats() -> MessageStats:
"""获取全局消息统计器实例"""
global _global_stats
if _global_stats is None:
with _stats_lock:
if _global_stats is None:
_global_stats = MessageStats()
return _global_stats
# ==================== 导出 ====================
__all__ = ['MessageStats', 'get_message_stats']

View File

@@ -1,35 +1,144 @@
"""
插件基类模块
提供插件的基础功能:
- 生命周期钩子on_load, on_enable, on_disable, on_unload, on_reload
- 定时任务管理
- 依赖声明
- 插件元数据
"""
from abc import ABC
from typing import List
from enum import Enum
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from loguru import logger
from .decorators import scheduler, add_job_safe, remove_job_safe
if TYPE_CHECKING:
from utils.plugin_manager import PluginManager
class PluginState(Enum):
"""插件状态"""
UNLOADED = "unloaded" # 未加载
LOADED = "loaded" # 已加载(未启用)
ENABLED = "enabled" # 已启用
DISABLED = "disabled" # 已禁用
ERROR = "error" # 错误状态
class PluginBase(ABC):
"""插件基类"""
"""
插件基类
生命周期:
1. __init__() - 构造函数(同步)
2. on_load() - 加载时调用(异步,可访问其他插件)
3. async_init() - 异步初始化(异步,加载配置、资源等)
4. on_enable() - 启用时调用(异步,注册定时任务)
5. on_disable() - 禁用时调用(异步,清理定时任务)
6. on_unload() - 卸载时调用(异步,释放资源)
7. on_reload() - 重载前调用(异步,保存状态)
使用示例:
class MyPlugin(PluginBase):
description = "我的插件"
author = "作者"
version = "1.0.0"
dependencies = ["AIChat"] # 依赖的插件
load_priority = 60 # 加载优先级
async def on_load(self, plugin_manager):
# 可以访问其他插件
aichat = plugin_manager.plugins.get("AIChat")
async def async_init(self):
# 加载配置、初始化资源
self.config = load_config()
async def on_enable(self, bot):
await super().on_enable(bot) # 注册定时任务
# 额外的启用逻辑
async def on_disable(self):
await super().on_disable() # 清理定时任务
# 额外的禁用逻辑
async def on_unload(self):
# 释放资源、关闭连接
await self.close_connections()
async def on_reload(self) -> dict:
# 返回需要保存的状态
return {"counter": self.counter}
async def restore_state(self, state: dict):
# 重载后恢复状态
self.counter = state.get("counter", 0)
"""
# ==================== 插件元数据 ====================
# 插件元数据
description: str = "暂无描述"
author: str = "未知"
version: str = "1.0.0"
# 插件依赖(填写依赖的插件类名列表)
# 例如: dependencies = ["MessageLogger", "AIChat"]
dependencies: List[str] = []
# 加载优先级数值越大越先加载默认50
# 基础插件设置高优先级,依赖其他插件的设置低优先级
load_priority: int = 50
# ==================== 实例属性 ====================
def __init__(self):
self.enabled = False
self._scheduled_jobs = set()
self.state = PluginState.UNLOADED
self._scheduled_jobs: set = set()
self._bot = None
self._plugin_manager: Optional["PluginManager"] = None
self._saved_state: Dict[str, Any] = {}
# ==================== 生命周期钩子 ====================
async def on_load(self, plugin_manager: "PluginManager"):
"""
插件加载时调用
此时其他依赖的插件已经加载完成,可以安全访问。
Args:
plugin_manager: 插件管理器实例
"""
self._plugin_manager = plugin_manager
self.state = PluginState.LOADED
logger.debug(f"[{self.__class__.__name__}] on_load 调用")
async def async_init(self):
"""
插件异步初始化
用于加载配置、初始化资源等耗时操作。
在 on_load 之后、on_enable 之前调用。
"""
pass
async def on_enable(self, bot=None):
"""插件启用时调用"""
"""
插件启用时调用
# 定时任务
注册定时任务、启动后台服务等。
Args:
bot: WechatHookClient 实例
"""
self._bot = bot
self.enabled = True
self.state = PluginState.ENABLED
# 注册定时任务
for method_name in dir(self):
method = getattr(self, method_name)
if hasattr(method, '_is_scheduled'):
@@ -39,18 +148,85 @@ class PluginBase(ABC):
add_job_safe(scheduler, job_id, method, bot, trigger, **trigger_args)
self._scheduled_jobs.add(job_id)
if self._scheduled_jobs:
logger.success("插件 {} 已加载定时任务: {}", self.__class__.__name__, self._scheduled_jobs)
logger.success(f"插件 {self.__class__.__name__} 已加载定时任务: {self._scheduled_jobs}")
async def on_disable(self):
"""插件禁用时调用"""
"""
插件禁用时调用
清理定时任务、停止后台服务等。
"""
self.enabled = False
self.state = PluginState.DISABLED
# 移除定时任务
for job_id in self._scheduled_jobs:
remove_job_safe(scheduler, job_id)
logger.info("已卸载定时任务: {}", self._scheduled_jobs)
if self._scheduled_jobs:
logger.info(f"已卸载定时任务: {self._scheduled_jobs}")
self._scheduled_jobs.clear()
async def async_init(self):
"""插件异步初始化"""
return
async def on_unload(self):
"""
插件卸载时调用
释放资源、关闭连接、保存数据等。
在 on_disable 之后调用。
"""
self.state = PluginState.UNLOADED
self._bot = None
self._plugin_manager = None
logger.debug(f"[{self.__class__.__name__}] on_unload 调用")
async def on_reload(self) -> Dict[str, Any]:
"""
插件重载前调用
返回需要在重载后恢复的状态数据。
Returns:
需要保存的状态字典
"""
logger.debug(f"[{self.__class__.__name__}] on_reload 调用")
return {}
async def restore_state(self, state: Dict[str, Any]):
"""
重载后恢复状态
Args:
state: on_reload 返回的状态字典
"""
self._saved_state = state
logger.debug(f"[{self.__class__.__name__}] 状态已恢复: {list(state.keys())}")
# ==================== 辅助方法 ====================
def get_plugin(self, plugin_name: str) -> Optional["PluginBase"]:
"""
获取其他插件实例
Args:
plugin_name: 插件类名
Returns:
插件实例,不存在返回 None
"""
if self._plugin_manager:
return self._plugin_manager.plugins.get(plugin_name)
return None
def get_bot(self):
"""获取 bot 实例"""
return self._bot
@property
def plugin_name(self) -> str:
"""获取插件名称"""
return self.__class__.__name__
def __repr__(self) -> str:
return f"<{self.__class__.__name__} v{self.version} state={self.state.value}>"

213
utils/plugin_inject.py Normal file
View File

@@ -0,0 +1,213 @@
"""
插件依赖注入模块
提供插件间依赖注入功能:
- @inject 装饰器自动注入依赖
- 延迟注入lazy injection避免循环依赖
- 类型安全的依赖获取
使用示例:
from utils.plugin_inject import inject, require_plugin
class MyPlugin(PluginBase):
# 方式1: 使用装饰器注入
@inject("AIChat")
def get_aichat(self) -> "AIChat":
pass # 自动注入,无需实现
# 方式2: 使用 require_plugin
async def some_method(self):
aichat = require_plugin("AIChat")
await aichat.do_something()
# 方式3: 使用基类的 get_plugin
async def another_method(self):
aichat = self.get_plugin("AIChat")
if aichat:
await aichat.do_something()
"""
from functools import wraps
from typing import Any, Callable, Optional, Type, TypeVar, TYPE_CHECKING
from loguru import logger
if TYPE_CHECKING:
from utils.plugin_base import PluginBase
T = TypeVar('T')
class PluginNotAvailableError(Exception):
"""插件不可用错误"""
pass
def _get_plugin_manager():
"""延迟获取 PluginManager 避免循环导入"""
from utils.plugin_manager import PluginManager
return PluginManager()
def require_plugin(plugin_name: str) -> "PluginBase":
"""
获取必需的插件(不存在则抛出异常)
Args:
plugin_name: 插件类名
Returns:
插件实例
Raises:
PluginNotAvailableError: 插件不存在或未启用
"""
pm = _get_plugin_manager()
plugin = pm.plugins.get(plugin_name)
if plugin is None:
raise PluginNotAvailableError(f"插件 {plugin_name} 不可用")
return plugin
def get_plugin(plugin_name: str) -> Optional["PluginBase"]:
"""
获取插件(不存在返回 None
Args:
plugin_name: 插件类名
Returns:
插件实例或 None
"""
pm = _get_plugin_manager()
return pm.plugins.get(plugin_name)
def inject(plugin_name: str) -> Callable:
"""
插件注入装饰器
将方法转换为属性 getter自动返回指定插件实例。
Args:
plugin_name: 要注入的插件类名
Usage:
class MyPlugin(PluginBase):
@inject("AIChat")
def aichat(self) -> "AIChat":
pass # 无需实现
async def handle(self, bot, message):
# 直接使用
await self.aichat.process(message)
"""
def decorator(method: Callable) -> property:
@wraps(method)
def getter(self) -> Optional["PluginBase"]:
# 优先使用插件自身的 _plugin_manager
if hasattr(self, '_plugin_manager') and self._plugin_manager:
return self._plugin_manager.plugins.get(plugin_name)
# 回退到全局 PluginManager
return get_plugin(plugin_name)
return property(getter)
return decorator
def inject_required(plugin_name: str) -> Callable:
"""
必需插件注入装饰器
与 inject 类似,但如果插件不存在会抛出异常。
Args:
plugin_name: 要注入的插件类名
Raises:
PluginNotAvailableError: 插件不存在
"""
def decorator(method: Callable) -> property:
@wraps(method)
def getter(self) -> "PluginBase":
plugin = None
if hasattr(self, '_plugin_manager') and self._plugin_manager:
plugin = self._plugin_manager.plugins.get(plugin_name)
else:
plugin = get_plugin(plugin_name)
if plugin is None:
raise PluginNotAvailableError(
f"插件 {self.__class__.__name__} 依赖的 {plugin_name} 不可用"
)
return plugin
return property(getter)
return decorator
class PluginProxy:
"""
插件代理
延迟获取插件,避免初始化时的循环依赖问题。
Usage:
class MyPlugin(PluginBase):
def __init__(self):
super().__init__()
self._aichat = PluginProxy("AIChat")
async def handle(self):
# 首次访问时才获取插件
if self._aichat.available:
await self._aichat.instance.process()
"""
def __init__(self, plugin_name: str):
self._plugin_name = plugin_name
self._cached_instance: Optional["PluginBase"] = None
self._checked = False
@property
def instance(self) -> Optional["PluginBase"]:
"""获取插件实例(带缓存)"""
if not self._checked:
self._cached_instance = get_plugin(self._plugin_name)
self._checked = True
return self._cached_instance
@property
def available(self) -> bool:
"""检查插件是否可用"""
return self.instance is not None
def require(self) -> "PluginBase":
"""获取插件,不存在则抛出异常"""
inst = self.instance
if inst is None:
raise PluginNotAvailableError(f"插件 {self._plugin_name} 不可用")
return inst
def invalidate(self):
"""清除缓存,下次访问重新获取"""
self._cached_instance = None
self._checked = False
def __repr__(self) -> str:
status = "available" if self.available else "unavailable"
return f"<PluginProxy({self._plugin_name}) {status}>"
# ==================== 导出 ====================
__all__ = [
'PluginNotAvailableError',
'require_plugin',
'get_plugin',
'inject',
'inject_required',
'PluginProxy',
]

View File

@@ -2,7 +2,6 @@ import importlib
import inspect
import os
import sys
import tomllib
import traceback
from typing import Dict, Type, List, Union
@@ -10,6 +9,8 @@ from loguru import logger
# from WechatAPI import WechatAPIClient # 注释掉WechatHookBot 不需要这个导入
from utils.singleton import Singleton
from utils.config_manager import get_bot_config
from utils.llm_tooling import register_plugin_tools, unregister_plugin_tools
from .event_manager import EventManager
from .plugin_base import PluginBase
@@ -22,10 +23,9 @@ class PluginManager(metaclass=Singleton):
self.bot = None
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
self.excluded_plugins = main_config.get("Bot", {}).get("disabled-plugins", [])
# 使用统一配置管理器
bot_config = get_bot_config()
self.excluded_plugins = bot_config.get("disabled-plugins", [])
def set_bot(self, bot):
"""设置 bot 客户端WechatHookClient"""
@@ -74,13 +74,34 @@ class PluginManager(metaclass=Singleton):
if is_disabled:
return False
# 创建插件实例
plugin = plugin_class()
EventManager.bind_instance(plugin)
await plugin.on_enable(self.bot)
# 生命周期: on_load可访问其他插件
await plugin.on_load(self)
# 生命周期: async_init加载配置、资源
await plugin.async_init()
# 绑定事件处理器
EventManager.bind_instance(plugin)
# 生命周期: on_enable注册定时任务
await plugin.on_enable(self.bot)
# 注册到插件管理器
self.plugins[plugin_name] = plugin
self.plugin_classes[plugin_name] = plugin_class
self.plugin_info[plugin_name]["enabled"] = True
# 注册插件的 LLM 工具到全局注册中心
try:
tools_config = self._get_tools_config()
timeout_config = self._get_timeout_config()
register_plugin_tools(plugin_name, plugin, tools_config, timeout_config)
except Exception as e:
logger.warning(f"注册插件 {plugin_name} 的工具时出错: {e}")
logger.success(f"加载插件 {plugin_name} 成功")
return True
except:
@@ -232,8 +253,22 @@ class PluginManager(metaclass=Singleton):
try:
plugin = self.plugins[plugin_name]
# 生命周期: on_disable清理定时任务
await plugin.on_disable()
# 解绑事件处理器
EventManager.unbind_instance(plugin)
# 从全局注册中心注销插件的工具
try:
unregister_plugin_tools(plugin_name)
except Exception as e:
logger.warning(f"注销插件 {plugin_name} 的工具时出错: {e}")
# 生命周期: on_unload释放资源
await plugin.on_unload()
del self.plugins[plugin_name]
del self.plugin_classes[plugin_name]
if plugin_name in self.plugin_info.keys():
@@ -256,7 +291,7 @@ class PluginManager(metaclass=Singleton):
return unloaded_plugins, failed_unloads
async def reload_plugin(self, plugin_name: str) -> bool:
"""重载单个插件"""
"""重载单个插件(支持状态保存和恢复)"""
if plugin_name not in self.plugin_classes:
return False
@@ -270,7 +305,15 @@ class PluginManager(metaclass=Singleton):
plugin_class = self.plugin_classes[plugin_name]
module_name = plugin_class.__module__
# 先卸载插件
# 生命周期: on_reload保存状态
saved_state = {}
if plugin_name in self.plugins:
try:
saved_state = await self.plugins[plugin_name].on_reload()
except Exception as e:
logger.warning(f"保存插件 {plugin_name} 状态失败: {e}")
# 卸载插件
if not await self.unload_plugin(plugin_name):
return False
@@ -284,8 +327,15 @@ class PluginManager(metaclass=Singleton):
issubclass(obj, PluginBase) and
obj != PluginBase and
obj.__name__ == plugin_name):
# 使用新的插件类而不是旧的
return await self.load_plugin(obj)
# 加载新插件
if await self.load_plugin(obj):
# 恢复状态
if saved_state and plugin_name in self.plugins:
try:
await self.plugins[plugin_name].restore_state(saved_state)
except Exception as e:
logger.warning(f"恢复插件 {plugin_name} 状态失败: {e}")
return True
return False
except Exception as e:
@@ -349,13 +399,42 @@ class PluginManager(metaclass=Singleton):
def get_plugin_info(self, plugin_name: str = None) -> Union[dict, List[dict]]:
"""获取插件信息
Args:
plugin_name: 插件名称如果为None则返回所有插件信息
Returns:
如果指定插件名,返回单个插件信息字典;否则返回所有插件信息列表
"""
if plugin_name:
return self.plugin_info.get(plugin_name)
return list(self.plugin_info.values())
def _get_tools_config(self) -> dict:
"""获取工具配置(用于工具注册)"""
try:
# 尝试从 AIChat 插件配置读取
import tomllib
from pathlib import Path
aichat_config_path = Path("plugins/AIChat/config.toml")
if aichat_config_path.exists():
with open(aichat_config_path, "rb") as f:
aichat_config = tomllib.load(f)
return aichat_config.get("tools", {})
except Exception:
pass
return {"mode": "all", "whitelist": [], "blacklist": []}
def _get_timeout_config(self) -> dict:
"""获取工具超时配置"""
try:
import tomllib
from pathlib import Path
aichat_config_path = Path("plugins/AIChat/config.toml")
if aichat_config_path.exists():
with open(aichat_config_path, "rb") as f:
aichat_config = tomllib.load(f)
return aichat_config.get("tools", {}).get("timeout", {})
except Exception:
pass
return {"default": 60}

488
utils/tool_executor.py Normal file
View File

@@ -0,0 +1,488 @@
"""
工具执行器模块
提供工具调用的高级执行逻辑:
- 批量工具执行(支持并行)
- 工具调用链处理
- 执行日志和审计
- 结果聚合
使用示例:
from utils.tool_executor import ToolExecutor, ToolCallRequest
executor = ToolExecutor()
# 单个工具执行
result = await executor.execute_single(
tool_call={"id": "call_1", "function": {"name": "get_weather", "arguments": "{}"}},
bot=bot,
from_wxid=wxid,
)
# 批量工具执行
results = await executor.execute_batch(
tool_calls=[...],
bot=bot,
from_wxid=wxid,
parallel=True,
)
"""
from __future__ import annotations
import asyncio
import json
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
from loguru import logger
@dataclass
class ToolCallRequest:
"""工具调用请求"""
id: str
name: str
arguments: Dict[str, Any]
raw_arguments: str = "" # 原始 JSON 字符串
@dataclass
class ToolCallResult:
"""工具调用结果"""
id: str
name: str
success: bool = True
message: str = ""
raw_result: Dict[str, Any] = field(default_factory=dict)
# 控制标志
need_ai_reply: bool = False
already_sent: bool = False
send_result_text: bool = False
no_reply: bool = False
save_to_memory: bool = False
# 执行信息
execution_time_ms: float = 0.0
error: Optional[str] = None
def to_message(self) -> Dict[str, Any]:
"""转换为 OpenAI 兼容的 tool message"""
content = self.message if self.success else f"错误: {self.error or self.message}"
return {
"role": "tool",
"tool_call_id": self.id,
"content": content
}
@dataclass
class ExecutionStats:
"""执行统计"""
total_calls: int = 0
successful_calls: int = 0
failed_calls: int = 0
timeout_calls: int = 0
total_time_ms: float = 0.0
avg_time_ms: float = 0.0
class ToolExecutor:
"""
工具执行器
提供统一的工具执行接口:
- 参数解析和校验
- 超时保护
- 错误处理
- 执行统计
"""
def __init__(
self,
default_timeout: float = 60.0,
max_parallel: int = 5,
validate_args: bool = True,
):
self.default_timeout = default_timeout
self.max_parallel = max_parallel
self.validate_args = validate_args
self._stats = ExecutionStats()
def parse_tool_call(self, tool_call: Dict[str, Any]) -> ToolCallRequest:
"""
解析 OpenAI 格式的工具调用
Args:
tool_call: OpenAI 返回的 tool_call 对象
Returns:
ToolCallRequest 对象
"""
call_id = tool_call.get("id", "")
function = tool_call.get("function", {})
name = function.get("name", "")
raw_args = function.get("arguments", "{}")
# 解析 arguments JSON
try:
arguments = json.loads(raw_args) if raw_args else {}
except json.JSONDecodeError as e:
logger.warning(f"[ToolExecutor] 解析参数失败: {e}, raw={raw_args[:100]}")
arguments = {}
return ToolCallRequest(
id=call_id,
name=name,
arguments=arguments,
raw_arguments=raw_args,
)
async def execute_single(
self,
tool_call: Dict[str, Any],
bot,
from_wxid: str,
timeout_override: Optional[float] = None,
) -> ToolCallResult:
"""
执行单个工具调用
Args:
tool_call: OpenAI 格式的 tool_call
bot: WechatHookClient 实例
from_wxid: 消息来源 wxid
timeout_override: 覆盖默认超时
Returns:
ToolCallResult 对象
"""
from utils.tool_registry import get_tool_registry
from utils.llm_tooling import validate_tool_arguments, ToolResult
start_time = time.time()
request = self.parse_tool_call(tool_call)
registry = get_tool_registry()
result = ToolCallResult(
id=request.id,
name=request.name,
)
# 获取工具定义
tool_def = registry.get(request.name)
if not tool_def:
result.success = False
result.error = f"工具 {request.name} 不存在"
result.message = result.error
self._update_stats(False, time.time() - start_time)
return result
# 参数校验
if self.validate_args:
schema = tool_def.schema.get("function", {}).get("parameters", {})
ok, error_msg, validated_args = validate_tool_arguments(
request.name, request.arguments, schema
)
if not ok:
result.success = False
result.error = error_msg
result.message = error_msg
self._update_stats(False, time.time() - start_time)
return result
request.arguments = validated_args
# 执行工具
timeout = timeout_override or tool_def.timeout or self.default_timeout
try:
logger.debug(f"[ToolExecutor] 执行工具: {request.name}")
raw_result = await asyncio.wait_for(
tool_def.executor(request.name, request.arguments, bot, from_wxid),
timeout=timeout
)
# 解析结果
tool_result = ToolResult.from_raw(raw_result)
if tool_result:
result.success = tool_result.success
result.message = tool_result.message
result.need_ai_reply = tool_result.need_ai_reply
result.already_sent = tool_result.already_sent
result.send_result_text = tool_result.send_result_text
result.no_reply = tool_result.no_reply
result.save_to_memory = tool_result.save_to_memory
else:
result.message = str(raw_result) if raw_result else "执行完成"
result.raw_result = raw_result if isinstance(raw_result, dict) else {"result": raw_result}
execution_time = time.time() - start_time
result.execution_time_ms = execution_time * 1000
self._update_stats(result.success, execution_time)
logger.debug(
f"[ToolExecutor] 工具 {request.name} 执行完成 "
f"({result.execution_time_ms:.1f}ms)"
)
except asyncio.TimeoutError:
result.success = False
result.error = f"执行超时 ({timeout}s)"
result.message = result.error
self._update_stats(False, time.time() - start_time, timeout=True)
logger.warning(f"[ToolExecutor] 工具 {request.name} 执行超时")
except asyncio.CancelledError:
raise
except Exception as e:
result.success = False
result.error = str(e)
result.message = f"执行失败: {e}"
self._update_stats(False, time.time() - start_time)
logger.error(f"[ToolExecutor] 工具 {request.name} 执行异常: {e}")
return result
async def execute_batch(
self,
tool_calls: List[Dict[str, Any]],
bot,
from_wxid: str,
parallel: bool = True,
stop_on_error: bool = False,
) -> List[ToolCallResult]:
"""
批量执行工具调用
Args:
tool_calls: 工具调用列表
bot: WechatHookClient 实例
from_wxid: 消息来源 wxid
parallel: 是否并行执行
stop_on_error: 遇到错误是否停止
Returns:
ToolCallResult 列表
"""
if not tool_calls:
return []
if parallel and len(tool_calls) > 1:
return await self._execute_parallel(tool_calls, bot, from_wxid, stop_on_error)
else:
return await self._execute_sequential(tool_calls, bot, from_wxid, stop_on_error)
async def _execute_sequential(
self,
tool_calls: List[Dict[str, Any]],
bot,
from_wxid: str,
stop_on_error: bool,
) -> List[ToolCallResult]:
"""顺序执行"""
results = []
for tool_call in tool_calls:
result = await self.execute_single(tool_call, bot, from_wxid)
results.append(result)
if stop_on_error and not result.success:
logger.warning(f"[ToolExecutor] 工具 {result.name} 失败,停止批量执行")
break
return results
async def _execute_parallel(
self,
tool_calls: List[Dict[str, Any]],
bot,
from_wxid: str,
stop_on_error: bool,
) -> List[ToolCallResult]:
"""并行执行(带并发限制)"""
semaphore = asyncio.Semaphore(self.max_parallel)
async def execute_with_limit(tool_call):
async with semaphore:
return await self.execute_single(tool_call, bot, from_wxid)
tasks = [execute_with_limit(tc) for tc in tool_calls]
if stop_on_error:
# 使用 gather 但不 return_exceptions让第一个错误停止执行
results = []
for coro in asyncio.as_completed(tasks):
try:
result = await coro
results.append(result)
if not result.success:
# 取消剩余任务
for task in tasks:
if isinstance(task, asyncio.Task) and not task.done():
task.cancel()
break
except Exception as e:
logger.error(f"[ToolExecutor] 并行执行异常: {e}")
break
return results
else:
# 全部执行,收集所有结果
return await asyncio.gather(*tasks, return_exceptions=False)
def _update_stats(self, success: bool, execution_time: float, timeout: bool = False):
"""更新执行统计"""
self._stats.total_calls += 1
if success:
self._stats.successful_calls += 1
else:
self._stats.failed_calls += 1
if timeout:
self._stats.timeout_calls += 1
self._stats.total_time_ms += execution_time * 1000
self._stats.avg_time_ms = self._stats.total_time_ms / self._stats.total_calls
def get_stats(self) -> Dict[str, Any]:
"""获取执行统计"""
return {
"total_calls": self._stats.total_calls,
"successful_calls": self._stats.successful_calls,
"failed_calls": self._stats.failed_calls,
"timeout_calls": self._stats.timeout_calls,
"total_time_ms": self._stats.total_time_ms,
"avg_time_ms": self._stats.avg_time_ms,
"success_rate": (
self._stats.successful_calls / self._stats.total_calls
if self._stats.total_calls > 0 else 0
),
}
def reset_stats(self):
"""重置统计"""
self._stats = ExecutionStats()
class ToolCallChain:
"""
工具调用链
用于处理需要多轮工具调用的场景,记录调用历史。
"""
def __init__(self, max_rounds: int = 10):
self.max_rounds = max_rounds
self.history: List[ToolCallResult] = []
self.current_round = 0
def add_result(self, result: ToolCallResult):
"""添加调用结果"""
self.history.append(result)
def add_results(self, results: List[ToolCallResult]):
"""添加多个调用结果"""
self.history.extend(results)
def increment_round(self):
"""增加轮次"""
self.current_round += 1
def can_continue(self) -> bool:
"""检查是否可以继续调用"""
return self.current_round < self.max_rounds
def get_tool_messages(self) -> List[Dict[str, Any]]:
"""获取所有工具调用的消息(用于发送给 LLM"""
return [result.to_message() for result in self.history]
def get_last_results(self, n: int = 1) -> List[ToolCallResult]:
"""获取最后 n 个结果"""
return self.history[-n:] if self.history else []
def has_special_flags(self) -> Dict[str, bool]:
"""检查是否有特殊标志"""
flags = {
"need_ai_reply": False,
"already_sent": False,
"no_reply": False,
"save_to_memory": False,
"send_result_text": False,
}
for result in self.history:
if result.need_ai_reply:
flags["need_ai_reply"] = True
if result.already_sent:
flags["already_sent"] = True
if result.no_reply:
flags["no_reply"] = True
if result.save_to_memory:
flags["save_to_memory"] = True
if result.send_result_text:
flags["send_result_text"] = True
return flags
def get_summary(self) -> str:
"""获取调用链摘要"""
if not self.history:
return "无工具调用"
successful = sum(1 for r in self.history if r.success)
failed = len(self.history) - successful
total_time = sum(r.execution_time_ms for r in self.history)
tools_called = [r.name for r in self.history]
return (
f"调用链: {len(self.history)} 个工具, "
f"成功 {successful}, 失败 {failed}, "
f"总耗时 {total_time:.1f}ms, "
f"工具: {', '.join(tools_called)}"
)
# ==================== 便捷函数 ====================
_default_executor: Optional[ToolExecutor] = None
def get_tool_executor(
default_timeout: float = 60.0,
max_parallel: int = 5,
) -> ToolExecutor:
"""获取默认工具执行器"""
global _default_executor
if _default_executor is None:
_default_executor = ToolExecutor(
default_timeout=default_timeout,
max_parallel=max_parallel,
)
return _default_executor
async def execute_tool_calls(
tool_calls: List[Dict[str, Any]],
bot,
from_wxid: str,
parallel: bool = True,
) -> List[ToolCallResult]:
"""便捷函数:执行工具调用列表"""
executor = get_tool_executor()
return await executor.execute_batch(tool_calls, bot, from_wxid, parallel=parallel)
# ==================== 导出 ====================
__all__ = [
'ToolCallRequest',
'ToolCallResult',
'ExecutionStats',
'ToolExecutor',
'ToolCallChain',
'get_tool_executor',
'execute_tool_calls',
]

286
utils/tool_registry.py Normal file
View File

@@ -0,0 +1,286 @@
"""
工具注册中心
集中管理所有 LLM 工具的注册、查找和执行
- O(1) 工具查找(替代 O(n) 插件遍历)
- 统一的超时保护
- 工具元信息管理
使用示例:
from utils.tool_registry import get_tool_registry
registry = get_tool_registry()
# 注册工具
registry.register(
name="generate_image",
plugin_name="AIChat",
schema={...},
executor=some_async_func,
timeout=120
)
# 执行工具
result = await registry.execute("generate_image", arguments, bot, from_wxid)
"""
import asyncio
from dataclasses import dataclass, field
from threading import Lock
from typing import Any, Callable, Dict, List, Optional, Awaitable
from loguru import logger
@dataclass
class ToolDefinition:
"""工具定义"""
name: str
plugin_name: str
schema: Dict[str, Any] # OpenAI-compatible tool schema
executor: Callable[..., Awaitable[Dict[str, Any]]]
timeout: float = 60.0
priority: int = 50 # 同名工具时优先级高的生效
description: str = ""
def __post_init__(self):
# 从 schema 提取描述
if not self.description and self.schema:
func_def = self.schema.get("function", {})
self.description = func_def.get("description", "")
class ToolRegistry:
"""
工具注册中心(线程安全单例)
功能:
- 工具注册与注销
- O(1) 工具查找
- 统一超时保护执行
- 工具列表导出(供 LLM 使用)
"""
_instance: Optional["ToolRegistry"] = None
_lock = Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
instance = super().__new__(cls)
instance._initialized = False
cls._instance = instance
return cls._instance
def __init__(self):
if self._initialized:
return
self._tools: Dict[str, ToolDefinition] = {}
self._tools_by_plugin: Dict[str, List[str]] = {} # plugin_name -> [tool_names]
self._registry_lock = Lock()
self._initialized = True
logger.debug("ToolRegistry 初始化完成")
def register(
self,
name: str,
plugin_name: str,
schema: Dict[str, Any],
executor: Callable[..., Awaitable[Dict[str, Any]]],
timeout: float = 60.0,
priority: int = 50,
) -> bool:
"""
注册工具
Args:
name: 工具名称(唯一标识)
plugin_name: 所属插件名
schema: OpenAI-compatible tool schema
executor: 异步执行函数,签名: async (tool_name, arguments, bot, from_wxid) -> dict
timeout: 执行超时(秒)
priority: 优先级(同名工具时高优先级覆盖低优先级)
Returns:
是否注册成功
"""
with self._registry_lock:
# 检查是否已存在同名工具
existing = self._tools.get(name)
if existing:
if existing.priority >= priority:
logger.debug(f"工具 {name} 已存在且优先级更高,跳过注册")
return False
logger.info(f"工具 {name}{plugin_name} 覆盖(优先级 {priority} > {existing.priority}")
# 从旧插件的工具列表中移除
old_plugin_tools = self._tools_by_plugin.get(existing.plugin_name, [])
if name in old_plugin_tools:
old_plugin_tools.remove(name)
# 注册新工具
tool_def = ToolDefinition(
name=name,
plugin_name=plugin_name,
schema=schema,
executor=executor,
timeout=timeout,
priority=priority,
)
self._tools[name] = tool_def
# 更新插件工具映射
if plugin_name not in self._tools_by_plugin:
self._tools_by_plugin[plugin_name] = []
if name not in self._tools_by_plugin[plugin_name]:
self._tools_by_plugin[plugin_name].append(name)
logger.debug(f"注册工具: {name} (插件: {plugin_name}, 超时: {timeout}s)")
return True
def unregister(self, name: str) -> bool:
"""注销工具"""
with self._registry_lock:
tool_def = self._tools.pop(name, None)
if tool_def:
plugin_tools = self._tools_by_plugin.get(tool_def.plugin_name, [])
if name in plugin_tools:
plugin_tools.remove(name)
logger.debug(f"注销工具: {name}")
return True
return False
def unregister_plugin(self, plugin_name: str) -> int:
"""
注销插件的所有工具
Args:
plugin_name: 插件名
Returns:
注销的工具数量
"""
with self._registry_lock:
tool_names = self._tools_by_plugin.pop(plugin_name, [])
count = 0
for name in tool_names:
if self._tools.pop(name, None):
count += 1
if count > 0:
logger.info(f"注销插件 {plugin_name}{count} 个工具")
return count
def get(self, name: str) -> Optional[ToolDefinition]:
"""获取工具定义O(1) 查找)"""
return self._tools.get(name)
def get_all_schemas(self) -> List[Dict[str, Any]]:
"""获取所有工具的 schema 列表(供 LLM 使用)"""
return [tool.schema for tool in self._tools.values()]
def get_plugin_schemas(self, plugin_name: str) -> List[Dict[str, Any]]:
"""获取指定插件的工具 schema 列表"""
tool_names = self._tools_by_plugin.get(plugin_name, [])
return [self._tools[name].schema for name in tool_names if name in self._tools]
def list_tools(self) -> List[str]:
"""列出所有工具名"""
return list(self._tools.keys())
def list_plugin_tools(self, plugin_name: str) -> List[str]:
"""列出插件的所有工具名"""
return self._tools_by_plugin.get(plugin_name, []).copy()
async def execute(
self,
name: str,
arguments: Dict[str, Any],
bot,
from_wxid: str,
timeout_override: float = None,
) -> Dict[str, Any]:
"""
执行工具(带超时保护和统一错误处理)
Args:
name: 工具名
arguments: 工具参数
bot: WechatHookClient 实例
from_wxid: 消息来源 wxid
timeout_override: 覆盖默认超时时间
Returns:
工具执行结果字典
"""
from utils.errors import (
ToolNotFoundError, ToolTimeoutError, ToolExecutionError,
handle_error
)
tool_def = self._tools.get(name)
if not tool_def:
err = ToolNotFoundError(f"工具 {name} 不存在")
return err.to_dict()
timeout = timeout_override if timeout_override is not None else tool_def.timeout
try:
result = await asyncio.wait_for(
tool_def.executor(name, arguments, bot, from_wxid),
timeout=timeout
)
return result
except asyncio.TimeoutError:
err = ToolTimeoutError(
message=f"工具 {name} 执行超时 ({timeout}s)",
user_message=f"工具执行超时 ({timeout}s)",
context={"tool_name": name, "timeout": timeout}
)
logger.warning(err.message)
result = err.to_dict()
result["timeout"] = True
return result
except Exception as e:
error_result = handle_error(
e,
context=f"执行工具 {name}",
log=True,
)
return error_result.to_dict()
def get_stats(self) -> Dict[str, Any]:
"""获取注册统计信息"""
return {
"total_tools": len(self._tools),
"plugins": len(self._tools_by_plugin),
"tools_by_plugin": {
plugin: len(tools)
for plugin, tools in self._tools_by_plugin.items()
}
}
def clear(self):
"""清空所有注册(用于测试或重置)"""
with self._registry_lock:
self._tools.clear()
self._tools_by_plugin.clear()
logger.info("ToolRegistry 已清空")
# ==================== 便捷函数 ====================
def get_tool_registry() -> ToolRegistry:
"""获取工具注册中心实例"""
return ToolRegistry()
# ==================== 导出列表 ====================
__all__ = [
'ToolDefinition',
'ToolRegistry',
'get_tool_registry',
]