471 lines
16 KiB
Python
471 lines
16 KiB
Python
"""
|
||
上下文/存储统一封装
|
||
|
||
提供统一的会话上下文读写接口:
|
||
- 私聊/单人会话 memory(优先 Redis,降级内存)
|
||
- 群聊 history(优先 Redis,降级文件)
|
||
- 持久记忆 sqlite
|
||
|
||
AIChat 只需要通过本模块读写消息,不再关心介质细节。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import sqlite3
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from loguru import logger
|
||
|
||
from utils.redis_cache import get_cache
|
||
|
||
|
||
def _safe_chat_id(chat_id: str) -> str:
|
||
return (chat_id or "").replace("@", "_").replace(":", "_")
|
||
|
||
|
||
def _extract_text_from_multimodal(content: Any) -> str:
|
||
if isinstance(content, str):
|
||
return content
|
||
if isinstance(content, list):
|
||
parts = []
|
||
for item in content:
|
||
if isinstance(item, dict) and item.get("type") == "text":
|
||
parts.append(item.get("text", ""))
|
||
return "".join(parts).strip()
|
||
return str(content)
|
||
|
||
|
||
@dataclass
|
||
class HistoryRecord:
|
||
role: str = "user"
|
||
nickname: str = ""
|
||
content: Any = ""
|
||
timestamp: Any = None
|
||
wxid: Optional[str] = None
|
||
id: Optional[str] = None
|
||
|
||
@classmethod
|
||
def from_raw(cls, raw: Dict[str, Any]) -> "HistoryRecord":
|
||
role = raw.get("role") or "user"
|
||
nickname = raw.get("nickname") or raw.get("SenderNickname") or ""
|
||
content = raw.get("content") if "content" in raw else raw.get("Content", "")
|
||
ts = raw.get("timestamp") or raw.get("time") or raw.get("CreateTime")
|
||
wxid = raw.get("wxid") or raw.get("SenderWxid")
|
||
rid = raw.get("id") or raw.get("msgid")
|
||
return cls(role=role, nickname=nickname, content=content, timestamp=ts, wxid=wxid, id=rid)
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
d = {
|
||
"role": self.role or "user",
|
||
"nickname": self.nickname,
|
||
"content": self.content,
|
||
}
|
||
if self.timestamp is not None:
|
||
d["timestamp"] = self.timestamp
|
||
if self.wxid:
|
||
d["wxid"] = self.wxid
|
||
if self.id:
|
||
d["id"] = self.id
|
||
return d
|
||
|
||
|
||
class ContextStore:
|
||
"""
|
||
统一上下文存储。
|
||
|
||
Args:
|
||
config: AIChat 配置 dict
|
||
history_dir: 历史文件目录(群聊降级)
|
||
memory_fallback: AIChat 内存 dict(私聊降级)
|
||
history_locks: AIChat locks dict(文件写入)
|
||
persistent_db_path: sqlite 文件路径
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
config: Dict[str, Any],
|
||
history_dir: Optional[Path],
|
||
memory_fallback: Dict[str, List[Dict[str, Any]]],
|
||
history_locks: Dict[str, asyncio.Lock],
|
||
persistent_db_path: Optional[Path],
|
||
):
|
||
self.config = config or {}
|
||
self.history_dir = history_dir
|
||
self.memory_fallback = memory_fallback
|
||
self.history_locks = history_locks
|
||
self.persistent_db_path = persistent_db_path
|
||
|
||
# ------------------ 私聊 memory ------------------
|
||
|
||
def _use_redis_for_memory(self) -> bool:
|
||
redis_config = self.config.get("redis", {})
|
||
if not redis_config.get("use_redis_history", True):
|
||
return False
|
||
redis_cache = get_cache()
|
||
return bool(redis_cache and redis_cache.enabled)
|
||
|
||
def add_private_message(
|
||
self,
|
||
chat_id: str,
|
||
role: str,
|
||
content: Any,
|
||
*,
|
||
image_base64: str = None,
|
||
nickname: str = "",
|
||
sender_wxid: str = None,
|
||
) -> None:
|
||
if not self.config.get("memory", {}).get("enabled", False):
|
||
return
|
||
|
||
if image_base64:
|
||
message_content = [
|
||
{"type": "text", "text": _extract_text_from_multimodal(content)},
|
||
{"type": "image_url", "image_url": {"url": image_base64}},
|
||
]
|
||
else:
|
||
message_content = content
|
||
|
||
redis_config = self.config.get("redis", {})
|
||
if self._use_redis_for_memory():
|
||
redis_cache = get_cache()
|
||
ttl = redis_config.get("chat_history_ttl", 86400)
|
||
try:
|
||
redis_cache.add_chat_message(
|
||
chat_id,
|
||
role,
|
||
message_content,
|
||
nickname=nickname,
|
||
sender_wxid=sender_wxid,
|
||
ttl=ttl,
|
||
)
|
||
max_messages = self.config.get("memory", {}).get("max_messages", 20)
|
||
redis_cache.trim_chat_history(chat_id, max_messages)
|
||
return
|
||
except Exception as e:
|
||
logger.debug(f"[ContextStore] Redis private history 写入失败: {e}")
|
||
|
||
if chat_id not in self.memory_fallback:
|
||
self.memory_fallback[chat_id] = []
|
||
self.memory_fallback[chat_id].append({"role": role, "content": message_content})
|
||
max_messages = self.config.get("memory", {}).get("max_messages", 20)
|
||
if len(self.memory_fallback[chat_id]) > max_messages:
|
||
self.memory_fallback[chat_id] = self.memory_fallback[chat_id][-max_messages:]
|
||
|
||
def get_private_messages(self, chat_id: str) -> List[Dict[str, Any]]:
|
||
if not self.config.get("memory", {}).get("enabled", False):
|
||
return []
|
||
|
||
if self._use_redis_for_memory():
|
||
redis_cache = get_cache()
|
||
max_messages = self.config.get("memory", {}).get("max_messages", 20)
|
||
try:
|
||
history = redis_cache.get_chat_history(chat_id, max_messages)
|
||
return [HistoryRecord.from_raw(h).to_dict() for h in history]
|
||
except Exception as e:
|
||
logger.debug(f"[ContextStore] Redis private history 读取失败: {e}")
|
||
|
||
return self.memory_fallback.get(chat_id, [])
|
||
|
||
def clear_private_messages(self, chat_id: str) -> None:
|
||
if self._use_redis_for_memory():
|
||
redis_cache = get_cache()
|
||
try:
|
||
redis_cache.clear_chat_history(chat_id)
|
||
except Exception:
|
||
pass
|
||
self.memory_fallback.pop(chat_id, None)
|
||
|
||
# ------------------ 群聊 history ------------------
|
||
|
||
def _use_redis_for_group_history(self) -> bool:
|
||
redis_config = self.config.get("redis", {})
|
||
if not redis_config.get("use_redis_history", True):
|
||
return False
|
||
redis_cache = get_cache()
|
||
return bool(redis_cache and redis_cache.enabled)
|
||
|
||
def _get_history_file(self, chat_id: str) -> Optional[Path]:
|
||
if not self.history_dir:
|
||
return None
|
||
return self.history_dir / f"{_safe_chat_id(chat_id)}.json"
|
||
|
||
def _get_history_lock(self, chat_id: str) -> asyncio.Lock:
|
||
lock = self.history_locks.get(chat_id)
|
||
if lock is None:
|
||
lock = asyncio.Lock()
|
||
self.history_locks[chat_id] = lock
|
||
return lock
|
||
|
||
def _read_history_file(self, history_file: Path) -> List[Dict[str, Any]]:
|
||
try:
|
||
with open(history_file, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
except FileNotFoundError:
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"读取历史记录失败: {history_file}, {e}")
|
||
return []
|
||
|
||
def _write_history_file(self, history_file: Path, history: List[Dict[str, Any]]) -> None:
|
||
history_file.parent.mkdir(parents=True, exist_ok=True)
|
||
temp_file = Path(str(history_file) + ".tmp")
|
||
with open(temp_file, "w", encoding="utf-8") as f:
|
||
json.dump(history, f, ensure_ascii=False, indent=2)
|
||
temp_file.replace(history_file)
|
||
|
||
async def load_group_history(self, chat_id: str) -> List[Dict[str, Any]]:
|
||
if not self.config.get("history", {}).get("enabled", True):
|
||
return []
|
||
|
||
if self._use_redis_for_group_history():
|
||
redis_cache = get_cache()
|
||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||
try:
|
||
history = redis_cache.get_group_history(chat_id, max_history)
|
||
return [HistoryRecord.from_raw(h).to_dict() for h in history]
|
||
except Exception as e:
|
||
logger.debug(f"[ContextStore] Redis group history 读取失败: {e}")
|
||
|
||
history_file = self._get_history_file(chat_id)
|
||
if not history_file:
|
||
return []
|
||
|
||
lock = self._get_history_lock(chat_id)
|
||
async with lock:
|
||
raw_history = self._read_history_file(history_file)
|
||
return [HistoryRecord.from_raw(h).to_dict() for h in raw_history]
|
||
|
||
async def add_group_message(
|
||
self,
|
||
chat_id: str,
|
||
nickname: str,
|
||
content: Any,
|
||
*,
|
||
record_id: str = None,
|
||
image_base64: str = None,
|
||
role: str = "user",
|
||
sender_wxid: str = None,
|
||
) -> None:
|
||
if not self.config.get("history", {}).get("enabled", True):
|
||
return
|
||
|
||
if image_base64:
|
||
message_content = [
|
||
{"type": "text", "text": _extract_text_from_multimodal(content)},
|
||
{"type": "image_url", "image_url": {"url": image_base64}},
|
||
]
|
||
else:
|
||
message_content = content
|
||
|
||
if self._use_redis_for_group_history():
|
||
redis_cache = get_cache()
|
||
redis_config = self.config.get("redis", {})
|
||
ttl = redis_config.get("group_history_ttl", 172800)
|
||
try:
|
||
redis_cache.add_group_message(
|
||
chat_id,
|
||
nickname,
|
||
message_content,
|
||
record_id=record_id,
|
||
role=role,
|
||
sender_wxid=sender_wxid,
|
||
ttl=ttl,
|
||
)
|
||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||
redis_cache.trim_group_history(chat_id, max_history)
|
||
return
|
||
except Exception as e:
|
||
logger.debug(f"[ContextStore] Redis group history 写入失败: {e}")
|
||
|
||
history_file = self._get_history_file(chat_id)
|
||
if not history_file:
|
||
return
|
||
|
||
lock = self._get_history_lock(chat_id)
|
||
async with lock:
|
||
history = self._read_history_file(history_file)
|
||
record = HistoryRecord(
|
||
role=role or "user",
|
||
nickname=nickname,
|
||
content=message_content,
|
||
timestamp=datetime.now().isoformat(),
|
||
wxid=sender_wxid,
|
||
id=record_id,
|
||
)
|
||
history.append(record.to_dict())
|
||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||
if len(history) > max_history:
|
||
history = history[-max_history:]
|
||
self._write_history_file(history_file, history)
|
||
|
||
async def update_group_message_by_id(self, chat_id: str, record_id: str, new_content: Any) -> None:
|
||
if not self.config.get("history", {}).get("enabled", True):
|
||
return
|
||
|
||
if self._use_redis_for_group_history():
|
||
redis_cache = get_cache()
|
||
try:
|
||
redis_cache.update_group_message_by_id(chat_id, record_id, new_content)
|
||
return
|
||
except Exception as e:
|
||
logger.debug(f"[ContextStore] Redis group history 更新失败: {e}")
|
||
|
||
history_file = self._get_history_file(chat_id)
|
||
if not history_file:
|
||
return
|
||
|
||
lock = self._get_history_lock(chat_id)
|
||
async with lock:
|
||
history = self._read_history_file(history_file)
|
||
for rec in history:
|
||
if rec.get("id") == record_id:
|
||
rec["content"] = new_content
|
||
break
|
||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||
if len(history) > max_history:
|
||
history = history[-max_history:]
|
||
self._write_history_file(history_file, history)
|
||
|
||
# ------------------ 持久记忆 sqlite ------------------
|
||
|
||
def init_persistent_memory_db(self) -> Optional[Path]:
|
||
if not self.persistent_db_path:
|
||
return None
|
||
|
||
self.persistent_db_path.parent.mkdir(exist_ok=True, parents=True)
|
||
conn = sqlite3.connect(self.persistent_db_path)
|
||
cursor = conn.cursor()
|
||
cursor.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS memories (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
chat_id TEXT NOT NULL,
|
||
chat_type TEXT NOT NULL,
|
||
user_wxid TEXT NOT NULL,
|
||
user_nickname TEXT,
|
||
content TEXT NOT NULL,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
)
|
||
"""
|
||
)
|
||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_chat_id ON memories(chat_id)")
|
||
conn.commit()
|
||
conn.close()
|
||
logger.info(f"持久记忆数据库已初始化: {self.persistent_db_path}")
|
||
return self.persistent_db_path
|
||
|
||
def add_persistent_memory(
|
||
self,
|
||
chat_id: str,
|
||
chat_type: str,
|
||
user_wxid: str,
|
||
user_nickname: str,
|
||
content: str,
|
||
) -> int:
|
||
if not self.persistent_db_path:
|
||
return -1
|
||
|
||
conn = sqlite3.connect(self.persistent_db_path)
|
||
cursor = conn.cursor()
|
||
cursor.execute(
|
||
"""
|
||
INSERT INTO memories (chat_id, chat_type, user_wxid, user_nickname, content)
|
||
VALUES (?, ?, ?, ?, ?)
|
||
""",
|
||
(chat_id, chat_type, user_wxid, user_nickname, content),
|
||
)
|
||
memory_id = cursor.lastrowid
|
||
conn.commit()
|
||
conn.close()
|
||
return memory_id
|
||
|
||
def get_persistent_memories(self, chat_id: str) -> List[Dict[str, Any]]:
|
||
if not self.persistent_db_path:
|
||
return []
|
||
|
||
conn = sqlite3.connect(self.persistent_db_path)
|
||
cursor = conn.cursor()
|
||
cursor.execute(
|
||
"""
|
||
SELECT id, user_nickname, content, created_at
|
||
FROM memories
|
||
WHERE chat_id = ?
|
||
ORDER BY created_at ASC
|
||
""",
|
||
(chat_id,),
|
||
)
|
||
rows = cursor.fetchall()
|
||
conn.close()
|
||
return [
|
||
{"id": r[0], "nickname": r[1], "content": r[2], "time": r[3]}
|
||
for r in rows
|
||
]
|
||
|
||
def delete_persistent_memory(self, chat_id: str, memory_id: int) -> bool:
|
||
if not self.persistent_db_path:
|
||
return False
|
||
|
||
conn = sqlite3.connect(self.persistent_db_path)
|
||
cursor = conn.cursor()
|
||
cursor.execute(
|
||
"DELETE FROM memories WHERE id = ? AND chat_id = ?",
|
||
(memory_id, chat_id),
|
||
)
|
||
deleted = cursor.rowcount > 0
|
||
conn.commit()
|
||
conn.close()
|
||
return deleted
|
||
|
||
def clear_persistent_memories(self, chat_id: str) -> int:
|
||
if not self.persistent_db_path:
|
||
return 0
|
||
|
||
conn = sqlite3.connect(self.persistent_db_path)
|
||
cursor = conn.cursor()
|
||
cursor.execute("DELETE FROM memories WHERE chat_id = ?", (chat_id,))
|
||
deleted_count = cursor.rowcount
|
||
conn.commit()
|
||
conn.close()
|
||
return deleted_count
|
||
|
||
# ------------------ 旧数据扫描/清理 ------------------
|
||
|
||
def find_legacy_group_history_keys(self) -> List[str]:
|
||
"""
|
||
发现旧版本使用 safe_id 写入的 group_history key。
|
||
|
||
Returns:
|
||
legacy_keys 列表(不删除)
|
||
"""
|
||
redis_cache = get_cache()
|
||
if not redis_cache or not redis_cache.enabled:
|
||
return []
|
||
|
||
try:
|
||
keys = redis_cache.client.keys("group_history:*")
|
||
legacy = []
|
||
for k in keys or []:
|
||
# 新 key 一般包含 @chatroom;旧 safe_id 不包含 @
|
||
if "@chatroom" not in k and "_" in k:
|
||
legacy.append(k)
|
||
return legacy
|
||
except Exception as e:
|
||
logger.debug(f"[ContextStore] 扫描 legacy group_history keys 失败: {e}")
|
||
return []
|
||
|
||
def delete_legacy_group_history_keys(self, legacy_keys: List[str]) -> int:
|
||
"""删除给定 legacy key 列表"""
|
||
redis_cache = get_cache()
|
||
if not redis_cache or not redis_cache.enabled or not legacy_keys:
|
||
return 0
|
||
try:
|
||
return redis_cache.client.delete(*legacy_keys)
|
||
except Exception as e:
|
||
logger.debug(f"[ContextStore] 删除 legacy group_history keys 失败: {e}")
|
||
return 0
|