659 lines
22 KiB
Python
659 lines
22 KiB
Python
"""
|
||
机器人核心工具模块
|
||
|
||
包含:
|
||
- 优先级消息队列
|
||
- 自适应熔断器
|
||
- 请求重试机制
|
||
- 配置热更新
|
||
- 性能监控
|
||
"""
|
||
|
||
import asyncio
|
||
import time
|
||
import heapq
|
||
import os
|
||
import tomllib
|
||
import functools
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional, Callable, Any, Tuple
|
||
from dataclasses import dataclass, field
|
||
from loguru import logger
|
||
import aiohttp
|
||
|
||
|
||
# ==================== 优先级消息队列 ====================
|
||
|
||
# 消息优先级定义
|
||
class MessagePriority:
|
||
"""消息优先级常量"""
|
||
CRITICAL = 100 # 系统消息、登录信息
|
||
HIGH = 80 # 管理员命令
|
||
NORMAL = 50 # @机器人的消息
|
||
LOW = 20 # 普通群消息
|
||
|
||
# 高优先级消息类型
|
||
PRIORITY_MESSAGE_TYPES = {
|
||
11025: MessagePriority.CRITICAL, # 登录信息
|
||
11058: MessagePriority.CRITICAL, # 系统消息
|
||
11098: MessagePriority.HIGH, # 群成员加入
|
||
11099: MessagePriority.HIGH, # 群成员退出
|
||
11100: MessagePriority.HIGH, # 群信息变更
|
||
11056: MessagePriority.HIGH, # 好友请求
|
||
}
|
||
|
||
|
||
@dataclass(order=True)
|
||
class PriorityMessage:
|
||
"""优先级消息包装"""
|
||
priority: int
|
||
timestamp: float = field(compare=False)
|
||
msg_type: int = field(compare=False)
|
||
data: dict = field(compare=False)
|
||
|
||
def __init__(self, msg_type: int, data: dict, priority: int = None):
|
||
# 优先级越高,数值越小(因为heapq是最小堆)
|
||
self.priority = -(priority or PRIORITY_MESSAGE_TYPES.get(msg_type, MessagePriority.NORMAL))
|
||
self.timestamp = time.time()
|
||
self.msg_type = msg_type
|
||
self.data = data
|
||
|
||
|
||
class PriorityMessageQueue:
|
||
"""优先级消息队列"""
|
||
|
||
def __init__(self, maxsize: int = 1000):
|
||
self.maxsize = maxsize
|
||
self._heap: List[PriorityMessage] = []
|
||
self._lock = asyncio.Lock()
|
||
self._not_empty = asyncio.Event()
|
||
self._unfinished_tasks = 0
|
||
self._finished = asyncio.Event()
|
||
self._finished.set()
|
||
|
||
def qsize(self) -> int:
|
||
"""返回队列大小"""
|
||
return len(self._heap)
|
||
|
||
def empty(self) -> bool:
|
||
"""队列是否为空"""
|
||
return len(self._heap) == 0
|
||
|
||
def full(self) -> bool:
|
||
"""队列是否已满"""
|
||
return len(self._heap) >= self.maxsize
|
||
|
||
async def put(self, msg_type: int, data: dict, priority: int = None):
|
||
"""添加消息到队列"""
|
||
async with self._lock:
|
||
msg = PriorityMessage(msg_type, data, priority)
|
||
heapq.heappush(self._heap, msg)
|
||
self._unfinished_tasks += 1
|
||
self._finished.clear()
|
||
self._not_empty.set()
|
||
|
||
async def get(self) -> Tuple[int, dict]:
|
||
"""获取优先级最高的消息"""
|
||
while True:
|
||
async with self._lock:
|
||
if self._heap:
|
||
msg = heapq.heappop(self._heap)
|
||
if not self._heap:
|
||
self._not_empty.clear()
|
||
return (msg.msg_type, msg.data)
|
||
|
||
# 等待新消息
|
||
await self._not_empty.wait()
|
||
|
||
def get_nowait(self) -> Tuple[int, dict]:
|
||
"""非阻塞获取消息"""
|
||
if not self._heap:
|
||
raise asyncio.QueueEmpty()
|
||
msg = heapq.heappop(self._heap)
|
||
if not self._heap:
|
||
self._not_empty.clear()
|
||
return (msg.msg_type, msg.data)
|
||
|
||
def task_done(self):
|
||
"""标记任务完成"""
|
||
self._unfinished_tasks -= 1
|
||
if self._unfinished_tasks == 0:
|
||
self._finished.set()
|
||
|
||
async def join(self):
|
||
"""等待所有任务完成"""
|
||
await self._finished.wait()
|
||
|
||
def drop_lowest_priority(self) -> bool:
|
||
"""丢弃优先级最低的消息"""
|
||
if not self._heap:
|
||
return False
|
||
|
||
# 找到优先级最低的消息(priority值最大,因为是负数所以最小)
|
||
min_idx = 0
|
||
for i, msg in enumerate(self._heap):
|
||
if msg.priority > self._heap[min_idx].priority:
|
||
min_idx = i
|
||
|
||
# 删除该消息
|
||
self._heap.pop(min_idx)
|
||
heapq.heapify(self._heap)
|
||
self._unfinished_tasks -= 1
|
||
return True
|
||
|
||
|
||
# ==================== 自适应熔断器 ====================
|
||
|
||
class AdaptiveCircuitBreaker:
|
||
"""自适应熔断器"""
|
||
|
||
# 熔断器状态
|
||
STATE_CLOSED = "closed" # 正常状态
|
||
STATE_OPEN = "open" # 熔断状态
|
||
STATE_HALF_OPEN = "half_open" # 半开状态(尝试恢复)
|
||
|
||
def __init__(
|
||
self,
|
||
failure_threshold: int = 5,
|
||
success_threshold: int = 3,
|
||
initial_recovery_time: float = 5.0,
|
||
max_recovery_time: float = 300.0,
|
||
recovery_multiplier: float = 2.0
|
||
):
|
||
"""
|
||
初始化熔断器
|
||
|
||
Args:
|
||
failure_threshold: 触发熔断的连续失败次数
|
||
success_threshold: 恢复正常的连续成功次数
|
||
initial_recovery_time: 初始恢复等待时间(秒)
|
||
max_recovery_time: 最大恢复等待时间(秒)
|
||
recovery_multiplier: 恢复时间增长倍数
|
||
"""
|
||
self.failure_threshold = failure_threshold
|
||
self.success_threshold = success_threshold
|
||
self.initial_recovery_time = initial_recovery_time
|
||
self.max_recovery_time = max_recovery_time
|
||
self.recovery_multiplier = recovery_multiplier
|
||
|
||
# 状态
|
||
self.state = self.STATE_CLOSED
|
||
self.failure_count = 0
|
||
self.success_count = 0
|
||
self.last_failure_time = 0
|
||
self.current_recovery_time = initial_recovery_time
|
||
|
||
# 统计
|
||
self.total_failures = 0
|
||
self.total_successes = 0
|
||
self.total_rejections = 0
|
||
|
||
def is_open(self) -> bool:
|
||
"""检查熔断器是否开启(是否应该拒绝请求)"""
|
||
if self.state == self.STATE_CLOSED:
|
||
return False
|
||
|
||
if self.state == self.STATE_OPEN:
|
||
# 检查是否可以尝试恢复
|
||
elapsed = time.time() - self.last_failure_time
|
||
if elapsed >= self.current_recovery_time:
|
||
self.state = self.STATE_HALF_OPEN
|
||
self.success_count = 0
|
||
logger.info(f"熔断器进入半开状态,尝试恢复(等待了 {elapsed:.1f}s)")
|
||
return False
|
||
return True
|
||
|
||
# 半开状态,允许请求通过
|
||
return False
|
||
|
||
def record_success(self):
|
||
"""记录成功"""
|
||
self.total_successes += 1
|
||
|
||
if self.state == self.STATE_HALF_OPEN:
|
||
self.success_count += 1
|
||
if self.success_count >= self.success_threshold:
|
||
# 恢复正常
|
||
self.state = self.STATE_CLOSED
|
||
self.failure_count = 0
|
||
self.success_count = 0
|
||
self.current_recovery_time = self.initial_recovery_time
|
||
logger.success(f"熔断器已恢复正常(连续成功 {self.success_threshold} 次)")
|
||
else:
|
||
# 正常状态,重置失败计数
|
||
self.failure_count = 0
|
||
|
||
def record_failure(self):
|
||
"""记录失败"""
|
||
self.total_failures += 1
|
||
self.failure_count += 1
|
||
self.last_failure_time = time.time()
|
||
|
||
if self.state == self.STATE_HALF_OPEN:
|
||
# 半开状态下失败,重新熔断
|
||
self.state = self.STATE_OPEN
|
||
self.success_count = 0
|
||
# 增加恢复时间
|
||
self.current_recovery_time = min(
|
||
self.current_recovery_time * self.recovery_multiplier,
|
||
self.max_recovery_time
|
||
)
|
||
logger.warning(f"熔断器重新开启,下次恢复等待 {self.current_recovery_time:.1f}s")
|
||
|
||
elif self.state == self.STATE_CLOSED:
|
||
if self.failure_count >= self.failure_threshold:
|
||
self.state = self.STATE_OPEN
|
||
logger.warning(f"熔断器开启,连续失败 {self.failure_count} 次")
|
||
|
||
def record_rejection(self):
|
||
"""记录被拒绝的请求"""
|
||
self.total_rejections += 1
|
||
|
||
def get_stats(self) -> dict:
|
||
"""获取统计信息"""
|
||
return {
|
||
"state": self.state,
|
||
"failure_count": self.failure_count,
|
||
"success_count": self.success_count,
|
||
"current_recovery_time": self.current_recovery_time,
|
||
"total_failures": self.total_failures,
|
||
"total_successes": self.total_successes,
|
||
"total_rejections": self.total_rejections
|
||
}
|
||
|
||
|
||
# ==================== 请求重试机制 ====================
|
||
|
||
class RetryConfig:
|
||
"""重试配置"""
|
||
def __init__(
|
||
self,
|
||
max_retries: int = 3,
|
||
initial_delay: float = 1.0,
|
||
max_delay: float = 30.0,
|
||
exponential_base: float = 2.0,
|
||
retryable_exceptions: tuple = (
|
||
aiohttp.ClientError,
|
||
asyncio.TimeoutError,
|
||
ConnectionError,
|
||
)
|
||
):
|
||
self.max_retries = max_retries
|
||
self.initial_delay = initial_delay
|
||
self.max_delay = max_delay
|
||
self.exponential_base = exponential_base
|
||
self.retryable_exceptions = retryable_exceptions
|
||
|
||
|
||
def retry_async(config: RetryConfig = None):
|
||
"""异步重试装饰器"""
|
||
if config is None:
|
||
config = RetryConfig()
|
||
|
||
def decorator(func):
|
||
@functools.wraps(func)
|
||
async def wrapper(*args, **kwargs):
|
||
last_exception = None
|
||
|
||
for attempt in range(config.max_retries + 1):
|
||
try:
|
||
return await func(*args, **kwargs)
|
||
except config.retryable_exceptions as e:
|
||
last_exception = e
|
||
|
||
if attempt == config.max_retries:
|
||
logger.error(f"重试 {config.max_retries} 次后仍然失败: {func.__name__}")
|
||
raise
|
||
|
||
# 计算延迟时间(指数退避)
|
||
delay = min(
|
||
config.initial_delay * (config.exponential_base ** attempt),
|
||
config.max_delay
|
||
)
|
||
|
||
logger.warning(
|
||
f"请求失败,{delay:.1f}s 后重试 "
|
||
f"(第 {attempt + 1}/{config.max_retries} 次): {e}"
|
||
)
|
||
await asyncio.sleep(delay)
|
||
|
||
raise last_exception
|
||
return wrapper
|
||
return decorator
|
||
|
||
|
||
async def request_with_retry(
|
||
session: aiohttp.ClientSession,
|
||
method: str,
|
||
url: str,
|
||
max_retries: int = 3,
|
||
**kwargs
|
||
) -> aiohttp.ClientResponse:
|
||
"""带重试的 HTTP 请求"""
|
||
config = RetryConfig(max_retries=max_retries)
|
||
last_exception = None
|
||
|
||
for attempt in range(config.max_retries + 1):
|
||
try:
|
||
response = await session.request(method, url, **kwargs)
|
||
return response
|
||
except config.retryable_exceptions as e:
|
||
last_exception = e
|
||
|
||
if attempt == config.max_retries:
|
||
raise
|
||
|
||
delay = min(
|
||
config.initial_delay * (config.exponential_base ** attempt),
|
||
config.max_delay
|
||
)
|
||
|
||
logger.warning(f"HTTP 请求失败,{delay:.1f}s 后重试: {e}")
|
||
await asyncio.sleep(delay)
|
||
|
||
raise last_exception
|
||
|
||
|
||
# ==================== 配置热更新 ====================
|
||
|
||
class ConfigWatcher:
|
||
"""配置文件监听器"""
|
||
|
||
def __init__(self, config_path: str, check_interval: float = 5.0):
|
||
"""
|
||
初始化配置监听器
|
||
|
||
Args:
|
||
config_path: 配置文件路径
|
||
check_interval: 检查间隔(秒)
|
||
"""
|
||
self.config_path = Path(config_path)
|
||
self.check_interval = check_interval
|
||
self.last_mtime = 0
|
||
self.callbacks: List[Callable[[dict], Any]] = []
|
||
self.current_config: dict = {}
|
||
self._running = False
|
||
self._task: Optional[asyncio.Task] = None
|
||
|
||
def register_callback(self, callback: Callable[[dict], Any]):
|
||
"""注册配置更新回调"""
|
||
self.callbacks.append(callback)
|
||
|
||
def unregister_callback(self, callback: Callable[[dict], Any]):
|
||
"""取消注册回调"""
|
||
if callback in self.callbacks:
|
||
self.callbacks.remove(callback)
|
||
|
||
def _load_config(self) -> dict:
|
||
"""加载配置文件"""
|
||
try:
|
||
with open(self.config_path, "rb") as f:
|
||
return tomllib.load(f)
|
||
except Exception as e:
|
||
logger.error(f"加载配置文件失败: {e}")
|
||
return {}
|
||
|
||
def get_config(self) -> dict:
|
||
"""获取当前配置"""
|
||
return self.current_config
|
||
|
||
async def start(self):
|
||
"""启动配置监听"""
|
||
if self._running:
|
||
return
|
||
|
||
self._running = True
|
||
|
||
# 初始加载
|
||
if self.config_path.exists():
|
||
self.last_mtime = os.path.getmtime(self.config_path)
|
||
self.current_config = self._load_config()
|
||
|
||
self._task = asyncio.create_task(self._watch_loop())
|
||
logger.info(f"配置监听器已启动: {self.config_path}")
|
||
|
||
async def stop(self):
|
||
"""停止配置监听"""
|
||
self._running = False
|
||
if self._task:
|
||
self._task.cancel()
|
||
try:
|
||
await self._task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
logger.info("配置监听器已停止")
|
||
|
||
async def _watch_loop(self):
|
||
"""监听循环"""
|
||
while self._running:
|
||
try:
|
||
await asyncio.sleep(self.check_interval)
|
||
|
||
if not self.config_path.exists():
|
||
continue
|
||
|
||
mtime = os.path.getmtime(self.config_path)
|
||
if mtime > self.last_mtime:
|
||
logger.info("检测到配置文件变化,重新加载...")
|
||
new_config = self._load_config()
|
||
|
||
if new_config:
|
||
old_config = self.current_config
|
||
self.current_config = new_config
|
||
self.last_mtime = mtime
|
||
|
||
# 通知所有回调
|
||
for callback in self.callbacks:
|
||
try:
|
||
if asyncio.iscoroutinefunction(callback):
|
||
await callback(new_config)
|
||
else:
|
||
callback(new_config)
|
||
except Exception as e:
|
||
logger.error(f"配置更新回调执行失败: {e}")
|
||
|
||
logger.success("配置已热更新")
|
||
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"配置监听异常: {e}")
|
||
|
||
|
||
# ==================== 性能监控 ====================
|
||
|
||
class PerformanceMonitor:
|
||
"""性能监控器"""
|
||
|
||
def __init__(self):
|
||
self.start_time = time.time()
|
||
|
||
# 消息统计
|
||
self.message_received = 0
|
||
self.message_processed = 0
|
||
self.message_failed = 0
|
||
self.message_dropped = 0
|
||
|
||
# 处理时间统计
|
||
self.processing_times: List[float] = []
|
||
self.max_processing_times = 1000 # 保留最近1000条记录
|
||
|
||
# 插件统计
|
||
self.plugin_stats: Dict[str, dict] = {}
|
||
|
||
# 队列统计
|
||
self.queue_size_history: List[Tuple[float, int]] = []
|
||
self.max_queue_history = 100
|
||
|
||
# 熔断器统计
|
||
self.circuit_breaker_stats: dict = {}
|
||
|
||
def record_message_received(self):
|
||
"""记录收到消息"""
|
||
self.message_received += 1
|
||
|
||
def record_message_processed(self, processing_time: float):
|
||
"""记录消息处理完成"""
|
||
self.message_processed += 1
|
||
self.processing_times.append(processing_time)
|
||
|
||
# 限制历史记录数量
|
||
if len(self.processing_times) > self.max_processing_times:
|
||
self.processing_times = self.processing_times[-self.max_processing_times:]
|
||
|
||
def record_message_failed(self):
|
||
"""记录消息处理失败"""
|
||
self.message_failed += 1
|
||
|
||
def record_message_dropped(self):
|
||
"""记录消息被丢弃"""
|
||
self.message_dropped += 1
|
||
|
||
def record_queue_size(self, size: int):
|
||
"""记录队列大小"""
|
||
self.queue_size_history.append((time.time(), size))
|
||
|
||
if len(self.queue_size_history) > self.max_queue_history:
|
||
self.queue_size_history = self.queue_size_history[-self.max_queue_history:]
|
||
|
||
def record_plugin_execution(self, plugin_name: str, execution_time: float, success: bool):
|
||
"""记录插件执行"""
|
||
if plugin_name not in self.plugin_stats:
|
||
self.plugin_stats[plugin_name] = {
|
||
"total_calls": 0,
|
||
"success_calls": 0,
|
||
"failed_calls": 0,
|
||
"total_time": 0,
|
||
"max_time": 0,
|
||
"recent_times": []
|
||
}
|
||
|
||
stats = self.plugin_stats[plugin_name]
|
||
stats["total_calls"] += 1
|
||
stats["total_time"] += execution_time
|
||
stats["max_time"] = max(stats["max_time"], execution_time)
|
||
stats["recent_times"].append(execution_time)
|
||
|
||
if len(stats["recent_times"]) > 100:
|
||
stats["recent_times"] = stats["recent_times"][-100:]
|
||
|
||
if success:
|
||
stats["success_calls"] += 1
|
||
else:
|
||
stats["failed_calls"] += 1
|
||
|
||
def update_circuit_breaker_stats(self, stats: dict):
|
||
"""更新熔断器统计"""
|
||
self.circuit_breaker_stats = stats
|
||
|
||
def get_stats(self) -> dict:
|
||
"""获取完整统计信息"""
|
||
uptime = time.time() - self.start_time
|
||
|
||
# 计算平均处理时间
|
||
avg_processing_time = 0
|
||
if self.processing_times:
|
||
avg_processing_time = sum(self.processing_times) / len(self.processing_times)
|
||
|
||
# 计算处理速率
|
||
processing_rate = self.message_processed / uptime if uptime > 0 else 0
|
||
|
||
# 计算成功率
|
||
total = self.message_processed + self.message_failed
|
||
success_rate = self.message_processed / total if total > 0 else 1.0
|
||
|
||
return {
|
||
"uptime_seconds": uptime,
|
||
"uptime_formatted": self._format_uptime(uptime),
|
||
"messages": {
|
||
"received": self.message_received,
|
||
"processed": self.message_processed,
|
||
"failed": self.message_failed,
|
||
"dropped": self.message_dropped,
|
||
"success_rate": f"{success_rate * 100:.1f}%",
|
||
"processing_rate": f"{processing_rate:.2f}/s"
|
||
},
|
||
"processing_time": {
|
||
"average_ms": f"{avg_processing_time * 1000:.1f}",
|
||
"max_ms": f"{max(self.processing_times) * 1000:.1f}" if self.processing_times else "0",
|
||
"min_ms": f"{min(self.processing_times) * 1000:.1f}" if self.processing_times else "0"
|
||
},
|
||
"queue": {
|
||
"current_size": self.queue_size_history[-1][1] if self.queue_size_history else 0,
|
||
"max_size": max(s[1] for s in self.queue_size_history) if self.queue_size_history else 0
|
||
},
|
||
"circuit_breaker": self.circuit_breaker_stats,
|
||
"plugins": self._get_plugin_summary()
|
||
}
|
||
|
||
def _get_plugin_summary(self) -> List[dict]:
|
||
"""获取插件统计摘要"""
|
||
summary = []
|
||
for name, stats in self.plugin_stats.items():
|
||
avg_time = stats["total_time"] / stats["total_calls"] if stats["total_calls"] > 0 else 0
|
||
summary.append({
|
||
"name": name,
|
||
"calls": stats["total_calls"],
|
||
"success_rate": f"{stats['success_calls'] / stats['total_calls'] * 100:.1f}%" if stats["total_calls"] > 0 else "N/A",
|
||
"avg_time_ms": f"{avg_time * 1000:.1f}",
|
||
"max_time_ms": f"{stats['max_time'] * 1000:.1f}"
|
||
})
|
||
|
||
# 按平均时间排序
|
||
summary.sort(key=lambda x: float(x["avg_time_ms"]), reverse=True)
|
||
return summary
|
||
|
||
def _format_uptime(self, seconds: float) -> str:
|
||
"""格式化运行时间"""
|
||
days = int(seconds // 86400)
|
||
hours = int((seconds % 86400) // 3600)
|
||
minutes = int((seconds % 3600) // 60)
|
||
secs = int(seconds % 60)
|
||
|
||
if days > 0:
|
||
return f"{days}天 {hours}小时 {minutes}分钟"
|
||
elif hours > 0:
|
||
return f"{hours}小时 {minutes}分钟"
|
||
elif minutes > 0:
|
||
return f"{minutes}分钟 {secs}秒"
|
||
else:
|
||
return f"{secs}秒"
|
||
|
||
def print_stats(self):
|
||
"""打印统计信息到日志"""
|
||
stats = self.get_stats()
|
||
|
||
logger.info("=" * 50)
|
||
logger.info("性能监控报告")
|
||
logger.info("=" * 50)
|
||
logger.info(f"运行时间: {stats['uptime_formatted']}")
|
||
logger.info(f"消息统计: 收到 {stats['messages']['received']}, "
|
||
f"处理 {stats['messages']['processed']}, "
|
||
f"失败 {stats['messages']['failed']}, "
|
||
f"丢弃 {stats['messages']['dropped']}")
|
||
logger.info(f"成功率: {stats['messages']['success_rate']}, "
|
||
f"处理速率: {stats['messages']['processing_rate']}")
|
||
logger.info(f"平均处理时间: {stats['processing_time']['average_ms']}ms")
|
||
logger.info(f"队列大小: {stats['queue']['current_size']}")
|
||
logger.info(f"熔断器状态: {stats['circuit_breaker'].get('state', 'N/A')}")
|
||
|
||
if stats['plugins']:
|
||
logger.info("插件耗时排行:")
|
||
for i, p in enumerate(stats['plugins'][:5], 1):
|
||
logger.info(f" {i}. {p['name']}: {p['avg_time_ms']}ms (调用 {p['calls']} 次)")
|
||
|
||
logger.info("=" * 50)
|
||
|
||
|
||
# ==================== 全局实例 ====================
|
||
|
||
# 性能监控器单例
|
||
_performance_monitor: Optional[PerformanceMonitor] = None
|
||
|
||
def get_performance_monitor() -> PerformanceMonitor:
|
||
"""获取性能监控器实例"""
|
||
global _performance_monitor
|
||
if _performance_monitor is None:
|
||
_performance_monitor = PerformanceMonitor()
|
||
return _performance_monitor
|