diff --git a/plugins/AIChat.zip b/plugins/AIChat.zip index be81073..183a623 100644 Binary files a/plugins/AIChat.zip and b/plugins/AIChat.zip differ diff --git a/plugins/AIChat/main.py b/plugins/AIChat/main.py index 5986cdd..ba7322a 100644 --- a/plugins/AIChat/main.py +++ b/plugins/AIChat/main.py @@ -8,13 +8,14 @@ AI 聊天插件 import asyncio import tomllib import aiohttp -import sqlite3 +import json from pathlib import Path from datetime import datetime from loguru import logger from utils.plugin_base import PluginBase from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message from utils.redis_cache import get_cache +from utils.llm_tooling import ToolResult, collect_tools_with_plugins, collect_tools, get_tool_schema_map, validate_tool_arguments import xml.etree.ElementTree as ET import base64 import uuid @@ -46,6 +47,7 @@ class AIChat(PluginBase): self.image_desc_queue = asyncio.Queue() # 图片描述任务队列 self.image_desc_workers = [] # 工作协程列表 self.persistent_memory_db = None # 持久记忆数据库路径 + self.store = None # ContextStore 实例(统一存储) async def async_init(self): """插件异步初始化""" @@ -88,82 +90,68 @@ class AIChat(PluginBase): self.image_desc_workers.append(worker) logger.info("已启动 2 个图片描述工作协程") - # 初始化持久记忆数据库 - self._init_persistent_memory_db() - - logger.info(f"AI 聊天插件已加载,模型: {self.config['api']['model']}") - - def _init_persistent_memory_db(self): - """初始化持久记忆数据库""" + # 初始化持久记忆数据库与统一存储 + from utils.context_store import ContextStore db_dir = Path(__file__).parent / "data" db_dir.mkdir(exist_ok=True) self.persistent_memory_db = db_dir / "persistent_memory.db" + self.store = ContextStore( + self.config, + self.history_dir, + self.memory, + self.history_locks, + self.persistent_memory_db, + ) + self.store.init_persistent_memory_db() - conn = sqlite3.connect(self.persistent_memory_db) - cursor = conn.cursor() - cursor.execute(""" - CREATE TABLE IF NOT EXISTS memories ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - chat_id TEXT NOT NULL, - chat_type TEXT NOT NULL, - user_wxid TEXT NOT NULL, - user_nickname TEXT, - content TEXT NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - """) - cursor.execute("CREATE INDEX IF NOT EXISTS idx_chat_id ON memories(chat_id)") - conn.commit() - conn.close() - logger.info(f"持久记忆数据库已初始化: {self.persistent_memory_db}") + logger.info(f"AI 聊天插件已加载,模型: {self.config['api']['model']}") + + async def on_disable(self): + """插件禁用时调用,清理后台任务和队列""" + await super().on_disable() + + # 取消图片描述工作协程,避免重载后叠加 + if self.image_desc_workers: + for worker in self.image_desc_workers: + worker.cancel() + await asyncio.gather(*self.image_desc_workers, return_exceptions=True) + self.image_desc_workers.clear() + + # 清空图片描述队列 + try: + while self.image_desc_queue and not self.image_desc_queue.empty(): + self.image_desc_queue.get_nowait() + self.image_desc_queue.task_done() + except Exception: + pass + self.image_desc_queue = asyncio.Queue() + + logger.info("AIChat 已清理后台图片描述任务") def _add_persistent_memory(self, chat_id: str, chat_type: str, user_wxid: str, user_nickname: str, content: str) -> int: - """添加持久记忆,返回记忆ID""" - conn = sqlite3.connect(self.persistent_memory_db) - cursor = conn.cursor() - cursor.execute(""" - INSERT INTO memories (chat_id, chat_type, user_wxid, user_nickname, content) - VALUES (?, ?, ?, ?, ?) - """, (chat_id, chat_type, user_wxid, user_nickname, content)) - memory_id = cursor.lastrowid - conn.commit() - conn.close() - return memory_id + """添加持久记忆,返回记忆ID(委托 ContextStore)""" + if not self.store: + return -1 + return self.store.add_persistent_memory(chat_id, chat_type, user_wxid, user_nickname, content) def _get_persistent_memories(self, chat_id: str) -> list: - """获取指定会话的所有持久记忆""" - conn = sqlite3.connect(self.persistent_memory_db) - cursor = conn.cursor() - cursor.execute(""" - SELECT id, user_nickname, content, created_at - FROM memories - WHERE chat_id = ? - ORDER BY created_at ASC - """, (chat_id,)) - rows = cursor.fetchall() - conn.close() - return [{"id": r[0], "nickname": r[1], "content": r[2], "time": r[3]} for r in rows] + """获取指定会话的所有持久记忆(委托 ContextStore)""" + if not self.store: + return [] + return self.store.get_persistent_memories(chat_id) def _delete_persistent_memory(self, chat_id: str, memory_id: int) -> bool: - """删除指定的持久记忆""" - conn = sqlite3.connect(self.persistent_memory_db) - cursor = conn.cursor() - cursor.execute("DELETE FROM memories WHERE id = ? AND chat_id = ?", (memory_id, chat_id)) - deleted = cursor.rowcount > 0 - conn.commit() - conn.close() - return deleted + """删除指定的持久记忆(委托 ContextStore)""" + if not self.store: + return False + return self.store.delete_persistent_memory(chat_id, memory_id) def _clear_persistent_memories(self, chat_id: str) -> int: - """清空指定会话的所有持久记忆,返回删除数量""" - conn = sqlite3.connect(self.persistent_memory_db) - cursor = conn.cursor() - cursor.execute("DELETE FROM memories WHERE chat_id = ?", (chat_id,)) - deleted_count = cursor.rowcount - conn.commit() - conn.close() - return deleted_count + """清空指定会话的所有持久记忆(委托 ContextStore)""" + if not self.store: + return 0 + return self.store.clear_persistent_memories(chat_id) def _get_chat_id(self, from_wxid: str, sender_wxid: str = None, is_group: bool = False) -> str: """获取会话ID""" @@ -270,69 +258,21 @@ class AIChat(PluginBase): content: 消息内容(可以是字符串或列表) image_base64: 可选的图片base64数据 """ - if not self.config.get("memory", {}).get("enabled", False): + if not self.store: return - - # 如果有图片,构建多模态内容 - if image_base64: - message_content = [ - {"type": "text", "text": content if isinstance(content, str) else ""}, - {"type": "image_url", "image_url": {"url": image_base64}} - ] - else: - message_content = content - - # 优先使用 Redis 存储 - redis_config = self.config.get("redis", {}) - if redis_config.get("use_redis_history", True): - redis_cache = get_cache() - if redis_cache and redis_cache.enabled: - ttl = redis_config.get("chat_history_ttl", 86400) - redis_cache.add_chat_message(chat_id, role, message_content, ttl=ttl) - # 裁剪历史 - max_messages = self.config["memory"]["max_messages"] - redis_cache.trim_chat_history(chat_id, max_messages) - return - - # 降级到内存存储 - if chat_id not in self.memory: - self.memory[chat_id] = [] - - self.memory[chat_id].append({"role": role, "content": message_content}) - - # 限制记忆长度 - max_messages = self.config["memory"]["max_messages"] - if len(self.memory[chat_id]) > max_messages: - self.memory[chat_id] = self.memory[chat_id][-max_messages:] + self.store.add_private_message(chat_id, role, content, image_base64=image_base64) def _get_memory_messages(self, chat_id: str) -> list: """获取记忆中的消息""" - if not self.config.get("memory", {}).get("enabled", False): + if not self.store: return [] - - # 优先从 Redis 获取 - redis_config = self.config.get("redis", {}) - if redis_config.get("use_redis_history", True): - redis_cache = get_cache() - if redis_cache and redis_cache.enabled: - max_messages = self.config["memory"]["max_messages"] - return redis_cache.get_chat_history(chat_id, max_messages) - - # 降级到内存 - return self.memory.get(chat_id, []) + return self.store.get_private_messages(chat_id) def _clear_memory(self, chat_id: str): """清空指定会话的记忆""" - # 清空 Redis - redis_config = self.config.get("redis", {}) - if redis_config.get("use_redis_history", True): - redis_cache = get_cache() - if redis_cache and redis_cache.enabled: - redis_cache.clear_chat_history(chat_id) - - # 同时清空内存 - if chat_id in self.memory: - del self.memory[chat_id] + if not self.store: + return + self.store.clear_private_messages(chat_id) async def _download_and_encode_image(self, bot, cdnurl: str, aeskey: str) -> str: """下载图片并转换为base64,优先从缓存获取""" @@ -498,127 +438,129 @@ class AIChat(PluginBase): Returns: 图片描述文本,失败返回空字符串 """ - try: - api_config = self.config["api"] - description_model = config.get("model", api_config["model"]) + api_config = self.config["api"] + description_model = config.get("model", api_config["model"]) - # 构建消息 - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - {"type": "image_url", "image_url": {"url": image_base64}} - ] - } - ] - - payload = { - "model": description_model, - "messages": messages, - "max_tokens": config.get("max_tokens", 1000), - "stream": True + # 构建消息 + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": image_base64}} + ] } + ] - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_config['api_key']}" - } + payload = { + "model": description_model, + "messages": messages, + "max_tokens": config.get("max_tokens", 1000), + "stream": True + } - timeout = aiohttp.ClientTimeout(total=api_config["timeout"]) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_config['api_key']}" + } - # 配置代理 - connector = None - proxy_config = self.config.get("proxy", {}) - if proxy_config.get("enabled", False): - proxy_type = proxy_config.get("type", "socks5").upper() - proxy_host = proxy_config.get("host", "127.0.0.1") - proxy_port = proxy_config.get("port", 7890) - proxy_username = proxy_config.get("username") - proxy_password = proxy_config.get("password") + max_retries = int(config.get("retries", 2)) + last_error = None - if proxy_username and proxy_password: - proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}" - else: - proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}" + for attempt in range(max_retries + 1): + try: + timeout = aiohttp.ClientTimeout(total=api_config["timeout"]) - if PROXY_SUPPORT: - try: - connector = ProxyConnector.from_url(proxy_url) - except Exception as e: - logger.warning(f"代理配置失败,将直连: {e}") - connector = None + # 配置代理(每次重试单独构造 connector) + connector = None + proxy_config = self.config.get("proxy", {}) + if proxy_config.get("enabled", False): + proxy_type = proxy_config.get("type", "socks5").upper() + proxy_host = proxy_config.get("host", "127.0.0.1") + proxy_port = proxy_config.get("port", 7890) + proxy_username = proxy_config.get("username") + proxy_password = proxy_config.get("password") - async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: - async with session.post( - api_config["url"], - json=payload, - headers=headers - ) as resp: - if resp.status != 200: - error_text = await resp.text() - logger.error(f"图片描述 API 返回错误: {resp.status}, {error_text[:200]}") - return "" + if proxy_username and proxy_password: + proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}" + else: + proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}" - # 流式接收响应 - import json - description = "" - async for line in resp.content: - line = line.decode('utf-8').strip() - if not line or line == "data: [DONE]": - continue + if PROXY_SUPPORT: + try: + connector = ProxyConnector.from_url(proxy_url) + except Exception as e: + logger.warning(f"代理配置失败,将直连: {e}") + connector = None - if line.startswith("data: "): - try: - data = json.loads(line[6:]) - delta = data.get("choices", [{}])[0].get("delta", {}) - content = delta.get("content", "") - if content: - description += content - except: - pass + async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: + async with session.post( + api_config["url"], + json=payload, + headers=headers + ) as resp: + if resp.status != 200: + error_text = await resp.text() + raise Exception(f"图片描述 API 返回错误: {resp.status}, {error_text[:200]}") - logger.debug(f"图片描述生成成功: {description}") - return description.strip() + # 流式接收响应 + description = "" + async for line in resp.content: + line = line.decode('utf-8').strip() + if not line or line == "data: [DONE]": + continue - except Exception as e: - logger.error(f"生成图片描述失败: {e}") - import traceback - logger.error(f"详细错误: {traceback.format_exc()}") - return "" + if line.startswith("data: "): + try: + data = json.loads(line[6:]) + delta = data.get("choices", [{}])[0].get("delta", {}) + content = delta.get("content", "") + if content: + description += content + except Exception: + pass + + logger.debug(f"图片描述生成成功: {description}") + return description.strip() + + except asyncio.CancelledError: + raise + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + last_error = str(e) + if attempt < max_retries: + logger.warning(f"图片描述网络错误: {e},重试 {attempt + 1}/{max_retries}") + await asyncio.sleep(1 * (attempt + 1)) + continue + except Exception as e: + last_error = str(e) + if attempt < max_retries: + logger.warning(f"图片描述生成异常: {e},重试 {attempt + 1}/{max_retries}") + await asyncio.sleep(1 * (attempt + 1)) + continue + + logger.error(f"生成图片描述失败,已重试 {max_retries + 1} 次: {last_error}") + return "" + + def _collect_tools_with_plugins(self) -> dict: + """收集所有插件的 LLM 工具,并保留来源插件名""" + from utils.plugin_manager import PluginManager + tools_config = self.config.get("tools", {}) + return collect_tools_with_plugins(tools_config, PluginManager().plugins) def _collect_tools(self): """收集所有插件的LLM工具(支持白名单/黑名单过滤)""" from utils.plugin_manager import PluginManager - tools = [] - - # 获取工具过滤配置 tools_config = self.config.get("tools", {}) - mode = tools_config.get("mode", "all") - whitelist = set(tools_config.get("whitelist", [])) - blacklist = set(tools_config.get("blacklist", [])) + return collect_tools(tools_config, PluginManager().plugins) - for plugin in PluginManager().plugins.values(): - if hasattr(plugin, 'get_llm_tools'): - plugin_tools = plugin.get_llm_tools() - if plugin_tools: - for tool in plugin_tools: - tool_name = tool.get("function", {}).get("name", "") + def _get_tool_schema_map(self, tools_map: dict | None = None) -> dict: + """构建工具名到参数 schema 的映射""" + tools_map = tools_map or self._collect_tools_with_plugins() + return get_tool_schema_map(tools_map) - # 根据模式过滤 - if mode == "whitelist": - if tool_name in whitelist: - tools.append(tool) - logger.debug(f"[白名单] 启用工具: {tool_name}") - elif mode == "blacklist": - if tool_name not in blacklist: - tools.append(tool) - else: - logger.debug(f"[黑名单] 禁用工具: {tool_name}") - else: # all - tools.append(tool) - - return tools + def _validate_tool_arguments(self, tool_name: str, arguments: dict, schema: dict) -> tuple: + """轻量校验并补全默认参数""" + return validate_tool_arguments(tool_name, arguments, schema) async def _handle_list_prompts(self, bot, from_wxid: str): """处理人设列表指令""" @@ -698,6 +640,55 @@ class AIChat(PluginBase): return total return 0 + def _extract_text_from_multimodal(self, content) -> str: + """从多模态 content 中提取文本,模型不支持时用于降级""" + if isinstance(content, list): + texts = [item.get("text", "") for item in content if item.get("type") == "text"] + text = "".join(texts).strip() + return text or "[图片]" + if content is None: + return "" + return str(content) + + def _append_group_history_messages(self, messages: list, recent_history: list): + """将群聊历史按 role 追加到 LLM messages""" + for msg in recent_history: + role = msg.get("role") or "user" + msg_nickname = msg.get("nickname", "") + msg_content = msg.get("content", "") + + # 机器人历史回复 + if role == "assistant": + if isinstance(msg_content, list): + msg_content = self._extract_text_from_multimodal(msg_content) + messages.append({ + "role": "assistant", + "content": msg_content + }) + continue + + # 用户历史消息 + if isinstance(msg_content, list): + content_with_nickname = [] + for item in msg_content: + if item.get("type") == "text": + content_with_nickname.append({ + "type": "text", + "text": f"[{msg_nickname}] {item.get('text', '')}" + }) + else: + content_with_nickname.append(item) + + messages.append({ + "role": "user", + "content": content_with_nickname + }) + else: + messages.append({ + "role": "user", + "content": f"[{msg_nickname}] {msg_content}" + }) + async def _handle_context_stats(self, bot, from_wxid: str, user_wxid: str, is_group: bool): """处理上下文统计指令""" try: @@ -867,6 +858,34 @@ class AIChat(PluginBase): await self._handle_context_stats(bot, from_wxid, user_wxid, is_group) return False + # 旧群历史 key 扫描/清理(仅管理员) + if content in ("/旧群历史", "/legacy_history"): + if user_wxid in admins and self.store: + legacy_keys = self.store.find_legacy_group_history_keys() + if legacy_keys: + await bot.send_text( + from_wxid, + f"⚠️ 检测到 {len(legacy_keys)} 个旧版群历史 key(safe_id 写入)。\n" + f"如需清理请发送 /清理旧群历史", + ) + else: + await bot.send_text(from_wxid, "✅ 未发现旧版群历史 key") + else: + await bot.send_text(from_wxid, "❌ 仅管理员可执行该指令") + return False + + if content in ("/清理旧群历史", "/clean_legacy_history"): + if user_wxid in admins and self.store: + legacy_keys = self.store.find_legacy_group_history_keys() + deleted = self.store.delete_legacy_group_history_keys(legacy_keys) + await bot.send_text( + from_wxid, + f"✅ 已清理旧版群历史 key: {deleted} 个", + ) + else: + await bot.send_text(from_wxid, "❌ 仅管理员可执行该指令") + return False + # 检查是否是记忆状态指令(仅管理员) if content == "/记忆状态": if user_wxid in admins: @@ -953,8 +972,9 @@ class AIChat(PluginBase): nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group) # 保存到群组历史记录(所有消息都保存,不管是否回复) - if is_group: - await self._add_to_history(from_wxid, nickname, content) + # 但如果是 AutoReply 触发的,跳过保存(消息已经在正常流程中保存过了) + if is_group and not message.get('_auto_reply_triggered'): + await self._add_to_history(from_wxid, nickname, content, sender_wxid=user_wxid) # 如果不需要回复,直接返回 if not should_reply: @@ -980,7 +1000,9 @@ class AIChat(PluginBase): try: # 获取会话ID并添加用户消息到记忆 chat_id = self._get_chat_id(from_wxid, user_wxid, is_group) - self._add_to_memory(chat_id, "user", actual_content) + # 如果是 AutoReply 触发的,不重复添加用户消息(已在正常流程中添加) + if not message.get('_auto_reply_triggered'): + self._add_to_memory(chat_id, "user", actual_content) # 调用 AI API(带重试机制) max_retries = self.config.get("api", {}).get("max_retries", 2) @@ -1025,7 +1047,7 @@ class AIChat(PluginBase): with open("main_config.toml", "rb") as f: main_config = tomllib.load(f) bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人") - await self._add_to_history(from_wxid, bot_nickname, response) + await self._add_to_history(from_wxid, bot_nickname, response, role="assistant") logger.success(f"AI 回复成功: {response[:50]}...") else: logger.info("AI 回复为空或已通过其他方式发送(如聊天记录)") @@ -1159,36 +1181,8 @@ class AIChat(PluginBase): # 取最近的 N 条消息作为上下文 recent_history = history[-max_context:] if len(history) > max_context else history - # 转换为 AI 消息格式 - for msg in recent_history: - msg_nickname = msg.get("nickname", "") - msg_content = msg.get("content", "") - - # 检查是否是多模态内容(包含图片) - if isinstance(msg_content, list): - # 多模态内容:添加昵称前缀到文本部分 - content_with_nickname = [] - for item in msg_content: - if item.get("type") == "text": - # 在文本前添加昵称 - content_with_nickname.append({ - "type": "text", - "text": f"[{msg_nickname}] {item.get('text', '')}" - }) - else: - # 图片等其他类型直接保留 - content_with_nickname.append(item) - - messages.append({ - "role": "user", - "content": content_with_nickname - }) - else: - # 纯文本内容:简单格式 [昵称] 内容 - messages.append({ - "role": "user", - "content": f"[{msg_nickname}] {msg_content}" - }) + # 转换为 AI 消息格式(按 role) + self._append_group_history_messages(messages, recent_history) else: # 私聊使用原有的 memory 机制 if chat_id: @@ -1199,10 +1193,6 @@ class AIChat(PluginBase): # 添加当前用户消息 messages.append({"role": "user", "content": f"[{nickname}] {user_message}" if is_group and nickname else user_message}) - # 保存用户信息供工具调用使用 - self._current_user_wxid = user_wxid - self._current_is_group = is_group - payload = { "model": api_config["model"], "messages": messages, @@ -1346,7 +1336,7 @@ class AIChat(PluginBase): asyncio.create_task( self._execute_tools_async( tool_calls_data, bot, from_wxid, chat_id, - nickname, is_group, messages + user_wxid, nickname, is_group, messages ) ) # 返回 None 表示工具调用已异步处理,不需要重试 @@ -1369,238 +1359,142 @@ class AIChat(PluginBase): raise Exception(f"API 响应格式错误: {e}") - def _get_history_file(self, chat_id: str) -> Path: - """获取群聊历史记录文件路径""" - if not self.history_dir: - return None - safe_name = chat_id.replace("@", "_").replace(":", "_") - return self.history_dir / f"{safe_name}.json" - - def _get_history_lock(self, chat_id: str) -> asyncio.Lock: - """获取指定会话的锁, 每个会话一把""" - lock = self.history_locks.get(chat_id) - if lock is None: - lock = asyncio.Lock() - self.history_locks[chat_id] = lock - return lock - - def _read_history_file(self, history_file: Path) -> list: - try: - import json - with open(history_file, "r", encoding="utf-8") as f: - return json.load(f) - except FileNotFoundError: - return [] - except Exception as e: - logger.error(f"读取历史记录失败: {e}") - return [] - - def _write_history_file(self, history_file: Path, history: list): - import json - history_file.parent.mkdir(parents=True, exist_ok=True) - temp_file = Path(str(history_file) + ".tmp") - with open(temp_file, "w", encoding="utf-8") as f: - json.dump(history, f, ensure_ascii=False, indent=2) - temp_file.replace(history_file) - - def _use_redis_for_group_history(self) -> bool: - """检查是否使用 Redis 存储群聊历史""" - redis_config = self.config.get("redis", {}) - if not redis_config.get("use_redis_history", True): - return False - redis_cache = get_cache() - return redis_cache and redis_cache.enabled - async def _load_history(self, chat_id: str) -> list: - """异步读取群聊历史, 优先使用 Redis""" - # 优先使用 Redis - if self._use_redis_for_group_history(): - redis_cache = get_cache() - max_history = self.config.get("history", {}).get("max_history", 100) - return redis_cache.get_group_history(chat_id, max_history) - - # 降级到文件存储 - history_file = self._get_history_file(chat_id) - if not history_file: + """异步读取群聊历史(委托 ContextStore)""" + if not self.store: return [] - lock = self._get_history_lock(chat_id) - async with lock: - return self._read_history_file(history_file) + return await self.store.load_group_history(chat_id) - async def _save_history(self, chat_id: str, history: list): - """异步写入群聊历史, 包含长度截断""" - # Redis 模式下不需要单独保存,add_group_message 已经处理 - if self._use_redis_for_group_history(): + async def _add_to_history( + self, + chat_id: str, + nickname: str, + content: str, + image_base64: str = None, + *, + role: str = "user", + sender_wxid: str = None, + ): + """将消息存入群聊历史(委托 ContextStore)""" + if not self.store: return + await self.store.add_group_message( + chat_id, + nickname, + content, + image_base64=image_base64, + role=role, + sender_wxid=sender_wxid, + ) - history_file = self._get_history_file(chat_id) - if not history_file: + async def _add_to_history_with_id( + self, + chat_id: str, + nickname: str, + content: str, + record_id: str, + *, + role: str = "user", + sender_wxid: str = None, + ): + """带ID的历史追加, 便于后续更新(委托 ContextStore)""" + if not self.store: return - - max_history = self.config.get("history", {}).get("max_history", 100) - if len(history) > max_history: - history = history[-max_history:] - - lock = self._get_history_lock(chat_id) - async with lock: - self._write_history_file(history_file, history) - - async def _add_to_history(self, chat_id: str, nickname: str, content: str, image_base64: str = None): - """ - 将消息存入群聊历史 - - Args: - chat_id: 群聊ID - nickname: 用户昵称 - content: 消息内容 - image_base64: 可选的图片base64 - """ - if not self.config.get("history", {}).get("enabled", True): - return - - # 构建消息内容 - if image_base64: - message_content = [ - {"type": "text", "text": content}, - {"type": "image_url", "image_url": {"url": image_base64}} - ] - else: - message_content = content - - # 优先使用 Redis - if self._use_redis_for_group_history(): - redis_cache = get_cache() - redis_config = self.config.get("redis", {}) - ttl = redis_config.get("group_history_ttl", 172800) - redis_cache.add_group_message(chat_id, nickname, message_content, ttl=ttl) - # 裁剪历史 - max_history = self.config.get("history", {}).get("max_history", 100) - redis_cache.trim_group_history(chat_id, max_history) - return - - # 降级到文件存储 - history_file = self._get_history_file(chat_id) - if not history_file: - return - - lock = self._get_history_lock(chat_id) - async with lock: - history = self._read_history_file(history_file) - - message_record = { - "nickname": nickname, - "timestamp": datetime.now().isoformat(), - "content": message_content - } - - history.append(message_record) - max_history = self.config.get("history", {}).get("max_history", 100) - if len(history) > max_history: - history = history[-max_history:] - - self._write_history_file(history_file, history) - - async def _add_to_history_with_id(self, chat_id: str, nickname: str, content: str, record_id: str): - """带ID的历史追加, 便于后续更新""" - if not self.config.get("history", {}).get("enabled", True): - return - - # 优先使用 Redis - if self._use_redis_for_group_history(): - redis_cache = get_cache() - redis_config = self.config.get("redis", {}) - ttl = redis_config.get("group_history_ttl", 172800) - redis_cache.add_group_message(chat_id, nickname, content, record_id=record_id, ttl=ttl) - # 裁剪历史 - max_history = self.config.get("history", {}).get("max_history", 100) - redis_cache.trim_group_history(chat_id, max_history) - return - - # 降级到文件存储 - history_file = self._get_history_file(chat_id) - if not history_file: - return - - lock = self._get_history_lock(chat_id) - async with lock: - history = self._read_history_file(history_file) - message_record = { - "id": record_id, - "nickname": nickname, - "timestamp": datetime.now().isoformat(), - "content": content - } - history.append(message_record) - max_history = self.config.get("history", {}).get("max_history", 100) - if len(history) > max_history: - history = history[-max_history:] - self._write_history_file(history_file, history) + await self.store.add_group_message( + chat_id, + nickname, + content, + record_id=record_id, + role=role, + sender_wxid=sender_wxid, + ) async def _update_history_by_id(self, chat_id: str, record_id: str, new_content: str): - """根据ID更新历史记录""" - if not self.config.get("history", {}).get("enabled", True): + """根据ID更新历史记录(委托 ContextStore)""" + if not self.store: return - - # 优先使用 Redis - if self._use_redis_for_group_history(): - redis_cache = get_cache() - redis_cache.update_group_message_by_id(chat_id, record_id, new_content) - return - - # 降级到文件存储 - history_file = self._get_history_file(chat_id) - if not history_file: - return - - lock = self._get_history_lock(chat_id) - async with lock: - history = self._read_history_file(history_file) - for record in history: - if record.get("id") == record_id: - record["content"] = new_content - break - max_history = self.config.get("history", {}).get("max_history", 100) - if len(history) > max_history: - history = history[-max_history:] - self._write_history_file(history_file, history) + await self.store.update_group_message_by_id(chat_id, record_id, new_content) - async def _execute_tool_and_get_result(self, tool_name: str, arguments: dict, bot, from_wxid: str): + async def _execute_tool_and_get_result( + self, + tool_name: str, + arguments: dict, + bot, + from_wxid: str, + user_wxid: str = None, + is_group: bool = False, + tools_map: dict | None = None, + ): """执行工具调用并返回结果""" from utils.plugin_manager import PluginManager # 添加用户信息到 arguments - arguments["user_wxid"] = getattr(self, "_current_user_wxid", from_wxid) - arguments["is_group"] = getattr(self, "_current_is_group", False) + arguments["user_wxid"] = user_wxid or from_wxid + arguments["is_group"] = bool(is_group) logger.info(f"开始执行工具: {tool_name}") plugins = PluginManager().plugins logger.info(f"检查 {len(plugins)} 个插件") - for plugin_name, plugin in plugins.items(): - logger.debug(f"检查插件: {plugin_name}, 有execute_llm_tool: {hasattr(plugin, 'execute_llm_tool')}") - if hasattr(plugin, 'execute_llm_tool'): + async def _normalize_result(raw, plugin_name: str): + if raw is None: + return None + + if not isinstance(raw, dict): + raw = {"success": True, "message": str(raw)} + else: + raw.setdefault("success", True) + + if raw.get("success"): + logger.success(f"工具执行成功: {tool_name} ({plugin_name})") + else: + logger.warning(f"工具执行失败: {tool_name} ({plugin_name})") + return raw + + # 先尝试直达目标插件(来自 get_llm_tools 的映射) + if tools_map and tool_name in tools_map: + target_plugin_name, _tool_def = tools_map[tool_name] + target_plugin = plugins.get(target_plugin_name) + if target_plugin and hasattr(target_plugin, "execute_llm_tool"): try: - logger.info(f"调用 {plugin_name}.execute_llm_tool") - result = await plugin.execute_llm_tool(tool_name, arguments, bot, from_wxid) - logger.info(f"{plugin_name} 返回: {result}") - if result is not None: - if result.get("success"): - logger.success(f"工具执行成功: {tool_name}") - return result - else: - logger.debug(f"{plugin_name} 不处理此工具,继续检查下一个插件") + logger.info(f"直达调用 {target_plugin_name}.execute_llm_tool") + result = await target_plugin.execute_llm_tool(tool_name, arguments, bot, from_wxid) + logger.info(f"{target_plugin_name} 返回: {result}") + normalized = await _normalize_result(result, target_plugin_name) + if normalized is not None: + return normalized except Exception as e: - logger.error(f"工具执行异常 ({plugin_name}): {tool_name}, {e}") + logger.error(f"工具执行异常 ({target_plugin_name}): {tool_name}, {e}") import traceback logger.error(f"详细错误: {traceback.format_exc()}") + else: + logger.warning(f"工具 {tool_name} 期望插件 {target_plugin_name} 不存在或不支持 execute_llm_tool,回退全量扫描") + + # 回退:遍历所有插件 + for plugin_name, plugin in plugins.items(): + logger.debug(f"检查插件: {plugin_name}, 有execute_llm_tool: {hasattr(plugin, 'execute_llm_tool')}") + if not hasattr(plugin, "execute_llm_tool"): + continue + + try: + logger.info(f"调用 {plugin_name}.execute_llm_tool") + result = await plugin.execute_llm_tool(tool_name, arguments, bot, from_wxid) + logger.info(f"{plugin_name} 返回: {result}") + normalized = await _normalize_result(result, plugin_name) + if normalized is not None: + return normalized + except Exception as e: + logger.error(f"工具执行异常 ({plugin_name}): {tool_name}, {e}") + import traceback + logger.error(f"详细错误: {traceback.format_exc()}") logger.warning(f"未找到工具: {tool_name}") return {"success": False, "message": f"未找到工具: {tool_name}"} async def _execute_tools_async(self, tool_calls_data: list, bot, from_wxid: str, - chat_id: str, nickname: str, is_group: bool, + chat_id: str, user_wxid: str, nickname: str, is_group: bool, messages: list): """ 异步执行工具调用(不阻塞主流程) @@ -1608,14 +1502,14 @@ class AIChat(PluginBase): AI 已经先回复用户,这里异步执行工具,完成后发送结果 支持 need_ai_reply 标记:工具结果回传给 AI 继续对话(保留上下文和人设) """ - import json - try: logger.info(f"开始异步执行 {len(tool_calls_data)} 个工具调用") # 并行执行所有工具 tasks = [] tool_info_list = [] # 保存工具信息用于后续处理 + tools_map = self._collect_tools_with_plugins() + schema_map = self._get_tool_schema_map(tools_map) for tool_call in tool_calls_data: function_name = tool_call.get("function", {}).get("name", "") @@ -1627,13 +1521,31 @@ class AIChat(PluginBase): try: arguments = json.loads(arguments_str) - except: + except Exception: arguments = {} + schema = schema_map.get(function_name) + ok, err, arguments = self._validate_tool_arguments(function_name, arguments, schema) + if not ok: + logger.warning(f"[异步] 工具 {function_name} 参数校验失败: {err}") + try: + await bot.send_text(from_wxid, f"❌ 工具 {function_name} 参数错误: {err}") + except Exception: + pass + continue + logger.info(f"[异步] 准备执行工具: {function_name}, 参数: {arguments}") # 创建异步任务 - task = self._execute_tool_and_get_result(function_name, arguments, bot, from_wxid) + task = self._execute_tool_and_get_result( + function_name, + arguments, + bot, + from_wxid, + user_wxid=user_wxid, + is_group=is_group, + tools_map=tools_map, + ) tasks.append(task) tool_info_list.append({ "tool_call_id": tool_call_id, @@ -1656,38 +1568,45 @@ class AIChat(PluginBase): if isinstance(result, Exception): logger.error(f"[异步] 工具 {function_name} 执行异常: {result}") - # 发送错误提示 - await bot.send_text(from_wxid, f"❌ {function_name} 执行失败") + try: + await bot.send_text(from_wxid, f"❌ {function_name} 执行失败: {result}") + except Exception: + pass continue - if result and result.get("success"): + tool_result = ToolResult.from_raw(result) + if not tool_result: + continue + + if tool_result.success: logger.success(f"[异步] 工具 {function_name} 执行成功") - - # 检查是否需要 AI 基于工具结果继续回复 - if result.get("need_ai_reply"): - need_ai_reply_results.append({ - "tool_call_id": tool_call_id, - "function_name": function_name, - "result": result.get("message", "") - }) - continue # 不直接发送,等待 AI 处理 - - # 如果工具没有自己发送内容,且有消息需要发送 - if not result.get("already_sent") and result.get("message"): - # 某些工具可能需要发送结果消息 - msg = result.get("message", "") - if msg and not result.get("no_reply"): - # 检查是否需要发送文本结果 - if result.get("send_result_text"): - await bot.send_text(from_wxid, msg) - - # 保存工具结果到记忆(可选) - if result.get("save_to_memory") and chat_id: - self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {result.get('message', '')}") else: - logger.warning(f"[异步] 工具 {function_name} 执行失败: {result}") - if result and result.get("message"): - await bot.send_text(from_wxid, f"❌ {result.get('message')}") + logger.warning(f"[异步] 工具 {function_name} 执行失败") + + # 需要 AI 继续处理的结果 + if tool_result.need_ai_reply: + need_ai_reply_results.append({ + "tool_call_id": tool_call_id, + "function_name": function_name, + "result": tool_result.message + }) + continue + + # 工具成功且需要回文本时发送 + if tool_result.success and not tool_result.already_sent and tool_result.message and not tool_result.no_reply: + if tool_result.send_result_text: + await bot.send_text(from_wxid, tool_result.message) + + # 工具失败默认回一条错误提示 + if not tool_result.success and tool_result.message and not tool_result.no_reply: + try: + await bot.send_text(from_wxid, f"❌ {tool_result.message}") + except Exception: + pass + + # 保存工具结果到记忆(可选) + if tool_result.save_to_memory and chat_id: + self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {tool_result.message}") # 如果有需要 AI 回复的工具结果,调用 AI 继续对话 if need_ai_reply_results: @@ -1833,21 +1752,21 @@ class AIChat(PluginBase): pass async def _execute_tools_async_with_image(self, tool_calls_data: list, bot, from_wxid: str, - chat_id: str, nickname: str, is_group: bool, + chat_id: str, user_wxid: str, nickname: str, is_group: bool, messages: list, image_base64: str): """ 异步执行工具调用(带图片参数,用于图生图等场景) AI 已经先回复用户,这里异步执行工具,完成后发送结果 """ - import json - try: logger.info(f"[异步-图片] 开始执行 {len(tool_calls_data)} 个工具调用") # 并行执行所有工具 tasks = [] tool_info_list = [] + tools_map = self._collect_tools_with_plugins() + schema_map = self._get_tool_schema_map(tools_map) for tool_call in tool_calls_data: function_name = tool_call.get("function", {}).get("name", "") @@ -1859,7 +1778,7 @@ class AIChat(PluginBase): try: arguments = json.loads(arguments_str) - except: + except Exception: arguments = {} # 如果是图生图工具,添加图片 base64 @@ -1867,9 +1786,27 @@ class AIChat(PluginBase): arguments["image_base64"] = image_base64 logger.info(f"[异步-图片] 图生图工具,已添加图片数据") + schema = schema_map.get(function_name) + ok, err, arguments = self._validate_tool_arguments(function_name, arguments, schema) + if not ok: + logger.warning(f"[异步-图片] 工具 {function_name} 参数校验失败: {err}") + try: + await bot.send_text(from_wxid, f"❌ 工具 {function_name} 参数错误: {err}") + except Exception: + pass + continue + logger.info(f"[异步-图片] 准备执行工具: {function_name}") - task = self._execute_tool_and_get_result(function_name, arguments, bot, from_wxid) + task = self._execute_tool_and_get_result( + function_name, + arguments, + bot, + from_wxid, + user_wxid=user_wxid, + is_group=is_group, + tools_map=tools_map, + ) tasks.append(task) tool_info_list.append({ "tool_call_id": tool_call_id, @@ -1887,23 +1824,33 @@ class AIChat(PluginBase): if isinstance(result, Exception): logger.error(f"[异步-图片] 工具 {function_name} 执行异常: {result}") - await bot.send_text(from_wxid, f"❌ {function_name} 执行失败") + try: + await bot.send_text(from_wxid, f"❌ {function_name} 执行失败: {result}") + except Exception: + pass continue - if result and result.get("success"): + tool_result = ToolResult.from_raw(result) + if not tool_result: + continue + + if tool_result.success: logger.success(f"[异步-图片] 工具 {function_name} 执行成功") - - if not result.get("already_sent") and result.get("message"): - msg = result.get("message", "") - if msg and not result.get("no_reply") and result.get("send_result_text"): - await bot.send_text(from_wxid, msg) - - if result.get("save_to_memory") and chat_id: - self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {result.get('message', '')}") else: - logger.warning(f"[异步-图片] 工具 {function_name} 执行失败: {result}") - if result and result.get("message"): - await bot.send_text(from_wxid, f"❌ {result.get('message')}") + logger.warning(f"[异步-图片] 工具 {function_name} 执行失败") + + if tool_result.success and not tool_result.already_sent and tool_result.message and not tool_result.no_reply: + if tool_result.send_result_text: + await bot.send_text(from_wxid, tool_result.message) + + if not tool_result.success and tool_result.message and not tool_result.no_reply: + try: + await bot.send_text(from_wxid, f"❌ {tool_result.message}") + except Exception: + pass + + if tool_result.save_to_memory and chat_id: + self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {tool_result.message}") logger.info(f"[异步-图片] 所有工具执行完成") @@ -2125,7 +2072,7 @@ class AIChat(PluginBase): with open("main_config.toml", "rb") as f: main_config = tomllib.load(f) bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人") - await self._add_to_history(from_wxid, bot_nickname, response) + await self._add_to_history(from_wxid, bot_nickname, response, role="assistant") logger.success(f"AI回复成功: {response[:50]}...") return False @@ -2243,7 +2190,7 @@ class AIChat(PluginBase): with open("main_config.toml", "rb") as f: main_config = tomllib.load(f) bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人") - await self._add_to_history(from_wxid, bot_nickname, response) + await self._add_to_history(from_wxid, bot_nickname, response, role="assistant") logger.success(f"[聊天记录] AI 回复成功: {response[:50]}...") else: await bot.send_text(from_wxid, "❌ AI 回复生成失败") @@ -2318,7 +2265,7 @@ class AIChat(PluginBase): await self._add_to_history(from_wxid, nickname, f"[发送了一个视频] {user_question}") # 调用主AI生成回复(使用现有的 _call_ai_api 方法,继承完整上下文) - response = await self._call_ai_api(combined_message, chat_id, from_wxid, is_group, nickname) + response = await self._call_ai_api(combined_message, bot, from_wxid, chat_id, nickname, user_wxid, is_group) if response: await bot.send_text(from_wxid, response) @@ -2329,7 +2276,7 @@ class AIChat(PluginBase): with open("main_config.toml", "rb") as f: main_config = tomllib.load(f) bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人") - await self._add_to_history(from_wxid, bot_nickname, response) + await self._add_to_history(from_wxid, bot_nickname, response, role="assistant") logger.success(f"[视频识别] 主AI回复成功: {response[:50]}...") else: await bot.send_text(from_wxid, "❌ AI 回复生成失败") @@ -2882,27 +2829,40 @@ class AIChat(PluginBase): if nickname: system_content += f"\n当前对话用户的昵称是:{nickname}" + # 加载持久记忆(与文本模式一致) + memory_chat_id = from_wxid if is_group else user_wxid + if memory_chat_id: + persistent_memories = self._get_persistent_memories(memory_chat_id) + if persistent_memories: + system_content += "\n\n【持久记忆】以下是用户要求你记住的重要信息:\n" + for m in persistent_memories: + mem_time = m['time'][:10] if m['time'] else "" + system_content += f"- [{mem_time}] {m['nickname']}: {m['content']}\n" + messages = [{"role": "system", "content": system_content}] - # 添加历史记忆 - if chat_id: - memory_messages = self._get_memory_messages(chat_id) - if memory_messages and len(memory_messages) > 1: - messages.extend(memory_messages[:-1]) + # 添加历史上下文 + if is_group and from_wxid: + history = await self._load_history(from_wxid) + max_context = self.config.get("history", {}).get("max_context", 50) + recent_history = history[-max_context:] if len(history) > max_context else history + self._append_group_history_messages(messages, recent_history) + else: + if chat_id: + memory_messages = self._get_memory_messages(chat_id) + if memory_messages and len(memory_messages) > 1: + messages.extend(memory_messages[:-1]) # 添加当前用户消息(带图片) + text_value = f"[{nickname}] {user_message}" if is_group and nickname else user_message messages.append({ "role": "user", "content": [ - {"type": "text", "text": user_message}, + {"type": "text", "text": text_value}, {"type": "image_url", "image_url": {"url": image_base64}} ] }) - # 保存用户信息供工具调用使用 - self._current_user_wxid = user_wxid - self._current_is_group = is_group - payload = { "model": api_config["model"], "messages": messages, @@ -3031,7 +2991,7 @@ class AIChat(PluginBase): asyncio.create_task( self._execute_tools_async_with_image( tool_calls_data, bot, from_wxid, chat_id, - nickname, is_group, messages, image_base64 + user_wxid, nickname, is_group, messages, image_base64 ) ) return "" @@ -3211,14 +3171,25 @@ class AIChat(PluginBase): while True: try: task = await self.image_desc_queue.get() + except asyncio.CancelledError: + logger.info("图片描述工作协程收到取消信号,退出") + break + + try: await self._generate_and_update_image_description( task["bot"], task["from_wxid"], task["nickname"], task["cdnbigimgurl"], task["aeskey"], task["is_emoji"], task["placeholder_id"], task["config"] ) - self.image_desc_queue.task_done() + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"图片描述工作协程异常: {e}") + finally: + try: + self.image_desc_queue.task_done() + except ValueError: + pass async def _generate_and_update_image_description(self, bot, from_wxid: str, nickname: str, cdnbigimgurl: str, aeskey: str, is_emoji: bool, @@ -3247,6 +3218,8 @@ class AIChat(PluginBase): await self._update_history_by_id(from_wxid, placeholder_id, "[图片]") logger.warning(f"图片描述生成失败") + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"异步生成图片描述失败: {e}") await self._update_history_by_id(from_wxid, placeholder_id, "[图片]") diff --git a/plugins/AIChat_Gemini.zip b/plugins/AIChat_Gemini.zip deleted file mode 100644 index 5ddb892..0000000 Binary files a/plugins/AIChat_Gemini.zip and /dev/null differ diff --git a/plugins/AutoReply/main.py b/plugins/AutoReply/main.py index e5a9de9..ac7b7d2 100644 --- a/plugins/AutoReply/main.py +++ b/plugins/AutoReply/main.py @@ -115,6 +115,19 @@ class AutoReply(PluginBase): logger.error(f"[AutoReply] 初始化失败: {e}") self.config = None + async def on_disable(self): + """插件禁用时调用,清理后台判断任务""" + await super().on_disable() + + if self.pending_tasks: + for task in self.pending_tasks.values(): + task.cancel() + await asyncio.gather(*self.pending_tasks.values(), return_exceptions=True) + self.pending_tasks.clear() + + self.judging.clear() + logger.info("[AutoReply] 已清理后台判断任务") + def _load_bot_info(self): """加载机器人信息""" try: @@ -294,6 +307,8 @@ class AutoReply(PluginBase): # 直接调用 AIChat 生成回复(基于最新上下文) await self._trigger_ai_reply(bot, pending.from_wxid) + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"[AutoReply] 后台判断异常: {e}") import traceback @@ -316,8 +331,7 @@ class AutoReply(PluginBase): return # 获取最新的历史记录作为上下文 - chat_id = self._normalize_chat_id(from_wxid) - recent_context = await self._get_recent_context_for_reply(chat_id) + recent_context = await self._get_recent_context_for_reply(from_wxid) if not recent_context: logger.warning("[AutoReply] 无法获取上下文") @@ -342,10 +356,10 @@ class AutoReply(PluginBase): import traceback logger.error(traceback.format_exc()) - async def _get_recent_context_for_reply(self, chat_id: str) -> str: + async def _get_recent_context_for_reply(self, group_id: str) -> str: """获取最近的上下文用于生成回复""" try: - history = await self._get_history(chat_id) + history = await self._get_history(group_id) if not history: return "" @@ -353,35 +367,10 @@ class AutoReply(PluginBase): count = self.config.get('context', {}).get('messages_count', 5) recent = history[-count:] if len(history) > count else history - # 构建上下文摘要 - context_lines = [] - for record in recent: - nickname = record.get('nickname', '未知') - content = record.get('content', '') - if isinstance(content, list): - # 多模态内容,提取文本 - for item in content: - if item.get('type') == 'text': - content = item.get('text', '') - break - else: - content = '[图片]' - if len(content) > 50: - content = content[:50] + "..." - context_lines.append(f"{nickname}: {content}") - - # 返回最后一条消息作为触发内容(AIChat 会读取完整历史) - if recent: - last = recent[-1] - last_content = last.get('content', '') - if isinstance(last_content, list): - for item in last_content: - if item.get('type') == 'text': - return item.get('text', '') - return '[图片]' - return last_content - - return "" + # 自动回复触发不再把最后一条用户消息再次发给 AI, + # 避免在上下文里出现“同一句话重复两遍”的错觉。 + # AIChat 会读取完整历史 recent_history。 + return "(自动回复触发)请基于最近群聊内容,自然地回复一句,不要复述提示本身。" except Exception as e: logger.error(f"[AutoReply] 获取上下文失败: {e}") @@ -389,12 +378,13 @@ class AutoReply(PluginBase): async def _judge_with_small_model(self, from_wxid: str, content: str) -> JudgeResult: """使用小模型判断是否需要回复""" - chat_id = self._normalize_chat_id(from_wxid) - chat_state = self._get_chat_state(chat_id) + group_id = from_wxid + state_id = self._normalize_chat_id(group_id) + chat_state = self._get_chat_state(state_id) # 获取最近消息历史 - recent_messages = await self._get_recent_messages(chat_id) - last_bot_reply = await self._get_last_bot_reply(chat_id) + recent_messages = await self._get_recent_messages(group_id) + last_bot_reply = await self._get_last_bot_reply(group_id) # 构建判断提示词 reasoning_part = ',\n "reasoning": "简短分析原因(20字内)"' if self.config["judge"]["include_reasoning"] else "" @@ -403,7 +393,7 @@ class AutoReply(PluginBase): ## 当前状态 - 精力: {chat_state.energy:.1f}/1.0 -- 上次发言: {self._get_minutes_since_last_reply(chat_id)}分钟前 +- 上次发言: {self._get_minutes_since_last_reply(state_id)}分钟前 ## 最近对话 {recent_messages} @@ -531,6 +521,13 @@ class AutoReply(PluginBase): if not aichat_plugin: return [] + # 优先使用 AIChat 的统一 ContextStore + if hasattr(aichat_plugin, "store") and aichat_plugin.store: + try: + return await aichat_plugin.store.load_group_history(chat_id) + except Exception as e: + logger.debug(f"[AutoReply] ContextStore 获取历史失败: {e}") + # 优先使用 Redis(与 AIChat 保持一致) try: from utils.redis_cache import get_cache @@ -548,7 +545,8 @@ class AutoReply(PluginBase): # 降级到文件存储 if hasattr(aichat_plugin, 'history_dir') and aichat_plugin.history_dir: - history_file = aichat_plugin.history_dir / f"{chat_id}.json" + safe_id = (chat_id or "").replace("@", "_").replace(":", "_") + history_file = aichat_plugin.history_dir / f"{safe_id}.json" if history_file.exists(): with open(history_file, "r", encoding="utf-8") as f: return json.load(f) @@ -558,10 +556,10 @@ class AutoReply(PluginBase): return [] - async def _get_recent_messages(self, chat_id: str) -> str: - """获取最近消息历史""" + async def _get_recent_messages(self, group_id: str) -> str: + """获取最近消息历史(群聊)""" try: - history = await self._get_history(chat_id) + history = await self._get_history(group_id) if not history: return "暂无对话历史" @@ -572,6 +570,12 @@ class AutoReply(PluginBase): for record in recent: nickname = record.get('nickname', '未知') content = record.get('content', '') + if isinstance(content, list): + text_parts = [] + for item in content: + if item.get("type") == "text": + text_parts.append(item.get("text", "")) + content = "".join(text_parts).strip() or "[图片]" # 限制单条消息长度 if len(content) > 100: content = content[:100] + "..." @@ -584,17 +588,23 @@ class AutoReply(PluginBase): return "暂无对话历史" - async def _get_last_bot_reply(self, chat_id: str) -> Optional[str]: - """获取上次机器人回复""" + async def _get_last_bot_reply(self, group_id: str) -> Optional[str]: + """获取上次机器人回复(群聊)""" try: - history = await self._get_history(chat_id) + history = await self._get_history(group_id) if not history: return None # 从后往前查找机器人回复 for record in reversed(history): - if record.get('nickname') == self.bot_nickname: + if record.get('role') == 'assistant' or record.get('nickname') == self.bot_nickname: content = record.get('content', '') + if isinstance(content, list): + text_parts = [] + for item in content: + if item.get("type") == "text": + text_parts.append(item.get("text", "")) + content = "".join(text_parts).strip() or "[图片]" if len(content) > 100: content = content[:100] + "..." return content diff --git a/plugins/SignInPlugin/main.py b/plugins/SignInPlugin/main.py index 6010334..187601d 100644 --- a/plugins/SignInPlugin/main.py +++ b/plugins/SignInPlugin/main.py @@ -1606,9 +1606,8 @@ class SignInPlugin(PluginBase): return {"success": True, "message": f"城市注册请求已处理: {city}"} else: - return {"success": False, "message": "未知的工具名称"} + return None except Exception as e: logger.error(f"LLM工具执行失败: {e}") return {"success": False, "message": f"执行失败: {str(e)}"} - diff --git a/plugins/Weather/main.py b/plugins/Weather/main.py index e79c905..53d87ce 100644 --- a/plugins/Weather/main.py +++ b/plugins/Weather/main.py @@ -322,7 +322,7 @@ class WeatherPlugin(PluginBase): """执行LLM工具调用,供AIChat插件调用""" try: if tool_name != "query_weather": - return {"success": False, "message": "未知的工具名称"} + return None # 从 arguments 中获取用户信息 user_wxid = arguments.get("user_wxid", from_wxid) diff --git a/plugins/ZImageTurbo/main.py b/plugins/ZImageTurbo/main.py index a367202..41e17e5 100644 --- a/plugins/ZImageTurbo/main.py +++ b/plugins/ZImageTurbo/main.py @@ -345,7 +345,7 @@ class ZImageTurbo(PluginBase): async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict: """执行LLM工具调用,供AIChat插件调用""" if tool_name != "generate_image": - return {"success": False, "message": "未知的工具名称"} + return None try: prompt = arguments.get("prompt", "") diff --git a/utils/context_store.py b/utils/context_store.py new file mode 100644 index 0000000..f00e484 --- /dev/null +++ b/utils/context_store.py @@ -0,0 +1,470 @@ +""" +上下文/存储统一封装 + +提供统一的会话上下文读写接口: +- 私聊/单人会话 memory(优先 Redis,降级内存) +- 群聊 history(优先 Redis,降级文件) +- 持久记忆 sqlite + +AIChat 只需要通过本模块读写消息,不再关心介质细节。 +""" + +from __future__ import annotations + +import asyncio +import json +import sqlite3 +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional + +from loguru import logger + +from utils.redis_cache import get_cache + + +def _safe_chat_id(chat_id: str) -> str: + return (chat_id or "").replace("@", "_").replace(":", "_") + + +def _extract_text_from_multimodal(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + parts.append(item.get("text", "")) + return "".join(parts).strip() + return str(content) + + +@dataclass +class HistoryRecord: + role: str = "user" + nickname: str = "" + content: Any = "" + timestamp: Any = None + wxid: Optional[str] = None + id: Optional[str] = None + + @classmethod + def from_raw(cls, raw: Dict[str, Any]) -> "HistoryRecord": + role = raw.get("role") or "user" + nickname = raw.get("nickname") or raw.get("SenderNickname") or "" + content = raw.get("content") if "content" in raw else raw.get("Content", "") + ts = raw.get("timestamp") or raw.get("time") or raw.get("CreateTime") + wxid = raw.get("wxid") or raw.get("SenderWxid") + rid = raw.get("id") or raw.get("msgid") + return cls(role=role, nickname=nickname, content=content, timestamp=ts, wxid=wxid, id=rid) + + def to_dict(self) -> Dict[str, Any]: + d = { + "role": self.role or "user", + "nickname": self.nickname, + "content": self.content, + } + if self.timestamp is not None: + d["timestamp"] = self.timestamp + if self.wxid: + d["wxid"] = self.wxid + if self.id: + d["id"] = self.id + return d + + +class ContextStore: + """ + 统一上下文存储。 + + Args: + config: AIChat 配置 dict + history_dir: 历史文件目录(群聊降级) + memory_fallback: AIChat 内存 dict(私聊降级) + history_locks: AIChat locks dict(文件写入) + persistent_db_path: sqlite 文件路径 + """ + + def __init__( + self, + config: Dict[str, Any], + history_dir: Optional[Path], + memory_fallback: Dict[str, List[Dict[str, Any]]], + history_locks: Dict[str, asyncio.Lock], + persistent_db_path: Optional[Path], + ): + self.config = config or {} + self.history_dir = history_dir + self.memory_fallback = memory_fallback + self.history_locks = history_locks + self.persistent_db_path = persistent_db_path + + # ------------------ 私聊 memory ------------------ + + def _use_redis_for_memory(self) -> bool: + redis_config = self.config.get("redis", {}) + if not redis_config.get("use_redis_history", True): + return False + redis_cache = get_cache() + return bool(redis_cache and redis_cache.enabled) + + def add_private_message( + self, + chat_id: str, + role: str, + content: Any, + *, + image_base64: str = None, + nickname: str = "", + sender_wxid: str = None, + ) -> None: + if not self.config.get("memory", {}).get("enabled", False): + return + + if image_base64: + message_content = [ + {"type": "text", "text": _extract_text_from_multimodal(content)}, + {"type": "image_url", "image_url": {"url": image_base64}}, + ] + else: + message_content = content + + redis_config = self.config.get("redis", {}) + if self._use_redis_for_memory(): + redis_cache = get_cache() + ttl = redis_config.get("chat_history_ttl", 86400) + try: + redis_cache.add_chat_message( + chat_id, + role, + message_content, + nickname=nickname, + sender_wxid=sender_wxid, + ttl=ttl, + ) + max_messages = self.config.get("memory", {}).get("max_messages", 20) + redis_cache.trim_chat_history(chat_id, max_messages) + return + except Exception as e: + logger.debug(f"[ContextStore] Redis private history 写入失败: {e}") + + if chat_id not in self.memory_fallback: + self.memory_fallback[chat_id] = [] + self.memory_fallback[chat_id].append({"role": role, "content": message_content}) + max_messages = self.config.get("memory", {}).get("max_messages", 20) + if len(self.memory_fallback[chat_id]) > max_messages: + self.memory_fallback[chat_id] = self.memory_fallback[chat_id][-max_messages:] + + def get_private_messages(self, chat_id: str) -> List[Dict[str, Any]]: + if not self.config.get("memory", {}).get("enabled", False): + return [] + + if self._use_redis_for_memory(): + redis_cache = get_cache() + max_messages = self.config.get("memory", {}).get("max_messages", 20) + try: + history = redis_cache.get_chat_history(chat_id, max_messages) + return [HistoryRecord.from_raw(h).to_dict() for h in history] + except Exception as e: + logger.debug(f"[ContextStore] Redis private history 读取失败: {e}") + + return self.memory_fallback.get(chat_id, []) + + def clear_private_messages(self, chat_id: str) -> None: + if self._use_redis_for_memory(): + redis_cache = get_cache() + try: + redis_cache.clear_chat_history(chat_id) + except Exception: + pass + self.memory_fallback.pop(chat_id, None) + + # ------------------ 群聊 history ------------------ + + def _use_redis_for_group_history(self) -> bool: + redis_config = self.config.get("redis", {}) + if not redis_config.get("use_redis_history", True): + return False + redis_cache = get_cache() + return bool(redis_cache and redis_cache.enabled) + + def _get_history_file(self, chat_id: str) -> Optional[Path]: + if not self.history_dir: + return None + return self.history_dir / f"{_safe_chat_id(chat_id)}.json" + + def _get_history_lock(self, chat_id: str) -> asyncio.Lock: + lock = self.history_locks.get(chat_id) + if lock is None: + lock = asyncio.Lock() + self.history_locks[chat_id] = lock + return lock + + def _read_history_file(self, history_file: Path) -> List[Dict[str, Any]]: + try: + with open(history_file, "r", encoding="utf-8") as f: + return json.load(f) + except FileNotFoundError: + return [] + except Exception as e: + logger.error(f"读取历史记录失败: {history_file}, {e}") + return [] + + def _write_history_file(self, history_file: Path, history: List[Dict[str, Any]]) -> None: + history_file.parent.mkdir(parents=True, exist_ok=True) + temp_file = Path(str(history_file) + ".tmp") + with open(temp_file, "w", encoding="utf-8") as f: + json.dump(history, f, ensure_ascii=False, indent=2) + temp_file.replace(history_file) + + async def load_group_history(self, chat_id: str) -> List[Dict[str, Any]]: + if not self.config.get("history", {}).get("enabled", True): + return [] + + if self._use_redis_for_group_history(): + redis_cache = get_cache() + max_history = self.config.get("history", {}).get("max_history", 100) + try: + history = redis_cache.get_group_history(chat_id, max_history) + return [HistoryRecord.from_raw(h).to_dict() for h in history] + except Exception as e: + logger.debug(f"[ContextStore] Redis group history 读取失败: {e}") + + history_file = self._get_history_file(chat_id) + if not history_file: + return [] + + lock = self._get_history_lock(chat_id) + async with lock: + raw_history = self._read_history_file(history_file) + return [HistoryRecord.from_raw(h).to_dict() for h in raw_history] + + async def add_group_message( + self, + chat_id: str, + nickname: str, + content: Any, + *, + record_id: str = None, + image_base64: str = None, + role: str = "user", + sender_wxid: str = None, + ) -> None: + if not self.config.get("history", {}).get("enabled", True): + return + + if image_base64: + message_content = [ + {"type": "text", "text": _extract_text_from_multimodal(content)}, + {"type": "image_url", "image_url": {"url": image_base64}}, + ] + else: + message_content = content + + if self._use_redis_for_group_history(): + redis_cache = get_cache() + redis_config = self.config.get("redis", {}) + ttl = redis_config.get("group_history_ttl", 172800) + try: + redis_cache.add_group_message( + chat_id, + nickname, + message_content, + record_id=record_id, + role=role, + sender_wxid=sender_wxid, + ttl=ttl, + ) + max_history = self.config.get("history", {}).get("max_history", 100) + redis_cache.trim_group_history(chat_id, max_history) + return + except Exception as e: + logger.debug(f"[ContextStore] Redis group history 写入失败: {e}") + + history_file = self._get_history_file(chat_id) + if not history_file: + return + + lock = self._get_history_lock(chat_id) + async with lock: + history = self._read_history_file(history_file) + record = HistoryRecord( + role=role or "user", + nickname=nickname, + content=message_content, + timestamp=datetime.now().isoformat(), + wxid=sender_wxid, + id=record_id, + ) + history.append(record.to_dict()) + max_history = self.config.get("history", {}).get("max_history", 100) + if len(history) > max_history: + history = history[-max_history:] + self._write_history_file(history_file, history) + + async def update_group_message_by_id(self, chat_id: str, record_id: str, new_content: Any) -> None: + if not self.config.get("history", {}).get("enabled", True): + return + + if self._use_redis_for_group_history(): + redis_cache = get_cache() + try: + redis_cache.update_group_message_by_id(chat_id, record_id, new_content) + return + except Exception as e: + logger.debug(f"[ContextStore] Redis group history 更新失败: {e}") + + history_file = self._get_history_file(chat_id) + if not history_file: + return + + lock = self._get_history_lock(chat_id) + async with lock: + history = self._read_history_file(history_file) + for rec in history: + if rec.get("id") == record_id: + rec["content"] = new_content + break + max_history = self.config.get("history", {}).get("max_history", 100) + if len(history) > max_history: + history = history[-max_history:] + self._write_history_file(history_file, history) + + # ------------------ 持久记忆 sqlite ------------------ + + def init_persistent_memory_db(self) -> Optional[Path]: + if not self.persistent_db_path: + return None + + self.persistent_db_path.parent.mkdir(exist_ok=True, parents=True) + conn = sqlite3.connect(self.persistent_db_path) + cursor = conn.cursor() + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS memories ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + chat_id TEXT NOT NULL, + chat_type TEXT NOT NULL, + user_wxid TEXT NOT NULL, + user_nickname TEXT, + content TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + cursor.execute("CREATE INDEX IF NOT EXISTS idx_chat_id ON memories(chat_id)") + conn.commit() + conn.close() + logger.info(f"持久记忆数据库已初始化: {self.persistent_db_path}") + return self.persistent_db_path + + def add_persistent_memory( + self, + chat_id: str, + chat_type: str, + user_wxid: str, + user_nickname: str, + content: str, + ) -> int: + if not self.persistent_db_path: + return -1 + + conn = sqlite3.connect(self.persistent_db_path) + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO memories (chat_id, chat_type, user_wxid, user_nickname, content) + VALUES (?, ?, ?, ?, ?) + """, + (chat_id, chat_type, user_wxid, user_nickname, content), + ) + memory_id = cursor.lastrowid + conn.commit() + conn.close() + return memory_id + + def get_persistent_memories(self, chat_id: str) -> List[Dict[str, Any]]: + if not self.persistent_db_path: + return [] + + conn = sqlite3.connect(self.persistent_db_path) + cursor = conn.cursor() + cursor.execute( + """ + SELECT id, user_nickname, content, created_at + FROM memories + WHERE chat_id = ? + ORDER BY created_at ASC + """, + (chat_id,), + ) + rows = cursor.fetchall() + conn.close() + return [ + {"id": r[0], "nickname": r[1], "content": r[2], "time": r[3]} + for r in rows + ] + + def delete_persistent_memory(self, chat_id: str, memory_id: int) -> bool: + if not self.persistent_db_path: + return False + + conn = sqlite3.connect(self.persistent_db_path) + cursor = conn.cursor() + cursor.execute( + "DELETE FROM memories WHERE id = ? AND chat_id = ?", + (memory_id, chat_id), + ) + deleted = cursor.rowcount > 0 + conn.commit() + conn.close() + return deleted + + def clear_persistent_memories(self, chat_id: str) -> int: + if not self.persistent_db_path: + return 0 + + conn = sqlite3.connect(self.persistent_db_path) + cursor = conn.cursor() + cursor.execute("DELETE FROM memories WHERE chat_id = ?", (chat_id,)) + deleted_count = cursor.rowcount + conn.commit() + conn.close() + return deleted_count + + # ------------------ 旧数据扫描/清理 ------------------ + + def find_legacy_group_history_keys(self) -> List[str]: + """ + 发现旧版本使用 safe_id 写入的 group_history key。 + + Returns: + legacy_keys 列表(不删除) + """ + redis_cache = get_cache() + if not redis_cache or not redis_cache.enabled: + return [] + + try: + keys = redis_cache.client.keys("group_history:*") + legacy = [] + for k in keys or []: + # 新 key 一般包含 @chatroom;旧 safe_id 不包含 @ + if "@chatroom" not in k and "_" in k: + legacy.append(k) + return legacy + except Exception as e: + logger.debug(f"[ContextStore] 扫描 legacy group_history keys 失败: {e}") + return [] + + def delete_legacy_group_history_keys(self, legacy_keys: List[str]) -> int: + """删除给定 legacy key 列表""" + redis_cache = get_cache() + if not redis_cache or not redis_cache.enabled or not legacy_keys: + return 0 + try: + return redis_cache.client.delete(*legacy_keys) + except Exception as e: + logger.debug(f"[ContextStore] 删除 legacy group_history keys 失败: {e}") + return 0 diff --git a/utils/llm_tooling.py b/utils/llm_tooling.py new file mode 100644 index 0000000..a8d845c --- /dev/null +++ b/utils/llm_tooling.py @@ -0,0 +1,183 @@ +""" +LLM 工具体系公共模块 + +统一工具收集、参数校验与执行结果结构,供 AIChat 等插件使用。 +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +from loguru import logger + + +@dataclass +class ToolResult: + """统一的工具执行结果结构""" + + success: bool = True + message: str = "" + need_ai_reply: bool = False + already_sent: bool = False + send_result_text: bool = False + no_reply: bool = False + save_to_memory: bool = False + + @classmethod + def from_raw(cls, raw: Any) -> Optional["ToolResult"]: + if raw is None: + return None + + if not isinstance(raw, dict): + return cls(success=True, message=str(raw)) + + msg = raw.get("message", "") + if not isinstance(msg, str): + try: + msg = json.dumps(msg, ensure_ascii=False) + except Exception: + msg = str(msg) + + return cls( + success=bool(raw.get("success", True)), + message=msg, + need_ai_reply=bool(raw.get("need_ai_reply", False)), + already_sent=bool(raw.get("already_sent", False)), + send_result_text=bool(raw.get("send_result_text", False)), + no_reply=bool(raw.get("no_reply", False)), + save_to_memory=bool(raw.get("save_to_memory", False)), + ) + + +def collect_tools_with_plugins( + tools_config: Dict[str, Any], + plugins: Dict[str, Any], +) -> Dict[str, Tuple[str, Dict[str, Any]]]: + """ + 收集所有插件的 LLM 工具,并保留来源插件名。 + + Args: + tools_config: AIChat 配置中的 [tools] 节 + plugins: PluginManager().plugins 映射 + + Returns: + {tool_name: (plugin_name, tool_dict)} + """ + tools_by_name: Dict[str, Tuple[str, Dict[str, Any]]] = {} + + mode = tools_config.get("mode", "all") + whitelist = set(tools_config.get("whitelist", [])) + blacklist = set(tools_config.get("blacklist", [])) + + for plugin_name, plugin in plugins.items(): + if not hasattr(plugin, "get_llm_tools"): + continue + + plugin_tools = plugin.get_llm_tools() or [] + for tool in plugin_tools: + tool_name = tool.get("function", {}).get("name", "") + if not tool_name: + continue + + if mode == "whitelist" and tool_name not in whitelist: + continue + if mode == "blacklist" and tool_name in blacklist: + logger.debug(f"[黑名单] 禁用工具: {tool_name}") + continue + + if tool_name in tools_by_name: + logger.warning(f"重复工具名 {tool_name} 来自 {plugin_name},已忽略") + continue + + tools_by_name[tool_name] = (plugin_name, tool) + if mode == "whitelist": + logger.debug(f"[白名单] 启用工具: {tool_name}") + + return tools_by_name + + +def collect_tools( + tools_config: Dict[str, Any], + plugins: Dict[str, Any], +) -> List[Dict[str, Any]]: + """仅返回工具定义列表""" + return [item[1] for item in collect_tools_with_plugins(tools_config, plugins).values()] + + +def get_tool_schema_map( + tools_map: Dict[str, Tuple[str, Dict[str, Any]]], +) -> Dict[str, Dict[str, Any]]: + """构建工具名到参数 schema 的映射""" + schema_map: Dict[str, Dict[str, Any]] = {} + for name, (_plugin_name, tool) in tools_map.items(): + fn = tool.get("function", {}) + schema_map[name] = fn.get("parameters", {}) or {} + return schema_map + + +def validate_tool_arguments( + tool_name: str, + arguments: Dict[str, Any], + schema: Optional[Dict[str, Any]], +) -> Tuple[bool, str, Dict[str, Any]]: + """ + 轻量校验并补全默认参数。 + + Returns: + (ok, error_message, new_arguments) + """ + if not schema: + return True, "", arguments + + props = schema.get("properties", {}) or {} + required = schema.get("required", []) or [] + + # 应用默认值 + for key, prop in props.items(): + if key not in arguments and isinstance(prop, dict) and "default" in prop: + arguments[key] = prop["default"] + + missing = [] + for key in required: + if key not in arguments or arguments[key] in (None, "", []): + missing.append(key) + + if missing: + return False, f"缺少参数: {', '.join(missing)}", arguments + + # 枚举与基础类型校验 + for key, prop in props.items(): + if key not in arguments or not isinstance(prop, dict): + continue + + value = arguments[key] + + if "enum" in prop and value not in prop["enum"]: + return False, f"参数 {key} 必须是 {prop['enum']}", arguments + + expected_type = prop.get("type") + if expected_type == "integer": + try: + arguments[key] = int(value) + except Exception: + return False, f"参数 {key} 应为整数", arguments + elif expected_type == "number": + try: + arguments[key] = float(value) + except Exception: + return False, f"参数 {key} 应为数字", arguments + elif expected_type == "boolean": + if isinstance(value, bool): + continue + if isinstance(value, str) and value.lower() in ("true", "false", "1", "0"): + arguments[key] = value.lower() in ("true", "1") + else: + return False, f"参数 {key} 应为布尔值", arguments + elif expected_type == "string": + if not isinstance(value, str): + arguments[key] = str(value) + + return True, "", arguments + diff --git a/utils/redis_cache.py b/utils/redis_cache.py index 2615eb5..04a849b 100644 --- a/utils/redis_cache.py +++ b/utils/redis_cache.py @@ -322,7 +322,18 @@ class RedisCache: logger.error(f"获取对话历史失败: {chat_id}, {e}") return [] - def add_chat_message(self, chat_id: str, role: str, content, ttl: int = 86400) -> bool: + def add_chat_message( + self, + chat_id: str, + role: str, + content, + ttl: int = 86400, + *, + nickname: str = None, + sender_wxid: str = None, + record_id: str = None, + timestamp: float = None, + ) -> bool: """ 添加消息到对话历史 @@ -331,6 +342,10 @@ class RedisCache: role: 角色 (user/assistant) content: 消息内容(字符串或列表) ttl: 过期时间(秒),默认24小时 + nickname: 可选昵称(用于统一 schema) + sender_wxid: 可选发送者 wxid + record_id: 可选记录 ID + timestamp: 可选时间戳 Returns: 是否添加成功 @@ -340,7 +355,18 @@ class RedisCache: try: key = self._make_key("chat_history", chat_id) - message = {"role": role, "content": content} + import time as _time + message = { + "role": role or "user", + "content": content, + } + if nickname: + message["nickname"] = nickname + if sender_wxid: + message["wxid"] = sender_wxid + if record_id: + message["id"] = record_id + message["timestamp"] = timestamp or _time.time() self.client.rpush(key, json.dumps(message, ensure_ascii=False)) self.client.expire(key, ttl) return True @@ -416,8 +442,17 @@ class RedisCache: logger.error(f"获取群聊历史失败: {group_id}, {e}") return [] - def add_group_message(self, group_id: str, nickname: str, content, - record_id: str = None, ttl: int = 86400) -> bool: + def add_group_message( + self, + group_id: str, + nickname: str, + content, + record_id: str = None, + *, + role: str = "user", + sender_wxid: str = None, + ttl: int = 86400, + ) -> bool: """ 添加消息到群聊历史 @@ -426,6 +461,8 @@ class RedisCache: nickname: 发送者昵称 content: 消息内容 record_id: 可选的记录ID,用于后续更新 + role: 角色 (user/assistant),默认 user + sender_wxid: 可选的发送者 wxid ttl: 过期时间(秒),默认24小时 Returns: @@ -438,10 +475,13 @@ class RedisCache: import time key = self._make_key("group_history", group_id) message = { + "role": role or "user", "nickname": nickname, "content": content, "timestamp": time.time() } + if sender_wxid: + message["wxid"] = sender_wxid if record_id: message["id"] = record_id