feat:优化屎山

This commit is contained in:
2025-12-12 18:35:39 +08:00
parent 6156064f56
commit c1983172af
10 changed files with 1257 additions and 582 deletions

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

View File

@@ -115,6 +115,19 @@ class AutoReply(PluginBase):
logger.error(f"[AutoReply] 初始化失败: {e}") logger.error(f"[AutoReply] 初始化失败: {e}")
self.config = None self.config = None
async def on_disable(self):
"""插件禁用时调用,清理后台判断任务"""
await super().on_disable()
if self.pending_tasks:
for task in self.pending_tasks.values():
task.cancel()
await asyncio.gather(*self.pending_tasks.values(), return_exceptions=True)
self.pending_tasks.clear()
self.judging.clear()
logger.info("[AutoReply] 已清理后台判断任务")
def _load_bot_info(self): def _load_bot_info(self):
"""加载机器人信息""" """加载机器人信息"""
try: try:
@@ -294,6 +307,8 @@ class AutoReply(PluginBase):
# 直接调用 AIChat 生成回复(基于最新上下文) # 直接调用 AIChat 生成回复(基于最新上下文)
await self._trigger_ai_reply(bot, pending.from_wxid) await self._trigger_ai_reply(bot, pending.from_wxid)
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"[AutoReply] 后台判断异常: {e}") logger.error(f"[AutoReply] 后台判断异常: {e}")
import traceback import traceback
@@ -316,8 +331,7 @@ class AutoReply(PluginBase):
return return
# 获取最新的历史记录作为上下文 # 获取最新的历史记录作为上下文
chat_id = self._normalize_chat_id(from_wxid) recent_context = await self._get_recent_context_for_reply(from_wxid)
recent_context = await self._get_recent_context_for_reply(chat_id)
if not recent_context: if not recent_context:
logger.warning("[AutoReply] 无法获取上下文") logger.warning("[AutoReply] 无法获取上下文")
@@ -342,10 +356,10 @@ class AutoReply(PluginBase):
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
async def _get_recent_context_for_reply(self, chat_id: str) -> str: async def _get_recent_context_for_reply(self, group_id: str) -> str:
"""获取最近的上下文用于生成回复""" """获取最近的上下文用于生成回复"""
try: try:
history = await self._get_history(chat_id) history = await self._get_history(group_id)
if not history: if not history:
return "" return ""
@@ -353,35 +367,10 @@ class AutoReply(PluginBase):
count = self.config.get('context', {}).get('messages_count', 5) count = self.config.get('context', {}).get('messages_count', 5)
recent = history[-count:] if len(history) > count else history recent = history[-count:] if len(history) > count else history
# 构建上下文摘要 # 自动回复触发不再把最后一条用户消息再次发给 AI
context_lines = [] # 避免在上下文里出现“同一句话重复两遍”的错觉。
for record in recent: # AIChat 会读取完整历史 recent_history。
nickname = record.get('nickname', '未知') return "(自动回复触发)请基于最近群聊内容,自然地回复一句,不要复述提示本身。"
content = record.get('content', '')
if isinstance(content, list):
# 多模态内容,提取文本
for item in content:
if item.get('type') == 'text':
content = item.get('text', '')
break
else:
content = '[图片]'
if len(content) > 50:
content = content[:50] + "..."
context_lines.append(f"{nickname}: {content}")
# 返回最后一条消息作为触发内容AIChat 会读取完整历史)
if recent:
last = recent[-1]
last_content = last.get('content', '')
if isinstance(last_content, list):
for item in last_content:
if item.get('type') == 'text':
return item.get('text', '')
return '[图片]'
return last_content
return ""
except Exception as e: except Exception as e:
logger.error(f"[AutoReply] 获取上下文失败: {e}") logger.error(f"[AutoReply] 获取上下文失败: {e}")
@@ -389,12 +378,13 @@ class AutoReply(PluginBase):
async def _judge_with_small_model(self, from_wxid: str, content: str) -> JudgeResult: async def _judge_with_small_model(self, from_wxid: str, content: str) -> JudgeResult:
"""使用小模型判断是否需要回复""" """使用小模型判断是否需要回复"""
chat_id = self._normalize_chat_id(from_wxid) group_id = from_wxid
chat_state = self._get_chat_state(chat_id) state_id = self._normalize_chat_id(group_id)
chat_state = self._get_chat_state(state_id)
# 获取最近消息历史 # 获取最近消息历史
recent_messages = await self._get_recent_messages(chat_id) recent_messages = await self._get_recent_messages(group_id)
last_bot_reply = await self._get_last_bot_reply(chat_id) last_bot_reply = await self._get_last_bot_reply(group_id)
# 构建判断提示词 # 构建判断提示词
reasoning_part = ',\n "reasoning": "简短分析原因(20字内)"' if self.config["judge"]["include_reasoning"] else "" reasoning_part = ',\n "reasoning": "简短分析原因(20字内)"' if self.config["judge"]["include_reasoning"] else ""
@@ -403,7 +393,7 @@ class AutoReply(PluginBase):
## 当前状态 ## 当前状态
- 精力: {chat_state.energy:.1f}/1.0 - 精力: {chat_state.energy:.1f}/1.0
- 上次发言: {self._get_minutes_since_last_reply(chat_id)}分钟前 - 上次发言: {self._get_minutes_since_last_reply(state_id)}分钟前
## 最近对话 ## 最近对话
{recent_messages} {recent_messages}
@@ -531,6 +521,13 @@ class AutoReply(PluginBase):
if not aichat_plugin: if not aichat_plugin:
return [] return []
# 优先使用 AIChat 的统一 ContextStore
if hasattr(aichat_plugin, "store") and aichat_plugin.store:
try:
return await aichat_plugin.store.load_group_history(chat_id)
except Exception as e:
logger.debug(f"[AutoReply] ContextStore 获取历史失败: {e}")
# 优先使用 Redis与 AIChat 保持一致) # 优先使用 Redis与 AIChat 保持一致)
try: try:
from utils.redis_cache import get_cache from utils.redis_cache import get_cache
@@ -548,7 +545,8 @@ class AutoReply(PluginBase):
# 降级到文件存储 # 降级到文件存储
if hasattr(aichat_plugin, 'history_dir') and aichat_plugin.history_dir: if hasattr(aichat_plugin, 'history_dir') and aichat_plugin.history_dir:
history_file = aichat_plugin.history_dir / f"{chat_id}.json" safe_id = (chat_id or "").replace("@", "_").replace(":", "_")
history_file = aichat_plugin.history_dir / f"{safe_id}.json"
if history_file.exists(): if history_file.exists():
with open(history_file, "r", encoding="utf-8") as f: with open(history_file, "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
@@ -558,10 +556,10 @@ class AutoReply(PluginBase):
return [] return []
async def _get_recent_messages(self, chat_id: str) -> str: async def _get_recent_messages(self, group_id: str) -> str:
"""获取最近消息历史""" """获取最近消息历史(群聊)"""
try: try:
history = await self._get_history(chat_id) history = await self._get_history(group_id)
if not history: if not history:
return "暂无对话历史" return "暂无对话历史"
@@ -572,6 +570,12 @@ class AutoReply(PluginBase):
for record in recent: for record in recent:
nickname = record.get('nickname', '未知') nickname = record.get('nickname', '未知')
content = record.get('content', '') content = record.get('content', '')
if isinstance(content, list):
text_parts = []
for item in content:
if item.get("type") == "text":
text_parts.append(item.get("text", ""))
content = "".join(text_parts).strip() or "[图片]"
# 限制单条消息长度 # 限制单条消息长度
if len(content) > 100: if len(content) > 100:
content = content[:100] + "..." content = content[:100] + "..."
@@ -584,17 +588,23 @@ class AutoReply(PluginBase):
return "暂无对话历史" return "暂无对话历史"
async def _get_last_bot_reply(self, chat_id: str) -> Optional[str]: async def _get_last_bot_reply(self, group_id: str) -> Optional[str]:
"""获取上次机器人回复""" """获取上次机器人回复(群聊)"""
try: try:
history = await self._get_history(chat_id) history = await self._get_history(group_id)
if not history: if not history:
return None return None
# 从后往前查找机器人回复 # 从后往前查找机器人回复
for record in reversed(history): for record in reversed(history):
if record.get('nickname') == self.bot_nickname: if record.get('role') == 'assistant' or record.get('nickname') == self.bot_nickname:
content = record.get('content', '') content = record.get('content', '')
if isinstance(content, list):
text_parts = []
for item in content:
if item.get("type") == "text":
text_parts.append(item.get("text", ""))
content = "".join(text_parts).strip() or "[图片]"
if len(content) > 100: if len(content) > 100:
content = content[:100] + "..." content = content[:100] + "..."
return content return content

View File

@@ -1606,9 +1606,8 @@ class SignInPlugin(PluginBase):
return {"success": True, "message": f"城市注册请求已处理: {city}"} return {"success": True, "message": f"城市注册请求已处理: {city}"}
else: else:
return {"success": False, "message": "未知的工具名称"} return None
except Exception as e: except Exception as e:
logger.error(f"LLM工具执行失败: {e}") logger.error(f"LLM工具执行失败: {e}")
return {"success": False, "message": f"执行失败: {str(e)}"} return {"success": False, "message": f"执行失败: {str(e)}"}

View File

@@ -322,7 +322,7 @@ class WeatherPlugin(PluginBase):
"""执行LLM工具调用供AIChat插件调用""" """执行LLM工具调用供AIChat插件调用"""
try: try:
if tool_name != "query_weather": if tool_name != "query_weather":
return {"success": False, "message": "未知的工具名称"} return None
# 从 arguments 中获取用户信息 # 从 arguments 中获取用户信息
user_wxid = arguments.get("user_wxid", from_wxid) user_wxid = arguments.get("user_wxid", from_wxid)

View File

@@ -345,7 +345,7 @@ class ZImageTurbo(PluginBase):
async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict: async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict:
"""执行LLM工具调用供AIChat插件调用""" """执行LLM工具调用供AIChat插件调用"""
if tool_name != "generate_image": if tool_name != "generate_image":
return {"success": False, "message": "未知的工具名称"} return None
try: try:
prompt = arguments.get("prompt", "") prompt = arguments.get("prompt", "")

470
utils/context_store.py Normal file
View File

@@ -0,0 +1,470 @@
"""
上下文/存储统一封装
提供统一的会话上下文读写接口:
- 私聊/单人会话 memory优先 Redis降级内存
- 群聊 history优先 Redis降级文件
- 持久记忆 sqlite
AIChat 只需要通过本模块读写消息,不再关心介质细节。
"""
from __future__ import annotations
import asyncio
import json
import sqlite3
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from loguru import logger
from utils.redis_cache import get_cache
def _safe_chat_id(chat_id: str) -> str:
return (chat_id or "").replace("@", "_").replace(":", "_")
def _extract_text_from_multimodal(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(item.get("text", ""))
return "".join(parts).strip()
return str(content)
@dataclass
class HistoryRecord:
role: str = "user"
nickname: str = ""
content: Any = ""
timestamp: Any = None
wxid: Optional[str] = None
id: Optional[str] = None
@classmethod
def from_raw(cls, raw: Dict[str, Any]) -> "HistoryRecord":
role = raw.get("role") or "user"
nickname = raw.get("nickname") or raw.get("SenderNickname") or ""
content = raw.get("content") if "content" in raw else raw.get("Content", "")
ts = raw.get("timestamp") or raw.get("time") or raw.get("CreateTime")
wxid = raw.get("wxid") or raw.get("SenderWxid")
rid = raw.get("id") or raw.get("msgid")
return cls(role=role, nickname=nickname, content=content, timestamp=ts, wxid=wxid, id=rid)
def to_dict(self) -> Dict[str, Any]:
d = {
"role": self.role or "user",
"nickname": self.nickname,
"content": self.content,
}
if self.timestamp is not None:
d["timestamp"] = self.timestamp
if self.wxid:
d["wxid"] = self.wxid
if self.id:
d["id"] = self.id
return d
class ContextStore:
"""
统一上下文存储。
Args:
config: AIChat 配置 dict
history_dir: 历史文件目录(群聊降级)
memory_fallback: AIChat 内存 dict私聊降级
history_locks: AIChat locks dict文件写入
persistent_db_path: sqlite 文件路径
"""
def __init__(
self,
config: Dict[str, Any],
history_dir: Optional[Path],
memory_fallback: Dict[str, List[Dict[str, Any]]],
history_locks: Dict[str, asyncio.Lock],
persistent_db_path: Optional[Path],
):
self.config = config or {}
self.history_dir = history_dir
self.memory_fallback = memory_fallback
self.history_locks = history_locks
self.persistent_db_path = persistent_db_path
# ------------------ 私聊 memory ------------------
def _use_redis_for_memory(self) -> bool:
redis_config = self.config.get("redis", {})
if not redis_config.get("use_redis_history", True):
return False
redis_cache = get_cache()
return bool(redis_cache and redis_cache.enabled)
def add_private_message(
self,
chat_id: str,
role: str,
content: Any,
*,
image_base64: str = None,
nickname: str = "",
sender_wxid: str = None,
) -> None:
if not self.config.get("memory", {}).get("enabled", False):
return
if image_base64:
message_content = [
{"type": "text", "text": _extract_text_from_multimodal(content)},
{"type": "image_url", "image_url": {"url": image_base64}},
]
else:
message_content = content
redis_config = self.config.get("redis", {})
if self._use_redis_for_memory():
redis_cache = get_cache()
ttl = redis_config.get("chat_history_ttl", 86400)
try:
redis_cache.add_chat_message(
chat_id,
role,
message_content,
nickname=nickname,
sender_wxid=sender_wxid,
ttl=ttl,
)
max_messages = self.config.get("memory", {}).get("max_messages", 20)
redis_cache.trim_chat_history(chat_id, max_messages)
return
except Exception as e:
logger.debug(f"[ContextStore] Redis private history 写入失败: {e}")
if chat_id not in self.memory_fallback:
self.memory_fallback[chat_id] = []
self.memory_fallback[chat_id].append({"role": role, "content": message_content})
max_messages = self.config.get("memory", {}).get("max_messages", 20)
if len(self.memory_fallback[chat_id]) > max_messages:
self.memory_fallback[chat_id] = self.memory_fallback[chat_id][-max_messages:]
def get_private_messages(self, chat_id: str) -> List[Dict[str, Any]]:
if not self.config.get("memory", {}).get("enabled", False):
return []
if self._use_redis_for_memory():
redis_cache = get_cache()
max_messages = self.config.get("memory", {}).get("max_messages", 20)
try:
history = redis_cache.get_chat_history(chat_id, max_messages)
return [HistoryRecord.from_raw(h).to_dict() for h in history]
except Exception as e:
logger.debug(f"[ContextStore] Redis private history 读取失败: {e}")
return self.memory_fallback.get(chat_id, [])
def clear_private_messages(self, chat_id: str) -> None:
if self._use_redis_for_memory():
redis_cache = get_cache()
try:
redis_cache.clear_chat_history(chat_id)
except Exception:
pass
self.memory_fallback.pop(chat_id, None)
# ------------------ 群聊 history ------------------
def _use_redis_for_group_history(self) -> bool:
redis_config = self.config.get("redis", {})
if not redis_config.get("use_redis_history", True):
return False
redis_cache = get_cache()
return bool(redis_cache and redis_cache.enabled)
def _get_history_file(self, chat_id: str) -> Optional[Path]:
if not self.history_dir:
return None
return self.history_dir / f"{_safe_chat_id(chat_id)}.json"
def _get_history_lock(self, chat_id: str) -> asyncio.Lock:
lock = self.history_locks.get(chat_id)
if lock is None:
lock = asyncio.Lock()
self.history_locks[chat_id] = lock
return lock
def _read_history_file(self, history_file: Path) -> List[Dict[str, Any]]:
try:
with open(history_file, "r", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError:
return []
except Exception as e:
logger.error(f"读取历史记录失败: {history_file}, {e}")
return []
def _write_history_file(self, history_file: Path, history: List[Dict[str, Any]]) -> None:
history_file.parent.mkdir(parents=True, exist_ok=True)
temp_file = Path(str(history_file) + ".tmp")
with open(temp_file, "w", encoding="utf-8") as f:
json.dump(history, f, ensure_ascii=False, indent=2)
temp_file.replace(history_file)
async def load_group_history(self, chat_id: str) -> List[Dict[str, Any]]:
if not self.config.get("history", {}).get("enabled", True):
return []
if self._use_redis_for_group_history():
redis_cache = get_cache()
max_history = self.config.get("history", {}).get("max_history", 100)
try:
history = redis_cache.get_group_history(chat_id, max_history)
return [HistoryRecord.from_raw(h).to_dict() for h in history]
except Exception as e:
logger.debug(f"[ContextStore] Redis group history 读取失败: {e}")
history_file = self._get_history_file(chat_id)
if not history_file:
return []
lock = self._get_history_lock(chat_id)
async with lock:
raw_history = self._read_history_file(history_file)
return [HistoryRecord.from_raw(h).to_dict() for h in raw_history]
async def add_group_message(
self,
chat_id: str,
nickname: str,
content: Any,
*,
record_id: str = None,
image_base64: str = None,
role: str = "user",
sender_wxid: str = None,
) -> None:
if not self.config.get("history", {}).get("enabled", True):
return
if image_base64:
message_content = [
{"type": "text", "text": _extract_text_from_multimodal(content)},
{"type": "image_url", "image_url": {"url": image_base64}},
]
else:
message_content = content
if self._use_redis_for_group_history():
redis_cache = get_cache()
redis_config = self.config.get("redis", {})
ttl = redis_config.get("group_history_ttl", 172800)
try:
redis_cache.add_group_message(
chat_id,
nickname,
message_content,
record_id=record_id,
role=role,
sender_wxid=sender_wxid,
ttl=ttl,
)
max_history = self.config.get("history", {}).get("max_history", 100)
redis_cache.trim_group_history(chat_id, max_history)
return
except Exception as e:
logger.debug(f"[ContextStore] Redis group history 写入失败: {e}")
history_file = self._get_history_file(chat_id)
if not history_file:
return
lock = self._get_history_lock(chat_id)
async with lock:
history = self._read_history_file(history_file)
record = HistoryRecord(
role=role or "user",
nickname=nickname,
content=message_content,
timestamp=datetime.now().isoformat(),
wxid=sender_wxid,
id=record_id,
)
history.append(record.to_dict())
max_history = self.config.get("history", {}).get("max_history", 100)
if len(history) > max_history:
history = history[-max_history:]
self._write_history_file(history_file, history)
async def update_group_message_by_id(self, chat_id: str, record_id: str, new_content: Any) -> None:
if not self.config.get("history", {}).get("enabled", True):
return
if self._use_redis_for_group_history():
redis_cache = get_cache()
try:
redis_cache.update_group_message_by_id(chat_id, record_id, new_content)
return
except Exception as e:
logger.debug(f"[ContextStore] Redis group history 更新失败: {e}")
history_file = self._get_history_file(chat_id)
if not history_file:
return
lock = self._get_history_lock(chat_id)
async with lock:
history = self._read_history_file(history_file)
for rec in history:
if rec.get("id") == record_id:
rec["content"] = new_content
break
max_history = self.config.get("history", {}).get("max_history", 100)
if len(history) > max_history:
history = history[-max_history:]
self._write_history_file(history_file, history)
# ------------------ 持久记忆 sqlite ------------------
def init_persistent_memory_db(self) -> Optional[Path]:
if not self.persistent_db_path:
return None
self.persistent_db_path.parent.mkdir(exist_ok=True, parents=True)
conn = sqlite3.connect(self.persistent_db_path)
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS memories (
id INTEGER PRIMARY KEY AUTOINCREMENT,
chat_id TEXT NOT NULL,
chat_type TEXT NOT NULL,
user_wxid TEXT NOT NULL,
user_nickname TEXT,
content TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_chat_id ON memories(chat_id)")
conn.commit()
conn.close()
logger.info(f"持久记忆数据库已初始化: {self.persistent_db_path}")
return self.persistent_db_path
def add_persistent_memory(
self,
chat_id: str,
chat_type: str,
user_wxid: str,
user_nickname: str,
content: str,
) -> int:
if not self.persistent_db_path:
return -1
conn = sqlite3.connect(self.persistent_db_path)
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO memories (chat_id, chat_type, user_wxid, user_nickname, content)
VALUES (?, ?, ?, ?, ?)
""",
(chat_id, chat_type, user_wxid, user_nickname, content),
)
memory_id = cursor.lastrowid
conn.commit()
conn.close()
return memory_id
def get_persistent_memories(self, chat_id: str) -> List[Dict[str, Any]]:
if not self.persistent_db_path:
return []
conn = sqlite3.connect(self.persistent_db_path)
cursor = conn.cursor()
cursor.execute(
"""
SELECT id, user_nickname, content, created_at
FROM memories
WHERE chat_id = ?
ORDER BY created_at ASC
""",
(chat_id,),
)
rows = cursor.fetchall()
conn.close()
return [
{"id": r[0], "nickname": r[1], "content": r[2], "time": r[3]}
for r in rows
]
def delete_persistent_memory(self, chat_id: str, memory_id: int) -> bool:
if not self.persistent_db_path:
return False
conn = sqlite3.connect(self.persistent_db_path)
cursor = conn.cursor()
cursor.execute(
"DELETE FROM memories WHERE id = ? AND chat_id = ?",
(memory_id, chat_id),
)
deleted = cursor.rowcount > 0
conn.commit()
conn.close()
return deleted
def clear_persistent_memories(self, chat_id: str) -> int:
if not self.persistent_db_path:
return 0
conn = sqlite3.connect(self.persistent_db_path)
cursor = conn.cursor()
cursor.execute("DELETE FROM memories WHERE chat_id = ?", (chat_id,))
deleted_count = cursor.rowcount
conn.commit()
conn.close()
return deleted_count
# ------------------ 旧数据扫描/清理 ------------------
def find_legacy_group_history_keys(self) -> List[str]:
"""
发现旧版本使用 safe_id 写入的 group_history key。
Returns:
legacy_keys 列表(不删除)
"""
redis_cache = get_cache()
if not redis_cache or not redis_cache.enabled:
return []
try:
keys = redis_cache.client.keys("group_history:*")
legacy = []
for k in keys or []:
# 新 key 一般包含 @chatroom旧 safe_id 不包含 @
if "@chatroom" not in k and "_" in k:
legacy.append(k)
return legacy
except Exception as e:
logger.debug(f"[ContextStore] 扫描 legacy group_history keys 失败: {e}")
return []
def delete_legacy_group_history_keys(self, legacy_keys: List[str]) -> int:
"""删除给定 legacy key 列表"""
redis_cache = get_cache()
if not redis_cache or not redis_cache.enabled or not legacy_keys:
return 0
try:
return redis_cache.client.delete(*legacy_keys)
except Exception as e:
logger.debug(f"[ContextStore] 删除 legacy group_history keys 失败: {e}")
return 0

183
utils/llm_tooling.py Normal file
View File

@@ -0,0 +1,183 @@
"""
LLM 工具体系公共模块
统一工具收集、参数校验与执行结果结构,供 AIChat 等插件使用。
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from loguru import logger
@dataclass
class ToolResult:
"""统一的工具执行结果结构"""
success: bool = True
message: str = ""
need_ai_reply: bool = False
already_sent: bool = False
send_result_text: bool = False
no_reply: bool = False
save_to_memory: bool = False
@classmethod
def from_raw(cls, raw: Any) -> Optional["ToolResult"]:
if raw is None:
return None
if not isinstance(raw, dict):
return cls(success=True, message=str(raw))
msg = raw.get("message", "")
if not isinstance(msg, str):
try:
msg = json.dumps(msg, ensure_ascii=False)
except Exception:
msg = str(msg)
return cls(
success=bool(raw.get("success", True)),
message=msg,
need_ai_reply=bool(raw.get("need_ai_reply", False)),
already_sent=bool(raw.get("already_sent", False)),
send_result_text=bool(raw.get("send_result_text", False)),
no_reply=bool(raw.get("no_reply", False)),
save_to_memory=bool(raw.get("save_to_memory", False)),
)
def collect_tools_with_plugins(
tools_config: Dict[str, Any],
plugins: Dict[str, Any],
) -> Dict[str, Tuple[str, Dict[str, Any]]]:
"""
收集所有插件的 LLM 工具,并保留来源插件名。
Args:
tools_config: AIChat 配置中的 [tools] 节
plugins: PluginManager().plugins 映射
Returns:
{tool_name: (plugin_name, tool_dict)}
"""
tools_by_name: Dict[str, Tuple[str, Dict[str, Any]]] = {}
mode = tools_config.get("mode", "all")
whitelist = set(tools_config.get("whitelist", []))
blacklist = set(tools_config.get("blacklist", []))
for plugin_name, plugin in plugins.items():
if not hasattr(plugin, "get_llm_tools"):
continue
plugin_tools = plugin.get_llm_tools() or []
for tool in plugin_tools:
tool_name = tool.get("function", {}).get("name", "")
if not tool_name:
continue
if mode == "whitelist" and tool_name not in whitelist:
continue
if mode == "blacklist" and tool_name in blacklist:
logger.debug(f"[黑名单] 禁用工具: {tool_name}")
continue
if tool_name in tools_by_name:
logger.warning(f"重复工具名 {tool_name} 来自 {plugin_name},已忽略")
continue
tools_by_name[tool_name] = (plugin_name, tool)
if mode == "whitelist":
logger.debug(f"[白名单] 启用工具: {tool_name}")
return tools_by_name
def collect_tools(
tools_config: Dict[str, Any],
plugins: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""仅返回工具定义列表"""
return [item[1] for item in collect_tools_with_plugins(tools_config, plugins).values()]
def get_tool_schema_map(
tools_map: Dict[str, Tuple[str, Dict[str, Any]]],
) -> Dict[str, Dict[str, Any]]:
"""构建工具名到参数 schema 的映射"""
schema_map: Dict[str, Dict[str, Any]] = {}
for name, (_plugin_name, tool) in tools_map.items():
fn = tool.get("function", {})
schema_map[name] = fn.get("parameters", {}) or {}
return schema_map
def validate_tool_arguments(
tool_name: str,
arguments: Dict[str, Any],
schema: Optional[Dict[str, Any]],
) -> Tuple[bool, str, Dict[str, Any]]:
"""
轻量校验并补全默认参数。
Returns:
(ok, error_message, new_arguments)
"""
if not schema:
return True, "", arguments
props = schema.get("properties", {}) or {}
required = schema.get("required", []) or []
# 应用默认值
for key, prop in props.items():
if key not in arguments and isinstance(prop, dict) and "default" in prop:
arguments[key] = prop["default"]
missing = []
for key in required:
if key not in arguments or arguments[key] in (None, "", []):
missing.append(key)
if missing:
return False, f"缺少参数: {', '.join(missing)}", arguments
# 枚举与基础类型校验
for key, prop in props.items():
if key not in arguments or not isinstance(prop, dict):
continue
value = arguments[key]
if "enum" in prop and value not in prop["enum"]:
return False, f"参数 {key} 必须是 {prop['enum']}", arguments
expected_type = prop.get("type")
if expected_type == "integer":
try:
arguments[key] = int(value)
except Exception:
return False, f"参数 {key} 应为整数", arguments
elif expected_type == "number":
try:
arguments[key] = float(value)
except Exception:
return False, f"参数 {key} 应为数字", arguments
elif expected_type == "boolean":
if isinstance(value, bool):
continue
if isinstance(value, str) and value.lower() in ("true", "false", "1", "0"):
arguments[key] = value.lower() in ("true", "1")
else:
return False, f"参数 {key} 应为布尔值", arguments
elif expected_type == "string":
if not isinstance(value, str):
arguments[key] = str(value)
return True, "", arguments

View File

@@ -322,7 +322,18 @@ class RedisCache:
logger.error(f"获取对话历史失败: {chat_id}, {e}") logger.error(f"获取对话历史失败: {chat_id}, {e}")
return [] return []
def add_chat_message(self, chat_id: str, role: str, content, ttl: int = 86400) -> bool: def add_chat_message(
self,
chat_id: str,
role: str,
content,
ttl: int = 86400,
*,
nickname: str = None,
sender_wxid: str = None,
record_id: str = None,
timestamp: float = None,
) -> bool:
""" """
添加消息到对话历史 添加消息到对话历史
@@ -331,6 +342,10 @@ class RedisCache:
role: 角色 (user/assistant) role: 角色 (user/assistant)
content: 消息内容(字符串或列表) content: 消息内容(字符串或列表)
ttl: 过期时间默认24小时 ttl: 过期时间默认24小时
nickname: 可选昵称(用于统一 schema
sender_wxid: 可选发送者 wxid
record_id: 可选记录 ID
timestamp: 可选时间戳
Returns: Returns:
是否添加成功 是否添加成功
@@ -340,7 +355,18 @@ class RedisCache:
try: try:
key = self._make_key("chat_history", chat_id) key = self._make_key("chat_history", chat_id)
message = {"role": role, "content": content} import time as _time
message = {
"role": role or "user",
"content": content,
}
if nickname:
message["nickname"] = nickname
if sender_wxid:
message["wxid"] = sender_wxid
if record_id:
message["id"] = record_id
message["timestamp"] = timestamp or _time.time()
self.client.rpush(key, json.dumps(message, ensure_ascii=False)) self.client.rpush(key, json.dumps(message, ensure_ascii=False))
self.client.expire(key, ttl) self.client.expire(key, ttl)
return True return True
@@ -416,8 +442,17 @@ class RedisCache:
logger.error(f"获取群聊历史失败: {group_id}, {e}") logger.error(f"获取群聊历史失败: {group_id}, {e}")
return [] return []
def add_group_message(self, group_id: str, nickname: str, content, def add_group_message(
record_id: str = None, ttl: int = 86400) -> bool: self,
group_id: str,
nickname: str,
content,
record_id: str = None,
*,
role: str = "user",
sender_wxid: str = None,
ttl: int = 86400,
) -> bool:
""" """
添加消息到群聊历史 添加消息到群聊历史
@@ -426,6 +461,8 @@ class RedisCache:
nickname: 发送者昵称 nickname: 发送者昵称
content: 消息内容 content: 消息内容
record_id: 可选的记录ID用于后续更新 record_id: 可选的记录ID用于后续更新
role: 角色 (user/assistant),默认 user
sender_wxid: 可选的发送者 wxid
ttl: 过期时间默认24小时 ttl: 过期时间默认24小时
Returns: Returns:
@@ -438,10 +475,13 @@ class RedisCache:
import time import time
key = self._make_key("group_history", group_id) key = self._make_key("group_history", group_id)
message = { message = {
"role": role or "user",
"nickname": nickname, "nickname": nickname,
"content": content, "content": content,
"timestamp": time.time() "timestamp": time.time()
} }
if sender_wxid:
message["wxid"] = sender_wxid
if record_id: if record_id:
message["id"] = record_id message["id"] = record_id