Files
WechatHookBot/plugins/AutoReply/main.py

481 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AutoReply 插件 - 基于双LLM架构的智能自动回复
使用小模型判断是否需要回复通过后触发AIChat插件生成回复
"""
import json
import time
import asyncio
import tomllib
import aiohttp
from pathlib import Path
from datetime import datetime, date
from dataclasses import dataclass, field
from typing import Dict, Optional
from loguru import logger
from utils.plugin_base import PluginBase
from utils.decorators import on_text_message
try:
from aiohttp_socks import ProxyConnector
PROXY_SUPPORT = True
except ImportError:
PROXY_SUPPORT = False
@dataclass
class JudgeResult:
"""判断结果"""
relevance: float = 0.0
willingness: float = 0.0
social: float = 0.0
timing: float = 0.0
continuity: float = 0.0
reasoning: str = ""
should_reply: bool = False
overall_score: float = 0.0
@dataclass
class ChatState:
"""群聊状态"""
energy: float = 1.0
last_reply_time: float = 0.0
last_reset_date: str = ""
total_messages: int = 0
total_replies: int = 0
class AutoReply(PluginBase):
"""智能自动回复插件"""
description = "基于双LLM架构的智能自动回复插件"
author = "ShiHao"
version = "1.1.0"
def __init__(self):
super().__init__()
self.config = None
self.chat_states: Dict[str, ChatState] = {}
self.weights = {}
self.last_judge_time: Dict[str, float] = {}
self.judging: Dict[str, bool] = {}
self.bot_wxid: str = ""
self.bot_nickname: str = ""
async def async_init(self):
"""异步初始化"""
try:
config_path = Path(__file__).parent / "config.toml"
with open(config_path, "rb") as f:
self.config = tomllib.load(f)
# 加载权重配置
self.weights = {
"relevance": self.config["weights"]["relevance"],
"willingness": self.config["weights"]["willingness"],
"social": self.config["weights"]["social"],
"timing": self.config["weights"]["timing"],
"continuity": self.config["weights"]["continuity"]
}
# 检查权重和
weight_sum = sum(self.weights.values())
if abs(weight_sum - 1.0) > 1e-6:
logger.warning(f"[AutoReply] 判断权重和不为1当前和为{weight_sum},已自动归一化")
self.weights = {k: v / weight_sum for k, v in self.weights.items()}
# 加载机器人信息
self._load_bot_info()
logger.success(f"[AutoReply] 插件已加载,判断模型: {self.config['basic']['judge_model']}")
logger.info(f"[AutoReply] 回复阈值: {self.config['basic']['reply_threshold']}, 最小间隔: {self.config['rate_limit']['min_interval']}")
except Exception as e:
logger.error(f"[AutoReply] 初始化失败: {e}")
self.config = None
def _load_bot_info(self):
"""加载机器人信息"""
try:
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
self.bot_wxid = main_config.get("Bot", {}).get("wxid", "")
self.bot_nickname = main_config.get("Bot", {}).get("nickname", "")
except Exception as e:
logger.warning(f"[AutoReply] 加载机器人信息失败: {e}")
def _normalize_chat_id(self, chat_id: str) -> str:
"""将群ID转成history文件使用的安全文件名"""
return (chat_id or "").replace("@", "_").replace(":", "_")
def _is_chat_allowed(self, chat_id: str) -> bool:
"""白名单判断"""
whitelist_config = self.config.get("whitelist", {})
if not whitelist_config.get("enabled", False):
return True
chat_list = whitelist_config.get("chat_list", [])
safe_id = self._normalize_chat_id(chat_id)
# 同时检查原始ID和归一化ID
return chat_id in chat_list or safe_id in chat_list
def _is_at_bot(self, message: dict) -> bool:
"""检查是否@了机器人"""
# 优先检查 Ats 列表
at_list = message.get('Ats', [])
if at_list:
# 检查机器人wxid是否在@列表中
if self.bot_wxid and self.bot_wxid in at_list:
return True
# 备用:检查内容中是否包含@机器人昵称
content = message.get('Content', '')
if self.bot_nickname and f"@{self.bot_nickname}" in content:
return True
return False
def _is_bot_message(self, message: dict) -> bool:
"""检查是否是机器人自己的消息"""
sender_wxid = message.get('SenderWxid', '')
return sender_wxid == self.bot_wxid if self.bot_wxid else False
@on_text_message(priority=90)
async def handle_message(self, bot, message: dict):
"""处理消息"""
try:
# 检查是否启用
if not self.config or not self.config["basic"]["enabled"]:
return True
# 只处理群聊消息
if not message.get('IsGroup', False):
return True
from_wxid = message.get('FromWxid', '') # 群聊ID
sender_wxid = message.get('SenderWxid', '') # 发送者ID
content = message.get('Content', '').strip()
# 跳过空消息
if not content:
return True
# 跳过机器人自己的消息
if self._is_bot_message(message):
return True
# 检查白名单
if not self._is_chat_allowed(from_wxid):
return True
# 跳过@机器人的消息让AIChat正常处理
if self._is_at_bot(message):
logger.debug(f"[AutoReply] 跳过@消息交由AIChat处理")
return True
chat_id = self._normalize_chat_id(from_wxid)
current_time = time.time()
# 频率限制:检查是否正在判断中
if self.judging.get(chat_id, False):
logger.debug(f"[AutoReply] 群聊 {from_wxid[:15]}... 正在判断中,跳过")
return True
# 频率限制:检查时间间隔
min_interval = self.config.get("rate_limit", {}).get("min_interval", 10)
last_time = self.last_judge_time.get(chat_id, 0)
if current_time - last_time < min_interval:
logger.debug(f"[AutoReply] 距离上次判断仅 {current_time - last_time:.1f}秒,跳过")
return True
# 标记正在判断
self.judging[chat_id] = True
self.last_judge_time[chat_id] = current_time
try:
# 使用小模型判断
judge_result = await self._judge_with_small_model(from_wxid, content)
if judge_result.should_reply:
logger.info(f"[AutoReply] 触发回复 | 群:{from_wxid[:15]}... | 评分:{judge_result.overall_score:.2f} | {judge_result.reasoning[:30]}")
# 更新状态
self._update_state(chat_id, replied=True)
# 设置触发标记让AIChat处理
message['_auto_reply_triggered'] = True
else:
logger.debug(f"[AutoReply] 不触发 | 群:{from_wxid[:15]}... | 评分:{judge_result.overall_score:.2f}")
self._update_state(chat_id, replied=False)
finally:
# 清除判断中标记
self.judging[chat_id] = False
return True
except Exception as e:
logger.error(f"[AutoReply] 处理异常: {e}")
import traceback
logger.error(traceback.format_exc())
# 清除判断中标记
if 'chat_id' in locals():
self.judging[chat_id] = False
return True
async def _judge_with_small_model(self, from_wxid: str, content: str) -> JudgeResult:
"""使用小模型判断是否需要回复"""
chat_id = self._normalize_chat_id(from_wxid)
chat_state = self._get_chat_state(chat_id)
# 获取最近消息历史
recent_messages = await self._get_recent_messages(chat_id)
last_bot_reply = await self._get_last_bot_reply(chat_id)
# 构建判断提示词
reasoning_part = ',\n "reasoning": "简短分析原因(20字内)"' if self.config["judge"]["include_reasoning"] else ""
judge_prompt = f"""你是群聊机器人的决策系统,判断是否应该主动回复。
## 当前状态
- 精力: {chat_state.energy:.1f}/1.0
- 上次发言: {self._get_minutes_since_last_reply(chat_id)}分钟前
## 最近对话
{recent_messages}
## 上次机器人回复
{last_bot_reply or "暂无"}
## 待判断消息
{content}
## 评估维度(0-10分)
1. relevance: 内容是否有趣、值得回复
2. willingness: 基于精力的回复意愿
3. social: 回复是否社交适宜
4. timing: 时机是否恰当
5. continuity: 与上次回复的关联度
回复阈值: {self.config['basic']['reply_threshold']}
仅返回JSON:
{{
"relevance": 分数,
"willingness": 分数,
"social": 分数,
"timing": 分数,
"continuity": 分数{reasoning_part}
}}"""
# 调用API
max_retries = self.config["judge"].get("max_retries", 2)
for attempt in range(max_retries + 1):
try:
result = await self._call_judge_api(judge_prompt)
# 解析JSON
content_text = result.strip()
# 移除可能的markdown代码块标记
if content_text.startswith("```"):
content_text = content_text.split("```")[1]
if content_text.startswith("json"):
content_text = content_text[4:]
content_text = content_text.strip()
judge_data = json.loads(content_text)
# 计算综合评分
overall_score = (
judge_data.get("relevance", 0) * self.weights["relevance"] +
judge_data.get("willingness", 0) * self.weights["willingness"] +
judge_data.get("social", 0) * self.weights["social"] +
judge_data.get("timing", 0) * self.weights["timing"] +
judge_data.get("continuity", 0) * self.weights["continuity"]
) / 10.0
should_reply = overall_score >= self.config["basic"]["reply_threshold"]
return JudgeResult(
relevance=judge_data.get("relevance", 0),
willingness=judge_data.get("willingness", 0),
social=judge_data.get("social", 0),
timing=judge_data.get("timing", 0),
continuity=judge_data.get("continuity", 0),
reasoning=judge_data.get("reasoning", ""),
should_reply=should_reply,
overall_score=overall_score
)
except json.JSONDecodeError as e:
logger.warning(f"[AutoReply] JSON解析失败 (尝试 {attempt + 1}/{max_retries + 1}): {e}")
if attempt == max_retries:
return JudgeResult(should_reply=False, reasoning="JSON解析失败")
except Exception as e:
logger.error(f"[AutoReply] 判断异常: {e}")
return JudgeResult(should_reply=False, reasoning=f"异常: {str(e)}")
return JudgeResult(should_reply=False, reasoning="重试失败")
async def _call_judge_api(self, prompt: str) -> str:
"""调用判断模型API"""
api_url = self.config["basic"]["judge_api_url"]
api_key = self.config["basic"]["judge_api_key"]
model = self.config["basic"]["judge_model"]
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
payload = {
"model": model,
"messages": [
{"role": "system", "content": "你是群聊回复决策系统。严格按JSON格式返回不要输出其他内容。"},
{"role": "user", "content": prompt}
],
"temperature": 0.5,
"max_tokens": 200
}
# 配置代理
connector = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False) and PROXY_SUPPORT:
proxy_type = proxy_config.get("type", "http")
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
connector = ProxyConnector.from_url(proxy_url)
timeout = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
async with session.post(api_url, headers=headers, json=payload) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"API调用失败: {response.status}, {error_text[:100]}")
result = await response.json()
return result["choices"][0]["message"]["content"]
async def _get_history(self, chat_id: str) -> list:
"""获取群聊历史记录(优先 Redis降级到文件"""
try:
from utils.plugin_manager import PluginManager
aichat_plugin = PluginManager().plugins.get("AIChat")
if not aichat_plugin:
return []
# 优先使用 Redis与 AIChat 保持一致)
try:
from utils.redis_cache import get_cache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
aichat_config = aichat_plugin.config or {}
redis_config = aichat_config.get("redis", {})
if redis_config.get("use_redis_history", True):
max_history = aichat_config.get("history", {}).get("max_history", 100)
history = redis_cache.get_group_history(chat_id, max_history)
if history:
return history
except Exception as e:
logger.debug(f"[AutoReply] Redis 获取历史失败: {e}")
# 降级到文件存储
if hasattr(aichat_plugin, 'history_dir') and aichat_plugin.history_dir:
history_file = aichat_plugin.history_dir / f"{chat_id}.json"
if history_file.exists():
with open(history_file, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.debug(f"[AutoReply] 获取历史失败: {e}")
return []
async def _get_recent_messages(self, chat_id: str) -> str:
"""获取最近消息历史"""
try:
history = await self._get_history(chat_id)
if not history:
return "暂无对话历史"
count = self.config.get('context', {}).get('messages_count', 5)
recent = history[-count:] if len(history) > count else history
messages = []
for record in recent:
nickname = record.get('nickname', '未知')
content = record.get('content', '')
# 限制单条消息长度
if len(content) > 100:
content = content[:100] + "..."
messages.append(f"{nickname}: {content}")
return "\n".join(messages) if messages else "暂无对话历史"
except Exception as e:
logger.debug(f"[AutoReply] 获取消息历史失败: {e}")
return "暂无对话历史"
async def _get_last_bot_reply(self, chat_id: str) -> Optional[str]:
"""获取上次机器人回复"""
try:
history = await self._get_history(chat_id)
if not history:
return None
# 从后往前查找机器人回复
for record in reversed(history):
if record.get('nickname') == self.bot_nickname:
content = record.get('content', '')
if len(content) > 100:
content = content[:100] + "..."
return content
except Exception as e:
logger.debug(f"[AutoReply] 获取上次回复失败: {e}")
return None
def _get_chat_state(self, chat_id: str) -> ChatState:
"""获取群聊状态"""
if chat_id not in self.chat_states:
self.chat_states[chat_id] = ChatState()
state = self.chat_states[chat_id]
today = date.today().isoformat()
# 每日重置精力
if state.last_reset_date != today:
state.last_reset_date = today
state.energy = min(1.0, state.energy + 0.2)
return state
def _get_minutes_since_last_reply(self, chat_id: str) -> int:
"""获取距离上次回复的分钟数"""
state = self._get_chat_state(chat_id)
if state.last_reply_time == 0:
return 999
return int((time.time() - state.last_reply_time) / 60)
def _update_state(self, chat_id: str, replied: bool):
"""更新群聊状态"""
state = self._get_chat_state(chat_id)
state.total_messages += 1
if replied:
state.last_reply_time = time.time()
state.total_replies += 1
decay = self.config.get("energy", {}).get("decay_rate", 0.1)
state.energy = max(0.1, state.energy - decay)
else:
recovery = self.config.get("energy", {}).get("recovery_rate", 0.02)
state.energy = min(1.0, state.energy + recovery)