529 lines
22 KiB
Python
529 lines
22 KiB
Python
"""
|
||
AutoReply 插件 - 基于双LLM架构的智能自动回复
|
||
|
||
使用小模型判断是否需要回复,通过后触发AIChat插件生成回复
|
||
"""
|
||
|
||
import json
|
||
import time
|
||
import tomllib
|
||
import aiohttp
|
||
from pathlib import Path
|
||
from datetime import datetime, date
|
||
from dataclasses import dataclass
|
||
from typing import Dict
|
||
from loguru import logger
|
||
from utils.plugin_base import PluginBase
|
||
from utils.decorators import on_text_message, schedule
|
||
|
||
try:
|
||
from aiohttp_socks import ProxyConnector
|
||
PROXY_SUPPORT = True
|
||
except ImportError:
|
||
PROXY_SUPPORT = False
|
||
|
||
|
||
@dataclass
|
||
class JudgeResult:
|
||
"""判断结果"""
|
||
relevance: float = 0.0
|
||
willingness: float = 0.0
|
||
social: float = 0.0
|
||
timing: float = 0.0
|
||
continuity: float = 0.0
|
||
reasoning: str = ""
|
||
should_reply: bool = False
|
||
overall_score: float = 0.0
|
||
|
||
|
||
@dataclass
|
||
class ChatState:
|
||
"""群聊状态"""
|
||
energy: float = 1.0
|
||
last_reply_time: float = 0.0
|
||
last_reset_date: str = ""
|
||
total_messages: int = 0
|
||
total_replies: int = 0
|
||
|
||
|
||
class AutoReply(PluginBase):
|
||
"""智能自动回复插件"""
|
||
|
||
description = "基于双LLM架构的智能自动回复插件"
|
||
author = "ShiHao"
|
||
version = "1.0.0"
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.config = None
|
||
self.chat_states: Dict[str, ChatState] = {}
|
||
self.weights = {}
|
||
self.last_judge_time: Dict[str, float] = {} # 记录每个群最后判断时间
|
||
self.judging: Dict[str, bool] = {} # 记录是否正在判断中
|
||
self.last_history_size: Dict[str, int] = {} # 记录每个群的history大小
|
||
self.pending_judge: Dict[str, bool] = {} # 记录是否有待判断的消息
|
||
self.whitelist_normalized = set() # 归一化后的白名单ID(与history文件名一致)
|
||
|
||
async def async_init(self):
|
||
"""异步初始化"""
|
||
config_path = Path(__file__).parent / "config.toml"
|
||
with open(config_path, "rb") as f:
|
||
self.config = tomllib.load(f)
|
||
|
||
# 加载权重配置
|
||
self.weights = {
|
||
"relevance": self.config["weights"]["relevance"],
|
||
"willingness": self.config["weights"]["willingness"],
|
||
"social": self.config["weights"]["social"],
|
||
"timing": self.config["weights"]["timing"],
|
||
"continuity": self.config["weights"]["continuity"]
|
||
}
|
||
|
||
# 检查权重和
|
||
weight_sum = sum(self.weights.values())
|
||
if abs(weight_sum - 1.0) > 1e-6:
|
||
logger.warning(f"判断权重和不为1,当前和为{weight_sum},已自动归一化")
|
||
self.weights = {k: v / weight_sum for k, v in self.weights.items()}
|
||
|
||
# 预处理白名单(与history文件名的归一化规则保持一致)
|
||
self.whitelist_normalized = {
|
||
self._normalize_chat_id(cid) for cid in self.config.get("whitelist", {}).get("chat_list", [])
|
||
}
|
||
|
||
logger.info(f"AutoReply 插件已加载,判断模型: {self.config['basic']['judge_model']}")
|
||
logger.info(f"AutoReply 配置: enabled={self.config['basic']['enabled']}, priority=90")
|
||
logger.info(f"AutoReply 监听模式: 每{self.config.get('rate_limit', {}).get('check_interval', 5)}秒检查history变化")
|
||
logger.warning("⚠️ AutoReply插件已启动,等待消息...")
|
||
|
||
def _normalize_chat_id(self, chat_id: str) -> str:
|
||
"""将群ID转成history文件使用的安全文件名"""
|
||
return (chat_id or "").replace("@", "_").replace(":", "_")
|
||
|
||
def _is_chat_allowed(self, raw_chat_id: str) -> bool:
|
||
"""白名单判断,兼容原始ID与归一化ID"""
|
||
if not self.config["whitelist"]["enabled"]:
|
||
return True
|
||
safe_id = self._normalize_chat_id(raw_chat_id)
|
||
return raw_chat_id in self.config["whitelist"]["chat_list"] or safe_id in self.whitelist_normalized
|
||
|
||
@schedule('interval', seconds=5)
|
||
async def check_history_changes(self, *args, **kwargs):
|
||
"""定时检查history文件变化"""
|
||
if not self.config["basic"]["enabled"]:
|
||
logger.debug("[AutoReply] 插件未启用,跳过检查")
|
||
return
|
||
|
||
# 检查是否启用监听模式
|
||
if not self.config.get("rate_limit", {}).get("monitor_mode", True):
|
||
logger.debug("[AutoReply] 监听模式未启用,跳过检查")
|
||
return
|
||
|
||
try:
|
||
# 获取AIChat插件的history目录
|
||
from utils.plugin_manager import PluginManager
|
||
plugin_manager = PluginManager() # 单例模式,直接实例化
|
||
aichat_plugin = plugin_manager.plugins.get("AIChat")
|
||
|
||
if not aichat_plugin:
|
||
logger.debug("[AutoReply] 未找到AIChat插件")
|
||
return
|
||
|
||
if not hasattr(aichat_plugin, 'history_dir'):
|
||
logger.debug("[AutoReply] AIChat插件没有history_dir属性")
|
||
return
|
||
|
||
history_dir = aichat_plugin.history_dir
|
||
if not history_dir.exists():
|
||
logger.debug(f"[AutoReply] History目录不存在: {history_dir}")
|
||
return
|
||
|
||
logger.debug(f"[AutoReply] 开始检查history目录: {history_dir}")
|
||
|
||
# 遍历所有history文件
|
||
for history_file in history_dir.glob("*.json"):
|
||
chat_id = history_file.stem # 文件名就是chat_id
|
||
|
||
# 检查白名单
|
||
if self.config["whitelist"]["enabled"]:
|
||
if chat_id not in self.whitelist_normalized:
|
||
continue
|
||
|
||
try:
|
||
with open(history_file, "r", encoding="utf-8") as f:
|
||
history = json.load(f)
|
||
|
||
current_size = len(history)
|
||
last_size = self.last_history_size.get(chat_id, 0)
|
||
|
||
# 如果有新消息
|
||
if current_size > last_size:
|
||
# 获取新增的消息
|
||
new_messages = history[last_size:]
|
||
|
||
# 检查新消息中是否有非机器人的消息
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
|
||
|
||
has_user_message = any(
|
||
msg.get('nickname') != bot_nickname
|
||
for msg in new_messages
|
||
)
|
||
|
||
if has_user_message:
|
||
logger.debug(f"[AutoReply] 检测到群聊 {chat_id[:20]}... 有新消息")
|
||
# 标记为待判断
|
||
self.pending_judge[chat_id] = True
|
||
|
||
# 更新记录的大小
|
||
self.last_history_size[chat_id] = current_size
|
||
|
||
except Exception as e:
|
||
logger.debug(f"读取history文件失败: {history_file.name}, {e}")
|
||
continue
|
||
|
||
except Exception as e:
|
||
logger.error(f"检查history变化失败: {e}")
|
||
|
||
@on_text_message(priority=90) # 高优先级,在AIChat之前执行
|
||
async def handle_message(self, bot, message: dict):
|
||
"""处理消息"""
|
||
try:
|
||
logger.debug(f"[AutoReply] 收到消息,开始处理")
|
||
|
||
# 检查是否启用
|
||
if not self.config["basic"]["enabled"]:
|
||
logger.debug("AutoReply插件未启用,跳过处理")
|
||
return True
|
||
|
||
# 只处理群聊消息
|
||
is_group = message.get('IsGroup', False)
|
||
if not is_group:
|
||
logger.debug("AutoReply只处理群聊消息,跳过私聊")
|
||
return True
|
||
|
||
# 群聊消息:FromWxid是群ID,SenderWxid是发送者ID
|
||
from_wxid = message.get('FromWxid') # 群聊ID
|
||
sender_wxid = message.get('SenderWxid') # 发送者ID
|
||
chat_id = self._normalize_chat_id(from_wxid) # 归一化ID,匹配history文件名
|
||
content = (message.get('msg') or message.get('Content', '')).strip()
|
||
|
||
# 跳过空消息
|
||
if not content:
|
||
logger.debug("AutoReply跳过空消息")
|
||
return True
|
||
|
||
# 检查白名单(使用from_wxid作为群聊ID)
|
||
if not self._is_chat_allowed(from_wxid):
|
||
logger.debug(f"AutoReply白名单模式,群聊 {from_wxid[:20]}... 不在白名单中")
|
||
return True
|
||
|
||
# 跳过已被@的消息(让AIChat正常处理)
|
||
if self._is_at_bot(message):
|
||
logger.debug("AutoReply跳过@消息,交由AIChat处理")
|
||
return True
|
||
|
||
# 监听模式:只在检测到待判断标记时才判断
|
||
monitor_mode = self.config.get("rate_limit", {}).get("monitor_mode", True)
|
||
if monitor_mode:
|
||
if not self.pending_judge.get(chat_id, False):
|
||
logger.debug(f"AutoReply监听模式,群聊 {from_wxid[:20]}... 无待判断标记")
|
||
return True
|
||
# 清除待判断标记
|
||
self.pending_judge[chat_id] = False
|
||
|
||
# 频率限制:检查是否正在判断中
|
||
if self.config.get("rate_limit", {}).get("skip_if_judging", True):
|
||
if self.judging.get(chat_id, False):
|
||
logger.debug(f"AutoReply跳过消息,群聊 {from_wxid[:20]}... 正在判断中")
|
||
return True
|
||
|
||
# 频率限制:检查距离上次判断的时间间隔
|
||
min_interval = self.config.get("rate_limit", {}).get("min_interval", 10)
|
||
last_time = self.last_judge_time.get(chat_id, 0)
|
||
current_time = time.time()
|
||
if current_time - last_time < min_interval:
|
||
logger.debug(f"AutoReply跳过消息,距离上次判断仅 {current_time - last_time:.1f}秒")
|
||
# 监听模式下,如果时间间隔不够,重新标记为待判断
|
||
if monitor_mode:
|
||
self.pending_judge[chat_id] = True
|
||
return True
|
||
|
||
logger.info(f"AutoReply开始判断消息: {content[:30]}...")
|
||
|
||
# 标记正在判断中
|
||
self.judging[chat_id] = True
|
||
self.last_judge_time[chat_id] = current_time
|
||
|
||
# 使用小模型判断是否需要回复
|
||
judge_result = await self._judge_with_small_model(bot, message)
|
||
|
||
# 清除判断中标记
|
||
self.judging[chat_id] = False
|
||
|
||
if judge_result.should_reply:
|
||
logger.info(f"🔥 AutoReply触发 | {from_wxid[:20]}... | 评分:{judge_result.overall_score:.2f} | {judge_result.reasoning[:50]}")
|
||
|
||
# 更新状态
|
||
self._update_active_state(chat_id, judge_result)
|
||
|
||
# 修改消息,让AIChat认为需要回复
|
||
message['_auto_reply_triggered'] = True
|
||
|
||
return True # 继续传递给AIChat
|
||
else:
|
||
logger.debug(f"AutoReply不触发 | {from_wxid[:20]}... | 评分:{judge_result.overall_score:.2f}")
|
||
self._update_passive_state(chat_id, judge_result)
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"AutoReply处理异常: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
# 异常时也要清除判断中标记
|
||
if 'chat_id' in locals():
|
||
self.judging[chat_id] = False
|
||
elif 'from_wxid' in locals():
|
||
self.judging[self._normalize_chat_id(from_wxid)] = False
|
||
return True
|
||
|
||
def _is_at_bot(self, message: dict) -> bool:
|
||
"""检查是否@了机器人"""
|
||
content = message.get('Content', '')
|
||
# 规范化后的消息使用 Ats 字段
|
||
at_list = message.get('Ats', [])
|
||
# 检查是否有@列表或内容中包含@标记
|
||
return len(at_list) > 0 or '@' in content or '@' in content
|
||
|
||
async def _judge_with_small_model(self, bot, message: dict) -> JudgeResult:
|
||
"""使用小模型判断是否需要回复"""
|
||
# 规范化后的消息:FromWxid是群ID,SenderWxid是发送者ID,Content是内容
|
||
from_wxid = message.get('FromWxid') # 群聊ID
|
||
chat_id = self._normalize_chat_id(from_wxid)
|
||
content = message.get('Content', '')
|
||
sender_wxid = message.get('SenderWxid', '')
|
||
|
||
# 获取群聊状态
|
||
chat_state = self._get_chat_state(chat_id)
|
||
|
||
# 获取最近消息历史
|
||
recent_messages = await self._get_recent_messages(chat_id)
|
||
last_bot_reply = await self._get_last_bot_reply(chat_id)
|
||
|
||
# 构建判断提示词
|
||
reasoning_part = ""
|
||
if self.config["judge"]["include_reasoning"]:
|
||
reasoning_part = ',\n "reasoning": "详细分析原因"'
|
||
|
||
judge_prompt = f"""你是群聊机器人的决策系统,判断是否应该主动回复。
|
||
|
||
## 当前群聊情况
|
||
- 群聊ID: {from_wxid}
|
||
- 精力水平: {chat_state.energy:.1f}/1.0
|
||
- 上次发言: {self._get_minutes_since_last_reply(chat_id)}分钟前
|
||
|
||
## 最近{self.config['context']['messages_count']}条对话
|
||
{recent_messages}
|
||
|
||
## 上次机器人回复
|
||
{last_bot_reply if last_bot_reply else "暂无"}
|
||
|
||
## 待判断消息
|
||
内容: {content}
|
||
时间: {datetime.now().strftime('%H:%M:%S')}
|
||
|
||
## 评估要求
|
||
从以下5个维度评估(0-10分):
|
||
1. **内容相关度**(0-10):消息是否有趣、有价值、适合回复
|
||
2. **回复意愿**(0-10):基于当前精力水平的回复意愿
|
||
3. **社交适宜性**(0-10):在当前群聊氛围下回复是否合适
|
||
4. **时机恰当性**(0-10):回复时机是否恰当
|
||
5. **对话连贯性**(0-10):当前消息与上次回复的关联程度
|
||
|
||
**回复阈值**: {self.config['basic']['reply_threshold']}
|
||
|
||
请以JSON格式回复:
|
||
{{
|
||
"relevance": 分数,
|
||
"willingness": 分数,
|
||
"social": 分数,
|
||
"timing": 分数,
|
||
"continuity": 分数{reasoning_part}
|
||
}}
|
||
|
||
**注意:你的回复必须是完整的JSON对象,不要包含任何其他内容!**"""
|
||
|
||
# 调用小模型API
|
||
max_retries = self.config["judge"]["max_retries"] + 1
|
||
for attempt in range(max_retries):
|
||
try:
|
||
result = await self._call_judge_api(judge_prompt)
|
||
|
||
# 解析JSON
|
||
content_text = result.strip()
|
||
if content_text.startswith("```json"):
|
||
content_text = content_text.replace("```json", "").replace("```", "").strip()
|
||
elif content_text.startswith("```"):
|
||
content_text = content_text.replace("```", "").strip()
|
||
|
||
judge_data = json.loads(content_text)
|
||
|
||
# 计算综合评分
|
||
overall_score = (
|
||
judge_data.get("relevance", 0) * self.weights["relevance"] +
|
||
judge_data.get("willingness", 0) * self.weights["willingness"] +
|
||
judge_data.get("social", 0) * self.weights["social"] +
|
||
judge_data.get("timing", 0) * self.weights["timing"] +
|
||
judge_data.get("continuity", 0) * self.weights["continuity"]
|
||
) / 10.0
|
||
|
||
should_reply = overall_score >= self.config["basic"]["reply_threshold"]
|
||
|
||
return JudgeResult(
|
||
relevance=judge_data.get("relevance", 0),
|
||
willingness=judge_data.get("willingness", 0),
|
||
social=judge_data.get("social", 0),
|
||
timing=judge_data.get("timing", 0),
|
||
continuity=judge_data.get("continuity", 0),
|
||
reasoning=judge_data.get("reasoning", "") if self.config["judge"]["include_reasoning"] else "",
|
||
should_reply=should_reply,
|
||
overall_score=overall_score
|
||
)
|
||
|
||
except json.JSONDecodeError as e:
|
||
logger.warning(f"小模型返回JSON解析失败 (尝试 {attempt + 1}/{max_retries}): {str(e)}")
|
||
if attempt == max_retries - 1:
|
||
return JudgeResult(should_reply=False, reasoning="JSON解析失败")
|
||
continue
|
||
except Exception as e:
|
||
logger.error(f"小模型判断异常: {e}")
|
||
return JudgeResult(should_reply=False, reasoning=f"异常: {str(e)}")
|
||
|
||
return JudgeResult(should_reply=False, reasoning="重试失败")
|
||
|
||
async def _call_judge_api(self, prompt: str) -> str:
|
||
"""调用判断模型API"""
|
||
api_url = self.config["basic"]["judge_api_url"]
|
||
api_key = self.config["basic"]["judge_api_key"]
|
||
model = self.config["basic"]["judge_model"]
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}"
|
||
}
|
||
|
||
payload = {
|
||
"model": model,
|
||
"messages": [
|
||
{"role": "system", "content": "你是一个专业的群聊回复决策系统。你必须严格按照JSON格式返回结果。"},
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
"temperature": 0.7
|
||
}
|
||
|
||
# 配置代理
|
||
connector = None
|
||
if self.config["proxy"]["enabled"] and PROXY_SUPPORT:
|
||
proxy_type = self.config["proxy"]["type"]
|
||
proxy_host = self.config["proxy"]["host"]
|
||
proxy_port = self.config["proxy"]["port"]
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
|
||
async with aiohttp.ClientSession(connector=connector) as session:
|
||
async with session.post(api_url, headers=headers, json=payload, timeout=aiohttp.ClientTimeout(total=30)) as response:
|
||
if response.status != 200:
|
||
raise Exception(f"API调用失败: {response.status}")
|
||
|
||
result = await response.json()
|
||
return result["choices"][0]["message"]["content"]
|
||
|
||
async def _get_recent_messages(self, chat_id: str) -> str:
|
||
"""获取最近消息历史"""
|
||
try:
|
||
# 尝试从AIChat插件获取历史记录
|
||
from utils.plugin_manager import PluginManager
|
||
plugin_manager = PluginManager() # 单例模式,直接实例化
|
||
aichat_plugin = plugin_manager.plugins.get("AIChat")
|
||
|
||
if aichat_plugin and hasattr(aichat_plugin, 'history_dir'):
|
||
history_file = aichat_plugin.history_dir / f"{chat_id}.json"
|
||
if history_file.exists():
|
||
with open(history_file, "r", encoding="utf-8") as f:
|
||
history = json.load(f)
|
||
|
||
# 获取最近N条
|
||
recent = history[-self.config['context']['messages_count']:]
|
||
messages = []
|
||
for record in recent:
|
||
nickname = record.get('nickname', '未知')
|
||
content = record.get('content', '')
|
||
messages.append(f"{nickname}: {content}")
|
||
|
||
return "\n".join(messages) if messages else "暂无对话历史"
|
||
except Exception as e:
|
||
logger.debug(f"获取消息历史失败: {e}")
|
||
|
||
return "暂无对话历史"
|
||
|
||
async def _get_last_bot_reply(self, chat_id: str) -> str:
|
||
"""获取上次机器人回复"""
|
||
try:
|
||
from utils.plugin_manager import PluginManager
|
||
plugin_manager = PluginManager() # 单例模式,直接实例化
|
||
aichat_plugin = plugin_manager.plugins.get("AIChat")
|
||
|
||
if aichat_plugin and hasattr(aichat_plugin, 'history_dir'):
|
||
history_file = aichat_plugin.history_dir / f"{chat_id}.json"
|
||
if history_file.exists():
|
||
with open(history_file, "r", encoding="utf-8") as f:
|
||
history = json.load(f)
|
||
|
||
# 从后往前查找机器人回复
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
|
||
|
||
for record in reversed(history):
|
||
if record.get('nickname') == bot_nickname:
|
||
return record.get('content', '')
|
||
except Exception as e:
|
||
logger.debug(f"获取上次回复失败: {e}")
|
||
|
||
return None
|
||
|
||
def _get_chat_state(self, chat_id: str) -> ChatState:
|
||
"""获取群聊状态"""
|
||
if chat_id not in self.chat_states:
|
||
self.chat_states[chat_id] = ChatState()
|
||
|
||
today = date.today().isoformat()
|
||
state = self.chat_states[chat_id]
|
||
|
||
if state.last_reset_date != today:
|
||
state.last_reset_date = today
|
||
state.energy = min(1.0, state.energy + 0.2)
|
||
|
||
return state
|
||
|
||
def _get_minutes_since_last_reply(self, chat_id: str) -> int:
|
||
"""获取距离上次回复的分钟数"""
|
||
chat_state = self._get_chat_state(chat_id)
|
||
if chat_state.last_reply_time == 0:
|
||
return 999
|
||
return int((time.time() - chat_state.last_reply_time) / 60)
|
||
|
||
def _update_active_state(self, chat_id: str, judge_result: JudgeResult):
|
||
"""更新主动回复状态"""
|
||
chat_state = self._get_chat_state(chat_id)
|
||
chat_state.last_reply_time = time.time()
|
||
chat_state.total_replies += 1
|
||
chat_state.total_messages += 1
|
||
chat_state.energy = max(0.1, chat_state.energy - self.config["energy"]["decay_rate"])
|
||
|
||
def _update_passive_state(self, chat_id: str, judge_result: JudgeResult):
|
||
"""更新被动状态"""
|
||
chat_state = self._get_chat_state(chat_id)
|
||
chat_state.total_messages += 1
|
||
chat_state.energy = min(1.0, chat_state.energy + self.config["energy"]["recovery_rate"])
|