Files
WechatHookBot/utils/bot_utils.py
2025-12-05 18:06:13 +08:00

659 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
机器人核心工具模块
包含:
- 优先级消息队列
- 自适应熔断器
- 请求重试机制
- 配置热更新
- 性能监控
"""
import asyncio
import time
import heapq
import os
import tomllib
import functools
from pathlib import Path
from typing import Dict, List, Optional, Callable, Any, Tuple
from dataclasses import dataclass, field
from loguru import logger
import aiohttp
# ==================== 优先级消息队列 ====================
# 消息优先级定义
class MessagePriority:
"""消息优先级常量"""
CRITICAL = 100 # 系统消息、登录信息
HIGH = 80 # 管理员命令
NORMAL = 50 # @机器人的消息
LOW = 20 # 普通群消息
# 高优先级消息类型
PRIORITY_MESSAGE_TYPES = {
11025: MessagePriority.CRITICAL, # 登录信息
11058: MessagePriority.CRITICAL, # 系统消息
11098: MessagePriority.HIGH, # 群成员加入
11099: MessagePriority.HIGH, # 群成员退出
11100: MessagePriority.HIGH, # 群信息变更
11056: MessagePriority.HIGH, # 好友请求
}
@dataclass(order=True)
class PriorityMessage:
"""优先级消息包装"""
priority: int
timestamp: float = field(compare=False)
msg_type: int = field(compare=False)
data: dict = field(compare=False)
def __init__(self, msg_type: int, data: dict, priority: int = None):
# 优先级越高数值越小因为heapq是最小堆
self.priority = -(priority or PRIORITY_MESSAGE_TYPES.get(msg_type, MessagePriority.NORMAL))
self.timestamp = time.time()
self.msg_type = msg_type
self.data = data
class PriorityMessageQueue:
"""优先级消息队列"""
def __init__(self, maxsize: int = 1000):
self.maxsize = maxsize
self._heap: List[PriorityMessage] = []
self._lock = asyncio.Lock()
self._not_empty = asyncio.Event()
self._unfinished_tasks = 0
self._finished = asyncio.Event()
self._finished.set()
def qsize(self) -> int:
"""返回队列大小"""
return len(self._heap)
def empty(self) -> bool:
"""队列是否为空"""
return len(self._heap) == 0
def full(self) -> bool:
"""队列是否已满"""
return len(self._heap) >= self.maxsize
async def put(self, msg_type: int, data: dict, priority: int = None):
"""添加消息到队列"""
async with self._lock:
msg = PriorityMessage(msg_type, data, priority)
heapq.heappush(self._heap, msg)
self._unfinished_tasks += 1
self._finished.clear()
self._not_empty.set()
async def get(self) -> Tuple[int, dict]:
"""获取优先级最高的消息"""
while True:
async with self._lock:
if self._heap:
msg = heapq.heappop(self._heap)
if not self._heap:
self._not_empty.clear()
return (msg.msg_type, msg.data)
# 等待新消息
await self._not_empty.wait()
def get_nowait(self) -> Tuple[int, dict]:
"""非阻塞获取消息"""
if not self._heap:
raise asyncio.QueueEmpty()
msg = heapq.heappop(self._heap)
if not self._heap:
self._not_empty.clear()
return (msg.msg_type, msg.data)
def task_done(self):
"""标记任务完成"""
self._unfinished_tasks -= 1
if self._unfinished_tasks == 0:
self._finished.set()
async def join(self):
"""等待所有任务完成"""
await self._finished.wait()
def drop_lowest_priority(self) -> bool:
"""丢弃优先级最低的消息"""
if not self._heap:
return False
# 找到优先级最低的消息priority值最大因为是负数所以最小
min_idx = 0
for i, msg in enumerate(self._heap):
if msg.priority > self._heap[min_idx].priority:
min_idx = i
# 删除该消息
self._heap.pop(min_idx)
heapq.heapify(self._heap)
self._unfinished_tasks -= 1
return True
# ==================== 自适应熔断器 ====================
class AdaptiveCircuitBreaker:
"""自适应熔断器"""
# 熔断器状态
STATE_CLOSED = "closed" # 正常状态
STATE_OPEN = "open" # 熔断状态
STATE_HALF_OPEN = "half_open" # 半开状态(尝试恢复)
def __init__(
self,
failure_threshold: int = 5,
success_threshold: int = 3,
initial_recovery_time: float = 5.0,
max_recovery_time: float = 300.0,
recovery_multiplier: float = 2.0
):
"""
初始化熔断器
Args:
failure_threshold: 触发熔断的连续失败次数
success_threshold: 恢复正常的连续成功次数
initial_recovery_time: 初始恢复等待时间(秒)
max_recovery_time: 最大恢复等待时间(秒)
recovery_multiplier: 恢复时间增长倍数
"""
self.failure_threshold = failure_threshold
self.success_threshold = success_threshold
self.initial_recovery_time = initial_recovery_time
self.max_recovery_time = max_recovery_time
self.recovery_multiplier = recovery_multiplier
# 状态
self.state = self.STATE_CLOSED
self.failure_count = 0
self.success_count = 0
self.last_failure_time = 0
self.current_recovery_time = initial_recovery_time
# 统计
self.total_failures = 0
self.total_successes = 0
self.total_rejections = 0
def is_open(self) -> bool:
"""检查熔断器是否开启(是否应该拒绝请求)"""
if self.state == self.STATE_CLOSED:
return False
if self.state == self.STATE_OPEN:
# 检查是否可以尝试恢复
elapsed = time.time() - self.last_failure_time
if elapsed >= self.current_recovery_time:
self.state = self.STATE_HALF_OPEN
self.success_count = 0
logger.info(f"熔断器进入半开状态,尝试恢复(等待了 {elapsed:.1f}s")
return False
return True
# 半开状态,允许请求通过
return False
def record_success(self):
"""记录成功"""
self.total_successes += 1
if self.state == self.STATE_HALF_OPEN:
self.success_count += 1
if self.success_count >= self.success_threshold:
# 恢复正常
self.state = self.STATE_CLOSED
self.failure_count = 0
self.success_count = 0
self.current_recovery_time = self.initial_recovery_time
logger.success(f"熔断器已恢复正常(连续成功 {self.success_threshold} 次)")
else:
# 正常状态,重置失败计数
self.failure_count = 0
def record_failure(self):
"""记录失败"""
self.total_failures += 1
self.failure_count += 1
self.last_failure_time = time.time()
if self.state == self.STATE_HALF_OPEN:
# 半开状态下失败,重新熔断
self.state = self.STATE_OPEN
self.success_count = 0
# 增加恢复时间
self.current_recovery_time = min(
self.current_recovery_time * self.recovery_multiplier,
self.max_recovery_time
)
logger.warning(f"熔断器重新开启,下次恢复等待 {self.current_recovery_time:.1f}s")
elif self.state == self.STATE_CLOSED:
if self.failure_count >= self.failure_threshold:
self.state = self.STATE_OPEN
logger.warning(f"熔断器开启,连续失败 {self.failure_count}")
def record_rejection(self):
"""记录被拒绝的请求"""
self.total_rejections += 1
def get_stats(self) -> dict:
"""获取统计信息"""
return {
"state": self.state,
"failure_count": self.failure_count,
"success_count": self.success_count,
"current_recovery_time": self.current_recovery_time,
"total_failures": self.total_failures,
"total_successes": self.total_successes,
"total_rejections": self.total_rejections
}
# ==================== 请求重试机制 ====================
class RetryConfig:
"""重试配置"""
def __init__(
self,
max_retries: int = 3,
initial_delay: float = 1.0,
max_delay: float = 30.0,
exponential_base: float = 2.0,
retryable_exceptions: tuple = (
aiohttp.ClientError,
asyncio.TimeoutError,
ConnectionError,
)
):
self.max_retries = max_retries
self.initial_delay = initial_delay
self.max_delay = max_delay
self.exponential_base = exponential_base
self.retryable_exceptions = retryable_exceptions
def retry_async(config: RetryConfig = None):
"""异步重试装饰器"""
if config is None:
config = RetryConfig()
def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(config.max_retries + 1):
try:
return await func(*args, **kwargs)
except config.retryable_exceptions as e:
last_exception = e
if attempt == config.max_retries:
logger.error(f"重试 {config.max_retries} 次后仍然失败: {func.__name__}")
raise
# 计算延迟时间(指数退避)
delay = min(
config.initial_delay * (config.exponential_base ** attempt),
config.max_delay
)
logger.warning(
f"请求失败,{delay:.1f}s 后重试 "
f"(第 {attempt + 1}/{config.max_retries} 次): {e}"
)
await asyncio.sleep(delay)
raise last_exception
return wrapper
return decorator
async def request_with_retry(
session: aiohttp.ClientSession,
method: str,
url: str,
max_retries: int = 3,
**kwargs
) -> aiohttp.ClientResponse:
"""带重试的 HTTP 请求"""
config = RetryConfig(max_retries=max_retries)
last_exception = None
for attempt in range(config.max_retries + 1):
try:
response = await session.request(method, url, **kwargs)
return response
except config.retryable_exceptions as e:
last_exception = e
if attempt == config.max_retries:
raise
delay = min(
config.initial_delay * (config.exponential_base ** attempt),
config.max_delay
)
logger.warning(f"HTTP 请求失败,{delay:.1f}s 后重试: {e}")
await asyncio.sleep(delay)
raise last_exception
# ==================== 配置热更新 ====================
class ConfigWatcher:
"""配置文件监听器"""
def __init__(self, config_path: str, check_interval: float = 5.0):
"""
初始化配置监听器
Args:
config_path: 配置文件路径
check_interval: 检查间隔(秒)
"""
self.config_path = Path(config_path)
self.check_interval = check_interval
self.last_mtime = 0
self.callbacks: List[Callable[[dict], Any]] = []
self.current_config: dict = {}
self._running = False
self._task: Optional[asyncio.Task] = None
def register_callback(self, callback: Callable[[dict], Any]):
"""注册配置更新回调"""
self.callbacks.append(callback)
def unregister_callback(self, callback: Callable[[dict], Any]):
"""取消注册回调"""
if callback in self.callbacks:
self.callbacks.remove(callback)
def _load_config(self) -> dict:
"""加载配置文件"""
try:
with open(self.config_path, "rb") as f:
return tomllib.load(f)
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
return {}
def get_config(self) -> dict:
"""获取当前配置"""
return self.current_config
async def start(self):
"""启动配置监听"""
if self._running:
return
self._running = True
# 初始加载
if self.config_path.exists():
self.last_mtime = os.path.getmtime(self.config_path)
self.current_config = self._load_config()
self._task = asyncio.create_task(self._watch_loop())
logger.info(f"配置监听器已启动: {self.config_path}")
async def stop(self):
"""停止配置监听"""
self._running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
logger.info("配置监听器已停止")
async def _watch_loop(self):
"""监听循环"""
while self._running:
try:
await asyncio.sleep(self.check_interval)
if not self.config_path.exists():
continue
mtime = os.path.getmtime(self.config_path)
if mtime > self.last_mtime:
logger.info("检测到配置文件变化,重新加载...")
new_config = self._load_config()
if new_config:
old_config = self.current_config
self.current_config = new_config
self.last_mtime = mtime
# 通知所有回调
for callback in self.callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(new_config)
else:
callback(new_config)
except Exception as e:
logger.error(f"配置更新回调执行失败: {e}")
logger.success("配置已热更新")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"配置监听异常: {e}")
# ==================== 性能监控 ====================
class PerformanceMonitor:
"""性能监控器"""
def __init__(self):
self.start_time = time.time()
# 消息统计
self.message_received = 0
self.message_processed = 0
self.message_failed = 0
self.message_dropped = 0
# 处理时间统计
self.processing_times: List[float] = []
self.max_processing_times = 1000 # 保留最近1000条记录
# 插件统计
self.plugin_stats: Dict[str, dict] = {}
# 队列统计
self.queue_size_history: List[Tuple[float, int]] = []
self.max_queue_history = 100
# 熔断器统计
self.circuit_breaker_stats: dict = {}
def record_message_received(self):
"""记录收到消息"""
self.message_received += 1
def record_message_processed(self, processing_time: float):
"""记录消息处理完成"""
self.message_processed += 1
self.processing_times.append(processing_time)
# 限制历史记录数量
if len(self.processing_times) > self.max_processing_times:
self.processing_times = self.processing_times[-self.max_processing_times:]
def record_message_failed(self):
"""记录消息处理失败"""
self.message_failed += 1
def record_message_dropped(self):
"""记录消息被丢弃"""
self.message_dropped += 1
def record_queue_size(self, size: int):
"""记录队列大小"""
self.queue_size_history.append((time.time(), size))
if len(self.queue_size_history) > self.max_queue_history:
self.queue_size_history = self.queue_size_history[-self.max_queue_history:]
def record_plugin_execution(self, plugin_name: str, execution_time: float, success: bool):
"""记录插件执行"""
if plugin_name not in self.plugin_stats:
self.plugin_stats[plugin_name] = {
"total_calls": 0,
"success_calls": 0,
"failed_calls": 0,
"total_time": 0,
"max_time": 0,
"recent_times": []
}
stats = self.plugin_stats[plugin_name]
stats["total_calls"] += 1
stats["total_time"] += execution_time
stats["max_time"] = max(stats["max_time"], execution_time)
stats["recent_times"].append(execution_time)
if len(stats["recent_times"]) > 100:
stats["recent_times"] = stats["recent_times"][-100:]
if success:
stats["success_calls"] += 1
else:
stats["failed_calls"] += 1
def update_circuit_breaker_stats(self, stats: dict):
"""更新熔断器统计"""
self.circuit_breaker_stats = stats
def get_stats(self) -> dict:
"""获取完整统计信息"""
uptime = time.time() - self.start_time
# 计算平均处理时间
avg_processing_time = 0
if self.processing_times:
avg_processing_time = sum(self.processing_times) / len(self.processing_times)
# 计算处理速率
processing_rate = self.message_processed / uptime if uptime > 0 else 0
# 计算成功率
total = self.message_processed + self.message_failed
success_rate = self.message_processed / total if total > 0 else 1.0
return {
"uptime_seconds": uptime,
"uptime_formatted": self._format_uptime(uptime),
"messages": {
"received": self.message_received,
"processed": self.message_processed,
"failed": self.message_failed,
"dropped": self.message_dropped,
"success_rate": f"{success_rate * 100:.1f}%",
"processing_rate": f"{processing_rate:.2f}/s"
},
"processing_time": {
"average_ms": f"{avg_processing_time * 1000:.1f}",
"max_ms": f"{max(self.processing_times) * 1000:.1f}" if self.processing_times else "0",
"min_ms": f"{min(self.processing_times) * 1000:.1f}" if self.processing_times else "0"
},
"queue": {
"current_size": self.queue_size_history[-1][1] if self.queue_size_history else 0,
"max_size": max(s[1] for s in self.queue_size_history) if self.queue_size_history else 0
},
"circuit_breaker": self.circuit_breaker_stats,
"plugins": self._get_plugin_summary()
}
def _get_plugin_summary(self) -> List[dict]:
"""获取插件统计摘要"""
summary = []
for name, stats in self.plugin_stats.items():
avg_time = stats["total_time"] / stats["total_calls"] if stats["total_calls"] > 0 else 0
summary.append({
"name": name,
"calls": stats["total_calls"],
"success_rate": f"{stats['success_calls'] / stats['total_calls'] * 100:.1f}%" if stats["total_calls"] > 0 else "N/A",
"avg_time_ms": f"{avg_time * 1000:.1f}",
"max_time_ms": f"{stats['max_time'] * 1000:.1f}"
})
# 按平均时间排序
summary.sort(key=lambda x: float(x["avg_time_ms"]), reverse=True)
return summary
def _format_uptime(self, seconds: float) -> str:
"""格式化运行时间"""
days = int(seconds // 86400)
hours = int((seconds % 86400) // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
if days > 0:
return f"{days}{hours}小时 {minutes}分钟"
elif hours > 0:
return f"{hours}小时 {minutes}分钟"
elif minutes > 0:
return f"{minutes}分钟 {secs}"
else:
return f"{secs}"
def print_stats(self):
"""打印统计信息到日志"""
stats = self.get_stats()
logger.info("=" * 50)
logger.info("性能监控报告")
logger.info("=" * 50)
logger.info(f"运行时间: {stats['uptime_formatted']}")
logger.info(f"消息统计: 收到 {stats['messages']['received']}, "
f"处理 {stats['messages']['processed']}, "
f"失败 {stats['messages']['failed']}, "
f"丢弃 {stats['messages']['dropped']}")
logger.info(f"成功率: {stats['messages']['success_rate']}, "
f"处理速率: {stats['messages']['processing_rate']}")
logger.info(f"平均处理时间: {stats['processing_time']['average_ms']}ms")
logger.info(f"队列大小: {stats['queue']['current_size']}")
logger.info(f"熔断器状态: {stats['circuit_breaker'].get('state', 'N/A')}")
if stats['plugins']:
logger.info("插件耗时排行:")
for i, p in enumerate(stats['plugins'][:5], 1):
logger.info(f" {i}. {p['name']}: {p['avg_time_ms']}ms (调用 {p['calls']} 次)")
logger.info("=" * 50)
# ==================== 全局实例 ====================
# 性能监控器单例
_performance_monitor: Optional[PerformanceMonitor] = None
def get_performance_monitor() -> PerformanceMonitor:
"""获取性能监控器实例"""
global _performance_monitor
if _performance_monitor is None:
_performance_monitor = PerformanceMonitor()
return _performance_monitor