feat:优化屎山
This commit is contained in:
470
utils/context_store.py
Normal file
470
utils/context_store.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
上下文/存储统一封装
|
||||
|
||||
提供统一的会话上下文读写接口:
|
||||
- 私聊/单人会话 memory(优先 Redis,降级内存)
|
||||
- 群聊 history(优先 Redis,降级文件)
|
||||
- 持久记忆 sqlite
|
||||
|
||||
AIChat 只需要通过本模块读写消息,不再关心介质细节。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from utils.redis_cache import get_cache
|
||||
|
||||
|
||||
def _safe_chat_id(chat_id: str) -> str:
|
||||
return (chat_id or "").replace("@", "_").replace(":", "_")
|
||||
|
||||
|
||||
def _extract_text_from_multimodal(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
parts.append(item.get("text", ""))
|
||||
return "".join(parts).strip()
|
||||
return str(content)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HistoryRecord:
|
||||
role: str = "user"
|
||||
nickname: str = ""
|
||||
content: Any = ""
|
||||
timestamp: Any = None
|
||||
wxid: Optional[str] = None
|
||||
id: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, raw: Dict[str, Any]) -> "HistoryRecord":
|
||||
role = raw.get("role") or "user"
|
||||
nickname = raw.get("nickname") or raw.get("SenderNickname") or ""
|
||||
content = raw.get("content") if "content" in raw else raw.get("Content", "")
|
||||
ts = raw.get("timestamp") or raw.get("time") or raw.get("CreateTime")
|
||||
wxid = raw.get("wxid") or raw.get("SenderWxid")
|
||||
rid = raw.get("id") or raw.get("msgid")
|
||||
return cls(role=role, nickname=nickname, content=content, timestamp=ts, wxid=wxid, id=rid)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = {
|
||||
"role": self.role or "user",
|
||||
"nickname": self.nickname,
|
||||
"content": self.content,
|
||||
}
|
||||
if self.timestamp is not None:
|
||||
d["timestamp"] = self.timestamp
|
||||
if self.wxid:
|
||||
d["wxid"] = self.wxid
|
||||
if self.id:
|
||||
d["id"] = self.id
|
||||
return d
|
||||
|
||||
|
||||
class ContextStore:
|
||||
"""
|
||||
统一上下文存储。
|
||||
|
||||
Args:
|
||||
config: AIChat 配置 dict
|
||||
history_dir: 历史文件目录(群聊降级)
|
||||
memory_fallback: AIChat 内存 dict(私聊降级)
|
||||
history_locks: AIChat locks dict(文件写入)
|
||||
persistent_db_path: sqlite 文件路径
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
history_dir: Optional[Path],
|
||||
memory_fallback: Dict[str, List[Dict[str, Any]]],
|
||||
history_locks: Dict[str, asyncio.Lock],
|
||||
persistent_db_path: Optional[Path],
|
||||
):
|
||||
self.config = config or {}
|
||||
self.history_dir = history_dir
|
||||
self.memory_fallback = memory_fallback
|
||||
self.history_locks = history_locks
|
||||
self.persistent_db_path = persistent_db_path
|
||||
|
||||
# ------------------ 私聊 memory ------------------
|
||||
|
||||
def _use_redis_for_memory(self) -> bool:
|
||||
redis_config = self.config.get("redis", {})
|
||||
if not redis_config.get("use_redis_history", True):
|
||||
return False
|
||||
redis_cache = get_cache()
|
||||
return bool(redis_cache and redis_cache.enabled)
|
||||
|
||||
def add_private_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
role: str,
|
||||
content: Any,
|
||||
*,
|
||||
image_base64: str = None,
|
||||
nickname: str = "",
|
||||
sender_wxid: str = None,
|
||||
) -> None:
|
||||
if not self.config.get("memory", {}).get("enabled", False):
|
||||
return
|
||||
|
||||
if image_base64:
|
||||
message_content = [
|
||||
{"type": "text", "text": _extract_text_from_multimodal(content)},
|
||||
{"type": "image_url", "image_url": {"url": image_base64}},
|
||||
]
|
||||
else:
|
||||
message_content = content
|
||||
|
||||
redis_config = self.config.get("redis", {})
|
||||
if self._use_redis_for_memory():
|
||||
redis_cache = get_cache()
|
||||
ttl = redis_config.get("chat_history_ttl", 86400)
|
||||
try:
|
||||
redis_cache.add_chat_message(
|
||||
chat_id,
|
||||
role,
|
||||
message_content,
|
||||
nickname=nickname,
|
||||
sender_wxid=sender_wxid,
|
||||
ttl=ttl,
|
||||
)
|
||||
max_messages = self.config.get("memory", {}).get("max_messages", 20)
|
||||
redis_cache.trim_chat_history(chat_id, max_messages)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"[ContextStore] Redis private history 写入失败: {e}")
|
||||
|
||||
if chat_id not in self.memory_fallback:
|
||||
self.memory_fallback[chat_id] = []
|
||||
self.memory_fallback[chat_id].append({"role": role, "content": message_content})
|
||||
max_messages = self.config.get("memory", {}).get("max_messages", 20)
|
||||
if len(self.memory_fallback[chat_id]) > max_messages:
|
||||
self.memory_fallback[chat_id] = self.memory_fallback[chat_id][-max_messages:]
|
||||
|
||||
def get_private_messages(self, chat_id: str) -> List[Dict[str, Any]]:
|
||||
if not self.config.get("memory", {}).get("enabled", False):
|
||||
return []
|
||||
|
||||
if self._use_redis_for_memory():
|
||||
redis_cache = get_cache()
|
||||
max_messages = self.config.get("memory", {}).get("max_messages", 20)
|
||||
try:
|
||||
history = redis_cache.get_chat_history(chat_id, max_messages)
|
||||
return [HistoryRecord.from_raw(h).to_dict() for h in history]
|
||||
except Exception as e:
|
||||
logger.debug(f"[ContextStore] Redis private history 读取失败: {e}")
|
||||
|
||||
return self.memory_fallback.get(chat_id, [])
|
||||
|
||||
def clear_private_messages(self, chat_id: str) -> None:
|
||||
if self._use_redis_for_memory():
|
||||
redis_cache = get_cache()
|
||||
try:
|
||||
redis_cache.clear_chat_history(chat_id)
|
||||
except Exception:
|
||||
pass
|
||||
self.memory_fallback.pop(chat_id, None)
|
||||
|
||||
# ------------------ 群聊 history ------------------
|
||||
|
||||
def _use_redis_for_group_history(self) -> bool:
|
||||
redis_config = self.config.get("redis", {})
|
||||
if not redis_config.get("use_redis_history", True):
|
||||
return False
|
||||
redis_cache = get_cache()
|
||||
return bool(redis_cache and redis_cache.enabled)
|
||||
|
||||
def _get_history_file(self, chat_id: str) -> Optional[Path]:
|
||||
if not self.history_dir:
|
||||
return None
|
||||
return self.history_dir / f"{_safe_chat_id(chat_id)}.json"
|
||||
|
||||
def _get_history_lock(self, chat_id: str) -> asyncio.Lock:
|
||||
lock = self.history_locks.get(chat_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
self.history_locks[chat_id] = lock
|
||||
return lock
|
||||
|
||||
def _read_history_file(self, history_file: Path) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
with open(history_file, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"读取历史记录失败: {history_file}, {e}")
|
||||
return []
|
||||
|
||||
def _write_history_file(self, history_file: Path, history: List[Dict[str, Any]]) -> None:
|
||||
history_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
temp_file = Path(str(history_file) + ".tmp")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
json.dump(history, f, ensure_ascii=False, indent=2)
|
||||
temp_file.replace(history_file)
|
||||
|
||||
async def load_group_history(self, chat_id: str) -> List[Dict[str, Any]]:
|
||||
if not self.config.get("history", {}).get("enabled", True):
|
||||
return []
|
||||
|
||||
if self._use_redis_for_group_history():
|
||||
redis_cache = get_cache()
|
||||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||||
try:
|
||||
history = redis_cache.get_group_history(chat_id, max_history)
|
||||
return [HistoryRecord.from_raw(h).to_dict() for h in history]
|
||||
except Exception as e:
|
||||
logger.debug(f"[ContextStore] Redis group history 读取失败: {e}")
|
||||
|
||||
history_file = self._get_history_file(chat_id)
|
||||
if not history_file:
|
||||
return []
|
||||
|
||||
lock = self._get_history_lock(chat_id)
|
||||
async with lock:
|
||||
raw_history = self._read_history_file(history_file)
|
||||
return [HistoryRecord.from_raw(h).to_dict() for h in raw_history]
|
||||
|
||||
async def add_group_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
nickname: str,
|
||||
content: Any,
|
||||
*,
|
||||
record_id: str = None,
|
||||
image_base64: str = None,
|
||||
role: str = "user",
|
||||
sender_wxid: str = None,
|
||||
) -> None:
|
||||
if not self.config.get("history", {}).get("enabled", True):
|
||||
return
|
||||
|
||||
if image_base64:
|
||||
message_content = [
|
||||
{"type": "text", "text": _extract_text_from_multimodal(content)},
|
||||
{"type": "image_url", "image_url": {"url": image_base64}},
|
||||
]
|
||||
else:
|
||||
message_content = content
|
||||
|
||||
if self._use_redis_for_group_history():
|
||||
redis_cache = get_cache()
|
||||
redis_config = self.config.get("redis", {})
|
||||
ttl = redis_config.get("group_history_ttl", 172800)
|
||||
try:
|
||||
redis_cache.add_group_message(
|
||||
chat_id,
|
||||
nickname,
|
||||
message_content,
|
||||
record_id=record_id,
|
||||
role=role,
|
||||
sender_wxid=sender_wxid,
|
||||
ttl=ttl,
|
||||
)
|
||||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||||
redis_cache.trim_group_history(chat_id, max_history)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"[ContextStore] Redis group history 写入失败: {e}")
|
||||
|
||||
history_file = self._get_history_file(chat_id)
|
||||
if not history_file:
|
||||
return
|
||||
|
||||
lock = self._get_history_lock(chat_id)
|
||||
async with lock:
|
||||
history = self._read_history_file(history_file)
|
||||
record = HistoryRecord(
|
||||
role=role or "user",
|
||||
nickname=nickname,
|
||||
content=message_content,
|
||||
timestamp=datetime.now().isoformat(),
|
||||
wxid=sender_wxid,
|
||||
id=record_id,
|
||||
)
|
||||
history.append(record.to_dict())
|
||||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||||
if len(history) > max_history:
|
||||
history = history[-max_history:]
|
||||
self._write_history_file(history_file, history)
|
||||
|
||||
async def update_group_message_by_id(self, chat_id: str, record_id: str, new_content: Any) -> None:
|
||||
if not self.config.get("history", {}).get("enabled", True):
|
||||
return
|
||||
|
||||
if self._use_redis_for_group_history():
|
||||
redis_cache = get_cache()
|
||||
try:
|
||||
redis_cache.update_group_message_by_id(chat_id, record_id, new_content)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"[ContextStore] Redis group history 更新失败: {e}")
|
||||
|
||||
history_file = self._get_history_file(chat_id)
|
||||
if not history_file:
|
||||
return
|
||||
|
||||
lock = self._get_history_lock(chat_id)
|
||||
async with lock:
|
||||
history = self._read_history_file(history_file)
|
||||
for rec in history:
|
||||
if rec.get("id") == record_id:
|
||||
rec["content"] = new_content
|
||||
break
|
||||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||||
if len(history) > max_history:
|
||||
history = history[-max_history:]
|
||||
self._write_history_file(history_file, history)
|
||||
|
||||
# ------------------ 持久记忆 sqlite ------------------
|
||||
|
||||
def init_persistent_memory_db(self) -> Optional[Path]:
|
||||
if not self.persistent_db_path:
|
||||
return None
|
||||
|
||||
self.persistent_db_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
conn = sqlite3.connect(self.persistent_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
chat_id TEXT NOT NULL,
|
||||
chat_type TEXT NOT NULL,
|
||||
user_wxid TEXT NOT NULL,
|
||||
user_nickname TEXT,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_chat_id ON memories(chat_id)")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info(f"持久记忆数据库已初始化: {self.persistent_db_path}")
|
||||
return self.persistent_db_path
|
||||
|
||||
def add_persistent_memory(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_type: str,
|
||||
user_wxid: str,
|
||||
user_nickname: str,
|
||||
content: str,
|
||||
) -> int:
|
||||
if not self.persistent_db_path:
|
||||
return -1
|
||||
|
||||
conn = sqlite3.connect(self.persistent_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO memories (chat_id, chat_type, user_wxid, user_nickname, content)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(chat_id, chat_type, user_wxid, user_nickname, content),
|
||||
)
|
||||
memory_id = cursor.lastrowid
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return memory_id
|
||||
|
||||
def get_persistent_memories(self, chat_id: str) -> List[Dict[str, Any]]:
|
||||
if not self.persistent_db_path:
|
||||
return []
|
||||
|
||||
conn = sqlite3.connect(self.persistent_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT id, user_nickname, content, created_at
|
||||
FROM memories
|
||||
WHERE chat_id = ?
|
||||
ORDER BY created_at ASC
|
||||
""",
|
||||
(chat_id,),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
return [
|
||||
{"id": r[0], "nickname": r[1], "content": r[2], "time": r[3]}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
def delete_persistent_memory(self, chat_id: str, memory_id: int) -> bool:
|
||||
if not self.persistent_db_path:
|
||||
return False
|
||||
|
||||
conn = sqlite3.connect(self.persistent_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"DELETE FROM memories WHERE id = ? AND chat_id = ?",
|
||||
(memory_id, chat_id),
|
||||
)
|
||||
deleted = cursor.rowcount > 0
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return deleted
|
||||
|
||||
def clear_persistent_memories(self, chat_id: str) -> int:
|
||||
if not self.persistent_db_path:
|
||||
return 0
|
||||
|
||||
conn = sqlite3.connect(self.persistent_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM memories WHERE chat_id = ?", (chat_id,))
|
||||
deleted_count = cursor.rowcount
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return deleted_count
|
||||
|
||||
# ------------------ 旧数据扫描/清理 ------------------
|
||||
|
||||
def find_legacy_group_history_keys(self) -> List[str]:
|
||||
"""
|
||||
发现旧版本使用 safe_id 写入的 group_history key。
|
||||
|
||||
Returns:
|
||||
legacy_keys 列表(不删除)
|
||||
"""
|
||||
redis_cache = get_cache()
|
||||
if not redis_cache or not redis_cache.enabled:
|
||||
return []
|
||||
|
||||
try:
|
||||
keys = redis_cache.client.keys("group_history:*")
|
||||
legacy = []
|
||||
for k in keys or []:
|
||||
# 新 key 一般包含 @chatroom;旧 safe_id 不包含 @
|
||||
if "@chatroom" not in k and "_" in k:
|
||||
legacy.append(k)
|
||||
return legacy
|
||||
except Exception as e:
|
||||
logger.debug(f"[ContextStore] 扫描 legacy group_history keys 失败: {e}")
|
||||
return []
|
||||
|
||||
def delete_legacy_group_history_keys(self, legacy_keys: List[str]) -> int:
|
||||
"""删除给定 legacy key 列表"""
|
||||
redis_cache = get_cache()
|
||||
if not redis_cache or not redis_cache.enabled or not legacy_keys:
|
||||
return 0
|
||||
try:
|
||||
return redis_cache.client.delete(*legacy_keys)
|
||||
except Exception as e:
|
||||
logger.debug(f"[ContextStore] 删除 legacy group_history keys 失败: {e}")
|
||||
return 0
|
||||
Reference in New Issue
Block a user