Files
WechatHookBot/utils/redis_cache.py
2025-12-05 18:06:13 +08:00

745 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Redis 缓存工具类
用于缓存用户信息等数据,减少 API 调用
"""
import json
from typing import Optional, Dict, Any
from loguru import logger
try:
import redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
logger.warning("redis 库未安装,缓存功能将不可用")
class RedisCache:
"""Redis 缓存管理器"""
_instance = None
def __new__(cls, *args, **kwargs):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, config: Dict = None):
"""
初始化 Redis 连接
Args:
config: Redis 配置字典,包含 host, port, password, db 等
"""
if self._initialized:
return
self.client = None
self.enabled = False
self.default_ttl = 3600 # 默认过期时间 1 小时
if not REDIS_AVAILABLE:
logger.warning("Redis 库未安装,缓存功能禁用")
self._initialized = True
return
if config:
self.connect(config)
self._initialized = True
def connect(self, config: Dict) -> bool:
"""
连接 Redis
Args:
config: Redis 配置
Returns:
是否连接成功
"""
if not REDIS_AVAILABLE:
return False
try:
self.client = redis.Redis(
host=config.get("host", "localhost"),
port=config.get("port", 6379),
password=config.get("password", None),
db=config.get("db", 0),
decode_responses=True,
socket_timeout=5,
socket_connect_timeout=5
)
# 测试连接
self.client.ping()
self.enabled = True
self.default_ttl = config.get("ttl", 3600)
logger.success(f"Redis 连接成功: {config.get('host')}:{config.get('port')}")
return True
except Exception as e:
logger.error(f"Redis 连接失败: {e}")
self.client = None
self.enabled = False
return False
def _make_key(self, prefix: str, *args) -> str:
"""
生成缓存 key
Args:
prefix: key 前缀
*args: key 组成部分
Returns:
完整的 key
"""
parts = [prefix] + [str(arg) for arg in args]
return ":".join(parts)
def get(self, key: str) -> Optional[Any]:
"""
获取缓存值
Args:
key: 缓存 key
Returns:
缓存的值,不存在返回 None
"""
if not self.enabled or not self.client:
return None
try:
value = self.client.get(key)
if value:
return json.loads(value)
return None
except Exception as e:
logger.error(f"Redis GET 失败: {key}, {e}")
return None
def set(self, key: str, value: Any, ttl: int = None) -> bool:
"""
设置缓存值
Args:
key: 缓存 key
value: 要缓存的值
ttl: 过期时间(秒),默认使用 default_ttl
Returns:
是否设置成功
"""
if not self.enabled or not self.client:
return False
try:
ttl = ttl or self.default_ttl
self.client.setex(key, ttl, json.dumps(value, ensure_ascii=False))
return True
except Exception as e:
logger.error(f"Redis SET 失败: {key}, {e}")
return False
def delete(self, key: str) -> bool:
"""
删除缓存
Args:
key: 缓存 key
Returns:
是否删除成功
"""
if not self.enabled or not self.client:
return False
try:
self.client.delete(key)
return True
except Exception as e:
logger.error(f"Redis DELETE 失败: {key}, {e}")
return False
def delete_pattern(self, pattern: str) -> int:
"""
删除匹配模式的所有 key
Args:
pattern: key 模式,如 "user_info:*"
Returns:
删除的 key 数量
"""
if not self.enabled or not self.client:
return 0
try:
keys = self.client.keys(pattern)
if keys:
return self.client.delete(*keys)
return 0
except Exception as e:
logger.error(f"Redis DELETE PATTERN 失败: {pattern}, {e}")
return 0
# ==================== 用户信息缓存专用方法 ====================
def get_user_info(self, chatroom_id: str, user_wxid: str) -> Optional[Dict]:
"""
获取缓存的用户信息
Args:
chatroom_id: 群聊 ID
user_wxid: 用户 wxid
Returns:
用户信息字典,不存在返回 None
"""
key = self._make_key("user_info", chatroom_id, user_wxid)
return self.get(key)
def set_user_info(self, chatroom_id: str, user_wxid: str, user_info: Dict, ttl: int = None) -> bool:
"""
缓存用户信息
Args:
chatroom_id: 群聊 ID
user_wxid: 用户 wxid
user_info: 用户信息字典
ttl: 过期时间(秒)
Returns:
是否缓存成功
"""
key = self._make_key("user_info", chatroom_id, user_wxid)
return self.set(key, user_info, ttl)
def get_user_basic_info(self, chatroom_id: str, user_wxid: str) -> Optional[Dict]:
"""
获取缓存的用户基本信息(昵称和头像)
Args:
chatroom_id: 群聊 ID
user_wxid: 用户 wxid
Returns:
包含 nickname 和 avatar_url 的字典
"""
user_info = self.get_user_info(chatroom_id, user_wxid)
if user_info:
# 提取基本信息
nickname = ""
if isinstance(user_info.get("nickName"), dict):
nickname = user_info.get("nickName", {}).get("string", "")
else:
nickname = user_info.get("nickName", "")
avatar_url = user_info.get("bigHeadImgUrl", "")
if nickname or avatar_url:
return {
"nickname": nickname,
"avatar_url": avatar_url
}
return None
def clear_user_cache(self, chatroom_id: str = None, user_wxid: str = None) -> int:
"""
清除用户信息缓存
Args:
chatroom_id: 群聊 ID为空则清除所有群
user_wxid: 用户 wxid为空则清除该群所有用户
Returns:
清除的缓存数量
"""
if chatroom_id and user_wxid:
key = self._make_key("user_info", chatroom_id, user_wxid)
return 1 if self.delete(key) else 0
elif chatroom_id:
pattern = self._make_key("user_info", chatroom_id, "*")
return self.delete_pattern(pattern)
else:
return self.delete_pattern("user_info:*")
def get_cache_stats(self) -> Dict:
"""
获取缓存统计信息
Returns:
统计信息字典
"""
if not self.enabled or not self.client:
return {"enabled": False}
try:
info = self.client.info("memory")
user_keys = len(self.client.keys("user_info:*"))
chat_keys = len(self.client.keys("chat_history:*"))
return {
"enabled": True,
"used_memory": info.get("used_memory_human", "unknown"),
"user_info_count": user_keys,
"chat_history_count": chat_keys
}
except Exception as e:
logger.error(f"获取缓存统计失败: {e}")
return {"enabled": True, "error": str(e)}
# ==================== 对话历史缓存专用方法 ====================
def get_chat_history(self, chat_id: str, max_messages: int = 100) -> list:
"""
获取对话历史
Args:
chat_id: 会话ID私聊为用户wxid群聊为 群ID:用户ID
max_messages: 最大返回消息数
Returns:
消息列表
"""
if not self.enabled or not self.client:
return []
try:
key = self._make_key("chat_history", chat_id)
# 使用 LRANGE 获取最近的消息(列表尾部是最新的)
data = self.client.lrange(key, -max_messages, -1)
return [json.loads(item) for item in data]
except Exception as e:
logger.error(f"获取对话历史失败: {chat_id}, {e}")
return []
def add_chat_message(self, chat_id: str, role: str, content, ttl: int = 86400) -> bool:
"""
添加消息到对话历史
Args:
chat_id: 会话ID
role: 角色 (user/assistant)
content: 消息内容(字符串或列表)
ttl: 过期时间默认24小时
Returns:
是否添加成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("chat_history", chat_id)
message = {"role": role, "content": content}
self.client.rpush(key, json.dumps(message, ensure_ascii=False))
self.client.expire(key, ttl)
return True
except Exception as e:
logger.error(f"添加对话消息失败: {chat_id}, {e}")
return False
def trim_chat_history(self, chat_id: str, max_messages: int = 100) -> bool:
"""
裁剪对话历史保留最近的N条消息
Args:
chat_id: 会话ID
max_messages: 保留的最大消息数
Returns:
是否成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("chat_history", chat_id)
# 保留最后 max_messages 条
self.client.ltrim(key, -max_messages, -1)
return True
except Exception as e:
logger.error(f"裁剪对话历史失败: {chat_id}, {e}")
return False
def clear_chat_history(self, chat_id: str) -> bool:
"""
清空指定会话的对话历史
Args:
chat_id: 会话ID
Returns:
是否成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("chat_history", chat_id)
self.client.delete(key)
return True
except Exception as e:
logger.error(f"清空对话历史失败: {chat_id}, {e}")
return False
# ==================== 群聊历史记录专用方法 ====================
def get_group_history(self, group_id: str, max_messages: int = 100) -> list:
"""
获取群聊历史记录
Args:
group_id: 群聊ID
max_messages: 最大返回消息数
Returns:
消息列表,每条包含 nickname, content, timestamp
"""
if not self.enabled or not self.client:
return []
try:
key = self._make_key("group_history", group_id)
data = self.client.lrange(key, -max_messages, -1)
return [json.loads(item) for item in data]
except Exception as e:
logger.error(f"获取群聊历史失败: {group_id}, {e}")
return []
def add_group_message(self, group_id: str, nickname: str, content,
record_id: str = None, ttl: int = 86400) -> bool:
"""
添加消息到群聊历史
Args:
group_id: 群聊ID
nickname: 发送者昵称
content: 消息内容
record_id: 可选的记录ID用于后续更新
ttl: 过期时间默认24小时
Returns:
是否添加成功
"""
if not self.enabled or not self.client:
return False
try:
import time
key = self._make_key("group_history", group_id)
message = {
"nickname": nickname,
"content": content,
"timestamp": time.time()
}
if record_id:
message["id"] = record_id
self.client.rpush(key, json.dumps(message, ensure_ascii=False))
self.client.expire(key, ttl)
return True
except Exception as e:
logger.error(f"添加群聊消息失败: {group_id}, {e}")
return False
def update_group_message_by_id(self, group_id: str, record_id: str, new_content) -> bool:
"""
根据ID更新群聊历史中的消息
Args:
group_id: 群聊ID
record_id: 记录ID
new_content: 新内容
Returns:
是否更新成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("group_history", group_id)
# 获取所有消息
data = self.client.lrange(key, 0, -1)
for i, item in enumerate(data):
msg = json.loads(item)
if msg.get("id") == record_id:
msg["content"] = new_content
self.client.lset(key, i, json.dumps(msg, ensure_ascii=False))
return True
return False
except Exception as e:
logger.error(f"更新群聊消息失败: {group_id}, {record_id}, {e}")
return False
def trim_group_history(self, group_id: str, max_messages: int = 100) -> bool:
"""
裁剪群聊历史保留最近的N条消息
Args:
group_id: 群聊ID
max_messages: 保留的最大消息数
Returns:
是否成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("group_history", group_id)
self.client.ltrim(key, -max_messages, -1)
return True
except Exception as e:
logger.error(f"裁剪群聊历史失败: {group_id}, {e}")
return False
# ==================== 限流专用方法 ====================
def check_rate_limit(self, identifier: str, limit: int = 10,
window: int = 60, limit_type: str = "message") -> tuple:
"""
检查是否超过限流
使用滑动窗口算法
Args:
identifier: 标识符如用户wxid、群ID等
limit: 时间窗口内最大请求数
window: 时间窗口(秒)
limit_type: 限流类型message/ai_chat/image_gen等
Returns:
(是否允许, 剩余次数, 重置时间秒数)
"""
if not self.enabled or not self.client:
return (True, limit, 0) # Redis 不可用时不限流
try:
import time
key = self._make_key("rate_limit", limit_type, identifier)
now = time.time()
window_start = now - window
# 使用 pipeline 提高性能
pipe = self.client.pipeline()
# 移除过期的记录
pipe.zremrangebyscore(key, 0, window_start)
# 获取当前窗口内的请求数
pipe.zcard(key)
# 添加当前请求
pipe.zadd(key, {str(now): now})
# 设置过期时间
pipe.expire(key, window)
results = pipe.execute()
current_count = results[1] # zcard 的结果
if current_count >= limit:
# 获取最早的记录时间,计算重置时间
oldest = self.client.zrange(key, 0, 0, withscores=True)
if oldest:
reset_time = int(oldest[0][1] + window - now)
else:
reset_time = window
return (False, 0, max(reset_time, 1))
remaining = limit - current_count - 1
return (True, remaining, 0)
except Exception as e:
logger.error(f"限流检查失败: {identifier}, {e}")
return (True, limit, 0) # 出错时不限流
def get_rate_limit_status(self, identifier: str, limit: int = 10,
window: int = 60, limit_type: str = "message") -> Dict:
"""
获取限流状态(不增加计数)
Args:
identifier: 标识符
limit: 时间窗口内最大请求数
window: 时间窗口(秒)
limit_type: 限流类型
Returns:
状态字典
"""
if not self.enabled or not self.client:
return {"enabled": False, "current": 0, "limit": limit, "remaining": limit}
try:
import time
key = self._make_key("rate_limit", limit_type, identifier)
now = time.time()
window_start = now - window
# 移除过期记录并获取当前数量
self.client.zremrangebyscore(key, 0, window_start)
current = self.client.zcard(key)
return {
"enabled": True,
"current": current,
"limit": limit,
"remaining": max(0, limit - current),
"window": window
}
except Exception as e:
logger.error(f"获取限流状态失败: {identifier}, {e}")
return {"enabled": False, "error": str(e)}
def reset_rate_limit(self, identifier: str, limit_type: str = "message") -> bool:
"""
重置限流计数
Args:
identifier: 标识符
limit_type: 限流类型
Returns:
是否成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("rate_limit", limit_type, identifier)
self.client.delete(key)
return True
except Exception as e:
logger.error(f"重置限流失败: {identifier}, {e}")
return False
# ==================== 媒体缓存专用方法 ====================
def cache_media(self, media_key: str, base64_data: str, media_type: str = "image", ttl: int = 300) -> bool:
"""
缓存媒体文件的 base64 数据
Args:
media_key: 媒体唯一标识(如 cdnurl 的 hash 或 aeskey
base64_data: base64 编码的媒体数据
media_type: 媒体类型image/emoji/video
ttl: 过期时间默认5分钟
Returns:
是否缓存成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("media_cache", media_type, media_key)
# 直接存储 base64 字符串,不再 json 序列化
self.client.setex(key, ttl, base64_data)
logger.debug(f"媒体已缓存: {media_type}/{media_key[:20]}..., TTL={ttl}s")
return True
except Exception as e:
logger.error(f"缓存媒体失败: {media_key}, {e}")
return False
def get_cached_media(self, media_key: str, media_type: str = "image") -> Optional[str]:
"""
获取缓存的媒体 base64 数据
Args:
media_key: 媒体唯一标识
media_type: 媒体类型
Returns:
base64 数据,不存在返回 None
"""
if not self.enabled or not self.client:
return None
try:
key = self._make_key("media_cache", media_type, media_key)
data = self.client.get(key)
if data:
logger.debug(f"媒体缓存命中: {media_type}/{media_key[:20]}...")
return data
return None
except Exception as e:
logger.error(f"获取媒体缓存失败: {media_key}, {e}")
return None
def delete_cached_media(self, media_key: str, media_type: str = "image") -> bool:
"""
删除缓存的媒体
Args:
media_key: 媒体唯一标识
media_type: 媒体类型
Returns:
是否删除成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("media_cache", media_type, media_key)
self.client.delete(key)
return True
except Exception as e:
logger.error(f"删除媒体缓存失败: {media_key}, {e}")
return False
@staticmethod
def generate_media_key(cdnurl: str = "", aeskey: str = "") -> str:
"""
根据 CDN URL 或 AES Key 生成媒体缓存 key
Args:
cdnurl: CDN URL
aeskey: AES Key
Returns:
缓存 key
"""
import hashlib
# 优先使用 aeskey更短更稳定否则使用 cdnurl 的 hash
if aeskey:
return aeskey[:32] # 取前32位作为 key
elif cdnurl:
return hashlib.md5(cdnurl.encode()).hexdigest()
return ""
def get_cache() -> Optional[RedisCache]:
"""
获取全局缓存实例
返回 RedisCache 单例实例。如果还没有初始化,返回一个未连接的实例。
建议在 MessageLogger 初始化后再调用此函数。
"""
return RedisCache._instance
def init_cache(config: Dict) -> RedisCache:
"""
初始化全局缓存实例
Args:
config: Redis 配置
Returns:
缓存实例
"""
global _cache_instance
_cache_instance = RedisCache(config)
return _cache_instance