feat:优化屎山

This commit is contained in:
2025-12-12 18:35:39 +08:00
parent 6156064f56
commit c1983172af
10 changed files with 1257 additions and 582 deletions

470
utils/context_store.py Normal file
View File

@@ -0,0 +1,470 @@
"""
上下文/存储统一封装
提供统一的会话上下文读写接口:
- 私聊/单人会话 memory优先 Redis降级内存
- 群聊 history优先 Redis降级文件
- 持久记忆 sqlite
AIChat 只需要通过本模块读写消息,不再关心介质细节。
"""
from __future__ import annotations
import asyncio
import json
import sqlite3
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from loguru import logger
from utils.redis_cache import get_cache
def _safe_chat_id(chat_id: str) -> str:
return (chat_id or "").replace("@", "_").replace(":", "_")
def _extract_text_from_multimodal(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(item.get("text", ""))
return "".join(parts).strip()
return str(content)
@dataclass
class HistoryRecord:
role: str = "user"
nickname: str = ""
content: Any = ""
timestamp: Any = None
wxid: Optional[str] = None
id: Optional[str] = None
@classmethod
def from_raw(cls, raw: Dict[str, Any]) -> "HistoryRecord":
role = raw.get("role") or "user"
nickname = raw.get("nickname") or raw.get("SenderNickname") or ""
content = raw.get("content") if "content" in raw else raw.get("Content", "")
ts = raw.get("timestamp") or raw.get("time") or raw.get("CreateTime")
wxid = raw.get("wxid") or raw.get("SenderWxid")
rid = raw.get("id") or raw.get("msgid")
return cls(role=role, nickname=nickname, content=content, timestamp=ts, wxid=wxid, id=rid)
def to_dict(self) -> Dict[str, Any]:
d = {
"role": self.role or "user",
"nickname": self.nickname,
"content": self.content,
}
if self.timestamp is not None:
d["timestamp"] = self.timestamp
if self.wxid:
d["wxid"] = self.wxid
if self.id:
d["id"] = self.id
return d
class ContextStore:
"""
统一上下文存储。
Args:
config: AIChat 配置 dict
history_dir: 历史文件目录(群聊降级)
memory_fallback: AIChat 内存 dict私聊降级
history_locks: AIChat locks dict文件写入
persistent_db_path: sqlite 文件路径
"""
def __init__(
self,
config: Dict[str, Any],
history_dir: Optional[Path],
memory_fallback: Dict[str, List[Dict[str, Any]]],
history_locks: Dict[str, asyncio.Lock],
persistent_db_path: Optional[Path],
):
self.config = config or {}
self.history_dir = history_dir
self.memory_fallback = memory_fallback
self.history_locks = history_locks
self.persistent_db_path = persistent_db_path
# ------------------ 私聊 memory ------------------
def _use_redis_for_memory(self) -> bool:
redis_config = self.config.get("redis", {})
if not redis_config.get("use_redis_history", True):
return False
redis_cache = get_cache()
return bool(redis_cache and redis_cache.enabled)
def add_private_message(
self,
chat_id: str,
role: str,
content: Any,
*,
image_base64: str = None,
nickname: str = "",
sender_wxid: str = None,
) -> None:
if not self.config.get("memory", {}).get("enabled", False):
return
if image_base64:
message_content = [
{"type": "text", "text": _extract_text_from_multimodal(content)},
{"type": "image_url", "image_url": {"url": image_base64}},
]
else:
message_content = content
redis_config = self.config.get("redis", {})
if self._use_redis_for_memory():
redis_cache = get_cache()
ttl = redis_config.get("chat_history_ttl", 86400)
try:
redis_cache.add_chat_message(
chat_id,
role,
message_content,
nickname=nickname,
sender_wxid=sender_wxid,
ttl=ttl,
)
max_messages = self.config.get("memory", {}).get("max_messages", 20)
redis_cache.trim_chat_history(chat_id, max_messages)
return
except Exception as e:
logger.debug(f"[ContextStore] Redis private history 写入失败: {e}")
if chat_id not in self.memory_fallback:
self.memory_fallback[chat_id] = []
self.memory_fallback[chat_id].append({"role": role, "content": message_content})
max_messages = self.config.get("memory", {}).get("max_messages", 20)
if len(self.memory_fallback[chat_id]) > max_messages:
self.memory_fallback[chat_id] = self.memory_fallback[chat_id][-max_messages:]
def get_private_messages(self, chat_id: str) -> List[Dict[str, Any]]:
if not self.config.get("memory", {}).get("enabled", False):
return []
if self._use_redis_for_memory():
redis_cache = get_cache()
max_messages = self.config.get("memory", {}).get("max_messages", 20)
try:
history = redis_cache.get_chat_history(chat_id, max_messages)
return [HistoryRecord.from_raw(h).to_dict() for h in history]
except Exception as e:
logger.debug(f"[ContextStore] Redis private history 读取失败: {e}")
return self.memory_fallback.get(chat_id, [])
def clear_private_messages(self, chat_id: str) -> None:
if self._use_redis_for_memory():
redis_cache = get_cache()
try:
redis_cache.clear_chat_history(chat_id)
except Exception:
pass
self.memory_fallback.pop(chat_id, None)
# ------------------ 群聊 history ------------------
def _use_redis_for_group_history(self) -> bool:
redis_config = self.config.get("redis", {})
if not redis_config.get("use_redis_history", True):
return False
redis_cache = get_cache()
return bool(redis_cache and redis_cache.enabled)
def _get_history_file(self, chat_id: str) -> Optional[Path]:
if not self.history_dir:
return None
return self.history_dir / f"{_safe_chat_id(chat_id)}.json"
def _get_history_lock(self, chat_id: str) -> asyncio.Lock:
lock = self.history_locks.get(chat_id)
if lock is None:
lock = asyncio.Lock()
self.history_locks[chat_id] = lock
return lock
def _read_history_file(self, history_file: Path) -> List[Dict[str, Any]]:
try:
with open(history_file, "r", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError:
return []
except Exception as e:
logger.error(f"读取历史记录失败: {history_file}, {e}")
return []
def _write_history_file(self, history_file: Path, history: List[Dict[str, Any]]) -> None:
history_file.parent.mkdir(parents=True, exist_ok=True)
temp_file = Path(str(history_file) + ".tmp")
with open(temp_file, "w", encoding="utf-8") as f:
json.dump(history, f, ensure_ascii=False, indent=2)
temp_file.replace(history_file)
async def load_group_history(self, chat_id: str) -> List[Dict[str, Any]]:
if not self.config.get("history", {}).get("enabled", True):
return []
if self._use_redis_for_group_history():
redis_cache = get_cache()
max_history = self.config.get("history", {}).get("max_history", 100)
try:
history = redis_cache.get_group_history(chat_id, max_history)
return [HistoryRecord.from_raw(h).to_dict() for h in history]
except Exception as e:
logger.debug(f"[ContextStore] Redis group history 读取失败: {e}")
history_file = self._get_history_file(chat_id)
if not history_file:
return []
lock = self._get_history_lock(chat_id)
async with lock:
raw_history = self._read_history_file(history_file)
return [HistoryRecord.from_raw(h).to_dict() for h in raw_history]
async def add_group_message(
self,
chat_id: str,
nickname: str,
content: Any,
*,
record_id: str = None,
image_base64: str = None,
role: str = "user",
sender_wxid: str = None,
) -> None:
if not self.config.get("history", {}).get("enabled", True):
return
if image_base64:
message_content = [
{"type": "text", "text": _extract_text_from_multimodal(content)},
{"type": "image_url", "image_url": {"url": image_base64}},
]
else:
message_content = content
if self._use_redis_for_group_history():
redis_cache = get_cache()
redis_config = self.config.get("redis", {})
ttl = redis_config.get("group_history_ttl", 172800)
try:
redis_cache.add_group_message(
chat_id,
nickname,
message_content,
record_id=record_id,
role=role,
sender_wxid=sender_wxid,
ttl=ttl,
)
max_history = self.config.get("history", {}).get("max_history", 100)
redis_cache.trim_group_history(chat_id, max_history)
return
except Exception as e:
logger.debug(f"[ContextStore] Redis group history 写入失败: {e}")
history_file = self._get_history_file(chat_id)
if not history_file:
return
lock = self._get_history_lock(chat_id)
async with lock:
history = self._read_history_file(history_file)
record = HistoryRecord(
role=role or "user",
nickname=nickname,
content=message_content,
timestamp=datetime.now().isoformat(),
wxid=sender_wxid,
id=record_id,
)
history.append(record.to_dict())
max_history = self.config.get("history", {}).get("max_history", 100)
if len(history) > max_history:
history = history[-max_history:]
self._write_history_file(history_file, history)
async def update_group_message_by_id(self, chat_id: str, record_id: str, new_content: Any) -> None:
if not self.config.get("history", {}).get("enabled", True):
return
if self._use_redis_for_group_history():
redis_cache = get_cache()
try:
redis_cache.update_group_message_by_id(chat_id, record_id, new_content)
return
except Exception as e:
logger.debug(f"[ContextStore] Redis group history 更新失败: {e}")
history_file = self._get_history_file(chat_id)
if not history_file:
return
lock = self._get_history_lock(chat_id)
async with lock:
history = self._read_history_file(history_file)
for rec in history:
if rec.get("id") == record_id:
rec["content"] = new_content
break
max_history = self.config.get("history", {}).get("max_history", 100)
if len(history) > max_history:
history = history[-max_history:]
self._write_history_file(history_file, history)
# ------------------ 持久记忆 sqlite ------------------
def init_persistent_memory_db(self) -> Optional[Path]:
if not self.persistent_db_path:
return None
self.persistent_db_path.parent.mkdir(exist_ok=True, parents=True)
conn = sqlite3.connect(self.persistent_db_path)
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS memories (
id INTEGER PRIMARY KEY AUTOINCREMENT,
chat_id TEXT NOT NULL,
chat_type TEXT NOT NULL,
user_wxid TEXT NOT NULL,
user_nickname TEXT,
content TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_chat_id ON memories(chat_id)")
conn.commit()
conn.close()
logger.info(f"持久记忆数据库已初始化: {self.persistent_db_path}")
return self.persistent_db_path
def add_persistent_memory(
self,
chat_id: str,
chat_type: str,
user_wxid: str,
user_nickname: str,
content: str,
) -> int:
if not self.persistent_db_path:
return -1
conn = sqlite3.connect(self.persistent_db_path)
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO memories (chat_id, chat_type, user_wxid, user_nickname, content)
VALUES (?, ?, ?, ?, ?)
""",
(chat_id, chat_type, user_wxid, user_nickname, content),
)
memory_id = cursor.lastrowid
conn.commit()
conn.close()
return memory_id
def get_persistent_memories(self, chat_id: str) -> List[Dict[str, Any]]:
if not self.persistent_db_path:
return []
conn = sqlite3.connect(self.persistent_db_path)
cursor = conn.cursor()
cursor.execute(
"""
SELECT id, user_nickname, content, created_at
FROM memories
WHERE chat_id = ?
ORDER BY created_at ASC
""",
(chat_id,),
)
rows = cursor.fetchall()
conn.close()
return [
{"id": r[0], "nickname": r[1], "content": r[2], "time": r[3]}
for r in rows
]
def delete_persistent_memory(self, chat_id: str, memory_id: int) -> bool:
if not self.persistent_db_path:
return False
conn = sqlite3.connect(self.persistent_db_path)
cursor = conn.cursor()
cursor.execute(
"DELETE FROM memories WHERE id = ? AND chat_id = ?",
(memory_id, chat_id),
)
deleted = cursor.rowcount > 0
conn.commit()
conn.close()
return deleted
def clear_persistent_memories(self, chat_id: str) -> int:
if not self.persistent_db_path:
return 0
conn = sqlite3.connect(self.persistent_db_path)
cursor = conn.cursor()
cursor.execute("DELETE FROM memories WHERE chat_id = ?", (chat_id,))
deleted_count = cursor.rowcount
conn.commit()
conn.close()
return deleted_count
# ------------------ 旧数据扫描/清理 ------------------
def find_legacy_group_history_keys(self) -> List[str]:
"""
发现旧版本使用 safe_id 写入的 group_history key。
Returns:
legacy_keys 列表(不删除)
"""
redis_cache = get_cache()
if not redis_cache or not redis_cache.enabled:
return []
try:
keys = redis_cache.client.keys("group_history:*")
legacy = []
for k in keys or []:
# 新 key 一般包含 @chatroom旧 safe_id 不包含 @
if "@chatroom" not in k and "_" in k:
legacy.append(k)
return legacy
except Exception as e:
logger.debug(f"[ContextStore] 扫描 legacy group_history keys 失败: {e}")
return []
def delete_legacy_group_history_keys(self, legacy_keys: List[str]) -> int:
"""删除给定 legacy key 列表"""
redis_cache = get_cache()
if not redis_cache or not redis_cache.enabled or not legacy_keys:
return 0
try:
return redis_cache.client.delete(*legacy_keys)
except Exception as e:
logger.debug(f"[ContextStore] 删除 legacy group_history keys 失败: {e}")
return 0

183
utils/llm_tooling.py Normal file
View File

@@ -0,0 +1,183 @@
"""
LLM 工具体系公共模块
统一工具收集、参数校验与执行结果结构,供 AIChat 等插件使用。
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from loguru import logger
@dataclass
class ToolResult:
"""统一的工具执行结果结构"""
success: bool = True
message: str = ""
need_ai_reply: bool = False
already_sent: bool = False
send_result_text: bool = False
no_reply: bool = False
save_to_memory: bool = False
@classmethod
def from_raw(cls, raw: Any) -> Optional["ToolResult"]:
if raw is None:
return None
if not isinstance(raw, dict):
return cls(success=True, message=str(raw))
msg = raw.get("message", "")
if not isinstance(msg, str):
try:
msg = json.dumps(msg, ensure_ascii=False)
except Exception:
msg = str(msg)
return cls(
success=bool(raw.get("success", True)),
message=msg,
need_ai_reply=bool(raw.get("need_ai_reply", False)),
already_sent=bool(raw.get("already_sent", False)),
send_result_text=bool(raw.get("send_result_text", False)),
no_reply=bool(raw.get("no_reply", False)),
save_to_memory=bool(raw.get("save_to_memory", False)),
)
def collect_tools_with_plugins(
tools_config: Dict[str, Any],
plugins: Dict[str, Any],
) -> Dict[str, Tuple[str, Dict[str, Any]]]:
"""
收集所有插件的 LLM 工具,并保留来源插件名。
Args:
tools_config: AIChat 配置中的 [tools] 节
plugins: PluginManager().plugins 映射
Returns:
{tool_name: (plugin_name, tool_dict)}
"""
tools_by_name: Dict[str, Tuple[str, Dict[str, Any]]] = {}
mode = tools_config.get("mode", "all")
whitelist = set(tools_config.get("whitelist", []))
blacklist = set(tools_config.get("blacklist", []))
for plugin_name, plugin in plugins.items():
if not hasattr(plugin, "get_llm_tools"):
continue
plugin_tools = plugin.get_llm_tools() or []
for tool in plugin_tools:
tool_name = tool.get("function", {}).get("name", "")
if not tool_name:
continue
if mode == "whitelist" and tool_name not in whitelist:
continue
if mode == "blacklist" and tool_name in blacklist:
logger.debug(f"[黑名单] 禁用工具: {tool_name}")
continue
if tool_name in tools_by_name:
logger.warning(f"重复工具名 {tool_name} 来自 {plugin_name},已忽略")
continue
tools_by_name[tool_name] = (plugin_name, tool)
if mode == "whitelist":
logger.debug(f"[白名单] 启用工具: {tool_name}")
return tools_by_name
def collect_tools(
tools_config: Dict[str, Any],
plugins: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""仅返回工具定义列表"""
return [item[1] for item in collect_tools_with_plugins(tools_config, plugins).values()]
def get_tool_schema_map(
tools_map: Dict[str, Tuple[str, Dict[str, Any]]],
) -> Dict[str, Dict[str, Any]]:
"""构建工具名到参数 schema 的映射"""
schema_map: Dict[str, Dict[str, Any]] = {}
for name, (_plugin_name, tool) in tools_map.items():
fn = tool.get("function", {})
schema_map[name] = fn.get("parameters", {}) or {}
return schema_map
def validate_tool_arguments(
tool_name: str,
arguments: Dict[str, Any],
schema: Optional[Dict[str, Any]],
) -> Tuple[bool, str, Dict[str, Any]]:
"""
轻量校验并补全默认参数。
Returns:
(ok, error_message, new_arguments)
"""
if not schema:
return True, "", arguments
props = schema.get("properties", {}) or {}
required = schema.get("required", []) or []
# 应用默认值
for key, prop in props.items():
if key not in arguments and isinstance(prop, dict) and "default" in prop:
arguments[key] = prop["default"]
missing = []
for key in required:
if key not in arguments or arguments[key] in (None, "", []):
missing.append(key)
if missing:
return False, f"缺少参数: {', '.join(missing)}", arguments
# 枚举与基础类型校验
for key, prop in props.items():
if key not in arguments or not isinstance(prop, dict):
continue
value = arguments[key]
if "enum" in prop and value not in prop["enum"]:
return False, f"参数 {key} 必须是 {prop['enum']}", arguments
expected_type = prop.get("type")
if expected_type == "integer":
try:
arguments[key] = int(value)
except Exception:
return False, f"参数 {key} 应为整数", arguments
elif expected_type == "number":
try:
arguments[key] = float(value)
except Exception:
return False, f"参数 {key} 应为数字", arguments
elif expected_type == "boolean":
if isinstance(value, bool):
continue
if isinstance(value, str) and value.lower() in ("true", "false", "1", "0"):
arguments[key] = value.lower() in ("true", "1")
else:
return False, f"参数 {key} 应为布尔值", arguments
elif expected_type == "string":
if not isinstance(value, str):
arguments[key] = str(value)
return True, "", arguments

View File

@@ -322,7 +322,18 @@ class RedisCache:
logger.error(f"获取对话历史失败: {chat_id}, {e}")
return []
def add_chat_message(self, chat_id: str, role: str, content, ttl: int = 86400) -> bool:
def add_chat_message(
self,
chat_id: str,
role: str,
content,
ttl: int = 86400,
*,
nickname: str = None,
sender_wxid: str = None,
record_id: str = None,
timestamp: float = None,
) -> bool:
"""
添加消息到对话历史
@@ -331,6 +342,10 @@ class RedisCache:
role: 角色 (user/assistant)
content: 消息内容(字符串或列表)
ttl: 过期时间默认24小时
nickname: 可选昵称(用于统一 schema
sender_wxid: 可选发送者 wxid
record_id: 可选记录 ID
timestamp: 可选时间戳
Returns:
是否添加成功
@@ -340,7 +355,18 @@ class RedisCache:
try:
key = self._make_key("chat_history", chat_id)
message = {"role": role, "content": content}
import time as _time
message = {
"role": role or "user",
"content": content,
}
if nickname:
message["nickname"] = nickname
if sender_wxid:
message["wxid"] = sender_wxid
if record_id:
message["id"] = record_id
message["timestamp"] = timestamp or _time.time()
self.client.rpush(key, json.dumps(message, ensure_ascii=False))
self.client.expire(key, ttl)
return True
@@ -416,8 +442,17 @@ class RedisCache:
logger.error(f"获取群聊历史失败: {group_id}, {e}")
return []
def add_group_message(self, group_id: str, nickname: str, content,
record_id: str = None, ttl: int = 86400) -> bool:
def add_group_message(
self,
group_id: str,
nickname: str,
content,
record_id: str = None,
*,
role: str = "user",
sender_wxid: str = None,
ttl: int = 86400,
) -> bool:
"""
添加消息到群聊历史
@@ -426,6 +461,8 @@ class RedisCache:
nickname: 发送者昵称
content: 消息内容
record_id: 可选的记录ID用于后续更新
role: 角色 (user/assistant),默认 user
sender_wxid: 可选的发送者 wxid
ttl: 过期时间默认24小时
Returns:
@@ -438,10 +475,13 @@ class RedisCache:
import time
key = self._make_key("group_history", group_id)
message = {
"role": role or "user",
"nickname": nickname,
"content": content,
"timestamp": time.time()
}
if sender_wxid:
message["wxid"] = sender_wxid
if record_id:
message["id"] = record_id