""" Redis 缓存工具类 用于缓存用户信息等数据,减少 API 调用 """ import json from typing import Optional, Dict, Any from loguru import logger try: import redis REDIS_AVAILABLE = True except ImportError: REDIS_AVAILABLE = False logger.warning("redis 库未安装,缓存功能将不可用") class RedisCache: """Redis 缓存管理器""" _instance = None def __new__(cls, *args, **kwargs): """单例模式""" if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self, config: Dict = None): """ 初始化 Redis 连接 Args: config: Redis 配置字典,包含 host, port, password, db 等 """ if self._initialized: return self.client = None self.enabled = False self.default_ttl = 3600 # 默认过期时间 1 小时 if not REDIS_AVAILABLE: logger.warning("Redis 库未安装,缓存功能禁用") self._initialized = True return if config: self.connect(config) self._initialized = True def connect(self, config: Dict) -> bool: """ 连接 Redis Args: config: Redis 配置 Returns: 是否连接成功 """ if not REDIS_AVAILABLE: return False try: self.client = redis.Redis( host=config.get("host", "localhost"), port=config.get("port", 6379), password=config.get("password", None), db=config.get("db", 0), decode_responses=True, socket_timeout=5, socket_connect_timeout=5 ) # 测试连接 self.client.ping() self.enabled = True self.default_ttl = config.get("ttl", 3600) logger.success(f"Redis 连接成功: {config.get('host')}:{config.get('port')}") return True except Exception as e: logger.error(f"Redis 连接失败: {e}") self.client = None self.enabled = False return False def _make_key(self, prefix: str, *args) -> str: """ 生成缓存 key Args: prefix: key 前缀 *args: key 组成部分 Returns: 完整的 key """ parts = [prefix] + [str(arg) for arg in args] return ":".join(parts) def get(self, key: str) -> Optional[Any]: """ 获取缓存值 Args: key: 缓存 key Returns: 缓存的值,不存在返回 None """ if not self.enabled or not self.client: return None try: value = self.client.get(key) if value: return json.loads(value) return None except Exception as e: logger.error(f"Redis GET 失败: {key}, {e}") return None def set(self, key: str, value: Any, ttl: int = None) -> bool: """ 设置缓存值 Args: key: 缓存 key value: 要缓存的值 ttl: 过期时间(秒),默认使用 default_ttl Returns: 是否设置成功 """ if not self.enabled or not self.client: return False try: ttl = ttl or self.default_ttl self.client.setex(key, ttl, json.dumps(value, ensure_ascii=False)) return True except Exception as e: logger.error(f"Redis SET 失败: {key}, {e}") return False def delete(self, key: str) -> bool: """ 删除缓存 Args: key: 缓存 key Returns: 是否删除成功 """ if not self.enabled or not self.client: return False try: self.client.delete(key) return True except Exception as e: logger.error(f"Redis DELETE 失败: {key}, {e}") return False def delete_pattern(self, pattern: str) -> int: """ 删除匹配模式的所有 key Args: pattern: key 模式,如 "user_info:*" Returns: 删除的 key 数量 """ if not self.enabled or not self.client: return 0 try: keys = self.client.keys(pattern) if keys: return self.client.delete(*keys) return 0 except Exception as e: logger.error(f"Redis DELETE PATTERN 失败: {pattern}, {e}") return 0 # ==================== 用户信息缓存专用方法 ==================== def get_user_info(self, chatroom_id: str, user_wxid: str) -> Optional[Dict]: """ 获取缓存的用户信息 Args: chatroom_id: 群聊 ID user_wxid: 用户 wxid Returns: 用户信息字典,不存在返回 None """ key = self._make_key("user_info", chatroom_id, user_wxid) return self.get(key) def set_user_info(self, chatroom_id: str, user_wxid: str, user_info: Dict, ttl: int = None) -> bool: """ 缓存用户信息 Args: chatroom_id: 群聊 ID user_wxid: 用户 wxid user_info: 用户信息字典 ttl: 过期时间(秒) Returns: 是否缓存成功 """ key = self._make_key("user_info", chatroom_id, user_wxid) return self.set(key, user_info, ttl) def get_user_basic_info(self, chatroom_id: str, user_wxid: str) -> Optional[Dict]: """ 获取缓存的用户基本信息(昵称和头像) Args: chatroom_id: 群聊 ID user_wxid: 用户 wxid Returns: 包含 nickname 和 avatar_url 的字典 """ user_info = self.get_user_info(chatroom_id, user_wxid) if user_info: # 提取基本信息 nickname = "" if isinstance(user_info.get("nickName"), dict): nickname = user_info.get("nickName", {}).get("string", "") else: nickname = user_info.get("nickName", "") avatar_url = user_info.get("bigHeadImgUrl", "") if nickname or avatar_url: return { "nickname": nickname, "avatar_url": avatar_url } return None def clear_user_cache(self, chatroom_id: str = None, user_wxid: str = None) -> int: """ 清除用户信息缓存 Args: chatroom_id: 群聊 ID,为空则清除所有群 user_wxid: 用户 wxid,为空则清除该群所有用户 Returns: 清除的缓存数量 """ if chatroom_id and user_wxid: key = self._make_key("user_info", chatroom_id, user_wxid) return 1 if self.delete(key) else 0 elif chatroom_id: pattern = self._make_key("user_info", chatroom_id, "*") return self.delete_pattern(pattern) else: return self.delete_pattern("user_info:*") def get_cache_stats(self) -> Dict: """ 获取缓存统计信息 Returns: 统计信息字典 """ if not self.enabled or not self.client: return {"enabled": False} try: info = self.client.info("memory") user_keys = len(self.client.keys("user_info:*")) chat_keys = len(self.client.keys("chat_history:*")) return { "enabled": True, "used_memory": info.get("used_memory_human", "unknown"), "user_info_count": user_keys, "chat_history_count": chat_keys } except Exception as e: logger.error(f"获取缓存统计失败: {e}") return {"enabled": True, "error": str(e)} # ==================== 对话历史缓存专用方法 ==================== def get_chat_history(self, chat_id: str, max_messages: int = 100) -> list: """ 获取对话历史 Args: chat_id: 会话ID(私聊为用户wxid,群聊为 群ID:用户ID) max_messages: 最大返回消息数 Returns: 消息列表 """ if not self.enabled or not self.client: return [] try: key = self._make_key("chat_history", chat_id) # 使用 LRANGE 获取最近的消息(列表尾部是最新的) data = self.client.lrange(key, -max_messages, -1) return [json.loads(item) for item in data] except Exception as e: logger.error(f"获取对话历史失败: {chat_id}, {e}") return [] def add_chat_message(self, chat_id: str, role: str, content, ttl: int = 86400) -> bool: """ 添加消息到对话历史 Args: chat_id: 会话ID role: 角色 (user/assistant) content: 消息内容(字符串或列表) ttl: 过期时间(秒),默认24小时 Returns: 是否添加成功 """ if not self.enabled or not self.client: return False try: key = self._make_key("chat_history", chat_id) message = {"role": role, "content": content} self.client.rpush(key, json.dumps(message, ensure_ascii=False)) self.client.expire(key, ttl) return True except Exception as e: logger.error(f"添加对话消息失败: {chat_id}, {e}") return False def trim_chat_history(self, chat_id: str, max_messages: int = 100) -> bool: """ 裁剪对话历史,保留最近的N条消息 Args: chat_id: 会话ID max_messages: 保留的最大消息数 Returns: 是否成功 """ if not self.enabled or not self.client: return False try: key = self._make_key("chat_history", chat_id) # 保留最后 max_messages 条 self.client.ltrim(key, -max_messages, -1) return True except Exception as e: logger.error(f"裁剪对话历史失败: {chat_id}, {e}") return False def clear_chat_history(self, chat_id: str) -> bool: """ 清空指定会话的对话历史 Args: chat_id: 会话ID Returns: 是否成功 """ if not self.enabled or not self.client: return False try: key = self._make_key("chat_history", chat_id) self.client.delete(key) return True except Exception as e: logger.error(f"清空对话历史失败: {chat_id}, {e}") return False # ==================== 群聊历史记录专用方法 ==================== def get_group_history(self, group_id: str, max_messages: int = 100) -> list: """ 获取群聊历史记录 Args: group_id: 群聊ID max_messages: 最大返回消息数 Returns: 消息列表,每条包含 nickname, content, timestamp """ if not self.enabled or not self.client: return [] try: key = self._make_key("group_history", group_id) data = self.client.lrange(key, -max_messages, -1) return [json.loads(item) for item in data] except Exception as e: logger.error(f"获取群聊历史失败: {group_id}, {e}") return [] def add_group_message(self, group_id: str, nickname: str, content, record_id: str = None, ttl: int = 86400) -> bool: """ 添加消息到群聊历史 Args: group_id: 群聊ID nickname: 发送者昵称 content: 消息内容 record_id: 可选的记录ID,用于后续更新 ttl: 过期时间(秒),默认24小时 Returns: 是否添加成功 """ if not self.enabled or not self.client: return False try: import time key = self._make_key("group_history", group_id) message = { "nickname": nickname, "content": content, "timestamp": time.time() } if record_id: message["id"] = record_id self.client.rpush(key, json.dumps(message, ensure_ascii=False)) self.client.expire(key, ttl) return True except Exception as e: logger.error(f"添加群聊消息失败: {group_id}, {e}") return False def update_group_message_by_id(self, group_id: str, record_id: str, new_content) -> bool: """ 根据ID更新群聊历史中的消息 Args: group_id: 群聊ID record_id: 记录ID new_content: 新内容 Returns: 是否更新成功 """ if not self.enabled or not self.client: return False try: key = self._make_key("group_history", group_id) # 获取所有消息 data = self.client.lrange(key, 0, -1) for i, item in enumerate(data): msg = json.loads(item) if msg.get("id") == record_id: msg["content"] = new_content self.client.lset(key, i, json.dumps(msg, ensure_ascii=False)) return True return False except Exception as e: logger.error(f"更新群聊消息失败: {group_id}, {record_id}, {e}") return False def trim_group_history(self, group_id: str, max_messages: int = 100) -> bool: """ 裁剪群聊历史,保留最近的N条消息 Args: group_id: 群聊ID max_messages: 保留的最大消息数 Returns: 是否成功 """ if not self.enabled or not self.client: return False try: key = self._make_key("group_history", group_id) self.client.ltrim(key, -max_messages, -1) return True except Exception as e: logger.error(f"裁剪群聊历史失败: {group_id}, {e}") return False # ==================== 限流专用方法 ==================== def check_rate_limit(self, identifier: str, limit: int = 10, window: int = 60, limit_type: str = "message") -> tuple: """ 检查是否超过限流 使用滑动窗口算法 Args: identifier: 标识符(如用户wxid、群ID等) limit: 时间窗口内最大请求数 window: 时间窗口(秒) limit_type: 限流类型(message/ai_chat/image_gen等) Returns: (是否允许, 剩余次数, 重置时间秒数) """ if not self.enabled or not self.client: return (True, limit, 0) # Redis 不可用时不限流 try: import time key = self._make_key("rate_limit", limit_type, identifier) now = time.time() window_start = now - window # 使用 pipeline 提高性能 pipe = self.client.pipeline() # 移除过期的记录 pipe.zremrangebyscore(key, 0, window_start) # 获取当前窗口内的请求数 pipe.zcard(key) # 添加当前请求 pipe.zadd(key, {str(now): now}) # 设置过期时间 pipe.expire(key, window) results = pipe.execute() current_count = results[1] # zcard 的结果 if current_count >= limit: # 获取最早的记录时间,计算重置时间 oldest = self.client.zrange(key, 0, 0, withscores=True) if oldest: reset_time = int(oldest[0][1] + window - now) else: reset_time = window return (False, 0, max(reset_time, 1)) remaining = limit - current_count - 1 return (True, remaining, 0) except Exception as e: logger.error(f"限流检查失败: {identifier}, {e}") return (True, limit, 0) # 出错时不限流 def get_rate_limit_status(self, identifier: str, limit: int = 10, window: int = 60, limit_type: str = "message") -> Dict: """ 获取限流状态(不增加计数) Args: identifier: 标识符 limit: 时间窗口内最大请求数 window: 时间窗口(秒) limit_type: 限流类型 Returns: 状态字典 """ if not self.enabled or not self.client: return {"enabled": False, "current": 0, "limit": limit, "remaining": limit} try: import time key = self._make_key("rate_limit", limit_type, identifier) now = time.time() window_start = now - window # 移除过期记录并获取当前数量 self.client.zremrangebyscore(key, 0, window_start) current = self.client.zcard(key) return { "enabled": True, "current": current, "limit": limit, "remaining": max(0, limit - current), "window": window } except Exception as e: logger.error(f"获取限流状态失败: {identifier}, {e}") return {"enabled": False, "error": str(e)} def reset_rate_limit(self, identifier: str, limit_type: str = "message") -> bool: """ 重置限流计数 Args: identifier: 标识符 limit_type: 限流类型 Returns: 是否成功 """ if not self.enabled or not self.client: return False try: key = self._make_key("rate_limit", limit_type, identifier) self.client.delete(key) return True except Exception as e: logger.error(f"重置限流失败: {identifier}, {e}") return False # ==================== 媒体缓存专用方法 ==================== def cache_media(self, media_key: str, base64_data: str, media_type: str = "image", ttl: int = 300) -> bool: """ 缓存媒体文件的 base64 数据 Args: media_key: 媒体唯一标识(如 cdnurl 的 hash 或 aeskey) base64_data: base64 编码的媒体数据 media_type: 媒体类型(image/emoji/video) ttl: 过期时间(秒),默认5分钟 Returns: 是否缓存成功 """ if not self.enabled or not self.client: return False try: key = self._make_key("media_cache", media_type, media_key) # 直接存储 base64 字符串,不再 json 序列化 self.client.setex(key, ttl, base64_data) logger.debug(f"媒体已缓存: {media_type}/{media_key[:20]}..., TTL={ttl}s") return True except Exception as e: logger.error(f"缓存媒体失败: {media_key}, {e}") return False def get_cached_media(self, media_key: str, media_type: str = "image") -> Optional[str]: """ 获取缓存的媒体 base64 数据 Args: media_key: 媒体唯一标识 media_type: 媒体类型 Returns: base64 数据,不存在返回 None """ if not self.enabled or not self.client: return None try: key = self._make_key("media_cache", media_type, media_key) data = self.client.get(key) if data: logger.debug(f"媒体缓存命中: {media_type}/{media_key[:20]}...") return data return None except Exception as e: logger.error(f"获取媒体缓存失败: {media_key}, {e}") return None def delete_cached_media(self, media_key: str, media_type: str = "image") -> bool: """ 删除缓存的媒体 Args: media_key: 媒体唯一标识 media_type: 媒体类型 Returns: 是否删除成功 """ if not self.enabled or not self.client: return False try: key = self._make_key("media_cache", media_type, media_key) self.client.delete(key) return True except Exception as e: logger.error(f"删除媒体缓存失败: {media_key}, {e}") return False @staticmethod def generate_media_key(cdnurl: str = "", aeskey: str = "") -> str: """ 根据 CDN URL 或 AES Key 生成媒体缓存 key Args: cdnurl: CDN URL aeskey: AES Key Returns: 缓存 key """ import hashlib # 优先使用 aeskey(更短更稳定),否则使用 cdnurl 的 hash if aeskey: return aeskey[:32] # 取前32位作为 key elif cdnurl: return hashlib.md5(cdnurl.encode()).hexdigest() return "" def get_cache() -> Optional[RedisCache]: """ 获取全局缓存实例 返回 RedisCache 单例实例。如果还没有初始化,返回一个未连接的实例。 建议在 MessageLogger 初始化后再调用此函数。 """ return RedisCache._instance def init_cache(config: Dict) -> RedisCache: """ 初始化全局缓存实例 Args: config: Redis 配置 Returns: 缓存实例 """ global _cache_instance _cache_instance = RedisCache(config) return _cache_instance