feat: 优化整体项目

This commit is contained in:
2025-12-05 18:06:13 +08:00
parent b4df26f61d
commit 7d3ef70093
13 changed files with 2661 additions and 305 deletions

658
utils/bot_utils.py Normal file
View File

@@ -0,0 +1,658 @@
"""
机器人核心工具模块
包含:
- 优先级消息队列
- 自适应熔断器
- 请求重试机制
- 配置热更新
- 性能监控
"""
import asyncio
import time
import heapq
import os
import tomllib
import functools
from pathlib import Path
from typing import Dict, List, Optional, Callable, Any, Tuple
from dataclasses import dataclass, field
from loguru import logger
import aiohttp
# ==================== 优先级消息队列 ====================
# 消息优先级定义
class MessagePriority:
"""消息优先级常量"""
CRITICAL = 100 # 系统消息、登录信息
HIGH = 80 # 管理员命令
NORMAL = 50 # @机器人的消息
LOW = 20 # 普通群消息
# 高优先级消息类型
PRIORITY_MESSAGE_TYPES = {
11025: MessagePriority.CRITICAL, # 登录信息
11058: MessagePriority.CRITICAL, # 系统消息
11098: MessagePriority.HIGH, # 群成员加入
11099: MessagePriority.HIGH, # 群成员退出
11100: MessagePriority.HIGH, # 群信息变更
11056: MessagePriority.HIGH, # 好友请求
}
@dataclass(order=True)
class PriorityMessage:
"""优先级消息包装"""
priority: int
timestamp: float = field(compare=False)
msg_type: int = field(compare=False)
data: dict = field(compare=False)
def __init__(self, msg_type: int, data: dict, priority: int = None):
# 优先级越高数值越小因为heapq是最小堆
self.priority = -(priority or PRIORITY_MESSAGE_TYPES.get(msg_type, MessagePriority.NORMAL))
self.timestamp = time.time()
self.msg_type = msg_type
self.data = data
class PriorityMessageQueue:
"""优先级消息队列"""
def __init__(self, maxsize: int = 1000):
self.maxsize = maxsize
self._heap: List[PriorityMessage] = []
self._lock = asyncio.Lock()
self._not_empty = asyncio.Event()
self._unfinished_tasks = 0
self._finished = asyncio.Event()
self._finished.set()
def qsize(self) -> int:
"""返回队列大小"""
return len(self._heap)
def empty(self) -> bool:
"""队列是否为空"""
return len(self._heap) == 0
def full(self) -> bool:
"""队列是否已满"""
return len(self._heap) >= self.maxsize
async def put(self, msg_type: int, data: dict, priority: int = None):
"""添加消息到队列"""
async with self._lock:
msg = PriorityMessage(msg_type, data, priority)
heapq.heappush(self._heap, msg)
self._unfinished_tasks += 1
self._finished.clear()
self._not_empty.set()
async def get(self) -> Tuple[int, dict]:
"""获取优先级最高的消息"""
while True:
async with self._lock:
if self._heap:
msg = heapq.heappop(self._heap)
if not self._heap:
self._not_empty.clear()
return (msg.msg_type, msg.data)
# 等待新消息
await self._not_empty.wait()
def get_nowait(self) -> Tuple[int, dict]:
"""非阻塞获取消息"""
if not self._heap:
raise asyncio.QueueEmpty()
msg = heapq.heappop(self._heap)
if not self._heap:
self._not_empty.clear()
return (msg.msg_type, msg.data)
def task_done(self):
"""标记任务完成"""
self._unfinished_tasks -= 1
if self._unfinished_tasks == 0:
self._finished.set()
async def join(self):
"""等待所有任务完成"""
await self._finished.wait()
def drop_lowest_priority(self) -> bool:
"""丢弃优先级最低的消息"""
if not self._heap:
return False
# 找到优先级最低的消息priority值最大因为是负数所以最小
min_idx = 0
for i, msg in enumerate(self._heap):
if msg.priority > self._heap[min_idx].priority:
min_idx = i
# 删除该消息
self._heap.pop(min_idx)
heapq.heapify(self._heap)
self._unfinished_tasks -= 1
return True
# ==================== 自适应熔断器 ====================
class AdaptiveCircuitBreaker:
"""自适应熔断器"""
# 熔断器状态
STATE_CLOSED = "closed" # 正常状态
STATE_OPEN = "open" # 熔断状态
STATE_HALF_OPEN = "half_open" # 半开状态(尝试恢复)
def __init__(
self,
failure_threshold: int = 5,
success_threshold: int = 3,
initial_recovery_time: float = 5.0,
max_recovery_time: float = 300.0,
recovery_multiplier: float = 2.0
):
"""
初始化熔断器
Args:
failure_threshold: 触发熔断的连续失败次数
success_threshold: 恢复正常的连续成功次数
initial_recovery_time: 初始恢复等待时间(秒)
max_recovery_time: 最大恢复等待时间(秒)
recovery_multiplier: 恢复时间增长倍数
"""
self.failure_threshold = failure_threshold
self.success_threshold = success_threshold
self.initial_recovery_time = initial_recovery_time
self.max_recovery_time = max_recovery_time
self.recovery_multiplier = recovery_multiplier
# 状态
self.state = self.STATE_CLOSED
self.failure_count = 0
self.success_count = 0
self.last_failure_time = 0
self.current_recovery_time = initial_recovery_time
# 统计
self.total_failures = 0
self.total_successes = 0
self.total_rejections = 0
def is_open(self) -> bool:
"""检查熔断器是否开启(是否应该拒绝请求)"""
if self.state == self.STATE_CLOSED:
return False
if self.state == self.STATE_OPEN:
# 检查是否可以尝试恢复
elapsed = time.time() - self.last_failure_time
if elapsed >= self.current_recovery_time:
self.state = self.STATE_HALF_OPEN
self.success_count = 0
logger.info(f"熔断器进入半开状态,尝试恢复(等待了 {elapsed:.1f}s")
return False
return True
# 半开状态,允许请求通过
return False
def record_success(self):
"""记录成功"""
self.total_successes += 1
if self.state == self.STATE_HALF_OPEN:
self.success_count += 1
if self.success_count >= self.success_threshold:
# 恢复正常
self.state = self.STATE_CLOSED
self.failure_count = 0
self.success_count = 0
self.current_recovery_time = self.initial_recovery_time
logger.success(f"熔断器已恢复正常(连续成功 {self.success_threshold} 次)")
else:
# 正常状态,重置失败计数
self.failure_count = 0
def record_failure(self):
"""记录失败"""
self.total_failures += 1
self.failure_count += 1
self.last_failure_time = time.time()
if self.state == self.STATE_HALF_OPEN:
# 半开状态下失败,重新熔断
self.state = self.STATE_OPEN
self.success_count = 0
# 增加恢复时间
self.current_recovery_time = min(
self.current_recovery_time * self.recovery_multiplier,
self.max_recovery_time
)
logger.warning(f"熔断器重新开启,下次恢复等待 {self.current_recovery_time:.1f}s")
elif self.state == self.STATE_CLOSED:
if self.failure_count >= self.failure_threshold:
self.state = self.STATE_OPEN
logger.warning(f"熔断器开启,连续失败 {self.failure_count}")
def record_rejection(self):
"""记录被拒绝的请求"""
self.total_rejections += 1
def get_stats(self) -> dict:
"""获取统计信息"""
return {
"state": self.state,
"failure_count": self.failure_count,
"success_count": self.success_count,
"current_recovery_time": self.current_recovery_time,
"total_failures": self.total_failures,
"total_successes": self.total_successes,
"total_rejections": self.total_rejections
}
# ==================== 请求重试机制 ====================
class RetryConfig:
"""重试配置"""
def __init__(
self,
max_retries: int = 3,
initial_delay: float = 1.0,
max_delay: float = 30.0,
exponential_base: float = 2.0,
retryable_exceptions: tuple = (
aiohttp.ClientError,
asyncio.TimeoutError,
ConnectionError,
)
):
self.max_retries = max_retries
self.initial_delay = initial_delay
self.max_delay = max_delay
self.exponential_base = exponential_base
self.retryable_exceptions = retryable_exceptions
def retry_async(config: RetryConfig = None):
"""异步重试装饰器"""
if config is None:
config = RetryConfig()
def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(config.max_retries + 1):
try:
return await func(*args, **kwargs)
except config.retryable_exceptions as e:
last_exception = e
if attempt == config.max_retries:
logger.error(f"重试 {config.max_retries} 次后仍然失败: {func.__name__}")
raise
# 计算延迟时间(指数退避)
delay = min(
config.initial_delay * (config.exponential_base ** attempt),
config.max_delay
)
logger.warning(
f"请求失败,{delay:.1f}s 后重试 "
f"(第 {attempt + 1}/{config.max_retries} 次): {e}"
)
await asyncio.sleep(delay)
raise last_exception
return wrapper
return decorator
async def request_with_retry(
session: aiohttp.ClientSession,
method: str,
url: str,
max_retries: int = 3,
**kwargs
) -> aiohttp.ClientResponse:
"""带重试的 HTTP 请求"""
config = RetryConfig(max_retries=max_retries)
last_exception = None
for attempt in range(config.max_retries + 1):
try:
response = await session.request(method, url, **kwargs)
return response
except config.retryable_exceptions as e:
last_exception = e
if attempt == config.max_retries:
raise
delay = min(
config.initial_delay * (config.exponential_base ** attempt),
config.max_delay
)
logger.warning(f"HTTP 请求失败,{delay:.1f}s 后重试: {e}")
await asyncio.sleep(delay)
raise last_exception
# ==================== 配置热更新 ====================
class ConfigWatcher:
"""配置文件监听器"""
def __init__(self, config_path: str, check_interval: float = 5.0):
"""
初始化配置监听器
Args:
config_path: 配置文件路径
check_interval: 检查间隔(秒)
"""
self.config_path = Path(config_path)
self.check_interval = check_interval
self.last_mtime = 0
self.callbacks: List[Callable[[dict], Any]] = []
self.current_config: dict = {}
self._running = False
self._task: Optional[asyncio.Task] = None
def register_callback(self, callback: Callable[[dict], Any]):
"""注册配置更新回调"""
self.callbacks.append(callback)
def unregister_callback(self, callback: Callable[[dict], Any]):
"""取消注册回调"""
if callback in self.callbacks:
self.callbacks.remove(callback)
def _load_config(self) -> dict:
"""加载配置文件"""
try:
with open(self.config_path, "rb") as f:
return tomllib.load(f)
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
return {}
def get_config(self) -> dict:
"""获取当前配置"""
return self.current_config
async def start(self):
"""启动配置监听"""
if self._running:
return
self._running = True
# 初始加载
if self.config_path.exists():
self.last_mtime = os.path.getmtime(self.config_path)
self.current_config = self._load_config()
self._task = asyncio.create_task(self._watch_loop())
logger.info(f"配置监听器已启动: {self.config_path}")
async def stop(self):
"""停止配置监听"""
self._running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
logger.info("配置监听器已停止")
async def _watch_loop(self):
"""监听循环"""
while self._running:
try:
await asyncio.sleep(self.check_interval)
if not self.config_path.exists():
continue
mtime = os.path.getmtime(self.config_path)
if mtime > self.last_mtime:
logger.info("检测到配置文件变化,重新加载...")
new_config = self._load_config()
if new_config:
old_config = self.current_config
self.current_config = new_config
self.last_mtime = mtime
# 通知所有回调
for callback in self.callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(new_config)
else:
callback(new_config)
except Exception as e:
logger.error(f"配置更新回调执行失败: {e}")
logger.success("配置已热更新")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"配置监听异常: {e}")
# ==================== 性能监控 ====================
class PerformanceMonitor:
"""性能监控器"""
def __init__(self):
self.start_time = time.time()
# 消息统计
self.message_received = 0
self.message_processed = 0
self.message_failed = 0
self.message_dropped = 0
# 处理时间统计
self.processing_times: List[float] = []
self.max_processing_times = 1000 # 保留最近1000条记录
# 插件统计
self.plugin_stats: Dict[str, dict] = {}
# 队列统计
self.queue_size_history: List[Tuple[float, int]] = []
self.max_queue_history = 100
# 熔断器统计
self.circuit_breaker_stats: dict = {}
def record_message_received(self):
"""记录收到消息"""
self.message_received += 1
def record_message_processed(self, processing_time: float):
"""记录消息处理完成"""
self.message_processed += 1
self.processing_times.append(processing_time)
# 限制历史记录数量
if len(self.processing_times) > self.max_processing_times:
self.processing_times = self.processing_times[-self.max_processing_times:]
def record_message_failed(self):
"""记录消息处理失败"""
self.message_failed += 1
def record_message_dropped(self):
"""记录消息被丢弃"""
self.message_dropped += 1
def record_queue_size(self, size: int):
"""记录队列大小"""
self.queue_size_history.append((time.time(), size))
if len(self.queue_size_history) > self.max_queue_history:
self.queue_size_history = self.queue_size_history[-self.max_queue_history:]
def record_plugin_execution(self, plugin_name: str, execution_time: float, success: bool):
"""记录插件执行"""
if plugin_name not in self.plugin_stats:
self.plugin_stats[plugin_name] = {
"total_calls": 0,
"success_calls": 0,
"failed_calls": 0,
"total_time": 0,
"max_time": 0,
"recent_times": []
}
stats = self.plugin_stats[plugin_name]
stats["total_calls"] += 1
stats["total_time"] += execution_time
stats["max_time"] = max(stats["max_time"], execution_time)
stats["recent_times"].append(execution_time)
if len(stats["recent_times"]) > 100:
stats["recent_times"] = stats["recent_times"][-100:]
if success:
stats["success_calls"] += 1
else:
stats["failed_calls"] += 1
def update_circuit_breaker_stats(self, stats: dict):
"""更新熔断器统计"""
self.circuit_breaker_stats = stats
def get_stats(self) -> dict:
"""获取完整统计信息"""
uptime = time.time() - self.start_time
# 计算平均处理时间
avg_processing_time = 0
if self.processing_times:
avg_processing_time = sum(self.processing_times) / len(self.processing_times)
# 计算处理速率
processing_rate = self.message_processed / uptime if uptime > 0 else 0
# 计算成功率
total = self.message_processed + self.message_failed
success_rate = self.message_processed / total if total > 0 else 1.0
return {
"uptime_seconds": uptime,
"uptime_formatted": self._format_uptime(uptime),
"messages": {
"received": self.message_received,
"processed": self.message_processed,
"failed": self.message_failed,
"dropped": self.message_dropped,
"success_rate": f"{success_rate * 100:.1f}%",
"processing_rate": f"{processing_rate:.2f}/s"
},
"processing_time": {
"average_ms": f"{avg_processing_time * 1000:.1f}",
"max_ms": f"{max(self.processing_times) * 1000:.1f}" if self.processing_times else "0",
"min_ms": f"{min(self.processing_times) * 1000:.1f}" if self.processing_times else "0"
},
"queue": {
"current_size": self.queue_size_history[-1][1] if self.queue_size_history else 0,
"max_size": max(s[1] for s in self.queue_size_history) if self.queue_size_history else 0
},
"circuit_breaker": self.circuit_breaker_stats,
"plugins": self._get_plugin_summary()
}
def _get_plugin_summary(self) -> List[dict]:
"""获取插件统计摘要"""
summary = []
for name, stats in self.plugin_stats.items():
avg_time = stats["total_time"] / stats["total_calls"] if stats["total_calls"] > 0 else 0
summary.append({
"name": name,
"calls": stats["total_calls"],
"success_rate": f"{stats['success_calls'] / stats['total_calls'] * 100:.1f}%" if stats["total_calls"] > 0 else "N/A",
"avg_time_ms": f"{avg_time * 1000:.1f}",
"max_time_ms": f"{stats['max_time'] * 1000:.1f}"
})
# 按平均时间排序
summary.sort(key=lambda x: float(x["avg_time_ms"]), reverse=True)
return summary
def _format_uptime(self, seconds: float) -> str:
"""格式化运行时间"""
days = int(seconds // 86400)
hours = int((seconds % 86400) // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
if days > 0:
return f"{days}{hours}小时 {minutes}分钟"
elif hours > 0:
return f"{hours}小时 {minutes}分钟"
elif minutes > 0:
return f"{minutes}分钟 {secs}"
else:
return f"{secs}"
def print_stats(self):
"""打印统计信息到日志"""
stats = self.get_stats()
logger.info("=" * 50)
logger.info("性能监控报告")
logger.info("=" * 50)
logger.info(f"运行时间: {stats['uptime_formatted']}")
logger.info(f"消息统计: 收到 {stats['messages']['received']}, "
f"处理 {stats['messages']['processed']}, "
f"失败 {stats['messages']['failed']}, "
f"丢弃 {stats['messages']['dropped']}")
logger.info(f"成功率: {stats['messages']['success_rate']}, "
f"处理速率: {stats['messages']['processing_rate']}")
logger.info(f"平均处理时间: {stats['processing_time']['average_ms']}ms")
logger.info(f"队列大小: {stats['queue']['current_size']}")
logger.info(f"熔断器状态: {stats['circuit_breaker'].get('state', 'N/A')}")
if stats['plugins']:
logger.info("插件耗时排行:")
for i, p in enumerate(stats['plugins'][:5], 1):
logger.info(f" {i}. {p['name']}: {p['avg_time_ms']}ms (调用 {p['calls']} 次)")
logger.info("=" * 50)
# ==================== 全局实例 ====================
# 性能监控器单例
_performance_monitor: Optional[PerformanceMonitor] = None
def get_performance_monitor() -> PerformanceMonitor:
"""获取性能监控器实例"""
global _performance_monitor
if _performance_monitor is None:
_performance_monitor = PerformanceMonitor()
return _performance_monitor

View File

@@ -1,4 +1,5 @@
from abc import ABC
from typing import List
from loguru import logger
@@ -13,6 +14,14 @@ class PluginBase(ABC):
author: str = "未知"
version: str = "1.0.0"
# 插件依赖(填写依赖的插件类名列表)
# 例如: dependencies = ["MessageLogger", "AIChat"]
dependencies: List[str] = []
# 加载优先级数值越大越先加载默认50
# 基础插件设置高优先级,依赖其他插件的设置低优先级
load_priority: int = 50
def __init__(self):
self.enabled = False
self._scheduled_jobs = set()

View File

@@ -117,24 +117,107 @@ class PluginManager(metaclass=Singleton):
if not found:
logger.warning(f"未找到插件类 {plugin_name}")
def _resolve_load_order(self, plugin_classes: List[Type[PluginBase]]) -> List[Type[PluginBase]]:
"""
解析插件加载顺序(拓扑排序 + 优先级排序)
Args:
plugin_classes: 插件类列表
Returns:
按依赖关系和优先级排序后的插件类列表
"""
# 构建插件名到类的映射
name_to_class = {cls.__name__: cls for cls in plugin_classes}
# 构建依赖图
dependencies = {}
for cls in plugin_classes:
deps = getattr(cls, 'dependencies', [])
dependencies[cls.__name__] = [d for d in deps if d in name_to_class]
# 拓扑排序
sorted_names = []
visited = set()
temp_visited = set()
def visit(name: str):
if name in temp_visited:
# 检测到循环依赖
logger.warning(f"检测到循环依赖: {name}")
return
if name in visited:
return
temp_visited.add(name)
# 先访问依赖
for dep in dependencies.get(name, []):
visit(dep)
temp_visited.remove(name)
visited.add(name)
sorted_names.append(name)
# 按优先级排序后再进行拓扑排序
priority_sorted = sorted(
plugin_classes,
key=lambda cls: getattr(cls, 'load_priority', 50),
reverse=True
)
for cls in priority_sorted:
if cls.__name__ not in visited:
visit(cls.__name__)
# 返回排序后的类列表
return [name_to_class[name] for name in sorted_names if name in name_to_class]
async def load_plugins(self, load_disabled: bool = True) -> Union[List[str], bool]:
"""加载所有插件(按依赖顺序)"""
loaded_plugins = []
# 第一步:收集所有插件类
all_plugin_classes = []
plugin_disabled_map = {}
for dirname in os.listdir("plugins"):
if os.path.isdir(f"plugins/{dirname}") and os.path.exists(f"plugins/{dirname}/main.py"):
try:
module = importlib.import_module(f"plugins.{dirname}.main")
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase:
all_plugin_classes.append(obj)
# 记录是否禁用
is_disabled = False
if not load_disabled:
is_disabled = obj.__name__ in self.excluded_plugins or dirname in self.excluded_plugins
if await self._load_plugin_class(obj, is_disabled=is_disabled):
loaded_plugins.append(obj.__name__)
plugin_disabled_map[obj.__name__] = is_disabled
except:
logger.error(f"加载 {dirname} 时发生错误: {traceback.format_exc()}")
# 第二步:按依赖顺序排序
sorted_classes = self._resolve_load_order(all_plugin_classes)
logger.info(f"插件加载顺序: {[cls.__name__ for cls in sorted_classes]}")
# 第三步:按顺序加载插件
for plugin_class in sorted_classes:
plugin_name = plugin_class.__name__
is_disabled = plugin_disabled_map.get(plugin_name, False)
# 检查依赖是否已加载
deps = getattr(plugin_class, 'dependencies', [])
deps_satisfied = all(dep in self.plugins for dep in deps)
if not deps_satisfied and not is_disabled:
missing_deps = [dep for dep in deps if dep not in self.plugins]
logger.warning(f"插件 {plugin_name} 的依赖未满足: {missing_deps},跳过加载")
continue
if await self._load_plugin_class(plugin_class, is_disabled=is_disabled):
loaded_plugins.append(plugin_name)
return loaded_plugins
async def unload_plugin(self, plugin_name: str) -> bool:

744
utils/redis_cache.py Normal file
View File

@@ -0,0 +1,744 @@
"""
Redis 缓存工具类
用于缓存用户信息等数据,减少 API 调用
"""
import json
from typing import Optional, Dict, Any
from loguru import logger
try:
import redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
logger.warning("redis 库未安装,缓存功能将不可用")
class RedisCache:
"""Redis 缓存管理器"""
_instance = None
def __new__(cls, *args, **kwargs):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, config: Dict = None):
"""
初始化 Redis 连接
Args:
config: Redis 配置字典,包含 host, port, password, db 等
"""
if self._initialized:
return
self.client = None
self.enabled = False
self.default_ttl = 3600 # 默认过期时间 1 小时
if not REDIS_AVAILABLE:
logger.warning("Redis 库未安装,缓存功能禁用")
self._initialized = True
return
if config:
self.connect(config)
self._initialized = True
def connect(self, config: Dict) -> bool:
"""
连接 Redis
Args:
config: Redis 配置
Returns:
是否连接成功
"""
if not REDIS_AVAILABLE:
return False
try:
self.client = redis.Redis(
host=config.get("host", "localhost"),
port=config.get("port", 6379),
password=config.get("password", None),
db=config.get("db", 0),
decode_responses=True,
socket_timeout=5,
socket_connect_timeout=5
)
# 测试连接
self.client.ping()
self.enabled = True
self.default_ttl = config.get("ttl", 3600)
logger.success(f"Redis 连接成功: {config.get('host')}:{config.get('port')}")
return True
except Exception as e:
logger.error(f"Redis 连接失败: {e}")
self.client = None
self.enabled = False
return False
def _make_key(self, prefix: str, *args) -> str:
"""
生成缓存 key
Args:
prefix: key 前缀
*args: key 组成部分
Returns:
完整的 key
"""
parts = [prefix] + [str(arg) for arg in args]
return ":".join(parts)
def get(self, key: str) -> Optional[Any]:
"""
获取缓存值
Args:
key: 缓存 key
Returns:
缓存的值,不存在返回 None
"""
if not self.enabled or not self.client:
return None
try:
value = self.client.get(key)
if value:
return json.loads(value)
return None
except Exception as e:
logger.error(f"Redis GET 失败: {key}, {e}")
return None
def set(self, key: str, value: Any, ttl: int = None) -> bool:
"""
设置缓存值
Args:
key: 缓存 key
value: 要缓存的值
ttl: 过期时间(秒),默认使用 default_ttl
Returns:
是否设置成功
"""
if not self.enabled or not self.client:
return False
try:
ttl = ttl or self.default_ttl
self.client.setex(key, ttl, json.dumps(value, ensure_ascii=False))
return True
except Exception as e:
logger.error(f"Redis SET 失败: {key}, {e}")
return False
def delete(self, key: str) -> bool:
"""
删除缓存
Args:
key: 缓存 key
Returns:
是否删除成功
"""
if not self.enabled or not self.client:
return False
try:
self.client.delete(key)
return True
except Exception as e:
logger.error(f"Redis DELETE 失败: {key}, {e}")
return False
def delete_pattern(self, pattern: str) -> int:
"""
删除匹配模式的所有 key
Args:
pattern: key 模式,如 "user_info:*"
Returns:
删除的 key 数量
"""
if not self.enabled or not self.client:
return 0
try:
keys = self.client.keys(pattern)
if keys:
return self.client.delete(*keys)
return 0
except Exception as e:
logger.error(f"Redis DELETE PATTERN 失败: {pattern}, {e}")
return 0
# ==================== 用户信息缓存专用方法 ====================
def get_user_info(self, chatroom_id: str, user_wxid: str) -> Optional[Dict]:
"""
获取缓存的用户信息
Args:
chatroom_id: 群聊 ID
user_wxid: 用户 wxid
Returns:
用户信息字典,不存在返回 None
"""
key = self._make_key("user_info", chatroom_id, user_wxid)
return self.get(key)
def set_user_info(self, chatroom_id: str, user_wxid: str, user_info: Dict, ttl: int = None) -> bool:
"""
缓存用户信息
Args:
chatroom_id: 群聊 ID
user_wxid: 用户 wxid
user_info: 用户信息字典
ttl: 过期时间(秒)
Returns:
是否缓存成功
"""
key = self._make_key("user_info", chatroom_id, user_wxid)
return self.set(key, user_info, ttl)
def get_user_basic_info(self, chatroom_id: str, user_wxid: str) -> Optional[Dict]:
"""
获取缓存的用户基本信息(昵称和头像)
Args:
chatroom_id: 群聊 ID
user_wxid: 用户 wxid
Returns:
包含 nickname 和 avatar_url 的字典
"""
user_info = self.get_user_info(chatroom_id, user_wxid)
if user_info:
# 提取基本信息
nickname = ""
if isinstance(user_info.get("nickName"), dict):
nickname = user_info.get("nickName", {}).get("string", "")
else:
nickname = user_info.get("nickName", "")
avatar_url = user_info.get("bigHeadImgUrl", "")
if nickname or avatar_url:
return {
"nickname": nickname,
"avatar_url": avatar_url
}
return None
def clear_user_cache(self, chatroom_id: str = None, user_wxid: str = None) -> int:
"""
清除用户信息缓存
Args:
chatroom_id: 群聊 ID为空则清除所有群
user_wxid: 用户 wxid为空则清除该群所有用户
Returns:
清除的缓存数量
"""
if chatroom_id and user_wxid:
key = self._make_key("user_info", chatroom_id, user_wxid)
return 1 if self.delete(key) else 0
elif chatroom_id:
pattern = self._make_key("user_info", chatroom_id, "*")
return self.delete_pattern(pattern)
else:
return self.delete_pattern("user_info:*")
def get_cache_stats(self) -> Dict:
"""
获取缓存统计信息
Returns:
统计信息字典
"""
if not self.enabled or not self.client:
return {"enabled": False}
try:
info = self.client.info("memory")
user_keys = len(self.client.keys("user_info:*"))
chat_keys = len(self.client.keys("chat_history:*"))
return {
"enabled": True,
"used_memory": info.get("used_memory_human", "unknown"),
"user_info_count": user_keys,
"chat_history_count": chat_keys
}
except Exception as e:
logger.error(f"获取缓存统计失败: {e}")
return {"enabled": True, "error": str(e)}
# ==================== 对话历史缓存专用方法 ====================
def get_chat_history(self, chat_id: str, max_messages: int = 100) -> list:
"""
获取对话历史
Args:
chat_id: 会话ID私聊为用户wxid群聊为 群ID:用户ID
max_messages: 最大返回消息数
Returns:
消息列表
"""
if not self.enabled or not self.client:
return []
try:
key = self._make_key("chat_history", chat_id)
# 使用 LRANGE 获取最近的消息(列表尾部是最新的)
data = self.client.lrange(key, -max_messages, -1)
return [json.loads(item) for item in data]
except Exception as e:
logger.error(f"获取对话历史失败: {chat_id}, {e}")
return []
def add_chat_message(self, chat_id: str, role: str, content, ttl: int = 86400) -> bool:
"""
添加消息到对话历史
Args:
chat_id: 会话ID
role: 角色 (user/assistant)
content: 消息内容(字符串或列表)
ttl: 过期时间默认24小时
Returns:
是否添加成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("chat_history", chat_id)
message = {"role": role, "content": content}
self.client.rpush(key, json.dumps(message, ensure_ascii=False))
self.client.expire(key, ttl)
return True
except Exception as e:
logger.error(f"添加对话消息失败: {chat_id}, {e}")
return False
def trim_chat_history(self, chat_id: str, max_messages: int = 100) -> bool:
"""
裁剪对话历史保留最近的N条消息
Args:
chat_id: 会话ID
max_messages: 保留的最大消息数
Returns:
是否成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("chat_history", chat_id)
# 保留最后 max_messages 条
self.client.ltrim(key, -max_messages, -1)
return True
except Exception as e:
logger.error(f"裁剪对话历史失败: {chat_id}, {e}")
return False
def clear_chat_history(self, chat_id: str) -> bool:
"""
清空指定会话的对话历史
Args:
chat_id: 会话ID
Returns:
是否成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("chat_history", chat_id)
self.client.delete(key)
return True
except Exception as e:
logger.error(f"清空对话历史失败: {chat_id}, {e}")
return False
# ==================== 群聊历史记录专用方法 ====================
def get_group_history(self, group_id: str, max_messages: int = 100) -> list:
"""
获取群聊历史记录
Args:
group_id: 群聊ID
max_messages: 最大返回消息数
Returns:
消息列表,每条包含 nickname, content, timestamp
"""
if not self.enabled or not self.client:
return []
try:
key = self._make_key("group_history", group_id)
data = self.client.lrange(key, -max_messages, -1)
return [json.loads(item) for item in data]
except Exception as e:
logger.error(f"获取群聊历史失败: {group_id}, {e}")
return []
def add_group_message(self, group_id: str, nickname: str, content,
record_id: str = None, ttl: int = 86400) -> bool:
"""
添加消息到群聊历史
Args:
group_id: 群聊ID
nickname: 发送者昵称
content: 消息内容
record_id: 可选的记录ID用于后续更新
ttl: 过期时间默认24小时
Returns:
是否添加成功
"""
if not self.enabled or not self.client:
return False
try:
import time
key = self._make_key("group_history", group_id)
message = {
"nickname": nickname,
"content": content,
"timestamp": time.time()
}
if record_id:
message["id"] = record_id
self.client.rpush(key, json.dumps(message, ensure_ascii=False))
self.client.expire(key, ttl)
return True
except Exception as e:
logger.error(f"添加群聊消息失败: {group_id}, {e}")
return False
def update_group_message_by_id(self, group_id: str, record_id: str, new_content) -> bool:
"""
根据ID更新群聊历史中的消息
Args:
group_id: 群聊ID
record_id: 记录ID
new_content: 新内容
Returns:
是否更新成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("group_history", group_id)
# 获取所有消息
data = self.client.lrange(key, 0, -1)
for i, item in enumerate(data):
msg = json.loads(item)
if msg.get("id") == record_id:
msg["content"] = new_content
self.client.lset(key, i, json.dumps(msg, ensure_ascii=False))
return True
return False
except Exception as e:
logger.error(f"更新群聊消息失败: {group_id}, {record_id}, {e}")
return False
def trim_group_history(self, group_id: str, max_messages: int = 100) -> bool:
"""
裁剪群聊历史保留最近的N条消息
Args:
group_id: 群聊ID
max_messages: 保留的最大消息数
Returns:
是否成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("group_history", group_id)
self.client.ltrim(key, -max_messages, -1)
return True
except Exception as e:
logger.error(f"裁剪群聊历史失败: {group_id}, {e}")
return False
# ==================== 限流专用方法 ====================
def check_rate_limit(self, identifier: str, limit: int = 10,
window: int = 60, limit_type: str = "message") -> tuple:
"""
检查是否超过限流
使用滑动窗口算法
Args:
identifier: 标识符如用户wxid、群ID等
limit: 时间窗口内最大请求数
window: 时间窗口(秒)
limit_type: 限流类型message/ai_chat/image_gen等
Returns:
(是否允许, 剩余次数, 重置时间秒数)
"""
if not self.enabled or not self.client:
return (True, limit, 0) # Redis 不可用时不限流
try:
import time
key = self._make_key("rate_limit", limit_type, identifier)
now = time.time()
window_start = now - window
# 使用 pipeline 提高性能
pipe = self.client.pipeline()
# 移除过期的记录
pipe.zremrangebyscore(key, 0, window_start)
# 获取当前窗口内的请求数
pipe.zcard(key)
# 添加当前请求
pipe.zadd(key, {str(now): now})
# 设置过期时间
pipe.expire(key, window)
results = pipe.execute()
current_count = results[1] # zcard 的结果
if current_count >= limit:
# 获取最早的记录时间,计算重置时间
oldest = self.client.zrange(key, 0, 0, withscores=True)
if oldest:
reset_time = int(oldest[0][1] + window - now)
else:
reset_time = window
return (False, 0, max(reset_time, 1))
remaining = limit - current_count - 1
return (True, remaining, 0)
except Exception as e:
logger.error(f"限流检查失败: {identifier}, {e}")
return (True, limit, 0) # 出错时不限流
def get_rate_limit_status(self, identifier: str, limit: int = 10,
window: int = 60, limit_type: str = "message") -> Dict:
"""
获取限流状态(不增加计数)
Args:
identifier: 标识符
limit: 时间窗口内最大请求数
window: 时间窗口(秒)
limit_type: 限流类型
Returns:
状态字典
"""
if not self.enabled or not self.client:
return {"enabled": False, "current": 0, "limit": limit, "remaining": limit}
try:
import time
key = self._make_key("rate_limit", limit_type, identifier)
now = time.time()
window_start = now - window
# 移除过期记录并获取当前数量
self.client.zremrangebyscore(key, 0, window_start)
current = self.client.zcard(key)
return {
"enabled": True,
"current": current,
"limit": limit,
"remaining": max(0, limit - current),
"window": window
}
except Exception as e:
logger.error(f"获取限流状态失败: {identifier}, {e}")
return {"enabled": False, "error": str(e)}
def reset_rate_limit(self, identifier: str, limit_type: str = "message") -> bool:
"""
重置限流计数
Args:
identifier: 标识符
limit_type: 限流类型
Returns:
是否成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("rate_limit", limit_type, identifier)
self.client.delete(key)
return True
except Exception as e:
logger.error(f"重置限流失败: {identifier}, {e}")
return False
# ==================== 媒体缓存专用方法 ====================
def cache_media(self, media_key: str, base64_data: str, media_type: str = "image", ttl: int = 300) -> bool:
"""
缓存媒体文件的 base64 数据
Args:
media_key: 媒体唯一标识(如 cdnurl 的 hash 或 aeskey
base64_data: base64 编码的媒体数据
media_type: 媒体类型image/emoji/video
ttl: 过期时间默认5分钟
Returns:
是否缓存成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("media_cache", media_type, media_key)
# 直接存储 base64 字符串,不再 json 序列化
self.client.setex(key, ttl, base64_data)
logger.debug(f"媒体已缓存: {media_type}/{media_key[:20]}..., TTL={ttl}s")
return True
except Exception as e:
logger.error(f"缓存媒体失败: {media_key}, {e}")
return False
def get_cached_media(self, media_key: str, media_type: str = "image") -> Optional[str]:
"""
获取缓存的媒体 base64 数据
Args:
media_key: 媒体唯一标识
media_type: 媒体类型
Returns:
base64 数据,不存在返回 None
"""
if not self.enabled or not self.client:
return None
try:
key = self._make_key("media_cache", media_type, media_key)
data = self.client.get(key)
if data:
logger.debug(f"媒体缓存命中: {media_type}/{media_key[:20]}...")
return data
return None
except Exception as e:
logger.error(f"获取媒体缓存失败: {media_key}, {e}")
return None
def delete_cached_media(self, media_key: str, media_type: str = "image") -> bool:
"""
删除缓存的媒体
Args:
media_key: 媒体唯一标识
media_type: 媒体类型
Returns:
是否删除成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("media_cache", media_type, media_key)
self.client.delete(key)
return True
except Exception as e:
logger.error(f"删除媒体缓存失败: {media_key}, {e}")
return False
@staticmethod
def generate_media_key(cdnurl: str = "", aeskey: str = "") -> str:
"""
根据 CDN URL 或 AES Key 生成媒体缓存 key
Args:
cdnurl: CDN URL
aeskey: AES Key
Returns:
缓存 key
"""
import hashlib
# 优先使用 aeskey更短更稳定否则使用 cdnurl 的 hash
if aeskey:
return aeskey[:32] # 取前32位作为 key
elif cdnurl:
return hashlib.md5(cdnurl.encode()).hexdigest()
return ""
def get_cache() -> Optional[RedisCache]:
"""
获取全局缓存实例
返回 RedisCache 单例实例。如果还没有初始化,返回一个未连接的实例。
建议在 MessageLogger 初始化后再调用此函数。
"""
return RedisCache._instance
def init_cache(config: Dict) -> RedisCache:
"""
初始化全局缓存实例
Args:
config: Redis 配置
Returns:
缓存实例
"""
global _cache_instance
_cache_instance = RedisCache(config)
return _cache_instance