feat:优化屎山
This commit is contained in:
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
@@ -115,6 +115,19 @@ class AutoReply(PluginBase):
|
||||
logger.error(f"[AutoReply] 初始化失败: {e}")
|
||||
self.config = None
|
||||
|
||||
async def on_disable(self):
|
||||
"""插件禁用时调用,清理后台判断任务"""
|
||||
await super().on_disable()
|
||||
|
||||
if self.pending_tasks:
|
||||
for task in self.pending_tasks.values():
|
||||
task.cancel()
|
||||
await asyncio.gather(*self.pending_tasks.values(), return_exceptions=True)
|
||||
self.pending_tasks.clear()
|
||||
|
||||
self.judging.clear()
|
||||
logger.info("[AutoReply] 已清理后台判断任务")
|
||||
|
||||
def _load_bot_info(self):
|
||||
"""加载机器人信息"""
|
||||
try:
|
||||
@@ -294,6 +307,8 @@ class AutoReply(PluginBase):
|
||||
# 直接调用 AIChat 生成回复(基于最新上下文)
|
||||
await self._trigger_ai_reply(bot, pending.from_wxid)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[AutoReply] 后台判断异常: {e}")
|
||||
import traceback
|
||||
@@ -316,8 +331,7 @@ class AutoReply(PluginBase):
|
||||
return
|
||||
|
||||
# 获取最新的历史记录作为上下文
|
||||
chat_id = self._normalize_chat_id(from_wxid)
|
||||
recent_context = await self._get_recent_context_for_reply(chat_id)
|
||||
recent_context = await self._get_recent_context_for_reply(from_wxid)
|
||||
|
||||
if not recent_context:
|
||||
logger.warning("[AutoReply] 无法获取上下文")
|
||||
@@ -342,10 +356,10 @@ class AutoReply(PluginBase):
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _get_recent_context_for_reply(self, chat_id: str) -> str:
|
||||
async def _get_recent_context_for_reply(self, group_id: str) -> str:
|
||||
"""获取最近的上下文用于生成回复"""
|
||||
try:
|
||||
history = await self._get_history(chat_id)
|
||||
history = await self._get_history(group_id)
|
||||
if not history:
|
||||
return ""
|
||||
|
||||
@@ -353,35 +367,10 @@ class AutoReply(PluginBase):
|
||||
count = self.config.get('context', {}).get('messages_count', 5)
|
||||
recent = history[-count:] if len(history) > count else history
|
||||
|
||||
# 构建上下文摘要
|
||||
context_lines = []
|
||||
for record in recent:
|
||||
nickname = record.get('nickname', '未知')
|
||||
content = record.get('content', '')
|
||||
if isinstance(content, list):
|
||||
# 多模态内容,提取文本
|
||||
for item in content:
|
||||
if item.get('type') == 'text':
|
||||
content = item.get('text', '')
|
||||
break
|
||||
else:
|
||||
content = '[图片]'
|
||||
if len(content) > 50:
|
||||
content = content[:50] + "..."
|
||||
context_lines.append(f"{nickname}: {content}")
|
||||
|
||||
# 返回最后一条消息作为触发内容(AIChat 会读取完整历史)
|
||||
if recent:
|
||||
last = recent[-1]
|
||||
last_content = last.get('content', '')
|
||||
if isinstance(last_content, list):
|
||||
for item in last_content:
|
||||
if item.get('type') == 'text':
|
||||
return item.get('text', '')
|
||||
return '[图片]'
|
||||
return last_content
|
||||
|
||||
return ""
|
||||
# 自动回复触发不再把最后一条用户消息再次发给 AI,
|
||||
# 避免在上下文里出现“同一句话重复两遍”的错觉。
|
||||
# AIChat 会读取完整历史 recent_history。
|
||||
return "(自动回复触发)请基于最近群聊内容,自然地回复一句,不要复述提示本身。"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AutoReply] 获取上下文失败: {e}")
|
||||
@@ -389,12 +378,13 @@ class AutoReply(PluginBase):
|
||||
|
||||
async def _judge_with_small_model(self, from_wxid: str, content: str) -> JudgeResult:
|
||||
"""使用小模型判断是否需要回复"""
|
||||
chat_id = self._normalize_chat_id(from_wxid)
|
||||
chat_state = self._get_chat_state(chat_id)
|
||||
group_id = from_wxid
|
||||
state_id = self._normalize_chat_id(group_id)
|
||||
chat_state = self._get_chat_state(state_id)
|
||||
|
||||
# 获取最近消息历史
|
||||
recent_messages = await self._get_recent_messages(chat_id)
|
||||
last_bot_reply = await self._get_last_bot_reply(chat_id)
|
||||
recent_messages = await self._get_recent_messages(group_id)
|
||||
last_bot_reply = await self._get_last_bot_reply(group_id)
|
||||
|
||||
# 构建判断提示词
|
||||
reasoning_part = ',\n "reasoning": "简短分析原因(20字内)"' if self.config["judge"]["include_reasoning"] else ""
|
||||
@@ -403,7 +393,7 @@ class AutoReply(PluginBase):
|
||||
|
||||
## 当前状态
|
||||
- 精力: {chat_state.energy:.1f}/1.0
|
||||
- 上次发言: {self._get_minutes_since_last_reply(chat_id)}分钟前
|
||||
- 上次发言: {self._get_minutes_since_last_reply(state_id)}分钟前
|
||||
|
||||
## 最近对话
|
||||
{recent_messages}
|
||||
@@ -531,6 +521,13 @@ class AutoReply(PluginBase):
|
||||
if not aichat_plugin:
|
||||
return []
|
||||
|
||||
# 优先使用 AIChat 的统一 ContextStore
|
||||
if hasattr(aichat_plugin, "store") and aichat_plugin.store:
|
||||
try:
|
||||
return await aichat_plugin.store.load_group_history(chat_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"[AutoReply] ContextStore 获取历史失败: {e}")
|
||||
|
||||
# 优先使用 Redis(与 AIChat 保持一致)
|
||||
try:
|
||||
from utils.redis_cache import get_cache
|
||||
@@ -548,7 +545,8 @@ class AutoReply(PluginBase):
|
||||
|
||||
# 降级到文件存储
|
||||
if hasattr(aichat_plugin, 'history_dir') and aichat_plugin.history_dir:
|
||||
history_file = aichat_plugin.history_dir / f"{chat_id}.json"
|
||||
safe_id = (chat_id or "").replace("@", "_").replace(":", "_")
|
||||
history_file = aichat_plugin.history_dir / f"{safe_id}.json"
|
||||
if history_file.exists():
|
||||
with open(history_file, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
@@ -558,10 +556,10 @@ class AutoReply(PluginBase):
|
||||
|
||||
return []
|
||||
|
||||
async def _get_recent_messages(self, chat_id: str) -> str:
|
||||
"""获取最近消息历史"""
|
||||
async def _get_recent_messages(self, group_id: str) -> str:
|
||||
"""获取最近消息历史(群聊)"""
|
||||
try:
|
||||
history = await self._get_history(chat_id)
|
||||
history = await self._get_history(group_id)
|
||||
if not history:
|
||||
return "暂无对话历史"
|
||||
|
||||
@@ -572,6 +570,12 @@ class AutoReply(PluginBase):
|
||||
for record in recent:
|
||||
nickname = record.get('nickname', '未知')
|
||||
content = record.get('content', '')
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
content = "".join(text_parts).strip() or "[图片]"
|
||||
# 限制单条消息长度
|
||||
if len(content) > 100:
|
||||
content = content[:100] + "..."
|
||||
@@ -584,17 +588,23 @@ class AutoReply(PluginBase):
|
||||
|
||||
return "暂无对话历史"
|
||||
|
||||
async def _get_last_bot_reply(self, chat_id: str) -> Optional[str]:
|
||||
"""获取上次机器人回复"""
|
||||
async def _get_last_bot_reply(self, group_id: str) -> Optional[str]:
|
||||
"""获取上次机器人回复(群聊)"""
|
||||
try:
|
||||
history = await self._get_history(chat_id)
|
||||
history = await self._get_history(group_id)
|
||||
if not history:
|
||||
return None
|
||||
|
||||
# 从后往前查找机器人回复
|
||||
for record in reversed(history):
|
||||
if record.get('nickname') == self.bot_nickname:
|
||||
if record.get('role') == 'assistant' or record.get('nickname') == self.bot_nickname:
|
||||
content = record.get('content', '')
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
content = "".join(text_parts).strip() or "[图片]"
|
||||
if len(content) > 100:
|
||||
content = content[:100] + "..."
|
||||
return content
|
||||
|
||||
@@ -1606,9 +1606,8 @@ class SignInPlugin(PluginBase):
|
||||
return {"success": True, "message": f"城市注册请求已处理: {city}"}
|
||||
|
||||
else:
|
||||
return {"success": False, "message": "未知的工具名称"}
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM工具执行失败: {e}")
|
||||
return {"success": False, "message": f"执行失败: {str(e)}"}
|
||||
|
||||
|
||||
@@ -322,7 +322,7 @@ class WeatherPlugin(PluginBase):
|
||||
"""执行LLM工具调用,供AIChat插件调用"""
|
||||
try:
|
||||
if tool_name != "query_weather":
|
||||
return {"success": False, "message": "未知的工具名称"}
|
||||
return None
|
||||
|
||||
# 从 arguments 中获取用户信息
|
||||
user_wxid = arguments.get("user_wxid", from_wxid)
|
||||
|
||||
@@ -345,7 +345,7 @@ class ZImageTurbo(PluginBase):
|
||||
async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict:
|
||||
"""执行LLM工具调用,供AIChat插件调用"""
|
||||
if tool_name != "generate_image":
|
||||
return {"success": False, "message": "未知的工具名称"}
|
||||
return None
|
||||
|
||||
try:
|
||||
prompt = arguments.get("prompt", "")
|
||||
|
||||
470
utils/context_store.py
Normal file
470
utils/context_store.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
上下文/存储统一封装
|
||||
|
||||
提供统一的会话上下文读写接口:
|
||||
- 私聊/单人会话 memory(优先 Redis,降级内存)
|
||||
- 群聊 history(优先 Redis,降级文件)
|
||||
- 持久记忆 sqlite
|
||||
|
||||
AIChat 只需要通过本模块读写消息,不再关心介质细节。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from utils.redis_cache import get_cache
|
||||
|
||||
|
||||
def _safe_chat_id(chat_id: str) -> str:
|
||||
return (chat_id or "").replace("@", "_").replace(":", "_")
|
||||
|
||||
|
||||
def _extract_text_from_multimodal(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
parts.append(item.get("text", ""))
|
||||
return "".join(parts).strip()
|
||||
return str(content)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HistoryRecord:
|
||||
role: str = "user"
|
||||
nickname: str = ""
|
||||
content: Any = ""
|
||||
timestamp: Any = None
|
||||
wxid: Optional[str] = None
|
||||
id: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, raw: Dict[str, Any]) -> "HistoryRecord":
|
||||
role = raw.get("role") or "user"
|
||||
nickname = raw.get("nickname") or raw.get("SenderNickname") or ""
|
||||
content = raw.get("content") if "content" in raw else raw.get("Content", "")
|
||||
ts = raw.get("timestamp") or raw.get("time") or raw.get("CreateTime")
|
||||
wxid = raw.get("wxid") or raw.get("SenderWxid")
|
||||
rid = raw.get("id") or raw.get("msgid")
|
||||
return cls(role=role, nickname=nickname, content=content, timestamp=ts, wxid=wxid, id=rid)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = {
|
||||
"role": self.role or "user",
|
||||
"nickname": self.nickname,
|
||||
"content": self.content,
|
||||
}
|
||||
if self.timestamp is not None:
|
||||
d["timestamp"] = self.timestamp
|
||||
if self.wxid:
|
||||
d["wxid"] = self.wxid
|
||||
if self.id:
|
||||
d["id"] = self.id
|
||||
return d
|
||||
|
||||
|
||||
class ContextStore:
|
||||
"""
|
||||
统一上下文存储。
|
||||
|
||||
Args:
|
||||
config: AIChat 配置 dict
|
||||
history_dir: 历史文件目录(群聊降级)
|
||||
memory_fallback: AIChat 内存 dict(私聊降级)
|
||||
history_locks: AIChat locks dict(文件写入)
|
||||
persistent_db_path: sqlite 文件路径
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
history_dir: Optional[Path],
|
||||
memory_fallback: Dict[str, List[Dict[str, Any]]],
|
||||
history_locks: Dict[str, asyncio.Lock],
|
||||
persistent_db_path: Optional[Path],
|
||||
):
|
||||
self.config = config or {}
|
||||
self.history_dir = history_dir
|
||||
self.memory_fallback = memory_fallback
|
||||
self.history_locks = history_locks
|
||||
self.persistent_db_path = persistent_db_path
|
||||
|
||||
# ------------------ 私聊 memory ------------------
|
||||
|
||||
def _use_redis_for_memory(self) -> bool:
|
||||
redis_config = self.config.get("redis", {})
|
||||
if not redis_config.get("use_redis_history", True):
|
||||
return False
|
||||
redis_cache = get_cache()
|
||||
return bool(redis_cache and redis_cache.enabled)
|
||||
|
||||
def add_private_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
role: str,
|
||||
content: Any,
|
||||
*,
|
||||
image_base64: str = None,
|
||||
nickname: str = "",
|
||||
sender_wxid: str = None,
|
||||
) -> None:
|
||||
if not self.config.get("memory", {}).get("enabled", False):
|
||||
return
|
||||
|
||||
if image_base64:
|
||||
message_content = [
|
||||
{"type": "text", "text": _extract_text_from_multimodal(content)},
|
||||
{"type": "image_url", "image_url": {"url": image_base64}},
|
||||
]
|
||||
else:
|
||||
message_content = content
|
||||
|
||||
redis_config = self.config.get("redis", {})
|
||||
if self._use_redis_for_memory():
|
||||
redis_cache = get_cache()
|
||||
ttl = redis_config.get("chat_history_ttl", 86400)
|
||||
try:
|
||||
redis_cache.add_chat_message(
|
||||
chat_id,
|
||||
role,
|
||||
message_content,
|
||||
nickname=nickname,
|
||||
sender_wxid=sender_wxid,
|
||||
ttl=ttl,
|
||||
)
|
||||
max_messages = self.config.get("memory", {}).get("max_messages", 20)
|
||||
redis_cache.trim_chat_history(chat_id, max_messages)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"[ContextStore] Redis private history 写入失败: {e}")
|
||||
|
||||
if chat_id not in self.memory_fallback:
|
||||
self.memory_fallback[chat_id] = []
|
||||
self.memory_fallback[chat_id].append({"role": role, "content": message_content})
|
||||
max_messages = self.config.get("memory", {}).get("max_messages", 20)
|
||||
if len(self.memory_fallback[chat_id]) > max_messages:
|
||||
self.memory_fallback[chat_id] = self.memory_fallback[chat_id][-max_messages:]
|
||||
|
||||
def get_private_messages(self, chat_id: str) -> List[Dict[str, Any]]:
|
||||
if not self.config.get("memory", {}).get("enabled", False):
|
||||
return []
|
||||
|
||||
if self._use_redis_for_memory():
|
||||
redis_cache = get_cache()
|
||||
max_messages = self.config.get("memory", {}).get("max_messages", 20)
|
||||
try:
|
||||
history = redis_cache.get_chat_history(chat_id, max_messages)
|
||||
return [HistoryRecord.from_raw(h).to_dict() for h in history]
|
||||
except Exception as e:
|
||||
logger.debug(f"[ContextStore] Redis private history 读取失败: {e}")
|
||||
|
||||
return self.memory_fallback.get(chat_id, [])
|
||||
|
||||
def clear_private_messages(self, chat_id: str) -> None:
|
||||
if self._use_redis_for_memory():
|
||||
redis_cache = get_cache()
|
||||
try:
|
||||
redis_cache.clear_chat_history(chat_id)
|
||||
except Exception:
|
||||
pass
|
||||
self.memory_fallback.pop(chat_id, None)
|
||||
|
||||
# ------------------ 群聊 history ------------------
|
||||
|
||||
def _use_redis_for_group_history(self) -> bool:
|
||||
redis_config = self.config.get("redis", {})
|
||||
if not redis_config.get("use_redis_history", True):
|
||||
return False
|
||||
redis_cache = get_cache()
|
||||
return bool(redis_cache and redis_cache.enabled)
|
||||
|
||||
def _get_history_file(self, chat_id: str) -> Optional[Path]:
|
||||
if not self.history_dir:
|
||||
return None
|
||||
return self.history_dir / f"{_safe_chat_id(chat_id)}.json"
|
||||
|
||||
def _get_history_lock(self, chat_id: str) -> asyncio.Lock:
|
||||
lock = self.history_locks.get(chat_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
self.history_locks[chat_id] = lock
|
||||
return lock
|
||||
|
||||
def _read_history_file(self, history_file: Path) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
with open(history_file, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"读取历史记录失败: {history_file}, {e}")
|
||||
return []
|
||||
|
||||
def _write_history_file(self, history_file: Path, history: List[Dict[str, Any]]) -> None:
|
||||
history_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
temp_file = Path(str(history_file) + ".tmp")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
json.dump(history, f, ensure_ascii=False, indent=2)
|
||||
temp_file.replace(history_file)
|
||||
|
||||
async def load_group_history(self, chat_id: str) -> List[Dict[str, Any]]:
|
||||
if not self.config.get("history", {}).get("enabled", True):
|
||||
return []
|
||||
|
||||
if self._use_redis_for_group_history():
|
||||
redis_cache = get_cache()
|
||||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||||
try:
|
||||
history = redis_cache.get_group_history(chat_id, max_history)
|
||||
return [HistoryRecord.from_raw(h).to_dict() for h in history]
|
||||
except Exception as e:
|
||||
logger.debug(f"[ContextStore] Redis group history 读取失败: {e}")
|
||||
|
||||
history_file = self._get_history_file(chat_id)
|
||||
if not history_file:
|
||||
return []
|
||||
|
||||
lock = self._get_history_lock(chat_id)
|
||||
async with lock:
|
||||
raw_history = self._read_history_file(history_file)
|
||||
return [HistoryRecord.from_raw(h).to_dict() for h in raw_history]
|
||||
|
||||
async def add_group_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
nickname: str,
|
||||
content: Any,
|
||||
*,
|
||||
record_id: str = None,
|
||||
image_base64: str = None,
|
||||
role: str = "user",
|
||||
sender_wxid: str = None,
|
||||
) -> None:
|
||||
if not self.config.get("history", {}).get("enabled", True):
|
||||
return
|
||||
|
||||
if image_base64:
|
||||
message_content = [
|
||||
{"type": "text", "text": _extract_text_from_multimodal(content)},
|
||||
{"type": "image_url", "image_url": {"url": image_base64}},
|
||||
]
|
||||
else:
|
||||
message_content = content
|
||||
|
||||
if self._use_redis_for_group_history():
|
||||
redis_cache = get_cache()
|
||||
redis_config = self.config.get("redis", {})
|
||||
ttl = redis_config.get("group_history_ttl", 172800)
|
||||
try:
|
||||
redis_cache.add_group_message(
|
||||
chat_id,
|
||||
nickname,
|
||||
message_content,
|
||||
record_id=record_id,
|
||||
role=role,
|
||||
sender_wxid=sender_wxid,
|
||||
ttl=ttl,
|
||||
)
|
||||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||||
redis_cache.trim_group_history(chat_id, max_history)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"[ContextStore] Redis group history 写入失败: {e}")
|
||||
|
||||
history_file = self._get_history_file(chat_id)
|
||||
if not history_file:
|
||||
return
|
||||
|
||||
lock = self._get_history_lock(chat_id)
|
||||
async with lock:
|
||||
history = self._read_history_file(history_file)
|
||||
record = HistoryRecord(
|
||||
role=role or "user",
|
||||
nickname=nickname,
|
||||
content=message_content,
|
||||
timestamp=datetime.now().isoformat(),
|
||||
wxid=sender_wxid,
|
||||
id=record_id,
|
||||
)
|
||||
history.append(record.to_dict())
|
||||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||||
if len(history) > max_history:
|
||||
history = history[-max_history:]
|
||||
self._write_history_file(history_file, history)
|
||||
|
||||
async def update_group_message_by_id(self, chat_id: str, record_id: str, new_content: Any) -> None:
|
||||
if not self.config.get("history", {}).get("enabled", True):
|
||||
return
|
||||
|
||||
if self._use_redis_for_group_history():
|
||||
redis_cache = get_cache()
|
||||
try:
|
||||
redis_cache.update_group_message_by_id(chat_id, record_id, new_content)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"[ContextStore] Redis group history 更新失败: {e}")
|
||||
|
||||
history_file = self._get_history_file(chat_id)
|
||||
if not history_file:
|
||||
return
|
||||
|
||||
lock = self._get_history_lock(chat_id)
|
||||
async with lock:
|
||||
history = self._read_history_file(history_file)
|
||||
for rec in history:
|
||||
if rec.get("id") == record_id:
|
||||
rec["content"] = new_content
|
||||
break
|
||||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||||
if len(history) > max_history:
|
||||
history = history[-max_history:]
|
||||
self._write_history_file(history_file, history)
|
||||
|
||||
# ------------------ 持久记忆 sqlite ------------------
|
||||
|
||||
def init_persistent_memory_db(self) -> Optional[Path]:
|
||||
if not self.persistent_db_path:
|
||||
return None
|
||||
|
||||
self.persistent_db_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
conn = sqlite3.connect(self.persistent_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
chat_id TEXT NOT NULL,
|
||||
chat_type TEXT NOT NULL,
|
||||
user_wxid TEXT NOT NULL,
|
||||
user_nickname TEXT,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_chat_id ON memories(chat_id)")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info(f"持久记忆数据库已初始化: {self.persistent_db_path}")
|
||||
return self.persistent_db_path
|
||||
|
||||
def add_persistent_memory(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_type: str,
|
||||
user_wxid: str,
|
||||
user_nickname: str,
|
||||
content: str,
|
||||
) -> int:
|
||||
if not self.persistent_db_path:
|
||||
return -1
|
||||
|
||||
conn = sqlite3.connect(self.persistent_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO memories (chat_id, chat_type, user_wxid, user_nickname, content)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(chat_id, chat_type, user_wxid, user_nickname, content),
|
||||
)
|
||||
memory_id = cursor.lastrowid
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return memory_id
|
||||
|
||||
def get_persistent_memories(self, chat_id: str) -> List[Dict[str, Any]]:
|
||||
if not self.persistent_db_path:
|
||||
return []
|
||||
|
||||
conn = sqlite3.connect(self.persistent_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT id, user_nickname, content, created_at
|
||||
FROM memories
|
||||
WHERE chat_id = ?
|
||||
ORDER BY created_at ASC
|
||||
""",
|
||||
(chat_id,),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
return [
|
||||
{"id": r[0], "nickname": r[1], "content": r[2], "time": r[3]}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
def delete_persistent_memory(self, chat_id: str, memory_id: int) -> bool:
|
||||
if not self.persistent_db_path:
|
||||
return False
|
||||
|
||||
conn = sqlite3.connect(self.persistent_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"DELETE FROM memories WHERE id = ? AND chat_id = ?",
|
||||
(memory_id, chat_id),
|
||||
)
|
||||
deleted = cursor.rowcount > 0
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return deleted
|
||||
|
||||
def clear_persistent_memories(self, chat_id: str) -> int:
|
||||
if not self.persistent_db_path:
|
||||
return 0
|
||||
|
||||
conn = sqlite3.connect(self.persistent_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM memories WHERE chat_id = ?", (chat_id,))
|
||||
deleted_count = cursor.rowcount
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return deleted_count
|
||||
|
||||
# ------------------ 旧数据扫描/清理 ------------------
|
||||
|
||||
def find_legacy_group_history_keys(self) -> List[str]:
|
||||
"""
|
||||
发现旧版本使用 safe_id 写入的 group_history key。
|
||||
|
||||
Returns:
|
||||
legacy_keys 列表(不删除)
|
||||
"""
|
||||
redis_cache = get_cache()
|
||||
if not redis_cache or not redis_cache.enabled:
|
||||
return []
|
||||
|
||||
try:
|
||||
keys = redis_cache.client.keys("group_history:*")
|
||||
legacy = []
|
||||
for k in keys or []:
|
||||
# 新 key 一般包含 @chatroom;旧 safe_id 不包含 @
|
||||
if "@chatroom" not in k and "_" in k:
|
||||
legacy.append(k)
|
||||
return legacy
|
||||
except Exception as e:
|
||||
logger.debug(f"[ContextStore] 扫描 legacy group_history keys 失败: {e}")
|
||||
return []
|
||||
|
||||
def delete_legacy_group_history_keys(self, legacy_keys: List[str]) -> int:
|
||||
"""删除给定 legacy key 列表"""
|
||||
redis_cache = get_cache()
|
||||
if not redis_cache or not redis_cache.enabled or not legacy_keys:
|
||||
return 0
|
||||
try:
|
||||
return redis_cache.client.delete(*legacy_keys)
|
||||
except Exception as e:
|
||||
logger.debug(f"[ContextStore] 删除 legacy group_history keys 失败: {e}")
|
||||
return 0
|
||||
183
utils/llm_tooling.py
Normal file
183
utils/llm_tooling.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
LLM 工具体系公共模块
|
||||
|
||||
统一工具收集、参数校验与执行结果结构,供 AIChat 等插件使用。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""统一的工具执行结果结构"""
|
||||
|
||||
success: bool = True
|
||||
message: str = ""
|
||||
need_ai_reply: bool = False
|
||||
already_sent: bool = False
|
||||
send_result_text: bool = False
|
||||
no_reply: bool = False
|
||||
save_to_memory: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, raw: Any) -> Optional["ToolResult"]:
|
||||
if raw is None:
|
||||
return None
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
return cls(success=True, message=str(raw))
|
||||
|
||||
msg = raw.get("message", "")
|
||||
if not isinstance(msg, str):
|
||||
try:
|
||||
msg = json.dumps(msg, ensure_ascii=False)
|
||||
except Exception:
|
||||
msg = str(msg)
|
||||
|
||||
return cls(
|
||||
success=bool(raw.get("success", True)),
|
||||
message=msg,
|
||||
need_ai_reply=bool(raw.get("need_ai_reply", False)),
|
||||
already_sent=bool(raw.get("already_sent", False)),
|
||||
send_result_text=bool(raw.get("send_result_text", False)),
|
||||
no_reply=bool(raw.get("no_reply", False)),
|
||||
save_to_memory=bool(raw.get("save_to_memory", False)),
|
||||
)
|
||||
|
||||
|
||||
def collect_tools_with_plugins(
|
||||
tools_config: Dict[str, Any],
|
||||
plugins: Dict[str, Any],
|
||||
) -> Dict[str, Tuple[str, Dict[str, Any]]]:
|
||||
"""
|
||||
收集所有插件的 LLM 工具,并保留来源插件名。
|
||||
|
||||
Args:
|
||||
tools_config: AIChat 配置中的 [tools] 节
|
||||
plugins: PluginManager().plugins 映射
|
||||
|
||||
Returns:
|
||||
{tool_name: (plugin_name, tool_dict)}
|
||||
"""
|
||||
tools_by_name: Dict[str, Tuple[str, Dict[str, Any]]] = {}
|
||||
|
||||
mode = tools_config.get("mode", "all")
|
||||
whitelist = set(tools_config.get("whitelist", []))
|
||||
blacklist = set(tools_config.get("blacklist", []))
|
||||
|
||||
for plugin_name, plugin in plugins.items():
|
||||
if not hasattr(plugin, "get_llm_tools"):
|
||||
continue
|
||||
|
||||
plugin_tools = plugin.get_llm_tools() or []
|
||||
for tool in plugin_tools:
|
||||
tool_name = tool.get("function", {}).get("name", "")
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
if mode == "whitelist" and tool_name not in whitelist:
|
||||
continue
|
||||
if mode == "blacklist" and tool_name in blacklist:
|
||||
logger.debug(f"[黑名单] 禁用工具: {tool_name}")
|
||||
continue
|
||||
|
||||
if tool_name in tools_by_name:
|
||||
logger.warning(f"重复工具名 {tool_name} 来自 {plugin_name},已忽略")
|
||||
continue
|
||||
|
||||
tools_by_name[tool_name] = (plugin_name, tool)
|
||||
if mode == "whitelist":
|
||||
logger.debug(f"[白名单] 启用工具: {tool_name}")
|
||||
|
||||
return tools_by_name
|
||||
|
||||
|
||||
def collect_tools(
|
||||
tools_config: Dict[str, Any],
|
||||
plugins: Dict[str, Any],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""仅返回工具定义列表"""
|
||||
return [item[1] for item in collect_tools_with_plugins(tools_config, plugins).values()]
|
||||
|
||||
|
||||
def get_tool_schema_map(
|
||||
tools_map: Dict[str, Tuple[str, Dict[str, Any]]],
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""构建工具名到参数 schema 的映射"""
|
||||
schema_map: Dict[str, Dict[str, Any]] = {}
|
||||
for name, (_plugin_name, tool) in tools_map.items():
|
||||
fn = tool.get("function", {})
|
||||
schema_map[name] = fn.get("parameters", {}) or {}
|
||||
return schema_map
|
||||
|
||||
|
||||
def validate_tool_arguments(
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
schema: Optional[Dict[str, Any]],
|
||||
) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""
|
||||
轻量校验并补全默认参数。
|
||||
|
||||
Returns:
|
||||
(ok, error_message, new_arguments)
|
||||
"""
|
||||
if not schema:
|
||||
return True, "", arguments
|
||||
|
||||
props = schema.get("properties", {}) or {}
|
||||
required = schema.get("required", []) or []
|
||||
|
||||
# 应用默认值
|
||||
for key, prop in props.items():
|
||||
if key not in arguments and isinstance(prop, dict) and "default" in prop:
|
||||
arguments[key] = prop["default"]
|
||||
|
||||
missing = []
|
||||
for key in required:
|
||||
if key not in arguments or arguments[key] in (None, "", []):
|
||||
missing.append(key)
|
||||
|
||||
if missing:
|
||||
return False, f"缺少参数: {', '.join(missing)}", arguments
|
||||
|
||||
# 枚举与基础类型校验
|
||||
for key, prop in props.items():
|
||||
if key not in arguments or not isinstance(prop, dict):
|
||||
continue
|
||||
|
||||
value = arguments[key]
|
||||
|
||||
if "enum" in prop and value not in prop["enum"]:
|
||||
return False, f"参数 {key} 必须是 {prop['enum']}", arguments
|
||||
|
||||
expected_type = prop.get("type")
|
||||
if expected_type == "integer":
|
||||
try:
|
||||
arguments[key] = int(value)
|
||||
except Exception:
|
||||
return False, f"参数 {key} 应为整数", arguments
|
||||
elif expected_type == "number":
|
||||
try:
|
||||
arguments[key] = float(value)
|
||||
except Exception:
|
||||
return False, f"参数 {key} 应为数字", arguments
|
||||
elif expected_type == "boolean":
|
||||
if isinstance(value, bool):
|
||||
continue
|
||||
if isinstance(value, str) and value.lower() in ("true", "false", "1", "0"):
|
||||
arguments[key] = value.lower() in ("true", "1")
|
||||
else:
|
||||
return False, f"参数 {key} 应为布尔值", arguments
|
||||
elif expected_type == "string":
|
||||
if not isinstance(value, str):
|
||||
arguments[key] = str(value)
|
||||
|
||||
return True, "", arguments
|
||||
|
||||
@@ -322,7 +322,18 @@ class RedisCache:
|
||||
logger.error(f"获取对话历史失败: {chat_id}, {e}")
|
||||
return []
|
||||
|
||||
def add_chat_message(self, chat_id: str, role: str, content, ttl: int = 86400) -> bool:
|
||||
def add_chat_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
role: str,
|
||||
content,
|
||||
ttl: int = 86400,
|
||||
*,
|
||||
nickname: str = None,
|
||||
sender_wxid: str = None,
|
||||
record_id: str = None,
|
||||
timestamp: float = None,
|
||||
) -> bool:
|
||||
"""
|
||||
添加消息到对话历史
|
||||
|
||||
@@ -331,6 +342,10 @@ class RedisCache:
|
||||
role: 角色 (user/assistant)
|
||||
content: 消息内容(字符串或列表)
|
||||
ttl: 过期时间(秒),默认24小时
|
||||
nickname: 可选昵称(用于统一 schema)
|
||||
sender_wxid: 可选发送者 wxid
|
||||
record_id: 可选记录 ID
|
||||
timestamp: 可选时间戳
|
||||
|
||||
Returns:
|
||||
是否添加成功
|
||||
@@ -340,7 +355,18 @@ class RedisCache:
|
||||
|
||||
try:
|
||||
key = self._make_key("chat_history", chat_id)
|
||||
message = {"role": role, "content": content}
|
||||
import time as _time
|
||||
message = {
|
||||
"role": role or "user",
|
||||
"content": content,
|
||||
}
|
||||
if nickname:
|
||||
message["nickname"] = nickname
|
||||
if sender_wxid:
|
||||
message["wxid"] = sender_wxid
|
||||
if record_id:
|
||||
message["id"] = record_id
|
||||
message["timestamp"] = timestamp or _time.time()
|
||||
self.client.rpush(key, json.dumps(message, ensure_ascii=False))
|
||||
self.client.expire(key, ttl)
|
||||
return True
|
||||
@@ -416,8 +442,17 @@ class RedisCache:
|
||||
logger.error(f"获取群聊历史失败: {group_id}, {e}")
|
||||
return []
|
||||
|
||||
def add_group_message(self, group_id: str, nickname: str, content,
|
||||
record_id: str = None, ttl: int = 86400) -> bool:
|
||||
def add_group_message(
|
||||
self,
|
||||
group_id: str,
|
||||
nickname: str,
|
||||
content,
|
||||
record_id: str = None,
|
||||
*,
|
||||
role: str = "user",
|
||||
sender_wxid: str = None,
|
||||
ttl: int = 86400,
|
||||
) -> bool:
|
||||
"""
|
||||
添加消息到群聊历史
|
||||
|
||||
@@ -426,6 +461,8 @@ class RedisCache:
|
||||
nickname: 发送者昵称
|
||||
content: 消息内容
|
||||
record_id: 可选的记录ID,用于后续更新
|
||||
role: 角色 (user/assistant),默认 user
|
||||
sender_wxid: 可选的发送者 wxid
|
||||
ttl: 过期时间(秒),默认24小时
|
||||
|
||||
Returns:
|
||||
@@ -438,10 +475,13 @@ class RedisCache:
|
||||
import time
|
||||
key = self._make_key("group_history", group_id)
|
||||
message = {
|
||||
"role": role or "user",
|
||||
"nickname": nickname,
|
||||
"content": content,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
if sender_wxid:
|
||||
message["wxid"] = sender_wxid
|
||||
if record_id:
|
||||
message["id"] = record_id
|
||||
|
||||
|
||||
Reference in New Issue
Block a user