feat: 优化整体项目
This commit is contained in:
@@ -126,6 +126,59 @@ def add_callback_handler(callback_handler):
|
|||||||
logger.debug(f"注册断开回调: {name}")
|
logger.debug(f"注册断开回调: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
def remove_callback_handler(callback_handler):
|
||||||
|
"""
|
||||||
|
移除回调处理器实例(修复内存泄漏)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback_handler: 包含回调方法的对象
|
||||||
|
"""
|
||||||
|
global _GLOBAL_CONNECT_CALLBACK_LIST, _GLOBAL_RECV_CALLBACK_LIST, _GLOBAL_CLOSE_CALLBACK_LIST
|
||||||
|
|
||||||
|
# 移除属于该处理器的所有回调
|
||||||
|
_GLOBAL_CONNECT_CALLBACK_LIST = [
|
||||||
|
f for f in _GLOBAL_CONNECT_CALLBACK_LIST
|
||||||
|
if not (hasattr(f, '__self__') and f.__self__ is callback_handler)
|
||||||
|
]
|
||||||
|
_GLOBAL_RECV_CALLBACK_LIST = [
|
||||||
|
f for f in _GLOBAL_RECV_CALLBACK_LIST
|
||||||
|
if not (hasattr(f, '__self__') and f.__self__ is callback_handler)
|
||||||
|
]
|
||||||
|
_GLOBAL_CLOSE_CALLBACK_LIST = [
|
||||||
|
f for f in _GLOBAL_CLOSE_CALLBACK_LIST
|
||||||
|
if not (hasattr(f, '__self__') and f.__self__ is callback_handler)
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.debug(f"已移除回调处理器: {type(callback_handler).__name__}")
|
||||||
|
|
||||||
|
|
||||||
|
def clear_all_callbacks():
|
||||||
|
"""
|
||||||
|
清空所有回调(用于关闭时清理)
|
||||||
|
"""
|
||||||
|
global _GLOBAL_CONNECT_CALLBACK_LIST, _GLOBAL_RECV_CALLBACK_LIST, _GLOBAL_CLOSE_CALLBACK_LIST
|
||||||
|
|
||||||
|
_GLOBAL_CONNECT_CALLBACK_LIST.clear()
|
||||||
|
_GLOBAL_RECV_CALLBACK_LIST.clear()
|
||||||
|
_GLOBAL_CLOSE_CALLBACK_LIST.clear()
|
||||||
|
|
||||||
|
logger.debug("已清空所有回调")
|
||||||
|
|
||||||
|
|
||||||
|
def get_callback_count() -> dict:
|
||||||
|
"""
|
||||||
|
获取回调数量统计
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
回调数量字典
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"connect": len(_GLOBAL_CONNECT_CALLBACK_LIST),
|
||||||
|
"recv": len(_GLOBAL_RECV_CALLBACK_LIST),
|
||||||
|
"close": len(_GLOBAL_CLOSE_CALLBACK_LIST)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@WINFUNCTYPE(None, ctypes.c_void_p)
|
@WINFUNCTYPE(None, ctypes.c_void_p)
|
||||||
def wechat_connect_callback(client_id):
|
def wechat_connect_callback(client_id):
|
||||||
"""
|
"""
|
||||||
|
|||||||
336
bot.py
336
bot.py
@@ -2,10 +2,19 @@
|
|||||||
WechatHookBot - 主入口
|
WechatHookBot - 主入口
|
||||||
|
|
||||||
基于个微大客户版 Hook API 的微信机器人框架
|
基于个微大客户版 Hook API 的微信机器人框架
|
||||||
|
|
||||||
|
优化功能:
|
||||||
|
- 优先级消息队列
|
||||||
|
- 自适应熔断器
|
||||||
|
- 配置热更新
|
||||||
|
- 性能监控
|
||||||
|
- 优雅关闭
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
import tomllib
|
import tomllib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -13,6 +22,8 @@ from loguru import logger
|
|||||||
from WechatHook import NoveLoader, WechatHookClient
|
from WechatHook import NoveLoader, WechatHookClient
|
||||||
from WechatHook.callbacks import (
|
from WechatHook.callbacks import (
|
||||||
add_callback_handler,
|
add_callback_handler,
|
||||||
|
remove_callback_handler,
|
||||||
|
clear_all_callbacks,
|
||||||
wechat_connect_callback,
|
wechat_connect_callback,
|
||||||
wechat_recv_callback,
|
wechat_recv_callback,
|
||||||
wechat_close_callback,
|
wechat_close_callback,
|
||||||
@@ -23,7 +34,15 @@ from WechatHook.callbacks import (
|
|||||||
from utils.hookbot import HookBot
|
from utils.hookbot import HookBot
|
||||||
from utils.plugin_manager import PluginManager
|
from utils.plugin_manager import PluginManager
|
||||||
from utils.decorators import scheduler
|
from utils.decorators import scheduler
|
||||||
# from database import KeyvalDB, MessageDB # 不需要数据库
|
from utils.bot_utils import (
|
||||||
|
PriorityMessageQueue,
|
||||||
|
MessagePriority,
|
||||||
|
PRIORITY_MESSAGE_TYPES,
|
||||||
|
AdaptiveCircuitBreaker,
|
||||||
|
ConfigWatcher,
|
||||||
|
PerformanceMonitor,
|
||||||
|
get_performance_monitor
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BotService:
|
class BotService:
|
||||||
@@ -37,17 +56,24 @@ class BotService:
|
|||||||
self.process_id = None # 微信进程 ID
|
self.process_id = None # 微信进程 ID
|
||||||
self.socket_client_id = None # Socket 客户端 ID
|
self.socket_client_id = None # Socket 客户端 ID
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
self.is_shutting_down = False # 是否正在关闭
|
||||||
self.event_loop = None # 事件循环引用
|
self.event_loop = None # 事件循环引用
|
||||||
|
|
||||||
# 消息队列和性能控制
|
# 消息队列和性能控制
|
||||||
self.message_queue = None
|
self.message_queue: PriorityMessageQueue = None # 优先级消息队列
|
||||||
self.queue_config = {}
|
self.queue_config = {}
|
||||||
self.concurrency_config = {}
|
self.concurrency_config = {}
|
||||||
self.consumer_tasks = []
|
self.consumer_tasks = []
|
||||||
self.processing_semaphore = None
|
self.processing_semaphore = None
|
||||||
self.circuit_breaker_failures = 0
|
|
||||||
self.circuit_breaker_open = False
|
# 自适应熔断器
|
||||||
self.circuit_breaker_last_failure = 0
|
self.circuit_breaker: AdaptiveCircuitBreaker = None
|
||||||
|
|
||||||
|
# 配置热更新
|
||||||
|
self.config_watcher: ConfigWatcher = None
|
||||||
|
|
||||||
|
# 性能监控
|
||||||
|
self.performance_monitor: PerformanceMonitor = None
|
||||||
|
|
||||||
@CONNECT_CALLBACK(in_class=True)
|
@CONNECT_CALLBACK(in_class=True)
|
||||||
def on_connect(self, client_id):
|
def on_connect(self, client_id):
|
||||||
@@ -85,40 +111,55 @@ class BotService:
|
|||||||
logger.error(f"消息入队失败: {e}")
|
logger.error(f"消息入队失败: {e}")
|
||||||
|
|
||||||
async def _enqueue_message(self, msg_type, data):
|
async def _enqueue_message(self, msg_type, data):
|
||||||
"""将消息加入队列"""
|
"""将消息加入优先级队列"""
|
||||||
try:
|
try:
|
||||||
|
# 记录收到消息
|
||||||
|
if self.performance_monitor:
|
||||||
|
self.performance_monitor.record_message_received()
|
||||||
|
|
||||||
# 检查队列是否已满
|
# 检查队列是否已满
|
||||||
if self.message_queue.qsize() >= self.queue_config.get("max_size", 1000):
|
if self.message_queue.full():
|
||||||
overflow_strategy = self.queue_config.get("overflow_strategy", "drop_oldest")
|
overflow_strategy = self.queue_config.get("overflow_strategy", "drop_oldest")
|
||||||
|
|
||||||
if overflow_strategy == "drop_oldest":
|
if overflow_strategy == "drop_oldest":
|
||||||
# 丢弃最旧的消息
|
# 丢弃优先级最低的消息
|
||||||
try:
|
if self.message_queue.drop_lowest_priority():
|
||||||
self.message_queue.get_nowait()
|
logger.warning("队列已满,丢弃优先级最低的消息")
|
||||||
logger.warning("队列已满,丢弃最旧消息")
|
if self.performance_monitor:
|
||||||
except asyncio.QueueEmpty:
|
self.performance_monitor.record_message_dropped()
|
||||||
pass
|
|
||||||
elif overflow_strategy == "sampling":
|
elif overflow_strategy == "sampling":
|
||||||
# 采样处理,随机丢弃
|
# 采样处理,随机丢弃(但高优先级消息不丢弃)
|
||||||
import random
|
import random
|
||||||
if random.random() < 0.5: # 50% 概率丢弃
|
priority = PRIORITY_MESSAGE_TYPES.get(msg_type, MessagePriority.NORMAL)
|
||||||
|
if priority < MessagePriority.HIGH and random.random() < 0.5:
|
||||||
logger.debug("队列压力大,采样丢弃消息")
|
logger.debug("队列压力大,采样丢弃消息")
|
||||||
|
if self.performance_monitor:
|
||||||
|
self.performance_monitor.record_message_dropped()
|
||||||
return
|
return
|
||||||
else: # degrade
|
else: # degrade
|
||||||
logger.warning("队列已满,降级处理")
|
# 降级处理(但高优先级消息不丢弃)
|
||||||
return
|
priority = PRIORITY_MESSAGE_TYPES.get(msg_type, MessagePriority.NORMAL)
|
||||||
|
if priority < MessagePriority.HIGH:
|
||||||
|
logger.warning("队列已满,降级处理")
|
||||||
|
if self.performance_monitor:
|
||||||
|
self.performance_monitor.record_message_dropped()
|
||||||
|
return
|
||||||
|
|
||||||
# 将消息放入队列
|
# 将消息放入优先级队列
|
||||||
await self.message_queue.put((msg_type, data))
|
await self.message_queue.put(msg_type, data)
|
||||||
|
|
||||||
|
# 记录队列大小
|
||||||
|
if self.performance_monitor:
|
||||||
|
self.performance_monitor.record_queue_size(self.message_queue.qsize())
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"消息入队异常: {e}")
|
logger.error(f"消息入队异常: {e}")
|
||||||
|
|
||||||
async def _message_consumer(self, consumer_id: int):
|
async def _message_consumer(self, consumer_id: int):
|
||||||
"""消息消费者协程"""
|
"""消息消费者协程 - 纯队列串行模式,避免并发触发风控"""
|
||||||
logger.info(f"消息消费者 {consumer_id} 已启动")
|
logger.info(f"消息消费者 {consumer_id} 已启动(串行模式)")
|
||||||
|
|
||||||
while self.is_running:
|
while self.is_running and not self.is_shutting_down:
|
||||||
try:
|
try:
|
||||||
# 从队列获取消息,设置超时避免无限等待
|
# 从队列获取消息,设置超时避免无限等待
|
||||||
msg_type, data = await asyncio.wait_for(
|
msg_type, data = await asyncio.wait_for(
|
||||||
@@ -127,76 +168,68 @@ class BotService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 检查熔断器状态
|
# 检查熔断器状态
|
||||||
if self._check_circuit_breaker():
|
if self.circuit_breaker and self.circuit_breaker.is_open():
|
||||||
logger.debug("熔断器开启,跳过消息处理")
|
logger.debug("熔断器开启,跳过消息处理")
|
||||||
|
self.circuit_breaker.record_rejection()
|
||||||
|
self.message_queue.task_done()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 创建并发任务,不等待完成
|
# 串行处理:等待当前消息处理完成后再处理下一条
|
||||||
timeout = self.concurrency_config.get("plugin_task_timeout_seconds", 5)
|
timeout = self.concurrency_config.get("plugin_task_timeout_seconds", 720)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
# 使用信号量控制并发数量
|
try:
|
||||||
async def process_with_semaphore():
|
await asyncio.wait_for(
|
||||||
async with self.processing_semaphore:
|
self.hookbot.process_message(msg_type, data),
|
||||||
try:
|
timeout=timeout
|
||||||
await asyncio.wait_for(
|
)
|
||||||
self.hookbot.process_message(msg_type, data),
|
# 记录成功
|
||||||
timeout=timeout
|
processing_time = time.time() - start_time
|
||||||
)
|
if self.circuit_breaker:
|
||||||
self._reset_circuit_breaker()
|
self.circuit_breaker.record_success()
|
||||||
except asyncio.TimeoutError:
|
if self.performance_monitor:
|
||||||
logger.warning(f"消息处理超时 (>{timeout}s): type={msg_type}")
|
self.performance_monitor.record_message_processed(processing_time)
|
||||||
self._record_circuit_breaker_failure()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"消息处理异常: {e}")
|
|
||||||
self._record_circuit_breaker_failure()
|
|
||||||
|
|
||||||
# 创建任务但不等待,实现真正并发
|
except asyncio.TimeoutError:
|
||||||
asyncio.create_task(process_with_semaphore())
|
logger.warning(f"消息处理超时 (>{timeout}s): type={msg_type}")
|
||||||
|
if self.circuit_breaker:
|
||||||
|
self.circuit_breaker.record_failure()
|
||||||
|
if self.performance_monitor:
|
||||||
|
self.performance_monitor.record_message_failed()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"消息处理异常: {e}")
|
||||||
|
if self.circuit_breaker:
|
||||||
|
self.circuit_breaker.record_failure()
|
||||||
|
if self.performance_monitor:
|
||||||
|
self.performance_monitor.record_message_failed()
|
||||||
|
|
||||||
# 标记任务完成
|
# 标记任务完成
|
||||||
self.message_queue.task_done()
|
self.message_queue.task_done()
|
||||||
|
|
||||||
|
# 更新熔断器统计
|
||||||
|
if self.performance_monitor and self.circuit_breaker:
|
||||||
|
self.performance_monitor.update_circuit_breaker_stats(
|
||||||
|
self.circuit_breaker.get_stats()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 消息间隔,避免发送太快触发风控
|
||||||
|
message_interval = self.concurrency_config.get("message_interval_ms", 100)
|
||||||
|
if message_interval > 0:
|
||||||
|
await asyncio.sleep(message_interval / 1000.0)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# 队列为空,继续等待
|
# 队列为空,继续等待
|
||||||
continue
|
continue
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# 任务被取消,退出循环
|
||||||
|
logger.info(f"消费者 {consumer_id} 收到取消信号")
|
||||||
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"消费者 {consumer_id} 异常: {e}")
|
logger.error(f"消费者 {consumer_id} 异常: {e}")
|
||||||
await asyncio.sleep(0.1) # 短暂休息避免忙等
|
await asyncio.sleep(0.1) # 短暂休息避免忙等
|
||||||
|
|
||||||
def _check_circuit_breaker(self) -> bool:
|
logger.info(f"消费者 {consumer_id} 已退出")
|
||||||
"""检查熔断器状态"""
|
|
||||||
if not self.concurrency_config.get("enable_circuit_breaker", True):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.circuit_breaker_open:
|
|
||||||
# 检查是否可以尝试恢复
|
|
||||||
import time
|
|
||||||
if time.time() - self.circuit_breaker_last_failure > 30: # 30秒后尝试恢复
|
|
||||||
self.circuit_breaker_open = False
|
|
||||||
self.circuit_breaker_failures = 0
|
|
||||||
logger.info("熔断器尝试恢复")
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _record_circuit_breaker_failure(self):
|
|
||||||
"""记录熔断器失败"""
|
|
||||||
if not self.concurrency_config.get("enable_circuit_breaker", True):
|
|
||||||
return
|
|
||||||
|
|
||||||
self.circuit_breaker_failures += 1
|
|
||||||
threshold = self.concurrency_config.get("circuit_breaker_threshold", 5)
|
|
||||||
|
|
||||||
if self.circuit_breaker_failures >= threshold:
|
|
||||||
import time
|
|
||||||
self.circuit_breaker_open = True
|
|
||||||
self.circuit_breaker_last_failure = time.time()
|
|
||||||
logger.warning(f"熔断器开启,连续失败 {self.circuit_breaker_failures} 次")
|
|
||||||
|
|
||||||
def _reset_circuit_breaker(self):
|
|
||||||
"""重置熔断器"""
|
|
||||||
if self.circuit_breaker_failures > 0:
|
|
||||||
self.circuit_breaker_failures = 0
|
|
||||||
|
|
||||||
@CLOSE_CALLBACK(in_class=True)
|
@CLOSE_CALLBACK(in_class=True)
|
||||||
def on_close(self, client_id):
|
def on_close(self, client_id):
|
||||||
@@ -236,16 +269,36 @@ class BotService:
|
|||||||
self.queue_config = config.get("Queue", {})
|
self.queue_config = config.get("Queue", {})
|
||||||
self.concurrency_config = config.get("Concurrency", {})
|
self.concurrency_config = config.get("Concurrency", {})
|
||||||
|
|
||||||
# 创建消息队列
|
# 创建优先级消息队列
|
||||||
queue_size = self.queue_config.get("max_size", 1000)
|
queue_size = self.queue_config.get("max_size", 1000)
|
||||||
self.message_queue = asyncio.Queue(maxsize=queue_size)
|
self.message_queue = PriorityMessageQueue(maxsize=queue_size)
|
||||||
logger.info(f"消息队列已创建,容量: {queue_size}")
|
logger.info(f"优先级消息队列已创建,容量: {queue_size}")
|
||||||
|
|
||||||
# 创建并发控制信号量
|
# 创建并发控制信号量
|
||||||
max_concurrency = self.concurrency_config.get("plugin_max_concurrency", 8)
|
max_concurrency = self.concurrency_config.get("plugin_max_concurrency", 8)
|
||||||
self.processing_semaphore = asyncio.Semaphore(max_concurrency)
|
self.processing_semaphore = asyncio.Semaphore(max_concurrency)
|
||||||
logger.info(f"并发控制已设置,最大并发: {max_concurrency}")
|
logger.info(f"并发控制已设置,最大并发: {max_concurrency}")
|
||||||
|
|
||||||
|
# 创建自适应熔断器
|
||||||
|
if self.concurrency_config.get("enable_circuit_breaker", True):
|
||||||
|
self.circuit_breaker = AdaptiveCircuitBreaker(
|
||||||
|
failure_threshold=self.concurrency_config.get("circuit_breaker_threshold", 10),
|
||||||
|
success_threshold=3,
|
||||||
|
initial_recovery_time=5.0,
|
||||||
|
max_recovery_time=300.0
|
||||||
|
)
|
||||||
|
logger.info("自适应熔断器已创建")
|
||||||
|
|
||||||
|
# 创建性能监控器
|
||||||
|
self.performance_monitor = get_performance_monitor()
|
||||||
|
logger.info("性能监控器已创建")
|
||||||
|
|
||||||
|
# 创建配置热更新监听器
|
||||||
|
self.config_watcher = ConfigWatcher("main_config.toml", check_interval=5.0)
|
||||||
|
self.config_watcher.register_callback(self._on_config_update)
|
||||||
|
await self.config_watcher.start()
|
||||||
|
logger.info("配置热更新监听器已启动")
|
||||||
|
|
||||||
# 不需要数据库(简化版本)
|
# 不需要数据库(简化版本)
|
||||||
|
|
||||||
# 获取 DLL 路径
|
# 获取 DLL 路径
|
||||||
@@ -340,6 +393,26 @@ class BotService:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _on_config_update(self, new_config: dict):
|
||||||
|
"""配置热更新回调"""
|
||||||
|
logger.info("正在应用新配置...")
|
||||||
|
|
||||||
|
# 更新队列配置
|
||||||
|
self.queue_config = new_config.get("Queue", self.queue_config)
|
||||||
|
|
||||||
|
# 更新并发配置
|
||||||
|
old_concurrency = self.concurrency_config
|
||||||
|
self.concurrency_config = new_config.get("Concurrency", self.concurrency_config)
|
||||||
|
|
||||||
|
# 更新熔断器配置
|
||||||
|
if self.circuit_breaker:
|
||||||
|
new_threshold = self.concurrency_config.get("circuit_breaker_threshold", 10)
|
||||||
|
if new_threshold != old_concurrency.get("circuit_breaker_threshold", 10):
|
||||||
|
self.circuit_breaker.failure_threshold = new_threshold
|
||||||
|
logger.info(f"熔断器阈值已更新: {new_threshold}")
|
||||||
|
|
||||||
|
logger.success("配置热更新完成")
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""运行机器人"""
|
"""运行机器人"""
|
||||||
if not await self.initialize():
|
if not await self.initialize():
|
||||||
@@ -347,6 +420,15 @@ class BotService:
|
|||||||
|
|
||||||
self.is_running = True
|
self.is_running = True
|
||||||
|
|
||||||
|
# 启动定期性能报告
|
||||||
|
async def periodic_stats():
|
||||||
|
while self.is_running:
|
||||||
|
await asyncio.sleep(300) # 每5分钟输出一次
|
||||||
|
if self.performance_monitor and self.is_running:
|
||||||
|
self.performance_monitor.print_stats()
|
||||||
|
|
||||||
|
stats_task = asyncio.create_task(periodic_stats())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("机器人正在运行,按 Ctrl+C 停止...")
|
logger.info("机器人正在运行,按 Ctrl+C 停止...")
|
||||||
while self.is_running:
|
while self.is_running:
|
||||||
@@ -354,44 +436,96 @@ class BotService:
|
|||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("收到停止信号...")
|
logger.info("收到停止信号...")
|
||||||
finally:
|
finally:
|
||||||
|
stats_task.cancel()
|
||||||
await self.stop()
|
await self.stop()
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
"""停止机器人"""
|
"""优雅关闭机器人"""
|
||||||
logger.info("正在停止机器人...")
|
if self.is_shutting_down:
|
||||||
self.is_running = False
|
return
|
||||||
|
self.is_shutting_down = True
|
||||||
|
|
||||||
# 停止消息消费者
|
logger.info("=" * 60)
|
||||||
|
logger.info("正在优雅关闭机器人...")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# 1. 停止接收新消息
|
||||||
|
self.is_running = False
|
||||||
|
logger.info("[1/7] 停止接收新消息")
|
||||||
|
|
||||||
|
# 2. 等待队列中的消息处理完成(带超时)
|
||||||
|
if self.message_queue and not self.message_queue.empty():
|
||||||
|
queue_size = self.message_queue.qsize()
|
||||||
|
logger.info(f"[2/7] 等待队列中 {queue_size} 条消息处理完成...")
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self.message_queue.join(),
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
logger.info("[2/7] 队列消息已全部处理完成")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("[2/7] 队列消息未在 30 秒内处理完成,强制清空")
|
||||||
|
# 清空剩余消息
|
||||||
|
while not self.message_queue.empty():
|
||||||
|
try:
|
||||||
|
self.message_queue.get_nowait()
|
||||||
|
self.message_queue.task_done()
|
||||||
|
except:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.info("[2/7] 队列为空,无需等待")
|
||||||
|
|
||||||
|
# 3. 停止消息消费者
|
||||||
if self.consumer_tasks:
|
if self.consumer_tasks:
|
||||||
logger.info("正在停止消息消费者...")
|
logger.info(f"[3/7] 停止 {len(self.consumer_tasks)} 个消息消费者...")
|
||||||
for task in self.consumer_tasks:
|
for task in self.consumer_tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
await asyncio.gather(*self.consumer_tasks, return_exceptions=True)
|
||||||
# 等待所有消费者任务完成
|
|
||||||
if self.consumer_tasks:
|
|
||||||
await asyncio.gather(*self.consumer_tasks, return_exceptions=True)
|
|
||||||
self.consumer_tasks.clear()
|
self.consumer_tasks.clear()
|
||||||
logger.info("消息消费者已停止")
|
logger.info("[3/7] 消息消费者已停止")
|
||||||
|
else:
|
||||||
|
logger.info("[3/7] 无消费者需要停止")
|
||||||
|
|
||||||
# 清空消息队列
|
# 4. 停止配置监听器
|
||||||
if self.message_queue:
|
if self.config_watcher:
|
||||||
while not self.message_queue.empty():
|
logger.info("[4/7] 停止配置监听器...")
|
||||||
try:
|
await self.config_watcher.stop()
|
||||||
self.message_queue.get_nowait()
|
logger.info("[4/7] 配置监听器已停止")
|
||||||
self.message_queue.task_done()
|
else:
|
||||||
except asyncio.QueueEmpty:
|
logger.info("[4/7] 无配置监听器")
|
||||||
break
|
|
||||||
logger.info("消息队列已清空")
|
|
||||||
|
|
||||||
# 停止定时任务
|
# 5. 卸载插件
|
||||||
|
if self.plugin_manager:
|
||||||
|
logger.info("[5/7] 卸载插件...")
|
||||||
|
await self.plugin_manager.unload_plugins()
|
||||||
|
logger.info("[5/7] 插件已卸载")
|
||||||
|
else:
|
||||||
|
logger.info("[5/7] 无插件需要卸载")
|
||||||
|
|
||||||
|
# 6. 停止定时任务
|
||||||
if scheduler.running:
|
if scheduler.running:
|
||||||
|
logger.info("[6/7] 停止定时任务...")
|
||||||
scheduler.shutdown()
|
scheduler.shutdown()
|
||||||
|
logger.info("[6/7] 定时任务已停止")
|
||||||
|
else:
|
||||||
|
logger.info("[6/7] 定时任务未运行")
|
||||||
|
|
||||||
|
# 7. 清理回调和销毁微信连接
|
||||||
|
logger.info("[7/7] 清理资源...")
|
||||||
|
remove_callback_handler(self)
|
||||||
|
clear_all_callbacks()
|
||||||
|
|
||||||
# 销毁微信连接
|
|
||||||
if self.loader:
|
if self.loader:
|
||||||
self.loader.DestroyWeChat()
|
self.loader.DestroyWeChat()
|
||||||
|
|
||||||
logger.success("机器人已停止")
|
# 输出最终性能报告
|
||||||
|
if self.performance_monitor:
|
||||||
|
logger.info("最终性能报告:")
|
||||||
|
self.performance_monitor.print_stats()
|
||||||
|
|
||||||
|
logger.success("=" * 60)
|
||||||
|
logger.success("机器人已优雅关闭")
|
||||||
|
logger.success("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
AI 聊天插件
|
AI 聊天插件
|
||||||
|
|
||||||
支持自定义模型、API 和人设
|
支持自定义模型、API 和人设
|
||||||
|
支持 Redis 存储对话历史和限流
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -12,6 +13,7 @@ from datetime import datetime
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from utils.plugin_base import PluginBase
|
from utils.plugin_base import PluginBase
|
||||||
from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message
|
from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message
|
||||||
|
from utils.redis_cache import get_cache
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
import base64
|
import base64
|
||||||
import uuid
|
import uuid
|
||||||
@@ -95,6 +97,92 @@ class AIChat(PluginBase):
|
|||||||
else:
|
else:
|
||||||
return sender_wxid or from_wxid # 私聊使用用户ID
|
return sender_wxid or from_wxid # 私聊使用用户ID
|
||||||
|
|
||||||
|
async def _get_user_nickname(self, bot, from_wxid: str, user_wxid: str, is_group: bool) -> str:
|
||||||
|
"""
|
||||||
|
获取用户昵称,优先使用 Redis 缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bot: WechatHookClient 实例
|
||||||
|
from_wxid: 消息来源(群聊ID或私聊用户ID)
|
||||||
|
user_wxid: 用户wxid
|
||||||
|
is_group: 是否群聊
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
用户昵称
|
||||||
|
"""
|
||||||
|
if not is_group:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
nickname = ""
|
||||||
|
|
||||||
|
# 1. 优先从 Redis 缓存获取
|
||||||
|
redis_cache = get_cache()
|
||||||
|
if redis_cache and redis_cache.enabled:
|
||||||
|
cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid)
|
||||||
|
if cached_info and cached_info.get("nickname"):
|
||||||
|
logger.debug(f"[缓存命中] 用户昵称: {user_wxid} -> {cached_info['nickname']}")
|
||||||
|
return cached_info["nickname"]
|
||||||
|
|
||||||
|
# 2. 缓存未命中,调用 API 获取
|
||||||
|
try:
|
||||||
|
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
|
||||||
|
if user_info and user_info.get("nickName", {}).get("string"):
|
||||||
|
nickname = user_info["nickName"]["string"]
|
||||||
|
# 存入缓存
|
||||||
|
if redis_cache and redis_cache.enabled:
|
||||||
|
redis_cache.set_user_info(from_wxid, user_wxid, user_info)
|
||||||
|
logger.debug(f"[已缓存] 用户昵称: {user_wxid} -> {nickname}")
|
||||||
|
return nickname
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"API获取用户昵称失败: {e}")
|
||||||
|
|
||||||
|
# 3. 从 MessageLogger 数据库查询
|
||||||
|
if not nickname:
|
||||||
|
try:
|
||||||
|
from plugins.MessageLogger.main import MessageLogger
|
||||||
|
msg_logger = MessageLogger.get_instance()
|
||||||
|
if msg_logger:
|
||||||
|
with msg_logger.get_db_connection() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
|
||||||
|
(user_wxid,)
|
||||||
|
)
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if result:
|
||||||
|
nickname = result[0]
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"从数据库获取昵称失败: {e}")
|
||||||
|
|
||||||
|
# 4. 最后降级使用 wxid
|
||||||
|
if not nickname:
|
||||||
|
nickname = user_wxid or "未知用户"
|
||||||
|
|
||||||
|
return nickname
|
||||||
|
|
||||||
|
def _check_rate_limit(self, user_wxid: str) -> tuple:
|
||||||
|
"""
|
||||||
|
检查用户是否超过限流
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_wxid: 用户wxid
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否允许, 剩余次数, 重置时间秒数)
|
||||||
|
"""
|
||||||
|
rate_limit_config = self.config.get("rate_limit", {})
|
||||||
|
if not rate_limit_config.get("enabled", True):
|
||||||
|
return (True, 999, 0)
|
||||||
|
|
||||||
|
redis_cache = get_cache()
|
||||||
|
if not redis_cache or not redis_cache.enabled:
|
||||||
|
return (True, 999, 0) # Redis 不可用时不限流
|
||||||
|
|
||||||
|
limit = rate_limit_config.get("ai_chat_limit", 20)
|
||||||
|
window = rate_limit_config.get("ai_chat_window", 60)
|
||||||
|
|
||||||
|
return redis_cache.check_rate_limit(user_wxid, limit, window, "ai_chat")
|
||||||
|
|
||||||
def _add_to_memory(self, chat_id: str, role: str, content, image_base64: str = None):
|
def _add_to_memory(self, chat_id: str, role: str, content, image_base64: str = None):
|
||||||
"""
|
"""
|
||||||
添加消息到记忆
|
添加消息到记忆
|
||||||
@@ -108,9 +196,6 @@ class AIChat(PluginBase):
|
|||||||
if not self.config.get("memory", {}).get("enabled", False):
|
if not self.config.get("memory", {}).get("enabled", False):
|
||||||
return
|
return
|
||||||
|
|
||||||
if chat_id not in self.memory:
|
|
||||||
self.memory[chat_id] = []
|
|
||||||
|
|
||||||
# 如果有图片,构建多模态内容
|
# 如果有图片,构建多模态内容
|
||||||
if image_base64:
|
if image_base64:
|
||||||
message_content = [
|
message_content = [
|
||||||
@@ -120,6 +205,22 @@ class AIChat(PluginBase):
|
|||||||
else:
|
else:
|
||||||
message_content = content
|
message_content = content
|
||||||
|
|
||||||
|
# 优先使用 Redis 存储
|
||||||
|
redis_config = self.config.get("redis", {})
|
||||||
|
if redis_config.get("use_redis_history", True):
|
||||||
|
redis_cache = get_cache()
|
||||||
|
if redis_cache and redis_cache.enabled:
|
||||||
|
ttl = redis_config.get("chat_history_ttl", 86400)
|
||||||
|
redis_cache.add_chat_message(chat_id, role, message_content, ttl=ttl)
|
||||||
|
# 裁剪历史
|
||||||
|
max_messages = self.config["memory"]["max_messages"]
|
||||||
|
redis_cache.trim_chat_history(chat_id, max_messages)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 降级到内存存储
|
||||||
|
if chat_id not in self.memory:
|
||||||
|
self.memory[chat_id] = []
|
||||||
|
|
||||||
self.memory[chat_id].append({"role": role, "content": message_content})
|
self.memory[chat_id].append({"role": role, "content": message_content})
|
||||||
|
|
||||||
# 限制记忆长度
|
# 限制记忆长度
|
||||||
@@ -131,16 +232,47 @@ class AIChat(PluginBase):
|
|||||||
"""获取记忆中的消息"""
|
"""获取记忆中的消息"""
|
||||||
if not self.config.get("memory", {}).get("enabled", False):
|
if not self.config.get("memory", {}).get("enabled", False):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
# 优先从 Redis 获取
|
||||||
|
redis_config = self.config.get("redis", {})
|
||||||
|
if redis_config.get("use_redis_history", True):
|
||||||
|
redis_cache = get_cache()
|
||||||
|
if redis_cache and redis_cache.enabled:
|
||||||
|
max_messages = self.config["memory"]["max_messages"]
|
||||||
|
return redis_cache.get_chat_history(chat_id, max_messages)
|
||||||
|
|
||||||
|
# 降级到内存
|
||||||
return self.memory.get(chat_id, [])
|
return self.memory.get(chat_id, [])
|
||||||
|
|
||||||
def _clear_memory(self, chat_id: str):
|
def _clear_memory(self, chat_id: str):
|
||||||
"""清空指定会话的记忆"""
|
"""清空指定会话的记忆"""
|
||||||
|
# 清空 Redis
|
||||||
|
redis_config = self.config.get("redis", {})
|
||||||
|
if redis_config.get("use_redis_history", True):
|
||||||
|
redis_cache = get_cache()
|
||||||
|
if redis_cache and redis_cache.enabled:
|
||||||
|
redis_cache.clear_chat_history(chat_id)
|
||||||
|
|
||||||
|
# 同时清空内存
|
||||||
if chat_id in self.memory:
|
if chat_id in self.memory:
|
||||||
del self.memory[chat_id]
|
del self.memory[chat_id]
|
||||||
|
|
||||||
async def _download_and_encode_image(self, bot, cdnurl: str, aeskey: str) -> str:
|
async def _download_and_encode_image(self, bot, cdnurl: str, aeskey: str) -> str:
|
||||||
"""下载图片并转换为base64"""
|
"""下载图片并转换为base64,优先从缓存获取"""
|
||||||
try:
|
try:
|
||||||
|
# 1. 优先从 Redis 缓存获取
|
||||||
|
from utils.redis_cache import RedisCache
|
||||||
|
redis_cache = get_cache()
|
||||||
|
if redis_cache and redis_cache.enabled:
|
||||||
|
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
|
||||||
|
if media_key:
|
||||||
|
cached_data = redis_cache.get_cached_media(media_key, "image")
|
||||||
|
if cached_data:
|
||||||
|
logger.debug(f"[缓存命中] 图片从 Redis 获取: {media_key[:20]}...")
|
||||||
|
return cached_data
|
||||||
|
|
||||||
|
# 2. 缓存未命中,下载图片
|
||||||
|
logger.debug(f"[缓存未命中] 开始下载图片...")
|
||||||
temp_dir = Path(__file__).parent / "temp"
|
temp_dir = Path(__file__).parent / "temp"
|
||||||
temp_dir.mkdir(exist_ok=True)
|
temp_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
@@ -168,74 +300,114 @@ class AIChat(PluginBase):
|
|||||||
with open(save_path, "rb") as f:
|
with open(save_path, "rb") as f:
|
||||||
image_data = base64.b64encode(f.read()).decode()
|
image_data = base64.b64encode(f.read()).decode()
|
||||||
|
|
||||||
|
base64_result = f"data:image/jpeg;base64,{image_data}"
|
||||||
|
|
||||||
|
# 3. 缓存到 Redis(供后续使用)
|
||||||
|
if redis_cache and redis_cache.enabled and media_key:
|
||||||
|
redis_cache.cache_media(media_key, base64_result, "image", ttl=300)
|
||||||
|
logger.debug(f"[已缓存] 图片缓存到 Redis: {media_key[:20]}...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
Path(save_path).unlink()
|
Path(save_path).unlink()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return f"data:image/jpeg;base64,{image_data}"
|
return base64_result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"下载图片失败: {e}")
|
logger.error(f"下载图片失败: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def _download_emoji_and_encode(self, cdn_url: str) -> str:
|
async def _download_emoji_and_encode(self, cdn_url: str, max_retries: int = 3) -> str:
|
||||||
"""下载表情包并转换为base64(HTTP 直接下载)"""
|
"""下载表情包并转换为base64(HTTP 直接下载,带重试机制),优先从缓存获取"""
|
||||||
try:
|
# 替换 HTML 实体
|
||||||
# 替换 HTML 实体
|
cdn_url = cdn_url.replace("&", "&")
|
||||||
cdn_url = cdn_url.replace("&", "&")
|
|
||||||
|
|
||||||
temp_dir = Path(__file__).parent / "temp"
|
# 1. 优先从 Redis 缓存获取
|
||||||
temp_dir.mkdir(exist_ok=True)
|
from utils.redis_cache import RedisCache
|
||||||
|
redis_cache = get_cache()
|
||||||
|
media_key = RedisCache.generate_media_key(cdnurl=cdn_url)
|
||||||
|
if redis_cache and redis_cache.enabled and media_key:
|
||||||
|
cached_data = redis_cache.get_cached_media(media_key, "emoji")
|
||||||
|
if cached_data:
|
||||||
|
logger.debug(f"[缓存命中] 表情包从 Redis 获取: {media_key[:20]}...")
|
||||||
|
return cached_data
|
||||||
|
|
||||||
filename = f"temp_{uuid.uuid4().hex[:8]}.gif"
|
# 2. 缓存未命中,下载表情包
|
||||||
save_path = temp_dir / filename
|
logger.debug(f"[缓存未命中] 开始下载表情包...")
|
||||||
|
temp_dir = Path(__file__).parent / "temp"
|
||||||
|
temp_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
# 使用 aiohttp 下载
|
filename = f"temp_{uuid.uuid4().hex[:8]}.gif"
|
||||||
timeout = aiohttp.ClientTimeout(total=30)
|
save_path = temp_dir / filename
|
||||||
|
|
||||||
# 配置代理
|
last_error = None
|
||||||
connector = None
|
|
||||||
proxy_config = self.config.get("proxy", {})
|
|
||||||
if proxy_config.get("enabled", False):
|
|
||||||
proxy_type = proxy_config.get("type", "socks5").upper()
|
|
||||||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
|
||||||
proxy_port = proxy_config.get("port", 7890)
|
|
||||||
proxy_username = proxy_config.get("username")
|
|
||||||
proxy_password = proxy_config.get("password")
|
|
||||||
|
|
||||||
if proxy_username and proxy_password:
|
for attempt in range(max_retries):
|
||||||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
try:
|
||||||
else:
|
# 使用 aiohttp 下载,每次重试增加超时时间
|
||||||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
timeout = aiohttp.ClientTimeout(total=30 + attempt * 15)
|
||||||
|
|
||||||
if PROXY_SUPPORT:
|
# 配置代理
|
||||||
try:
|
connector = None
|
||||||
connector = ProxyConnector.from_url(proxy_url)
|
proxy_config = self.config.get("proxy", {})
|
||||||
except:
|
if proxy_config.get("enabled", False):
|
||||||
connector = None
|
proxy_type = proxy_config.get("type", "socks5").upper()
|
||||||
|
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||||||
|
proxy_port = proxy_config.get("port", 7890)
|
||||||
|
proxy_username = proxy_config.get("username")
|
||||||
|
proxy_password = proxy_config.get("password")
|
||||||
|
|
||||||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
if proxy_username and proxy_password:
|
||||||
async with session.get(cdn_url) as response:
|
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||||||
if response.status == 200:
|
else:
|
||||||
content = await response.read()
|
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||||||
with open(save_path, "wb") as f:
|
|
||||||
f.write(content)
|
|
||||||
|
|
||||||
# 编码为 base64
|
if PROXY_SUPPORT:
|
||||||
image_data = base64.b64encode(content).decode()
|
|
||||||
|
|
||||||
# 删除临时文件
|
|
||||||
try:
|
try:
|
||||||
save_path.unlink()
|
connector = ProxyConnector.from_url(proxy_url)
|
||||||
except:
|
except:
|
||||||
pass
|
connector = None
|
||||||
|
|
||||||
return f"data:image/gif;base64,{image_data}"
|
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||||||
|
async with session.get(cdn_url) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
content = await response.read()
|
||||||
|
|
||||||
return ""
|
if len(content) == 0:
|
||||||
except Exception as e:
|
logger.warning(f"表情包下载内容为空,重试 {attempt + 1}/{max_retries}")
|
||||||
logger.error(f"下载表情包失败: {e}")
|
continue
|
||||||
return ""
|
|
||||||
|
# 编码为 base64
|
||||||
|
image_data = base64.b64encode(content).decode()
|
||||||
|
|
||||||
|
logger.debug(f"表情包下载成功,大小: {len(content)} 字节")
|
||||||
|
base64_result = f"data:image/gif;base64,{image_data}"
|
||||||
|
|
||||||
|
# 3. 缓存到 Redis(供后续使用)
|
||||||
|
if redis_cache and redis_cache.enabled and media_key:
|
||||||
|
redis_cache.cache_media(media_key, base64_result, "emoji", ttl=300)
|
||||||
|
logger.debug(f"[已缓存] 表情包缓存到 Redis: {media_key[:20]}...")
|
||||||
|
|
||||||
|
return base64_result
|
||||||
|
else:
|
||||||
|
logger.warning(f"表情包下载失败,状态码: {response.status},重试 {attempt + 1}/{max_retries}")
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
last_error = "请求超时"
|
||||||
|
logger.warning(f"表情包下载超时,重试 {attempt + 1}/{max_retries}")
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
last_error = str(e)
|
||||||
|
logger.warning(f"表情包下载网络错误: {e},重试 {attempt + 1}/{max_retries}")
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
logger.warning(f"表情包下载异常: {e},重试 {attempt + 1}/{max_retries}")
|
||||||
|
|
||||||
|
# 重试前等待(指数退避)
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
await asyncio.sleep(1 * (attempt + 1))
|
||||||
|
|
||||||
|
logger.error(f"表情包下载失败,已重试 {max_retries} 次: {last_error}")
|
||||||
|
return ""
|
||||||
|
|
||||||
async def _generate_image_description(self, image_base64: str, prompt: str, config: dict) -> str:
|
async def _generate_image_description(self, image_base64: str, prompt: str, config: dict) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -479,37 +651,8 @@ class AIChat(PluginBase):
|
|||||||
# 检查是否应该回复
|
# 检查是否应该回复
|
||||||
should_reply = self._should_reply(message, content, bot_wxid)
|
should_reply = self._should_reply(message, content, bot_wxid)
|
||||||
|
|
||||||
# 获取用户昵称(用于历史记录)
|
# 获取用户昵称(用于历史记录)- 使用缓存优化
|
||||||
nickname = ""
|
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
|
||||||
if is_group:
|
|
||||||
try:
|
|
||||||
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
|
|
||||||
if user_info and user_info.get("nickName", {}).get("string"):
|
|
||||||
nickname = user_info["nickName"]["string"]
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 如果获取昵称失败,从 MessageLogger 数据库查询
|
|
||||||
if not nickname:
|
|
||||||
from plugins.MessageLogger.main import MessageLogger
|
|
||||||
msg_logger = MessageLogger.get_instance()
|
|
||||||
if msg_logger:
|
|
||||||
try:
|
|
||||||
with msg_logger.get_db_connection() as conn:
|
|
||||||
with conn.cursor() as cursor:
|
|
||||||
cursor.execute(
|
|
||||||
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
|
|
||||||
(user_wxid,)
|
|
||||||
)
|
|
||||||
result = cursor.fetchone()
|
|
||||||
if result:
|
|
||||||
nickname = result[0]
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 最后降级使用 wxid
|
|
||||||
if not nickname:
|
|
||||||
nickname = user_wxid or sender_wxid or "未知用户"
|
|
||||||
|
|
||||||
# 保存到群组历史记录(所有消息都保存,不管是否回复)
|
# 保存到群组历史记录(所有消息都保存,不管是否回复)
|
||||||
if is_group:
|
if is_group:
|
||||||
@@ -519,6 +662,16 @@ class AIChat(PluginBase):
|
|||||||
if not should_reply:
|
if not should_reply:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 限流检查(仅在需要回复时检查)
|
||||||
|
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
|
||||||
|
if not allowed:
|
||||||
|
rate_limit_config = self.config.get("rate_limit", {})
|
||||||
|
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
|
||||||
|
msg = msg.format(seconds=reset_time)
|
||||||
|
await bot.send_text(from_wxid, msg)
|
||||||
|
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
|
||||||
|
return False
|
||||||
|
|
||||||
# 提取实际消息内容(去除@)
|
# 提取实际消息内容(去除@)
|
||||||
actual_content = self._extract_content(message, content)
|
actual_content = self._extract_content(message, content)
|
||||||
if not actual_content:
|
if not actual_content:
|
||||||
@@ -1004,8 +1157,23 @@ class AIChat(PluginBase):
|
|||||||
json.dump(history, f, ensure_ascii=False, indent=2)
|
json.dump(history, f, ensure_ascii=False, indent=2)
|
||||||
temp_file.replace(history_file)
|
temp_file.replace(history_file)
|
||||||
|
|
||||||
|
def _use_redis_for_group_history(self) -> bool:
|
||||||
|
"""检查是否使用 Redis 存储群聊历史"""
|
||||||
|
redis_config = self.config.get("redis", {})
|
||||||
|
if not redis_config.get("use_redis_history", True):
|
||||||
|
return False
|
||||||
|
redis_cache = get_cache()
|
||||||
|
return redis_cache and redis_cache.enabled
|
||||||
|
|
||||||
async def _load_history(self, chat_id: str) -> list:
|
async def _load_history(self, chat_id: str) -> list:
|
||||||
"""异步读取群聊历史, 用锁避免与写入冲突"""
|
"""异步读取群聊历史, 优先使用 Redis"""
|
||||||
|
# 优先使用 Redis
|
||||||
|
if self._use_redis_for_group_history():
|
||||||
|
redis_cache = get_cache()
|
||||||
|
max_history = self.config.get("history", {}).get("max_history", 100)
|
||||||
|
return redis_cache.get_group_history(chat_id, max_history)
|
||||||
|
|
||||||
|
# 降级到文件存储
|
||||||
history_file = self._get_history_file(chat_id)
|
history_file = self._get_history_file(chat_id)
|
||||||
if not history_file:
|
if not history_file:
|
||||||
return []
|
return []
|
||||||
@@ -1015,6 +1183,10 @@ class AIChat(PluginBase):
|
|||||||
|
|
||||||
async def _save_history(self, chat_id: str, history: list):
|
async def _save_history(self, chat_id: str, history: list):
|
||||||
"""异步写入群聊历史, 包含长度截断"""
|
"""异步写入群聊历史, 包含长度截断"""
|
||||||
|
# Redis 模式下不需要单独保存,add_group_message 已经处理
|
||||||
|
if self._use_redis_for_group_history():
|
||||||
|
return
|
||||||
|
|
||||||
history_file = self._get_history_file(chat_id)
|
history_file = self._get_history_file(chat_id)
|
||||||
if not history_file:
|
if not history_file:
|
||||||
return
|
return
|
||||||
@@ -1040,6 +1212,27 @@ class AIChat(PluginBase):
|
|||||||
if not self.config.get("history", {}).get("enabled", True):
|
if not self.config.get("history", {}).get("enabled", True):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 构建消息内容
|
||||||
|
if image_base64:
|
||||||
|
message_content = [
|
||||||
|
{"type": "text", "text": content},
|
||||||
|
{"type": "image_url", "image_url": {"url": image_base64}}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
message_content = content
|
||||||
|
|
||||||
|
# 优先使用 Redis
|
||||||
|
if self._use_redis_for_group_history():
|
||||||
|
redis_cache = get_cache()
|
||||||
|
redis_config = self.config.get("redis", {})
|
||||||
|
ttl = redis_config.get("group_history_ttl", 172800)
|
||||||
|
redis_cache.add_group_message(chat_id, nickname, message_content, ttl=ttl)
|
||||||
|
# 裁剪历史
|
||||||
|
max_history = self.config.get("history", {}).get("max_history", 100)
|
||||||
|
redis_cache.trim_group_history(chat_id, max_history)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 降级到文件存储
|
||||||
history_file = self._get_history_file(chat_id)
|
history_file = self._get_history_file(chat_id)
|
||||||
if not history_file:
|
if not history_file:
|
||||||
return
|
return
|
||||||
@@ -1050,17 +1243,10 @@ class AIChat(PluginBase):
|
|||||||
|
|
||||||
message_record = {
|
message_record = {
|
||||||
"nickname": nickname,
|
"nickname": nickname,
|
||||||
"timestamp": datetime.now().isoformat()
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"content": message_content
|
||||||
}
|
}
|
||||||
|
|
||||||
if image_base64:
|
|
||||||
message_record["content"] = [
|
|
||||||
{"type": "text", "text": content},
|
|
||||||
{"type": "image_url", "image_url": {"url": image_base64}}
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
message_record["content"] = content
|
|
||||||
|
|
||||||
history.append(message_record)
|
history.append(message_record)
|
||||||
max_history = self.config.get("history", {}).get("max_history", 100)
|
max_history = self.config.get("history", {}).get("max_history", 100)
|
||||||
if len(history) > max_history:
|
if len(history) > max_history:
|
||||||
@@ -1073,6 +1259,18 @@ class AIChat(PluginBase):
|
|||||||
if not self.config.get("history", {}).get("enabled", True):
|
if not self.config.get("history", {}).get("enabled", True):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 优先使用 Redis
|
||||||
|
if self._use_redis_for_group_history():
|
||||||
|
redis_cache = get_cache()
|
||||||
|
redis_config = self.config.get("redis", {})
|
||||||
|
ttl = redis_config.get("group_history_ttl", 172800)
|
||||||
|
redis_cache.add_group_message(chat_id, nickname, content, record_id=record_id, ttl=ttl)
|
||||||
|
# 裁剪历史
|
||||||
|
max_history = self.config.get("history", {}).get("max_history", 100)
|
||||||
|
redis_cache.trim_group_history(chat_id, max_history)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 降级到文件存储
|
||||||
history_file = self._get_history_file(chat_id)
|
history_file = self._get_history_file(chat_id)
|
||||||
if not history_file:
|
if not history_file:
|
||||||
return
|
return
|
||||||
@@ -1097,6 +1295,13 @@ class AIChat(PluginBase):
|
|||||||
if not self.config.get("history", {}).get("enabled", True):
|
if not self.config.get("history", {}).get("enabled", True):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 优先使用 Redis
|
||||||
|
if self._use_redis_for_group_history():
|
||||||
|
redis_cache = get_cache()
|
||||||
|
redis_cache.update_group_message_by_id(chat_id, record_id, new_content)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 降级到文件存储
|
||||||
history_file = self._get_history_file(chat_id)
|
history_file = self._get_history_file(chat_id)
|
||||||
if not history_file:
|
if not history_file:
|
||||||
return
|
return
|
||||||
@@ -1205,16 +1410,18 @@ class AIChat(PluginBase):
|
|||||||
|
|
||||||
logger.info(f"AI处理引用图片消息: {title_text[:50]}...")
|
logger.info(f"AI处理引用图片消息: {title_text[:50]}...")
|
||||||
|
|
||||||
# 获取用户昵称
|
# 限流检查
|
||||||
nickname = ""
|
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
|
||||||
if is_group:
|
if not allowed:
|
||||||
try:
|
rate_limit_config = self.config.get("rate_limit", {})
|
||||||
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
|
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
|
||||||
if user_info and user_info.get("nickName", {}).get("string"):
|
msg = msg.format(seconds=reset_time)
|
||||||
nickname = user_info["nickName"]["string"]
|
await bot.send_text(from_wxid, msg)
|
||||||
logger.info(f"获取到用户昵称: {nickname}")
|
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
|
||||||
except Exception as e:
|
return False
|
||||||
logger.error(f"获取用户昵称失败: {e}")
|
|
||||||
|
# 获取用户昵称 - 使用缓存优化
|
||||||
|
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
|
||||||
|
|
||||||
# 下载并编码图片
|
# 下载并编码图片
|
||||||
logger.info(f"开始下载图片: {cdnbigimgurl[:50]}...")
|
logger.info(f"开始下载图片: {cdnbigimgurl[:50]}...")
|
||||||
@@ -1627,34 +1834,8 @@ class AIChat(PluginBase):
|
|||||||
if not is_emoji and not aeskey:
|
if not is_emoji and not aeskey:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 获取用户昵称
|
# 获取用户昵称 - 使用缓存优化
|
||||||
nickname = ""
|
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
|
||||||
try:
|
|
||||||
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
|
|
||||||
if user_info and user_info.get("nickName", {}).get("string"):
|
|
||||||
nickname = user_info["nickName"]["string"]
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not nickname:
|
|
||||||
from plugins.MessageLogger.main import MessageLogger
|
|
||||||
msg_logger = MessageLogger.get_instance()
|
|
||||||
if msg_logger:
|
|
||||||
try:
|
|
||||||
with msg_logger.get_db_connection() as conn:
|
|
||||||
with conn.cursor() as cursor:
|
|
||||||
cursor.execute(
|
|
||||||
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
|
|
||||||
(user_wxid,)
|
|
||||||
)
|
|
||||||
result = cursor.fetchone()
|
|
||||||
if result:
|
|
||||||
nickname = result[0]
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not nickname:
|
|
||||||
nickname = user_wxid or sender_wxid or "未知用户"
|
|
||||||
|
|
||||||
# 立即插入占位符到 history
|
# 立即插入占位符到 history
|
||||||
placeholder_id = str(uuid.uuid4())
|
placeholder_id = str(uuid.uuid4())
|
||||||
|
|||||||
@@ -2,8 +2,17 @@
|
|||||||
插件管理插件
|
插件管理插件
|
||||||
|
|
||||||
提供插件的热重载、启用、禁用等管理功能
|
提供插件的热重载、启用、禁用等管理功能
|
||||||
|
支持的指令:
|
||||||
|
/插件列表 - 查看所有插件状态
|
||||||
|
/重载插件 <名称> - 重载指定插件
|
||||||
|
/重载所有插件 - 重载所有插件
|
||||||
|
/启用插件 <名称> - 启用指定插件
|
||||||
|
/禁用插件 <名称> - 禁用指定插件
|
||||||
|
/刷新插件 - 扫描并发现新插件
|
||||||
|
/插件帮助 - 显示帮助信息
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
import tomllib
|
import tomllib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -18,7 +27,10 @@ class ManagePlugin(PluginBase):
|
|||||||
# 插件元数据
|
# 插件元数据
|
||||||
description = "插件管理,支持热重载、启用、禁用"
|
description = "插件管理,支持热重载、启用、禁用"
|
||||||
author = "ShiHao"
|
author = "ShiHao"
|
||||||
version = "1.0.0"
|
version = "2.0.0"
|
||||||
|
|
||||||
|
# 最高加载优先级,确保最先加载
|
||||||
|
load_priority = 100
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -34,37 +46,72 @@ class ManagePlugin(PluginBase):
|
|||||||
self.admins = main_config.get("Bot", {}).get("admins", [])
|
self.admins = main_config.get("Bot", {}).get("admins", [])
|
||||||
logger.info(f"插件管理插件已加载,管理员: {self.admins}")
|
logger.info(f"插件管理插件已加载,管理员: {self.admins}")
|
||||||
|
|
||||||
@on_text_message()
|
def _check_admin(self, message: dict) -> bool:
|
||||||
|
"""检查是否是管理员"""
|
||||||
|
sender_wxid = message.get("SenderWxid", "")
|
||||||
|
from_wxid = message.get("FromWxid", "")
|
||||||
|
is_group = message.get("IsGroup", False)
|
||||||
|
|
||||||
|
# 私聊时 sender_wxid 可能为空,使用 from_wxid
|
||||||
|
user_wxid = sender_wxid if is_group else from_wxid
|
||||||
|
|
||||||
|
return user_wxid in self.admins
|
||||||
|
|
||||||
|
@on_text_message(priority=99)
|
||||||
async def handle_command(self, bot, message: dict):
|
async def handle_command(self, bot, message: dict):
|
||||||
"""处理管理命令"""
|
"""处理管理命令"""
|
||||||
content = message.get("Content", "").strip()
|
content = message.get("Content", "").strip()
|
||||||
from_wxid = message.get("FromWxid", "")
|
from_wxid = message.get("FromWxid", "")
|
||||||
sender_wxid = message.get("SenderWxid", "")
|
|
||||||
|
|
||||||
logger.debug(f"ManagePlugin: content={content}, from={from_wxid}, sender={sender_wxid}, admins={self.admins}")
|
|
||||||
|
|
||||||
# 检查权限
|
# 检查权限
|
||||||
if not self.admins or sender_wxid not in self.admins:
|
if not self._check_admin(message):
|
||||||
return
|
return True # 继续传递给其他插件
|
||||||
|
|
||||||
|
# 插件帮助
|
||||||
|
if content == "/插件帮助" or content == "/plugin help":
|
||||||
|
await self._show_help(bot, from_wxid)
|
||||||
|
return False
|
||||||
|
|
||||||
# 插件列表
|
# 插件列表
|
||||||
if content == "/插件列表" or content == "/plugins":
|
elif content == "/插件列表" or content == "/plugins":
|
||||||
await self._list_plugins(bot, from_wxid)
|
await self._list_plugins(bot, from_wxid)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 重载所有插件
|
||||||
|
elif content == "/重载所有插件" or content == "/reload all":
|
||||||
|
await self._reload_all_plugins(bot, from_wxid)
|
||||||
|
return False
|
||||||
|
|
||||||
# 重载插件
|
# 重载插件
|
||||||
elif content.startswith("/重载插件 ") or content.startswith("/reload "):
|
elif content.startswith("/重载插件 ") or content.startswith("/reload "):
|
||||||
plugin_name = content.split(maxsplit=1)[1].strip()
|
plugin_name = content.split(maxsplit=1)[1].strip()
|
||||||
await self._reload_plugin(bot, from_wxid, plugin_name)
|
await self._reload_plugin(bot, from_wxid, plugin_name)
|
||||||
|
return False
|
||||||
|
|
||||||
# 启用插件
|
# 启用插件
|
||||||
elif content.startswith("/启用插件 ") or content.startswith("/enable "):
|
elif content.startswith("/启用插件 ") or content.startswith("/enable "):
|
||||||
plugin_name = content.split(maxsplit=1)[1].strip()
|
plugin_name = content.split(maxsplit=1)[1].strip()
|
||||||
await self._enable_plugin(bot, from_wxid, plugin_name)
|
await self._enable_plugin(bot, from_wxid, plugin_name)
|
||||||
|
return False
|
||||||
|
|
||||||
# 禁用插件
|
# 禁用插件
|
||||||
elif content.startswith("/禁用插件 ") or content.startswith("/disable "):
|
elif content.startswith("/禁用插件 ") or content.startswith("/disable "):
|
||||||
plugin_name = content.split(maxsplit=1)[1].strip()
|
plugin_name = content.split(maxsplit=1)[1].strip()
|
||||||
await self._disable_plugin(bot, from_wxid, plugin_name)
|
await self._disable_plugin(bot, from_wxid, plugin_name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 刷新插件(发现新插件)
|
||||||
|
elif content == "/刷新插件" or content == "/refresh":
|
||||||
|
await self._refresh_plugins(bot, from_wxid)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 加载新插件(从目录加载全新插件)
|
||||||
|
elif content.startswith("/加载插件 ") or content.startswith("/load "):
|
||||||
|
plugin_name = content.split(maxsplit=1)[1].strip()
|
||||||
|
await self._load_new_plugin(bot, from_wxid, plugin_name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True # 不是管理命令,继续传递
|
||||||
|
|
||||||
async def _list_plugins(self, bot, to_wxid: str):
|
async def _list_plugins(self, bot, to_wxid: str):
|
||||||
"""列出所有插件"""
|
"""列出所有插件"""
|
||||||
@@ -159,3 +206,200 @@ class ManagePlugin(PluginBase):
|
|||||||
logger.info(f"插件 {plugin_name} 已被禁用")
|
logger.info(f"插件 {plugin_name} 已被禁用")
|
||||||
else:
|
else:
|
||||||
await bot.send_text(to_wxid, f"❌ 插件 {plugin_name} 禁用失败")
|
await bot.send_text(to_wxid, f"❌ 插件 {plugin_name} 禁用失败")
|
||||||
|
|
||||||
|
async def _show_help(self, bot, to_wxid: str):
|
||||||
|
"""显示帮助信息"""
|
||||||
|
help_text = """📦 插件管理帮助
|
||||||
|
|
||||||
|
/插件列表 - 查看所有插件状态
|
||||||
|
/插件帮助 - 显示此帮助信息
|
||||||
|
|
||||||
|
/加载插件 <名称> - 加载新插件(无需重启)
|
||||||
|
/重载插件 <名称> - 热重载指定插件
|
||||||
|
/重载所有插件 - 热重载所有插件
|
||||||
|
|
||||||
|
/启用插件 <名称> - 启用已禁用的插件
|
||||||
|
/禁用插件 <名称> - 禁用指定插件
|
||||||
|
|
||||||
|
/刷新插件 - 扫描发现新插件
|
||||||
|
|
||||||
|
示例:
|
||||||
|
/加载插件 NewPlugin
|
||||||
|
/重载插件 AIChat
|
||||||
|
/禁用插件 Weather"""
|
||||||
|
await bot.send_text(to_wxid, help_text)
|
||||||
|
|
||||||
|
async def _reload_all_plugins(self, bot, to_wxid: str):
|
||||||
|
"""重载所有插件"""
|
||||||
|
pm = PluginManager()
|
||||||
|
|
||||||
|
await bot.send_text(to_wxid, "⏳ 正在重载所有插件...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 清理插件相关的模块缓存
|
||||||
|
modules_to_remove = [
|
||||||
|
name for name in sys.modules.keys()
|
||||||
|
if name.startswith('plugins.') and 'ManagePlugin' not in name
|
||||||
|
]
|
||||||
|
for module_name in modules_to_remove:
|
||||||
|
del sys.modules[module_name]
|
||||||
|
|
||||||
|
# 重载所有插件
|
||||||
|
reloaded = await pm.reload_plugins()
|
||||||
|
|
||||||
|
if reloaded:
|
||||||
|
await bot.send_text(
|
||||||
|
to_wxid,
|
||||||
|
f"✅ 重载完成\n已加载 {len(reloaded)} 个插件:\n" +
|
||||||
|
"\n".join(f" • {name}" for name in reloaded)
|
||||||
|
)
|
||||||
|
logger.success(f"已重载 {len(reloaded)} 个插件")
|
||||||
|
else:
|
||||||
|
await bot.send_text(to_wxid, "⚠️ 没有插件被重载")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"重载所有插件失败: {e}")
|
||||||
|
await bot.send_text(to_wxid, f"❌ 重载失败: {e}")
|
||||||
|
|
||||||
|
async def _refresh_plugins(self, bot, to_wxid: str):
|
||||||
|
"""刷新插件列表,发现新插件"""
|
||||||
|
pm = PluginManager()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 记录刷新前的插件数量
|
||||||
|
old_count = len(pm.plugin_info)
|
||||||
|
|
||||||
|
# 刷新插件列表
|
||||||
|
await pm.refresh_plugins()
|
||||||
|
|
||||||
|
# 计算新发现的插件
|
||||||
|
new_count = len(pm.plugin_info)
|
||||||
|
new_plugins = new_count - old_count
|
||||||
|
|
||||||
|
if new_plugins > 0:
|
||||||
|
# 获取新发现的插件名称
|
||||||
|
new_plugin_names = [
|
||||||
|
info["name"] for info in pm.plugin_info.values()
|
||||||
|
if not info.get("enabled", False)
|
||||||
|
][-new_plugins:]
|
||||||
|
|
||||||
|
await bot.send_text(
|
||||||
|
to_wxid,
|
||||||
|
f"✅ 发现 {new_plugins} 个新插件:\n" +
|
||||||
|
"\n".join(f" • {name}" for name in new_plugin_names) +
|
||||||
|
"\n\n使用 /启用插件 <名称> 来启用"
|
||||||
|
)
|
||||||
|
logger.info(f"发现 {new_plugins} 个新插件: {new_plugin_names}")
|
||||||
|
else:
|
||||||
|
await bot.send_text(to_wxid, "ℹ️ 没有发现新插件")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"刷新插件失败: {e}")
|
||||||
|
await bot.send_text(to_wxid, f"❌ 刷新失败: {e}")
|
||||||
|
|
||||||
|
async def _load_new_plugin(self, bot, to_wxid: str, plugin_name: str):
|
||||||
|
"""加载全新的插件(支持插件类名或目录名)"""
|
||||||
|
import os
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
pm = PluginManager()
|
||||||
|
|
||||||
|
# 检查是否已加载
|
||||||
|
if plugin_name in pm.plugins:
|
||||||
|
await bot.send_text(to_wxid, f"ℹ️ 插件 {plugin_name} 已经加载,如需重载请使用 /重载插件")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 尝试查找插件
|
||||||
|
found = False
|
||||||
|
plugin_class = None
|
||||||
|
actual_plugin_name = None
|
||||||
|
|
||||||
|
for dirname in os.listdir("plugins"):
|
||||||
|
dirpath = f"plugins/{dirname}"
|
||||||
|
if not os.path.isdir(dirpath) or not os.path.exists(f"{dirpath}/main.py"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 支持通过目录名或类名查找
|
||||||
|
if dirname == plugin_name or dirname.lower() == plugin_name.lower():
|
||||||
|
# 通过目录名匹配
|
||||||
|
module_name = f"plugins.{dirname}.main"
|
||||||
|
|
||||||
|
# 清理旧的模块缓存
|
||||||
|
if module_name in sys.modules:
|
||||||
|
del sys.modules[module_name]
|
||||||
|
|
||||||
|
# 导入模块
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# 查找插件类
|
||||||
|
for name, obj in inspect.getmembers(module):
|
||||||
|
if (inspect.isclass(obj) and
|
||||||
|
issubclass(obj, PluginBase) and
|
||||||
|
obj != PluginBase):
|
||||||
|
plugin_class = obj
|
||||||
|
actual_plugin_name = obj.__name__
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if found:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 尝试通过类名匹配
|
||||||
|
try:
|
||||||
|
module_name = f"plugins.{dirname}.main"
|
||||||
|
if module_name in sys.modules:
|
||||||
|
del sys.modules[module_name]
|
||||||
|
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
for name, obj in inspect.getmembers(module):
|
||||||
|
if (inspect.isclass(obj) and
|
||||||
|
issubclass(obj, PluginBase) and
|
||||||
|
obj != PluginBase and
|
||||||
|
(obj.__name__ == plugin_name or obj.__name__.lower() == plugin_name.lower())):
|
||||||
|
plugin_class = obj
|
||||||
|
actual_plugin_name = obj.__name__
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if found:
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not found or not plugin_class:
|
||||||
|
await bot.send_text(
|
||||||
|
to_wxid,
|
||||||
|
f"❌ 未找到插件 {plugin_name}\n"
|
||||||
|
f"请确认:\n"
|
||||||
|
f"1. plugins/{plugin_name}/main.py 存在\n"
|
||||||
|
f"2. main.py 中有继承 PluginBase 的类"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查是否已加载(用实际类名再检查一次)
|
||||||
|
if actual_plugin_name in pm.plugins:
|
||||||
|
await bot.send_text(to_wxid, f"ℹ️ 插件 {actual_plugin_name} 已经加载")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 加载插件
|
||||||
|
success = await pm._load_plugin_class(plugin_class)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
await bot.send_text(
|
||||||
|
to_wxid,
|
||||||
|
f"✅ 插件加载成功\n"
|
||||||
|
f"名称: {actual_plugin_name}\n"
|
||||||
|
f"版本: {plugin_class.version}\n"
|
||||||
|
f"作者: {plugin_class.author}"
|
||||||
|
)
|
||||||
|
logger.success(f"新插件 {actual_plugin_name} 已加载")
|
||||||
|
else:
|
||||||
|
await bot.send_text(to_wxid, f"❌ 插件 {actual_plugin_name} 加载失败")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
logger.error(f"加载新插件失败: {e}\n{traceback.format_exc()}")
|
||||||
|
await bot.send_text(to_wxid, f"❌ 加载失败: {e}")
|
||||||
|
|||||||
100
plugins/Menu/main.py
Normal file
100
plugins/Menu/main.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""
|
||||||
|
菜单插件
|
||||||
|
|
||||||
|
用户发送 /菜单、/帮助 等指令时,按顺序发送菜单图片
|
||||||
|
图片命名格式:menu1.png、menu2.png、menu3.png ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
from loguru import logger
|
||||||
|
from utils.plugin_base import PluginBase
|
||||||
|
from utils.decorators import on_text_message
|
||||||
|
|
||||||
|
|
||||||
|
class Menu(PluginBase):
|
||||||
|
"""菜单插件"""
|
||||||
|
|
||||||
|
# 插件元数据
|
||||||
|
description = "菜单插件,发送帮助图片"
|
||||||
|
author = "ShiHao"
|
||||||
|
version = "1.0.0"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.menu_dir = None
|
||||||
|
self.trigger_commands = ["/菜单", "/帮助", "/help", "/menu"]
|
||||||
|
self.send_interval = 0.5 # 发送间隔(秒),避免发送过快
|
||||||
|
|
||||||
|
async def async_init(self):
|
||||||
|
"""插件异步初始化"""
|
||||||
|
# 设置菜单图片目录
|
||||||
|
self.menu_dir = Path(__file__).parent / "images"
|
||||||
|
self.menu_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# 扫描现有图片
|
||||||
|
images = self._get_menu_images()
|
||||||
|
logger.info(f"菜单插件已加载,图片目录: {self.menu_dir},找到 {len(images)} 张菜单图片")
|
||||||
|
|
||||||
|
def _get_menu_images(self) -> list:
|
||||||
|
"""
|
||||||
|
获取所有符合命名规范的菜单图片,按序号排序
|
||||||
|
|
||||||
|
命名格式:menu1.png、menu2.jpg、menu3.jpeg 等
|
||||||
|
"""
|
||||||
|
if not self.menu_dir or not self.menu_dir.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 匹配 menu + 数字 + 图片扩展名
|
||||||
|
pattern = re.compile(r'^menu(\d+)\.(png|jpg|jpeg|gif|bmp)$', re.IGNORECASE)
|
||||||
|
|
||||||
|
images = []
|
||||||
|
for file in self.menu_dir.iterdir():
|
||||||
|
if file.is_file():
|
||||||
|
match = pattern.match(file.name)
|
||||||
|
if match:
|
||||||
|
seq_num = int(match.group(1))
|
||||||
|
images.append((seq_num, file))
|
||||||
|
|
||||||
|
# 按序号排序
|
||||||
|
images.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
# 只返回文件路径
|
||||||
|
return [img[1] for img in images]
|
||||||
|
|
||||||
|
@on_text_message(priority=60)
|
||||||
|
async def handle_menu_command(self, bot, message: dict):
|
||||||
|
"""处理菜单指令"""
|
||||||
|
content = message.get("Content", "").strip()
|
||||||
|
from_wxid = message.get("FromWxid", "")
|
||||||
|
|
||||||
|
# 检查是否是菜单指令
|
||||||
|
if content not in self.trigger_commands:
|
||||||
|
return True # 继续传递给其他插件
|
||||||
|
|
||||||
|
logger.info(f"收到菜单指令: {content}, from: {from_wxid}")
|
||||||
|
|
||||||
|
# 获取菜单图片
|
||||||
|
images = self._get_menu_images()
|
||||||
|
|
||||||
|
if not images:
|
||||||
|
await bot.send_text(from_wxid, "暂无菜单图片,请联系管理员添加")
|
||||||
|
logger.warning(f"菜单图片目录为空: {self.menu_dir}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 按顺序发送图片
|
||||||
|
for i, image_path in enumerate(images):
|
||||||
|
try:
|
||||||
|
await bot.send_image(from_wxid, str(image_path))
|
||||||
|
logger.debug(f"已发送菜单图片 {i+1}/{len(images)}: {image_path.name}")
|
||||||
|
|
||||||
|
# 发送间隔,避免发送过快
|
||||||
|
if i < len(images) - 1:
|
||||||
|
await asyncio.sleep(self.send_interval)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送菜单图片失败: {image_path}, 错误: {e}")
|
||||||
|
|
||||||
|
logger.success(f"菜单图片发送完成,共 {len(images)} 张")
|
||||||
|
return False # 阻止继续传递
|
||||||
@@ -18,6 +18,7 @@ from utils.decorators import (
|
|||||||
on_file_message,
|
on_file_message,
|
||||||
on_emoji_message
|
on_emoji_message
|
||||||
)
|
)
|
||||||
|
from utils.redis_cache import RedisCache, get_cache
|
||||||
import pymysql
|
import pymysql
|
||||||
from WechatHook import WechatHookClient
|
from WechatHook import WechatHookClient
|
||||||
from minio import Minio
|
from minio import Minio
|
||||||
@@ -39,6 +40,7 @@ class MessageLogger(PluginBase):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = None
|
self.config = None
|
||||||
self.db_config = None
|
self.db_config = None
|
||||||
|
self.redis_cache = None # Redis 缓存实例
|
||||||
|
|
||||||
# 创建独立的日志记录器
|
# 创建独立的日志记录器
|
||||||
self._setup_logger()
|
self._setup_logger()
|
||||||
@@ -83,9 +85,22 @@ class MessageLogger(PluginBase):
|
|||||||
|
|
||||||
self.db_config = self.config["database"]
|
self.db_config = self.config["database"]
|
||||||
|
|
||||||
|
# 初始化 Redis 缓存
|
||||||
|
redis_config = self.config.get("redis", {})
|
||||||
|
if redis_config.get("enabled", False):
|
||||||
|
self.log.info("正在初始化 Redis 缓存...")
|
||||||
|
self.redis_cache = RedisCache(redis_config)
|
||||||
|
if self.redis_cache.enabled:
|
||||||
|
self.log.success(f"Redis 缓存初始化成功,TTL={redis_config.get('ttl', 3600)}秒")
|
||||||
|
else:
|
||||||
|
self.log.warning("Redis 缓存初始化失败,将不使用缓存")
|
||||||
|
self.redis_cache = None
|
||||||
|
else:
|
||||||
|
self.log.info("Redis 缓存未启用")
|
||||||
|
|
||||||
# 初始化 MinIO 客户端
|
# 初始化 MinIO 客户端
|
||||||
self.minio_client = Minio(
|
self.minio_client = Minio(
|
||||||
"101.201.65.129:19000",
|
"115.190.113.141:19000",
|
||||||
access_key="admin",
|
access_key="admin",
|
||||||
secret_key="80012029Lz",
|
secret_key="80012029Lz",
|
||||||
secure=False
|
secure=False
|
||||||
@@ -216,7 +231,7 @@ class MessageLogger(PluginBase):
|
|||||||
return ("", "", "", "", "0")
|
return ("", "", "", "", "0")
|
||||||
|
|
||||||
async def download_image_and_upload(self, bot, cdnurl: str, aeskey: str) -> str:
|
async def download_image_and_upload(self, bot, cdnurl: str, aeskey: str) -> str:
|
||||||
"""下载图片并上传到 MinIO"""
|
"""下载图片并上传到 MinIO,同时缓存 base64 供其他插件使用"""
|
||||||
try:
|
try:
|
||||||
temp_file = Path(__file__).parent / f"temp_{uuid.uuid4().hex}.jpg"
|
temp_file = Path(__file__).parent / f"temp_{uuid.uuid4().hex}.jpg"
|
||||||
success = await bot.cdn_download(cdnurl, aeskey, str(temp_file), file_type=2)
|
success = await bot.cdn_download(cdnurl, aeskey, str(temp_file), file_type=2)
|
||||||
@@ -225,12 +240,26 @@ class MessageLogger(PluginBase):
|
|||||||
|
|
||||||
# 等待文件下载完成
|
# 等待文件下载完成
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
for _ in range(50):
|
for _ in range(50):
|
||||||
if temp_file.exists() and temp_file.stat().st_size > 0:
|
if temp_file.exists() and temp_file.stat().st_size > 0:
|
||||||
break
|
break
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
if temp_file.exists() and temp_file.stat().st_size > 0:
|
if temp_file.exists() and temp_file.stat().st_size > 0:
|
||||||
|
# 读取文件并缓存 base64(供 AIChat 等插件使用)
|
||||||
|
with open(temp_file, "rb") as f:
|
||||||
|
image_data = f.read()
|
||||||
|
base64_data = f"data:image/jpeg;base64,{base64.b64encode(image_data).decode()}"
|
||||||
|
|
||||||
|
# 缓存到 Redis(5分钟过期)
|
||||||
|
redis_cache = get_cache()
|
||||||
|
if redis_cache and redis_cache.enabled:
|
||||||
|
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
|
||||||
|
if media_key:
|
||||||
|
redis_cache.cache_media(media_key, base64_data, "image", ttl=300)
|
||||||
|
self.log.debug(f"图片已缓存到 Redis: {media_key[:20]}...")
|
||||||
|
|
||||||
media_url = await self.upload_file_to_minio(str(temp_file), "images")
|
media_url = await self.upload_file_to_minio(str(temp_file), "images")
|
||||||
temp_file.unlink()
|
temp_file.unlink()
|
||||||
return media_url
|
return media_url
|
||||||
@@ -326,13 +355,25 @@ class MessageLogger(PluginBase):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def download_and_upload(self, url: str, file_type: str, ext: str) -> str:
|
async def download_and_upload(self, url: str, file_type: str, ext: str) -> str:
|
||||||
"""下载文件并上传到 MinIO"""
|
"""下载文件并上传到 MinIO,同时缓存 base64 供其他插件使用"""
|
||||||
try:
|
try:
|
||||||
|
import base64
|
||||||
# 下载文件
|
# 下载文件
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
async with session.get(url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
||||||
if resp.status == 200:
|
if resp.status == 200:
|
||||||
data = await resp.read()
|
data = await resp.read()
|
||||||
|
|
||||||
|
# 缓存表情包 base64(供 AIChat 等插件使用)
|
||||||
|
if file_type == "emojis" and data:
|
||||||
|
redis_cache = get_cache()
|
||||||
|
if redis_cache and redis_cache.enabled:
|
||||||
|
base64_data = f"data:image/gif;base64,{base64.b64encode(data).decode()}"
|
||||||
|
media_key = RedisCache.generate_media_key(cdnurl=url)
|
||||||
|
if media_key:
|
||||||
|
redis_cache.cache_media(media_key, base64_data, "emoji", ttl=300)
|
||||||
|
self.log.debug(f"表情包已缓存到 Redis: {media_key[:20]}...")
|
||||||
|
|
||||||
# 保存到临时文件
|
# 保存到临时文件
|
||||||
temp_file = Path(__file__).parent / f"temp_{uuid.uuid4().hex}{ext}"
|
temp_file = Path(__file__).parent / f"temp_{uuid.uuid4().hex}{ext}"
|
||||||
temp_file.write_bytes(data)
|
temp_file.write_bytes(data)
|
||||||
@@ -374,7 +415,7 @@ class MessageLogger(PluginBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 返回访问 URL
|
# 返回访问 URL
|
||||||
url = f"http://101.201.65.129:19000/{self.minio_bucket}/{object_name}"
|
url = f"http://115.190.113.141:19000/{self.minio_bucket}/{object_name}"
|
||||||
self.log.debug(f"文件上传成功: {url}")
|
self.log.debug(f"文件上传成功: {url}")
|
||||||
return url
|
return url
|
||||||
|
|
||||||
@@ -405,29 +446,45 @@ class MessageLogger(PluginBase):
|
|||||||
avatar_url = ""
|
avatar_url = ""
|
||||||
|
|
||||||
if is_group and self.config["behavior"]["fetch_avatar"]:
|
if is_group and self.config["behavior"]["fetch_avatar"]:
|
||||||
try:
|
cache_hit = False
|
||||||
self.log.info(f"尝试获取用户信息: from_wxid={from_wxid}, sender_wxid={sender_wxid}")
|
|
||||||
user_info = await bot.get_user_info_in_chatroom(from_wxid, sender_wxid)
|
|
||||||
self.log.info(f"获取到用户信息: {user_info}")
|
|
||||||
|
|
||||||
if user_info:
|
# 1. 先尝试从 Redis 缓存获取
|
||||||
# 处理不同的数据结构
|
if self.redis_cache and self.redis_cache.enabled:
|
||||||
if isinstance(user_info.get("nickName"), dict):
|
cached_info = self.redis_cache.get_user_basic_info(from_wxid, sender_wxid)
|
||||||
nickname = user_info.get("nickName", {}).get("string", "")
|
if cached_info:
|
||||||
|
nickname = cached_info.get("nickname", "")
|
||||||
|
avatar_url = cached_info.get("avatar_url", "")
|
||||||
|
if nickname and avatar_url:
|
||||||
|
cache_hit = True
|
||||||
|
self.log.debug(f"[缓存命中] {sender_wxid}: {nickname}")
|
||||||
|
|
||||||
|
# 2. 缓存未命中,调用 API 获取
|
||||||
|
if not cache_hit:
|
||||||
|
try:
|
||||||
|
self.log.info(f"[缓存未命中] 调用API获取用户信息: {sender_wxid}")
|
||||||
|
user_info = await bot.get_user_info_in_chatroom(from_wxid, sender_wxid)
|
||||||
|
|
||||||
|
if user_info:
|
||||||
|
# 处理不同的数据结构
|
||||||
|
if isinstance(user_info.get("nickName"), dict):
|
||||||
|
nickname = user_info.get("nickName", {}).get("string", "")
|
||||||
|
else:
|
||||||
|
nickname = user_info.get("nickName", "")
|
||||||
|
|
||||||
|
avatar_url = user_info.get("bigHeadImgUrl", "")
|
||||||
|
self.log.info(f"API获取成功: nickname={nickname}, avatar_url={avatar_url[:50] if avatar_url else ''}...")
|
||||||
|
|
||||||
|
# 3. 将用户信息存入 Redis 缓存
|
||||||
|
if self.redis_cache and self.redis_cache.enabled and nickname:
|
||||||
|
self.redis_cache.set_user_info(from_wxid, sender_wxid, user_info)
|
||||||
|
self.log.debug(f"[已缓存] {sender_wxid}: {nickname}")
|
||||||
else:
|
else:
|
||||||
nickname = user_info.get("nickName", "")
|
self.log.warning(f"用户信息为空: {sender_wxid}")
|
||||||
|
|
||||||
avatar_url = user_info.get("bigHeadImgUrl", "")
|
except Exception as e:
|
||||||
self.log.info(f"解析用户信息: nickname={nickname}, avatar_url={avatar_url[:50]}...")
|
self.log.error(f"获取用户信息失败: {e}")
|
||||||
else:
|
|
||||||
self.log.warning(f"用户信息为空: {sender_wxid}")
|
|
||||||
|
|
||||||
except Exception as e:
|
# 4. 如果仍然没有获取到,从历史记录中查找
|
||||||
self.log.error(f"获取用户信息失败: {e}")
|
|
||||||
import traceback
|
|
||||||
self.log.error(f"详细错误: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
# 如果获取失败,从历史记录中查找
|
|
||||||
if not nickname or not avatar_url:
|
if not nickname or not avatar_url:
|
||||||
self.log.info(f"尝试从历史记录获取用户信息: {sender_wxid}")
|
self.log.info(f"尝试从历史记录获取用户信息: {sender_wxid}")
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class PerformanceMonitor(PluginBase):
|
|||||||
if sender_wxid not in admins:
|
if sender_wxid not in admins:
|
||||||
return
|
return
|
||||||
|
|
||||||
if content in ["/性能", "/stats", "/状态"]:
|
if content in ["/性能", "/stats", "/状态", "/性能报告"]:
|
||||||
stats_msg = await self._get_performance_stats(bot)
|
stats_msg = await self._get_performance_stats(bot)
|
||||||
await bot.send_text(from_wxid, stats_msg)
|
await bot.send_text(from_wxid, stats_msg)
|
||||||
return False # 阻止其他插件处理
|
return False # 阻止其他插件处理
|
||||||
@@ -66,6 +66,16 @@ class PerformanceMonitor(PluginBase):
|
|||||||
async def _get_performance_stats(self, bot) -> str:
|
async def _get_performance_stats(self, bot) -> str:
|
||||||
"""获取性能统计信息"""
|
"""获取性能统计信息"""
|
||||||
try:
|
try:
|
||||||
|
# 尝试使用新的性能监控器
|
||||||
|
try:
|
||||||
|
from utils.bot_utils import get_performance_monitor
|
||||||
|
monitor = get_performance_monitor()
|
||||||
|
if monitor and monitor.message_received > 0:
|
||||||
|
return self._format_new_stats(monitor)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 降级到旧的统计方式
|
||||||
stats = await self._get_performance_data()
|
stats = await self._get_performance_data()
|
||||||
|
|
||||||
# 格式化统计信息
|
# 格式化统计信息
|
||||||
@@ -99,6 +109,52 @@ class PerformanceMonitor(PluginBase):
|
|||||||
logger.error(f"获取性能统计失败: {e}")
|
logger.error(f"获取性能统计失败: {e}")
|
||||||
return f"❌ 获取性能统计失败: {str(e)}"
|
return f"❌ 获取性能统计失败: {str(e)}"
|
||||||
|
|
||||||
|
def _format_new_stats(self, monitor) -> str:
|
||||||
|
"""格式化新性能监控器的统计信息"""
|
||||||
|
stats = monitor.get_stats()
|
||||||
|
|
||||||
|
# 基础信息
|
||||||
|
msg = f"""📊 系统性能报告
|
||||||
|
|
||||||
|
🕐 运行时间: {stats['uptime_formatted']}
|
||||||
|
|
||||||
|
📨 消息统计:
|
||||||
|
• 收到: {stats['messages']['received']}
|
||||||
|
• 处理: {stats['messages']['processed']}
|
||||||
|
• 失败: {stats['messages']['failed']}
|
||||||
|
• 丢弃: {stats['messages']['dropped']}
|
||||||
|
• 成功率: {stats['messages']['success_rate']}
|
||||||
|
• 处理速率: {stats['messages']['processing_rate']}
|
||||||
|
|
||||||
|
⚡ 处理性能:
|
||||||
|
• 平均耗时: {stats['processing_time']['average_ms']}ms
|
||||||
|
• 最大耗时: {stats['processing_time']['max_ms']}ms
|
||||||
|
• 最小耗时: {stats['processing_time']['min_ms']}ms
|
||||||
|
|
||||||
|
📦 队列状态:
|
||||||
|
• 当前大小: {stats['queue']['current_size']}
|
||||||
|
• 历史最大: {stats['queue']['max_size']}"""
|
||||||
|
|
||||||
|
# 熔断器状态
|
||||||
|
cb = stats.get('circuit_breaker', {})
|
||||||
|
if cb:
|
||||||
|
state_icon = {'closed': '🟢', 'open': '🔴', 'half_open': '🟡'}.get(cb.get('state', ''), '⚪')
|
||||||
|
msg += f"""
|
||||||
|
|
||||||
|
🔌 熔断器:
|
||||||
|
• 状态: {state_icon} {cb.get('state', 'N/A')}
|
||||||
|
• 失败计数: {cb.get('failure_count', 0)}
|
||||||
|
• 恢复时间: {cb.get('current_recovery_time', 0):.0f}s"""
|
||||||
|
|
||||||
|
# 插件耗时排行
|
||||||
|
plugins = stats.get('plugins', [])
|
||||||
|
if plugins:
|
||||||
|
msg += "\n\n🔧 插件耗时排行:"
|
||||||
|
for i, p in enumerate(plugins[:5], 1):
|
||||||
|
msg += f"\n {i}. {p['name']}: {p['avg_time_ms']}ms ({p['calls']}次)"
|
||||||
|
|
||||||
|
return msg
|
||||||
|
|
||||||
async def _get_performance_data(self) -> dict:
|
async def _get_performance_data(self) -> dict:
|
||||||
"""获取性能数据"""
|
"""获取性能数据"""
|
||||||
# 系统资源(简化版本,不依赖psutil)
|
# 系统资源(简化版本,不依赖psutil)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import pymysql
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from utils.plugin_base import PluginBase
|
from utils.plugin_base import PluginBase
|
||||||
from utils.decorators import on_text_message
|
from utils.decorators import on_text_message
|
||||||
|
from utils.redis_cache import get_cache
|
||||||
from WechatHook import WechatHookClient
|
from WechatHook import WechatHookClient
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -151,11 +152,20 @@ class SignInPlugin(PluginBase):
|
|||||||
|
|
||||||
async def get_user_nickname_from_group(self, client: WechatHookClient,
|
async def get_user_nickname_from_group(self, client: WechatHookClient,
|
||||||
group_wxid: str, user_wxid: str) -> str:
|
group_wxid: str, user_wxid: str) -> str:
|
||||||
"""从群聊中获取用户昵称"""
|
"""从群聊中获取用户昵称(优先使用缓存)"""
|
||||||
try:
|
try:
|
||||||
logger.debug(f"尝试获取用户 {user_wxid} 在群 {group_wxid} 中的昵称")
|
# 动态获取缓存实例(由 MessageLogger 初始化)
|
||||||
|
redis_cache = get_cache()
|
||||||
|
|
||||||
# 使用11174 API获取单个用户的详细信息
|
# 1. 先尝试从 Redis 缓存获取
|
||||||
|
if redis_cache and redis_cache.enabled:
|
||||||
|
cached_info = redis_cache.get_user_basic_info(group_wxid, user_wxid)
|
||||||
|
if cached_info and cached_info.get("nickname"):
|
||||||
|
logger.debug(f"[缓存命中] {user_wxid}: {cached_info['nickname']}")
|
||||||
|
return cached_info["nickname"]
|
||||||
|
|
||||||
|
# 2. 缓存未命中,调用 API 获取
|
||||||
|
logger.debug(f"[缓存未命中] 调用API获取用户昵称: {user_wxid}")
|
||||||
user_info = await client.get_user_info_in_chatroom(group_wxid, user_wxid)
|
user_info = await client.get_user_info_in_chatroom(group_wxid, user_wxid)
|
||||||
|
|
||||||
if user_info:
|
if user_info:
|
||||||
@@ -163,7 +173,11 @@ class SignInPlugin(PluginBase):
|
|||||||
nickname = user_info.get("nickName", {}).get("string", "")
|
nickname = user_info.get("nickName", {}).get("string", "")
|
||||||
|
|
||||||
if nickname:
|
if nickname:
|
||||||
logger.success(f"获取到用户昵称: {user_wxid} -> {nickname}")
|
logger.success(f"API获取用户昵称成功: {user_wxid} -> {nickname}")
|
||||||
|
# 3. 将用户信息存入缓存
|
||||||
|
if redis_cache and redis_cache.enabled:
|
||||||
|
redis_cache.set_user_info(group_wxid, user_wxid, user_info)
|
||||||
|
logger.debug(f"[已缓存] {user_wxid}: {nickname}")
|
||||||
return nickname
|
return nickname
|
||||||
else:
|
else:
|
||||||
logger.warning(f"用户 {user_wxid} 的昵称字段为空")
|
logger.warning(f"用户 {user_wxid} 的昵称字段为空")
|
||||||
@@ -770,13 +784,24 @@ class SignInPlugin(PluginBase):
|
|||||||
current_points = updated_user["points"] if updated_user else total_earned
|
current_points = updated_user["points"] if updated_user else total_earned
|
||||||
updated_user["points"] = current_points
|
updated_user["points"] = current_points
|
||||||
|
|
||||||
# 尝试获取用户头像
|
# 尝试获取用户头像(优先使用缓存)
|
||||||
avatar_url = None
|
avatar_url = None
|
||||||
if is_group:
|
if is_group:
|
||||||
try:
|
try:
|
||||||
user_detail = await client.get_user_info_in_chatroom(from_wxid, user_wxid)
|
redis_cache = get_cache()
|
||||||
if user_detail:
|
# 先从缓存获取
|
||||||
avatar_url = user_detail.get("bigHeadImgUrl", "")
|
if redis_cache and redis_cache.enabled:
|
||||||
|
cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid)
|
||||||
|
if cached_info:
|
||||||
|
avatar_url = cached_info.get("avatar_url", "")
|
||||||
|
# 缓存未命中则调用 API
|
||||||
|
if not avatar_url:
|
||||||
|
user_detail = await client.get_user_info_in_chatroom(from_wxid, user_wxid)
|
||||||
|
if user_detail:
|
||||||
|
avatar_url = user_detail.get("bigHeadImgUrl", "")
|
||||||
|
# 存入缓存
|
||||||
|
if redis_cache and redis_cache.enabled:
|
||||||
|
redis_cache.set_user_info(from_wxid, user_wxid, user_detail)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"获取用户头像失败: {e}")
|
logger.warning(f"获取用户头像失败: {e}")
|
||||||
|
|
||||||
@@ -864,13 +889,24 @@ class SignInPlugin(PluginBase):
|
|||||||
self.update_user_nickname(user_wxid, nickname)
|
self.update_user_nickname(user_wxid, nickname)
|
||||||
user_info["nickname"] = nickname
|
user_info["nickname"] = nickname
|
||||||
|
|
||||||
# 尝试获取用户头像
|
# 尝试获取用户头像(优先使用缓存)
|
||||||
avatar_url = None
|
avatar_url = None
|
||||||
if is_group:
|
if is_group:
|
||||||
try:
|
try:
|
||||||
user_detail = await client.get_user_info_in_chatroom(from_wxid, user_wxid)
|
redis_cache = get_cache()
|
||||||
if user_detail:
|
# 先从缓存获取
|
||||||
avatar_url = user_detail.get("bigHeadImgUrl", "")
|
if redis_cache and redis_cache.enabled:
|
||||||
|
cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid)
|
||||||
|
if cached_info:
|
||||||
|
avatar_url = cached_info.get("avatar_url", "")
|
||||||
|
# 缓存未命中则调用 API
|
||||||
|
if not avatar_url:
|
||||||
|
user_detail = await client.get_user_info_in_chatroom(from_wxid, user_wxid)
|
||||||
|
if user_detail:
|
||||||
|
avatar_url = user_detail.get("bigHeadImgUrl", "")
|
||||||
|
# 存入缓存
|
||||||
|
if redis_cache and redis_cache.enabled:
|
||||||
|
redis_cache.set_user_info(from_wxid, user_wxid, user_detail)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"获取用户头像失败: {e}")
|
logger.warning(f"获取用户头像失败: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -3,3 +3,4 @@ APScheduler==3.11.0
|
|||||||
aiohttp==3.9.1
|
aiohttp==3.9.1
|
||||||
Pillow>=10.0.0
|
Pillow>=10.0.0
|
||||||
aiohttp-socks>=0.8.0
|
aiohttp-socks>=0.8.0
|
||||||
|
redis>=5.0.0
|
||||||
|
|||||||
658
utils/bot_utils.py
Normal file
658
utils/bot_utils.py
Normal file
@@ -0,0 +1,658 @@
|
|||||||
|
"""
|
||||||
|
机器人核心工具模块
|
||||||
|
|
||||||
|
包含:
|
||||||
|
- 优先级消息队列
|
||||||
|
- 自适应熔断器
|
||||||
|
- 请求重试机制
|
||||||
|
- 配置热更新
|
||||||
|
- 性能监控
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import heapq
|
||||||
|
import os
|
||||||
|
import tomllib
|
||||||
|
import functools
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Callable, Any, Tuple
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from loguru import logger
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 优先级消息队列 ====================
|
||||||
|
|
||||||
|
# 消息优先级定义
|
||||||
|
class MessagePriority:
|
||||||
|
"""消息优先级常量"""
|
||||||
|
CRITICAL = 100 # 系统消息、登录信息
|
||||||
|
HIGH = 80 # 管理员命令
|
||||||
|
NORMAL = 50 # @机器人的消息
|
||||||
|
LOW = 20 # 普通群消息
|
||||||
|
|
||||||
|
# 高优先级消息类型
|
||||||
|
PRIORITY_MESSAGE_TYPES = {
|
||||||
|
11025: MessagePriority.CRITICAL, # 登录信息
|
||||||
|
11058: MessagePriority.CRITICAL, # 系统消息
|
||||||
|
11098: MessagePriority.HIGH, # 群成员加入
|
||||||
|
11099: MessagePriority.HIGH, # 群成员退出
|
||||||
|
11100: MessagePriority.HIGH, # 群信息变更
|
||||||
|
11056: MessagePriority.HIGH, # 好友请求
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(order=True)
|
||||||
|
class PriorityMessage:
|
||||||
|
"""优先级消息包装"""
|
||||||
|
priority: int
|
||||||
|
timestamp: float = field(compare=False)
|
||||||
|
msg_type: int = field(compare=False)
|
||||||
|
data: dict = field(compare=False)
|
||||||
|
|
||||||
|
def __init__(self, msg_type: int, data: dict, priority: int = None):
|
||||||
|
# 优先级越高,数值越小(因为heapq是最小堆)
|
||||||
|
self.priority = -(priority or PRIORITY_MESSAGE_TYPES.get(msg_type, MessagePriority.NORMAL))
|
||||||
|
self.timestamp = time.time()
|
||||||
|
self.msg_type = msg_type
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
|
||||||
|
class PriorityMessageQueue:
|
||||||
|
"""优先级消息队列"""
|
||||||
|
|
||||||
|
def __init__(self, maxsize: int = 1000):
|
||||||
|
self.maxsize = maxsize
|
||||||
|
self._heap: List[PriorityMessage] = []
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._not_empty = asyncio.Event()
|
||||||
|
self._unfinished_tasks = 0
|
||||||
|
self._finished = asyncio.Event()
|
||||||
|
self._finished.set()
|
||||||
|
|
||||||
|
def qsize(self) -> int:
|
||||||
|
"""返回队列大小"""
|
||||||
|
return len(self._heap)
|
||||||
|
|
||||||
|
def empty(self) -> bool:
|
||||||
|
"""队列是否为空"""
|
||||||
|
return len(self._heap) == 0
|
||||||
|
|
||||||
|
def full(self) -> bool:
|
||||||
|
"""队列是否已满"""
|
||||||
|
return len(self._heap) >= self.maxsize
|
||||||
|
|
||||||
|
async def put(self, msg_type: int, data: dict, priority: int = None):
|
||||||
|
"""添加消息到队列"""
|
||||||
|
async with self._lock:
|
||||||
|
msg = PriorityMessage(msg_type, data, priority)
|
||||||
|
heapq.heappush(self._heap, msg)
|
||||||
|
self._unfinished_tasks += 1
|
||||||
|
self._finished.clear()
|
||||||
|
self._not_empty.set()
|
||||||
|
|
||||||
|
async def get(self) -> Tuple[int, dict]:
|
||||||
|
"""获取优先级最高的消息"""
|
||||||
|
while True:
|
||||||
|
async with self._lock:
|
||||||
|
if self._heap:
|
||||||
|
msg = heapq.heappop(self._heap)
|
||||||
|
if not self._heap:
|
||||||
|
self._not_empty.clear()
|
||||||
|
return (msg.msg_type, msg.data)
|
||||||
|
|
||||||
|
# 等待新消息
|
||||||
|
await self._not_empty.wait()
|
||||||
|
|
||||||
|
def get_nowait(self) -> Tuple[int, dict]:
|
||||||
|
"""非阻塞获取消息"""
|
||||||
|
if not self._heap:
|
||||||
|
raise asyncio.QueueEmpty()
|
||||||
|
msg = heapq.heappop(self._heap)
|
||||||
|
if not self._heap:
|
||||||
|
self._not_empty.clear()
|
||||||
|
return (msg.msg_type, msg.data)
|
||||||
|
|
||||||
|
def task_done(self):
|
||||||
|
"""标记任务完成"""
|
||||||
|
self._unfinished_tasks -= 1
|
||||||
|
if self._unfinished_tasks == 0:
|
||||||
|
self._finished.set()
|
||||||
|
|
||||||
|
async def join(self):
|
||||||
|
"""等待所有任务完成"""
|
||||||
|
await self._finished.wait()
|
||||||
|
|
||||||
|
def drop_lowest_priority(self) -> bool:
|
||||||
|
"""丢弃优先级最低的消息"""
|
||||||
|
if not self._heap:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 找到优先级最低的消息(priority值最大,因为是负数所以最小)
|
||||||
|
min_idx = 0
|
||||||
|
for i, msg in enumerate(self._heap):
|
||||||
|
if msg.priority > self._heap[min_idx].priority:
|
||||||
|
min_idx = i
|
||||||
|
|
||||||
|
# 删除该消息
|
||||||
|
self._heap.pop(min_idx)
|
||||||
|
heapq.heapify(self._heap)
|
||||||
|
self._unfinished_tasks -= 1
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 自适应熔断器 ====================
|
||||||
|
|
||||||
|
class AdaptiveCircuitBreaker:
|
||||||
|
"""自适应熔断器"""
|
||||||
|
|
||||||
|
# 熔断器状态
|
||||||
|
STATE_CLOSED = "closed" # 正常状态
|
||||||
|
STATE_OPEN = "open" # 熔断状态
|
||||||
|
STATE_HALF_OPEN = "half_open" # 半开状态(尝试恢复)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
failure_threshold: int = 5,
|
||||||
|
success_threshold: int = 3,
|
||||||
|
initial_recovery_time: float = 5.0,
|
||||||
|
max_recovery_time: float = 300.0,
|
||||||
|
recovery_multiplier: float = 2.0
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化熔断器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
failure_threshold: 触发熔断的连续失败次数
|
||||||
|
success_threshold: 恢复正常的连续成功次数
|
||||||
|
initial_recovery_time: 初始恢复等待时间(秒)
|
||||||
|
max_recovery_time: 最大恢复等待时间(秒)
|
||||||
|
recovery_multiplier: 恢复时间增长倍数
|
||||||
|
"""
|
||||||
|
self.failure_threshold = failure_threshold
|
||||||
|
self.success_threshold = success_threshold
|
||||||
|
self.initial_recovery_time = initial_recovery_time
|
||||||
|
self.max_recovery_time = max_recovery_time
|
||||||
|
self.recovery_multiplier = recovery_multiplier
|
||||||
|
|
||||||
|
# 状态
|
||||||
|
self.state = self.STATE_CLOSED
|
||||||
|
self.failure_count = 0
|
||||||
|
self.success_count = 0
|
||||||
|
self.last_failure_time = 0
|
||||||
|
self.current_recovery_time = initial_recovery_time
|
||||||
|
|
||||||
|
# 统计
|
||||||
|
self.total_failures = 0
|
||||||
|
self.total_successes = 0
|
||||||
|
self.total_rejections = 0
|
||||||
|
|
||||||
|
def is_open(self) -> bool:
|
||||||
|
"""检查熔断器是否开启(是否应该拒绝请求)"""
|
||||||
|
if self.state == self.STATE_CLOSED:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.state == self.STATE_OPEN:
|
||||||
|
# 检查是否可以尝试恢复
|
||||||
|
elapsed = time.time() - self.last_failure_time
|
||||||
|
if elapsed >= self.current_recovery_time:
|
||||||
|
self.state = self.STATE_HALF_OPEN
|
||||||
|
self.success_count = 0
|
||||||
|
logger.info(f"熔断器进入半开状态,尝试恢复(等待了 {elapsed:.1f}s)")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 半开状态,允许请求通过
|
||||||
|
return False
|
||||||
|
|
||||||
|
def record_success(self):
|
||||||
|
"""记录成功"""
|
||||||
|
self.total_successes += 1
|
||||||
|
|
||||||
|
if self.state == self.STATE_HALF_OPEN:
|
||||||
|
self.success_count += 1
|
||||||
|
if self.success_count >= self.success_threshold:
|
||||||
|
# 恢复正常
|
||||||
|
self.state = self.STATE_CLOSED
|
||||||
|
self.failure_count = 0
|
||||||
|
self.success_count = 0
|
||||||
|
self.current_recovery_time = self.initial_recovery_time
|
||||||
|
logger.success(f"熔断器已恢复正常(连续成功 {self.success_threshold} 次)")
|
||||||
|
else:
|
||||||
|
# 正常状态,重置失败计数
|
||||||
|
self.failure_count = 0
|
||||||
|
|
||||||
|
def record_failure(self):
|
||||||
|
"""记录失败"""
|
||||||
|
self.total_failures += 1
|
||||||
|
self.failure_count += 1
|
||||||
|
self.last_failure_time = time.time()
|
||||||
|
|
||||||
|
if self.state == self.STATE_HALF_OPEN:
|
||||||
|
# 半开状态下失败,重新熔断
|
||||||
|
self.state = self.STATE_OPEN
|
||||||
|
self.success_count = 0
|
||||||
|
# 增加恢复时间
|
||||||
|
self.current_recovery_time = min(
|
||||||
|
self.current_recovery_time * self.recovery_multiplier,
|
||||||
|
self.max_recovery_time
|
||||||
|
)
|
||||||
|
logger.warning(f"熔断器重新开启,下次恢复等待 {self.current_recovery_time:.1f}s")
|
||||||
|
|
||||||
|
elif self.state == self.STATE_CLOSED:
|
||||||
|
if self.failure_count >= self.failure_threshold:
|
||||||
|
self.state = self.STATE_OPEN
|
||||||
|
logger.warning(f"熔断器开启,连续失败 {self.failure_count} 次")
|
||||||
|
|
||||||
|
def record_rejection(self):
|
||||||
|
"""记录被拒绝的请求"""
|
||||||
|
self.total_rejections += 1
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""获取统计信息"""
|
||||||
|
return {
|
||||||
|
"state": self.state,
|
||||||
|
"failure_count": self.failure_count,
|
||||||
|
"success_count": self.success_count,
|
||||||
|
"current_recovery_time": self.current_recovery_time,
|
||||||
|
"total_failures": self.total_failures,
|
||||||
|
"total_successes": self.total_successes,
|
||||||
|
"total_rejections": self.total_rejections
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 请求重试机制 ====================
|
||||||
|
|
||||||
|
class RetryConfig:
|
||||||
|
"""重试配置"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_retries: int = 3,
|
||||||
|
initial_delay: float = 1.0,
|
||||||
|
max_delay: float = 30.0,
|
||||||
|
exponential_base: float = 2.0,
|
||||||
|
retryable_exceptions: tuple = (
|
||||||
|
aiohttp.ClientError,
|
||||||
|
asyncio.TimeoutError,
|
||||||
|
ConnectionError,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.initial_delay = initial_delay
|
||||||
|
self.max_delay = max_delay
|
||||||
|
self.exponential_base = exponential_base
|
||||||
|
self.retryable_exceptions = retryable_exceptions
|
||||||
|
|
||||||
|
|
||||||
|
def retry_async(config: RetryConfig = None):
|
||||||
|
"""异步重试装饰器"""
|
||||||
|
if config is None:
|
||||||
|
config = RetryConfig()
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
last_exception = None
|
||||||
|
|
||||||
|
for attempt in range(config.max_retries + 1):
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except config.retryable_exceptions as e:
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
if attempt == config.max_retries:
|
||||||
|
logger.error(f"重试 {config.max_retries} 次后仍然失败: {func.__name__}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 计算延迟时间(指数退避)
|
||||||
|
delay = min(
|
||||||
|
config.initial_delay * (config.exponential_base ** attempt),
|
||||||
|
config.max_delay
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"请求失败,{delay:.1f}s 后重试 "
|
||||||
|
f"(第 {attempt + 1}/{config.max_retries} 次): {e}"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
raise last_exception
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
async def request_with_retry(
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
max_retries: int = 3,
|
||||||
|
**kwargs
|
||||||
|
) -> aiohttp.ClientResponse:
|
||||||
|
"""带重试的 HTTP 请求"""
|
||||||
|
config = RetryConfig(max_retries=max_retries)
|
||||||
|
last_exception = None
|
||||||
|
|
||||||
|
for attempt in range(config.max_retries + 1):
|
||||||
|
try:
|
||||||
|
response = await session.request(method, url, **kwargs)
|
||||||
|
return response
|
||||||
|
except config.retryable_exceptions as e:
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
if attempt == config.max_retries:
|
||||||
|
raise
|
||||||
|
|
||||||
|
delay = min(
|
||||||
|
config.initial_delay * (config.exponential_base ** attempt),
|
||||||
|
config.max_delay
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning(f"HTTP 请求失败,{delay:.1f}s 后重试: {e}")
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
raise last_exception
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 配置热更新 ====================
|
||||||
|
|
||||||
|
class ConfigWatcher:
|
||||||
|
"""配置文件监听器"""
|
||||||
|
|
||||||
|
def __init__(self, config_path: str, check_interval: float = 5.0):
|
||||||
|
"""
|
||||||
|
初始化配置监听器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: 配置文件路径
|
||||||
|
check_interval: 检查间隔(秒)
|
||||||
|
"""
|
||||||
|
self.config_path = Path(config_path)
|
||||||
|
self.check_interval = check_interval
|
||||||
|
self.last_mtime = 0
|
||||||
|
self.callbacks: List[Callable[[dict], Any]] = []
|
||||||
|
self.current_config: dict = {}
|
||||||
|
self._running = False
|
||||||
|
self._task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
def register_callback(self, callback: Callable[[dict], Any]):
|
||||||
|
"""注册配置更新回调"""
|
||||||
|
self.callbacks.append(callback)
|
||||||
|
|
||||||
|
def unregister_callback(self, callback: Callable[[dict], Any]):
|
||||||
|
"""取消注册回调"""
|
||||||
|
if callback in self.callbacks:
|
||||||
|
self.callbacks.remove(callback)
|
||||||
|
|
||||||
|
def _load_config(self) -> dict:
|
||||||
|
"""加载配置文件"""
|
||||||
|
try:
|
||||||
|
with open(self.config_path, "rb") as f:
|
||||||
|
return tomllib.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载配置文件失败: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_config(self) -> dict:
|
||||||
|
"""获取当前配置"""
|
||||||
|
return self.current_config
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""启动配置监听"""
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
# 初始加载
|
||||||
|
if self.config_path.exists():
|
||||||
|
self.last_mtime = os.path.getmtime(self.config_path)
|
||||||
|
self.current_config = self._load_config()
|
||||||
|
|
||||||
|
self._task = asyncio.create_task(self._watch_loop())
|
||||||
|
logger.info(f"配置监听器已启动: {self.config_path}")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""停止配置监听"""
|
||||||
|
self._running = False
|
||||||
|
if self._task:
|
||||||
|
self._task.cancel()
|
||||||
|
try:
|
||||||
|
await self._task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
logger.info("配置监听器已停止")
|
||||||
|
|
||||||
|
async def _watch_loop(self):
|
||||||
|
"""监听循环"""
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self.check_interval)
|
||||||
|
|
||||||
|
if not self.config_path.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
mtime = os.path.getmtime(self.config_path)
|
||||||
|
if mtime > self.last_mtime:
|
||||||
|
logger.info("检测到配置文件变化,重新加载...")
|
||||||
|
new_config = self._load_config()
|
||||||
|
|
||||||
|
if new_config:
|
||||||
|
old_config = self.current_config
|
||||||
|
self.current_config = new_config
|
||||||
|
self.last_mtime = mtime
|
||||||
|
|
||||||
|
# 通知所有回调
|
||||||
|
for callback in self.callbacks:
|
||||||
|
try:
|
||||||
|
if asyncio.iscoroutinefunction(callback):
|
||||||
|
await callback(new_config)
|
||||||
|
else:
|
||||||
|
callback(new_config)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"配置更新回调执行失败: {e}")
|
||||||
|
|
||||||
|
logger.success("配置已热更新")
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"配置监听异常: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 性能监控 ====================
|
||||||
|
|
||||||
|
class PerformanceMonitor:
|
||||||
|
"""性能监控器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.start_time = time.time()
|
||||||
|
|
||||||
|
# 消息统计
|
||||||
|
self.message_received = 0
|
||||||
|
self.message_processed = 0
|
||||||
|
self.message_failed = 0
|
||||||
|
self.message_dropped = 0
|
||||||
|
|
||||||
|
# 处理时间统计
|
||||||
|
self.processing_times: List[float] = []
|
||||||
|
self.max_processing_times = 1000 # 保留最近1000条记录
|
||||||
|
|
||||||
|
# 插件统计
|
||||||
|
self.plugin_stats: Dict[str, dict] = {}
|
||||||
|
|
||||||
|
# 队列统计
|
||||||
|
self.queue_size_history: List[Tuple[float, int]] = []
|
||||||
|
self.max_queue_history = 100
|
||||||
|
|
||||||
|
# 熔断器统计
|
||||||
|
self.circuit_breaker_stats: dict = {}
|
||||||
|
|
||||||
|
def record_message_received(self):
|
||||||
|
"""记录收到消息"""
|
||||||
|
self.message_received += 1
|
||||||
|
|
||||||
|
def record_message_processed(self, processing_time: float):
|
||||||
|
"""记录消息处理完成"""
|
||||||
|
self.message_processed += 1
|
||||||
|
self.processing_times.append(processing_time)
|
||||||
|
|
||||||
|
# 限制历史记录数量
|
||||||
|
if len(self.processing_times) > self.max_processing_times:
|
||||||
|
self.processing_times = self.processing_times[-self.max_processing_times:]
|
||||||
|
|
||||||
|
def record_message_failed(self):
|
||||||
|
"""记录消息处理失败"""
|
||||||
|
self.message_failed += 1
|
||||||
|
|
||||||
|
def record_message_dropped(self):
|
||||||
|
"""记录消息被丢弃"""
|
||||||
|
self.message_dropped += 1
|
||||||
|
|
||||||
|
def record_queue_size(self, size: int):
|
||||||
|
"""记录队列大小"""
|
||||||
|
self.queue_size_history.append((time.time(), size))
|
||||||
|
|
||||||
|
if len(self.queue_size_history) > self.max_queue_history:
|
||||||
|
self.queue_size_history = self.queue_size_history[-self.max_queue_history:]
|
||||||
|
|
||||||
|
def record_plugin_execution(self, plugin_name: str, execution_time: float, success: bool):
|
||||||
|
"""记录插件执行"""
|
||||||
|
if plugin_name not in self.plugin_stats:
|
||||||
|
self.plugin_stats[plugin_name] = {
|
||||||
|
"total_calls": 0,
|
||||||
|
"success_calls": 0,
|
||||||
|
"failed_calls": 0,
|
||||||
|
"total_time": 0,
|
||||||
|
"max_time": 0,
|
||||||
|
"recent_times": []
|
||||||
|
}
|
||||||
|
|
||||||
|
stats = self.plugin_stats[plugin_name]
|
||||||
|
stats["total_calls"] += 1
|
||||||
|
stats["total_time"] += execution_time
|
||||||
|
stats["max_time"] = max(stats["max_time"], execution_time)
|
||||||
|
stats["recent_times"].append(execution_time)
|
||||||
|
|
||||||
|
if len(stats["recent_times"]) > 100:
|
||||||
|
stats["recent_times"] = stats["recent_times"][-100:]
|
||||||
|
|
||||||
|
if success:
|
||||||
|
stats["success_calls"] += 1
|
||||||
|
else:
|
||||||
|
stats["failed_calls"] += 1
|
||||||
|
|
||||||
|
def update_circuit_breaker_stats(self, stats: dict):
|
||||||
|
"""更新熔断器统计"""
|
||||||
|
self.circuit_breaker_stats = stats
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""获取完整统计信息"""
|
||||||
|
uptime = time.time() - self.start_time
|
||||||
|
|
||||||
|
# 计算平均处理时间
|
||||||
|
avg_processing_time = 0
|
||||||
|
if self.processing_times:
|
||||||
|
avg_processing_time = sum(self.processing_times) / len(self.processing_times)
|
||||||
|
|
||||||
|
# 计算处理速率
|
||||||
|
processing_rate = self.message_processed / uptime if uptime > 0 else 0
|
||||||
|
|
||||||
|
# 计算成功率
|
||||||
|
total = self.message_processed + self.message_failed
|
||||||
|
success_rate = self.message_processed / total if total > 0 else 1.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"uptime_seconds": uptime,
|
||||||
|
"uptime_formatted": self._format_uptime(uptime),
|
||||||
|
"messages": {
|
||||||
|
"received": self.message_received,
|
||||||
|
"processed": self.message_processed,
|
||||||
|
"failed": self.message_failed,
|
||||||
|
"dropped": self.message_dropped,
|
||||||
|
"success_rate": f"{success_rate * 100:.1f}%",
|
||||||
|
"processing_rate": f"{processing_rate:.2f}/s"
|
||||||
|
},
|
||||||
|
"processing_time": {
|
||||||
|
"average_ms": f"{avg_processing_time * 1000:.1f}",
|
||||||
|
"max_ms": f"{max(self.processing_times) * 1000:.1f}" if self.processing_times else "0",
|
||||||
|
"min_ms": f"{min(self.processing_times) * 1000:.1f}" if self.processing_times else "0"
|
||||||
|
},
|
||||||
|
"queue": {
|
||||||
|
"current_size": self.queue_size_history[-1][1] if self.queue_size_history else 0,
|
||||||
|
"max_size": max(s[1] for s in self.queue_size_history) if self.queue_size_history else 0
|
||||||
|
},
|
||||||
|
"circuit_breaker": self.circuit_breaker_stats,
|
||||||
|
"plugins": self._get_plugin_summary()
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_plugin_summary(self) -> List[dict]:
|
||||||
|
"""获取插件统计摘要"""
|
||||||
|
summary = []
|
||||||
|
for name, stats in self.plugin_stats.items():
|
||||||
|
avg_time = stats["total_time"] / stats["total_calls"] if stats["total_calls"] > 0 else 0
|
||||||
|
summary.append({
|
||||||
|
"name": name,
|
||||||
|
"calls": stats["total_calls"],
|
||||||
|
"success_rate": f"{stats['success_calls'] / stats['total_calls'] * 100:.1f}%" if stats["total_calls"] > 0 else "N/A",
|
||||||
|
"avg_time_ms": f"{avg_time * 1000:.1f}",
|
||||||
|
"max_time_ms": f"{stats['max_time'] * 1000:.1f}"
|
||||||
|
})
|
||||||
|
|
||||||
|
# 按平均时间排序
|
||||||
|
summary.sort(key=lambda x: float(x["avg_time_ms"]), reverse=True)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def _format_uptime(self, seconds: float) -> str:
|
||||||
|
"""格式化运行时间"""
|
||||||
|
days = int(seconds // 86400)
|
||||||
|
hours = int((seconds % 86400) // 3600)
|
||||||
|
minutes = int((seconds % 3600) // 60)
|
||||||
|
secs = int(seconds % 60)
|
||||||
|
|
||||||
|
if days > 0:
|
||||||
|
return f"{days}天 {hours}小时 {minutes}分钟"
|
||||||
|
elif hours > 0:
|
||||||
|
return f"{hours}小时 {minutes}分钟"
|
||||||
|
elif minutes > 0:
|
||||||
|
return f"{minutes}分钟 {secs}秒"
|
||||||
|
else:
|
||||||
|
return f"{secs}秒"
|
||||||
|
|
||||||
|
def print_stats(self):
|
||||||
|
"""打印统计信息到日志"""
|
||||||
|
stats = self.get_stats()
|
||||||
|
|
||||||
|
logger.info("=" * 50)
|
||||||
|
logger.info("性能监控报告")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
logger.info(f"运行时间: {stats['uptime_formatted']}")
|
||||||
|
logger.info(f"消息统计: 收到 {stats['messages']['received']}, "
|
||||||
|
f"处理 {stats['messages']['processed']}, "
|
||||||
|
f"失败 {stats['messages']['failed']}, "
|
||||||
|
f"丢弃 {stats['messages']['dropped']}")
|
||||||
|
logger.info(f"成功率: {stats['messages']['success_rate']}, "
|
||||||
|
f"处理速率: {stats['messages']['processing_rate']}")
|
||||||
|
logger.info(f"平均处理时间: {stats['processing_time']['average_ms']}ms")
|
||||||
|
logger.info(f"队列大小: {stats['queue']['current_size']}")
|
||||||
|
logger.info(f"熔断器状态: {stats['circuit_breaker'].get('state', 'N/A')}")
|
||||||
|
|
||||||
|
if stats['plugins']:
|
||||||
|
logger.info("插件耗时排行:")
|
||||||
|
for i, p in enumerate(stats['plugins'][:5], 1):
|
||||||
|
logger.info(f" {i}. {p['name']}: {p['avg_time_ms']}ms (调用 {p['calls']} 次)")
|
||||||
|
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 全局实例 ====================
|
||||||
|
|
||||||
|
# 性能监控器单例
|
||||||
|
_performance_monitor: Optional[PerformanceMonitor] = None
|
||||||
|
|
||||||
|
def get_performance_monitor() -> PerformanceMonitor:
|
||||||
|
"""获取性能监控器实例"""
|
||||||
|
global _performance_monitor
|
||||||
|
if _performance_monitor is None:
|
||||||
|
_performance_monitor = PerformanceMonitor()
|
||||||
|
return _performance_monitor
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@@ -13,6 +14,14 @@ class PluginBase(ABC):
|
|||||||
author: str = "未知"
|
author: str = "未知"
|
||||||
version: str = "1.0.0"
|
version: str = "1.0.0"
|
||||||
|
|
||||||
|
# 插件依赖(填写依赖的插件类名列表)
|
||||||
|
# 例如: dependencies = ["MessageLogger", "AIChat"]
|
||||||
|
dependencies: List[str] = []
|
||||||
|
|
||||||
|
# 加载优先级(数值越大越先加载,默认50)
|
||||||
|
# 基础插件设置高优先级,依赖其他插件的设置低优先级
|
||||||
|
load_priority: int = 50
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.enabled = False
|
self.enabled = False
|
||||||
self._scheduled_jobs = set()
|
self._scheduled_jobs = set()
|
||||||
|
|||||||
@@ -117,24 +117,107 @@ class PluginManager(metaclass=Singleton):
|
|||||||
if not found:
|
if not found:
|
||||||
logger.warning(f"未找到插件类 {plugin_name}")
|
logger.warning(f"未找到插件类 {plugin_name}")
|
||||||
|
|
||||||
|
def _resolve_load_order(self, plugin_classes: List[Type[PluginBase]]) -> List[Type[PluginBase]]:
|
||||||
|
"""
|
||||||
|
解析插件加载顺序(拓扑排序 + 优先级排序)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_classes: 插件类列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
按依赖关系和优先级排序后的插件类列表
|
||||||
|
"""
|
||||||
|
# 构建插件名到类的映射
|
||||||
|
name_to_class = {cls.__name__: cls for cls in plugin_classes}
|
||||||
|
|
||||||
|
# 构建依赖图
|
||||||
|
dependencies = {}
|
||||||
|
for cls in plugin_classes:
|
||||||
|
deps = getattr(cls, 'dependencies', [])
|
||||||
|
dependencies[cls.__name__] = [d for d in deps if d in name_to_class]
|
||||||
|
|
||||||
|
# 拓扑排序
|
||||||
|
sorted_names = []
|
||||||
|
visited = set()
|
||||||
|
temp_visited = set()
|
||||||
|
|
||||||
|
def visit(name: str):
|
||||||
|
if name in temp_visited:
|
||||||
|
# 检测到循环依赖
|
||||||
|
logger.warning(f"检测到循环依赖: {name}")
|
||||||
|
return
|
||||||
|
if name in visited:
|
||||||
|
return
|
||||||
|
|
||||||
|
temp_visited.add(name)
|
||||||
|
|
||||||
|
# 先访问依赖
|
||||||
|
for dep in dependencies.get(name, []):
|
||||||
|
visit(dep)
|
||||||
|
|
||||||
|
temp_visited.remove(name)
|
||||||
|
visited.add(name)
|
||||||
|
sorted_names.append(name)
|
||||||
|
|
||||||
|
# 按优先级排序后再进行拓扑排序
|
||||||
|
priority_sorted = sorted(
|
||||||
|
plugin_classes,
|
||||||
|
key=lambda cls: getattr(cls, 'load_priority', 50),
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for cls in priority_sorted:
|
||||||
|
if cls.__name__ not in visited:
|
||||||
|
visit(cls.__name__)
|
||||||
|
|
||||||
|
# 返回排序后的类列表
|
||||||
|
return [name_to_class[name] for name in sorted_names if name in name_to_class]
|
||||||
|
|
||||||
async def load_plugins(self, load_disabled: bool = True) -> Union[List[str], bool]:
|
async def load_plugins(self, load_disabled: bool = True) -> Union[List[str], bool]:
|
||||||
|
"""加载所有插件(按依赖顺序)"""
|
||||||
loaded_plugins = []
|
loaded_plugins = []
|
||||||
|
|
||||||
|
# 第一步:收集所有插件类
|
||||||
|
all_plugin_classes = []
|
||||||
|
plugin_disabled_map = {}
|
||||||
|
|
||||||
for dirname in os.listdir("plugins"):
|
for dirname in os.listdir("plugins"):
|
||||||
if os.path.isdir(f"plugins/{dirname}") and os.path.exists(f"plugins/{dirname}/main.py"):
|
if os.path.isdir(f"plugins/{dirname}") and os.path.exists(f"plugins/{dirname}/main.py"):
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(f"plugins.{dirname}.main")
|
module = importlib.import_module(f"plugins.{dirname}.main")
|
||||||
for name, obj in inspect.getmembers(module):
|
for name, obj in inspect.getmembers(module):
|
||||||
if inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase:
|
if inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase:
|
||||||
|
all_plugin_classes.append(obj)
|
||||||
|
|
||||||
|
# 记录是否禁用
|
||||||
is_disabled = False
|
is_disabled = False
|
||||||
if not load_disabled:
|
if not load_disabled:
|
||||||
is_disabled = obj.__name__ in self.excluded_plugins or dirname in self.excluded_plugins
|
is_disabled = obj.__name__ in self.excluded_plugins or dirname in self.excluded_plugins
|
||||||
|
plugin_disabled_map[obj.__name__] = is_disabled
|
||||||
if await self._load_plugin_class(obj, is_disabled=is_disabled):
|
|
||||||
loaded_plugins.append(obj.__name__)
|
|
||||||
except:
|
except:
|
||||||
logger.error(f"加载 {dirname} 时发生错误: {traceback.format_exc()}")
|
logger.error(f"加载 {dirname} 时发生错误: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
# 第二步:按依赖顺序排序
|
||||||
|
sorted_classes = self._resolve_load_order(all_plugin_classes)
|
||||||
|
logger.info(f"插件加载顺序: {[cls.__name__ for cls in sorted_classes]}")
|
||||||
|
|
||||||
|
# 第三步:按顺序加载插件
|
||||||
|
for plugin_class in sorted_classes:
|
||||||
|
plugin_name = plugin_class.__name__
|
||||||
|
is_disabled = plugin_disabled_map.get(plugin_name, False)
|
||||||
|
|
||||||
|
# 检查依赖是否已加载
|
||||||
|
deps = getattr(plugin_class, 'dependencies', [])
|
||||||
|
deps_satisfied = all(dep in self.plugins for dep in deps)
|
||||||
|
|
||||||
|
if not deps_satisfied and not is_disabled:
|
||||||
|
missing_deps = [dep for dep in deps if dep not in self.plugins]
|
||||||
|
logger.warning(f"插件 {plugin_name} 的依赖未满足: {missing_deps},跳过加载")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if await self._load_plugin_class(plugin_class, is_disabled=is_disabled):
|
||||||
|
loaded_plugins.append(plugin_name)
|
||||||
|
|
||||||
return loaded_plugins
|
return loaded_plugins
|
||||||
|
|
||||||
async def unload_plugin(self, plugin_name: str) -> bool:
|
async def unload_plugin(self, plugin_name: str) -> bool:
|
||||||
|
|||||||
744
utils/redis_cache.py
Normal file
744
utils/redis_cache.py
Normal file
@@ -0,0 +1,744 @@
|
|||||||
|
"""
|
||||||
|
Redis 缓存工具类
|
||||||
|
|
||||||
|
用于缓存用户信息等数据,减少 API 调用
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
import redis
|
||||||
|
REDIS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
REDIS_AVAILABLE = False
|
||||||
|
logger.warning("redis 库未安装,缓存功能将不可用")
|
||||||
|
|
||||||
|
|
||||||
|
class RedisCache:
|
||||||
|
"""Redis 缓存管理器"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
"""单例模式"""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self, config: Dict = None):
|
||||||
|
"""
|
||||||
|
初始化 Redis 连接
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Redis 配置字典,包含 host, port, password, db 等
|
||||||
|
"""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.client = None
|
||||||
|
self.enabled = False
|
||||||
|
self.default_ttl = 3600 # 默认过期时间 1 小时
|
||||||
|
|
||||||
|
if not REDIS_AVAILABLE:
|
||||||
|
logger.warning("Redis 库未安装,缓存功能禁用")
|
||||||
|
self._initialized = True
|
||||||
|
return
|
||||||
|
|
||||||
|
if config:
|
||||||
|
self.connect(config)
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def connect(self, config: Dict) -> bool:
|
||||||
|
"""
|
||||||
|
连接 Redis
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Redis 配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否连接成功
|
||||||
|
"""
|
||||||
|
if not REDIS_AVAILABLE:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client = redis.Redis(
|
||||||
|
host=config.get("host", "localhost"),
|
||||||
|
port=config.get("port", 6379),
|
||||||
|
password=config.get("password", None),
|
||||||
|
db=config.get("db", 0),
|
||||||
|
decode_responses=True,
|
||||||
|
socket_timeout=5,
|
||||||
|
socket_connect_timeout=5
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试连接
|
||||||
|
self.client.ping()
|
||||||
|
self.enabled = True
|
||||||
|
self.default_ttl = config.get("ttl", 3600)
|
||||||
|
|
||||||
|
logger.success(f"Redis 连接成功: {config.get('host')}:{config.get('port')}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis 连接失败: {e}")
|
||||||
|
self.client = None
|
||||||
|
self.enabled = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _make_key(self, prefix: str, *args) -> str:
|
||||||
|
"""
|
||||||
|
生成缓存 key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: key 前缀
|
||||||
|
*args: key 组成部分
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的 key
|
||||||
|
"""
|
||||||
|
parts = [prefix] + [str(arg) for arg in args]
|
||||||
|
return ":".join(parts)
|
||||||
|
|
||||||
|
def get(self, key: str) -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
获取缓存值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 缓存 key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存的值,不存在返回 None
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
value = self.client.get(key)
|
||||||
|
if value:
|
||||||
|
return json.loads(value)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis GET 失败: {key}, {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, key: str, value: Any, ttl: int = None) -> bool:
|
||||||
|
"""
|
||||||
|
设置缓存值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 缓存 key
|
||||||
|
value: 要缓存的值
|
||||||
|
ttl: 过期时间(秒),默认使用 default_ttl
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否设置成功
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
ttl = ttl or self.default_ttl
|
||||||
|
self.client.setex(key, ttl, json.dumps(value, ensure_ascii=False))
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis SET 失败: {key}, {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
删除缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 缓存 key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否删除成功
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.delete(key)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis DELETE 失败: {key}, {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete_pattern(self, pattern: str) -> int:
|
||||||
|
"""
|
||||||
|
删除匹配模式的所有 key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: key 模式,如 "user_info:*"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
删除的 key 数量
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
keys = self.client.keys(pattern)
|
||||||
|
if keys:
|
||||||
|
return self.client.delete(*keys)
|
||||||
|
return 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis DELETE PATTERN 失败: {pattern}, {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# ==================== 用户信息缓存专用方法 ====================
|
||||||
|
|
||||||
|
def get_user_info(self, chatroom_id: str, user_wxid: str) -> Optional[Dict]:
|
||||||
|
"""
|
||||||
|
获取缓存的用户信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chatroom_id: 群聊 ID
|
||||||
|
user_wxid: 用户 wxid
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
用户信息字典,不存在返回 None
|
||||||
|
"""
|
||||||
|
key = self._make_key("user_info", chatroom_id, user_wxid)
|
||||||
|
return self.get(key)
|
||||||
|
|
||||||
|
def set_user_info(self, chatroom_id: str, user_wxid: str, user_info: Dict, ttl: int = None) -> bool:
|
||||||
|
"""
|
||||||
|
缓存用户信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chatroom_id: 群聊 ID
|
||||||
|
user_wxid: 用户 wxid
|
||||||
|
user_info: 用户信息字典
|
||||||
|
ttl: 过期时间(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否缓存成功
|
||||||
|
"""
|
||||||
|
key = self._make_key("user_info", chatroom_id, user_wxid)
|
||||||
|
return self.set(key, user_info, ttl)
|
||||||
|
|
||||||
|
def get_user_basic_info(self, chatroom_id: str, user_wxid: str) -> Optional[Dict]:
|
||||||
|
"""
|
||||||
|
获取缓存的用户基本信息(昵称和头像)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chatroom_id: 群聊 ID
|
||||||
|
user_wxid: 用户 wxid
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含 nickname 和 avatar_url 的字典
|
||||||
|
"""
|
||||||
|
user_info = self.get_user_info(chatroom_id, user_wxid)
|
||||||
|
if user_info:
|
||||||
|
# 提取基本信息
|
||||||
|
nickname = ""
|
||||||
|
if isinstance(user_info.get("nickName"), dict):
|
||||||
|
nickname = user_info.get("nickName", {}).get("string", "")
|
||||||
|
else:
|
||||||
|
nickname = user_info.get("nickName", "")
|
||||||
|
|
||||||
|
avatar_url = user_info.get("bigHeadImgUrl", "")
|
||||||
|
|
||||||
|
if nickname or avatar_url:
|
||||||
|
return {
|
||||||
|
"nickname": nickname,
|
||||||
|
"avatar_url": avatar_url
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clear_user_cache(self, chatroom_id: str = None, user_wxid: str = None) -> int:
|
||||||
|
"""
|
||||||
|
清除用户信息缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chatroom_id: 群聊 ID,为空则清除所有群
|
||||||
|
user_wxid: 用户 wxid,为空则清除该群所有用户
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
清除的缓存数量
|
||||||
|
"""
|
||||||
|
if chatroom_id and user_wxid:
|
||||||
|
key = self._make_key("user_info", chatroom_id, user_wxid)
|
||||||
|
return 1 if self.delete(key) else 0
|
||||||
|
elif chatroom_id:
|
||||||
|
pattern = self._make_key("user_info", chatroom_id, "*")
|
||||||
|
return self.delete_pattern(pattern)
|
||||||
|
else:
|
||||||
|
return self.delete_pattern("user_info:*")
|
||||||
|
|
||||||
|
def get_cache_stats(self) -> Dict:
|
||||||
|
"""
|
||||||
|
获取缓存统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
统计信息字典
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return {"enabled": False}
|
||||||
|
|
||||||
|
try:
|
||||||
|
info = self.client.info("memory")
|
||||||
|
user_keys = len(self.client.keys("user_info:*"))
|
||||||
|
chat_keys = len(self.client.keys("chat_history:*"))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"enabled": True,
|
||||||
|
"used_memory": info.get("used_memory_human", "unknown"),
|
||||||
|
"user_info_count": user_keys,
|
||||||
|
"chat_history_count": chat_keys
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取缓存统计失败: {e}")
|
||||||
|
return {"enabled": True, "error": str(e)}
|
||||||
|
|
||||||
|
# ==================== 对话历史缓存专用方法 ====================
|
||||||
|
|
||||||
|
def get_chat_history(self, chat_id: str, max_messages: int = 100) -> list:
|
||||||
|
"""
|
||||||
|
获取对话历史
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 会话ID(私聊为用户wxid,群聊为 群ID:用户ID)
|
||||||
|
max_messages: 最大返回消息数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
消息列表
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = self._make_key("chat_history", chat_id)
|
||||||
|
# 使用 LRANGE 获取最近的消息(列表尾部是最新的)
|
||||||
|
data = self.client.lrange(key, -max_messages, -1)
|
||||||
|
return [json.loads(item) for item in data]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取对话历史失败: {chat_id}, {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def add_chat_message(self, chat_id: str, role: str, content, ttl: int = 86400) -> bool:
|
||||||
|
"""
|
||||||
|
添加消息到对话历史
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 会话ID
|
||||||
|
role: 角色 (user/assistant)
|
||||||
|
content: 消息内容(字符串或列表)
|
||||||
|
ttl: 过期时间(秒),默认24小时
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否添加成功
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = self._make_key("chat_history", chat_id)
|
||||||
|
message = {"role": role, "content": content}
|
||||||
|
self.client.rpush(key, json.dumps(message, ensure_ascii=False))
|
||||||
|
self.client.expire(key, ttl)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"添加对话消息失败: {chat_id}, {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def trim_chat_history(self, chat_id: str, max_messages: int = 100) -> bool:
|
||||||
|
"""
|
||||||
|
裁剪对话历史,保留最近的N条消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 会话ID
|
||||||
|
max_messages: 保留的最大消息数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = self._make_key("chat_history", chat_id)
|
||||||
|
# 保留最后 max_messages 条
|
||||||
|
self.client.ltrim(key, -max_messages, -1)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"裁剪对话历史失败: {chat_id}, {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def clear_chat_history(self, chat_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
清空指定会话的对话历史
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 会话ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = self._make_key("chat_history", chat_id)
|
||||||
|
self.client.delete(key)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"清空对话历史失败: {chat_id}, {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# ==================== 群聊历史记录专用方法 ====================
|
||||||
|
|
||||||
|
def get_group_history(self, group_id: str, max_messages: int = 100) -> list:
|
||||||
|
"""
|
||||||
|
获取群聊历史记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: 群聊ID
|
||||||
|
max_messages: 最大返回消息数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
消息列表,每条包含 nickname, content, timestamp
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = self._make_key("group_history", group_id)
|
||||||
|
data = self.client.lrange(key, -max_messages, -1)
|
||||||
|
return [json.loads(item) for item in data]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取群聊历史失败: {group_id}, {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def add_group_message(self, group_id: str, nickname: str, content,
|
||||||
|
record_id: str = None, ttl: int = 86400) -> bool:
|
||||||
|
"""
|
||||||
|
添加消息到群聊历史
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: 群聊ID
|
||||||
|
nickname: 发送者昵称
|
||||||
|
content: 消息内容
|
||||||
|
record_id: 可选的记录ID,用于后续更新
|
||||||
|
ttl: 过期时间(秒),默认24小时
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否添加成功
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import time
|
||||||
|
key = self._make_key("group_history", group_id)
|
||||||
|
message = {
|
||||||
|
"nickname": nickname,
|
||||||
|
"content": content,
|
||||||
|
"timestamp": time.time()
|
||||||
|
}
|
||||||
|
if record_id:
|
||||||
|
message["id"] = record_id
|
||||||
|
|
||||||
|
self.client.rpush(key, json.dumps(message, ensure_ascii=False))
|
||||||
|
self.client.expire(key, ttl)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"添加群聊消息失败: {group_id}, {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def update_group_message_by_id(self, group_id: str, record_id: str, new_content) -> bool:
|
||||||
|
"""
|
||||||
|
根据ID更新群聊历史中的消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: 群聊ID
|
||||||
|
record_id: 记录ID
|
||||||
|
new_content: 新内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否更新成功
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = self._make_key("group_history", group_id)
|
||||||
|
# 获取所有消息
|
||||||
|
data = self.client.lrange(key, 0, -1)
|
||||||
|
|
||||||
|
for i, item in enumerate(data):
|
||||||
|
msg = json.loads(item)
|
||||||
|
if msg.get("id") == record_id:
|
||||||
|
msg["content"] = new_content
|
||||||
|
self.client.lset(key, i, json.dumps(msg, ensure_ascii=False))
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新群聊消息失败: {group_id}, {record_id}, {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def trim_group_history(self, group_id: str, max_messages: int = 100) -> bool:
|
||||||
|
"""
|
||||||
|
裁剪群聊历史,保留最近的N条消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: 群聊ID
|
||||||
|
max_messages: 保留的最大消息数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = self._make_key("group_history", group_id)
|
||||||
|
self.client.ltrim(key, -max_messages, -1)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"裁剪群聊历史失败: {group_id}, {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# ==================== 限流专用方法 ====================
|
||||||
|
|
||||||
|
def check_rate_limit(self, identifier: str, limit: int = 10,
|
||||||
|
window: int = 60, limit_type: str = "message") -> tuple:
|
||||||
|
"""
|
||||||
|
检查是否超过限流
|
||||||
|
|
||||||
|
使用滑动窗口算法
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier: 标识符(如用户wxid、群ID等)
|
||||||
|
limit: 时间窗口内最大请求数
|
||||||
|
window: 时间窗口(秒)
|
||||||
|
limit_type: 限流类型(message/ai_chat/image_gen等)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否允许, 剩余次数, 重置时间秒数)
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return (True, limit, 0) # Redis 不可用时不限流
|
||||||
|
|
||||||
|
try:
|
||||||
|
import time
|
||||||
|
key = self._make_key("rate_limit", limit_type, identifier)
|
||||||
|
now = time.time()
|
||||||
|
window_start = now - window
|
||||||
|
|
||||||
|
# 使用 pipeline 提高性能
|
||||||
|
pipe = self.client.pipeline()
|
||||||
|
|
||||||
|
# 移除过期的记录
|
||||||
|
pipe.zremrangebyscore(key, 0, window_start)
|
||||||
|
# 获取当前窗口内的请求数
|
||||||
|
pipe.zcard(key)
|
||||||
|
# 添加当前请求
|
||||||
|
pipe.zadd(key, {str(now): now})
|
||||||
|
# 设置过期时间
|
||||||
|
pipe.expire(key, window)
|
||||||
|
|
||||||
|
results = pipe.execute()
|
||||||
|
current_count = results[1] # zcard 的结果
|
||||||
|
|
||||||
|
if current_count >= limit:
|
||||||
|
# 获取最早的记录时间,计算重置时间
|
||||||
|
oldest = self.client.zrange(key, 0, 0, withscores=True)
|
||||||
|
if oldest:
|
||||||
|
reset_time = int(oldest[0][1] + window - now)
|
||||||
|
else:
|
||||||
|
reset_time = window
|
||||||
|
return (False, 0, max(reset_time, 1))
|
||||||
|
|
||||||
|
remaining = limit - current_count - 1
|
||||||
|
return (True, remaining, 0)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"限流检查失败: {identifier}, {e}")
|
||||||
|
return (True, limit, 0) # 出错时不限流
|
||||||
|
|
||||||
|
def get_rate_limit_status(self, identifier: str, limit: int = 10,
|
||||||
|
window: int = 60, limit_type: str = "message") -> Dict:
|
||||||
|
"""
|
||||||
|
获取限流状态(不增加计数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier: 标识符
|
||||||
|
limit: 时间窗口内最大请求数
|
||||||
|
window: 时间窗口(秒)
|
||||||
|
limit_type: 限流类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
状态字典
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return {"enabled": False, "current": 0, "limit": limit, "remaining": limit}
|
||||||
|
|
||||||
|
try:
|
||||||
|
import time
|
||||||
|
key = self._make_key("rate_limit", limit_type, identifier)
|
||||||
|
now = time.time()
|
||||||
|
window_start = now - window
|
||||||
|
|
||||||
|
# 移除过期记录并获取当前数量
|
||||||
|
self.client.zremrangebyscore(key, 0, window_start)
|
||||||
|
current = self.client.zcard(key)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"enabled": True,
|
||||||
|
"current": current,
|
||||||
|
"limit": limit,
|
||||||
|
"remaining": max(0, limit - current),
|
||||||
|
"window": window
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取限流状态失败: {identifier}, {e}")
|
||||||
|
return {"enabled": False, "error": str(e)}
|
||||||
|
|
||||||
|
def reset_rate_limit(self, identifier: str, limit_type: str = "message") -> bool:
|
||||||
|
"""
|
||||||
|
重置限流计数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier: 标识符
|
||||||
|
limit_type: 限流类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = self._make_key("rate_limit", limit_type, identifier)
|
||||||
|
self.client.delete(key)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"重置限流失败: {identifier}, {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# ==================== 媒体缓存专用方法 ====================
|
||||||
|
|
||||||
|
def cache_media(self, media_key: str, base64_data: str, media_type: str = "image", ttl: int = 300) -> bool:
|
||||||
|
"""
|
||||||
|
缓存媒体文件的 base64 数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
media_key: 媒体唯一标识(如 cdnurl 的 hash 或 aeskey)
|
||||||
|
base64_data: base64 编码的媒体数据
|
||||||
|
media_type: 媒体类型(image/emoji/video)
|
||||||
|
ttl: 过期时间(秒),默认5分钟
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否缓存成功
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = self._make_key("media_cache", media_type, media_key)
|
||||||
|
# 直接存储 base64 字符串,不再 json 序列化
|
||||||
|
self.client.setex(key, ttl, base64_data)
|
||||||
|
logger.debug(f"媒体已缓存: {media_type}/{media_key[:20]}..., TTL={ttl}s")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"缓存媒体失败: {media_key}, {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_cached_media(self, media_key: str, media_type: str = "image") -> Optional[str]:
|
||||||
|
"""
|
||||||
|
获取缓存的媒体 base64 数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
media_key: 媒体唯一标识
|
||||||
|
media_type: 媒体类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
base64 数据,不存在返回 None
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = self._make_key("media_cache", media_type, media_key)
|
||||||
|
data = self.client.get(key)
|
||||||
|
if data:
|
||||||
|
logger.debug(f"媒体缓存命中: {media_type}/{media_key[:20]}...")
|
||||||
|
return data
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取媒体缓存失败: {media_key}, {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete_cached_media(self, media_key: str, media_type: str = "image") -> bool:
|
||||||
|
"""
|
||||||
|
删除缓存的媒体
|
||||||
|
|
||||||
|
Args:
|
||||||
|
media_key: 媒体唯一标识
|
||||||
|
media_type: 媒体类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否删除成功
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.client:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = self._make_key("media_cache", media_type, media_key)
|
||||||
|
self.client.delete(key)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除媒体缓存失败: {media_key}, {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_media_key(cdnurl: str = "", aeskey: str = "") -> str:
|
||||||
|
"""
|
||||||
|
根据 CDN URL 或 AES Key 生成媒体缓存 key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cdnurl: CDN URL
|
||||||
|
aeskey: AES Key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存 key
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
# 优先使用 aeskey(更短更稳定),否则使用 cdnurl 的 hash
|
||||||
|
if aeskey:
|
||||||
|
return aeskey[:32] # 取前32位作为 key
|
||||||
|
elif cdnurl:
|
||||||
|
return hashlib.md5(cdnurl.encode()).hexdigest()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache() -> Optional[RedisCache]:
|
||||||
|
"""
|
||||||
|
获取全局缓存实例
|
||||||
|
|
||||||
|
返回 RedisCache 单例实例。如果还没有初始化,返回一个未连接的实例。
|
||||||
|
建议在 MessageLogger 初始化后再调用此函数。
|
||||||
|
"""
|
||||||
|
return RedisCache._instance
|
||||||
|
|
||||||
|
|
||||||
|
def init_cache(config: Dict) -> RedisCache:
|
||||||
|
"""
|
||||||
|
初始化全局缓存实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Redis 配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存实例
|
||||||
|
"""
|
||||||
|
global _cache_instance
|
||||||
|
_cache_instance = RedisCache(config)
|
||||||
|
return _cache_instance
|
||||||
Reference in New Issue
Block a user