From 7d3ef7009320bb10bafb1b422da9f0622ca1e556 Mon Sep 17 00:00:00 2001 From: shihao <3127647737@qq.com> Date: Fri, 5 Dec 2025 18:06:13 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E6=95=B4=E4=BD=93?= =?UTF-8?q?=E9=A1=B9=E7=9B=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- WechatHook/callbacks.py | 53 ++ bot.py | 362 +++++++++----- plugins/AIChat/main.py | 445 ++++++++++++----- plugins/ManagePlugin/main.py | 260 +++++++++- plugins/Menu/main.py | 100 ++++ plugins/MessageLogger/main.py | 103 +++- plugins/PerformanceMonitor/main.py | 64 ++- plugins/SignIn/main.py | 78 ++- requirements.txt | 1 + utils/bot_utils.py | 658 +++++++++++++++++++++++++ utils/plugin_base.py | 9 + utils/plugin_manager.py | 89 +++- utils/redis_cache.py | 744 +++++++++++++++++++++++++++++ 13 files changed, 2661 insertions(+), 305 deletions(-) create mode 100644 plugins/Menu/main.py create mode 100644 utils/bot_utils.py create mode 100644 utils/redis_cache.py diff --git a/WechatHook/callbacks.py b/WechatHook/callbacks.py index e60fe08..e07f2f1 100644 --- a/WechatHook/callbacks.py +++ b/WechatHook/callbacks.py @@ -126,6 +126,59 @@ def add_callback_handler(callback_handler): logger.debug(f"注册断开回调: {name}") +def remove_callback_handler(callback_handler): + """ + 移除回调处理器实例(修复内存泄漏) + + Args: + callback_handler: 包含回调方法的对象 + """ + global _GLOBAL_CONNECT_CALLBACK_LIST, _GLOBAL_RECV_CALLBACK_LIST, _GLOBAL_CLOSE_CALLBACK_LIST + + # 移除属于该处理器的所有回调 + _GLOBAL_CONNECT_CALLBACK_LIST = [ + f for f in _GLOBAL_CONNECT_CALLBACK_LIST + if not (hasattr(f, '__self__') and f.__self__ is callback_handler) + ] + _GLOBAL_RECV_CALLBACK_LIST = [ + f for f in _GLOBAL_RECV_CALLBACK_LIST + if not (hasattr(f, '__self__') and f.__self__ is callback_handler) + ] + _GLOBAL_CLOSE_CALLBACK_LIST = [ + f for f in _GLOBAL_CLOSE_CALLBACK_LIST + if not (hasattr(f, '__self__') and f.__self__ is callback_handler) + ] + + logger.debug(f"已移除回调处理器: {type(callback_handler).__name__}") + + +def clear_all_callbacks(): + """ + 清空所有回调(用于关闭时清理) + """ + global _GLOBAL_CONNECT_CALLBACK_LIST, _GLOBAL_RECV_CALLBACK_LIST, _GLOBAL_CLOSE_CALLBACK_LIST + + _GLOBAL_CONNECT_CALLBACK_LIST.clear() + _GLOBAL_RECV_CALLBACK_LIST.clear() + _GLOBAL_CLOSE_CALLBACK_LIST.clear() + + logger.debug("已清空所有回调") + + +def get_callback_count() -> dict: + """ + 获取回调数量统计 + + Returns: + 回调数量字典 + """ + return { + "connect": len(_GLOBAL_CONNECT_CALLBACK_LIST), + "recv": len(_GLOBAL_RECV_CALLBACK_LIST), + "close": len(_GLOBAL_CLOSE_CALLBACK_LIST) + } + + @WINFUNCTYPE(None, ctypes.c_void_p) def wechat_connect_callback(client_id): """ diff --git a/bot.py b/bot.py index eb9ff0f..83f4e4b 100644 --- a/bot.py +++ b/bot.py @@ -2,10 +2,19 @@ WechatHookBot - 主入口 基于个微大客户版 Hook API 的微信机器人框架 + +优化功能: +- 优先级消息队列 +- 自适应熔断器 +- 配置热更新 +- 性能监控 +- 优雅关闭 """ import asyncio +import signal import sys +import time import tomllib from pathlib import Path from loguru import logger @@ -13,6 +22,8 @@ from loguru import logger from WechatHook import NoveLoader, WechatHookClient from WechatHook.callbacks import ( add_callback_handler, + remove_callback_handler, + clear_all_callbacks, wechat_connect_callback, wechat_recv_callback, wechat_close_callback, @@ -23,7 +34,15 @@ from WechatHook.callbacks import ( from utils.hookbot import HookBot from utils.plugin_manager import PluginManager from utils.decorators import scheduler -# from database import KeyvalDB, MessageDB # 不需要数据库 +from utils.bot_utils import ( + PriorityMessageQueue, + MessagePriority, + PRIORITY_MESSAGE_TYPES, + AdaptiveCircuitBreaker, + ConfigWatcher, + PerformanceMonitor, + get_performance_monitor +) class BotService: @@ -37,17 +56,24 @@ class BotService: self.process_id = None # 微信进程 ID self.socket_client_id = None # Socket 客户端 ID self.is_running = False + self.is_shutting_down = False # 是否正在关闭 self.event_loop = None # 事件循环引用 - + # 消息队列和性能控制 - self.message_queue = None + self.message_queue: PriorityMessageQueue = None # 优先级消息队列 self.queue_config = {} self.concurrency_config = {} self.consumer_tasks = [] self.processing_semaphore = None - self.circuit_breaker_failures = 0 - self.circuit_breaker_open = False - self.circuit_breaker_last_failure = 0 + + # 自适应熔断器 + self.circuit_breaker: AdaptiveCircuitBreaker = None + + # 配置热更新 + self.config_watcher: ConfigWatcher = None + + # 性能监控 + self.performance_monitor: PerformanceMonitor = None @CONNECT_CALLBACK(in_class=True) def on_connect(self, client_id): @@ -85,118 +111,125 @@ class BotService: logger.error(f"消息入队失败: {e}") async def _enqueue_message(self, msg_type, data): - """将消息加入队列""" + """将消息加入优先级队列""" try: + # 记录收到消息 + if self.performance_monitor: + self.performance_monitor.record_message_received() + # 检查队列是否已满 - if self.message_queue.qsize() >= self.queue_config.get("max_size", 1000): + if self.message_queue.full(): overflow_strategy = self.queue_config.get("overflow_strategy", "drop_oldest") - + if overflow_strategy == "drop_oldest": - # 丢弃最旧的消息 - try: - self.message_queue.get_nowait() - logger.warning("队列已满,丢弃最旧消息") - except asyncio.QueueEmpty: - pass + # 丢弃优先级最低的消息 + if self.message_queue.drop_lowest_priority(): + logger.warning("队列已满,丢弃优先级最低的消息") + if self.performance_monitor: + self.performance_monitor.record_message_dropped() elif overflow_strategy == "sampling": - # 采样处理,随机丢弃 + # 采样处理,随机丢弃(但高优先级消息不丢弃) import random - if random.random() < 0.5: # 50% 概率丢弃 + priority = PRIORITY_MESSAGE_TYPES.get(msg_type, MessagePriority.NORMAL) + if priority < MessagePriority.HIGH and random.random() < 0.5: logger.debug("队列压力大,采样丢弃消息") + if self.performance_monitor: + self.performance_monitor.record_message_dropped() return else: # degrade - logger.warning("队列已满,降级处理") - return + # 降级处理(但高优先级消息不丢弃) + priority = PRIORITY_MESSAGE_TYPES.get(msg_type, MessagePriority.NORMAL) + if priority < MessagePriority.HIGH: + logger.warning("队列已满,降级处理") + if self.performance_monitor: + self.performance_monitor.record_message_dropped() + return + + # 将消息放入优先级队列 + await self.message_queue.put(msg_type, data) + + # 记录队列大小 + if self.performance_monitor: + self.performance_monitor.record_queue_size(self.message_queue.qsize()) - # 将消息放入队列 - await self.message_queue.put((msg_type, data)) - except Exception as e: logger.error(f"消息入队异常: {e}") async def _message_consumer(self, consumer_id: int): - """消息消费者协程""" - logger.info(f"消息消费者 {consumer_id} 已启动") - - while self.is_running: + """消息消费者协程 - 纯队列串行模式,避免并发触发风控""" + logger.info(f"消息消费者 {consumer_id} 已启动(串行模式)") + + while self.is_running and not self.is_shutting_down: try: # 从队列获取消息,设置超时避免无限等待 msg_type, data = await asyncio.wait_for( - self.message_queue.get(), + self.message_queue.get(), timeout=1.0 ) - + # 检查熔断器状态 - if self._check_circuit_breaker(): + if self.circuit_breaker and self.circuit_breaker.is_open(): logger.debug("熔断器开启,跳过消息处理") + self.circuit_breaker.record_rejection() + self.message_queue.task_done() continue - - # 创建并发任务,不等待完成 - timeout = self.concurrency_config.get("plugin_task_timeout_seconds", 5) - - # 使用信号量控制并发数量 - async def process_with_semaphore(): - async with self.processing_semaphore: - try: - await asyncio.wait_for( - self.hookbot.process_message(msg_type, data), - timeout=timeout - ) - self._reset_circuit_breaker() - except asyncio.TimeoutError: - logger.warning(f"消息处理超时 (>{timeout}s): type={msg_type}") - self._record_circuit_breaker_failure() - except Exception as e: - logger.error(f"消息处理异常: {e}") - self._record_circuit_breaker_failure() - - # 创建任务但不等待,实现真正并发 - asyncio.create_task(process_with_semaphore()) - + + # 串行处理:等待当前消息处理完成后再处理下一条 + timeout = self.concurrency_config.get("plugin_task_timeout_seconds", 720) + start_time = time.time() + + try: + await asyncio.wait_for( + self.hookbot.process_message(msg_type, data), + timeout=timeout + ) + # 记录成功 + processing_time = time.time() - start_time + if self.circuit_breaker: + self.circuit_breaker.record_success() + if self.performance_monitor: + self.performance_monitor.record_message_processed(processing_time) + + except asyncio.TimeoutError: + logger.warning(f"消息处理超时 (>{timeout}s): type={msg_type}") + if self.circuit_breaker: + self.circuit_breaker.record_failure() + if self.performance_monitor: + self.performance_monitor.record_message_failed() + + except Exception as e: + logger.error(f"消息处理异常: {e}") + if self.circuit_breaker: + self.circuit_breaker.record_failure() + if self.performance_monitor: + self.performance_monitor.record_message_failed() + # 标记任务完成 self.message_queue.task_done() - + + # 更新熔断器统计 + if self.performance_monitor and self.circuit_breaker: + self.performance_monitor.update_circuit_breaker_stats( + self.circuit_breaker.get_stats() + ) + + # 消息间隔,避免发送太快触发风控 + message_interval = self.concurrency_config.get("message_interval_ms", 100) + if message_interval > 0: + await asyncio.sleep(message_interval / 1000.0) + except asyncio.TimeoutError: # 队列为空,继续等待 continue + except asyncio.CancelledError: + # 任务被取消,退出循环 + logger.info(f"消费者 {consumer_id} 收到取消信号") + break except Exception as e: logger.error(f"消费者 {consumer_id} 异常: {e}") await asyncio.sleep(0.1) # 短暂休息避免忙等 - def _check_circuit_breaker(self) -> bool: - """检查熔断器状态""" - if not self.concurrency_config.get("enable_circuit_breaker", True): - return False - - if self.circuit_breaker_open: - # 检查是否可以尝试恢复 - import time - if time.time() - self.circuit_breaker_last_failure > 30: # 30秒后尝试恢复 - self.circuit_breaker_open = False - self.circuit_breaker_failures = 0 - logger.info("熔断器尝试恢复") - return False - return True - return False - - def _record_circuit_breaker_failure(self): - """记录熔断器失败""" - if not self.concurrency_config.get("enable_circuit_breaker", True): - return - - self.circuit_breaker_failures += 1 - threshold = self.concurrency_config.get("circuit_breaker_threshold", 5) - - if self.circuit_breaker_failures >= threshold: - import time - self.circuit_breaker_open = True - self.circuit_breaker_last_failure = time.time() - logger.warning(f"熔断器开启,连续失败 {self.circuit_breaker_failures} 次") - - def _reset_circuit_breaker(self): - """重置熔断器""" - if self.circuit_breaker_failures > 0: - self.circuit_breaker_failures = 0 + logger.info(f"消费者 {consumer_id} 已退出") @CLOSE_CALLBACK(in_class=True) def on_close(self, client_id): @@ -235,17 +268,37 @@ class BotService: # 初始化性能配置 self.queue_config = config.get("Queue", {}) self.concurrency_config = config.get("Concurrency", {}) - - # 创建消息队列 + + # 创建优先级消息队列 queue_size = self.queue_config.get("max_size", 1000) - self.message_queue = asyncio.Queue(maxsize=queue_size) - logger.info(f"消息队列已创建,容量: {queue_size}") - + self.message_queue = PriorityMessageQueue(maxsize=queue_size) + logger.info(f"优先级消息队列已创建,容量: {queue_size}") + # 创建并发控制信号量 max_concurrency = self.concurrency_config.get("plugin_max_concurrency", 8) self.processing_semaphore = asyncio.Semaphore(max_concurrency) logger.info(f"并发控制已设置,最大并发: {max_concurrency}") + # 创建自适应熔断器 + if self.concurrency_config.get("enable_circuit_breaker", True): + self.circuit_breaker = AdaptiveCircuitBreaker( + failure_threshold=self.concurrency_config.get("circuit_breaker_threshold", 10), + success_threshold=3, + initial_recovery_time=5.0, + max_recovery_time=300.0 + ) + logger.info("自适应熔断器已创建") + + # 创建性能监控器 + self.performance_monitor = get_performance_monitor() + logger.info("性能监控器已创建") + + # 创建配置热更新监听器 + self.config_watcher = ConfigWatcher("main_config.toml", check_interval=5.0) + self.config_watcher.register_callback(self._on_config_update) + await self.config_watcher.start() + logger.info("配置热更新监听器已启动") + # 不需要数据库(简化版本) # 获取 DLL 路径 @@ -340,6 +393,26 @@ class BotService: return True + def _on_config_update(self, new_config: dict): + """配置热更新回调""" + logger.info("正在应用新配置...") + + # 更新队列配置 + self.queue_config = new_config.get("Queue", self.queue_config) + + # 更新并发配置 + old_concurrency = self.concurrency_config + self.concurrency_config = new_config.get("Concurrency", self.concurrency_config) + + # 更新熔断器配置 + if self.circuit_breaker: + new_threshold = self.concurrency_config.get("circuit_breaker_threshold", 10) + if new_threshold != old_concurrency.get("circuit_breaker_threshold", 10): + self.circuit_breaker.failure_threshold = new_threshold + logger.info(f"熔断器阈值已更新: {new_threshold}") + + logger.success("配置热更新完成") + async def run(self): """运行机器人""" if not await self.initialize(): @@ -347,6 +420,15 @@ class BotService: self.is_running = True + # 启动定期性能报告 + async def periodic_stats(): + while self.is_running: + await asyncio.sleep(300) # 每5分钟输出一次 + if self.performance_monitor and self.is_running: + self.performance_monitor.print_stats() + + stats_task = asyncio.create_task(periodic_stats()) + try: logger.info("机器人正在运行,按 Ctrl+C 停止...") while self.is_running: @@ -354,44 +436,96 @@ class BotService: except KeyboardInterrupt: logger.info("收到停止信号...") finally: + stats_task.cancel() await self.stop() async def stop(self): - """停止机器人""" - logger.info("正在停止机器人...") - self.is_running = False + """优雅关闭机器人""" + if self.is_shutting_down: + return + self.is_shutting_down = True - # 停止消息消费者 + logger.info("=" * 60) + logger.info("正在优雅关闭机器人...") + logger.info("=" * 60) + + # 1. 停止接收新消息 + self.is_running = False + logger.info("[1/7] 停止接收新消息") + + # 2. 等待队列中的消息处理完成(带超时) + if self.message_queue and not self.message_queue.empty(): + queue_size = self.message_queue.qsize() + logger.info(f"[2/7] 等待队列中 {queue_size} 条消息处理完成...") + try: + await asyncio.wait_for( + self.message_queue.join(), + timeout=30 + ) + logger.info("[2/7] 队列消息已全部处理完成") + except asyncio.TimeoutError: + logger.warning("[2/7] 队列消息未在 30 秒内处理完成,强制清空") + # 清空剩余消息 + while not self.message_queue.empty(): + try: + self.message_queue.get_nowait() + self.message_queue.task_done() + except: + break + else: + logger.info("[2/7] 队列为空,无需等待") + + # 3. 停止消息消费者 if self.consumer_tasks: - logger.info("正在停止消息消费者...") + logger.info(f"[3/7] 停止 {len(self.consumer_tasks)} 个消息消费者...") for task in self.consumer_tasks: task.cancel() - - # 等待所有消费者任务完成 - if self.consumer_tasks: - await asyncio.gather(*self.consumer_tasks, return_exceptions=True) + await asyncio.gather(*self.consumer_tasks, return_exceptions=True) self.consumer_tasks.clear() - logger.info("消息消费者已停止") + logger.info("[3/7] 消息消费者已停止") + else: + logger.info("[3/7] 无消费者需要停止") - # 清空消息队列 - if self.message_queue: - while not self.message_queue.empty(): - try: - self.message_queue.get_nowait() - self.message_queue.task_done() - except asyncio.QueueEmpty: - break - logger.info("消息队列已清空") + # 4. 停止配置监听器 + if self.config_watcher: + logger.info("[4/7] 停止配置监听器...") + await self.config_watcher.stop() + logger.info("[4/7] 配置监听器已停止") + else: + logger.info("[4/7] 无配置监听器") - # 停止定时任务 + # 5. 卸载插件 + if self.plugin_manager: + logger.info("[5/7] 卸载插件...") + await self.plugin_manager.unload_plugins() + logger.info("[5/7] 插件已卸载") + else: + logger.info("[5/7] 无插件需要卸载") + + # 6. 停止定时任务 if scheduler.running: + logger.info("[6/7] 停止定时任务...") scheduler.shutdown() + logger.info("[6/7] 定时任务已停止") + else: + logger.info("[6/7] 定时任务未运行") + + # 7. 清理回调和销毁微信连接 + logger.info("[7/7] 清理资源...") + remove_callback_handler(self) + clear_all_callbacks() - # 销毁微信连接 if self.loader: self.loader.DestroyWeChat() - logger.success("机器人已停止") + # 输出最终性能报告 + if self.performance_monitor: + logger.info("最终性能报告:") + self.performance_monitor.print_stats() + + logger.success("=" * 60) + logger.success("机器人已优雅关闭") + logger.success("=" * 60) async def main(): diff --git a/plugins/AIChat/main.py b/plugins/AIChat/main.py index a3ff25c..7a10fce 100644 --- a/plugins/AIChat/main.py +++ b/plugins/AIChat/main.py @@ -2,6 +2,7 @@ AI 聊天插件 支持自定义模型、API 和人设 +支持 Redis 存储对话历史和限流 """ import asyncio @@ -12,6 +13,7 @@ from datetime import datetime from loguru import logger from utils.plugin_base import PluginBase from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message +from utils.redis_cache import get_cache import xml.etree.ElementTree as ET import base64 import uuid @@ -95,6 +97,92 @@ class AIChat(PluginBase): else: return sender_wxid or from_wxid # 私聊使用用户ID + async def _get_user_nickname(self, bot, from_wxid: str, user_wxid: str, is_group: bool) -> str: + """ + 获取用户昵称,优先使用 Redis 缓存 + + Args: + bot: WechatHookClient 实例 + from_wxid: 消息来源(群聊ID或私聊用户ID) + user_wxid: 用户wxid + is_group: 是否群聊 + + Returns: + 用户昵称 + """ + if not is_group: + return "" + + nickname = "" + + # 1. 优先从 Redis 缓存获取 + redis_cache = get_cache() + if redis_cache and redis_cache.enabled: + cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid) + if cached_info and cached_info.get("nickname"): + logger.debug(f"[缓存命中] 用户昵称: {user_wxid} -> {cached_info['nickname']}") + return cached_info["nickname"] + + # 2. 缓存未命中,调用 API 获取 + try: + user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid) + if user_info and user_info.get("nickName", {}).get("string"): + nickname = user_info["nickName"]["string"] + # 存入缓存 + if redis_cache and redis_cache.enabled: + redis_cache.set_user_info(from_wxid, user_wxid, user_info) + logger.debug(f"[已缓存] 用户昵称: {user_wxid} -> {nickname}") + return nickname + except Exception as e: + logger.warning(f"API获取用户昵称失败: {e}") + + # 3. 从 MessageLogger 数据库查询 + if not nickname: + try: + from plugins.MessageLogger.main import MessageLogger + msg_logger = MessageLogger.get_instance() + if msg_logger: + with msg_logger.get_db_connection() as conn: + with conn.cursor() as cursor: + cursor.execute( + "SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1", + (user_wxid,) + ) + result = cursor.fetchone() + if result: + nickname = result[0] + except Exception as e: + logger.debug(f"从数据库获取昵称失败: {e}") + + # 4. 最后降级使用 wxid + if not nickname: + nickname = user_wxid or "未知用户" + + return nickname + + def _check_rate_limit(self, user_wxid: str) -> tuple: + """ + 检查用户是否超过限流 + + Args: + user_wxid: 用户wxid + + Returns: + (是否允许, 剩余次数, 重置时间秒数) + """ + rate_limit_config = self.config.get("rate_limit", {}) + if not rate_limit_config.get("enabled", True): + return (True, 999, 0) + + redis_cache = get_cache() + if not redis_cache or not redis_cache.enabled: + return (True, 999, 0) # Redis 不可用时不限流 + + limit = rate_limit_config.get("ai_chat_limit", 20) + window = rate_limit_config.get("ai_chat_window", 60) + + return redis_cache.check_rate_limit(user_wxid, limit, window, "ai_chat") + def _add_to_memory(self, chat_id: str, role: str, content, image_base64: str = None): """ 添加消息到记忆 @@ -108,9 +196,6 @@ class AIChat(PluginBase): if not self.config.get("memory", {}).get("enabled", False): return - if chat_id not in self.memory: - self.memory[chat_id] = [] - # 如果有图片,构建多模态内容 if image_base64: message_content = [ @@ -120,6 +205,22 @@ class AIChat(PluginBase): else: message_content = content + # 优先使用 Redis 存储 + redis_config = self.config.get("redis", {}) + if redis_config.get("use_redis_history", True): + redis_cache = get_cache() + if redis_cache and redis_cache.enabled: + ttl = redis_config.get("chat_history_ttl", 86400) + redis_cache.add_chat_message(chat_id, role, message_content, ttl=ttl) + # 裁剪历史 + max_messages = self.config["memory"]["max_messages"] + redis_cache.trim_chat_history(chat_id, max_messages) + return + + # 降级到内存存储 + if chat_id not in self.memory: + self.memory[chat_id] = [] + self.memory[chat_id].append({"role": role, "content": message_content}) # 限制记忆长度 @@ -131,16 +232,47 @@ class AIChat(PluginBase): """获取记忆中的消息""" if not self.config.get("memory", {}).get("enabled", False): return [] + + # 优先从 Redis 获取 + redis_config = self.config.get("redis", {}) + if redis_config.get("use_redis_history", True): + redis_cache = get_cache() + if redis_cache and redis_cache.enabled: + max_messages = self.config["memory"]["max_messages"] + return redis_cache.get_chat_history(chat_id, max_messages) + + # 降级到内存 return self.memory.get(chat_id, []) def _clear_memory(self, chat_id: str): """清空指定会话的记忆""" + # 清空 Redis + redis_config = self.config.get("redis", {}) + if redis_config.get("use_redis_history", True): + redis_cache = get_cache() + if redis_cache and redis_cache.enabled: + redis_cache.clear_chat_history(chat_id) + + # 同时清空内存 if chat_id in self.memory: del self.memory[chat_id] async def _download_and_encode_image(self, bot, cdnurl: str, aeskey: str) -> str: - """下载图片并转换为base64""" + """下载图片并转换为base64,优先从缓存获取""" try: + # 1. 优先从 Redis 缓存获取 + from utils.redis_cache import RedisCache + redis_cache = get_cache() + if redis_cache and redis_cache.enabled: + media_key = RedisCache.generate_media_key(cdnurl, aeskey) + if media_key: + cached_data = redis_cache.get_cached_media(media_key, "image") + if cached_data: + logger.debug(f"[缓存命中] 图片从 Redis 获取: {media_key[:20]}...") + return cached_data + + # 2. 缓存未命中,下载图片 + logger.debug(f"[缓存未命中] 开始下载图片...") temp_dir = Path(__file__).parent / "temp" temp_dir.mkdir(exist_ok=True) @@ -168,74 +300,114 @@ class AIChat(PluginBase): with open(save_path, "rb") as f: image_data = base64.b64encode(f.read()).decode() + base64_result = f"data:image/jpeg;base64,{image_data}" + + # 3. 缓存到 Redis(供后续使用) + if redis_cache and redis_cache.enabled and media_key: + redis_cache.cache_media(media_key, base64_result, "image", ttl=300) + logger.debug(f"[已缓存] 图片缓存到 Redis: {media_key[:20]}...") + try: Path(save_path).unlink() except: pass - return f"data:image/jpeg;base64,{image_data}" + return base64_result except Exception as e: logger.error(f"下载图片失败: {e}") return "" - async def _download_emoji_and_encode(self, cdn_url: str) -> str: - """下载表情包并转换为base64(HTTP 直接下载)""" - try: - # 替换 HTML 实体 - cdn_url = cdn_url.replace("&", "&") + async def _download_emoji_and_encode(self, cdn_url: str, max_retries: int = 3) -> str: + """下载表情包并转换为base64(HTTP 直接下载,带重试机制),优先从缓存获取""" + # 替换 HTML 实体 + cdn_url = cdn_url.replace("&", "&") - temp_dir = Path(__file__).parent / "temp" - temp_dir.mkdir(exist_ok=True) + # 1. 优先从 Redis 缓存获取 + from utils.redis_cache import RedisCache + redis_cache = get_cache() + media_key = RedisCache.generate_media_key(cdnurl=cdn_url) + if redis_cache and redis_cache.enabled and media_key: + cached_data = redis_cache.get_cached_media(media_key, "emoji") + if cached_data: + logger.debug(f"[缓存命中] 表情包从 Redis 获取: {media_key[:20]}...") + return cached_data - filename = f"temp_{uuid.uuid4().hex[:8]}.gif" - save_path = temp_dir / filename + # 2. 缓存未命中,下载表情包 + logger.debug(f"[缓存未命中] 开始下载表情包...") + temp_dir = Path(__file__).parent / "temp" + temp_dir.mkdir(exist_ok=True) - # 使用 aiohttp 下载 - timeout = aiohttp.ClientTimeout(total=30) + filename = f"temp_{uuid.uuid4().hex[:8]}.gif" + save_path = temp_dir / filename - # 配置代理 - connector = None - proxy_config = self.config.get("proxy", {}) - if proxy_config.get("enabled", False): - proxy_type = proxy_config.get("type", "socks5").upper() - proxy_host = proxy_config.get("host", "127.0.0.1") - proxy_port = proxy_config.get("port", 7890) - proxy_username = proxy_config.get("username") - proxy_password = proxy_config.get("password") + last_error = None - if proxy_username and proxy_password: - proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}" - else: - proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}" + for attempt in range(max_retries): + try: + # 使用 aiohttp 下载,每次重试增加超时时间 + timeout = aiohttp.ClientTimeout(total=30 + attempt * 15) - if PROXY_SUPPORT: - try: - connector = ProxyConnector.from_url(proxy_url) - except: - connector = None + # 配置代理 + connector = None + proxy_config = self.config.get("proxy", {}) + if proxy_config.get("enabled", False): + proxy_type = proxy_config.get("type", "socks5").upper() + proxy_host = proxy_config.get("host", "127.0.0.1") + proxy_port = proxy_config.get("port", 7890) + proxy_username = proxy_config.get("username") + proxy_password = proxy_config.get("password") - async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: - async with session.get(cdn_url) as response: - if response.status == 200: - content = await response.read() - with open(save_path, "wb") as f: - f.write(content) + if proxy_username and proxy_password: + proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}" + else: + proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}" - # 编码为 base64 - image_data = base64.b64encode(content).decode() - - # 删除临时文件 + if PROXY_SUPPORT: try: - save_path.unlink() + connector = ProxyConnector.from_url(proxy_url) except: - pass + connector = None - return f"data:image/gif;base64,{image_data}" + async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: + async with session.get(cdn_url) as response: + if response.status == 200: + content = await response.read() - return "" - except Exception as e: - logger.error(f"下载表情包失败: {e}") - return "" + if len(content) == 0: + logger.warning(f"表情包下载内容为空,重试 {attempt + 1}/{max_retries}") + continue + + # 编码为 base64 + image_data = base64.b64encode(content).decode() + + logger.debug(f"表情包下载成功,大小: {len(content)} 字节") + base64_result = f"data:image/gif;base64,{image_data}" + + # 3. 缓存到 Redis(供后续使用) + if redis_cache and redis_cache.enabled and media_key: + redis_cache.cache_media(media_key, base64_result, "emoji", ttl=300) + logger.debug(f"[已缓存] 表情包缓存到 Redis: {media_key[:20]}...") + + return base64_result + else: + logger.warning(f"表情包下载失败,状态码: {response.status},重试 {attempt + 1}/{max_retries}") + + except asyncio.TimeoutError: + last_error = "请求超时" + logger.warning(f"表情包下载超时,重试 {attempt + 1}/{max_retries}") + except aiohttp.ClientError as e: + last_error = str(e) + logger.warning(f"表情包下载网络错误: {e},重试 {attempt + 1}/{max_retries}") + except Exception as e: + last_error = str(e) + logger.warning(f"表情包下载异常: {e},重试 {attempt + 1}/{max_retries}") + + # 重试前等待(指数退避) + if attempt < max_retries - 1: + await asyncio.sleep(1 * (attempt + 1)) + + logger.error(f"表情包下载失败,已重试 {max_retries} 次: {last_error}") + return "" async def _generate_image_description(self, image_base64: str, prompt: str, config: dict) -> str: """ @@ -479,37 +651,8 @@ class AIChat(PluginBase): # 检查是否应该回复 should_reply = self._should_reply(message, content, bot_wxid) - # 获取用户昵称(用于历史记录) - nickname = "" - if is_group: - try: - user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid) - if user_info and user_info.get("nickName", {}).get("string"): - nickname = user_info["nickName"]["string"] - except: - pass - - # 如果获取昵称失败,从 MessageLogger 数据库查询 - if not nickname: - from plugins.MessageLogger.main import MessageLogger - msg_logger = MessageLogger.get_instance() - if msg_logger: - try: - with msg_logger.get_db_connection() as conn: - with conn.cursor() as cursor: - cursor.execute( - "SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1", - (user_wxid,) - ) - result = cursor.fetchone() - if result: - nickname = result[0] - except: - pass - - # 最后降级使用 wxid - if not nickname: - nickname = user_wxid or sender_wxid or "未知用户" + # 获取用户昵称(用于历史记录)- 使用缓存优化 + nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group) # 保存到群组历史记录(所有消息都保存,不管是否回复) if is_group: @@ -519,6 +662,16 @@ class AIChat(PluginBase): if not should_reply: return + # 限流检查(仅在需要回复时检查) + allowed, remaining, reset_time = self._check_rate_limit(user_wxid) + if not allowed: + rate_limit_config = self.config.get("rate_limit", {}) + msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~") + msg = msg.format(seconds=reset_time) + await bot.send_text(from_wxid, msg) + logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置") + return False + # 提取实际消息内容(去除@) actual_content = self._extract_content(message, content) if not actual_content: @@ -1004,8 +1157,23 @@ class AIChat(PluginBase): json.dump(history, f, ensure_ascii=False, indent=2) temp_file.replace(history_file) + def _use_redis_for_group_history(self) -> bool: + """检查是否使用 Redis 存储群聊历史""" + redis_config = self.config.get("redis", {}) + if not redis_config.get("use_redis_history", True): + return False + redis_cache = get_cache() + return redis_cache and redis_cache.enabled + async def _load_history(self, chat_id: str) -> list: - """异步读取群聊历史, 用锁避免与写入冲突""" + """异步读取群聊历史, 优先使用 Redis""" + # 优先使用 Redis + if self._use_redis_for_group_history(): + redis_cache = get_cache() + max_history = self.config.get("history", {}).get("max_history", 100) + return redis_cache.get_group_history(chat_id, max_history) + + # 降级到文件存储 history_file = self._get_history_file(chat_id) if not history_file: return [] @@ -1015,6 +1183,10 @@ class AIChat(PluginBase): async def _save_history(self, chat_id: str, history: list): """异步写入群聊历史, 包含长度截断""" + # Redis 模式下不需要单独保存,add_group_message 已经处理 + if self._use_redis_for_group_history(): + return + history_file = self._get_history_file(chat_id) if not history_file: return @@ -1040,6 +1212,27 @@ class AIChat(PluginBase): if not self.config.get("history", {}).get("enabled", True): return + # 构建消息内容 + if image_base64: + message_content = [ + {"type": "text", "text": content}, + {"type": "image_url", "image_url": {"url": image_base64}} + ] + else: + message_content = content + + # 优先使用 Redis + if self._use_redis_for_group_history(): + redis_cache = get_cache() + redis_config = self.config.get("redis", {}) + ttl = redis_config.get("group_history_ttl", 172800) + redis_cache.add_group_message(chat_id, nickname, message_content, ttl=ttl) + # 裁剪历史 + max_history = self.config.get("history", {}).get("max_history", 100) + redis_cache.trim_group_history(chat_id, max_history) + return + + # 降级到文件存储 history_file = self._get_history_file(chat_id) if not history_file: return @@ -1050,17 +1243,10 @@ class AIChat(PluginBase): message_record = { "nickname": nickname, - "timestamp": datetime.now().isoformat() + "timestamp": datetime.now().isoformat(), + "content": message_content } - if image_base64: - message_record["content"] = [ - {"type": "text", "text": content}, - {"type": "image_url", "image_url": {"url": image_base64}} - ] - else: - message_record["content"] = content - history.append(message_record) max_history = self.config.get("history", {}).get("max_history", 100) if len(history) > max_history: @@ -1073,6 +1259,18 @@ class AIChat(PluginBase): if not self.config.get("history", {}).get("enabled", True): return + # 优先使用 Redis + if self._use_redis_for_group_history(): + redis_cache = get_cache() + redis_config = self.config.get("redis", {}) + ttl = redis_config.get("group_history_ttl", 172800) + redis_cache.add_group_message(chat_id, nickname, content, record_id=record_id, ttl=ttl) + # 裁剪历史 + max_history = self.config.get("history", {}).get("max_history", 100) + redis_cache.trim_group_history(chat_id, max_history) + return + + # 降级到文件存储 history_file = self._get_history_file(chat_id) if not history_file: return @@ -1097,6 +1295,13 @@ class AIChat(PluginBase): if not self.config.get("history", {}).get("enabled", True): return + # 优先使用 Redis + if self._use_redis_for_group_history(): + redis_cache = get_cache() + redis_cache.update_group_message_by_id(chat_id, record_id, new_content) + return + + # 降级到文件存储 history_file = self._get_history_file(chat_id) if not history_file: return @@ -1204,18 +1409,20 @@ class AIChat(PluginBase): return True logger.info(f"AI处理引用图片消息: {title_text[:50]}...") - - # 获取用户昵称 - nickname = "" - if is_group: - try: - user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid) - if user_info and user_info.get("nickName", {}).get("string"): - nickname = user_info["nickName"]["string"] - logger.info(f"获取到用户昵称: {nickname}") - except Exception as e: - logger.error(f"获取用户昵称失败: {e}") - + + # 限流检查 + allowed, remaining, reset_time = self._check_rate_limit(user_wxid) + if not allowed: + rate_limit_config = self.config.get("rate_limit", {}) + msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~") + msg = msg.format(seconds=reset_time) + await bot.send_text(from_wxid, msg) + logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置") + return False + + # 获取用户昵称 - 使用缓存优化 + nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group) + # 下载并编码图片 logger.info(f"开始下载图片: {cdnbigimgurl[:50]}...") image_base64 = await self._download_and_encode_image(bot, cdnbigimgurl, aeskey) @@ -1627,34 +1834,8 @@ class AIChat(PluginBase): if not is_emoji and not aeskey: return True - # 获取用户昵称 - nickname = "" - try: - user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid) - if user_info and user_info.get("nickName", {}).get("string"): - nickname = user_info["nickName"]["string"] - except: - pass - - if not nickname: - from plugins.MessageLogger.main import MessageLogger - msg_logger = MessageLogger.get_instance() - if msg_logger: - try: - with msg_logger.get_db_connection() as conn: - with conn.cursor() as cursor: - cursor.execute( - "SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1", - (user_wxid,) - ) - result = cursor.fetchone() - if result: - nickname = result[0] - except: - pass - - if not nickname: - nickname = user_wxid or sender_wxid or "未知用户" + # 获取用户昵称 - 使用缓存优化 + nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group) # 立即插入占位符到 history placeholder_id = str(uuid.uuid4()) diff --git a/plugins/ManagePlugin/main.py b/plugins/ManagePlugin/main.py index a616e64..1fb78fa 100644 --- a/plugins/ManagePlugin/main.py +++ b/plugins/ManagePlugin/main.py @@ -2,8 +2,17 @@ 插件管理插件 提供插件的热重载、启用、禁用等管理功能 +支持的指令: + /插件列表 - 查看所有插件状态 + /重载插件 <名称> - 重载指定插件 + /重载所有插件 - 重载所有插件 + /启用插件 <名称> - 启用指定插件 + /禁用插件 <名称> - 禁用指定插件 + /刷新插件 - 扫描并发现新插件 + /插件帮助 - 显示帮助信息 """ +import sys import tomllib from pathlib import Path from loguru import logger @@ -18,7 +27,10 @@ class ManagePlugin(PluginBase): # 插件元数据 description = "插件管理,支持热重载、启用、禁用" author = "ShiHao" - version = "1.0.0" + version = "2.0.0" + + # 最高加载优先级,确保最先加载 + load_priority = 100 def __init__(self): super().__init__() @@ -34,37 +46,72 @@ class ManagePlugin(PluginBase): self.admins = main_config.get("Bot", {}).get("admins", []) logger.info(f"插件管理插件已加载,管理员: {self.admins}") - @on_text_message() + def _check_admin(self, message: dict) -> bool: + """检查是否是管理员""" + sender_wxid = message.get("SenderWxid", "") + from_wxid = message.get("FromWxid", "") + is_group = message.get("IsGroup", False) + + # 私聊时 sender_wxid 可能为空,使用 from_wxid + user_wxid = sender_wxid if is_group else from_wxid + + return user_wxid in self.admins + + @on_text_message(priority=99) async def handle_command(self, bot, message: dict): """处理管理命令""" content = message.get("Content", "").strip() from_wxid = message.get("FromWxid", "") - sender_wxid = message.get("SenderWxid", "") - - logger.debug(f"ManagePlugin: content={content}, from={from_wxid}, sender={sender_wxid}, admins={self.admins}") # 检查权限 - if not self.admins or sender_wxid not in self.admins: - return + if not self._check_admin(message): + return True # 继续传递给其他插件 + + # 插件帮助 + if content == "/插件帮助" or content == "/plugin help": + await self._show_help(bot, from_wxid) + return False # 插件列表 - if content == "/插件列表" or content == "/plugins": + elif content == "/插件列表" or content == "/plugins": await self._list_plugins(bot, from_wxid) + return False + + # 重载所有插件 + elif content == "/重载所有插件" or content == "/reload all": + await self._reload_all_plugins(bot, from_wxid) + return False # 重载插件 elif content.startswith("/重载插件 ") or content.startswith("/reload "): plugin_name = content.split(maxsplit=1)[1].strip() await self._reload_plugin(bot, from_wxid, plugin_name) + return False # 启用插件 elif content.startswith("/启用插件 ") or content.startswith("/enable "): plugin_name = content.split(maxsplit=1)[1].strip() await self._enable_plugin(bot, from_wxid, plugin_name) + return False # 禁用插件 elif content.startswith("/禁用插件 ") or content.startswith("/disable "): plugin_name = content.split(maxsplit=1)[1].strip() await self._disable_plugin(bot, from_wxid, plugin_name) + return False + + # 刷新插件(发现新插件) + elif content == "/刷新插件" or content == "/refresh": + await self._refresh_plugins(bot, from_wxid) + return False + + # 加载新插件(从目录加载全新插件) + elif content.startswith("/加载插件 ") or content.startswith("/load "): + plugin_name = content.split(maxsplit=1)[1].strip() + await self._load_new_plugin(bot, from_wxid, plugin_name) + return False + + return True # 不是管理命令,继续传递 async def _list_plugins(self, bot, to_wxid: str): """列出所有插件""" @@ -159,3 +206,200 @@ class ManagePlugin(PluginBase): logger.info(f"插件 {plugin_name} 已被禁用") else: await bot.send_text(to_wxid, f"❌ 插件 {plugin_name} 禁用失败") + + async def _show_help(self, bot, to_wxid: str): + """显示帮助信息""" + help_text = """📦 插件管理帮助 + +/插件列表 - 查看所有插件状态 +/插件帮助 - 显示此帮助信息 + +/加载插件 <名称> - 加载新插件(无需重启) +/重载插件 <名称> - 热重载指定插件 +/重载所有插件 - 热重载所有插件 + +/启用插件 <名称> - 启用已禁用的插件 +/禁用插件 <名称> - 禁用指定插件 + +/刷新插件 - 扫描发现新插件 + +示例: +/加载插件 NewPlugin +/重载插件 AIChat +/禁用插件 Weather""" + await bot.send_text(to_wxid, help_text) + + async def _reload_all_plugins(self, bot, to_wxid: str): + """重载所有插件""" + pm = PluginManager() + + await bot.send_text(to_wxid, "⏳ 正在重载所有插件...") + + try: + # 清理插件相关的模块缓存 + modules_to_remove = [ + name for name in sys.modules.keys() + if name.startswith('plugins.') and 'ManagePlugin' not in name + ] + for module_name in modules_to_remove: + del sys.modules[module_name] + + # 重载所有插件 + reloaded = await pm.reload_plugins() + + if reloaded: + await bot.send_text( + to_wxid, + f"✅ 重载完成\n已加载 {len(reloaded)} 个插件:\n" + + "\n".join(f" • {name}" for name in reloaded) + ) + logger.success(f"已重载 {len(reloaded)} 个插件") + else: + await bot.send_text(to_wxid, "⚠️ 没有插件被重载") + + except Exception as e: + logger.error(f"重载所有插件失败: {e}") + await bot.send_text(to_wxid, f"❌ 重载失败: {e}") + + async def _refresh_plugins(self, bot, to_wxid: str): + """刷新插件列表,发现新插件""" + pm = PluginManager() + + try: + # 记录刷新前的插件数量 + old_count = len(pm.plugin_info) + + # 刷新插件列表 + await pm.refresh_plugins() + + # 计算新发现的插件 + new_count = len(pm.plugin_info) + new_plugins = new_count - old_count + + if new_plugins > 0: + # 获取新发现的插件名称 + new_plugin_names = [ + info["name"] for info in pm.plugin_info.values() + if not info.get("enabled", False) + ][-new_plugins:] + + await bot.send_text( + to_wxid, + f"✅ 发现 {new_plugins} 个新插件:\n" + + "\n".join(f" • {name}" for name in new_plugin_names) + + "\n\n使用 /启用插件 <名称> 来启用" + ) + logger.info(f"发现 {new_plugins} 个新插件: {new_plugin_names}") + else: + await bot.send_text(to_wxid, "ℹ️ 没有发现新插件") + + except Exception as e: + logger.error(f"刷新插件失败: {e}") + await bot.send_text(to_wxid, f"❌ 刷新失败: {e}") + + async def _load_new_plugin(self, bot, to_wxid: str, plugin_name: str): + """加载全新的插件(支持插件类名或目录名)""" + import os + import importlib + import inspect + + pm = PluginManager() + + # 检查是否已加载 + if plugin_name in pm.plugins: + await bot.send_text(to_wxid, f"ℹ️ 插件 {plugin_name} 已经加载,如需重载请使用 /重载插件") + return + + try: + # 尝试查找插件 + found = False + plugin_class = None + actual_plugin_name = None + + for dirname in os.listdir("plugins"): + dirpath = f"plugins/{dirname}" + if not os.path.isdir(dirpath) or not os.path.exists(f"{dirpath}/main.py"): + continue + + # 支持通过目录名或类名查找 + if dirname == plugin_name or dirname.lower() == plugin_name.lower(): + # 通过目录名匹配 + module_name = f"plugins.{dirname}.main" + + # 清理旧的模块缓存 + if module_name in sys.modules: + del sys.modules[module_name] + + # 导入模块 + module = importlib.import_module(module_name) + + # 查找插件类 + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) and + issubclass(obj, PluginBase) and + obj != PluginBase): + plugin_class = obj + actual_plugin_name = obj.__name__ + found = True + break + + if found: + break + + # 尝试通过类名匹配 + try: + module_name = f"plugins.{dirname}.main" + if module_name in sys.modules: + del sys.modules[module_name] + + module = importlib.import_module(module_name) + + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) and + issubclass(obj, PluginBase) and + obj != PluginBase and + (obj.__name__ == plugin_name or obj.__name__.lower() == plugin_name.lower())): + plugin_class = obj + actual_plugin_name = obj.__name__ + found = True + break + + if found: + break + except: + continue + + if not found or not plugin_class: + await bot.send_text( + to_wxid, + f"❌ 未找到插件 {plugin_name}\n" + f"请确认:\n" + f"1. plugins/{plugin_name}/main.py 存在\n" + f"2. main.py 中有继承 PluginBase 的类" + ) + return + + # 检查是否已加载(用实际类名再检查一次) + if actual_plugin_name in pm.plugins: + await bot.send_text(to_wxid, f"ℹ️ 插件 {actual_plugin_name} 已经加载") + return + + # 加载插件 + success = await pm._load_plugin_class(plugin_class) + + if success: + await bot.send_text( + to_wxid, + f"✅ 插件加载成功\n" + f"名称: {actual_plugin_name}\n" + f"版本: {plugin_class.version}\n" + f"作者: {plugin_class.author}" + ) + logger.success(f"新插件 {actual_plugin_name} 已加载") + else: + await bot.send_text(to_wxid, f"❌ 插件 {actual_plugin_name} 加载失败") + + except Exception as e: + import traceback + logger.error(f"加载新插件失败: {e}\n{traceback.format_exc()}") + await bot.send_text(to_wxid, f"❌ 加载失败: {e}") diff --git a/plugins/Menu/main.py b/plugins/Menu/main.py new file mode 100644 index 0000000..d3b4bdc --- /dev/null +++ b/plugins/Menu/main.py @@ -0,0 +1,100 @@ +""" +菜单插件 + +用户发送 /菜单、/帮助 等指令时,按顺序发送菜单图片 +图片命名格式:menu1.png、menu2.png、menu3.png ... +""" + +import re +import asyncio +from pathlib import Path +from loguru import logger +from utils.plugin_base import PluginBase +from utils.decorators import on_text_message + + +class Menu(PluginBase): + """菜单插件""" + + # 插件元数据 + description = "菜单插件,发送帮助图片" + author = "ShiHao" + version = "1.0.0" + + def __init__(self): + super().__init__() + self.menu_dir = None + self.trigger_commands = ["/菜单", "/帮助", "/help", "/menu"] + self.send_interval = 0.5 # 发送间隔(秒),避免发送过快 + + async def async_init(self): + """插件异步初始化""" + # 设置菜单图片目录 + self.menu_dir = Path(__file__).parent / "images" + self.menu_dir.mkdir(exist_ok=True) + + # 扫描现有图片 + images = self._get_menu_images() + logger.info(f"菜单插件已加载,图片目录: {self.menu_dir},找到 {len(images)} 张菜单图片") + + def _get_menu_images(self) -> list: + """ + 获取所有符合命名规范的菜单图片,按序号排序 + + 命名格式:menu1.png、menu2.jpg、menu3.jpeg 等 + """ + if not self.menu_dir or not self.menu_dir.exists(): + return [] + + # 匹配 menu + 数字 + 图片扩展名 + pattern = re.compile(r'^menu(\d+)\.(png|jpg|jpeg|gif|bmp)$', re.IGNORECASE) + + images = [] + for file in self.menu_dir.iterdir(): + if file.is_file(): + match = pattern.match(file.name) + if match: + seq_num = int(match.group(1)) + images.append((seq_num, file)) + + # 按序号排序 + images.sort(key=lambda x: x[0]) + + # 只返回文件路径 + return [img[1] for img in images] + + @on_text_message(priority=60) + async def handle_menu_command(self, bot, message: dict): + """处理菜单指令""" + content = message.get("Content", "").strip() + from_wxid = message.get("FromWxid", "") + + # 检查是否是菜单指令 + if content not in self.trigger_commands: + return True # 继续传递给其他插件 + + logger.info(f"收到菜单指令: {content}, from: {from_wxid}") + + # 获取菜单图片 + images = self._get_menu_images() + + if not images: + await bot.send_text(from_wxid, "暂无菜单图片,请联系管理员添加") + logger.warning(f"菜单图片目录为空: {self.menu_dir}") + return False + + # 按顺序发送图片 + for i, image_path in enumerate(images): + try: + await bot.send_image(from_wxid, str(image_path)) + logger.debug(f"已发送菜单图片 {i+1}/{len(images)}: {image_path.name}") + + # 发送间隔,避免发送过快 + if i < len(images) - 1: + await asyncio.sleep(self.send_interval) + + except Exception as e: + logger.error(f"发送菜单图片失败: {image_path}, 错误: {e}") + + logger.success(f"菜单图片发送完成,共 {len(images)} 张") + return False # 阻止继续传递 diff --git a/plugins/MessageLogger/main.py b/plugins/MessageLogger/main.py index c1fdb90..e72bca7 100644 --- a/plugins/MessageLogger/main.py +++ b/plugins/MessageLogger/main.py @@ -18,6 +18,7 @@ from utils.decorators import ( on_file_message, on_emoji_message ) +from utils.redis_cache import RedisCache, get_cache import pymysql from WechatHook import WechatHookClient from minio import Minio @@ -39,6 +40,7 @@ class MessageLogger(PluginBase): super().__init__() self.config = None self.db_config = None + self.redis_cache = None # Redis 缓存实例 # 创建独立的日志记录器 self._setup_logger() @@ -83,9 +85,22 @@ class MessageLogger(PluginBase): self.db_config = self.config["database"] + # 初始化 Redis 缓存 + redis_config = self.config.get("redis", {}) + if redis_config.get("enabled", False): + self.log.info("正在初始化 Redis 缓存...") + self.redis_cache = RedisCache(redis_config) + if self.redis_cache.enabled: + self.log.success(f"Redis 缓存初始化成功,TTL={redis_config.get('ttl', 3600)}秒") + else: + self.log.warning("Redis 缓存初始化失败,将不使用缓存") + self.redis_cache = None + else: + self.log.info("Redis 缓存未启用") + # 初始化 MinIO 客户端 self.minio_client = Minio( - "101.201.65.129:19000", + "115.190.113.141:19000", access_key="admin", secret_key="80012029Lz", secure=False @@ -216,7 +231,7 @@ class MessageLogger(PluginBase): return ("", "", "", "", "0") async def download_image_and_upload(self, bot, cdnurl: str, aeskey: str) -> str: - """下载图片并上传到 MinIO""" + """下载图片并上传到 MinIO,同时缓存 base64 供其他插件使用""" try: temp_file = Path(__file__).parent / f"temp_{uuid.uuid4().hex}.jpg" success = await bot.cdn_download(cdnurl, aeskey, str(temp_file), file_type=2) @@ -225,12 +240,26 @@ class MessageLogger(PluginBase): # 等待文件下载完成 import asyncio + import base64 for _ in range(50): if temp_file.exists() and temp_file.stat().st_size > 0: break await asyncio.sleep(0.1) if temp_file.exists() and temp_file.stat().st_size > 0: + # 读取文件并缓存 base64(供 AIChat 等插件使用) + with open(temp_file, "rb") as f: + image_data = f.read() + base64_data = f"data:image/jpeg;base64,{base64.b64encode(image_data).decode()}" + + # 缓存到 Redis(5分钟过期) + redis_cache = get_cache() + if redis_cache and redis_cache.enabled: + media_key = RedisCache.generate_media_key(cdnurl, aeskey) + if media_key: + redis_cache.cache_media(media_key, base64_data, "image", ttl=300) + self.log.debug(f"图片已缓存到 Redis: {media_key[:20]}...") + media_url = await self.upload_file_to_minio(str(temp_file), "images") temp_file.unlink() return media_url @@ -326,13 +355,25 @@ class MessageLogger(PluginBase): return "" async def download_and_upload(self, url: str, file_type: str, ext: str) -> str: - """下载文件并上传到 MinIO""" + """下载文件并上传到 MinIO,同时缓存 base64 供其他插件使用""" try: + import base64 # 下载文件 async with aiohttp.ClientSession() as session: async with session.get(url, timeout=aiohttp.ClientTimeout(total=10)) as resp: if resp.status == 200: data = await resp.read() + + # 缓存表情包 base64(供 AIChat 等插件使用) + if file_type == "emojis" and data: + redis_cache = get_cache() + if redis_cache and redis_cache.enabled: + base64_data = f"data:image/gif;base64,{base64.b64encode(data).decode()}" + media_key = RedisCache.generate_media_key(cdnurl=url) + if media_key: + redis_cache.cache_media(media_key, base64_data, "emoji", ttl=300) + self.log.debug(f"表情包已缓存到 Redis: {media_key[:20]}...") + # 保存到临时文件 temp_file = Path(__file__).parent / f"temp_{uuid.uuid4().hex}{ext}" temp_file.write_bytes(data) @@ -374,7 +415,7 @@ class MessageLogger(PluginBase): ) # 返回访问 URL - url = f"http://101.201.65.129:19000/{self.minio_bucket}/{object_name}" + url = f"http://115.190.113.141:19000/{self.minio_bucket}/{object_name}" self.log.debug(f"文件上传成功: {url}") return url @@ -405,29 +446,45 @@ class MessageLogger(PluginBase): avatar_url = "" if is_group and self.config["behavior"]["fetch_avatar"]: - try: - self.log.info(f"尝试获取用户信息: from_wxid={from_wxid}, sender_wxid={sender_wxid}") - user_info = await bot.get_user_info_in_chatroom(from_wxid, sender_wxid) - self.log.info(f"获取到用户信息: {user_info}") + cache_hit = False - if user_info: - # 处理不同的数据结构 - if isinstance(user_info.get("nickName"), dict): - nickname = user_info.get("nickName", {}).get("string", "") + # 1. 先尝试从 Redis 缓存获取 + if self.redis_cache and self.redis_cache.enabled: + cached_info = self.redis_cache.get_user_basic_info(from_wxid, sender_wxid) + if cached_info: + nickname = cached_info.get("nickname", "") + avatar_url = cached_info.get("avatar_url", "") + if nickname and avatar_url: + cache_hit = True + self.log.debug(f"[缓存命中] {sender_wxid}: {nickname}") + + # 2. 缓存未命中,调用 API 获取 + if not cache_hit: + try: + self.log.info(f"[缓存未命中] 调用API获取用户信息: {sender_wxid}") + user_info = await bot.get_user_info_in_chatroom(from_wxid, sender_wxid) + + if user_info: + # 处理不同的数据结构 + if isinstance(user_info.get("nickName"), dict): + nickname = user_info.get("nickName", {}).get("string", "") + else: + nickname = user_info.get("nickName", "") + + avatar_url = user_info.get("bigHeadImgUrl", "") + self.log.info(f"API获取成功: nickname={nickname}, avatar_url={avatar_url[:50] if avatar_url else ''}...") + + # 3. 将用户信息存入 Redis 缓存 + if self.redis_cache and self.redis_cache.enabled and nickname: + self.redis_cache.set_user_info(from_wxid, sender_wxid, user_info) + self.log.debug(f"[已缓存] {sender_wxid}: {nickname}") else: - nickname = user_info.get("nickName", "") + self.log.warning(f"用户信息为空: {sender_wxid}") - avatar_url = user_info.get("bigHeadImgUrl", "") - self.log.info(f"解析用户信息: nickname={nickname}, avatar_url={avatar_url[:50]}...") - else: - self.log.warning(f"用户信息为空: {sender_wxid}") + except Exception as e: + self.log.error(f"获取用户信息失败: {e}") - except Exception as e: - self.log.error(f"获取用户信息失败: {e}") - import traceback - self.log.error(f"详细错误: {traceback.format_exc()}") - - # 如果获取失败,从历史记录中查找 + # 4. 如果仍然没有获取到,从历史记录中查找 if not nickname or not avatar_url: self.log.info(f"尝试从历史记录获取用户信息: {sender_wxid}") try: diff --git a/plugins/PerformanceMonitor/main.py b/plugins/PerformanceMonitor/main.py index 90ac629..8b814ff 100644 --- a/plugins/PerformanceMonitor/main.py +++ b/plugins/PerformanceMonitor/main.py @@ -46,7 +46,7 @@ class PerformanceMonitor(PluginBase): if sender_wxid not in admins: return - if content in ["/性能", "/stats", "/状态"]: + if content in ["/性能", "/stats", "/状态", "/性能报告"]: stats_msg = await self._get_performance_stats(bot) await bot.send_text(from_wxid, stats_msg) return False # 阻止其他插件处理 @@ -66,11 +66,21 @@ class PerformanceMonitor(PluginBase): async def _get_performance_stats(self, bot) -> str: """获取性能统计信息""" try: + # 尝试使用新的性能监控器 + try: + from utils.bot_utils import get_performance_monitor + monitor = get_performance_monitor() + if monitor and monitor.message_received > 0: + return self._format_new_stats(monitor) + except ImportError: + pass + + # 降级到旧的统计方式 stats = await self._get_performance_data() - + # 格式化统计信息 uptime_hours = (time.time() - self.start_time) / 3600 - + msg = f"""📊 系统性能统计 🕐 运行时间: {uptime_hours:.1f} 小时 @@ -94,11 +104,57 @@ class PerformanceMonitor(PluginBase): • 过滤模式: {stats['ignore_mode']}""" return msg - + except Exception as e: logger.error(f"获取性能统计失败: {e}") return f"❌ 获取性能统计失败: {str(e)}" + def _format_new_stats(self, monitor) -> str: + """格式化新性能监控器的统计信息""" + stats = monitor.get_stats() + + # 基础信息 + msg = f"""📊 系统性能报告 + +🕐 运行时间: {stats['uptime_formatted']} + +📨 消息统计: + • 收到: {stats['messages']['received']} + • 处理: {stats['messages']['processed']} + • 失败: {stats['messages']['failed']} + • 丢弃: {stats['messages']['dropped']} + • 成功率: {stats['messages']['success_rate']} + • 处理速率: {stats['messages']['processing_rate']} + +⚡ 处理性能: + • 平均耗时: {stats['processing_time']['average_ms']}ms + • 最大耗时: {stats['processing_time']['max_ms']}ms + • 最小耗时: {stats['processing_time']['min_ms']}ms + +📦 队列状态: + • 当前大小: {stats['queue']['current_size']} + • 历史最大: {stats['queue']['max_size']}""" + + # 熔断器状态 + cb = stats.get('circuit_breaker', {}) + if cb: + state_icon = {'closed': '🟢', 'open': '🔴', 'half_open': '🟡'}.get(cb.get('state', ''), '⚪') + msg += f""" + +🔌 熔断器: + • 状态: {state_icon} {cb.get('state', 'N/A')} + • 失败计数: {cb.get('failure_count', 0)} + • 恢复时间: {cb.get('current_recovery_time', 0):.0f}s""" + + # 插件耗时排行 + plugins = stats.get('plugins', []) + if plugins: + msg += "\n\n🔧 插件耗时排行:" + for i, p in enumerate(plugins[:5], 1): + msg += f"\n {i}. {p['name']}: {p['avg_time_ms']}ms ({p['calls']}次)" + + return msg + async def _get_performance_data(self) -> dict: """获取性能数据""" # 系统资源(简化版本,不依赖psutil) diff --git a/plugins/SignIn/main.py b/plugins/SignIn/main.py index 358dd7b..3dea87a 100644 --- a/plugins/SignIn/main.py +++ b/plugins/SignIn/main.py @@ -19,6 +19,7 @@ import pymysql from loguru import logger from utils.plugin_base import PluginBase from utils.decorators import on_text_message +from utils.redis_cache import get_cache from WechatHook import WechatHookClient try: @@ -49,14 +50,14 @@ class SignInPlugin(PluginBase): self.config = tomllib.load(f) self.db_config = self.config["database"] - + # 创建临时文件夹 self.temp_dir = Path(__file__).parent / "temp" self.temp_dir.mkdir(exist_ok=True) - + # 图片文件夹 self.images_dir = Path(__file__).parent / "images" - + logger.success("签到插件初始化完成") def get_db_connection(self): @@ -149,29 +150,42 @@ class SignInPlugin(PluginBase): logger.error(f"更新用户昵称失败: {e}") return False - async def get_user_nickname_from_group(self, client: WechatHookClient, + async def get_user_nickname_from_group(self, client: WechatHookClient, group_wxid: str, user_wxid: str) -> str: - """从群聊中获取用户昵称""" + """从群聊中获取用户昵称(优先使用缓存)""" try: - logger.debug(f"尝试获取用户 {user_wxid} 在群 {group_wxid} 中的昵称") - - # 使用11174 API获取单个用户的详细信息 + # 动态获取缓存实例(由 MessageLogger 初始化) + redis_cache = get_cache() + + # 1. 先尝试从 Redis 缓存获取 + if redis_cache and redis_cache.enabled: + cached_info = redis_cache.get_user_basic_info(group_wxid, user_wxid) + if cached_info and cached_info.get("nickname"): + logger.debug(f"[缓存命中] {user_wxid}: {cached_info['nickname']}") + return cached_info["nickname"] + + # 2. 缓存未命中,调用 API 获取 + logger.debug(f"[缓存未命中] 调用API获取用户昵称: {user_wxid}") user_info = await client.get_user_info_in_chatroom(group_wxid, user_wxid) - + if user_info: # 从返回的详细信息中提取昵称 nickname = user_info.get("nickName", {}).get("string", "") - + if nickname: - logger.success(f"获取到用户昵称: {user_wxid} -> {nickname}") + logger.success(f"API获取用户昵称成功: {user_wxid} -> {nickname}") + # 3. 将用户信息存入缓存 + if redis_cache and redis_cache.enabled: + redis_cache.set_user_info(group_wxid, user_wxid, user_info) + logger.debug(f"[已缓存] {user_wxid}: {nickname}") return nickname else: logger.warning(f"用户 {user_wxid} 的昵称字段为空") else: logger.warning(f"未找到用户 {user_wxid} 在群 {group_wxid} 中的信息") - + return "" - + except Exception as e: logger.error(f"获取群成员昵称失败: {e}") return "" @@ -770,13 +784,24 @@ class SignInPlugin(PluginBase): current_points = updated_user["points"] if updated_user else total_earned updated_user["points"] = current_points - # 尝试获取用户头像 + # 尝试获取用户头像(优先使用缓存) avatar_url = None if is_group: try: - user_detail = await client.get_user_info_in_chatroom(from_wxid, user_wxid) - if user_detail: - avatar_url = user_detail.get("bigHeadImgUrl", "") + redis_cache = get_cache() + # 先从缓存获取 + if redis_cache and redis_cache.enabled: + cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid) + if cached_info: + avatar_url = cached_info.get("avatar_url", "") + # 缓存未命中则调用 API + if not avatar_url: + user_detail = await client.get_user_info_in_chatroom(from_wxid, user_wxid) + if user_detail: + avatar_url = user_detail.get("bigHeadImgUrl", "") + # 存入缓存 + if redis_cache and redis_cache.enabled: + redis_cache.set_user_info(from_wxid, user_wxid, user_detail) except Exception as e: logger.warning(f"获取用户头像失败: {e}") @@ -864,13 +889,24 @@ class SignInPlugin(PluginBase): self.update_user_nickname(user_wxid, nickname) user_info["nickname"] = nickname - # 尝试获取用户头像 + # 尝试获取用户头像(优先使用缓存) avatar_url = None if is_group: try: - user_detail = await client.get_user_info_in_chatroom(from_wxid, user_wxid) - if user_detail: - avatar_url = user_detail.get("bigHeadImgUrl", "") + redis_cache = get_cache() + # 先从缓存获取 + if redis_cache and redis_cache.enabled: + cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid) + if cached_info: + avatar_url = cached_info.get("avatar_url", "") + # 缓存未命中则调用 API + if not avatar_url: + user_detail = await client.get_user_info_in_chatroom(from_wxid, user_wxid) + if user_detail: + avatar_url = user_detail.get("bigHeadImgUrl", "") + # 存入缓存 + if redis_cache and redis_cache.enabled: + redis_cache.set_user_info(from_wxid, user_wxid, user_detail) except Exception as e: logger.warning(f"获取用户头像失败: {e}") diff --git a/requirements.txt b/requirements.txt index 548cbda..32a16c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ APScheduler==3.11.0 aiohttp==3.9.1 Pillow>=10.0.0 aiohttp-socks>=0.8.0 +redis>=5.0.0 diff --git a/utils/bot_utils.py b/utils/bot_utils.py new file mode 100644 index 0000000..e7b2812 --- /dev/null +++ b/utils/bot_utils.py @@ -0,0 +1,658 @@ +""" +机器人核心工具模块 + +包含: +- 优先级消息队列 +- 自适应熔断器 +- 请求重试机制 +- 配置热更新 +- 性能监控 +""" + +import asyncio +import time +import heapq +import os +import tomllib +import functools +from pathlib import Path +from typing import Dict, List, Optional, Callable, Any, Tuple +from dataclasses import dataclass, field +from loguru import logger +import aiohttp + + +# ==================== 优先级消息队列 ==================== + +# 消息优先级定义 +class MessagePriority: + """消息优先级常量""" + CRITICAL = 100 # 系统消息、登录信息 + HIGH = 80 # 管理员命令 + NORMAL = 50 # @机器人的消息 + LOW = 20 # 普通群消息 + +# 高优先级消息类型 +PRIORITY_MESSAGE_TYPES = { + 11025: MessagePriority.CRITICAL, # 登录信息 + 11058: MessagePriority.CRITICAL, # 系统消息 + 11098: MessagePriority.HIGH, # 群成员加入 + 11099: MessagePriority.HIGH, # 群成员退出 + 11100: MessagePriority.HIGH, # 群信息变更 + 11056: MessagePriority.HIGH, # 好友请求 +} + + +@dataclass(order=True) +class PriorityMessage: + """优先级消息包装""" + priority: int + timestamp: float = field(compare=False) + msg_type: int = field(compare=False) + data: dict = field(compare=False) + + def __init__(self, msg_type: int, data: dict, priority: int = None): + # 优先级越高,数值越小(因为heapq是最小堆) + self.priority = -(priority or PRIORITY_MESSAGE_TYPES.get(msg_type, MessagePriority.NORMAL)) + self.timestamp = time.time() + self.msg_type = msg_type + self.data = data + + +class PriorityMessageQueue: + """优先级消息队列""" + + def __init__(self, maxsize: int = 1000): + self.maxsize = maxsize + self._heap: List[PriorityMessage] = [] + self._lock = asyncio.Lock() + self._not_empty = asyncio.Event() + self._unfinished_tasks = 0 + self._finished = asyncio.Event() + self._finished.set() + + def qsize(self) -> int: + """返回队列大小""" + return len(self._heap) + + def empty(self) -> bool: + """队列是否为空""" + return len(self._heap) == 0 + + def full(self) -> bool: + """队列是否已满""" + return len(self._heap) >= self.maxsize + + async def put(self, msg_type: int, data: dict, priority: int = None): + """添加消息到队列""" + async with self._lock: + msg = PriorityMessage(msg_type, data, priority) + heapq.heappush(self._heap, msg) + self._unfinished_tasks += 1 + self._finished.clear() + self._not_empty.set() + + async def get(self) -> Tuple[int, dict]: + """获取优先级最高的消息""" + while True: + async with self._lock: + if self._heap: + msg = heapq.heappop(self._heap) + if not self._heap: + self._not_empty.clear() + return (msg.msg_type, msg.data) + + # 等待新消息 + await self._not_empty.wait() + + def get_nowait(self) -> Tuple[int, dict]: + """非阻塞获取消息""" + if not self._heap: + raise asyncio.QueueEmpty() + msg = heapq.heappop(self._heap) + if not self._heap: + self._not_empty.clear() + return (msg.msg_type, msg.data) + + def task_done(self): + """标记任务完成""" + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + async def join(self): + """等待所有任务完成""" + await self._finished.wait() + + def drop_lowest_priority(self) -> bool: + """丢弃优先级最低的消息""" + if not self._heap: + return False + + # 找到优先级最低的消息(priority值最大,因为是负数所以最小) + min_idx = 0 + for i, msg in enumerate(self._heap): + if msg.priority > self._heap[min_idx].priority: + min_idx = i + + # 删除该消息 + self._heap.pop(min_idx) + heapq.heapify(self._heap) + self._unfinished_tasks -= 1 + return True + + +# ==================== 自适应熔断器 ==================== + +class AdaptiveCircuitBreaker: + """自适应熔断器""" + + # 熔断器状态 + STATE_CLOSED = "closed" # 正常状态 + STATE_OPEN = "open" # 熔断状态 + STATE_HALF_OPEN = "half_open" # 半开状态(尝试恢复) + + def __init__( + self, + failure_threshold: int = 5, + success_threshold: int = 3, + initial_recovery_time: float = 5.0, + max_recovery_time: float = 300.0, + recovery_multiplier: float = 2.0 + ): + """ + 初始化熔断器 + + Args: + failure_threshold: 触发熔断的连续失败次数 + success_threshold: 恢复正常的连续成功次数 + initial_recovery_time: 初始恢复等待时间(秒) + max_recovery_time: 最大恢复等待时间(秒) + recovery_multiplier: 恢复时间增长倍数 + """ + self.failure_threshold = failure_threshold + self.success_threshold = success_threshold + self.initial_recovery_time = initial_recovery_time + self.max_recovery_time = max_recovery_time + self.recovery_multiplier = recovery_multiplier + + # 状态 + self.state = self.STATE_CLOSED + self.failure_count = 0 + self.success_count = 0 + self.last_failure_time = 0 + self.current_recovery_time = initial_recovery_time + + # 统计 + self.total_failures = 0 + self.total_successes = 0 + self.total_rejections = 0 + + def is_open(self) -> bool: + """检查熔断器是否开启(是否应该拒绝请求)""" + if self.state == self.STATE_CLOSED: + return False + + if self.state == self.STATE_OPEN: + # 检查是否可以尝试恢复 + elapsed = time.time() - self.last_failure_time + if elapsed >= self.current_recovery_time: + self.state = self.STATE_HALF_OPEN + self.success_count = 0 + logger.info(f"熔断器进入半开状态,尝试恢复(等待了 {elapsed:.1f}s)") + return False + return True + + # 半开状态,允许请求通过 + return False + + def record_success(self): + """记录成功""" + self.total_successes += 1 + + if self.state == self.STATE_HALF_OPEN: + self.success_count += 1 + if self.success_count >= self.success_threshold: + # 恢复正常 + self.state = self.STATE_CLOSED + self.failure_count = 0 + self.success_count = 0 + self.current_recovery_time = self.initial_recovery_time + logger.success(f"熔断器已恢复正常(连续成功 {self.success_threshold} 次)") + else: + # 正常状态,重置失败计数 + self.failure_count = 0 + + def record_failure(self): + """记录失败""" + self.total_failures += 1 + self.failure_count += 1 + self.last_failure_time = time.time() + + if self.state == self.STATE_HALF_OPEN: + # 半开状态下失败,重新熔断 + self.state = self.STATE_OPEN + self.success_count = 0 + # 增加恢复时间 + self.current_recovery_time = min( + self.current_recovery_time * self.recovery_multiplier, + self.max_recovery_time + ) + logger.warning(f"熔断器重新开启,下次恢复等待 {self.current_recovery_time:.1f}s") + + elif self.state == self.STATE_CLOSED: + if self.failure_count >= self.failure_threshold: + self.state = self.STATE_OPEN + logger.warning(f"熔断器开启,连续失败 {self.failure_count} 次") + + def record_rejection(self): + """记录被拒绝的请求""" + self.total_rejections += 1 + + def get_stats(self) -> dict: + """获取统计信息""" + return { + "state": self.state, + "failure_count": self.failure_count, + "success_count": self.success_count, + "current_recovery_time": self.current_recovery_time, + "total_failures": self.total_failures, + "total_successes": self.total_successes, + "total_rejections": self.total_rejections + } + + +# ==================== 请求重试机制 ==================== + +class RetryConfig: + """重试配置""" + def __init__( + self, + max_retries: int = 3, + initial_delay: float = 1.0, + max_delay: float = 30.0, + exponential_base: float = 2.0, + retryable_exceptions: tuple = ( + aiohttp.ClientError, + asyncio.TimeoutError, + ConnectionError, + ) + ): + self.max_retries = max_retries + self.initial_delay = initial_delay + self.max_delay = max_delay + self.exponential_base = exponential_base + self.retryable_exceptions = retryable_exceptions + + +def retry_async(config: RetryConfig = None): + """异步重试装饰器""" + if config is None: + config = RetryConfig() + + def decorator(func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + last_exception = None + + for attempt in range(config.max_retries + 1): + try: + return await func(*args, **kwargs) + except config.retryable_exceptions as e: + last_exception = e + + if attempt == config.max_retries: + logger.error(f"重试 {config.max_retries} 次后仍然失败: {func.__name__}") + raise + + # 计算延迟时间(指数退避) + delay = min( + config.initial_delay * (config.exponential_base ** attempt), + config.max_delay + ) + + logger.warning( + f"请求失败,{delay:.1f}s 后重试 " + f"(第 {attempt + 1}/{config.max_retries} 次): {e}" + ) + await asyncio.sleep(delay) + + raise last_exception + return wrapper + return decorator + + +async def request_with_retry( + session: aiohttp.ClientSession, + method: str, + url: str, + max_retries: int = 3, + **kwargs +) -> aiohttp.ClientResponse: + """带重试的 HTTP 请求""" + config = RetryConfig(max_retries=max_retries) + last_exception = None + + for attempt in range(config.max_retries + 1): + try: + response = await session.request(method, url, **kwargs) + return response + except config.retryable_exceptions as e: + last_exception = e + + if attempt == config.max_retries: + raise + + delay = min( + config.initial_delay * (config.exponential_base ** attempt), + config.max_delay + ) + + logger.warning(f"HTTP 请求失败,{delay:.1f}s 后重试: {e}") + await asyncio.sleep(delay) + + raise last_exception + + +# ==================== 配置热更新 ==================== + +class ConfigWatcher: + """配置文件监听器""" + + def __init__(self, config_path: str, check_interval: float = 5.0): + """ + 初始化配置监听器 + + Args: + config_path: 配置文件路径 + check_interval: 检查间隔(秒) + """ + self.config_path = Path(config_path) + self.check_interval = check_interval + self.last_mtime = 0 + self.callbacks: List[Callable[[dict], Any]] = [] + self.current_config: dict = {} + self._running = False + self._task: Optional[asyncio.Task] = None + + def register_callback(self, callback: Callable[[dict], Any]): + """注册配置更新回调""" + self.callbacks.append(callback) + + def unregister_callback(self, callback: Callable[[dict], Any]): + """取消注册回调""" + if callback in self.callbacks: + self.callbacks.remove(callback) + + def _load_config(self) -> dict: + """加载配置文件""" + try: + with open(self.config_path, "rb") as f: + return tomllib.load(f) + except Exception as e: + logger.error(f"加载配置文件失败: {e}") + return {} + + def get_config(self) -> dict: + """获取当前配置""" + return self.current_config + + async def start(self): + """启动配置监听""" + if self._running: + return + + self._running = True + + # 初始加载 + if self.config_path.exists(): + self.last_mtime = os.path.getmtime(self.config_path) + self.current_config = self._load_config() + + self._task = asyncio.create_task(self._watch_loop()) + logger.info(f"配置监听器已启动: {self.config_path}") + + async def stop(self): + """停止配置监听""" + self._running = False + if self._task: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + logger.info("配置监听器已停止") + + async def _watch_loop(self): + """监听循环""" + while self._running: + try: + await asyncio.sleep(self.check_interval) + + if not self.config_path.exists(): + continue + + mtime = os.path.getmtime(self.config_path) + if mtime > self.last_mtime: + logger.info("检测到配置文件变化,重新加载...") + new_config = self._load_config() + + if new_config: + old_config = self.current_config + self.current_config = new_config + self.last_mtime = mtime + + # 通知所有回调 + for callback in self.callbacks: + try: + if asyncio.iscoroutinefunction(callback): + await callback(new_config) + else: + callback(new_config) + except Exception as e: + logger.error(f"配置更新回调执行失败: {e}") + + logger.success("配置已热更新") + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"配置监听异常: {e}") + + +# ==================== 性能监控 ==================== + +class PerformanceMonitor: + """性能监控器""" + + def __init__(self): + self.start_time = time.time() + + # 消息统计 + self.message_received = 0 + self.message_processed = 0 + self.message_failed = 0 + self.message_dropped = 0 + + # 处理时间统计 + self.processing_times: List[float] = [] + self.max_processing_times = 1000 # 保留最近1000条记录 + + # 插件统计 + self.plugin_stats: Dict[str, dict] = {} + + # 队列统计 + self.queue_size_history: List[Tuple[float, int]] = [] + self.max_queue_history = 100 + + # 熔断器统计 + self.circuit_breaker_stats: dict = {} + + def record_message_received(self): + """记录收到消息""" + self.message_received += 1 + + def record_message_processed(self, processing_time: float): + """记录消息处理完成""" + self.message_processed += 1 + self.processing_times.append(processing_time) + + # 限制历史记录数量 + if len(self.processing_times) > self.max_processing_times: + self.processing_times = self.processing_times[-self.max_processing_times:] + + def record_message_failed(self): + """记录消息处理失败""" + self.message_failed += 1 + + def record_message_dropped(self): + """记录消息被丢弃""" + self.message_dropped += 1 + + def record_queue_size(self, size: int): + """记录队列大小""" + self.queue_size_history.append((time.time(), size)) + + if len(self.queue_size_history) > self.max_queue_history: + self.queue_size_history = self.queue_size_history[-self.max_queue_history:] + + def record_plugin_execution(self, plugin_name: str, execution_time: float, success: bool): + """记录插件执行""" + if plugin_name not in self.plugin_stats: + self.plugin_stats[plugin_name] = { + "total_calls": 0, + "success_calls": 0, + "failed_calls": 0, + "total_time": 0, + "max_time": 0, + "recent_times": [] + } + + stats = self.plugin_stats[plugin_name] + stats["total_calls"] += 1 + stats["total_time"] += execution_time + stats["max_time"] = max(stats["max_time"], execution_time) + stats["recent_times"].append(execution_time) + + if len(stats["recent_times"]) > 100: + stats["recent_times"] = stats["recent_times"][-100:] + + if success: + stats["success_calls"] += 1 + else: + stats["failed_calls"] += 1 + + def update_circuit_breaker_stats(self, stats: dict): + """更新熔断器统计""" + self.circuit_breaker_stats = stats + + def get_stats(self) -> dict: + """获取完整统计信息""" + uptime = time.time() - self.start_time + + # 计算平均处理时间 + avg_processing_time = 0 + if self.processing_times: + avg_processing_time = sum(self.processing_times) / len(self.processing_times) + + # 计算处理速率 + processing_rate = self.message_processed / uptime if uptime > 0 else 0 + + # 计算成功率 + total = self.message_processed + self.message_failed + success_rate = self.message_processed / total if total > 0 else 1.0 + + return { + "uptime_seconds": uptime, + "uptime_formatted": self._format_uptime(uptime), + "messages": { + "received": self.message_received, + "processed": self.message_processed, + "failed": self.message_failed, + "dropped": self.message_dropped, + "success_rate": f"{success_rate * 100:.1f}%", + "processing_rate": f"{processing_rate:.2f}/s" + }, + "processing_time": { + "average_ms": f"{avg_processing_time * 1000:.1f}", + "max_ms": f"{max(self.processing_times) * 1000:.1f}" if self.processing_times else "0", + "min_ms": f"{min(self.processing_times) * 1000:.1f}" if self.processing_times else "0" + }, + "queue": { + "current_size": self.queue_size_history[-1][1] if self.queue_size_history else 0, + "max_size": max(s[1] for s in self.queue_size_history) if self.queue_size_history else 0 + }, + "circuit_breaker": self.circuit_breaker_stats, + "plugins": self._get_plugin_summary() + } + + def _get_plugin_summary(self) -> List[dict]: + """获取插件统计摘要""" + summary = [] + for name, stats in self.plugin_stats.items(): + avg_time = stats["total_time"] / stats["total_calls"] if stats["total_calls"] > 0 else 0 + summary.append({ + "name": name, + "calls": stats["total_calls"], + "success_rate": f"{stats['success_calls'] / stats['total_calls'] * 100:.1f}%" if stats["total_calls"] > 0 else "N/A", + "avg_time_ms": f"{avg_time * 1000:.1f}", + "max_time_ms": f"{stats['max_time'] * 1000:.1f}" + }) + + # 按平均时间排序 + summary.sort(key=lambda x: float(x["avg_time_ms"]), reverse=True) + return summary + + def _format_uptime(self, seconds: float) -> str: + """格式化运行时间""" + days = int(seconds // 86400) + hours = int((seconds % 86400) // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + + if days > 0: + return f"{days}天 {hours}小时 {minutes}分钟" + elif hours > 0: + return f"{hours}小时 {minutes}分钟" + elif minutes > 0: + return f"{minutes}分钟 {secs}秒" + else: + return f"{secs}秒" + + def print_stats(self): + """打印统计信息到日志""" + stats = self.get_stats() + + logger.info("=" * 50) + logger.info("性能监控报告") + logger.info("=" * 50) + logger.info(f"运行时间: {stats['uptime_formatted']}") + logger.info(f"消息统计: 收到 {stats['messages']['received']}, " + f"处理 {stats['messages']['processed']}, " + f"失败 {stats['messages']['failed']}, " + f"丢弃 {stats['messages']['dropped']}") + logger.info(f"成功率: {stats['messages']['success_rate']}, " + f"处理速率: {stats['messages']['processing_rate']}") + logger.info(f"平均处理时间: {stats['processing_time']['average_ms']}ms") + logger.info(f"队列大小: {stats['queue']['current_size']}") + logger.info(f"熔断器状态: {stats['circuit_breaker'].get('state', 'N/A')}") + + if stats['plugins']: + logger.info("插件耗时排行:") + for i, p in enumerate(stats['plugins'][:5], 1): + logger.info(f" {i}. {p['name']}: {p['avg_time_ms']}ms (调用 {p['calls']} 次)") + + logger.info("=" * 50) + + +# ==================== 全局实例 ==================== + +# 性能监控器单例 +_performance_monitor: Optional[PerformanceMonitor] = None + +def get_performance_monitor() -> PerformanceMonitor: + """获取性能监控器实例""" + global _performance_monitor + if _performance_monitor is None: + _performance_monitor = PerformanceMonitor() + return _performance_monitor diff --git a/utils/plugin_base.py b/utils/plugin_base.py index 7798609..4295913 100644 --- a/utils/plugin_base.py +++ b/utils/plugin_base.py @@ -1,4 +1,5 @@ from abc import ABC +from typing import List from loguru import logger @@ -13,6 +14,14 @@ class PluginBase(ABC): author: str = "未知" version: str = "1.0.0" + # 插件依赖(填写依赖的插件类名列表) + # 例如: dependencies = ["MessageLogger", "AIChat"] + dependencies: List[str] = [] + + # 加载优先级(数值越大越先加载,默认50) + # 基础插件设置高优先级,依赖其他插件的设置低优先级 + load_priority: int = 50 + def __init__(self): self.enabled = False self._scheduled_jobs = set() diff --git a/utils/plugin_manager.py b/utils/plugin_manager.py index a016d2b..1838e6d 100644 --- a/utils/plugin_manager.py +++ b/utils/plugin_manager.py @@ -117,24 +117,107 @@ class PluginManager(metaclass=Singleton): if not found: logger.warning(f"未找到插件类 {plugin_name}") + def _resolve_load_order(self, plugin_classes: List[Type[PluginBase]]) -> List[Type[PluginBase]]: + """ + 解析插件加载顺序(拓扑排序 + 优先级排序) + + Args: + plugin_classes: 插件类列表 + + Returns: + 按依赖关系和优先级排序后的插件类列表 + """ + # 构建插件名到类的映射 + name_to_class = {cls.__name__: cls for cls in plugin_classes} + + # 构建依赖图 + dependencies = {} + for cls in plugin_classes: + deps = getattr(cls, 'dependencies', []) + dependencies[cls.__name__] = [d for d in deps if d in name_to_class] + + # 拓扑排序 + sorted_names = [] + visited = set() + temp_visited = set() + + def visit(name: str): + if name in temp_visited: + # 检测到循环依赖 + logger.warning(f"检测到循环依赖: {name}") + return + if name in visited: + return + + temp_visited.add(name) + + # 先访问依赖 + for dep in dependencies.get(name, []): + visit(dep) + + temp_visited.remove(name) + visited.add(name) + sorted_names.append(name) + + # 按优先级排序后再进行拓扑排序 + priority_sorted = sorted( + plugin_classes, + key=lambda cls: getattr(cls, 'load_priority', 50), + reverse=True + ) + + for cls in priority_sorted: + if cls.__name__ not in visited: + visit(cls.__name__) + + # 返回排序后的类列表 + return [name_to_class[name] for name in sorted_names if name in name_to_class] + async def load_plugins(self, load_disabled: bool = True) -> Union[List[str], bool]: + """加载所有插件(按依赖顺序)""" loaded_plugins = [] + # 第一步:收集所有插件类 + all_plugin_classes = [] + plugin_disabled_map = {} + for dirname in os.listdir("plugins"): if os.path.isdir(f"plugins/{dirname}") and os.path.exists(f"plugins/{dirname}/main.py"): try: module = importlib.import_module(f"plugins.{dirname}.main") for name, obj in inspect.getmembers(module): if inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase: + all_plugin_classes.append(obj) + + # 记录是否禁用 is_disabled = False if not load_disabled: is_disabled = obj.__name__ in self.excluded_plugins or dirname in self.excluded_plugins - - if await self._load_plugin_class(obj, is_disabled=is_disabled): - loaded_plugins.append(obj.__name__) + plugin_disabled_map[obj.__name__] = is_disabled except: logger.error(f"加载 {dirname} 时发生错误: {traceback.format_exc()}") + # 第二步:按依赖顺序排序 + sorted_classes = self._resolve_load_order(all_plugin_classes) + logger.info(f"插件加载顺序: {[cls.__name__ for cls in sorted_classes]}") + + # 第三步:按顺序加载插件 + for plugin_class in sorted_classes: + plugin_name = plugin_class.__name__ + is_disabled = plugin_disabled_map.get(plugin_name, False) + + # 检查依赖是否已加载 + deps = getattr(plugin_class, 'dependencies', []) + deps_satisfied = all(dep in self.plugins for dep in deps) + + if not deps_satisfied and not is_disabled: + missing_deps = [dep for dep in deps if dep not in self.plugins] + logger.warning(f"插件 {plugin_name} 的依赖未满足: {missing_deps},跳过加载") + continue + + if await self._load_plugin_class(plugin_class, is_disabled=is_disabled): + loaded_plugins.append(plugin_name) + return loaded_plugins async def unload_plugin(self, plugin_name: str) -> bool: diff --git a/utils/redis_cache.py b/utils/redis_cache.py new file mode 100644 index 0000000..2615eb5 --- /dev/null +++ b/utils/redis_cache.py @@ -0,0 +1,744 @@ +""" +Redis 缓存工具类 + +用于缓存用户信息等数据,减少 API 调用 +""" + +import json +from typing import Optional, Dict, Any +from loguru import logger + +try: + import redis + REDIS_AVAILABLE = True +except ImportError: + REDIS_AVAILABLE = False + logger.warning("redis 库未安装,缓存功能将不可用") + + +class RedisCache: + """Redis 缓存管理器""" + + _instance = None + + def __new__(cls, *args, **kwargs): + """单例模式""" + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self, config: Dict = None): + """ + 初始化 Redis 连接 + + Args: + config: Redis 配置字典,包含 host, port, password, db 等 + """ + if self._initialized: + return + + self.client = None + self.enabled = False + self.default_ttl = 3600 # 默认过期时间 1 小时 + + if not REDIS_AVAILABLE: + logger.warning("Redis 库未安装,缓存功能禁用") + self._initialized = True + return + + if config: + self.connect(config) + + self._initialized = True + + def connect(self, config: Dict) -> bool: + """ + 连接 Redis + + Args: + config: Redis 配置 + + Returns: + 是否连接成功 + """ + if not REDIS_AVAILABLE: + return False + + try: + self.client = redis.Redis( + host=config.get("host", "localhost"), + port=config.get("port", 6379), + password=config.get("password", None), + db=config.get("db", 0), + decode_responses=True, + socket_timeout=5, + socket_connect_timeout=5 + ) + + # 测试连接 + self.client.ping() + self.enabled = True + self.default_ttl = config.get("ttl", 3600) + + logger.success(f"Redis 连接成功: {config.get('host')}:{config.get('port')}") + return True + + except Exception as e: + logger.error(f"Redis 连接失败: {e}") + self.client = None + self.enabled = False + return False + + def _make_key(self, prefix: str, *args) -> str: + """ + 生成缓存 key + + Args: + prefix: key 前缀 + *args: key 组成部分 + + Returns: + 完整的 key + """ + parts = [prefix] + [str(arg) for arg in args] + return ":".join(parts) + + def get(self, key: str) -> Optional[Any]: + """ + 获取缓存值 + + Args: + key: 缓存 key + + Returns: + 缓存的值,不存在返回 None + """ + if not self.enabled or not self.client: + return None + + try: + value = self.client.get(key) + if value: + return json.loads(value) + return None + except Exception as e: + logger.error(f"Redis GET 失败: {key}, {e}") + return None + + def set(self, key: str, value: Any, ttl: int = None) -> bool: + """ + 设置缓存值 + + Args: + key: 缓存 key + value: 要缓存的值 + ttl: 过期时间(秒),默认使用 default_ttl + + Returns: + 是否设置成功 + """ + if not self.enabled or not self.client: + return False + + try: + ttl = ttl or self.default_ttl + self.client.setex(key, ttl, json.dumps(value, ensure_ascii=False)) + return True + except Exception as e: + logger.error(f"Redis SET 失败: {key}, {e}") + return False + + def delete(self, key: str) -> bool: + """ + 删除缓存 + + Args: + key: 缓存 key + + Returns: + 是否删除成功 + """ + if not self.enabled or not self.client: + return False + + try: + self.client.delete(key) + return True + except Exception as e: + logger.error(f"Redis DELETE 失败: {key}, {e}") + return False + + def delete_pattern(self, pattern: str) -> int: + """ + 删除匹配模式的所有 key + + Args: + pattern: key 模式,如 "user_info:*" + + Returns: + 删除的 key 数量 + """ + if not self.enabled or not self.client: + return 0 + + try: + keys = self.client.keys(pattern) + if keys: + return self.client.delete(*keys) + return 0 + except Exception as e: + logger.error(f"Redis DELETE PATTERN 失败: {pattern}, {e}") + return 0 + + # ==================== 用户信息缓存专用方法 ==================== + + def get_user_info(self, chatroom_id: str, user_wxid: str) -> Optional[Dict]: + """ + 获取缓存的用户信息 + + Args: + chatroom_id: 群聊 ID + user_wxid: 用户 wxid + + Returns: + 用户信息字典,不存在返回 None + """ + key = self._make_key("user_info", chatroom_id, user_wxid) + return self.get(key) + + def set_user_info(self, chatroom_id: str, user_wxid: str, user_info: Dict, ttl: int = None) -> bool: + """ + 缓存用户信息 + + Args: + chatroom_id: 群聊 ID + user_wxid: 用户 wxid + user_info: 用户信息字典 + ttl: 过期时间(秒) + + Returns: + 是否缓存成功 + """ + key = self._make_key("user_info", chatroom_id, user_wxid) + return self.set(key, user_info, ttl) + + def get_user_basic_info(self, chatroom_id: str, user_wxid: str) -> Optional[Dict]: + """ + 获取缓存的用户基本信息(昵称和头像) + + Args: + chatroom_id: 群聊 ID + user_wxid: 用户 wxid + + Returns: + 包含 nickname 和 avatar_url 的字典 + """ + user_info = self.get_user_info(chatroom_id, user_wxid) + if user_info: + # 提取基本信息 + nickname = "" + if isinstance(user_info.get("nickName"), dict): + nickname = user_info.get("nickName", {}).get("string", "") + else: + nickname = user_info.get("nickName", "") + + avatar_url = user_info.get("bigHeadImgUrl", "") + + if nickname or avatar_url: + return { + "nickname": nickname, + "avatar_url": avatar_url + } + return None + + def clear_user_cache(self, chatroom_id: str = None, user_wxid: str = None) -> int: + """ + 清除用户信息缓存 + + Args: + chatroom_id: 群聊 ID,为空则清除所有群 + user_wxid: 用户 wxid,为空则清除该群所有用户 + + Returns: + 清除的缓存数量 + """ + if chatroom_id and user_wxid: + key = self._make_key("user_info", chatroom_id, user_wxid) + return 1 if self.delete(key) else 0 + elif chatroom_id: + pattern = self._make_key("user_info", chatroom_id, "*") + return self.delete_pattern(pattern) + else: + return self.delete_pattern("user_info:*") + + def get_cache_stats(self) -> Dict: + """ + 获取缓存统计信息 + + Returns: + 统计信息字典 + """ + if not self.enabled or not self.client: + return {"enabled": False} + + try: + info = self.client.info("memory") + user_keys = len(self.client.keys("user_info:*")) + chat_keys = len(self.client.keys("chat_history:*")) + + return { + "enabled": True, + "used_memory": info.get("used_memory_human", "unknown"), + "user_info_count": user_keys, + "chat_history_count": chat_keys + } + except Exception as e: + logger.error(f"获取缓存统计失败: {e}") + return {"enabled": True, "error": str(e)} + + # ==================== 对话历史缓存专用方法 ==================== + + def get_chat_history(self, chat_id: str, max_messages: int = 100) -> list: + """ + 获取对话历史 + + Args: + chat_id: 会话ID(私聊为用户wxid,群聊为 群ID:用户ID) + max_messages: 最大返回消息数 + + Returns: + 消息列表 + """ + if not self.enabled or not self.client: + return [] + + try: + key = self._make_key("chat_history", chat_id) + # 使用 LRANGE 获取最近的消息(列表尾部是最新的) + data = self.client.lrange(key, -max_messages, -1) + return [json.loads(item) for item in data] + except Exception as e: + logger.error(f"获取对话历史失败: {chat_id}, {e}") + return [] + + def add_chat_message(self, chat_id: str, role: str, content, ttl: int = 86400) -> bool: + """ + 添加消息到对话历史 + + Args: + chat_id: 会话ID + role: 角色 (user/assistant) + content: 消息内容(字符串或列表) + ttl: 过期时间(秒),默认24小时 + + Returns: + 是否添加成功 + """ + if not self.enabled or not self.client: + return False + + try: + key = self._make_key("chat_history", chat_id) + message = {"role": role, "content": content} + self.client.rpush(key, json.dumps(message, ensure_ascii=False)) + self.client.expire(key, ttl) + return True + except Exception as e: + logger.error(f"添加对话消息失败: {chat_id}, {e}") + return False + + def trim_chat_history(self, chat_id: str, max_messages: int = 100) -> bool: + """ + 裁剪对话历史,保留最近的N条消息 + + Args: + chat_id: 会话ID + max_messages: 保留的最大消息数 + + Returns: + 是否成功 + """ + if not self.enabled or not self.client: + return False + + try: + key = self._make_key("chat_history", chat_id) + # 保留最后 max_messages 条 + self.client.ltrim(key, -max_messages, -1) + return True + except Exception as e: + logger.error(f"裁剪对话历史失败: {chat_id}, {e}") + return False + + def clear_chat_history(self, chat_id: str) -> bool: + """ + 清空指定会话的对话历史 + + Args: + chat_id: 会话ID + + Returns: + 是否成功 + """ + if not self.enabled or not self.client: + return False + + try: + key = self._make_key("chat_history", chat_id) + self.client.delete(key) + return True + except Exception as e: + logger.error(f"清空对话历史失败: {chat_id}, {e}") + return False + + # ==================== 群聊历史记录专用方法 ==================== + + def get_group_history(self, group_id: str, max_messages: int = 100) -> list: + """ + 获取群聊历史记录 + + Args: + group_id: 群聊ID + max_messages: 最大返回消息数 + + Returns: + 消息列表,每条包含 nickname, content, timestamp + """ + if not self.enabled or not self.client: + return [] + + try: + key = self._make_key("group_history", group_id) + data = self.client.lrange(key, -max_messages, -1) + return [json.loads(item) for item in data] + except Exception as e: + logger.error(f"获取群聊历史失败: {group_id}, {e}") + return [] + + def add_group_message(self, group_id: str, nickname: str, content, + record_id: str = None, ttl: int = 86400) -> bool: + """ + 添加消息到群聊历史 + + Args: + group_id: 群聊ID + nickname: 发送者昵称 + content: 消息内容 + record_id: 可选的记录ID,用于后续更新 + ttl: 过期时间(秒),默认24小时 + + Returns: + 是否添加成功 + """ + if not self.enabled or not self.client: + return False + + try: + import time + key = self._make_key("group_history", group_id) + message = { + "nickname": nickname, + "content": content, + "timestamp": time.time() + } + if record_id: + message["id"] = record_id + + self.client.rpush(key, json.dumps(message, ensure_ascii=False)) + self.client.expire(key, ttl) + return True + except Exception as e: + logger.error(f"添加群聊消息失败: {group_id}, {e}") + return False + + def update_group_message_by_id(self, group_id: str, record_id: str, new_content) -> bool: + """ + 根据ID更新群聊历史中的消息 + + Args: + group_id: 群聊ID + record_id: 记录ID + new_content: 新内容 + + Returns: + 是否更新成功 + """ + if not self.enabled or not self.client: + return False + + try: + key = self._make_key("group_history", group_id) + # 获取所有消息 + data = self.client.lrange(key, 0, -1) + + for i, item in enumerate(data): + msg = json.loads(item) + if msg.get("id") == record_id: + msg["content"] = new_content + self.client.lset(key, i, json.dumps(msg, ensure_ascii=False)) + return True + return False + except Exception as e: + logger.error(f"更新群聊消息失败: {group_id}, {record_id}, {e}") + return False + + def trim_group_history(self, group_id: str, max_messages: int = 100) -> bool: + """ + 裁剪群聊历史,保留最近的N条消息 + + Args: + group_id: 群聊ID + max_messages: 保留的最大消息数 + + Returns: + 是否成功 + """ + if not self.enabled or not self.client: + return False + + try: + key = self._make_key("group_history", group_id) + self.client.ltrim(key, -max_messages, -1) + return True + except Exception as e: + logger.error(f"裁剪群聊历史失败: {group_id}, {e}") + return False + + # ==================== 限流专用方法 ==================== + + def check_rate_limit(self, identifier: str, limit: int = 10, + window: int = 60, limit_type: str = "message") -> tuple: + """ + 检查是否超过限流 + + 使用滑动窗口算法 + + Args: + identifier: 标识符(如用户wxid、群ID等) + limit: 时间窗口内最大请求数 + window: 时间窗口(秒) + limit_type: 限流类型(message/ai_chat/image_gen等) + + Returns: + (是否允许, 剩余次数, 重置时间秒数) + """ + if not self.enabled or not self.client: + return (True, limit, 0) # Redis 不可用时不限流 + + try: + import time + key = self._make_key("rate_limit", limit_type, identifier) + now = time.time() + window_start = now - window + + # 使用 pipeline 提高性能 + pipe = self.client.pipeline() + + # 移除过期的记录 + pipe.zremrangebyscore(key, 0, window_start) + # 获取当前窗口内的请求数 + pipe.zcard(key) + # 添加当前请求 + pipe.zadd(key, {str(now): now}) + # 设置过期时间 + pipe.expire(key, window) + + results = pipe.execute() + current_count = results[1] # zcard 的结果 + + if current_count >= limit: + # 获取最早的记录时间,计算重置时间 + oldest = self.client.zrange(key, 0, 0, withscores=True) + if oldest: + reset_time = int(oldest[0][1] + window - now) + else: + reset_time = window + return (False, 0, max(reset_time, 1)) + + remaining = limit - current_count - 1 + return (True, remaining, 0) + + except Exception as e: + logger.error(f"限流检查失败: {identifier}, {e}") + return (True, limit, 0) # 出错时不限流 + + def get_rate_limit_status(self, identifier: str, limit: int = 10, + window: int = 60, limit_type: str = "message") -> Dict: + """ + 获取限流状态(不增加计数) + + Args: + identifier: 标识符 + limit: 时间窗口内最大请求数 + window: 时间窗口(秒) + limit_type: 限流类型 + + Returns: + 状态字典 + """ + if not self.enabled or not self.client: + return {"enabled": False, "current": 0, "limit": limit, "remaining": limit} + + try: + import time + key = self._make_key("rate_limit", limit_type, identifier) + now = time.time() + window_start = now - window + + # 移除过期记录并获取当前数量 + self.client.zremrangebyscore(key, 0, window_start) + current = self.client.zcard(key) + + return { + "enabled": True, + "current": current, + "limit": limit, + "remaining": max(0, limit - current), + "window": window + } + except Exception as e: + logger.error(f"获取限流状态失败: {identifier}, {e}") + return {"enabled": False, "error": str(e)} + + def reset_rate_limit(self, identifier: str, limit_type: str = "message") -> bool: + """ + 重置限流计数 + + Args: + identifier: 标识符 + limit_type: 限流类型 + + Returns: + 是否成功 + """ + if not self.enabled or not self.client: + return False + + try: + key = self._make_key("rate_limit", limit_type, identifier) + self.client.delete(key) + return True + except Exception as e: + logger.error(f"重置限流失败: {identifier}, {e}") + return False + + # ==================== 媒体缓存专用方法 ==================== + + def cache_media(self, media_key: str, base64_data: str, media_type: str = "image", ttl: int = 300) -> bool: + """ + 缓存媒体文件的 base64 数据 + + Args: + media_key: 媒体唯一标识(如 cdnurl 的 hash 或 aeskey) + base64_data: base64 编码的媒体数据 + media_type: 媒体类型(image/emoji/video) + ttl: 过期时间(秒),默认5分钟 + + Returns: + 是否缓存成功 + """ + if not self.enabled or not self.client: + return False + + try: + key = self._make_key("media_cache", media_type, media_key) + # 直接存储 base64 字符串,不再 json 序列化 + self.client.setex(key, ttl, base64_data) + logger.debug(f"媒体已缓存: {media_type}/{media_key[:20]}..., TTL={ttl}s") + return True + except Exception as e: + logger.error(f"缓存媒体失败: {media_key}, {e}") + return False + + def get_cached_media(self, media_key: str, media_type: str = "image") -> Optional[str]: + """ + 获取缓存的媒体 base64 数据 + + Args: + media_key: 媒体唯一标识 + media_type: 媒体类型 + + Returns: + base64 数据,不存在返回 None + """ + if not self.enabled or not self.client: + return None + + try: + key = self._make_key("media_cache", media_type, media_key) + data = self.client.get(key) + if data: + logger.debug(f"媒体缓存命中: {media_type}/{media_key[:20]}...") + return data + return None + except Exception as e: + logger.error(f"获取媒体缓存失败: {media_key}, {e}") + return None + + def delete_cached_media(self, media_key: str, media_type: str = "image") -> bool: + """ + 删除缓存的媒体 + + Args: + media_key: 媒体唯一标识 + media_type: 媒体类型 + + Returns: + 是否删除成功 + """ + if not self.enabled or not self.client: + return False + + try: + key = self._make_key("media_cache", media_type, media_key) + self.client.delete(key) + return True + except Exception as e: + logger.error(f"删除媒体缓存失败: {media_key}, {e}") + return False + + @staticmethod + def generate_media_key(cdnurl: str = "", aeskey: str = "") -> str: + """ + 根据 CDN URL 或 AES Key 生成媒体缓存 key + + Args: + cdnurl: CDN URL + aeskey: AES Key + + Returns: + 缓存 key + """ + import hashlib + # 优先使用 aeskey(更短更稳定),否则使用 cdnurl 的 hash + if aeskey: + return aeskey[:32] # 取前32位作为 key + elif cdnurl: + return hashlib.md5(cdnurl.encode()).hexdigest() + return "" + + +def get_cache() -> Optional[RedisCache]: + """ + 获取全局缓存实例 + + 返回 RedisCache 单例实例。如果还没有初始化,返回一个未连接的实例。 + 建议在 MessageLogger 初始化后再调用此函数。 + """ + return RedisCache._instance + + +def init_cache(config: Dict) -> RedisCache: + """ + 初始化全局缓存实例 + + Args: + config: Redis 配置 + + Returns: + 缓存实例 + """ + global _cache_instance + _cache_instance = RedisCache(config) + return _cache_instance