""" 上下文/存储统一封装 提供统一的会话上下文读写接口: - 私聊/单人会话 memory(优先 Redis,降级内存) - 群聊 history(优先 Redis,降级文件) - 持久记忆 sqlite AIChat 只需要通过本模块读写消息,不再关心介质细节。 """ from __future__ import annotations import asyncio import json import sqlite3 from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional from loguru import logger from utils.redis_cache import get_cache def _safe_chat_id(chat_id: str) -> str: return (chat_id or "").replace("@", "_").replace(":", "_") def _extract_text_from_multimodal(content: Any) -> str: if isinstance(content, str): return content if isinstance(content, list): parts = [] for item in content: if isinstance(item, dict) and item.get("type") == "text": parts.append(item.get("text", "")) return "".join(parts).strip() return str(content) @dataclass class HistoryRecord: role: str = "user" nickname: str = "" content: Any = "" timestamp: Any = None wxid: Optional[str] = None id: Optional[str] = None @classmethod def from_raw(cls, raw: Dict[str, Any]) -> "HistoryRecord": role = raw.get("role") or "user" nickname = raw.get("nickname") or raw.get("SenderNickname") or "" content = raw.get("content") if "content" in raw else raw.get("Content", "") ts = raw.get("timestamp") or raw.get("time") or raw.get("CreateTime") wxid = raw.get("wxid") or raw.get("SenderWxid") rid = raw.get("id") or raw.get("msgid") return cls(role=role, nickname=nickname, content=content, timestamp=ts, wxid=wxid, id=rid) def to_dict(self) -> Dict[str, Any]: d = { "role": self.role or "user", "nickname": self.nickname, "content": self.content, } if self.timestamp is not None: d["timestamp"] = self.timestamp if self.wxid: d["wxid"] = self.wxid if self.id: d["id"] = self.id return d class ContextStore: """ 统一上下文存储。 Args: config: AIChat 配置 dict history_dir: 历史文件目录(群聊降级) memory_fallback: AIChat 内存 dict(私聊降级) history_locks: AIChat locks dict(文件写入) persistent_db_path: sqlite 文件路径 """ def __init__( self, config: Dict[str, Any], history_dir: Optional[Path], memory_fallback: Dict[str, List[Dict[str, Any]]], history_locks: Dict[str, asyncio.Lock], persistent_db_path: Optional[Path], ): self.config = config or {} self.history_dir = history_dir self.memory_fallback = memory_fallback self.history_locks = history_locks self.persistent_db_path = persistent_db_path # ------------------ 私聊 memory ------------------ def _use_redis_for_memory(self) -> bool: redis_config = self.config.get("redis", {}) if not redis_config.get("use_redis_history", True): return False redis_cache = get_cache() return bool(redis_cache and redis_cache.enabled) def add_private_message( self, chat_id: str, role: str, content: Any, *, image_base64: str = None, nickname: str = "", sender_wxid: str = None, ) -> None: if not self.config.get("memory", {}).get("enabled", False): return if image_base64: message_content = [ {"type": "text", "text": _extract_text_from_multimodal(content)}, {"type": "image_url", "image_url": {"url": image_base64}}, ] else: message_content = content redis_config = self.config.get("redis", {}) if self._use_redis_for_memory(): redis_cache = get_cache() ttl = redis_config.get("chat_history_ttl", 86400) try: redis_cache.add_chat_message( chat_id, role, message_content, nickname=nickname, sender_wxid=sender_wxid, ttl=ttl, ) max_messages = self.config.get("memory", {}).get("max_messages", 20) redis_cache.trim_chat_history(chat_id, max_messages) return except Exception as e: logger.debug(f"[ContextStore] Redis private history 写入失败: {e}") if chat_id not in self.memory_fallback: self.memory_fallback[chat_id] = [] self.memory_fallback[chat_id].append({"role": role, "content": message_content}) max_messages = self.config.get("memory", {}).get("max_messages", 20) if len(self.memory_fallback[chat_id]) > max_messages: self.memory_fallback[chat_id] = self.memory_fallback[chat_id][-max_messages:] def get_private_messages(self, chat_id: str) -> List[Dict[str, Any]]: if not self.config.get("memory", {}).get("enabled", False): return [] if self._use_redis_for_memory(): redis_cache = get_cache() max_messages = self.config.get("memory", {}).get("max_messages", 20) try: history = redis_cache.get_chat_history(chat_id, max_messages) return [HistoryRecord.from_raw(h).to_dict() for h in history] except Exception as e: logger.debug(f"[ContextStore] Redis private history 读取失败: {e}") return self.memory_fallback.get(chat_id, []) def clear_private_messages(self, chat_id: str) -> None: if self._use_redis_for_memory(): redis_cache = get_cache() try: redis_cache.clear_chat_history(chat_id) except Exception: pass self.memory_fallback.pop(chat_id, None) # ------------------ 群聊 history ------------------ def _use_redis_for_group_history(self) -> bool: redis_config = self.config.get("redis", {}) if not redis_config.get("use_redis_history", True): return False redis_cache = get_cache() return bool(redis_cache and redis_cache.enabled) def _get_history_file(self, chat_id: str) -> Optional[Path]: if not self.history_dir: return None return self.history_dir / f"{_safe_chat_id(chat_id)}.json" def _get_history_lock(self, chat_id: str) -> asyncio.Lock: lock = self.history_locks.get(chat_id) if lock is None: lock = asyncio.Lock() self.history_locks[chat_id] = lock return lock def _read_history_file(self, history_file: Path) -> List[Dict[str, Any]]: try: with open(history_file, "r", encoding="utf-8") as f: return json.load(f) except FileNotFoundError: return [] except Exception as e: logger.error(f"读取历史记录失败: {history_file}, {e}") return [] def _write_history_file(self, history_file: Path, history: List[Dict[str, Any]]) -> None: history_file.parent.mkdir(parents=True, exist_ok=True) temp_file = Path(str(history_file) + ".tmp") with open(temp_file, "w", encoding="utf-8") as f: json.dump(history, f, ensure_ascii=False, indent=2) temp_file.replace(history_file) async def load_group_history(self, chat_id: str) -> List[Dict[str, Any]]: if not self.config.get("history", {}).get("enabled", True): return [] if self._use_redis_for_group_history(): redis_cache = get_cache() max_history = self.config.get("history", {}).get("max_history", 100) try: history = redis_cache.get_group_history(chat_id, max_history) return [HistoryRecord.from_raw(h).to_dict() for h in history] except Exception as e: logger.debug(f"[ContextStore] Redis group history 读取失败: {e}") history_file = self._get_history_file(chat_id) if not history_file: return [] lock = self._get_history_lock(chat_id) async with lock: raw_history = self._read_history_file(history_file) return [HistoryRecord.from_raw(h).to_dict() for h in raw_history] async def add_group_message( self, chat_id: str, nickname: str, content: Any, *, record_id: str = None, image_base64: str = None, role: str = "user", sender_wxid: str = None, ) -> None: if not self.config.get("history", {}).get("enabled", True): return if image_base64: message_content = [ {"type": "text", "text": _extract_text_from_multimodal(content)}, {"type": "image_url", "image_url": {"url": image_base64}}, ] else: message_content = content if self._use_redis_for_group_history(): redis_cache = get_cache() redis_config = self.config.get("redis", {}) ttl = redis_config.get("group_history_ttl", 172800) try: redis_cache.add_group_message( chat_id, nickname, message_content, record_id=record_id, role=role, sender_wxid=sender_wxid, ttl=ttl, ) max_history = self.config.get("history", {}).get("max_history", 100) redis_cache.trim_group_history(chat_id, max_history) return except Exception as e: logger.debug(f"[ContextStore] Redis group history 写入失败: {e}") history_file = self._get_history_file(chat_id) if not history_file: return lock = self._get_history_lock(chat_id) async with lock: history = self._read_history_file(history_file) record = HistoryRecord( role=role or "user", nickname=nickname, content=message_content, timestamp=datetime.now().isoformat(), wxid=sender_wxid, id=record_id, ) history.append(record.to_dict()) max_history = self.config.get("history", {}).get("max_history", 100) if len(history) > max_history: history = history[-max_history:] self._write_history_file(history_file, history) async def update_group_message_by_id(self, chat_id: str, record_id: str, new_content: Any) -> None: if not self.config.get("history", {}).get("enabled", True): return if self._use_redis_for_group_history(): redis_cache = get_cache() try: redis_cache.update_group_message_by_id(chat_id, record_id, new_content) return except Exception as e: logger.debug(f"[ContextStore] Redis group history 更新失败: {e}") history_file = self._get_history_file(chat_id) if not history_file: return lock = self._get_history_lock(chat_id) async with lock: history = self._read_history_file(history_file) for rec in history: if rec.get("id") == record_id: rec["content"] = new_content break max_history = self.config.get("history", {}).get("max_history", 100) if len(history) > max_history: history = history[-max_history:] self._write_history_file(history_file, history) # ------------------ 持久记忆 sqlite ------------------ def init_persistent_memory_db(self) -> Optional[Path]: if not self.persistent_db_path: return None self.persistent_db_path.parent.mkdir(exist_ok=True, parents=True) conn = sqlite3.connect(self.persistent_db_path) cursor = conn.cursor() cursor.execute( """ CREATE TABLE IF NOT EXISTS memories ( id INTEGER PRIMARY KEY AUTOINCREMENT, chat_id TEXT NOT NULL, chat_type TEXT NOT NULL, user_wxid TEXT NOT NULL, user_nickname TEXT, content TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """ ) cursor.execute("CREATE INDEX IF NOT EXISTS idx_chat_id ON memories(chat_id)") conn.commit() conn.close() logger.info(f"持久记忆数据库已初始化: {self.persistent_db_path}") return self.persistent_db_path def add_persistent_memory( self, chat_id: str, chat_type: str, user_wxid: str, user_nickname: str, content: str, ) -> int: if not self.persistent_db_path: return -1 conn = sqlite3.connect(self.persistent_db_path) cursor = conn.cursor() cursor.execute( """ INSERT INTO memories (chat_id, chat_type, user_wxid, user_nickname, content) VALUES (?, ?, ?, ?, ?) """, (chat_id, chat_type, user_wxid, user_nickname, content), ) memory_id = cursor.lastrowid conn.commit() conn.close() return memory_id def get_persistent_memories(self, chat_id: str) -> List[Dict[str, Any]]: if not self.persistent_db_path: return [] conn = sqlite3.connect(self.persistent_db_path) cursor = conn.cursor() cursor.execute( """ SELECT id, user_nickname, content, created_at FROM memories WHERE chat_id = ? ORDER BY created_at ASC """, (chat_id,), ) rows = cursor.fetchall() conn.close() return [ {"id": r[0], "nickname": r[1], "content": r[2], "time": r[3]} for r in rows ] def delete_persistent_memory(self, chat_id: str, memory_id: int) -> bool: if not self.persistent_db_path: return False conn = sqlite3.connect(self.persistent_db_path) cursor = conn.cursor() cursor.execute( "DELETE FROM memories WHERE id = ? AND chat_id = ?", (memory_id, chat_id), ) deleted = cursor.rowcount > 0 conn.commit() conn.close() return deleted def clear_persistent_memories(self, chat_id: str) -> int: if not self.persistent_db_path: return 0 conn = sqlite3.connect(self.persistent_db_path) cursor = conn.cursor() cursor.execute("DELETE FROM memories WHERE chat_id = ?", (chat_id,)) deleted_count = cursor.rowcount conn.commit() conn.close() return deleted_count # ------------------ 旧数据扫描/清理 ------------------ def find_legacy_group_history_keys(self) -> List[str]: """ 发现旧版本使用 safe_id 写入的 group_history key。 Returns: legacy_keys 列表(不删除) """ redis_cache = get_cache() if not redis_cache or not redis_cache.enabled: return [] try: keys = redis_cache.client.keys("group_history:*") legacy = [] for k in keys or []: # 新 key 一般包含 @chatroom;旧 safe_id 不包含 @ if "@chatroom" not in k and "_" in k: legacy.append(k) return legacy except Exception as e: logger.debug(f"[ContextStore] 扫描 legacy group_history keys 失败: {e}") return [] def delete_legacy_group_history_keys(self, legacy_keys: List[str]) -> int: """删除给定 legacy key 列表""" redis_cache = get_cache() if not redis_cache or not redis_cache.enabled or not legacy_keys: return 0 try: return redis_cache.client.delete(*legacy_keys) except Exception as e: logger.debug(f"[ContextStore] 删除 legacy group_history keys 失败: {e}") return 0