""" 机器人核心工具模块 包含: - 优先级消息队列 - 自适应熔断器 - 请求重试机制 - 配置热更新 - 性能监控 """ import asyncio import time import heapq import os import tomllib import functools from pathlib import Path from typing import Dict, List, Optional, Callable, Any, Tuple from dataclasses import dataclass, field from loguru import logger import aiohttp # ==================== 优先级消息队列 ==================== # 消息优先级定义 class MessagePriority: """消息优先级常量""" CRITICAL = 100 # 系统消息、登录信息 HIGH = 80 # 管理员命令 NORMAL = 50 # @机器人的消息 LOW = 20 # 普通群消息 # 高优先级消息类型 PRIORITY_MESSAGE_TYPES = { 11025: MessagePriority.CRITICAL, # 登录信息 11058: MessagePriority.CRITICAL, # 系统消息 11098: MessagePriority.HIGH, # 群成员加入 11099: MessagePriority.HIGH, # 群成员退出 11100: MessagePriority.HIGH, # 群信息变更 11056: MessagePriority.HIGH, # 好友请求 } @dataclass(order=True) class PriorityMessage: """优先级消息包装""" priority: int timestamp: float = field(compare=False) msg_type: int = field(compare=False) data: dict = field(compare=False) def __init__(self, msg_type: int, data: dict, priority: int = None): # 优先级越高,数值越小(因为heapq是最小堆) self.priority = -(priority or PRIORITY_MESSAGE_TYPES.get(msg_type, MessagePriority.NORMAL)) self.timestamp = time.time() self.msg_type = msg_type self.data = data class PriorityMessageQueue: """优先级消息队列""" def __init__(self, maxsize: int = 1000): self.maxsize = maxsize self._heap: List[PriorityMessage] = [] self._lock = asyncio.Lock() self._not_empty = asyncio.Event() self._unfinished_tasks = 0 self._finished = asyncio.Event() self._finished.set() def qsize(self) -> int: """返回队列大小""" return len(self._heap) def empty(self) -> bool: """队列是否为空""" return len(self._heap) == 0 def full(self) -> bool: """队列是否已满""" return len(self._heap) >= self.maxsize async def put(self, msg_type: int, data: dict, priority: int = None): """添加消息到队列""" async with self._lock: msg = PriorityMessage(msg_type, data, priority) heapq.heappush(self._heap, msg) self._unfinished_tasks += 1 self._finished.clear() self._not_empty.set() async def get(self) -> Tuple[int, dict]: """获取优先级最高的消息""" while True: async with self._lock: if self._heap: msg = heapq.heappop(self._heap) if not self._heap: self._not_empty.clear() return (msg.msg_type, msg.data) # 等待新消息 await self._not_empty.wait() def get_nowait(self) -> Tuple[int, dict]: """非阻塞获取消息""" if not self._heap: raise asyncio.QueueEmpty() msg = heapq.heappop(self._heap) if not self._heap: self._not_empty.clear() return (msg.msg_type, msg.data) def task_done(self): """标记任务完成""" self._unfinished_tasks -= 1 if self._unfinished_tasks == 0: self._finished.set() async def join(self): """等待所有任务完成""" await self._finished.wait() def drop_lowest_priority(self) -> bool: """丢弃优先级最低的消息""" if not self._heap: return False # 找到优先级最低的消息(priority值最大,因为是负数所以最小) min_idx = 0 for i, msg in enumerate(self._heap): if msg.priority > self._heap[min_idx].priority: min_idx = i # 删除该消息 self._heap.pop(min_idx) heapq.heapify(self._heap) self._unfinished_tasks -= 1 return True # ==================== 自适应熔断器 ==================== class AdaptiveCircuitBreaker: """自适应熔断器""" # 熔断器状态 STATE_CLOSED = "closed" # 正常状态 STATE_OPEN = "open" # 熔断状态 STATE_HALF_OPEN = "half_open" # 半开状态(尝试恢复) def __init__( self, failure_threshold: int = 5, success_threshold: int = 3, initial_recovery_time: float = 5.0, max_recovery_time: float = 300.0, recovery_multiplier: float = 2.0 ): """ 初始化熔断器 Args: failure_threshold: 触发熔断的连续失败次数 success_threshold: 恢复正常的连续成功次数 initial_recovery_time: 初始恢复等待时间(秒) max_recovery_time: 最大恢复等待时间(秒) recovery_multiplier: 恢复时间增长倍数 """ self.failure_threshold = failure_threshold self.success_threshold = success_threshold self.initial_recovery_time = initial_recovery_time self.max_recovery_time = max_recovery_time self.recovery_multiplier = recovery_multiplier # 状态 self.state = self.STATE_CLOSED self.failure_count = 0 self.success_count = 0 self.last_failure_time = 0 self.current_recovery_time = initial_recovery_time # 统计 self.total_failures = 0 self.total_successes = 0 self.total_rejections = 0 def is_open(self) -> bool: """检查熔断器是否开启(是否应该拒绝请求)""" if self.state == self.STATE_CLOSED: return False if self.state == self.STATE_OPEN: # 检查是否可以尝试恢复 elapsed = time.time() - self.last_failure_time if elapsed >= self.current_recovery_time: self.state = self.STATE_HALF_OPEN self.success_count = 0 logger.info(f"熔断器进入半开状态,尝试恢复(等待了 {elapsed:.1f}s)") return False return True # 半开状态,允许请求通过 return False def record_success(self): """记录成功""" self.total_successes += 1 if self.state == self.STATE_HALF_OPEN: self.success_count += 1 if self.success_count >= self.success_threshold: # 恢复正常 self.state = self.STATE_CLOSED self.failure_count = 0 self.success_count = 0 self.current_recovery_time = self.initial_recovery_time logger.success(f"熔断器已恢复正常(连续成功 {self.success_threshold} 次)") else: # 正常状态,重置失败计数 self.failure_count = 0 def record_failure(self): """记录失败""" self.total_failures += 1 self.failure_count += 1 self.last_failure_time = time.time() if self.state == self.STATE_HALF_OPEN: # 半开状态下失败,重新熔断 self.state = self.STATE_OPEN self.success_count = 0 # 增加恢复时间 self.current_recovery_time = min( self.current_recovery_time * self.recovery_multiplier, self.max_recovery_time ) logger.warning(f"熔断器重新开启,下次恢复等待 {self.current_recovery_time:.1f}s") elif self.state == self.STATE_CLOSED: if self.failure_count >= self.failure_threshold: self.state = self.STATE_OPEN logger.warning(f"熔断器开启,连续失败 {self.failure_count} 次") def record_rejection(self): """记录被拒绝的请求""" self.total_rejections += 1 def get_stats(self) -> dict: """获取统计信息""" return { "state": self.state, "failure_count": self.failure_count, "success_count": self.success_count, "current_recovery_time": self.current_recovery_time, "total_failures": self.total_failures, "total_successes": self.total_successes, "total_rejections": self.total_rejections } # ==================== 请求重试机制 ==================== class RetryConfig: """重试配置""" def __init__( self, max_retries: int = 3, initial_delay: float = 1.0, max_delay: float = 30.0, exponential_base: float = 2.0, retryable_exceptions: tuple = ( aiohttp.ClientError, asyncio.TimeoutError, ConnectionError, ) ): self.max_retries = max_retries self.initial_delay = initial_delay self.max_delay = max_delay self.exponential_base = exponential_base self.retryable_exceptions = retryable_exceptions def retry_async(config: RetryConfig = None): """异步重试装饰器""" if config is None: config = RetryConfig() def decorator(func): @functools.wraps(func) async def wrapper(*args, **kwargs): last_exception = None for attempt in range(config.max_retries + 1): try: return await func(*args, **kwargs) except config.retryable_exceptions as e: last_exception = e if attempt == config.max_retries: logger.error(f"重试 {config.max_retries} 次后仍然失败: {func.__name__}") raise # 计算延迟时间(指数退避) delay = min( config.initial_delay * (config.exponential_base ** attempt), config.max_delay ) logger.warning( f"请求失败,{delay:.1f}s 后重试 " f"(第 {attempt + 1}/{config.max_retries} 次): {e}" ) await asyncio.sleep(delay) raise last_exception return wrapper return decorator async def request_with_retry( session: aiohttp.ClientSession, method: str, url: str, max_retries: int = 3, **kwargs ) -> aiohttp.ClientResponse: """带重试的 HTTP 请求""" config = RetryConfig(max_retries=max_retries) last_exception = None for attempt in range(config.max_retries + 1): try: response = await session.request(method, url, **kwargs) return response except config.retryable_exceptions as e: last_exception = e if attempt == config.max_retries: raise delay = min( config.initial_delay * (config.exponential_base ** attempt), config.max_delay ) logger.warning(f"HTTP 请求失败,{delay:.1f}s 后重试: {e}") await asyncio.sleep(delay) raise last_exception # ==================== 配置热更新 ==================== class ConfigWatcher: """配置文件监听器""" def __init__(self, config_path: str, check_interval: float = 5.0): """ 初始化配置监听器 Args: config_path: 配置文件路径 check_interval: 检查间隔(秒) """ self.config_path = Path(config_path) self.check_interval = check_interval self.last_mtime = 0 self.callbacks: List[Callable[[dict], Any]] = [] self.current_config: dict = {} self._running = False self._task: Optional[asyncio.Task] = None def register_callback(self, callback: Callable[[dict], Any]): """注册配置更新回调""" self.callbacks.append(callback) def unregister_callback(self, callback: Callable[[dict], Any]): """取消注册回调""" if callback in self.callbacks: self.callbacks.remove(callback) def _load_config(self) -> dict: """加载配置文件""" try: with open(self.config_path, "rb") as f: return tomllib.load(f) except Exception as e: logger.error(f"加载配置文件失败: {e}") return {} def get_config(self) -> dict: """获取当前配置""" return self.current_config async def start(self): """启动配置监听""" if self._running: return self._running = True # 初始加载 if self.config_path.exists(): self.last_mtime = os.path.getmtime(self.config_path) self.current_config = self._load_config() self._task = asyncio.create_task(self._watch_loop()) logger.info(f"配置监听器已启动: {self.config_path}") async def stop(self): """停止配置监听""" self._running = False if self._task: self._task.cancel() try: await self._task except asyncio.CancelledError: pass logger.info("配置监听器已停止") async def _watch_loop(self): """监听循环""" while self._running: try: await asyncio.sleep(self.check_interval) if not self.config_path.exists(): continue mtime = os.path.getmtime(self.config_path) if mtime > self.last_mtime: logger.info("检测到配置文件变化,重新加载...") new_config = self._load_config() if new_config: old_config = self.current_config self.current_config = new_config self.last_mtime = mtime # 通知所有回调 for callback in self.callbacks: try: if asyncio.iscoroutinefunction(callback): await callback(new_config) else: callback(new_config) except Exception as e: logger.error(f"配置更新回调执行失败: {e}") logger.success("配置已热更新") except asyncio.CancelledError: break except Exception as e: logger.error(f"配置监听异常: {e}") # ==================== 性能监控 ==================== class PerformanceMonitor: """性能监控器""" def __init__(self): self.start_time = time.time() # 消息统计 self.message_received = 0 self.message_processed = 0 self.message_failed = 0 self.message_dropped = 0 # 处理时间统计 self.processing_times: List[float] = [] self.max_processing_times = 1000 # 保留最近1000条记录 # 插件统计 self.plugin_stats: Dict[str, dict] = {} # 队列统计 self.queue_size_history: List[Tuple[float, int]] = [] self.max_queue_history = 100 # 熔断器统计 self.circuit_breaker_stats: dict = {} def record_message_received(self): """记录收到消息""" self.message_received += 1 def record_message_processed(self, processing_time: float): """记录消息处理完成""" self.message_processed += 1 self.processing_times.append(processing_time) # 限制历史记录数量 if len(self.processing_times) > self.max_processing_times: self.processing_times = self.processing_times[-self.max_processing_times:] def record_message_failed(self): """记录消息处理失败""" self.message_failed += 1 def record_message_dropped(self): """记录消息被丢弃""" self.message_dropped += 1 def record_queue_size(self, size: int): """记录队列大小""" self.queue_size_history.append((time.time(), size)) if len(self.queue_size_history) > self.max_queue_history: self.queue_size_history = self.queue_size_history[-self.max_queue_history:] def record_plugin_execution(self, plugin_name: str, execution_time: float, success: bool): """记录插件执行""" if plugin_name not in self.plugin_stats: self.plugin_stats[plugin_name] = { "total_calls": 0, "success_calls": 0, "failed_calls": 0, "total_time": 0, "max_time": 0, "recent_times": [] } stats = self.plugin_stats[plugin_name] stats["total_calls"] += 1 stats["total_time"] += execution_time stats["max_time"] = max(stats["max_time"], execution_time) stats["recent_times"].append(execution_time) if len(stats["recent_times"]) > 100: stats["recent_times"] = stats["recent_times"][-100:] if success: stats["success_calls"] += 1 else: stats["failed_calls"] += 1 def update_circuit_breaker_stats(self, stats: dict): """更新熔断器统计""" self.circuit_breaker_stats = stats def get_stats(self) -> dict: """获取完整统计信息""" uptime = time.time() - self.start_time # 计算平均处理时间 avg_processing_time = 0 if self.processing_times: avg_processing_time = sum(self.processing_times) / len(self.processing_times) # 计算处理速率 processing_rate = self.message_processed / uptime if uptime > 0 else 0 # 计算成功率 total = self.message_processed + self.message_failed success_rate = self.message_processed / total if total > 0 else 1.0 return { "uptime_seconds": uptime, "uptime_formatted": self._format_uptime(uptime), "messages": { "received": self.message_received, "processed": self.message_processed, "failed": self.message_failed, "dropped": self.message_dropped, "success_rate": f"{success_rate * 100:.1f}%", "processing_rate": f"{processing_rate:.2f}/s" }, "processing_time": { "average_ms": f"{avg_processing_time * 1000:.1f}", "max_ms": f"{max(self.processing_times) * 1000:.1f}" if self.processing_times else "0", "min_ms": f"{min(self.processing_times) * 1000:.1f}" if self.processing_times else "0" }, "queue": { "current_size": self.queue_size_history[-1][1] if self.queue_size_history else 0, "max_size": max(s[1] for s in self.queue_size_history) if self.queue_size_history else 0 }, "circuit_breaker": self.circuit_breaker_stats, "plugins": self._get_plugin_summary() } def _get_plugin_summary(self) -> List[dict]: """获取插件统计摘要""" summary = [] for name, stats in self.plugin_stats.items(): avg_time = stats["total_time"] / stats["total_calls"] if stats["total_calls"] > 0 else 0 summary.append({ "name": name, "calls": stats["total_calls"], "success_rate": f"{stats['success_calls'] / stats['total_calls'] * 100:.1f}%" if stats["total_calls"] > 0 else "N/A", "avg_time_ms": f"{avg_time * 1000:.1f}", "max_time_ms": f"{stats['max_time'] * 1000:.1f}" }) # 按平均时间排序 summary.sort(key=lambda x: float(x["avg_time_ms"]), reverse=True) return summary def _format_uptime(self, seconds: float) -> str: """格式化运行时间""" days = int(seconds // 86400) hours = int((seconds % 86400) // 3600) minutes = int((seconds % 3600) // 60) secs = int(seconds % 60) if days > 0: return f"{days}天 {hours}小时 {minutes}分钟" elif hours > 0: return f"{hours}小时 {minutes}分钟" elif minutes > 0: return f"{minutes}分钟 {secs}秒" else: return f"{secs}秒" def print_stats(self): """打印统计信息到日志""" stats = self.get_stats() logger.info("=" * 50) logger.info("性能监控报告") logger.info("=" * 50) logger.info(f"运行时间: {stats['uptime_formatted']}") logger.info(f"消息统计: 收到 {stats['messages']['received']}, " f"处理 {stats['messages']['processed']}, " f"失败 {stats['messages']['failed']}, " f"丢弃 {stats['messages']['dropped']}") logger.info(f"成功率: {stats['messages']['success_rate']}, " f"处理速率: {stats['messages']['processing_rate']}") logger.info(f"平均处理时间: {stats['processing_time']['average_ms']}ms") logger.info(f"队列大小: {stats['queue']['current_size']}") logger.info(f"熔断器状态: {stats['circuit_breaker'].get('state', 'N/A')}") if stats['plugins']: logger.info("插件耗时排行:") for i, p in enumerate(stats['plugins'][:5], 1): logger.info(f" {i}. {p['name']}: {p['avg_time_ms']}ms (调用 {p['calls']} 次)") logger.info("=" * 50) # ==================== 全局实例 ==================== # 性能监控器单例 _performance_monitor: Optional[PerformanceMonitor] = None def get_performance_monitor() -> PerformanceMonitor: """获取性能监控器实例""" global _performance_monitor if _performance_monitor is None: _performance_monitor = PerformanceMonitor() return _performance_monitor