Files
WechatHookBot/utils/context_store.py
2025-12-12 18:35:39 +08:00

471 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
上下文/存储统一封装
提供统一的会话上下文读写接口:
- 私聊/单人会话 memory优先 Redis降级内存
- 群聊 history优先 Redis降级文件
- 持久记忆 sqlite
AIChat 只需要通过本模块读写消息,不再关心介质细节。
"""
from __future__ import annotations
import asyncio
import json
import sqlite3
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from loguru import logger
from utils.redis_cache import get_cache
def _safe_chat_id(chat_id: str) -> str:
return (chat_id or "").replace("@", "_").replace(":", "_")
def _extract_text_from_multimodal(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(item.get("text", ""))
return "".join(parts).strip()
return str(content)
@dataclass
class HistoryRecord:
role: str = "user"
nickname: str = ""
content: Any = ""
timestamp: Any = None
wxid: Optional[str] = None
id: Optional[str] = None
@classmethod
def from_raw(cls, raw: Dict[str, Any]) -> "HistoryRecord":
role = raw.get("role") or "user"
nickname = raw.get("nickname") or raw.get("SenderNickname") or ""
content = raw.get("content") if "content" in raw else raw.get("Content", "")
ts = raw.get("timestamp") or raw.get("time") or raw.get("CreateTime")
wxid = raw.get("wxid") or raw.get("SenderWxid")
rid = raw.get("id") or raw.get("msgid")
return cls(role=role, nickname=nickname, content=content, timestamp=ts, wxid=wxid, id=rid)
def to_dict(self) -> Dict[str, Any]:
d = {
"role": self.role or "user",
"nickname": self.nickname,
"content": self.content,
}
if self.timestamp is not None:
d["timestamp"] = self.timestamp
if self.wxid:
d["wxid"] = self.wxid
if self.id:
d["id"] = self.id
return d
class ContextStore:
"""
统一上下文存储。
Args:
config: AIChat 配置 dict
history_dir: 历史文件目录(群聊降级)
memory_fallback: AIChat 内存 dict私聊降级
history_locks: AIChat locks dict文件写入
persistent_db_path: sqlite 文件路径
"""
def __init__(
self,
config: Dict[str, Any],
history_dir: Optional[Path],
memory_fallback: Dict[str, List[Dict[str, Any]]],
history_locks: Dict[str, asyncio.Lock],
persistent_db_path: Optional[Path],
):
self.config = config or {}
self.history_dir = history_dir
self.memory_fallback = memory_fallback
self.history_locks = history_locks
self.persistent_db_path = persistent_db_path
# ------------------ 私聊 memory ------------------
def _use_redis_for_memory(self) -> bool:
redis_config = self.config.get("redis", {})
if not redis_config.get("use_redis_history", True):
return False
redis_cache = get_cache()
return bool(redis_cache and redis_cache.enabled)
def add_private_message(
self,
chat_id: str,
role: str,
content: Any,
*,
image_base64: str = None,
nickname: str = "",
sender_wxid: str = None,
) -> None:
if not self.config.get("memory", {}).get("enabled", False):
return
if image_base64:
message_content = [
{"type": "text", "text": _extract_text_from_multimodal(content)},
{"type": "image_url", "image_url": {"url": image_base64}},
]
else:
message_content = content
redis_config = self.config.get("redis", {})
if self._use_redis_for_memory():
redis_cache = get_cache()
ttl = redis_config.get("chat_history_ttl", 86400)
try:
redis_cache.add_chat_message(
chat_id,
role,
message_content,
nickname=nickname,
sender_wxid=sender_wxid,
ttl=ttl,
)
max_messages = self.config.get("memory", {}).get("max_messages", 20)
redis_cache.trim_chat_history(chat_id, max_messages)
return
except Exception as e:
logger.debug(f"[ContextStore] Redis private history 写入失败: {e}")
if chat_id not in self.memory_fallback:
self.memory_fallback[chat_id] = []
self.memory_fallback[chat_id].append({"role": role, "content": message_content})
max_messages = self.config.get("memory", {}).get("max_messages", 20)
if len(self.memory_fallback[chat_id]) > max_messages:
self.memory_fallback[chat_id] = self.memory_fallback[chat_id][-max_messages:]
def get_private_messages(self, chat_id: str) -> List[Dict[str, Any]]:
if not self.config.get("memory", {}).get("enabled", False):
return []
if self._use_redis_for_memory():
redis_cache = get_cache()
max_messages = self.config.get("memory", {}).get("max_messages", 20)
try:
history = redis_cache.get_chat_history(chat_id, max_messages)
return [HistoryRecord.from_raw(h).to_dict() for h in history]
except Exception as e:
logger.debug(f"[ContextStore] Redis private history 读取失败: {e}")
return self.memory_fallback.get(chat_id, [])
def clear_private_messages(self, chat_id: str) -> None:
if self._use_redis_for_memory():
redis_cache = get_cache()
try:
redis_cache.clear_chat_history(chat_id)
except Exception:
pass
self.memory_fallback.pop(chat_id, None)
# ------------------ 群聊 history ------------------
def _use_redis_for_group_history(self) -> bool:
redis_config = self.config.get("redis", {})
if not redis_config.get("use_redis_history", True):
return False
redis_cache = get_cache()
return bool(redis_cache and redis_cache.enabled)
def _get_history_file(self, chat_id: str) -> Optional[Path]:
if not self.history_dir:
return None
return self.history_dir / f"{_safe_chat_id(chat_id)}.json"
def _get_history_lock(self, chat_id: str) -> asyncio.Lock:
lock = self.history_locks.get(chat_id)
if lock is None:
lock = asyncio.Lock()
self.history_locks[chat_id] = lock
return lock
def _read_history_file(self, history_file: Path) -> List[Dict[str, Any]]:
try:
with open(history_file, "r", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError:
return []
except Exception as e:
logger.error(f"读取历史记录失败: {history_file}, {e}")
return []
def _write_history_file(self, history_file: Path, history: List[Dict[str, Any]]) -> None:
history_file.parent.mkdir(parents=True, exist_ok=True)
temp_file = Path(str(history_file) + ".tmp")
with open(temp_file, "w", encoding="utf-8") as f:
json.dump(history, f, ensure_ascii=False, indent=2)
temp_file.replace(history_file)
async def load_group_history(self, chat_id: str) -> List[Dict[str, Any]]:
if not self.config.get("history", {}).get("enabled", True):
return []
if self._use_redis_for_group_history():
redis_cache = get_cache()
max_history = self.config.get("history", {}).get("max_history", 100)
try:
history = redis_cache.get_group_history(chat_id, max_history)
return [HistoryRecord.from_raw(h).to_dict() for h in history]
except Exception as e:
logger.debug(f"[ContextStore] Redis group history 读取失败: {e}")
history_file = self._get_history_file(chat_id)
if not history_file:
return []
lock = self._get_history_lock(chat_id)
async with lock:
raw_history = self._read_history_file(history_file)
return [HistoryRecord.from_raw(h).to_dict() for h in raw_history]
async def add_group_message(
self,
chat_id: str,
nickname: str,
content: Any,
*,
record_id: str = None,
image_base64: str = None,
role: str = "user",
sender_wxid: str = None,
) -> None:
if not self.config.get("history", {}).get("enabled", True):
return
if image_base64:
message_content = [
{"type": "text", "text": _extract_text_from_multimodal(content)},
{"type": "image_url", "image_url": {"url": image_base64}},
]
else:
message_content = content
if self._use_redis_for_group_history():
redis_cache = get_cache()
redis_config = self.config.get("redis", {})
ttl = redis_config.get("group_history_ttl", 172800)
try:
redis_cache.add_group_message(
chat_id,
nickname,
message_content,
record_id=record_id,
role=role,
sender_wxid=sender_wxid,
ttl=ttl,
)
max_history = self.config.get("history", {}).get("max_history", 100)
redis_cache.trim_group_history(chat_id, max_history)
return
except Exception as e:
logger.debug(f"[ContextStore] Redis group history 写入失败: {e}")
history_file = self._get_history_file(chat_id)
if not history_file:
return
lock = self._get_history_lock(chat_id)
async with lock:
history = self._read_history_file(history_file)
record = HistoryRecord(
role=role or "user",
nickname=nickname,
content=message_content,
timestamp=datetime.now().isoformat(),
wxid=sender_wxid,
id=record_id,
)
history.append(record.to_dict())
max_history = self.config.get("history", {}).get("max_history", 100)
if len(history) > max_history:
history = history[-max_history:]
self._write_history_file(history_file, history)
async def update_group_message_by_id(self, chat_id: str, record_id: str, new_content: Any) -> None:
if not self.config.get("history", {}).get("enabled", True):
return
if self._use_redis_for_group_history():
redis_cache = get_cache()
try:
redis_cache.update_group_message_by_id(chat_id, record_id, new_content)
return
except Exception as e:
logger.debug(f"[ContextStore] Redis group history 更新失败: {e}")
history_file = self._get_history_file(chat_id)
if not history_file:
return
lock = self._get_history_lock(chat_id)
async with lock:
history = self._read_history_file(history_file)
for rec in history:
if rec.get("id") == record_id:
rec["content"] = new_content
break
max_history = self.config.get("history", {}).get("max_history", 100)
if len(history) > max_history:
history = history[-max_history:]
self._write_history_file(history_file, history)
# ------------------ 持久记忆 sqlite ------------------
def init_persistent_memory_db(self) -> Optional[Path]:
if not self.persistent_db_path:
return None
self.persistent_db_path.parent.mkdir(exist_ok=True, parents=True)
conn = sqlite3.connect(self.persistent_db_path)
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS memories (
id INTEGER PRIMARY KEY AUTOINCREMENT,
chat_id TEXT NOT NULL,
chat_type TEXT NOT NULL,
user_wxid TEXT NOT NULL,
user_nickname TEXT,
content TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_chat_id ON memories(chat_id)")
conn.commit()
conn.close()
logger.info(f"持久记忆数据库已初始化: {self.persistent_db_path}")
return self.persistent_db_path
def add_persistent_memory(
self,
chat_id: str,
chat_type: str,
user_wxid: str,
user_nickname: str,
content: str,
) -> int:
if not self.persistent_db_path:
return -1
conn = sqlite3.connect(self.persistent_db_path)
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO memories (chat_id, chat_type, user_wxid, user_nickname, content)
VALUES (?, ?, ?, ?, ?)
""",
(chat_id, chat_type, user_wxid, user_nickname, content),
)
memory_id = cursor.lastrowid
conn.commit()
conn.close()
return memory_id
def get_persistent_memories(self, chat_id: str) -> List[Dict[str, Any]]:
if not self.persistent_db_path:
return []
conn = sqlite3.connect(self.persistent_db_path)
cursor = conn.cursor()
cursor.execute(
"""
SELECT id, user_nickname, content, created_at
FROM memories
WHERE chat_id = ?
ORDER BY created_at ASC
""",
(chat_id,),
)
rows = cursor.fetchall()
conn.close()
return [
{"id": r[0], "nickname": r[1], "content": r[2], "time": r[3]}
for r in rows
]
def delete_persistent_memory(self, chat_id: str, memory_id: int) -> bool:
if not self.persistent_db_path:
return False
conn = sqlite3.connect(self.persistent_db_path)
cursor = conn.cursor()
cursor.execute(
"DELETE FROM memories WHERE id = ? AND chat_id = ?",
(memory_id, chat_id),
)
deleted = cursor.rowcount > 0
conn.commit()
conn.close()
return deleted
def clear_persistent_memories(self, chat_id: str) -> int:
if not self.persistent_db_path:
return 0
conn = sqlite3.connect(self.persistent_db_path)
cursor = conn.cursor()
cursor.execute("DELETE FROM memories WHERE chat_id = ?", (chat_id,))
deleted_count = cursor.rowcount
conn.commit()
conn.close()
return deleted_count
# ------------------ 旧数据扫描/清理 ------------------
def find_legacy_group_history_keys(self) -> List[str]:
"""
发现旧版本使用 safe_id 写入的 group_history key。
Returns:
legacy_keys 列表(不删除)
"""
redis_cache = get_cache()
if not redis_cache or not redis_cache.enabled:
return []
try:
keys = redis_cache.client.keys("group_history:*")
legacy = []
for k in keys or []:
# 新 key 一般包含 @chatroom旧 safe_id 不包含 @
if "@chatroom" not in k and "_" in k:
legacy.append(k)
return legacy
except Exception as e:
logger.debug(f"[ContextStore] 扫描 legacy group_history keys 失败: {e}")
return []
def delete_legacy_group_history_keys(self, legacy_keys: List[str]) -> int:
"""删除给定 legacy key 列表"""
redis_cache = get_cache()
if not redis_cache or not redis_cache.enabled or not legacy_keys:
return 0
try:
return redis_cache.client.delete(*legacy_keys)
except Exception as e:
logger.debug(f"[ContextStore] 删除 legacy group_history keys 失败: {e}")
return 0