feat: 优化整体项目

This commit is contained in:
2025-12-05 18:06:13 +08:00
parent b4df26f61d
commit 7d3ef70093
13 changed files with 2661 additions and 305 deletions

View File

@@ -2,6 +2,7 @@
AI 聊天插件
支持自定义模型、API 和人设
支持 Redis 存储对话历史和限流
"""
import asyncio
@@ -12,6 +13,7 @@ from datetime import datetime
from loguru import logger
from utils.plugin_base import PluginBase
from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message
from utils.redis_cache import get_cache
import xml.etree.ElementTree as ET
import base64
import uuid
@@ -95,6 +97,92 @@ class AIChat(PluginBase):
else:
return sender_wxid or from_wxid # 私聊使用用户ID
async def _get_user_nickname(self, bot, from_wxid: str, user_wxid: str, is_group: bool) -> str:
"""
获取用户昵称,优先使用 Redis 缓存
Args:
bot: WechatHookClient 实例
from_wxid: 消息来源群聊ID或私聊用户ID
user_wxid: 用户wxid
is_group: 是否群聊
Returns:
用户昵称
"""
if not is_group:
return ""
nickname = ""
# 1. 优先从 Redis 缓存获取
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid)
if cached_info and cached_info.get("nickname"):
logger.debug(f"[缓存命中] 用户昵称: {user_wxid} -> {cached_info['nickname']}")
return cached_info["nickname"]
# 2. 缓存未命中,调用 API 获取
try:
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
if user_info and user_info.get("nickName", {}).get("string"):
nickname = user_info["nickName"]["string"]
# 存入缓存
if redis_cache and redis_cache.enabled:
redis_cache.set_user_info(from_wxid, user_wxid, user_info)
logger.debug(f"[已缓存] 用户昵称: {user_wxid} -> {nickname}")
return nickname
except Exception as e:
logger.warning(f"API获取用户昵称失败: {e}")
# 3. 从 MessageLogger 数据库查询
if not nickname:
try:
from plugins.MessageLogger.main import MessageLogger
msg_logger = MessageLogger.get_instance()
if msg_logger:
with msg_logger.get_db_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
(user_wxid,)
)
result = cursor.fetchone()
if result:
nickname = result[0]
except Exception as e:
logger.debug(f"从数据库获取昵称失败: {e}")
# 4. 最后降级使用 wxid
if not nickname:
nickname = user_wxid or "未知用户"
return nickname
def _check_rate_limit(self, user_wxid: str) -> tuple:
"""
检查用户是否超过限流
Args:
user_wxid: 用户wxid
Returns:
(是否允许, 剩余次数, 重置时间秒数)
"""
rate_limit_config = self.config.get("rate_limit", {})
if not rate_limit_config.get("enabled", True):
return (True, 999, 0)
redis_cache = get_cache()
if not redis_cache or not redis_cache.enabled:
return (True, 999, 0) # Redis 不可用时不限流
limit = rate_limit_config.get("ai_chat_limit", 20)
window = rate_limit_config.get("ai_chat_window", 60)
return redis_cache.check_rate_limit(user_wxid, limit, window, "ai_chat")
def _add_to_memory(self, chat_id: str, role: str, content, image_base64: str = None):
"""
添加消息到记忆
@@ -108,9 +196,6 @@ class AIChat(PluginBase):
if not self.config.get("memory", {}).get("enabled", False):
return
if chat_id not in self.memory:
self.memory[chat_id] = []
# 如果有图片,构建多模态内容
if image_base64:
message_content = [
@@ -120,6 +205,22 @@ class AIChat(PluginBase):
else:
message_content = content
# 优先使用 Redis 存储
redis_config = self.config.get("redis", {})
if redis_config.get("use_redis_history", True):
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
ttl = redis_config.get("chat_history_ttl", 86400)
redis_cache.add_chat_message(chat_id, role, message_content, ttl=ttl)
# 裁剪历史
max_messages = self.config["memory"]["max_messages"]
redis_cache.trim_chat_history(chat_id, max_messages)
return
# 降级到内存存储
if chat_id not in self.memory:
self.memory[chat_id] = []
self.memory[chat_id].append({"role": role, "content": message_content})
# 限制记忆长度
@@ -131,16 +232,47 @@ class AIChat(PluginBase):
"""获取记忆中的消息"""
if not self.config.get("memory", {}).get("enabled", False):
return []
# 优先从 Redis 获取
redis_config = self.config.get("redis", {})
if redis_config.get("use_redis_history", True):
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
max_messages = self.config["memory"]["max_messages"]
return redis_cache.get_chat_history(chat_id, max_messages)
# 降级到内存
return self.memory.get(chat_id, [])
def _clear_memory(self, chat_id: str):
"""清空指定会话的记忆"""
# 清空 Redis
redis_config = self.config.get("redis", {})
if redis_config.get("use_redis_history", True):
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
redis_cache.clear_chat_history(chat_id)
# 同时清空内存
if chat_id in self.memory:
del self.memory[chat_id]
async def _download_and_encode_image(self, bot, cdnurl: str, aeskey: str) -> str:
"""下载图片并转换为base64"""
"""下载图片并转换为base64,优先从缓存获取"""
try:
# 1. 优先从 Redis 缓存获取
from utils.redis_cache import RedisCache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
if media_key:
cached_data = redis_cache.get_cached_media(media_key, "image")
if cached_data:
logger.debug(f"[缓存命中] 图片从 Redis 获取: {media_key[:20]}...")
return cached_data
# 2. 缓存未命中,下载图片
logger.debug(f"[缓存未命中] 开始下载图片...")
temp_dir = Path(__file__).parent / "temp"
temp_dir.mkdir(exist_ok=True)
@@ -168,74 +300,114 @@ class AIChat(PluginBase):
with open(save_path, "rb") as f:
image_data = base64.b64encode(f.read()).decode()
base64_result = f"data:image/jpeg;base64,{image_data}"
# 3. 缓存到 Redis供后续使用
if redis_cache and redis_cache.enabled and media_key:
redis_cache.cache_media(media_key, base64_result, "image", ttl=300)
logger.debug(f"[已缓存] 图片缓存到 Redis: {media_key[:20]}...")
try:
Path(save_path).unlink()
except:
pass
return f"data:image/jpeg;base64,{image_data}"
return base64_result
except Exception as e:
logger.error(f"下载图片失败: {e}")
return ""
async def _download_emoji_and_encode(self, cdn_url: str) -> str:
"""下载表情包并转换为base64HTTP 直接下载"""
try:
# 替换 HTML 实体
cdn_url = cdn_url.replace("&", "&")
async def _download_emoji_and_encode(self, cdn_url: str, max_retries: int = 3) -> str:
"""下载表情包并转换为base64HTTP 直接下载,带重试机制),优先从缓存获取"""
# 替换 HTML 实体
cdn_url = cdn_url.replace("&", "&")
temp_dir = Path(__file__).parent / "temp"
temp_dir.mkdir(exist_ok=True)
# 1. 优先从 Redis 缓存获取
from utils.redis_cache import RedisCache
redis_cache = get_cache()
media_key = RedisCache.generate_media_key(cdnurl=cdn_url)
if redis_cache and redis_cache.enabled and media_key:
cached_data = redis_cache.get_cached_media(media_key, "emoji")
if cached_data:
logger.debug(f"[缓存命中] 表情包从 Redis 获取: {media_key[:20]}...")
return cached_data
filename = f"temp_{uuid.uuid4().hex[:8]}.gif"
save_path = temp_dir / filename
# 2. 缓存未命中,下载表情包
logger.debug(f"[缓存未命中] 开始下载表情包...")
temp_dir = Path(__file__).parent / "temp"
temp_dir.mkdir(exist_ok=True)
# 使用 aiohttp 下载
timeout = aiohttp.ClientTimeout(total=30)
filename = f"temp_{uuid.uuid4().hex[:8]}.gif"
save_path = temp_dir / filename
# 配置代理
connector = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5").upper()
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_username = proxy_config.get("username")
proxy_password = proxy_config.get("password")
last_error = None
if proxy_username and proxy_password:
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
else:
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
for attempt in range(max_retries):
try:
# 使用 aiohttp 下载,每次重试增加超时时间
timeout = aiohttp.ClientTimeout(total=30 + attempt * 15)
if PROXY_SUPPORT:
try:
connector = ProxyConnector.from_url(proxy_url)
except:
connector = None
# 配置代理
connector = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5").upper()
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_username = proxy_config.get("username")
proxy_password = proxy_config.get("password")
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.get(cdn_url) as response:
if response.status == 200:
content = await response.read()
with open(save_path, "wb") as f:
f.write(content)
if proxy_username and proxy_password:
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
else:
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
# 编码为 base64
image_data = base64.b64encode(content).decode()
# 删除临时文件
if PROXY_SUPPORT:
try:
save_path.unlink()
connector = ProxyConnector.from_url(proxy_url)
except:
pass
connector = None
return f"data:image/gif;base64,{image_data}"
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.get(cdn_url) as response:
if response.status == 200:
content = await response.read()
return ""
except Exception as e:
logger.error(f"下载表情包失败: {e}")
return ""
if len(content) == 0:
logger.warning(f"表情包下载内容为空,重试 {attempt + 1}/{max_retries}")
continue
# 编码为 base64
image_data = base64.b64encode(content).decode()
logger.debug(f"表情包下载成功,大小: {len(content)} 字节")
base64_result = f"data:image/gif;base64,{image_data}"
# 3. 缓存到 Redis供后续使用
if redis_cache and redis_cache.enabled and media_key:
redis_cache.cache_media(media_key, base64_result, "emoji", ttl=300)
logger.debug(f"[已缓存] 表情包缓存到 Redis: {media_key[:20]}...")
return base64_result
else:
logger.warning(f"表情包下载失败,状态码: {response.status},重试 {attempt + 1}/{max_retries}")
except asyncio.TimeoutError:
last_error = "请求超时"
logger.warning(f"表情包下载超时,重试 {attempt + 1}/{max_retries}")
except aiohttp.ClientError as e:
last_error = str(e)
logger.warning(f"表情包下载网络错误: {e},重试 {attempt + 1}/{max_retries}")
except Exception as e:
last_error = str(e)
logger.warning(f"表情包下载异常: {e},重试 {attempt + 1}/{max_retries}")
# 重试前等待(指数退避)
if attempt < max_retries - 1:
await asyncio.sleep(1 * (attempt + 1))
logger.error(f"表情包下载失败,已重试 {max_retries} 次: {last_error}")
return ""
async def _generate_image_description(self, image_base64: str, prompt: str, config: dict) -> str:
"""
@@ -479,37 +651,8 @@ class AIChat(PluginBase):
# 检查是否应该回复
should_reply = self._should_reply(message, content, bot_wxid)
# 获取用户昵称(用于历史记录)
nickname = ""
if is_group:
try:
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
if user_info and user_info.get("nickName", {}).get("string"):
nickname = user_info["nickName"]["string"]
except:
pass
# 如果获取昵称失败,从 MessageLogger 数据库查询
if not nickname:
from plugins.MessageLogger.main import MessageLogger
msg_logger = MessageLogger.get_instance()
if msg_logger:
try:
with msg_logger.get_db_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
(user_wxid,)
)
result = cursor.fetchone()
if result:
nickname = result[0]
except:
pass
# 最后降级使用 wxid
if not nickname:
nickname = user_wxid or sender_wxid or "未知用户"
# 获取用户昵称(用于历史记录)- 使用缓存优化
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
# 保存到群组历史记录(所有消息都保存,不管是否回复)
if is_group:
@@ -519,6 +662,16 @@ class AIChat(PluginBase):
if not should_reply:
return
# 限流检查(仅在需要回复时检查)
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
if not allowed:
rate_limit_config = self.config.get("rate_limit", {})
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
msg = msg.format(seconds=reset_time)
await bot.send_text(from_wxid, msg)
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
return False
# 提取实际消息内容(去除@
actual_content = self._extract_content(message, content)
if not actual_content:
@@ -1004,8 +1157,23 @@ class AIChat(PluginBase):
json.dump(history, f, ensure_ascii=False, indent=2)
temp_file.replace(history_file)
def _use_redis_for_group_history(self) -> bool:
"""检查是否使用 Redis 存储群聊历史"""
redis_config = self.config.get("redis", {})
if not redis_config.get("use_redis_history", True):
return False
redis_cache = get_cache()
return redis_cache and redis_cache.enabled
async def _load_history(self, chat_id: str) -> list:
"""异步读取群聊历史, 用锁避免与写入冲突"""
"""异步读取群聊历史, 优先使用 Redis"""
# 优先使用 Redis
if self._use_redis_for_group_history():
redis_cache = get_cache()
max_history = self.config.get("history", {}).get("max_history", 100)
return redis_cache.get_group_history(chat_id, max_history)
# 降级到文件存储
history_file = self._get_history_file(chat_id)
if not history_file:
return []
@@ -1015,6 +1183,10 @@ class AIChat(PluginBase):
async def _save_history(self, chat_id: str, history: list):
"""异步写入群聊历史, 包含长度截断"""
# Redis 模式下不需要单独保存add_group_message 已经处理
if self._use_redis_for_group_history():
return
history_file = self._get_history_file(chat_id)
if not history_file:
return
@@ -1040,6 +1212,27 @@ class AIChat(PluginBase):
if not self.config.get("history", {}).get("enabled", True):
return
# 构建消息内容
if image_base64:
message_content = [
{"type": "text", "text": content},
{"type": "image_url", "image_url": {"url": image_base64}}
]
else:
message_content = content
# 优先使用 Redis
if self._use_redis_for_group_history():
redis_cache = get_cache()
redis_config = self.config.get("redis", {})
ttl = redis_config.get("group_history_ttl", 172800)
redis_cache.add_group_message(chat_id, nickname, message_content, ttl=ttl)
# 裁剪历史
max_history = self.config.get("history", {}).get("max_history", 100)
redis_cache.trim_group_history(chat_id, max_history)
return
# 降级到文件存储
history_file = self._get_history_file(chat_id)
if not history_file:
return
@@ -1050,17 +1243,10 @@ class AIChat(PluginBase):
message_record = {
"nickname": nickname,
"timestamp": datetime.now().isoformat()
"timestamp": datetime.now().isoformat(),
"content": message_content
}
if image_base64:
message_record["content"] = [
{"type": "text", "text": content},
{"type": "image_url", "image_url": {"url": image_base64}}
]
else:
message_record["content"] = content
history.append(message_record)
max_history = self.config.get("history", {}).get("max_history", 100)
if len(history) > max_history:
@@ -1073,6 +1259,18 @@ class AIChat(PluginBase):
if not self.config.get("history", {}).get("enabled", True):
return
# 优先使用 Redis
if self._use_redis_for_group_history():
redis_cache = get_cache()
redis_config = self.config.get("redis", {})
ttl = redis_config.get("group_history_ttl", 172800)
redis_cache.add_group_message(chat_id, nickname, content, record_id=record_id, ttl=ttl)
# 裁剪历史
max_history = self.config.get("history", {}).get("max_history", 100)
redis_cache.trim_group_history(chat_id, max_history)
return
# 降级到文件存储
history_file = self._get_history_file(chat_id)
if not history_file:
return
@@ -1097,6 +1295,13 @@ class AIChat(PluginBase):
if not self.config.get("history", {}).get("enabled", True):
return
# 优先使用 Redis
if self._use_redis_for_group_history():
redis_cache = get_cache()
redis_cache.update_group_message_by_id(chat_id, record_id, new_content)
return
# 降级到文件存储
history_file = self._get_history_file(chat_id)
if not history_file:
return
@@ -1204,18 +1409,20 @@ class AIChat(PluginBase):
return True
logger.info(f"AI处理引用图片消息: {title_text[:50]}...")
# 获取用户昵称
nickname = ""
if is_group:
try:
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
if user_info and user_info.get("nickName", {}).get("string"):
nickname = user_info["nickName"]["string"]
logger.info(f"获取到用户昵称: {nickname}")
except Exception as e:
logger.error(f"获取用户昵称失败: {e}")
# 限流检查
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
if not allowed:
rate_limit_config = self.config.get("rate_limit", {})
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
msg = msg.format(seconds=reset_time)
await bot.send_text(from_wxid, msg)
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
return False
# 获取用户昵称 - 使用缓存优化
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
# 下载并编码图片
logger.info(f"开始下载图片: {cdnbigimgurl[:50]}...")
image_base64 = await self._download_and_encode_image(bot, cdnbigimgurl, aeskey)
@@ -1627,34 +1834,8 @@ class AIChat(PluginBase):
if not is_emoji and not aeskey:
return True
# 获取用户昵称
nickname = ""
try:
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
if user_info and user_info.get("nickName", {}).get("string"):
nickname = user_info["nickName"]["string"]
except:
pass
if not nickname:
from plugins.MessageLogger.main import MessageLogger
msg_logger = MessageLogger.get_instance()
if msg_logger:
try:
with msg_logger.get_db_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
(user_wxid,)
)
result = cursor.fetchone()
if result:
nickname = result[0]
except:
pass
if not nickname:
nickname = user_wxid or sender_wxid or "未知用户"
# 获取用户昵称 - 使用缓存优化
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
# 立即插入占位符到 history
placeholder_id = str(uuid.uuid4())