feat: 优化整体项目
This commit is contained in:
658
utils/bot_utils.py
Normal file
658
utils/bot_utils.py
Normal file
@@ -0,0 +1,658 @@
|
||||
"""
|
||||
机器人核心工具模块
|
||||
|
||||
包含:
|
||||
- 优先级消息队列
|
||||
- 自适应熔断器
|
||||
- 请求重试机制
|
||||
- 配置热更新
|
||||
- 性能监控
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import heapq
|
||||
import os
|
||||
import tomllib
|
||||
import functools
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Callable, Any, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
import aiohttp
|
||||
|
||||
|
||||
# ==================== 优先级消息队列 ====================
|
||||
|
||||
# 消息优先级定义
|
||||
class MessagePriority:
|
||||
"""消息优先级常量"""
|
||||
CRITICAL = 100 # 系统消息、登录信息
|
||||
HIGH = 80 # 管理员命令
|
||||
NORMAL = 50 # @机器人的消息
|
||||
LOW = 20 # 普通群消息
|
||||
|
||||
# 高优先级消息类型
|
||||
PRIORITY_MESSAGE_TYPES = {
|
||||
11025: MessagePriority.CRITICAL, # 登录信息
|
||||
11058: MessagePriority.CRITICAL, # 系统消息
|
||||
11098: MessagePriority.HIGH, # 群成员加入
|
||||
11099: MessagePriority.HIGH, # 群成员退出
|
||||
11100: MessagePriority.HIGH, # 群信息变更
|
||||
11056: MessagePriority.HIGH, # 好友请求
|
||||
}
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class PriorityMessage:
|
||||
"""优先级消息包装"""
|
||||
priority: int
|
||||
timestamp: float = field(compare=False)
|
||||
msg_type: int = field(compare=False)
|
||||
data: dict = field(compare=False)
|
||||
|
||||
def __init__(self, msg_type: int, data: dict, priority: int = None):
|
||||
# 优先级越高,数值越小(因为heapq是最小堆)
|
||||
self.priority = -(priority or PRIORITY_MESSAGE_TYPES.get(msg_type, MessagePriority.NORMAL))
|
||||
self.timestamp = time.time()
|
||||
self.msg_type = msg_type
|
||||
self.data = data
|
||||
|
||||
|
||||
class PriorityMessageQueue:
|
||||
"""优先级消息队列"""
|
||||
|
||||
def __init__(self, maxsize: int = 1000):
|
||||
self.maxsize = maxsize
|
||||
self._heap: List[PriorityMessage] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._not_empty = asyncio.Event()
|
||||
self._unfinished_tasks = 0
|
||||
self._finished = asyncio.Event()
|
||||
self._finished.set()
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""返回队列大小"""
|
||||
return len(self._heap)
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""队列是否为空"""
|
||||
return len(self._heap) == 0
|
||||
|
||||
def full(self) -> bool:
|
||||
"""队列是否已满"""
|
||||
return len(self._heap) >= self.maxsize
|
||||
|
||||
async def put(self, msg_type: int, data: dict, priority: int = None):
|
||||
"""添加消息到队列"""
|
||||
async with self._lock:
|
||||
msg = PriorityMessage(msg_type, data, priority)
|
||||
heapq.heappush(self._heap, msg)
|
||||
self._unfinished_tasks += 1
|
||||
self._finished.clear()
|
||||
self._not_empty.set()
|
||||
|
||||
async def get(self) -> Tuple[int, dict]:
|
||||
"""获取优先级最高的消息"""
|
||||
while True:
|
||||
async with self._lock:
|
||||
if self._heap:
|
||||
msg = heapq.heappop(self._heap)
|
||||
if not self._heap:
|
||||
self._not_empty.clear()
|
||||
return (msg.msg_type, msg.data)
|
||||
|
||||
# 等待新消息
|
||||
await self._not_empty.wait()
|
||||
|
||||
def get_nowait(self) -> Tuple[int, dict]:
|
||||
"""非阻塞获取消息"""
|
||||
if not self._heap:
|
||||
raise asyncio.QueueEmpty()
|
||||
msg = heapq.heappop(self._heap)
|
||||
if not self._heap:
|
||||
self._not_empty.clear()
|
||||
return (msg.msg_type, msg.data)
|
||||
|
||||
def task_done(self):
|
||||
"""标记任务完成"""
|
||||
self._unfinished_tasks -= 1
|
||||
if self._unfinished_tasks == 0:
|
||||
self._finished.set()
|
||||
|
||||
async def join(self):
|
||||
"""等待所有任务完成"""
|
||||
await self._finished.wait()
|
||||
|
||||
def drop_lowest_priority(self) -> bool:
|
||||
"""丢弃优先级最低的消息"""
|
||||
if not self._heap:
|
||||
return False
|
||||
|
||||
# 找到优先级最低的消息(priority值最大,因为是负数所以最小)
|
||||
min_idx = 0
|
||||
for i, msg in enumerate(self._heap):
|
||||
if msg.priority > self._heap[min_idx].priority:
|
||||
min_idx = i
|
||||
|
||||
# 删除该消息
|
||||
self._heap.pop(min_idx)
|
||||
heapq.heapify(self._heap)
|
||||
self._unfinished_tasks -= 1
|
||||
return True
|
||||
|
||||
|
||||
# ==================== 自适应熔断器 ====================
|
||||
|
||||
class AdaptiveCircuitBreaker:
|
||||
"""自适应熔断器"""
|
||||
|
||||
# 熔断器状态
|
||||
STATE_CLOSED = "closed" # 正常状态
|
||||
STATE_OPEN = "open" # 熔断状态
|
||||
STATE_HALF_OPEN = "half_open" # 半开状态(尝试恢复)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_threshold: int = 5,
|
||||
success_threshold: int = 3,
|
||||
initial_recovery_time: float = 5.0,
|
||||
max_recovery_time: float = 300.0,
|
||||
recovery_multiplier: float = 2.0
|
||||
):
|
||||
"""
|
||||
初始化熔断器
|
||||
|
||||
Args:
|
||||
failure_threshold: 触发熔断的连续失败次数
|
||||
success_threshold: 恢复正常的连续成功次数
|
||||
initial_recovery_time: 初始恢复等待时间(秒)
|
||||
max_recovery_time: 最大恢复等待时间(秒)
|
||||
recovery_multiplier: 恢复时间增长倍数
|
||||
"""
|
||||
self.failure_threshold = failure_threshold
|
||||
self.success_threshold = success_threshold
|
||||
self.initial_recovery_time = initial_recovery_time
|
||||
self.max_recovery_time = max_recovery_time
|
||||
self.recovery_multiplier = recovery_multiplier
|
||||
|
||||
# 状态
|
||||
self.state = self.STATE_CLOSED
|
||||
self.failure_count = 0
|
||||
self.success_count = 0
|
||||
self.last_failure_time = 0
|
||||
self.current_recovery_time = initial_recovery_time
|
||||
|
||||
# 统计
|
||||
self.total_failures = 0
|
||||
self.total_successes = 0
|
||||
self.total_rejections = 0
|
||||
|
||||
def is_open(self) -> bool:
|
||||
"""检查熔断器是否开启(是否应该拒绝请求)"""
|
||||
if self.state == self.STATE_CLOSED:
|
||||
return False
|
||||
|
||||
if self.state == self.STATE_OPEN:
|
||||
# 检查是否可以尝试恢复
|
||||
elapsed = time.time() - self.last_failure_time
|
||||
if elapsed >= self.current_recovery_time:
|
||||
self.state = self.STATE_HALF_OPEN
|
||||
self.success_count = 0
|
||||
logger.info(f"熔断器进入半开状态,尝试恢复(等待了 {elapsed:.1f}s)")
|
||||
return False
|
||||
return True
|
||||
|
||||
# 半开状态,允许请求通过
|
||||
return False
|
||||
|
||||
def record_success(self):
|
||||
"""记录成功"""
|
||||
self.total_successes += 1
|
||||
|
||||
if self.state == self.STATE_HALF_OPEN:
|
||||
self.success_count += 1
|
||||
if self.success_count >= self.success_threshold:
|
||||
# 恢复正常
|
||||
self.state = self.STATE_CLOSED
|
||||
self.failure_count = 0
|
||||
self.success_count = 0
|
||||
self.current_recovery_time = self.initial_recovery_time
|
||||
logger.success(f"熔断器已恢复正常(连续成功 {self.success_threshold} 次)")
|
||||
else:
|
||||
# 正常状态,重置失败计数
|
||||
self.failure_count = 0
|
||||
|
||||
def record_failure(self):
|
||||
"""记录失败"""
|
||||
self.total_failures += 1
|
||||
self.failure_count += 1
|
||||
self.last_failure_time = time.time()
|
||||
|
||||
if self.state == self.STATE_HALF_OPEN:
|
||||
# 半开状态下失败,重新熔断
|
||||
self.state = self.STATE_OPEN
|
||||
self.success_count = 0
|
||||
# 增加恢复时间
|
||||
self.current_recovery_time = min(
|
||||
self.current_recovery_time * self.recovery_multiplier,
|
||||
self.max_recovery_time
|
||||
)
|
||||
logger.warning(f"熔断器重新开启,下次恢复等待 {self.current_recovery_time:.1f}s")
|
||||
|
||||
elif self.state == self.STATE_CLOSED:
|
||||
if self.failure_count >= self.failure_threshold:
|
||||
self.state = self.STATE_OPEN
|
||||
logger.warning(f"熔断器开启,连续失败 {self.failure_count} 次")
|
||||
|
||||
def record_rejection(self):
|
||||
"""记录被拒绝的请求"""
|
||||
self.total_rejections += 1
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""获取统计信息"""
|
||||
return {
|
||||
"state": self.state,
|
||||
"failure_count": self.failure_count,
|
||||
"success_count": self.success_count,
|
||||
"current_recovery_time": self.current_recovery_time,
|
||||
"total_failures": self.total_failures,
|
||||
"total_successes": self.total_successes,
|
||||
"total_rejections": self.total_rejections
|
||||
}
|
||||
|
||||
|
||||
# ==================== 请求重试机制 ====================
|
||||
|
||||
class RetryConfig:
|
||||
"""重试配置"""
|
||||
def __init__(
|
||||
self,
|
||||
max_retries: int = 3,
|
||||
initial_delay: float = 1.0,
|
||||
max_delay: float = 30.0,
|
||||
exponential_base: float = 2.0,
|
||||
retryable_exceptions: tuple = (
|
||||
aiohttp.ClientError,
|
||||
asyncio.TimeoutError,
|
||||
ConnectionError,
|
||||
)
|
||||
):
|
||||
self.max_retries = max_retries
|
||||
self.initial_delay = initial_delay
|
||||
self.max_delay = max_delay
|
||||
self.exponential_base = exponential_base
|
||||
self.retryable_exceptions = retryable_exceptions
|
||||
|
||||
|
||||
def retry_async(config: RetryConfig = None):
|
||||
"""异步重试装饰器"""
|
||||
if config is None:
|
||||
config = RetryConfig()
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(config.max_retries + 1):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except config.retryable_exceptions as e:
|
||||
last_exception = e
|
||||
|
||||
if attempt == config.max_retries:
|
||||
logger.error(f"重试 {config.max_retries} 次后仍然失败: {func.__name__}")
|
||||
raise
|
||||
|
||||
# 计算延迟时间(指数退避)
|
||||
delay = min(
|
||||
config.initial_delay * (config.exponential_base ** attempt),
|
||||
config.max_delay
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"请求失败,{delay:.1f}s 后重试 "
|
||||
f"(第 {attempt + 1}/{config.max_retries} 次): {e}"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
raise last_exception
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
async def request_with_retry(
|
||||
session: aiohttp.ClientSession,
|
||||
method: str,
|
||||
url: str,
|
||||
max_retries: int = 3,
|
||||
**kwargs
|
||||
) -> aiohttp.ClientResponse:
|
||||
"""带重试的 HTTP 请求"""
|
||||
config = RetryConfig(max_retries=max_retries)
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(config.max_retries + 1):
|
||||
try:
|
||||
response = await session.request(method, url, **kwargs)
|
||||
return response
|
||||
except config.retryable_exceptions as e:
|
||||
last_exception = e
|
||||
|
||||
if attempt == config.max_retries:
|
||||
raise
|
||||
|
||||
delay = min(
|
||||
config.initial_delay * (config.exponential_base ** attempt),
|
||||
config.max_delay
|
||||
)
|
||||
|
||||
logger.warning(f"HTTP 请求失败,{delay:.1f}s 后重试: {e}")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
raise last_exception
|
||||
|
||||
|
||||
# ==================== 配置热更新 ====================
|
||||
|
||||
class ConfigWatcher:
|
||||
"""配置文件监听器"""
|
||||
|
||||
def __init__(self, config_path: str, check_interval: float = 5.0):
|
||||
"""
|
||||
初始化配置监听器
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
check_interval: 检查间隔(秒)
|
||||
"""
|
||||
self.config_path = Path(config_path)
|
||||
self.check_interval = check_interval
|
||||
self.last_mtime = 0
|
||||
self.callbacks: List[Callable[[dict], Any]] = []
|
||||
self.current_config: dict = {}
|
||||
self._running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
|
||||
def register_callback(self, callback: Callable[[dict], Any]):
|
||||
"""注册配置更新回调"""
|
||||
self.callbacks.append(callback)
|
||||
|
||||
def unregister_callback(self, callback: Callable[[dict], Any]):
|
||||
"""取消注册回调"""
|
||||
if callback in self.callbacks:
|
||||
self.callbacks.remove(callback)
|
||||
|
||||
def _load_config(self) -> dict:
|
||||
"""加载配置文件"""
|
||||
try:
|
||||
with open(self.config_path, "rb") as f:
|
||||
return tomllib.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件失败: {e}")
|
||||
return {}
|
||||
|
||||
def get_config(self) -> dict:
|
||||
"""获取当前配置"""
|
||||
return self.current_config
|
||||
|
||||
async def start(self):
|
||||
"""启动配置监听"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# 初始加载
|
||||
if self.config_path.exists():
|
||||
self.last_mtime = os.path.getmtime(self.config_path)
|
||||
self.current_config = self._load_config()
|
||||
|
||||
self._task = asyncio.create_task(self._watch_loop())
|
||||
logger.info(f"配置监听器已启动: {self.config_path}")
|
||||
|
||||
async def stop(self):
|
||||
"""停止配置监听"""
|
||||
self._running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("配置监听器已停止")
|
||||
|
||||
async def _watch_loop(self):
|
||||
"""监听循环"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self.check_interval)
|
||||
|
||||
if not self.config_path.exists():
|
||||
continue
|
||||
|
||||
mtime = os.path.getmtime(self.config_path)
|
||||
if mtime > self.last_mtime:
|
||||
logger.info("检测到配置文件变化,重新加载...")
|
||||
new_config = self._load_config()
|
||||
|
||||
if new_config:
|
||||
old_config = self.current_config
|
||||
self.current_config = new_config
|
||||
self.last_mtime = mtime
|
||||
|
||||
# 通知所有回调
|
||||
for callback in self.callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(new_config)
|
||||
else:
|
||||
callback(new_config)
|
||||
except Exception as e:
|
||||
logger.error(f"配置更新回调执行失败: {e}")
|
||||
|
||||
logger.success("配置已热更新")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"配置监听异常: {e}")
|
||||
|
||||
|
||||
# ==================== 性能监控 ====================
|
||||
|
||||
class PerformanceMonitor:
|
||||
"""性能监控器"""
|
||||
|
||||
def __init__(self):
|
||||
self.start_time = time.time()
|
||||
|
||||
# 消息统计
|
||||
self.message_received = 0
|
||||
self.message_processed = 0
|
||||
self.message_failed = 0
|
||||
self.message_dropped = 0
|
||||
|
||||
# 处理时间统计
|
||||
self.processing_times: List[float] = []
|
||||
self.max_processing_times = 1000 # 保留最近1000条记录
|
||||
|
||||
# 插件统计
|
||||
self.plugin_stats: Dict[str, dict] = {}
|
||||
|
||||
# 队列统计
|
||||
self.queue_size_history: List[Tuple[float, int]] = []
|
||||
self.max_queue_history = 100
|
||||
|
||||
# 熔断器统计
|
||||
self.circuit_breaker_stats: dict = {}
|
||||
|
||||
def record_message_received(self):
|
||||
"""记录收到消息"""
|
||||
self.message_received += 1
|
||||
|
||||
def record_message_processed(self, processing_time: float):
|
||||
"""记录消息处理完成"""
|
||||
self.message_processed += 1
|
||||
self.processing_times.append(processing_time)
|
||||
|
||||
# 限制历史记录数量
|
||||
if len(self.processing_times) > self.max_processing_times:
|
||||
self.processing_times = self.processing_times[-self.max_processing_times:]
|
||||
|
||||
def record_message_failed(self):
|
||||
"""记录消息处理失败"""
|
||||
self.message_failed += 1
|
||||
|
||||
def record_message_dropped(self):
|
||||
"""记录消息被丢弃"""
|
||||
self.message_dropped += 1
|
||||
|
||||
def record_queue_size(self, size: int):
|
||||
"""记录队列大小"""
|
||||
self.queue_size_history.append((time.time(), size))
|
||||
|
||||
if len(self.queue_size_history) > self.max_queue_history:
|
||||
self.queue_size_history = self.queue_size_history[-self.max_queue_history:]
|
||||
|
||||
def record_plugin_execution(self, plugin_name: str, execution_time: float, success: bool):
|
||||
"""记录插件执行"""
|
||||
if plugin_name not in self.plugin_stats:
|
||||
self.plugin_stats[plugin_name] = {
|
||||
"total_calls": 0,
|
||||
"success_calls": 0,
|
||||
"failed_calls": 0,
|
||||
"total_time": 0,
|
||||
"max_time": 0,
|
||||
"recent_times": []
|
||||
}
|
||||
|
||||
stats = self.plugin_stats[plugin_name]
|
||||
stats["total_calls"] += 1
|
||||
stats["total_time"] += execution_time
|
||||
stats["max_time"] = max(stats["max_time"], execution_time)
|
||||
stats["recent_times"].append(execution_time)
|
||||
|
||||
if len(stats["recent_times"]) > 100:
|
||||
stats["recent_times"] = stats["recent_times"][-100:]
|
||||
|
||||
if success:
|
||||
stats["success_calls"] += 1
|
||||
else:
|
||||
stats["failed_calls"] += 1
|
||||
|
||||
def update_circuit_breaker_stats(self, stats: dict):
|
||||
"""更新熔断器统计"""
|
||||
self.circuit_breaker_stats = stats
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""获取完整统计信息"""
|
||||
uptime = time.time() - self.start_time
|
||||
|
||||
# 计算平均处理时间
|
||||
avg_processing_time = 0
|
||||
if self.processing_times:
|
||||
avg_processing_time = sum(self.processing_times) / len(self.processing_times)
|
||||
|
||||
# 计算处理速率
|
||||
processing_rate = self.message_processed / uptime if uptime > 0 else 0
|
||||
|
||||
# 计算成功率
|
||||
total = self.message_processed + self.message_failed
|
||||
success_rate = self.message_processed / total if total > 0 else 1.0
|
||||
|
||||
return {
|
||||
"uptime_seconds": uptime,
|
||||
"uptime_formatted": self._format_uptime(uptime),
|
||||
"messages": {
|
||||
"received": self.message_received,
|
||||
"processed": self.message_processed,
|
||||
"failed": self.message_failed,
|
||||
"dropped": self.message_dropped,
|
||||
"success_rate": f"{success_rate * 100:.1f}%",
|
||||
"processing_rate": f"{processing_rate:.2f}/s"
|
||||
},
|
||||
"processing_time": {
|
||||
"average_ms": f"{avg_processing_time * 1000:.1f}",
|
||||
"max_ms": f"{max(self.processing_times) * 1000:.1f}" if self.processing_times else "0",
|
||||
"min_ms": f"{min(self.processing_times) * 1000:.1f}" if self.processing_times else "0"
|
||||
},
|
||||
"queue": {
|
||||
"current_size": self.queue_size_history[-1][1] if self.queue_size_history else 0,
|
||||
"max_size": max(s[1] for s in self.queue_size_history) if self.queue_size_history else 0
|
||||
},
|
||||
"circuit_breaker": self.circuit_breaker_stats,
|
||||
"plugins": self._get_plugin_summary()
|
||||
}
|
||||
|
||||
def _get_plugin_summary(self) -> List[dict]:
|
||||
"""获取插件统计摘要"""
|
||||
summary = []
|
||||
for name, stats in self.plugin_stats.items():
|
||||
avg_time = stats["total_time"] / stats["total_calls"] if stats["total_calls"] > 0 else 0
|
||||
summary.append({
|
||||
"name": name,
|
||||
"calls": stats["total_calls"],
|
||||
"success_rate": f"{stats['success_calls'] / stats['total_calls'] * 100:.1f}%" if stats["total_calls"] > 0 else "N/A",
|
||||
"avg_time_ms": f"{avg_time * 1000:.1f}",
|
||||
"max_time_ms": f"{stats['max_time'] * 1000:.1f}"
|
||||
})
|
||||
|
||||
# 按平均时间排序
|
||||
summary.sort(key=lambda x: float(x["avg_time_ms"]), reverse=True)
|
||||
return summary
|
||||
|
||||
def _format_uptime(self, seconds: float) -> str:
|
||||
"""格式化运行时间"""
|
||||
days = int(seconds // 86400)
|
||||
hours = int((seconds % 86400) // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
secs = int(seconds % 60)
|
||||
|
||||
if days > 0:
|
||||
return f"{days}天 {hours}小时 {minutes}分钟"
|
||||
elif hours > 0:
|
||||
return f"{hours}小时 {minutes}分钟"
|
||||
elif minutes > 0:
|
||||
return f"{minutes}分钟 {secs}秒"
|
||||
else:
|
||||
return f"{secs}秒"
|
||||
|
||||
def print_stats(self):
|
||||
"""打印统计信息到日志"""
|
||||
stats = self.get_stats()
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("性能监控报告")
|
||||
logger.info("=" * 50)
|
||||
logger.info(f"运行时间: {stats['uptime_formatted']}")
|
||||
logger.info(f"消息统计: 收到 {stats['messages']['received']}, "
|
||||
f"处理 {stats['messages']['processed']}, "
|
||||
f"失败 {stats['messages']['failed']}, "
|
||||
f"丢弃 {stats['messages']['dropped']}")
|
||||
logger.info(f"成功率: {stats['messages']['success_rate']}, "
|
||||
f"处理速率: {stats['messages']['processing_rate']}")
|
||||
logger.info(f"平均处理时间: {stats['processing_time']['average_ms']}ms")
|
||||
logger.info(f"队列大小: {stats['queue']['current_size']}")
|
||||
logger.info(f"熔断器状态: {stats['circuit_breaker'].get('state', 'N/A')}")
|
||||
|
||||
if stats['plugins']:
|
||||
logger.info("插件耗时排行:")
|
||||
for i, p in enumerate(stats['plugins'][:5], 1):
|
||||
logger.info(f" {i}. {p['name']}: {p['avg_time_ms']}ms (调用 {p['calls']} 次)")
|
||||
|
||||
logger.info("=" * 50)
|
||||
|
||||
|
||||
# ==================== 全局实例 ====================
|
||||
|
||||
# 性能监控器单例
|
||||
_performance_monitor: Optional[PerformanceMonitor] = None
|
||||
|
||||
def get_performance_monitor() -> PerformanceMonitor:
|
||||
"""获取性能监控器实例"""
|
||||
global _performance_monitor
|
||||
if _performance_monitor is None:
|
||||
_performance_monitor = PerformanceMonitor()
|
||||
return _performance_monitor
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -13,6 +14,14 @@ class PluginBase(ABC):
|
||||
author: str = "未知"
|
||||
version: str = "1.0.0"
|
||||
|
||||
# 插件依赖(填写依赖的插件类名列表)
|
||||
# 例如: dependencies = ["MessageLogger", "AIChat"]
|
||||
dependencies: List[str] = []
|
||||
|
||||
# 加载优先级(数值越大越先加载,默认50)
|
||||
# 基础插件设置高优先级,依赖其他插件的设置低优先级
|
||||
load_priority: int = 50
|
||||
|
||||
def __init__(self):
|
||||
self.enabled = False
|
||||
self._scheduled_jobs = set()
|
||||
|
||||
@@ -117,24 +117,107 @@ class PluginManager(metaclass=Singleton):
|
||||
if not found:
|
||||
logger.warning(f"未找到插件类 {plugin_name}")
|
||||
|
||||
def _resolve_load_order(self, plugin_classes: List[Type[PluginBase]]) -> List[Type[PluginBase]]:
|
||||
"""
|
||||
解析插件加载顺序(拓扑排序 + 优先级排序)
|
||||
|
||||
Args:
|
||||
plugin_classes: 插件类列表
|
||||
|
||||
Returns:
|
||||
按依赖关系和优先级排序后的插件类列表
|
||||
"""
|
||||
# 构建插件名到类的映射
|
||||
name_to_class = {cls.__name__: cls for cls in plugin_classes}
|
||||
|
||||
# 构建依赖图
|
||||
dependencies = {}
|
||||
for cls in plugin_classes:
|
||||
deps = getattr(cls, 'dependencies', [])
|
||||
dependencies[cls.__name__] = [d for d in deps if d in name_to_class]
|
||||
|
||||
# 拓扑排序
|
||||
sorted_names = []
|
||||
visited = set()
|
||||
temp_visited = set()
|
||||
|
||||
def visit(name: str):
|
||||
if name in temp_visited:
|
||||
# 检测到循环依赖
|
||||
logger.warning(f"检测到循环依赖: {name}")
|
||||
return
|
||||
if name in visited:
|
||||
return
|
||||
|
||||
temp_visited.add(name)
|
||||
|
||||
# 先访问依赖
|
||||
for dep in dependencies.get(name, []):
|
||||
visit(dep)
|
||||
|
||||
temp_visited.remove(name)
|
||||
visited.add(name)
|
||||
sorted_names.append(name)
|
||||
|
||||
# 按优先级排序后再进行拓扑排序
|
||||
priority_sorted = sorted(
|
||||
plugin_classes,
|
||||
key=lambda cls: getattr(cls, 'load_priority', 50),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
for cls in priority_sorted:
|
||||
if cls.__name__ not in visited:
|
||||
visit(cls.__name__)
|
||||
|
||||
# 返回排序后的类列表
|
||||
return [name_to_class[name] for name in sorted_names if name in name_to_class]
|
||||
|
||||
async def load_plugins(self, load_disabled: bool = True) -> Union[List[str], bool]:
|
||||
"""加载所有插件(按依赖顺序)"""
|
||||
loaded_plugins = []
|
||||
|
||||
# 第一步:收集所有插件类
|
||||
all_plugin_classes = []
|
||||
plugin_disabled_map = {}
|
||||
|
||||
for dirname in os.listdir("plugins"):
|
||||
if os.path.isdir(f"plugins/{dirname}") and os.path.exists(f"plugins/{dirname}/main.py"):
|
||||
try:
|
||||
module = importlib.import_module(f"plugins.{dirname}.main")
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase:
|
||||
all_plugin_classes.append(obj)
|
||||
|
||||
# 记录是否禁用
|
||||
is_disabled = False
|
||||
if not load_disabled:
|
||||
is_disabled = obj.__name__ in self.excluded_plugins or dirname in self.excluded_plugins
|
||||
|
||||
if await self._load_plugin_class(obj, is_disabled=is_disabled):
|
||||
loaded_plugins.append(obj.__name__)
|
||||
plugin_disabled_map[obj.__name__] = is_disabled
|
||||
except:
|
||||
logger.error(f"加载 {dirname} 时发生错误: {traceback.format_exc()}")
|
||||
|
||||
# 第二步:按依赖顺序排序
|
||||
sorted_classes = self._resolve_load_order(all_plugin_classes)
|
||||
logger.info(f"插件加载顺序: {[cls.__name__ for cls in sorted_classes]}")
|
||||
|
||||
# 第三步:按顺序加载插件
|
||||
for plugin_class in sorted_classes:
|
||||
plugin_name = plugin_class.__name__
|
||||
is_disabled = plugin_disabled_map.get(plugin_name, False)
|
||||
|
||||
# 检查依赖是否已加载
|
||||
deps = getattr(plugin_class, 'dependencies', [])
|
||||
deps_satisfied = all(dep in self.plugins for dep in deps)
|
||||
|
||||
if not deps_satisfied and not is_disabled:
|
||||
missing_deps = [dep for dep in deps if dep not in self.plugins]
|
||||
logger.warning(f"插件 {plugin_name} 的依赖未满足: {missing_deps},跳过加载")
|
||||
continue
|
||||
|
||||
if await self._load_plugin_class(plugin_class, is_disabled=is_disabled):
|
||||
loaded_plugins.append(plugin_name)
|
||||
|
||||
return loaded_plugins
|
||||
|
||||
async def unload_plugin(self, plugin_name: str) -> bool:
|
||||
|
||||
744
utils/redis_cache.py
Normal file
744
utils/redis_cache.py
Normal file
@@ -0,0 +1,744 @@
|
||||
"""
|
||||
Redis 缓存工具类
|
||||
|
||||
用于缓存用户信息等数据,减少 API 调用
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import redis
|
||||
REDIS_AVAILABLE = True
|
||||
except ImportError:
|
||||
REDIS_AVAILABLE = False
|
||||
logger.warning("redis 库未安装,缓存功能将不可用")
|
||||
|
||||
|
||||
class RedisCache:
|
||||
"""Redis 缓存管理器"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""单例模式"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: Dict = None):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
config: Redis 配置字典,包含 host, port, password, db 等
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.client = None
|
||||
self.enabled = False
|
||||
self.default_ttl = 3600 # 默认过期时间 1 小时
|
||||
|
||||
if not REDIS_AVAILABLE:
|
||||
logger.warning("Redis 库未安装,缓存功能禁用")
|
||||
self._initialized = True
|
||||
return
|
||||
|
||||
if config:
|
||||
self.connect(config)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
def connect(self, config: Dict) -> bool:
|
||||
"""
|
||||
连接 Redis
|
||||
|
||||
Args:
|
||||
config: Redis 配置
|
||||
|
||||
Returns:
|
||||
是否连接成功
|
||||
"""
|
||||
if not REDIS_AVAILABLE:
|
||||
return False
|
||||
|
||||
try:
|
||||
self.client = redis.Redis(
|
||||
host=config.get("host", "localhost"),
|
||||
port=config.get("port", 6379),
|
||||
password=config.get("password", None),
|
||||
db=config.get("db", 0),
|
||||
decode_responses=True,
|
||||
socket_timeout=5,
|
||||
socket_connect_timeout=5
|
||||
)
|
||||
|
||||
# 测试连接
|
||||
self.client.ping()
|
||||
self.enabled = True
|
||||
self.default_ttl = config.get("ttl", 3600)
|
||||
|
||||
logger.success(f"Redis 连接成功: {config.get('host')}:{config.get('port')}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis 连接失败: {e}")
|
||||
self.client = None
|
||||
self.enabled = False
|
||||
return False
|
||||
|
||||
def _make_key(self, prefix: str, *args) -> str:
|
||||
"""
|
||||
生成缓存 key
|
||||
|
||||
Args:
|
||||
prefix: key 前缀
|
||||
*args: key 组成部分
|
||||
|
||||
Returns:
|
||||
完整的 key
|
||||
"""
|
||||
parts = [prefix] + [str(arg) for arg in args]
|
||||
return ":".join(parts)
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
获取缓存值
|
||||
|
||||
Args:
|
||||
key: 缓存 key
|
||||
|
||||
Returns:
|
||||
缓存的值,不存在返回 None
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return None
|
||||
|
||||
try:
|
||||
value = self.client.get(key)
|
||||
if value:
|
||||
return json.loads(value)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Redis GET 失败: {key}, {e}")
|
||||
return None
|
||||
|
||||
def set(self, key: str, value: Any, ttl: int = None) -> bool:
|
||||
"""
|
||||
设置缓存值
|
||||
|
||||
Args:
|
||||
key: 缓存 key
|
||||
value: 要缓存的值
|
||||
ttl: 过期时间(秒),默认使用 default_ttl
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return False
|
||||
|
||||
try:
|
||||
ttl = ttl or self.default_ttl
|
||||
self.client.setex(key, ttl, json.dumps(value, ensure_ascii=False))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Redis SET 失败: {key}, {e}")
|
||||
return False
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""
|
||||
删除缓存
|
||||
|
||||
Args:
|
||||
key: 缓存 key
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return False
|
||||
|
||||
try:
|
||||
self.client.delete(key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Redis DELETE 失败: {key}, {e}")
|
||||
return False
|
||||
|
||||
def delete_pattern(self, pattern: str) -> int:
|
||||
"""
|
||||
删除匹配模式的所有 key
|
||||
|
||||
Args:
|
||||
pattern: key 模式,如 "user_info:*"
|
||||
|
||||
Returns:
|
||||
删除的 key 数量
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return 0
|
||||
|
||||
try:
|
||||
keys = self.client.keys(pattern)
|
||||
if keys:
|
||||
return self.client.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Redis DELETE PATTERN 失败: {pattern}, {e}")
|
||||
return 0
|
||||
|
||||
# ==================== 用户信息缓存专用方法 ====================
|
||||
|
||||
def get_user_info(self, chatroom_id: str, user_wxid: str) -> Optional[Dict]:
|
||||
"""
|
||||
获取缓存的用户信息
|
||||
|
||||
Args:
|
||||
chatroom_id: 群聊 ID
|
||||
user_wxid: 用户 wxid
|
||||
|
||||
Returns:
|
||||
用户信息字典,不存在返回 None
|
||||
"""
|
||||
key = self._make_key("user_info", chatroom_id, user_wxid)
|
||||
return self.get(key)
|
||||
|
||||
def set_user_info(self, chatroom_id: str, user_wxid: str, user_info: Dict, ttl: int = None) -> bool:
|
||||
"""
|
||||
缓存用户信息
|
||||
|
||||
Args:
|
||||
chatroom_id: 群聊 ID
|
||||
user_wxid: 用户 wxid
|
||||
user_info: 用户信息字典
|
||||
ttl: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否缓存成功
|
||||
"""
|
||||
key = self._make_key("user_info", chatroom_id, user_wxid)
|
||||
return self.set(key, user_info, ttl)
|
||||
|
||||
def get_user_basic_info(self, chatroom_id: str, user_wxid: str) -> Optional[Dict]:
|
||||
"""
|
||||
获取缓存的用户基本信息(昵称和头像)
|
||||
|
||||
Args:
|
||||
chatroom_id: 群聊 ID
|
||||
user_wxid: 用户 wxid
|
||||
|
||||
Returns:
|
||||
包含 nickname 和 avatar_url 的字典
|
||||
"""
|
||||
user_info = self.get_user_info(chatroom_id, user_wxid)
|
||||
if user_info:
|
||||
# 提取基本信息
|
||||
nickname = ""
|
||||
if isinstance(user_info.get("nickName"), dict):
|
||||
nickname = user_info.get("nickName", {}).get("string", "")
|
||||
else:
|
||||
nickname = user_info.get("nickName", "")
|
||||
|
||||
avatar_url = user_info.get("bigHeadImgUrl", "")
|
||||
|
||||
if nickname or avatar_url:
|
||||
return {
|
||||
"nickname": nickname,
|
||||
"avatar_url": avatar_url
|
||||
}
|
||||
return None
|
||||
|
||||
def clear_user_cache(self, chatroom_id: str = None, user_wxid: str = None) -> int:
|
||||
"""
|
||||
清除用户信息缓存
|
||||
|
||||
Args:
|
||||
chatroom_id: 群聊 ID,为空则清除所有群
|
||||
user_wxid: 用户 wxid,为空则清除该群所有用户
|
||||
|
||||
Returns:
|
||||
清除的缓存数量
|
||||
"""
|
||||
if chatroom_id and user_wxid:
|
||||
key = self._make_key("user_info", chatroom_id, user_wxid)
|
||||
return 1 if self.delete(key) else 0
|
||||
elif chatroom_id:
|
||||
pattern = self._make_key("user_info", chatroom_id, "*")
|
||||
return self.delete_pattern(pattern)
|
||||
else:
|
||||
return self.delete_pattern("user_info:*")
|
||||
|
||||
def get_cache_stats(self) -> Dict:
|
||||
"""
|
||||
获取缓存统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return {"enabled": False}
|
||||
|
||||
try:
|
||||
info = self.client.info("memory")
|
||||
user_keys = len(self.client.keys("user_info:*"))
|
||||
chat_keys = len(self.client.keys("chat_history:*"))
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"used_memory": info.get("used_memory_human", "unknown"),
|
||||
"user_info_count": user_keys,
|
||||
"chat_history_count": chat_keys
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取缓存统计失败: {e}")
|
||||
return {"enabled": True, "error": str(e)}
|
||||
|
||||
# ==================== 对话历史缓存专用方法 ====================
|
||||
|
||||
def get_chat_history(self, chat_id: str, max_messages: int = 100) -> list:
|
||||
"""
|
||||
获取对话历史
|
||||
|
||||
Args:
|
||||
chat_id: 会话ID(私聊为用户wxid,群聊为 群ID:用户ID)
|
||||
max_messages: 最大返回消息数
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return []
|
||||
|
||||
try:
|
||||
key = self._make_key("chat_history", chat_id)
|
||||
# 使用 LRANGE 获取最近的消息(列表尾部是最新的)
|
||||
data = self.client.lrange(key, -max_messages, -1)
|
||||
return [json.loads(item) for item in data]
|
||||
except Exception as e:
|
||||
logger.error(f"获取对话历史失败: {chat_id}, {e}")
|
||||
return []
|
||||
|
||||
def add_chat_message(self, chat_id: str, role: str, content, ttl: int = 86400) -> bool:
|
||||
"""
|
||||
添加消息到对话历史
|
||||
|
||||
Args:
|
||||
chat_id: 会话ID
|
||||
role: 角色 (user/assistant)
|
||||
content: 消息内容(字符串或列表)
|
||||
ttl: 过期时间(秒),默认24小时
|
||||
|
||||
Returns:
|
||||
是否添加成功
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return False
|
||||
|
||||
try:
|
||||
key = self._make_key("chat_history", chat_id)
|
||||
message = {"role": role, "content": content}
|
||||
self.client.rpush(key, json.dumps(message, ensure_ascii=False))
|
||||
self.client.expire(key, ttl)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"添加对话消息失败: {chat_id}, {e}")
|
||||
return False
|
||||
|
||||
def trim_chat_history(self, chat_id: str, max_messages: int = 100) -> bool:
|
||||
"""
|
||||
裁剪对话历史,保留最近的N条消息
|
||||
|
||||
Args:
|
||||
chat_id: 会话ID
|
||||
max_messages: 保留的最大消息数
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return False
|
||||
|
||||
try:
|
||||
key = self._make_key("chat_history", chat_id)
|
||||
# 保留最后 max_messages 条
|
||||
self.client.ltrim(key, -max_messages, -1)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"裁剪对话历史失败: {chat_id}, {e}")
|
||||
return False
|
||||
|
||||
def clear_chat_history(self, chat_id: str) -> bool:
|
||||
"""
|
||||
清空指定会话的对话历史
|
||||
|
||||
Args:
|
||||
chat_id: 会话ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return False
|
||||
|
||||
try:
|
||||
key = self._make_key("chat_history", chat_id)
|
||||
self.client.delete(key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"清空对话历史失败: {chat_id}, {e}")
|
||||
return False
|
||||
|
||||
# ==================== 群聊历史记录专用方法 ====================
|
||||
|
||||
def get_group_history(self, group_id: str, max_messages: int = 100) -> list:
|
||||
"""
|
||||
获取群聊历史记录
|
||||
|
||||
Args:
|
||||
group_id: 群聊ID
|
||||
max_messages: 最大返回消息数
|
||||
|
||||
Returns:
|
||||
消息列表,每条包含 nickname, content, timestamp
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return []
|
||||
|
||||
try:
|
||||
key = self._make_key("group_history", group_id)
|
||||
data = self.client.lrange(key, -max_messages, -1)
|
||||
return [json.loads(item) for item in data]
|
||||
except Exception as e:
|
||||
logger.error(f"获取群聊历史失败: {group_id}, {e}")
|
||||
return []
|
||||
|
||||
def add_group_message(self, group_id: str, nickname: str, content,
|
||||
record_id: str = None, ttl: int = 86400) -> bool:
|
||||
"""
|
||||
添加消息到群聊历史
|
||||
|
||||
Args:
|
||||
group_id: 群聊ID
|
||||
nickname: 发送者昵称
|
||||
content: 消息内容
|
||||
record_id: 可选的记录ID,用于后续更新
|
||||
ttl: 过期时间(秒),默认24小时
|
||||
|
||||
Returns:
|
||||
是否添加成功
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return False
|
||||
|
||||
try:
|
||||
import time
|
||||
key = self._make_key("group_history", group_id)
|
||||
message = {
|
||||
"nickname": nickname,
|
||||
"content": content,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
if record_id:
|
||||
message["id"] = record_id
|
||||
|
||||
self.client.rpush(key, json.dumps(message, ensure_ascii=False))
|
||||
self.client.expire(key, ttl)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"添加群聊消息失败: {group_id}, {e}")
|
||||
return False
|
||||
|
||||
def update_group_message_by_id(self, group_id: str, record_id: str, new_content) -> bool:
|
||||
"""
|
||||
根据ID更新群聊历史中的消息
|
||||
|
||||
Args:
|
||||
group_id: 群聊ID
|
||||
record_id: 记录ID
|
||||
new_content: 新内容
|
||||
|
||||
Returns:
|
||||
是否更新成功
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return False
|
||||
|
||||
try:
|
||||
key = self._make_key("group_history", group_id)
|
||||
# 获取所有消息
|
||||
data = self.client.lrange(key, 0, -1)
|
||||
|
||||
for i, item in enumerate(data):
|
||||
msg = json.loads(item)
|
||||
if msg.get("id") == record_id:
|
||||
msg["content"] = new_content
|
||||
self.client.lset(key, i, json.dumps(msg, ensure_ascii=False))
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"更新群聊消息失败: {group_id}, {record_id}, {e}")
|
||||
return False
|
||||
|
||||
def trim_group_history(self, group_id: str, max_messages: int = 100) -> bool:
|
||||
"""
|
||||
裁剪群聊历史,保留最近的N条消息
|
||||
|
||||
Args:
|
||||
group_id: 群聊ID
|
||||
max_messages: 保留的最大消息数
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return False
|
||||
|
||||
try:
|
||||
key = self._make_key("group_history", group_id)
|
||||
self.client.ltrim(key, -max_messages, -1)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"裁剪群聊历史失败: {group_id}, {e}")
|
||||
return False
|
||||
|
||||
# ==================== 限流专用方法 ====================
|
||||
|
||||
def check_rate_limit(self, identifier: str, limit: int = 10,
|
||||
window: int = 60, limit_type: str = "message") -> tuple:
|
||||
"""
|
||||
检查是否超过限流
|
||||
|
||||
使用滑动窗口算法
|
||||
|
||||
Args:
|
||||
identifier: 标识符(如用户wxid、群ID等)
|
||||
limit: 时间窗口内最大请求数
|
||||
window: 时间窗口(秒)
|
||||
limit_type: 限流类型(message/ai_chat/image_gen等)
|
||||
|
||||
Returns:
|
||||
(是否允许, 剩余次数, 重置时间秒数)
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return (True, limit, 0) # Redis 不可用时不限流
|
||||
|
||||
try:
|
||||
import time
|
||||
key = self._make_key("rate_limit", limit_type, identifier)
|
||||
now = time.time()
|
||||
window_start = now - window
|
||||
|
||||
# 使用 pipeline 提高性能
|
||||
pipe = self.client.pipeline()
|
||||
|
||||
# 移除过期的记录
|
||||
pipe.zremrangebyscore(key, 0, window_start)
|
||||
# 获取当前窗口内的请求数
|
||||
pipe.zcard(key)
|
||||
# 添加当前请求
|
||||
pipe.zadd(key, {str(now): now})
|
||||
# 设置过期时间
|
||||
pipe.expire(key, window)
|
||||
|
||||
results = pipe.execute()
|
||||
current_count = results[1] # zcard 的结果
|
||||
|
||||
if current_count >= limit:
|
||||
# 获取最早的记录时间,计算重置时间
|
||||
oldest = self.client.zrange(key, 0, 0, withscores=True)
|
||||
if oldest:
|
||||
reset_time = int(oldest[0][1] + window - now)
|
||||
else:
|
||||
reset_time = window
|
||||
return (False, 0, max(reset_time, 1))
|
||||
|
||||
remaining = limit - current_count - 1
|
||||
return (True, remaining, 0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"限流检查失败: {identifier}, {e}")
|
||||
return (True, limit, 0) # 出错时不限流
|
||||
|
||||
def get_rate_limit_status(self, identifier: str, limit: int = 10,
|
||||
window: int = 60, limit_type: str = "message") -> Dict:
|
||||
"""
|
||||
获取限流状态(不增加计数)
|
||||
|
||||
Args:
|
||||
identifier: 标识符
|
||||
limit: 时间窗口内最大请求数
|
||||
window: 时间窗口(秒)
|
||||
limit_type: 限流类型
|
||||
|
||||
Returns:
|
||||
状态字典
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return {"enabled": False, "current": 0, "limit": limit, "remaining": limit}
|
||||
|
||||
try:
|
||||
import time
|
||||
key = self._make_key("rate_limit", limit_type, identifier)
|
||||
now = time.time()
|
||||
window_start = now - window
|
||||
|
||||
# 移除过期记录并获取当前数量
|
||||
self.client.zremrangebyscore(key, 0, window_start)
|
||||
current = self.client.zcard(key)
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"current": current,
|
||||
"limit": limit,
|
||||
"remaining": max(0, limit - current),
|
||||
"window": window
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取限流状态失败: {identifier}, {e}")
|
||||
return {"enabled": False, "error": str(e)}
|
||||
|
||||
def reset_rate_limit(self, identifier: str, limit_type: str = "message") -> bool:
|
||||
"""
|
||||
重置限流计数
|
||||
|
||||
Args:
|
||||
identifier: 标识符
|
||||
limit_type: 限流类型
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return False
|
||||
|
||||
try:
|
||||
key = self._make_key("rate_limit", limit_type, identifier)
|
||||
self.client.delete(key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"重置限流失败: {identifier}, {e}")
|
||||
return False
|
||||
|
||||
# ==================== 媒体缓存专用方法 ====================
|
||||
|
||||
def cache_media(self, media_key: str, base64_data: str, media_type: str = "image", ttl: int = 300) -> bool:
|
||||
"""
|
||||
缓存媒体文件的 base64 数据
|
||||
|
||||
Args:
|
||||
media_key: 媒体唯一标识(如 cdnurl 的 hash 或 aeskey)
|
||||
base64_data: base64 编码的媒体数据
|
||||
media_type: 媒体类型(image/emoji/video)
|
||||
ttl: 过期时间(秒),默认5分钟
|
||||
|
||||
Returns:
|
||||
是否缓存成功
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return False
|
||||
|
||||
try:
|
||||
key = self._make_key("media_cache", media_type, media_key)
|
||||
# 直接存储 base64 字符串,不再 json 序列化
|
||||
self.client.setex(key, ttl, base64_data)
|
||||
logger.debug(f"媒体已缓存: {media_type}/{media_key[:20]}..., TTL={ttl}s")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"缓存媒体失败: {media_key}, {e}")
|
||||
return False
|
||||
|
||||
def get_cached_media(self, media_key: str, media_type: str = "image") -> Optional[str]:
|
||||
"""
|
||||
获取缓存的媒体 base64 数据
|
||||
|
||||
Args:
|
||||
media_key: 媒体唯一标识
|
||||
media_type: 媒体类型
|
||||
|
||||
Returns:
|
||||
base64 数据,不存在返回 None
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return None
|
||||
|
||||
try:
|
||||
key = self._make_key("media_cache", media_type, media_key)
|
||||
data = self.client.get(key)
|
||||
if data:
|
||||
logger.debug(f"媒体缓存命中: {media_type}/{media_key[:20]}...")
|
||||
return data
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取媒体缓存失败: {media_key}, {e}")
|
||||
return None
|
||||
|
||||
def delete_cached_media(self, media_key: str, media_type: str = "image") -> bool:
|
||||
"""
|
||||
删除缓存的媒体
|
||||
|
||||
Args:
|
||||
media_key: 媒体唯一标识
|
||||
media_type: 媒体类型
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
if not self.enabled or not self.client:
|
||||
return False
|
||||
|
||||
try:
|
||||
key = self._make_key("media_cache", media_type, media_key)
|
||||
self.client.delete(key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"删除媒体缓存失败: {media_key}, {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def generate_media_key(cdnurl: str = "", aeskey: str = "") -> str:
|
||||
"""
|
||||
根据 CDN URL 或 AES Key 生成媒体缓存 key
|
||||
|
||||
Args:
|
||||
cdnurl: CDN URL
|
||||
aeskey: AES Key
|
||||
|
||||
Returns:
|
||||
缓存 key
|
||||
"""
|
||||
import hashlib
|
||||
# 优先使用 aeskey(更短更稳定),否则使用 cdnurl 的 hash
|
||||
if aeskey:
|
||||
return aeskey[:32] # 取前32位作为 key
|
||||
elif cdnurl:
|
||||
return hashlib.md5(cdnurl.encode()).hexdigest()
|
||||
return ""
|
||||
|
||||
|
||||
def get_cache() -> Optional[RedisCache]:
|
||||
"""
|
||||
获取全局缓存实例
|
||||
|
||||
返回 RedisCache 单例实例。如果还没有初始化,返回一个未连接的实例。
|
||||
建议在 MessageLogger 初始化后再调用此函数。
|
||||
"""
|
||||
return RedisCache._instance
|
||||
|
||||
|
||||
def init_cache(config: Dict) -> RedisCache:
|
||||
"""
|
||||
初始化全局缓存实例
|
||||
|
||||
Args:
|
||||
config: Redis 配置
|
||||
|
||||
Returns:
|
||||
缓存实例
|
||||
"""
|
||||
global _cache_instance
|
||||
_cache_instance = RedisCache(config)
|
||||
return _cache_instance
|
||||
Reference in New Issue
Block a user