feat: 优化整体项目

This commit is contained in:
2025-12-05 18:06:13 +08:00
parent b4df26f61d
commit 7d3ef70093
13 changed files with 2661 additions and 305 deletions

View File

@@ -126,6 +126,59 @@ def add_callback_handler(callback_handler):
logger.debug(f"注册断开回调: {name}")
def remove_callback_handler(callback_handler):
"""
移除回调处理器实例(修复内存泄漏)
Args:
callback_handler: 包含回调方法的对象
"""
global _GLOBAL_CONNECT_CALLBACK_LIST, _GLOBAL_RECV_CALLBACK_LIST, _GLOBAL_CLOSE_CALLBACK_LIST
# 移除属于该处理器的所有回调
_GLOBAL_CONNECT_CALLBACK_LIST = [
f for f in _GLOBAL_CONNECT_CALLBACK_LIST
if not (hasattr(f, '__self__') and f.__self__ is callback_handler)
]
_GLOBAL_RECV_CALLBACK_LIST = [
f for f in _GLOBAL_RECV_CALLBACK_LIST
if not (hasattr(f, '__self__') and f.__self__ is callback_handler)
]
_GLOBAL_CLOSE_CALLBACK_LIST = [
f for f in _GLOBAL_CLOSE_CALLBACK_LIST
if not (hasattr(f, '__self__') and f.__self__ is callback_handler)
]
logger.debug(f"已移除回调处理器: {type(callback_handler).__name__}")
def clear_all_callbacks():
"""
清空所有回调(用于关闭时清理)
"""
global _GLOBAL_CONNECT_CALLBACK_LIST, _GLOBAL_RECV_CALLBACK_LIST, _GLOBAL_CLOSE_CALLBACK_LIST
_GLOBAL_CONNECT_CALLBACK_LIST.clear()
_GLOBAL_RECV_CALLBACK_LIST.clear()
_GLOBAL_CLOSE_CALLBACK_LIST.clear()
logger.debug("已清空所有回调")
def get_callback_count() -> dict:
"""
获取回调数量统计
Returns:
回调数量字典
"""
return {
"connect": len(_GLOBAL_CONNECT_CALLBACK_LIST),
"recv": len(_GLOBAL_RECV_CALLBACK_LIST),
"close": len(_GLOBAL_CLOSE_CALLBACK_LIST)
}
@WINFUNCTYPE(None, ctypes.c_void_p)
def wechat_connect_callback(client_id):
"""

362
bot.py
View File

@@ -2,10 +2,19 @@
WechatHookBot - 主入口
基于个微大客户版 Hook API 的微信机器人框架
优化功能:
- 优先级消息队列
- 自适应熔断器
- 配置热更新
- 性能监控
- 优雅关闭
"""
import asyncio
import signal
import sys
import time
import tomllib
from pathlib import Path
from loguru import logger
@@ -13,6 +22,8 @@ from loguru import logger
from WechatHook import NoveLoader, WechatHookClient
from WechatHook.callbacks import (
add_callback_handler,
remove_callback_handler,
clear_all_callbacks,
wechat_connect_callback,
wechat_recv_callback,
wechat_close_callback,
@@ -23,7 +34,15 @@ from WechatHook.callbacks import (
from utils.hookbot import HookBot
from utils.plugin_manager import PluginManager
from utils.decorators import scheduler
# from database import KeyvalDB, MessageDB # 不需要数据库
from utils.bot_utils import (
PriorityMessageQueue,
MessagePriority,
PRIORITY_MESSAGE_TYPES,
AdaptiveCircuitBreaker,
ConfigWatcher,
PerformanceMonitor,
get_performance_monitor
)
class BotService:
@@ -37,17 +56,24 @@ class BotService:
self.process_id = None # 微信进程 ID
self.socket_client_id = None # Socket 客户端 ID
self.is_running = False
self.is_shutting_down = False # 是否正在关闭
self.event_loop = None # 事件循环引用
# 消息队列和性能控制
self.message_queue = None
self.message_queue: PriorityMessageQueue = None # 优先级消息队列
self.queue_config = {}
self.concurrency_config = {}
self.consumer_tasks = []
self.processing_semaphore = None
self.circuit_breaker_failures = 0
self.circuit_breaker_open = False
self.circuit_breaker_last_failure = 0
# 自适应熔断器
self.circuit_breaker: AdaptiveCircuitBreaker = None
# 配置热更新
self.config_watcher: ConfigWatcher = None
# 性能监控
self.performance_monitor: PerformanceMonitor = None
@CONNECT_CALLBACK(in_class=True)
def on_connect(self, client_id):
@@ -85,118 +111,125 @@ class BotService:
logger.error(f"消息入队失败: {e}")
async def _enqueue_message(self, msg_type, data):
"""将消息加入队列"""
"""将消息加入优先级队列"""
try:
# 记录收到消息
if self.performance_monitor:
self.performance_monitor.record_message_received()
# 检查队列是否已满
if self.message_queue.qsize() >= self.queue_config.get("max_size", 1000):
if self.message_queue.full():
overflow_strategy = self.queue_config.get("overflow_strategy", "drop_oldest")
if overflow_strategy == "drop_oldest":
# 丢弃最旧的消息
try:
self.message_queue.get_nowait()
logger.warning("队列已满,丢弃最旧消息")
except asyncio.QueueEmpty:
pass
# 丢弃优先级最低的消息
if self.message_queue.drop_lowest_priority():
logger.warning("队列已满,丢弃优先级最低的消息")
if self.performance_monitor:
self.performance_monitor.record_message_dropped()
elif overflow_strategy == "sampling":
# 采样处理,随机丢弃
# 采样处理,随机丢弃(但高优先级消息不丢弃)
import random
if random.random() < 0.5: # 50% 概率丢弃
priority = PRIORITY_MESSAGE_TYPES.get(msg_type, MessagePriority.NORMAL)
if priority < MessagePriority.HIGH and random.random() < 0.5:
logger.debug("队列压力大,采样丢弃消息")
if self.performance_monitor:
self.performance_monitor.record_message_dropped()
return
else: # degrade
logger.warning("队列已满,降级处理")
return
# 降级处理(但高优先级消息不丢弃)
priority = PRIORITY_MESSAGE_TYPES.get(msg_type, MessagePriority.NORMAL)
if priority < MessagePriority.HIGH:
logger.warning("队列已满,降级处理")
if self.performance_monitor:
self.performance_monitor.record_message_dropped()
return
# 将消息放入优先级队列
await self.message_queue.put(msg_type, data)
# 记录队列大小
if self.performance_monitor:
self.performance_monitor.record_queue_size(self.message_queue.qsize())
# 将消息放入队列
await self.message_queue.put((msg_type, data))
except Exception as e:
logger.error(f"消息入队异常: {e}")
async def _message_consumer(self, consumer_id: int):
"""消息消费者协程"""
logger.info(f"消息消费者 {consumer_id} 已启动")
while self.is_running:
"""消息消费者协程 - 纯队列串行模式,避免并发触发风控"""
logger.info(f"消息消费者 {consumer_id} 已启动(串行模式)")
while self.is_running and not self.is_shutting_down:
try:
# 从队列获取消息,设置超时避免无限等待
msg_type, data = await asyncio.wait_for(
self.message_queue.get(),
self.message_queue.get(),
timeout=1.0
)
# 检查熔断器状态
if self._check_circuit_breaker():
if self.circuit_breaker and self.circuit_breaker.is_open():
logger.debug("熔断器开启,跳过消息处理")
self.circuit_breaker.record_rejection()
self.message_queue.task_done()
continue
# 创建并发任务,不等待完成
timeout = self.concurrency_config.get("plugin_task_timeout_seconds", 5)
# 使用信号量控制并发数量
async def process_with_semaphore():
async with self.processing_semaphore:
try:
await asyncio.wait_for(
self.hookbot.process_message(msg_type, data),
timeout=timeout
)
self._reset_circuit_breaker()
except asyncio.TimeoutError:
logger.warning(f"消息处理超时 (>{timeout}s): type={msg_type}")
self._record_circuit_breaker_failure()
except Exception as e:
logger.error(f"消息处理异常: {e}")
self._record_circuit_breaker_failure()
# 创建任务但不等待,实现真正并发
asyncio.create_task(process_with_semaphore())
# 串行处理:等待当前消息处理完成后再处理下一条
timeout = self.concurrency_config.get("plugin_task_timeout_seconds", 720)
start_time = time.time()
try:
await asyncio.wait_for(
self.hookbot.process_message(msg_type, data),
timeout=timeout
)
# 记录成功
processing_time = time.time() - start_time
if self.circuit_breaker:
self.circuit_breaker.record_success()
if self.performance_monitor:
self.performance_monitor.record_message_processed(processing_time)
except asyncio.TimeoutError:
logger.warning(f"消息处理超时 (>{timeout}s): type={msg_type}")
if self.circuit_breaker:
self.circuit_breaker.record_failure()
if self.performance_monitor:
self.performance_monitor.record_message_failed()
except Exception as e:
logger.error(f"消息处理异常: {e}")
if self.circuit_breaker:
self.circuit_breaker.record_failure()
if self.performance_monitor:
self.performance_monitor.record_message_failed()
# 标记任务完成
self.message_queue.task_done()
# 更新熔断器统计
if self.performance_monitor and self.circuit_breaker:
self.performance_monitor.update_circuit_breaker_stats(
self.circuit_breaker.get_stats()
)
# 消息间隔,避免发送太快触发风控
message_interval = self.concurrency_config.get("message_interval_ms", 100)
if message_interval > 0:
await asyncio.sleep(message_interval / 1000.0)
except asyncio.TimeoutError:
# 队列为空,继续等待
continue
except asyncio.CancelledError:
# 任务被取消,退出循环
logger.info(f"消费者 {consumer_id} 收到取消信号")
break
except Exception as e:
logger.error(f"消费者 {consumer_id} 异常: {e}")
await asyncio.sleep(0.1) # 短暂休息避免忙等
def _check_circuit_breaker(self) -> bool:
"""检查熔断器状态"""
if not self.concurrency_config.get("enable_circuit_breaker", True):
return False
if self.circuit_breaker_open:
# 检查是否可以尝试恢复
import time
if time.time() - self.circuit_breaker_last_failure > 30: # 30秒后尝试恢复
self.circuit_breaker_open = False
self.circuit_breaker_failures = 0
logger.info("熔断器尝试恢复")
return False
return True
return False
def _record_circuit_breaker_failure(self):
"""记录熔断器失败"""
if not self.concurrency_config.get("enable_circuit_breaker", True):
return
self.circuit_breaker_failures += 1
threshold = self.concurrency_config.get("circuit_breaker_threshold", 5)
if self.circuit_breaker_failures >= threshold:
import time
self.circuit_breaker_open = True
self.circuit_breaker_last_failure = time.time()
logger.warning(f"熔断器开启,连续失败 {self.circuit_breaker_failures}")
def _reset_circuit_breaker(self):
"""重置熔断器"""
if self.circuit_breaker_failures > 0:
self.circuit_breaker_failures = 0
logger.info(f"消费者 {consumer_id} 已退出")
@CLOSE_CALLBACK(in_class=True)
def on_close(self, client_id):
@@ -235,17 +268,37 @@ class BotService:
# 初始化性能配置
self.queue_config = config.get("Queue", {})
self.concurrency_config = config.get("Concurrency", {})
# 创建消息队列
# 创建优先级消息队列
queue_size = self.queue_config.get("max_size", 1000)
self.message_queue = asyncio.Queue(maxsize=queue_size)
logger.info(f"消息队列已创建,容量: {queue_size}")
self.message_queue = PriorityMessageQueue(maxsize=queue_size)
logger.info(f"优先级消息队列已创建,容量: {queue_size}")
# 创建并发控制信号量
max_concurrency = self.concurrency_config.get("plugin_max_concurrency", 8)
self.processing_semaphore = asyncio.Semaphore(max_concurrency)
logger.info(f"并发控制已设置,最大并发: {max_concurrency}")
# 创建自适应熔断器
if self.concurrency_config.get("enable_circuit_breaker", True):
self.circuit_breaker = AdaptiveCircuitBreaker(
failure_threshold=self.concurrency_config.get("circuit_breaker_threshold", 10),
success_threshold=3,
initial_recovery_time=5.0,
max_recovery_time=300.0
)
logger.info("自适应熔断器已创建")
# 创建性能监控器
self.performance_monitor = get_performance_monitor()
logger.info("性能监控器已创建")
# 创建配置热更新监听器
self.config_watcher = ConfigWatcher("main_config.toml", check_interval=5.0)
self.config_watcher.register_callback(self._on_config_update)
await self.config_watcher.start()
logger.info("配置热更新监听器已启动")
# 不需要数据库(简化版本)
# 获取 DLL 路径
@@ -340,6 +393,26 @@ class BotService:
return True
def _on_config_update(self, new_config: dict):
"""配置热更新回调"""
logger.info("正在应用新配置...")
# 更新队列配置
self.queue_config = new_config.get("Queue", self.queue_config)
# 更新并发配置
old_concurrency = self.concurrency_config
self.concurrency_config = new_config.get("Concurrency", self.concurrency_config)
# 更新熔断器配置
if self.circuit_breaker:
new_threshold = self.concurrency_config.get("circuit_breaker_threshold", 10)
if new_threshold != old_concurrency.get("circuit_breaker_threshold", 10):
self.circuit_breaker.failure_threshold = new_threshold
logger.info(f"熔断器阈值已更新: {new_threshold}")
logger.success("配置热更新完成")
async def run(self):
"""运行机器人"""
if not await self.initialize():
@@ -347,6 +420,15 @@ class BotService:
self.is_running = True
# 启动定期性能报告
async def periodic_stats():
while self.is_running:
await asyncio.sleep(300) # 每5分钟输出一次
if self.performance_monitor and self.is_running:
self.performance_monitor.print_stats()
stats_task = asyncio.create_task(periodic_stats())
try:
logger.info("机器人正在运行,按 Ctrl+C 停止...")
while self.is_running:
@@ -354,44 +436,96 @@ class BotService:
except KeyboardInterrupt:
logger.info("收到停止信号...")
finally:
stats_task.cancel()
await self.stop()
async def stop(self):
"""停止机器人"""
logger.info("正在停止机器人...")
self.is_running = False
"""优雅关闭机器人"""
if self.is_shutting_down:
return
self.is_shutting_down = True
# 停止消息消费者
logger.info("=" * 60)
logger.info("正在优雅关闭机器人...")
logger.info("=" * 60)
# 1. 停止接收新消息
self.is_running = False
logger.info("[1/7] 停止接收新消息")
# 2. 等待队列中的消息处理完成(带超时)
if self.message_queue and not self.message_queue.empty():
queue_size = self.message_queue.qsize()
logger.info(f"[2/7] 等待队列中 {queue_size} 条消息处理完成...")
try:
await asyncio.wait_for(
self.message_queue.join(),
timeout=30
)
logger.info("[2/7] 队列消息已全部处理完成")
except asyncio.TimeoutError:
logger.warning("[2/7] 队列消息未在 30 秒内处理完成,强制清空")
# 清空剩余消息
while not self.message_queue.empty():
try:
self.message_queue.get_nowait()
self.message_queue.task_done()
except:
break
else:
logger.info("[2/7] 队列为空,无需等待")
# 3. 停止消息消费者
if self.consumer_tasks:
logger.info("正在停止消息消费者...")
logger.info(f"[3/7] 停止 {len(self.consumer_tasks)}消息消费者...")
for task in self.consumer_tasks:
task.cancel()
# 等待所有消费者任务完成
if self.consumer_tasks:
await asyncio.gather(*self.consumer_tasks, return_exceptions=True)
await asyncio.gather(*self.consumer_tasks, return_exceptions=True)
self.consumer_tasks.clear()
logger.info("消息消费者已停止")
logger.info("[3/7] 消息消费者已停止")
else:
logger.info("[3/7] 无消费者需要停止")
# 清空消息队列
if self.message_queue:
while not self.message_queue.empty():
try:
self.message_queue.get_nowait()
self.message_queue.task_done()
except asyncio.QueueEmpty:
break
logger.info("消息队列已清空")
# 4. 停止配置监听器
if self.config_watcher:
logger.info("[4/7] 停止配置监听器...")
await self.config_watcher.stop()
logger.info("[4/7] 配置监听器已停止")
else:
logger.info("[4/7] 无配置监听器")
# 停止定时任务
# 5. 卸载插件
if self.plugin_manager:
logger.info("[5/7] 卸载插件...")
await self.plugin_manager.unload_plugins()
logger.info("[5/7] 插件已卸载")
else:
logger.info("[5/7] 无插件需要卸载")
# 6. 停止定时任务
if scheduler.running:
logger.info("[6/7] 停止定时任务...")
scheduler.shutdown()
logger.info("[6/7] 定时任务已停止")
else:
logger.info("[6/7] 定时任务未运行")
# 7. 清理回调和销毁微信连接
logger.info("[7/7] 清理资源...")
remove_callback_handler(self)
clear_all_callbacks()
# 销毁微信连接
if self.loader:
self.loader.DestroyWeChat()
logger.success("机器人已停止")
# 输出最终性能报告
if self.performance_monitor:
logger.info("最终性能报告:")
self.performance_monitor.print_stats()
logger.success("=" * 60)
logger.success("机器人已优雅关闭")
logger.success("=" * 60)
async def main():

View File

@@ -2,6 +2,7 @@
AI 聊天插件
支持自定义模型、API 和人设
支持 Redis 存储对话历史和限流
"""
import asyncio
@@ -12,6 +13,7 @@ from datetime import datetime
from loguru import logger
from utils.plugin_base import PluginBase
from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message
from utils.redis_cache import get_cache
import xml.etree.ElementTree as ET
import base64
import uuid
@@ -95,6 +97,92 @@ class AIChat(PluginBase):
else:
return sender_wxid or from_wxid # 私聊使用用户ID
async def _get_user_nickname(self, bot, from_wxid: str, user_wxid: str, is_group: bool) -> str:
"""
获取用户昵称,优先使用 Redis 缓存
Args:
bot: WechatHookClient 实例
from_wxid: 消息来源群聊ID或私聊用户ID
user_wxid: 用户wxid
is_group: 是否群聊
Returns:
用户昵称
"""
if not is_group:
return ""
nickname = ""
# 1. 优先从 Redis 缓存获取
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid)
if cached_info and cached_info.get("nickname"):
logger.debug(f"[缓存命中] 用户昵称: {user_wxid} -> {cached_info['nickname']}")
return cached_info["nickname"]
# 2. 缓存未命中,调用 API 获取
try:
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
if user_info and user_info.get("nickName", {}).get("string"):
nickname = user_info["nickName"]["string"]
# 存入缓存
if redis_cache and redis_cache.enabled:
redis_cache.set_user_info(from_wxid, user_wxid, user_info)
logger.debug(f"[已缓存] 用户昵称: {user_wxid} -> {nickname}")
return nickname
except Exception as e:
logger.warning(f"API获取用户昵称失败: {e}")
# 3. 从 MessageLogger 数据库查询
if not nickname:
try:
from plugins.MessageLogger.main import MessageLogger
msg_logger = MessageLogger.get_instance()
if msg_logger:
with msg_logger.get_db_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
(user_wxid,)
)
result = cursor.fetchone()
if result:
nickname = result[0]
except Exception as e:
logger.debug(f"从数据库获取昵称失败: {e}")
# 4. 最后降级使用 wxid
if not nickname:
nickname = user_wxid or "未知用户"
return nickname
def _check_rate_limit(self, user_wxid: str) -> tuple:
"""
检查用户是否超过限流
Args:
user_wxid: 用户wxid
Returns:
(是否允许, 剩余次数, 重置时间秒数)
"""
rate_limit_config = self.config.get("rate_limit", {})
if not rate_limit_config.get("enabled", True):
return (True, 999, 0)
redis_cache = get_cache()
if not redis_cache or not redis_cache.enabled:
return (True, 999, 0) # Redis 不可用时不限流
limit = rate_limit_config.get("ai_chat_limit", 20)
window = rate_limit_config.get("ai_chat_window", 60)
return redis_cache.check_rate_limit(user_wxid, limit, window, "ai_chat")
def _add_to_memory(self, chat_id: str, role: str, content, image_base64: str = None):
"""
添加消息到记忆
@@ -108,9 +196,6 @@ class AIChat(PluginBase):
if not self.config.get("memory", {}).get("enabled", False):
return
if chat_id not in self.memory:
self.memory[chat_id] = []
# 如果有图片,构建多模态内容
if image_base64:
message_content = [
@@ -120,6 +205,22 @@ class AIChat(PluginBase):
else:
message_content = content
# 优先使用 Redis 存储
redis_config = self.config.get("redis", {})
if redis_config.get("use_redis_history", True):
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
ttl = redis_config.get("chat_history_ttl", 86400)
redis_cache.add_chat_message(chat_id, role, message_content, ttl=ttl)
# 裁剪历史
max_messages = self.config["memory"]["max_messages"]
redis_cache.trim_chat_history(chat_id, max_messages)
return
# 降级到内存存储
if chat_id not in self.memory:
self.memory[chat_id] = []
self.memory[chat_id].append({"role": role, "content": message_content})
# 限制记忆长度
@@ -131,16 +232,47 @@ class AIChat(PluginBase):
"""获取记忆中的消息"""
if not self.config.get("memory", {}).get("enabled", False):
return []
# 优先从 Redis 获取
redis_config = self.config.get("redis", {})
if redis_config.get("use_redis_history", True):
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
max_messages = self.config["memory"]["max_messages"]
return redis_cache.get_chat_history(chat_id, max_messages)
# 降级到内存
return self.memory.get(chat_id, [])
def _clear_memory(self, chat_id: str):
"""清空指定会话的记忆"""
# 清空 Redis
redis_config = self.config.get("redis", {})
if redis_config.get("use_redis_history", True):
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
redis_cache.clear_chat_history(chat_id)
# 同时清空内存
if chat_id in self.memory:
del self.memory[chat_id]
async def _download_and_encode_image(self, bot, cdnurl: str, aeskey: str) -> str:
"""下载图片并转换为base64"""
"""下载图片并转换为base64,优先从缓存获取"""
try:
# 1. 优先从 Redis 缓存获取
from utils.redis_cache import RedisCache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
if media_key:
cached_data = redis_cache.get_cached_media(media_key, "image")
if cached_data:
logger.debug(f"[缓存命中] 图片从 Redis 获取: {media_key[:20]}...")
return cached_data
# 2. 缓存未命中,下载图片
logger.debug(f"[缓存未命中] 开始下载图片...")
temp_dir = Path(__file__).parent / "temp"
temp_dir.mkdir(exist_ok=True)
@@ -168,74 +300,114 @@ class AIChat(PluginBase):
with open(save_path, "rb") as f:
image_data = base64.b64encode(f.read()).decode()
base64_result = f"data:image/jpeg;base64,{image_data}"
# 3. 缓存到 Redis供后续使用
if redis_cache and redis_cache.enabled and media_key:
redis_cache.cache_media(media_key, base64_result, "image", ttl=300)
logger.debug(f"[已缓存] 图片缓存到 Redis: {media_key[:20]}...")
try:
Path(save_path).unlink()
except:
pass
return f"data:image/jpeg;base64,{image_data}"
return base64_result
except Exception as e:
logger.error(f"下载图片失败: {e}")
return ""
async def _download_emoji_and_encode(self, cdn_url: str) -> str:
"""下载表情包并转换为base64HTTP 直接下载"""
try:
# 替换 HTML 实体
cdn_url = cdn_url.replace("&amp;", "&")
async def _download_emoji_and_encode(self, cdn_url: str, max_retries: int = 3) -> str:
"""下载表情包并转换为base64HTTP 直接下载,带重试机制),优先从缓存获取"""
# 替换 HTML 实体
cdn_url = cdn_url.replace("&amp;", "&")
temp_dir = Path(__file__).parent / "temp"
temp_dir.mkdir(exist_ok=True)
# 1. 优先从 Redis 缓存获取
from utils.redis_cache import RedisCache
redis_cache = get_cache()
media_key = RedisCache.generate_media_key(cdnurl=cdn_url)
if redis_cache and redis_cache.enabled and media_key:
cached_data = redis_cache.get_cached_media(media_key, "emoji")
if cached_data:
logger.debug(f"[缓存命中] 表情包从 Redis 获取: {media_key[:20]}...")
return cached_data
filename = f"temp_{uuid.uuid4().hex[:8]}.gif"
save_path = temp_dir / filename
# 2. 缓存未命中,下载表情包
logger.debug(f"[缓存未命中] 开始下载表情包...")
temp_dir = Path(__file__).parent / "temp"
temp_dir.mkdir(exist_ok=True)
# 使用 aiohttp 下载
timeout = aiohttp.ClientTimeout(total=30)
filename = f"temp_{uuid.uuid4().hex[:8]}.gif"
save_path = temp_dir / filename
# 配置代理
connector = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5").upper()
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_username = proxy_config.get("username")
proxy_password = proxy_config.get("password")
last_error = None
if proxy_username and proxy_password:
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
else:
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
for attempt in range(max_retries):
try:
# 使用 aiohttp 下载,每次重试增加超时时间
timeout = aiohttp.ClientTimeout(total=30 + attempt * 15)
if PROXY_SUPPORT:
try:
connector = ProxyConnector.from_url(proxy_url)
except:
connector = None
# 配置代理
connector = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5").upper()
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_username = proxy_config.get("username")
proxy_password = proxy_config.get("password")
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.get(cdn_url) as response:
if response.status == 200:
content = await response.read()
with open(save_path, "wb") as f:
f.write(content)
if proxy_username and proxy_password:
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
else:
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
# 编码为 base64
image_data = base64.b64encode(content).decode()
# 删除临时文件
if PROXY_SUPPORT:
try:
save_path.unlink()
connector = ProxyConnector.from_url(proxy_url)
except:
pass
connector = None
return f"data:image/gif;base64,{image_data}"
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.get(cdn_url) as response:
if response.status == 200:
content = await response.read()
return ""
except Exception as e:
logger.error(f"下载表情包失败: {e}")
return ""
if len(content) == 0:
logger.warning(f"表情包下载内容为空,重试 {attempt + 1}/{max_retries}")
continue
# 编码为 base64
image_data = base64.b64encode(content).decode()
logger.debug(f"表情包下载成功,大小: {len(content)} 字节")
base64_result = f"data:image/gif;base64,{image_data}"
# 3. 缓存到 Redis供后续使用
if redis_cache and redis_cache.enabled and media_key:
redis_cache.cache_media(media_key, base64_result, "emoji", ttl=300)
logger.debug(f"[已缓存] 表情包缓存到 Redis: {media_key[:20]}...")
return base64_result
else:
logger.warning(f"表情包下载失败,状态码: {response.status},重试 {attempt + 1}/{max_retries}")
except asyncio.TimeoutError:
last_error = "请求超时"
logger.warning(f"表情包下载超时,重试 {attempt + 1}/{max_retries}")
except aiohttp.ClientError as e:
last_error = str(e)
logger.warning(f"表情包下载网络错误: {e},重试 {attempt + 1}/{max_retries}")
except Exception as e:
last_error = str(e)
logger.warning(f"表情包下载异常: {e},重试 {attempt + 1}/{max_retries}")
# 重试前等待(指数退避)
if attempt < max_retries - 1:
await asyncio.sleep(1 * (attempt + 1))
logger.error(f"表情包下载失败,已重试 {max_retries} 次: {last_error}")
return ""
async def _generate_image_description(self, image_base64: str, prompt: str, config: dict) -> str:
"""
@@ -479,37 +651,8 @@ class AIChat(PluginBase):
# 检查是否应该回复
should_reply = self._should_reply(message, content, bot_wxid)
# 获取用户昵称(用于历史记录)
nickname = ""
if is_group:
try:
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
if user_info and user_info.get("nickName", {}).get("string"):
nickname = user_info["nickName"]["string"]
except:
pass
# 如果获取昵称失败,从 MessageLogger 数据库查询
if not nickname:
from plugins.MessageLogger.main import MessageLogger
msg_logger = MessageLogger.get_instance()
if msg_logger:
try:
with msg_logger.get_db_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
(user_wxid,)
)
result = cursor.fetchone()
if result:
nickname = result[0]
except:
pass
# 最后降级使用 wxid
if not nickname:
nickname = user_wxid or sender_wxid or "未知用户"
# 获取用户昵称(用于历史记录)- 使用缓存优化
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
# 保存到群组历史记录(所有消息都保存,不管是否回复)
if is_group:
@@ -519,6 +662,16 @@ class AIChat(PluginBase):
if not should_reply:
return
# 限流检查(仅在需要回复时检查)
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
if not allowed:
rate_limit_config = self.config.get("rate_limit", {})
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
msg = msg.format(seconds=reset_time)
await bot.send_text(from_wxid, msg)
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
return False
# 提取实际消息内容(去除@
actual_content = self._extract_content(message, content)
if not actual_content:
@@ -1004,8 +1157,23 @@ class AIChat(PluginBase):
json.dump(history, f, ensure_ascii=False, indent=2)
temp_file.replace(history_file)
def _use_redis_for_group_history(self) -> bool:
"""检查是否使用 Redis 存储群聊历史"""
redis_config = self.config.get("redis", {})
if not redis_config.get("use_redis_history", True):
return False
redis_cache = get_cache()
return redis_cache and redis_cache.enabled
async def _load_history(self, chat_id: str) -> list:
"""异步读取群聊历史, 用锁避免与写入冲突"""
"""异步读取群聊历史, 优先使用 Redis"""
# 优先使用 Redis
if self._use_redis_for_group_history():
redis_cache = get_cache()
max_history = self.config.get("history", {}).get("max_history", 100)
return redis_cache.get_group_history(chat_id, max_history)
# 降级到文件存储
history_file = self._get_history_file(chat_id)
if not history_file:
return []
@@ -1015,6 +1183,10 @@ class AIChat(PluginBase):
async def _save_history(self, chat_id: str, history: list):
"""异步写入群聊历史, 包含长度截断"""
# Redis 模式下不需要单独保存add_group_message 已经处理
if self._use_redis_for_group_history():
return
history_file = self._get_history_file(chat_id)
if not history_file:
return
@@ -1040,6 +1212,27 @@ class AIChat(PluginBase):
if not self.config.get("history", {}).get("enabled", True):
return
# 构建消息内容
if image_base64:
message_content = [
{"type": "text", "text": content},
{"type": "image_url", "image_url": {"url": image_base64}}
]
else:
message_content = content
# 优先使用 Redis
if self._use_redis_for_group_history():
redis_cache = get_cache()
redis_config = self.config.get("redis", {})
ttl = redis_config.get("group_history_ttl", 172800)
redis_cache.add_group_message(chat_id, nickname, message_content, ttl=ttl)
# 裁剪历史
max_history = self.config.get("history", {}).get("max_history", 100)
redis_cache.trim_group_history(chat_id, max_history)
return
# 降级到文件存储
history_file = self._get_history_file(chat_id)
if not history_file:
return
@@ -1050,17 +1243,10 @@ class AIChat(PluginBase):
message_record = {
"nickname": nickname,
"timestamp": datetime.now().isoformat()
"timestamp": datetime.now().isoformat(),
"content": message_content
}
if image_base64:
message_record["content"] = [
{"type": "text", "text": content},
{"type": "image_url", "image_url": {"url": image_base64}}
]
else:
message_record["content"] = content
history.append(message_record)
max_history = self.config.get("history", {}).get("max_history", 100)
if len(history) > max_history:
@@ -1073,6 +1259,18 @@ class AIChat(PluginBase):
if not self.config.get("history", {}).get("enabled", True):
return
# 优先使用 Redis
if self._use_redis_for_group_history():
redis_cache = get_cache()
redis_config = self.config.get("redis", {})
ttl = redis_config.get("group_history_ttl", 172800)
redis_cache.add_group_message(chat_id, nickname, content, record_id=record_id, ttl=ttl)
# 裁剪历史
max_history = self.config.get("history", {}).get("max_history", 100)
redis_cache.trim_group_history(chat_id, max_history)
return
# 降级到文件存储
history_file = self._get_history_file(chat_id)
if not history_file:
return
@@ -1097,6 +1295,13 @@ class AIChat(PluginBase):
if not self.config.get("history", {}).get("enabled", True):
return
# 优先使用 Redis
if self._use_redis_for_group_history():
redis_cache = get_cache()
redis_cache.update_group_message_by_id(chat_id, record_id, new_content)
return
# 降级到文件存储
history_file = self._get_history_file(chat_id)
if not history_file:
return
@@ -1204,18 +1409,20 @@ class AIChat(PluginBase):
return True
logger.info(f"AI处理引用图片消息: {title_text[:50]}...")
# 获取用户昵称
nickname = ""
if is_group:
try:
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
if user_info and user_info.get("nickName", {}).get("string"):
nickname = user_info["nickName"]["string"]
logger.info(f"获取到用户昵称: {nickname}")
except Exception as e:
logger.error(f"获取用户昵称失败: {e}")
# 限流检查
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
if not allowed:
rate_limit_config = self.config.get("rate_limit", {})
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
msg = msg.format(seconds=reset_time)
await bot.send_text(from_wxid, msg)
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
return False
# 获取用户昵称 - 使用缓存优化
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
# 下载并编码图片
logger.info(f"开始下载图片: {cdnbigimgurl[:50]}...")
image_base64 = await self._download_and_encode_image(bot, cdnbigimgurl, aeskey)
@@ -1627,34 +1834,8 @@ class AIChat(PluginBase):
if not is_emoji and not aeskey:
return True
# 获取用户昵称
nickname = ""
try:
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
if user_info and user_info.get("nickName", {}).get("string"):
nickname = user_info["nickName"]["string"]
except:
pass
if not nickname:
from plugins.MessageLogger.main import MessageLogger
msg_logger = MessageLogger.get_instance()
if msg_logger:
try:
with msg_logger.get_db_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
(user_wxid,)
)
result = cursor.fetchone()
if result:
nickname = result[0]
except:
pass
if not nickname:
nickname = user_wxid or sender_wxid or "未知用户"
# 获取用户昵称 - 使用缓存优化
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
# 立即插入占位符到 history
placeholder_id = str(uuid.uuid4())

View File

@@ -2,8 +2,17 @@
插件管理插件
提供插件的热重载、启用、禁用等管理功能
支持的指令:
/插件列表 - 查看所有插件状态
/重载插件 <名称> - 重载指定插件
/重载所有插件 - 重载所有插件
/启用插件 <名称> - 启用指定插件
/禁用插件 <名称> - 禁用指定插件
/刷新插件 - 扫描并发现新插件
/插件帮助 - 显示帮助信息
"""
import sys
import tomllib
from pathlib import Path
from loguru import logger
@@ -18,7 +27,10 @@ class ManagePlugin(PluginBase):
# 插件元数据
description = "插件管理,支持热重载、启用、禁用"
author = "ShiHao"
version = "1.0.0"
version = "2.0.0"
# 最高加载优先级,确保最先加载
load_priority = 100
def __init__(self):
super().__init__()
@@ -34,37 +46,72 @@ class ManagePlugin(PluginBase):
self.admins = main_config.get("Bot", {}).get("admins", [])
logger.info(f"插件管理插件已加载,管理员: {self.admins}")
@on_text_message()
def _check_admin(self, message: dict) -> bool:
"""检查是否是管理员"""
sender_wxid = message.get("SenderWxid", "")
from_wxid = message.get("FromWxid", "")
is_group = message.get("IsGroup", False)
# 私聊时 sender_wxid 可能为空,使用 from_wxid
user_wxid = sender_wxid if is_group else from_wxid
return user_wxid in self.admins
@on_text_message(priority=99)
async def handle_command(self, bot, message: dict):
"""处理管理命令"""
content = message.get("Content", "").strip()
from_wxid = message.get("FromWxid", "")
sender_wxid = message.get("SenderWxid", "")
logger.debug(f"ManagePlugin: content={content}, from={from_wxid}, sender={sender_wxid}, admins={self.admins}")
# 检查权限
if not self.admins or sender_wxid not in self.admins:
return
if not self._check_admin(message):
return True # 继续传递给其他插件
# 插件帮助
if content == "/插件帮助" or content == "/plugin help":
await self._show_help(bot, from_wxid)
return False
# 插件列表
if content == "/插件列表" or content == "/plugins":
elif content == "/插件列表" or content == "/plugins":
await self._list_plugins(bot, from_wxid)
return False
# 重载所有插件
elif content == "/重载所有插件" or content == "/reload all":
await self._reload_all_plugins(bot, from_wxid)
return False
# 重载插件
elif content.startswith("/重载插件 ") or content.startswith("/reload "):
plugin_name = content.split(maxsplit=1)[1].strip()
await self._reload_plugin(bot, from_wxid, plugin_name)
return False
# 启用插件
elif content.startswith("/启用插件 ") or content.startswith("/enable "):
plugin_name = content.split(maxsplit=1)[1].strip()
await self._enable_plugin(bot, from_wxid, plugin_name)
return False
# 禁用插件
elif content.startswith("/禁用插件 ") or content.startswith("/disable "):
plugin_name = content.split(maxsplit=1)[1].strip()
await self._disable_plugin(bot, from_wxid, plugin_name)
return False
# 刷新插件(发现新插件)
elif content == "/刷新插件" or content == "/refresh":
await self._refresh_plugins(bot, from_wxid)
return False
# 加载新插件(从目录加载全新插件)
elif content.startswith("/加载插件 ") or content.startswith("/load "):
plugin_name = content.split(maxsplit=1)[1].strip()
await self._load_new_plugin(bot, from_wxid, plugin_name)
return False
return True # 不是管理命令,继续传递
async def _list_plugins(self, bot, to_wxid: str):
"""列出所有插件"""
@@ -159,3 +206,200 @@ class ManagePlugin(PluginBase):
logger.info(f"插件 {plugin_name} 已被禁用")
else:
await bot.send_text(to_wxid, f"❌ 插件 {plugin_name} 禁用失败")
async def _show_help(self, bot, to_wxid: str):
"""显示帮助信息"""
help_text = """📦 插件管理帮助
/插件列表 - 查看所有插件状态
/插件帮助 - 显示此帮助信息
/加载插件 <名称> - 加载新插件(无需重启)
/重载插件 <名称> - 热重载指定插件
/重载所有插件 - 热重载所有插件
/启用插件 <名称> - 启用已禁用的插件
/禁用插件 <名称> - 禁用指定插件
/刷新插件 - 扫描发现新插件
示例:
/加载插件 NewPlugin
/重载插件 AIChat
/禁用插件 Weather"""
await bot.send_text(to_wxid, help_text)
async def _reload_all_plugins(self, bot, to_wxid: str):
"""重载所有插件"""
pm = PluginManager()
await bot.send_text(to_wxid, "⏳ 正在重载所有插件...")
try:
# 清理插件相关的模块缓存
modules_to_remove = [
name for name in sys.modules.keys()
if name.startswith('plugins.') and 'ManagePlugin' not in name
]
for module_name in modules_to_remove:
del sys.modules[module_name]
# 重载所有插件
reloaded = await pm.reload_plugins()
if reloaded:
await bot.send_text(
to_wxid,
f"✅ 重载完成\n已加载 {len(reloaded)} 个插件:\n" +
"\n".join(f"{name}" for name in reloaded)
)
logger.success(f"已重载 {len(reloaded)} 个插件")
else:
await bot.send_text(to_wxid, "⚠️ 没有插件被重载")
except Exception as e:
logger.error(f"重载所有插件失败: {e}")
await bot.send_text(to_wxid, f"❌ 重载失败: {e}")
async def _refresh_plugins(self, bot, to_wxid: str):
"""刷新插件列表,发现新插件"""
pm = PluginManager()
try:
# 记录刷新前的插件数量
old_count = len(pm.plugin_info)
# 刷新插件列表
await pm.refresh_plugins()
# 计算新发现的插件
new_count = len(pm.plugin_info)
new_plugins = new_count - old_count
if new_plugins > 0:
# 获取新发现的插件名称
new_plugin_names = [
info["name"] for info in pm.plugin_info.values()
if not info.get("enabled", False)
][-new_plugins:]
await bot.send_text(
to_wxid,
f"✅ 发现 {new_plugins} 个新插件:\n" +
"\n".join(f"{name}" for name in new_plugin_names) +
"\n\n使用 /启用插件 <名称> 来启用"
)
logger.info(f"发现 {new_plugins} 个新插件: {new_plugin_names}")
else:
await bot.send_text(to_wxid, " 没有发现新插件")
except Exception as e:
logger.error(f"刷新插件失败: {e}")
await bot.send_text(to_wxid, f"❌ 刷新失败: {e}")
async def _load_new_plugin(self, bot, to_wxid: str, plugin_name: str):
"""加载全新的插件(支持插件类名或目录名)"""
import os
import importlib
import inspect
pm = PluginManager()
# 检查是否已加载
if plugin_name in pm.plugins:
await bot.send_text(to_wxid, f" 插件 {plugin_name} 已经加载,如需重载请使用 /重载插件")
return
try:
# 尝试查找插件
found = False
plugin_class = None
actual_plugin_name = None
for dirname in os.listdir("plugins"):
dirpath = f"plugins/{dirname}"
if not os.path.isdir(dirpath) or not os.path.exists(f"{dirpath}/main.py"):
continue
# 支持通过目录名或类名查找
if dirname == plugin_name or dirname.lower() == plugin_name.lower():
# 通过目录名匹配
module_name = f"plugins.{dirname}.main"
# 清理旧的模块缓存
if module_name in sys.modules:
del sys.modules[module_name]
# 导入模块
module = importlib.import_module(module_name)
# 查找插件类
for name, obj in inspect.getmembers(module):
if (inspect.isclass(obj) and
issubclass(obj, PluginBase) and
obj != PluginBase):
plugin_class = obj
actual_plugin_name = obj.__name__
found = True
break
if found:
break
# 尝试通过类名匹配
try:
module_name = f"plugins.{dirname}.main"
if module_name in sys.modules:
del sys.modules[module_name]
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module):
if (inspect.isclass(obj) and
issubclass(obj, PluginBase) and
obj != PluginBase and
(obj.__name__ == plugin_name or obj.__name__.lower() == plugin_name.lower())):
plugin_class = obj
actual_plugin_name = obj.__name__
found = True
break
if found:
break
except:
continue
if not found or not plugin_class:
await bot.send_text(
to_wxid,
f"❌ 未找到插件 {plugin_name}\n"
f"请确认:\n"
f"1. plugins/{plugin_name}/main.py 存在\n"
f"2. main.py 中有继承 PluginBase 的类"
)
return
# 检查是否已加载(用实际类名再检查一次)
if actual_plugin_name in pm.plugins:
await bot.send_text(to_wxid, f" 插件 {actual_plugin_name} 已经加载")
return
# 加载插件
success = await pm._load_plugin_class(plugin_class)
if success:
await bot.send_text(
to_wxid,
f"✅ 插件加载成功\n"
f"名称: {actual_plugin_name}\n"
f"版本: {plugin_class.version}\n"
f"作者: {plugin_class.author}"
)
logger.success(f"新插件 {actual_plugin_name} 已加载")
else:
await bot.send_text(to_wxid, f"❌ 插件 {actual_plugin_name} 加载失败")
except Exception as e:
import traceback
logger.error(f"加载新插件失败: {e}\n{traceback.format_exc()}")
await bot.send_text(to_wxid, f"❌ 加载失败: {e}")

100
plugins/Menu/main.py Normal file
View File

@@ -0,0 +1,100 @@
"""
菜单插件
用户发送 /菜单、/帮助 等指令时,按顺序发送菜单图片
图片命名格式menu1.png、menu2.png、menu3.png ...
"""
import re
import asyncio
from pathlib import Path
from loguru import logger
from utils.plugin_base import PluginBase
from utils.decorators import on_text_message
class Menu(PluginBase):
"""菜单插件"""
# 插件元数据
description = "菜单插件,发送帮助图片"
author = "ShiHao"
version = "1.0.0"
def __init__(self):
super().__init__()
self.menu_dir = None
self.trigger_commands = ["/菜单", "/帮助", "/help", "/menu"]
self.send_interval = 0.5 # 发送间隔(秒),避免发送过快
async def async_init(self):
"""插件异步初始化"""
# 设置菜单图片目录
self.menu_dir = Path(__file__).parent / "images"
self.menu_dir.mkdir(exist_ok=True)
# 扫描现有图片
images = self._get_menu_images()
logger.info(f"菜单插件已加载,图片目录: {self.menu_dir},找到 {len(images)} 张菜单图片")
def _get_menu_images(self) -> list:
"""
获取所有符合命名规范的菜单图片,按序号排序
命名格式menu1.png、menu2.jpg、menu3.jpeg 等
"""
if not self.menu_dir or not self.menu_dir.exists():
return []
# 匹配 menu + 数字 + 图片扩展名
pattern = re.compile(r'^menu(\d+)\.(png|jpg|jpeg|gif|bmp)$', re.IGNORECASE)
images = []
for file in self.menu_dir.iterdir():
if file.is_file():
match = pattern.match(file.name)
if match:
seq_num = int(match.group(1))
images.append((seq_num, file))
# 按序号排序
images.sort(key=lambda x: x[0])
# 只返回文件路径
return [img[1] for img in images]
@on_text_message(priority=60)
async def handle_menu_command(self, bot, message: dict):
"""处理菜单指令"""
content = message.get("Content", "").strip()
from_wxid = message.get("FromWxid", "")
# 检查是否是菜单指令
if content not in self.trigger_commands:
return True # 继续传递给其他插件
logger.info(f"收到菜单指令: {content}, from: {from_wxid}")
# 获取菜单图片
images = self._get_menu_images()
if not images:
await bot.send_text(from_wxid, "暂无菜单图片,请联系管理员添加")
logger.warning(f"菜单图片目录为空: {self.menu_dir}")
return False
# 按顺序发送图片
for i, image_path in enumerate(images):
try:
await bot.send_image(from_wxid, str(image_path))
logger.debug(f"已发送菜单图片 {i+1}/{len(images)}: {image_path.name}")
# 发送间隔,避免发送过快
if i < len(images) - 1:
await asyncio.sleep(self.send_interval)
except Exception as e:
logger.error(f"发送菜单图片失败: {image_path}, 错误: {e}")
logger.success(f"菜单图片发送完成,共 {len(images)}")
return False # 阻止继续传递

View File

@@ -18,6 +18,7 @@ from utils.decorators import (
on_file_message,
on_emoji_message
)
from utils.redis_cache import RedisCache, get_cache
import pymysql
from WechatHook import WechatHookClient
from minio import Minio
@@ -39,6 +40,7 @@ class MessageLogger(PluginBase):
super().__init__()
self.config = None
self.db_config = None
self.redis_cache = None # Redis 缓存实例
# 创建独立的日志记录器
self._setup_logger()
@@ -83,9 +85,22 @@ class MessageLogger(PluginBase):
self.db_config = self.config["database"]
# 初始化 Redis 缓存
redis_config = self.config.get("redis", {})
if redis_config.get("enabled", False):
self.log.info("正在初始化 Redis 缓存...")
self.redis_cache = RedisCache(redis_config)
if self.redis_cache.enabled:
self.log.success(f"Redis 缓存初始化成功TTL={redis_config.get('ttl', 3600)}")
else:
self.log.warning("Redis 缓存初始化失败,将不使用缓存")
self.redis_cache = None
else:
self.log.info("Redis 缓存未启用")
# 初始化 MinIO 客户端
self.minio_client = Minio(
"101.201.65.129:19000",
"115.190.113.141:19000",
access_key="admin",
secret_key="80012029Lz",
secure=False
@@ -216,7 +231,7 @@ class MessageLogger(PluginBase):
return ("", "", "", "", "0")
async def download_image_and_upload(self, bot, cdnurl: str, aeskey: str) -> str:
"""下载图片并上传到 MinIO"""
"""下载图片并上传到 MinIO,同时缓存 base64 供其他插件使用"""
try:
temp_file = Path(__file__).parent / f"temp_{uuid.uuid4().hex}.jpg"
success = await bot.cdn_download(cdnurl, aeskey, str(temp_file), file_type=2)
@@ -225,12 +240,26 @@ class MessageLogger(PluginBase):
# 等待文件下载完成
import asyncio
import base64
for _ in range(50):
if temp_file.exists() and temp_file.stat().st_size > 0:
break
await asyncio.sleep(0.1)
if temp_file.exists() and temp_file.stat().st_size > 0:
# 读取文件并缓存 base64供 AIChat 等插件使用)
with open(temp_file, "rb") as f:
image_data = f.read()
base64_data = f"data:image/jpeg;base64,{base64.b64encode(image_data).decode()}"
# 缓存到 Redis5分钟过期
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
if media_key:
redis_cache.cache_media(media_key, base64_data, "image", ttl=300)
self.log.debug(f"图片已缓存到 Redis: {media_key[:20]}...")
media_url = await self.upload_file_to_minio(str(temp_file), "images")
temp_file.unlink()
return media_url
@@ -326,13 +355,25 @@ class MessageLogger(PluginBase):
return ""
async def download_and_upload(self, url: str, file_type: str, ext: str) -> str:
"""下载文件并上传到 MinIO"""
"""下载文件并上传到 MinIO,同时缓存 base64 供其他插件使用"""
try:
import base64
# 下载文件
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
if resp.status == 200:
data = await resp.read()
# 缓存表情包 base64供 AIChat 等插件使用)
if file_type == "emojis" and data:
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
base64_data = f"data:image/gif;base64,{base64.b64encode(data).decode()}"
media_key = RedisCache.generate_media_key(cdnurl=url)
if media_key:
redis_cache.cache_media(media_key, base64_data, "emoji", ttl=300)
self.log.debug(f"表情包已缓存到 Redis: {media_key[:20]}...")
# 保存到临时文件
temp_file = Path(__file__).parent / f"temp_{uuid.uuid4().hex}{ext}"
temp_file.write_bytes(data)
@@ -374,7 +415,7 @@ class MessageLogger(PluginBase):
)
# 返回访问 URL
url = f"http://101.201.65.129:19000/{self.minio_bucket}/{object_name}"
url = f"http://115.190.113.141:19000/{self.minio_bucket}/{object_name}"
self.log.debug(f"文件上传成功: {url}")
return url
@@ -405,29 +446,45 @@ class MessageLogger(PluginBase):
avatar_url = ""
if is_group and self.config["behavior"]["fetch_avatar"]:
try:
self.log.info(f"尝试获取用户信息: from_wxid={from_wxid}, sender_wxid={sender_wxid}")
user_info = await bot.get_user_info_in_chatroom(from_wxid, sender_wxid)
self.log.info(f"获取到用户信息: {user_info}")
cache_hit = False
if user_info:
# 处理不同的数据结构
if isinstance(user_info.get("nickName"), dict):
nickname = user_info.get("nickName", {}).get("string", "")
# 1. 先尝试从 Redis 缓存获取
if self.redis_cache and self.redis_cache.enabled:
cached_info = self.redis_cache.get_user_basic_info(from_wxid, sender_wxid)
if cached_info:
nickname = cached_info.get("nickname", "")
avatar_url = cached_info.get("avatar_url", "")
if nickname and avatar_url:
cache_hit = True
self.log.debug(f"[缓存命中] {sender_wxid}: {nickname}")
# 2. 缓存未命中,调用 API 获取
if not cache_hit:
try:
self.log.info(f"[缓存未命中] 调用API获取用户信息: {sender_wxid}")
user_info = await bot.get_user_info_in_chatroom(from_wxid, sender_wxid)
if user_info:
# 处理不同的数据结构
if isinstance(user_info.get("nickName"), dict):
nickname = user_info.get("nickName", {}).get("string", "")
else:
nickname = user_info.get("nickName", "")
avatar_url = user_info.get("bigHeadImgUrl", "")
self.log.info(f"API获取成功: nickname={nickname}, avatar_url={avatar_url[:50] if avatar_url else ''}...")
# 3. 将用户信息存入 Redis 缓存
if self.redis_cache and self.redis_cache.enabled and nickname:
self.redis_cache.set_user_info(from_wxid, sender_wxid, user_info)
self.log.debug(f"[已缓存] {sender_wxid}: {nickname}")
else:
nickname = user_info.get("nickName", "")
self.log.warning(f"用户信息为空: {sender_wxid}")
avatar_url = user_info.get("bigHeadImgUrl", "")
self.log.info(f"解析用户信息: nickname={nickname}, avatar_url={avatar_url[:50]}...")
else:
self.log.warning(f"用户信息为空: {sender_wxid}")
except Exception as e:
self.log.error(f"获取用户信息失败: {e}")
except Exception as e:
self.log.error(f"获取用户信息失败: {e}")
import traceback
self.log.error(f"详细错误: {traceback.format_exc()}")
# 如果获取失败,从历史记录中查找
# 4. 如果仍然没有获取到,从历史记录中查找
if not nickname or not avatar_url:
self.log.info(f"尝试从历史记录获取用户信息: {sender_wxid}")
try:

View File

@@ -46,7 +46,7 @@ class PerformanceMonitor(PluginBase):
if sender_wxid not in admins:
return
if content in ["/性能", "/stats", "/状态"]:
if content in ["/性能", "/stats", "/状态", "/性能报告"]:
stats_msg = await self._get_performance_stats(bot)
await bot.send_text(from_wxid, stats_msg)
return False # 阻止其他插件处理
@@ -66,11 +66,21 @@ class PerformanceMonitor(PluginBase):
async def _get_performance_stats(self, bot) -> str:
"""获取性能统计信息"""
try:
# 尝试使用新的性能监控器
try:
from utils.bot_utils import get_performance_monitor
monitor = get_performance_monitor()
if monitor and monitor.message_received > 0:
return self._format_new_stats(monitor)
except ImportError:
pass
# 降级到旧的统计方式
stats = await self._get_performance_data()
# 格式化统计信息
uptime_hours = (time.time() - self.start_time) / 3600
msg = f"""📊 系统性能统计
🕐 运行时间: {uptime_hours:.1f} 小时
@@ -94,11 +104,57 @@ class PerformanceMonitor(PluginBase):
• 过滤模式: {stats['ignore_mode']}"""
return msg
except Exception as e:
logger.error(f"获取性能统计失败: {e}")
return f"❌ 获取性能统计失败: {str(e)}"
def _format_new_stats(self, monitor) -> str:
"""格式化新性能监控器的统计信息"""
stats = monitor.get_stats()
# 基础信息
msg = f"""📊 系统性能报告
🕐 运行时间: {stats['uptime_formatted']}
📨 消息统计:
• 收到: {stats['messages']['received']}
• 处理: {stats['messages']['processed']}
• 失败: {stats['messages']['failed']}
• 丢弃: {stats['messages']['dropped']}
• 成功率: {stats['messages']['success_rate']}
• 处理速率: {stats['messages']['processing_rate']}
⚡ 处理性能:
• 平均耗时: {stats['processing_time']['average_ms']}ms
• 最大耗时: {stats['processing_time']['max_ms']}ms
• 最小耗时: {stats['processing_time']['min_ms']}ms
📦 队列状态:
• 当前大小: {stats['queue']['current_size']}
• 历史最大: {stats['queue']['max_size']}"""
# 熔断器状态
cb = stats.get('circuit_breaker', {})
if cb:
state_icon = {'closed': '🟢', 'open': '🔴', 'half_open': '🟡'}.get(cb.get('state', ''), '')
msg += f"""
🔌 熔断器:
• 状态: {state_icon} {cb.get('state', 'N/A')}
• 失败计数: {cb.get('failure_count', 0)}
• 恢复时间: {cb.get('current_recovery_time', 0):.0f}s"""
# 插件耗时排行
plugins = stats.get('plugins', [])
if plugins:
msg += "\n\n🔧 插件耗时排行:"
for i, p in enumerate(plugins[:5], 1):
msg += f"\n {i}. {p['name']}: {p['avg_time_ms']}ms ({p['calls']}次)"
return msg
async def _get_performance_data(self) -> dict:
"""获取性能数据"""
# 系统资源简化版本不依赖psutil

View File

@@ -19,6 +19,7 @@ import pymysql
from loguru import logger
from utils.plugin_base import PluginBase
from utils.decorators import on_text_message
from utils.redis_cache import get_cache
from WechatHook import WechatHookClient
try:
@@ -49,14 +50,14 @@ class SignInPlugin(PluginBase):
self.config = tomllib.load(f)
self.db_config = self.config["database"]
# 创建临时文件夹
self.temp_dir = Path(__file__).parent / "temp"
self.temp_dir.mkdir(exist_ok=True)
# 图片文件夹
self.images_dir = Path(__file__).parent / "images"
logger.success("签到插件初始化完成")
def get_db_connection(self):
@@ -149,29 +150,42 @@ class SignInPlugin(PluginBase):
logger.error(f"更新用户昵称失败: {e}")
return False
async def get_user_nickname_from_group(self, client: WechatHookClient,
async def get_user_nickname_from_group(self, client: WechatHookClient,
group_wxid: str, user_wxid: str) -> str:
"""从群聊中获取用户昵称"""
"""从群聊中获取用户昵称(优先使用缓存)"""
try:
logger.debug(f"尝试获取用户 {user_wxid} 在群 {group_wxid} 中的昵称")
# 使用11174 API获取单个用户的详细信息
# 动态获取缓存实例(由 MessageLogger 初始化)
redis_cache = get_cache()
# 1. 先尝试从 Redis 缓存获取
if redis_cache and redis_cache.enabled:
cached_info = redis_cache.get_user_basic_info(group_wxid, user_wxid)
if cached_info and cached_info.get("nickname"):
logger.debug(f"[缓存命中] {user_wxid}: {cached_info['nickname']}")
return cached_info["nickname"]
# 2. 缓存未命中,调用 API 获取
logger.debug(f"[缓存未命中] 调用API获取用户昵称: {user_wxid}")
user_info = await client.get_user_info_in_chatroom(group_wxid, user_wxid)
if user_info:
# 从返回的详细信息中提取昵称
nickname = user_info.get("nickName", {}).get("string", "")
if nickname:
logger.success(f"获取用户昵称: {user_wxid} -> {nickname}")
logger.success(f"API获取用户昵称成功: {user_wxid} -> {nickname}")
# 3. 将用户信息存入缓存
if redis_cache and redis_cache.enabled:
redis_cache.set_user_info(group_wxid, user_wxid, user_info)
logger.debug(f"[已缓存] {user_wxid}: {nickname}")
return nickname
else:
logger.warning(f"用户 {user_wxid} 的昵称字段为空")
else:
logger.warning(f"未找到用户 {user_wxid} 在群 {group_wxid} 中的信息")
return ""
except Exception as e:
logger.error(f"获取群成员昵称失败: {e}")
return ""
@@ -770,13 +784,24 @@ class SignInPlugin(PluginBase):
current_points = updated_user["points"] if updated_user else total_earned
updated_user["points"] = current_points
# 尝试获取用户头像
# 尝试获取用户头像(优先使用缓存)
avatar_url = None
if is_group:
try:
user_detail = await client.get_user_info_in_chatroom(from_wxid, user_wxid)
if user_detail:
avatar_url = user_detail.get("bigHeadImgUrl", "")
redis_cache = get_cache()
# 先从缓存获取
if redis_cache and redis_cache.enabled:
cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid)
if cached_info:
avatar_url = cached_info.get("avatar_url", "")
# 缓存未命中则调用 API
if not avatar_url:
user_detail = await client.get_user_info_in_chatroom(from_wxid, user_wxid)
if user_detail:
avatar_url = user_detail.get("bigHeadImgUrl", "")
# 存入缓存
if redis_cache and redis_cache.enabled:
redis_cache.set_user_info(from_wxid, user_wxid, user_detail)
except Exception as e:
logger.warning(f"获取用户头像失败: {e}")
@@ -864,13 +889,24 @@ class SignInPlugin(PluginBase):
self.update_user_nickname(user_wxid, nickname)
user_info["nickname"] = nickname
# 尝试获取用户头像
# 尝试获取用户头像(优先使用缓存)
avatar_url = None
if is_group:
try:
user_detail = await client.get_user_info_in_chatroom(from_wxid, user_wxid)
if user_detail:
avatar_url = user_detail.get("bigHeadImgUrl", "")
redis_cache = get_cache()
# 先从缓存获取
if redis_cache and redis_cache.enabled:
cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid)
if cached_info:
avatar_url = cached_info.get("avatar_url", "")
# 缓存未命中则调用 API
if not avatar_url:
user_detail = await client.get_user_info_in_chatroom(from_wxid, user_wxid)
if user_detail:
avatar_url = user_detail.get("bigHeadImgUrl", "")
# 存入缓存
if redis_cache and redis_cache.enabled:
redis_cache.set_user_info(from_wxid, user_wxid, user_detail)
except Exception as e:
logger.warning(f"获取用户头像失败: {e}")

View File

@@ -3,3 +3,4 @@ APScheduler==3.11.0
aiohttp==3.9.1
Pillow>=10.0.0
aiohttp-socks>=0.8.0
redis>=5.0.0

658
utils/bot_utils.py Normal file
View File

@@ -0,0 +1,658 @@
"""
机器人核心工具模块
包含:
- 优先级消息队列
- 自适应熔断器
- 请求重试机制
- 配置热更新
- 性能监控
"""
import asyncio
import time
import heapq
import os
import tomllib
import functools
from pathlib import Path
from typing import Dict, List, Optional, Callable, Any, Tuple
from dataclasses import dataclass, field
from loguru import logger
import aiohttp
# ==================== 优先级消息队列 ====================
# 消息优先级定义
class MessagePriority:
"""消息优先级常量"""
CRITICAL = 100 # 系统消息、登录信息
HIGH = 80 # 管理员命令
NORMAL = 50 # @机器人的消息
LOW = 20 # 普通群消息
# 高优先级消息类型
PRIORITY_MESSAGE_TYPES = {
11025: MessagePriority.CRITICAL, # 登录信息
11058: MessagePriority.CRITICAL, # 系统消息
11098: MessagePriority.HIGH, # 群成员加入
11099: MessagePriority.HIGH, # 群成员退出
11100: MessagePriority.HIGH, # 群信息变更
11056: MessagePriority.HIGH, # 好友请求
}
@dataclass(order=True)
class PriorityMessage:
"""优先级消息包装"""
priority: int
timestamp: float = field(compare=False)
msg_type: int = field(compare=False)
data: dict = field(compare=False)
def __init__(self, msg_type: int, data: dict, priority: int = None):
# 优先级越高数值越小因为heapq是最小堆
self.priority = -(priority or PRIORITY_MESSAGE_TYPES.get(msg_type, MessagePriority.NORMAL))
self.timestamp = time.time()
self.msg_type = msg_type
self.data = data
class PriorityMessageQueue:
"""优先级消息队列"""
def __init__(self, maxsize: int = 1000):
self.maxsize = maxsize
self._heap: List[PriorityMessage] = []
self._lock = asyncio.Lock()
self._not_empty = asyncio.Event()
self._unfinished_tasks = 0
self._finished = asyncio.Event()
self._finished.set()
def qsize(self) -> int:
"""返回队列大小"""
return len(self._heap)
def empty(self) -> bool:
"""队列是否为空"""
return len(self._heap) == 0
def full(self) -> bool:
"""队列是否已满"""
return len(self._heap) >= self.maxsize
async def put(self, msg_type: int, data: dict, priority: int = None):
"""添加消息到队列"""
async with self._lock:
msg = PriorityMessage(msg_type, data, priority)
heapq.heappush(self._heap, msg)
self._unfinished_tasks += 1
self._finished.clear()
self._not_empty.set()
async def get(self) -> Tuple[int, dict]:
"""获取优先级最高的消息"""
while True:
async with self._lock:
if self._heap:
msg = heapq.heappop(self._heap)
if not self._heap:
self._not_empty.clear()
return (msg.msg_type, msg.data)
# 等待新消息
await self._not_empty.wait()
def get_nowait(self) -> Tuple[int, dict]:
"""非阻塞获取消息"""
if not self._heap:
raise asyncio.QueueEmpty()
msg = heapq.heappop(self._heap)
if not self._heap:
self._not_empty.clear()
return (msg.msg_type, msg.data)
def task_done(self):
"""标记任务完成"""
self._unfinished_tasks -= 1
if self._unfinished_tasks == 0:
self._finished.set()
async def join(self):
"""等待所有任务完成"""
await self._finished.wait()
def drop_lowest_priority(self) -> bool:
"""丢弃优先级最低的消息"""
if not self._heap:
return False
# 找到优先级最低的消息priority值最大因为是负数所以最小
min_idx = 0
for i, msg in enumerate(self._heap):
if msg.priority > self._heap[min_idx].priority:
min_idx = i
# 删除该消息
self._heap.pop(min_idx)
heapq.heapify(self._heap)
self._unfinished_tasks -= 1
return True
# ==================== 自适应熔断器 ====================
class AdaptiveCircuitBreaker:
"""自适应熔断器"""
# 熔断器状态
STATE_CLOSED = "closed" # 正常状态
STATE_OPEN = "open" # 熔断状态
STATE_HALF_OPEN = "half_open" # 半开状态(尝试恢复)
def __init__(
self,
failure_threshold: int = 5,
success_threshold: int = 3,
initial_recovery_time: float = 5.0,
max_recovery_time: float = 300.0,
recovery_multiplier: float = 2.0
):
"""
初始化熔断器
Args:
failure_threshold: 触发熔断的连续失败次数
success_threshold: 恢复正常的连续成功次数
initial_recovery_time: 初始恢复等待时间(秒)
max_recovery_time: 最大恢复等待时间(秒)
recovery_multiplier: 恢复时间增长倍数
"""
self.failure_threshold = failure_threshold
self.success_threshold = success_threshold
self.initial_recovery_time = initial_recovery_time
self.max_recovery_time = max_recovery_time
self.recovery_multiplier = recovery_multiplier
# 状态
self.state = self.STATE_CLOSED
self.failure_count = 0
self.success_count = 0
self.last_failure_time = 0
self.current_recovery_time = initial_recovery_time
# 统计
self.total_failures = 0
self.total_successes = 0
self.total_rejections = 0
def is_open(self) -> bool:
"""检查熔断器是否开启(是否应该拒绝请求)"""
if self.state == self.STATE_CLOSED:
return False
if self.state == self.STATE_OPEN:
# 检查是否可以尝试恢复
elapsed = time.time() - self.last_failure_time
if elapsed >= self.current_recovery_time:
self.state = self.STATE_HALF_OPEN
self.success_count = 0
logger.info(f"熔断器进入半开状态,尝试恢复(等待了 {elapsed:.1f}s")
return False
return True
# 半开状态,允许请求通过
return False
def record_success(self):
"""记录成功"""
self.total_successes += 1
if self.state == self.STATE_HALF_OPEN:
self.success_count += 1
if self.success_count >= self.success_threshold:
# 恢复正常
self.state = self.STATE_CLOSED
self.failure_count = 0
self.success_count = 0
self.current_recovery_time = self.initial_recovery_time
logger.success(f"熔断器已恢复正常(连续成功 {self.success_threshold} 次)")
else:
# 正常状态,重置失败计数
self.failure_count = 0
def record_failure(self):
"""记录失败"""
self.total_failures += 1
self.failure_count += 1
self.last_failure_time = time.time()
if self.state == self.STATE_HALF_OPEN:
# 半开状态下失败,重新熔断
self.state = self.STATE_OPEN
self.success_count = 0
# 增加恢复时间
self.current_recovery_time = min(
self.current_recovery_time * self.recovery_multiplier,
self.max_recovery_time
)
logger.warning(f"熔断器重新开启,下次恢复等待 {self.current_recovery_time:.1f}s")
elif self.state == self.STATE_CLOSED:
if self.failure_count >= self.failure_threshold:
self.state = self.STATE_OPEN
logger.warning(f"熔断器开启,连续失败 {self.failure_count}")
def record_rejection(self):
"""记录被拒绝的请求"""
self.total_rejections += 1
def get_stats(self) -> dict:
"""获取统计信息"""
return {
"state": self.state,
"failure_count": self.failure_count,
"success_count": self.success_count,
"current_recovery_time": self.current_recovery_time,
"total_failures": self.total_failures,
"total_successes": self.total_successes,
"total_rejections": self.total_rejections
}
# ==================== 请求重试机制 ====================
class RetryConfig:
"""重试配置"""
def __init__(
self,
max_retries: int = 3,
initial_delay: float = 1.0,
max_delay: float = 30.0,
exponential_base: float = 2.0,
retryable_exceptions: tuple = (
aiohttp.ClientError,
asyncio.TimeoutError,
ConnectionError,
)
):
self.max_retries = max_retries
self.initial_delay = initial_delay
self.max_delay = max_delay
self.exponential_base = exponential_base
self.retryable_exceptions = retryable_exceptions
def retry_async(config: RetryConfig = None):
"""异步重试装饰器"""
if config is None:
config = RetryConfig()
def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(config.max_retries + 1):
try:
return await func(*args, **kwargs)
except config.retryable_exceptions as e:
last_exception = e
if attempt == config.max_retries:
logger.error(f"重试 {config.max_retries} 次后仍然失败: {func.__name__}")
raise
# 计算延迟时间(指数退避)
delay = min(
config.initial_delay * (config.exponential_base ** attempt),
config.max_delay
)
logger.warning(
f"请求失败,{delay:.1f}s 后重试 "
f"(第 {attempt + 1}/{config.max_retries} 次): {e}"
)
await asyncio.sleep(delay)
raise last_exception
return wrapper
return decorator
async def request_with_retry(
session: aiohttp.ClientSession,
method: str,
url: str,
max_retries: int = 3,
**kwargs
) -> aiohttp.ClientResponse:
"""带重试的 HTTP 请求"""
config = RetryConfig(max_retries=max_retries)
last_exception = None
for attempt in range(config.max_retries + 1):
try:
response = await session.request(method, url, **kwargs)
return response
except config.retryable_exceptions as e:
last_exception = e
if attempt == config.max_retries:
raise
delay = min(
config.initial_delay * (config.exponential_base ** attempt),
config.max_delay
)
logger.warning(f"HTTP 请求失败,{delay:.1f}s 后重试: {e}")
await asyncio.sleep(delay)
raise last_exception
# ==================== 配置热更新 ====================
class ConfigWatcher:
"""配置文件监听器"""
def __init__(self, config_path: str, check_interval: float = 5.0):
"""
初始化配置监听器
Args:
config_path: 配置文件路径
check_interval: 检查间隔(秒)
"""
self.config_path = Path(config_path)
self.check_interval = check_interval
self.last_mtime = 0
self.callbacks: List[Callable[[dict], Any]] = []
self.current_config: dict = {}
self._running = False
self._task: Optional[asyncio.Task] = None
def register_callback(self, callback: Callable[[dict], Any]):
"""注册配置更新回调"""
self.callbacks.append(callback)
def unregister_callback(self, callback: Callable[[dict], Any]):
"""取消注册回调"""
if callback in self.callbacks:
self.callbacks.remove(callback)
def _load_config(self) -> dict:
"""加载配置文件"""
try:
with open(self.config_path, "rb") as f:
return tomllib.load(f)
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
return {}
def get_config(self) -> dict:
"""获取当前配置"""
return self.current_config
async def start(self):
"""启动配置监听"""
if self._running:
return
self._running = True
# 初始加载
if self.config_path.exists():
self.last_mtime = os.path.getmtime(self.config_path)
self.current_config = self._load_config()
self._task = asyncio.create_task(self._watch_loop())
logger.info(f"配置监听器已启动: {self.config_path}")
async def stop(self):
"""停止配置监听"""
self._running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
logger.info("配置监听器已停止")
async def _watch_loop(self):
"""监听循环"""
while self._running:
try:
await asyncio.sleep(self.check_interval)
if not self.config_path.exists():
continue
mtime = os.path.getmtime(self.config_path)
if mtime > self.last_mtime:
logger.info("检测到配置文件变化,重新加载...")
new_config = self._load_config()
if new_config:
old_config = self.current_config
self.current_config = new_config
self.last_mtime = mtime
# 通知所有回调
for callback in self.callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(new_config)
else:
callback(new_config)
except Exception as e:
logger.error(f"配置更新回调执行失败: {e}")
logger.success("配置已热更新")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"配置监听异常: {e}")
# ==================== 性能监控 ====================
class PerformanceMonitor:
"""性能监控器"""
def __init__(self):
self.start_time = time.time()
# 消息统计
self.message_received = 0
self.message_processed = 0
self.message_failed = 0
self.message_dropped = 0
# 处理时间统计
self.processing_times: List[float] = []
self.max_processing_times = 1000 # 保留最近1000条记录
# 插件统计
self.plugin_stats: Dict[str, dict] = {}
# 队列统计
self.queue_size_history: List[Tuple[float, int]] = []
self.max_queue_history = 100
# 熔断器统计
self.circuit_breaker_stats: dict = {}
def record_message_received(self):
"""记录收到消息"""
self.message_received += 1
def record_message_processed(self, processing_time: float):
"""记录消息处理完成"""
self.message_processed += 1
self.processing_times.append(processing_time)
# 限制历史记录数量
if len(self.processing_times) > self.max_processing_times:
self.processing_times = self.processing_times[-self.max_processing_times:]
def record_message_failed(self):
"""记录消息处理失败"""
self.message_failed += 1
def record_message_dropped(self):
"""记录消息被丢弃"""
self.message_dropped += 1
def record_queue_size(self, size: int):
"""记录队列大小"""
self.queue_size_history.append((time.time(), size))
if len(self.queue_size_history) > self.max_queue_history:
self.queue_size_history = self.queue_size_history[-self.max_queue_history:]
def record_plugin_execution(self, plugin_name: str, execution_time: float, success: bool):
"""记录插件执行"""
if plugin_name not in self.plugin_stats:
self.plugin_stats[plugin_name] = {
"total_calls": 0,
"success_calls": 0,
"failed_calls": 0,
"total_time": 0,
"max_time": 0,
"recent_times": []
}
stats = self.plugin_stats[plugin_name]
stats["total_calls"] += 1
stats["total_time"] += execution_time
stats["max_time"] = max(stats["max_time"], execution_time)
stats["recent_times"].append(execution_time)
if len(stats["recent_times"]) > 100:
stats["recent_times"] = stats["recent_times"][-100:]
if success:
stats["success_calls"] += 1
else:
stats["failed_calls"] += 1
def update_circuit_breaker_stats(self, stats: dict):
"""更新熔断器统计"""
self.circuit_breaker_stats = stats
def get_stats(self) -> dict:
"""获取完整统计信息"""
uptime = time.time() - self.start_time
# 计算平均处理时间
avg_processing_time = 0
if self.processing_times:
avg_processing_time = sum(self.processing_times) / len(self.processing_times)
# 计算处理速率
processing_rate = self.message_processed / uptime if uptime > 0 else 0
# 计算成功率
total = self.message_processed + self.message_failed
success_rate = self.message_processed / total if total > 0 else 1.0
return {
"uptime_seconds": uptime,
"uptime_formatted": self._format_uptime(uptime),
"messages": {
"received": self.message_received,
"processed": self.message_processed,
"failed": self.message_failed,
"dropped": self.message_dropped,
"success_rate": f"{success_rate * 100:.1f}%",
"processing_rate": f"{processing_rate:.2f}/s"
},
"processing_time": {
"average_ms": f"{avg_processing_time * 1000:.1f}",
"max_ms": f"{max(self.processing_times) * 1000:.1f}" if self.processing_times else "0",
"min_ms": f"{min(self.processing_times) * 1000:.1f}" if self.processing_times else "0"
},
"queue": {
"current_size": self.queue_size_history[-1][1] if self.queue_size_history else 0,
"max_size": max(s[1] for s in self.queue_size_history) if self.queue_size_history else 0
},
"circuit_breaker": self.circuit_breaker_stats,
"plugins": self._get_plugin_summary()
}
def _get_plugin_summary(self) -> List[dict]:
"""获取插件统计摘要"""
summary = []
for name, stats in self.plugin_stats.items():
avg_time = stats["total_time"] / stats["total_calls"] if stats["total_calls"] > 0 else 0
summary.append({
"name": name,
"calls": stats["total_calls"],
"success_rate": f"{stats['success_calls'] / stats['total_calls'] * 100:.1f}%" if stats["total_calls"] > 0 else "N/A",
"avg_time_ms": f"{avg_time * 1000:.1f}",
"max_time_ms": f"{stats['max_time'] * 1000:.1f}"
})
# 按平均时间排序
summary.sort(key=lambda x: float(x["avg_time_ms"]), reverse=True)
return summary
def _format_uptime(self, seconds: float) -> str:
"""格式化运行时间"""
days = int(seconds // 86400)
hours = int((seconds % 86400) // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
if days > 0:
return f"{days}{hours}小时 {minutes}分钟"
elif hours > 0:
return f"{hours}小时 {minutes}分钟"
elif minutes > 0:
return f"{minutes}分钟 {secs}"
else:
return f"{secs}"
def print_stats(self):
"""打印统计信息到日志"""
stats = self.get_stats()
logger.info("=" * 50)
logger.info("性能监控报告")
logger.info("=" * 50)
logger.info(f"运行时间: {stats['uptime_formatted']}")
logger.info(f"消息统计: 收到 {stats['messages']['received']}, "
f"处理 {stats['messages']['processed']}, "
f"失败 {stats['messages']['failed']}, "
f"丢弃 {stats['messages']['dropped']}")
logger.info(f"成功率: {stats['messages']['success_rate']}, "
f"处理速率: {stats['messages']['processing_rate']}")
logger.info(f"平均处理时间: {stats['processing_time']['average_ms']}ms")
logger.info(f"队列大小: {stats['queue']['current_size']}")
logger.info(f"熔断器状态: {stats['circuit_breaker'].get('state', 'N/A')}")
if stats['plugins']:
logger.info("插件耗时排行:")
for i, p in enumerate(stats['plugins'][:5], 1):
logger.info(f" {i}. {p['name']}: {p['avg_time_ms']}ms (调用 {p['calls']} 次)")
logger.info("=" * 50)
# ==================== 全局实例 ====================
# 性能监控器单例
_performance_monitor: Optional[PerformanceMonitor] = None
def get_performance_monitor() -> PerformanceMonitor:
"""获取性能监控器实例"""
global _performance_monitor
if _performance_monitor is None:
_performance_monitor = PerformanceMonitor()
return _performance_monitor

View File

@@ -1,4 +1,5 @@
from abc import ABC
from typing import List
from loguru import logger
@@ -13,6 +14,14 @@ class PluginBase(ABC):
author: str = "未知"
version: str = "1.0.0"
# 插件依赖(填写依赖的插件类名列表)
# 例如: dependencies = ["MessageLogger", "AIChat"]
dependencies: List[str] = []
# 加载优先级数值越大越先加载默认50
# 基础插件设置高优先级,依赖其他插件的设置低优先级
load_priority: int = 50
def __init__(self):
self.enabled = False
self._scheduled_jobs = set()

View File

@@ -117,24 +117,107 @@ class PluginManager(metaclass=Singleton):
if not found:
logger.warning(f"未找到插件类 {plugin_name}")
def _resolve_load_order(self, plugin_classes: List[Type[PluginBase]]) -> List[Type[PluginBase]]:
"""
解析插件加载顺序(拓扑排序 + 优先级排序)
Args:
plugin_classes: 插件类列表
Returns:
按依赖关系和优先级排序后的插件类列表
"""
# 构建插件名到类的映射
name_to_class = {cls.__name__: cls for cls in plugin_classes}
# 构建依赖图
dependencies = {}
for cls in plugin_classes:
deps = getattr(cls, 'dependencies', [])
dependencies[cls.__name__] = [d for d in deps if d in name_to_class]
# 拓扑排序
sorted_names = []
visited = set()
temp_visited = set()
def visit(name: str):
if name in temp_visited:
# 检测到循环依赖
logger.warning(f"检测到循环依赖: {name}")
return
if name in visited:
return
temp_visited.add(name)
# 先访问依赖
for dep in dependencies.get(name, []):
visit(dep)
temp_visited.remove(name)
visited.add(name)
sorted_names.append(name)
# 按优先级排序后再进行拓扑排序
priority_sorted = sorted(
plugin_classes,
key=lambda cls: getattr(cls, 'load_priority', 50),
reverse=True
)
for cls in priority_sorted:
if cls.__name__ not in visited:
visit(cls.__name__)
# 返回排序后的类列表
return [name_to_class[name] for name in sorted_names if name in name_to_class]
async def load_plugins(self, load_disabled: bool = True) -> Union[List[str], bool]:
"""加载所有插件(按依赖顺序)"""
loaded_plugins = []
# 第一步:收集所有插件类
all_plugin_classes = []
plugin_disabled_map = {}
for dirname in os.listdir("plugins"):
if os.path.isdir(f"plugins/{dirname}") and os.path.exists(f"plugins/{dirname}/main.py"):
try:
module = importlib.import_module(f"plugins.{dirname}.main")
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase:
all_plugin_classes.append(obj)
# 记录是否禁用
is_disabled = False
if not load_disabled:
is_disabled = obj.__name__ in self.excluded_plugins or dirname in self.excluded_plugins
if await self._load_plugin_class(obj, is_disabled=is_disabled):
loaded_plugins.append(obj.__name__)
plugin_disabled_map[obj.__name__] = is_disabled
except:
logger.error(f"加载 {dirname} 时发生错误: {traceback.format_exc()}")
# 第二步:按依赖顺序排序
sorted_classes = self._resolve_load_order(all_plugin_classes)
logger.info(f"插件加载顺序: {[cls.__name__ for cls in sorted_classes]}")
# 第三步:按顺序加载插件
for plugin_class in sorted_classes:
plugin_name = plugin_class.__name__
is_disabled = plugin_disabled_map.get(plugin_name, False)
# 检查依赖是否已加载
deps = getattr(plugin_class, 'dependencies', [])
deps_satisfied = all(dep in self.plugins for dep in deps)
if not deps_satisfied and not is_disabled:
missing_deps = [dep for dep in deps if dep not in self.plugins]
logger.warning(f"插件 {plugin_name} 的依赖未满足: {missing_deps},跳过加载")
continue
if await self._load_plugin_class(plugin_class, is_disabled=is_disabled):
loaded_plugins.append(plugin_name)
return loaded_plugins
async def unload_plugin(self, plugin_name: str) -> bool:

744
utils/redis_cache.py Normal file
View File

@@ -0,0 +1,744 @@
"""
Redis 缓存工具类
用于缓存用户信息等数据,减少 API 调用
"""
import json
from typing import Optional, Dict, Any
from loguru import logger
try:
import redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
logger.warning("redis 库未安装,缓存功能将不可用")
class RedisCache:
"""Redis 缓存管理器"""
_instance = None
def __new__(cls, *args, **kwargs):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, config: Dict = None):
"""
初始化 Redis 连接
Args:
config: Redis 配置字典,包含 host, port, password, db 等
"""
if self._initialized:
return
self.client = None
self.enabled = False
self.default_ttl = 3600 # 默认过期时间 1 小时
if not REDIS_AVAILABLE:
logger.warning("Redis 库未安装,缓存功能禁用")
self._initialized = True
return
if config:
self.connect(config)
self._initialized = True
def connect(self, config: Dict) -> bool:
"""
连接 Redis
Args:
config: Redis 配置
Returns:
是否连接成功
"""
if not REDIS_AVAILABLE:
return False
try:
self.client = redis.Redis(
host=config.get("host", "localhost"),
port=config.get("port", 6379),
password=config.get("password", None),
db=config.get("db", 0),
decode_responses=True,
socket_timeout=5,
socket_connect_timeout=5
)
# 测试连接
self.client.ping()
self.enabled = True
self.default_ttl = config.get("ttl", 3600)
logger.success(f"Redis 连接成功: {config.get('host')}:{config.get('port')}")
return True
except Exception as e:
logger.error(f"Redis 连接失败: {e}")
self.client = None
self.enabled = False
return False
def _make_key(self, prefix: str, *args) -> str:
"""
生成缓存 key
Args:
prefix: key 前缀
*args: key 组成部分
Returns:
完整的 key
"""
parts = [prefix] + [str(arg) for arg in args]
return ":".join(parts)
def get(self, key: str) -> Optional[Any]:
"""
获取缓存值
Args:
key: 缓存 key
Returns:
缓存的值,不存在返回 None
"""
if not self.enabled or not self.client:
return None
try:
value = self.client.get(key)
if value:
return json.loads(value)
return None
except Exception as e:
logger.error(f"Redis GET 失败: {key}, {e}")
return None
def set(self, key: str, value: Any, ttl: int = None) -> bool:
"""
设置缓存值
Args:
key: 缓存 key
value: 要缓存的值
ttl: 过期时间(秒),默认使用 default_ttl
Returns:
是否设置成功
"""
if not self.enabled or not self.client:
return False
try:
ttl = ttl or self.default_ttl
self.client.setex(key, ttl, json.dumps(value, ensure_ascii=False))
return True
except Exception as e:
logger.error(f"Redis SET 失败: {key}, {e}")
return False
def delete(self, key: str) -> bool:
"""
删除缓存
Args:
key: 缓存 key
Returns:
是否删除成功
"""
if not self.enabled or not self.client:
return False
try:
self.client.delete(key)
return True
except Exception as e:
logger.error(f"Redis DELETE 失败: {key}, {e}")
return False
def delete_pattern(self, pattern: str) -> int:
"""
删除匹配模式的所有 key
Args:
pattern: key 模式,如 "user_info:*"
Returns:
删除的 key 数量
"""
if not self.enabled or not self.client:
return 0
try:
keys = self.client.keys(pattern)
if keys:
return self.client.delete(*keys)
return 0
except Exception as e:
logger.error(f"Redis DELETE PATTERN 失败: {pattern}, {e}")
return 0
# ==================== 用户信息缓存专用方法 ====================
def get_user_info(self, chatroom_id: str, user_wxid: str) -> Optional[Dict]:
"""
获取缓存的用户信息
Args:
chatroom_id: 群聊 ID
user_wxid: 用户 wxid
Returns:
用户信息字典,不存在返回 None
"""
key = self._make_key("user_info", chatroom_id, user_wxid)
return self.get(key)
def set_user_info(self, chatroom_id: str, user_wxid: str, user_info: Dict, ttl: int = None) -> bool:
"""
缓存用户信息
Args:
chatroom_id: 群聊 ID
user_wxid: 用户 wxid
user_info: 用户信息字典
ttl: 过期时间(秒)
Returns:
是否缓存成功
"""
key = self._make_key("user_info", chatroom_id, user_wxid)
return self.set(key, user_info, ttl)
def get_user_basic_info(self, chatroom_id: str, user_wxid: str) -> Optional[Dict]:
"""
获取缓存的用户基本信息(昵称和头像)
Args:
chatroom_id: 群聊 ID
user_wxid: 用户 wxid
Returns:
包含 nickname 和 avatar_url 的字典
"""
user_info = self.get_user_info(chatroom_id, user_wxid)
if user_info:
# 提取基本信息
nickname = ""
if isinstance(user_info.get("nickName"), dict):
nickname = user_info.get("nickName", {}).get("string", "")
else:
nickname = user_info.get("nickName", "")
avatar_url = user_info.get("bigHeadImgUrl", "")
if nickname or avatar_url:
return {
"nickname": nickname,
"avatar_url": avatar_url
}
return None
def clear_user_cache(self, chatroom_id: str = None, user_wxid: str = None) -> int:
"""
清除用户信息缓存
Args:
chatroom_id: 群聊 ID为空则清除所有群
user_wxid: 用户 wxid为空则清除该群所有用户
Returns:
清除的缓存数量
"""
if chatroom_id and user_wxid:
key = self._make_key("user_info", chatroom_id, user_wxid)
return 1 if self.delete(key) else 0
elif chatroom_id:
pattern = self._make_key("user_info", chatroom_id, "*")
return self.delete_pattern(pattern)
else:
return self.delete_pattern("user_info:*")
def get_cache_stats(self) -> Dict:
"""
获取缓存统计信息
Returns:
统计信息字典
"""
if not self.enabled or not self.client:
return {"enabled": False}
try:
info = self.client.info("memory")
user_keys = len(self.client.keys("user_info:*"))
chat_keys = len(self.client.keys("chat_history:*"))
return {
"enabled": True,
"used_memory": info.get("used_memory_human", "unknown"),
"user_info_count": user_keys,
"chat_history_count": chat_keys
}
except Exception as e:
logger.error(f"获取缓存统计失败: {e}")
return {"enabled": True, "error": str(e)}
# ==================== 对话历史缓存专用方法 ====================
def get_chat_history(self, chat_id: str, max_messages: int = 100) -> list:
"""
获取对话历史
Args:
chat_id: 会话ID私聊为用户wxid群聊为 群ID:用户ID
max_messages: 最大返回消息数
Returns:
消息列表
"""
if not self.enabled or not self.client:
return []
try:
key = self._make_key("chat_history", chat_id)
# 使用 LRANGE 获取最近的消息(列表尾部是最新的)
data = self.client.lrange(key, -max_messages, -1)
return [json.loads(item) for item in data]
except Exception as e:
logger.error(f"获取对话历史失败: {chat_id}, {e}")
return []
def add_chat_message(self, chat_id: str, role: str, content, ttl: int = 86400) -> bool:
"""
添加消息到对话历史
Args:
chat_id: 会话ID
role: 角色 (user/assistant)
content: 消息内容(字符串或列表)
ttl: 过期时间默认24小时
Returns:
是否添加成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("chat_history", chat_id)
message = {"role": role, "content": content}
self.client.rpush(key, json.dumps(message, ensure_ascii=False))
self.client.expire(key, ttl)
return True
except Exception as e:
logger.error(f"添加对话消息失败: {chat_id}, {e}")
return False
def trim_chat_history(self, chat_id: str, max_messages: int = 100) -> bool:
"""
裁剪对话历史保留最近的N条消息
Args:
chat_id: 会话ID
max_messages: 保留的最大消息数
Returns:
是否成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("chat_history", chat_id)
# 保留最后 max_messages 条
self.client.ltrim(key, -max_messages, -1)
return True
except Exception as e:
logger.error(f"裁剪对话历史失败: {chat_id}, {e}")
return False
def clear_chat_history(self, chat_id: str) -> bool:
"""
清空指定会话的对话历史
Args:
chat_id: 会话ID
Returns:
是否成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("chat_history", chat_id)
self.client.delete(key)
return True
except Exception as e:
logger.error(f"清空对话历史失败: {chat_id}, {e}")
return False
# ==================== 群聊历史记录专用方法 ====================
def get_group_history(self, group_id: str, max_messages: int = 100) -> list:
"""
获取群聊历史记录
Args:
group_id: 群聊ID
max_messages: 最大返回消息数
Returns:
消息列表,每条包含 nickname, content, timestamp
"""
if not self.enabled or not self.client:
return []
try:
key = self._make_key("group_history", group_id)
data = self.client.lrange(key, -max_messages, -1)
return [json.loads(item) for item in data]
except Exception as e:
logger.error(f"获取群聊历史失败: {group_id}, {e}")
return []
def add_group_message(self, group_id: str, nickname: str, content,
record_id: str = None, ttl: int = 86400) -> bool:
"""
添加消息到群聊历史
Args:
group_id: 群聊ID
nickname: 发送者昵称
content: 消息内容
record_id: 可选的记录ID用于后续更新
ttl: 过期时间默认24小时
Returns:
是否添加成功
"""
if not self.enabled or not self.client:
return False
try:
import time
key = self._make_key("group_history", group_id)
message = {
"nickname": nickname,
"content": content,
"timestamp": time.time()
}
if record_id:
message["id"] = record_id
self.client.rpush(key, json.dumps(message, ensure_ascii=False))
self.client.expire(key, ttl)
return True
except Exception as e:
logger.error(f"添加群聊消息失败: {group_id}, {e}")
return False
def update_group_message_by_id(self, group_id: str, record_id: str, new_content) -> bool:
"""
根据ID更新群聊历史中的消息
Args:
group_id: 群聊ID
record_id: 记录ID
new_content: 新内容
Returns:
是否更新成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("group_history", group_id)
# 获取所有消息
data = self.client.lrange(key, 0, -1)
for i, item in enumerate(data):
msg = json.loads(item)
if msg.get("id") == record_id:
msg["content"] = new_content
self.client.lset(key, i, json.dumps(msg, ensure_ascii=False))
return True
return False
except Exception as e:
logger.error(f"更新群聊消息失败: {group_id}, {record_id}, {e}")
return False
def trim_group_history(self, group_id: str, max_messages: int = 100) -> bool:
"""
裁剪群聊历史保留最近的N条消息
Args:
group_id: 群聊ID
max_messages: 保留的最大消息数
Returns:
是否成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("group_history", group_id)
self.client.ltrim(key, -max_messages, -1)
return True
except Exception as e:
logger.error(f"裁剪群聊历史失败: {group_id}, {e}")
return False
# ==================== 限流专用方法 ====================
def check_rate_limit(self, identifier: str, limit: int = 10,
window: int = 60, limit_type: str = "message") -> tuple:
"""
检查是否超过限流
使用滑动窗口算法
Args:
identifier: 标识符如用户wxid、群ID等
limit: 时间窗口内最大请求数
window: 时间窗口(秒)
limit_type: 限流类型message/ai_chat/image_gen等
Returns:
(是否允许, 剩余次数, 重置时间秒数)
"""
if not self.enabled or not self.client:
return (True, limit, 0) # Redis 不可用时不限流
try:
import time
key = self._make_key("rate_limit", limit_type, identifier)
now = time.time()
window_start = now - window
# 使用 pipeline 提高性能
pipe = self.client.pipeline()
# 移除过期的记录
pipe.zremrangebyscore(key, 0, window_start)
# 获取当前窗口内的请求数
pipe.zcard(key)
# 添加当前请求
pipe.zadd(key, {str(now): now})
# 设置过期时间
pipe.expire(key, window)
results = pipe.execute()
current_count = results[1] # zcard 的结果
if current_count >= limit:
# 获取最早的记录时间,计算重置时间
oldest = self.client.zrange(key, 0, 0, withscores=True)
if oldest:
reset_time = int(oldest[0][1] + window - now)
else:
reset_time = window
return (False, 0, max(reset_time, 1))
remaining = limit - current_count - 1
return (True, remaining, 0)
except Exception as e:
logger.error(f"限流检查失败: {identifier}, {e}")
return (True, limit, 0) # 出错时不限流
def get_rate_limit_status(self, identifier: str, limit: int = 10,
window: int = 60, limit_type: str = "message") -> Dict:
"""
获取限流状态(不增加计数)
Args:
identifier: 标识符
limit: 时间窗口内最大请求数
window: 时间窗口(秒)
limit_type: 限流类型
Returns:
状态字典
"""
if not self.enabled or not self.client:
return {"enabled": False, "current": 0, "limit": limit, "remaining": limit}
try:
import time
key = self._make_key("rate_limit", limit_type, identifier)
now = time.time()
window_start = now - window
# 移除过期记录并获取当前数量
self.client.zremrangebyscore(key, 0, window_start)
current = self.client.zcard(key)
return {
"enabled": True,
"current": current,
"limit": limit,
"remaining": max(0, limit - current),
"window": window
}
except Exception as e:
logger.error(f"获取限流状态失败: {identifier}, {e}")
return {"enabled": False, "error": str(e)}
def reset_rate_limit(self, identifier: str, limit_type: str = "message") -> bool:
"""
重置限流计数
Args:
identifier: 标识符
limit_type: 限流类型
Returns:
是否成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("rate_limit", limit_type, identifier)
self.client.delete(key)
return True
except Exception as e:
logger.error(f"重置限流失败: {identifier}, {e}")
return False
# ==================== 媒体缓存专用方法 ====================
def cache_media(self, media_key: str, base64_data: str, media_type: str = "image", ttl: int = 300) -> bool:
"""
缓存媒体文件的 base64 数据
Args:
media_key: 媒体唯一标识(如 cdnurl 的 hash 或 aeskey
base64_data: base64 编码的媒体数据
media_type: 媒体类型image/emoji/video
ttl: 过期时间默认5分钟
Returns:
是否缓存成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("media_cache", media_type, media_key)
# 直接存储 base64 字符串,不再 json 序列化
self.client.setex(key, ttl, base64_data)
logger.debug(f"媒体已缓存: {media_type}/{media_key[:20]}..., TTL={ttl}s")
return True
except Exception as e:
logger.error(f"缓存媒体失败: {media_key}, {e}")
return False
def get_cached_media(self, media_key: str, media_type: str = "image") -> Optional[str]:
"""
获取缓存的媒体 base64 数据
Args:
media_key: 媒体唯一标识
media_type: 媒体类型
Returns:
base64 数据,不存在返回 None
"""
if not self.enabled or not self.client:
return None
try:
key = self._make_key("media_cache", media_type, media_key)
data = self.client.get(key)
if data:
logger.debug(f"媒体缓存命中: {media_type}/{media_key[:20]}...")
return data
return None
except Exception as e:
logger.error(f"获取媒体缓存失败: {media_key}, {e}")
return None
def delete_cached_media(self, media_key: str, media_type: str = "image") -> bool:
"""
删除缓存的媒体
Args:
media_key: 媒体唯一标识
media_type: 媒体类型
Returns:
是否删除成功
"""
if not self.enabled or not self.client:
return False
try:
key = self._make_key("media_cache", media_type, media_key)
self.client.delete(key)
return True
except Exception as e:
logger.error(f"删除媒体缓存失败: {media_key}, {e}")
return False
@staticmethod
def generate_media_key(cdnurl: str = "", aeskey: str = "") -> str:
"""
根据 CDN URL 或 AES Key 生成媒体缓存 key
Args:
cdnurl: CDN URL
aeskey: AES Key
Returns:
缓存 key
"""
import hashlib
# 优先使用 aeskey更短更稳定否则使用 cdnurl 的 hash
if aeskey:
return aeskey[:32] # 取前32位作为 key
elif cdnurl:
return hashlib.md5(cdnurl.encode()).hexdigest()
return ""
def get_cache() -> Optional[RedisCache]:
"""
获取全局缓存实例
返回 RedisCache 单例实例。如果还没有初始化,返回一个未连接的实例。
建议在 MessageLogger 初始化后再调用此函数。
"""
return RedisCache._instance
def init_cache(config: Dict) -> RedisCache:
"""
初始化全局缓存实例
Args:
config: Redis 配置
Returns:
缓存实例
"""
global _cache_instance
_cache_instance = RedisCache(config)
return _cache_instance