Files
abot/plugins/ai_auto_response/main.py
2026-04-07 12:10:47 +08:00

585 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import re
import time
import xml.etree.ElementTree as ET
from typing import Any, Dict, List, Optional, Tuple
from loguru import logger
from base.plugin_common.message_plugin_interface import MessagePluginInterface
from base.plugin_common.plugin_interface import PluginStatus
from utils.robot_cmd.robot_command import GroupBotManager, PermissionStatus
from utils.wechat.contact_manager import ContactManager
from wechat_ipad import WechatAPIClient
from wechat_ipad.models.message import MessageType
from .context_builder import ContextBuilder
from .flow_manager import FlowManager
from .group_memory import GroupMemoryService
from .group_profile import GroupProfileResolver
from .llm_client import LLMClient
from .memory_store import MemoryStore
from .persona_engine import PersonaEngine
from .response_planner import ResponsePlanner
from .triggers import TriggerRouter
from .vector_memory import VectorMemoryStore
class AIAutoResponsePlugin(MessagePluginInterface):
FEATURE_KEY = "AI_AUTO_RESPONSE"
FEATURE_DESCRIPTION = "🐮 小牛拟人群聊BOT [群聊拟真、及时答疑、长期记忆]"
@property
def name(self) -> str:
return "小牛群聊BOT"
@property
def version(self) -> str:
return "2.0.0"
@property
def description(self) -> str:
return "拟人化群聊BOT支持心流、长期记忆和回归成员识别"
@property
def author(self) -> str:
return "ABOT Team"
@property
def command_prefix(self) -> Optional[str]:
return None
@property
def commands(self) -> List[str]:
return []
@property
def feature_key(self) -> Optional[str]:
return self.FEATURE_KEY
@property
def feature_description(self) -> Optional[str]:
return self.FEATURE_DESCRIPTION
def __init__(self):
super().__init__()
self.feature = self.register_feature()
self.group_messages: Dict[str, List[Dict]] = {}
self.enable = True
self.last_reply_at: Dict[str, float] = {}
def initialize(self, context: Dict[str, Any]) -> bool:
self.LOG = logger
self.db_manager = context.get("db_manager")
self.enable = bool(self._config.get("enable", True))
self.persona_engine = PersonaEngine(self.get_plugin_path(), self._config.get("persona", {}))
self.group_memory_service = GroupMemoryService(self.db_manager, self._config.get("group_profiles", {}) or {})
self.group_profile_resolver = GroupProfileResolver(self._config.get("group_profiles", {}) or {})
self.flow_manager = FlowManager({
**(self._config.get("flow", {}) or {}),
"night_silent_hours": (self._config.get("cooldown", {}) or {}).get("night_silent_hours", []),
})
merged_trigger_config = dict(self._config.get("priority", {}) or {})
merged_trigger_config.update(self._config.get("topics", {}) or {})
self.trigger_router = TriggerRouter(merged_trigger_config)
merged_memory_config = dict(self._config.get("mode", {}) or {})
merged_memory_config.update(self._config.get("memory", {}) or {})
self.memory_store = MemoryStore(self.db_manager, merged_memory_config)
self.vector_memory = VectorMemoryStore(self._config.get("memory", {}) or {})
self.context_builder = ContextBuilder()
self.response_planner = ResponsePlanner()
self.llm_client = LLMClient(self._config.get("api", {}) or {})
self.filters = self._config.get("filters", {}) or {}
self.mode_config = self._config.get("mode", {}) or {}
self.cooldown_config = self._config.get("cooldown", {}) or {}
self._synced_member_context_versions: Dict[str, str] = {}
self.log_debug = bool((self._config.get("logging", {}) or {}).get("debug", True))
self.LOG.debug(f"[{self.name}] 初始化完成")
return True
def start(self) -> bool:
self.status = PluginStatus.RUNNING
return True
def stop(self) -> bool:
self.status = PluginStatus.STOPPED
return True
def can_process(self, message: Dict[str, Any]) -> bool:
if not self.enable:
return False
room_id = message.get("roomid", "")
if not room_id:
return False
if GroupBotManager.get_group_permission(room_id, self.feature) == PermissionStatus.DISABLED:
return False
msg_type = message.get("type")
if msg_type not in (MessageType.TEXT, MessageType.APP):
return False
full_msg = message.get("full_wx_msg")
if full_msg and full_msg.from_self():
return False
content = self._normalize_content(message)
if not content:
return False
if self._should_ignore(content):
return False
if self._is_targeting_other_user(message):
return False
return True
async def process_message(self, message: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
room_id = message.get("roomid", "")
sender = message.get("sender", "")
bot: WechatAPIClient = message.get("bot")
content = self._normalize_content(message)
sender_name = self._get_sender_name(room_id, sender)
group_name = self._get_group_name(room_id, message)
group_memory_profile = self.group_memory_service.build_group_memory_profile(room_id, group_name)
group_profile = self.group_profile_resolver.resolve(room_id, group_name, group_memory_profile)
self._log_event(
"recv",
room_id=room_id,
sender=sender,
sender_name=sender_name,
group_mode=group_profile.get("mode", ""),
knowledge_domain=group_profile.get("knowledge_domain", ""),
memory_domain=group_profile.get("group_memory_domain", ""),
humor_style=group_profile.get("humor_style", ""),
sharpness_style=group_profile.get("sharpness_style", ""),
is_at=message.get("is_at", False),
content_preview=self._preview(content),
msg_type=str(message.get("type")),
)
normalized_message = {
"sender": sender,
"sender_name": sender_name,
"content": content,
"timestamp": message.get("timestamp"),
}
self._append_group_message(room_id, normalized_message)
memory_hints = self.memory_store.build_memory_hints(room_id, sender)
self._sync_member_memory(room_id, sender, sender_name, memory_hints.get("member_context", {}))
self._log_event(
"memory",
room_id=room_id,
sender=sender,
returning_state=memory_hints.get("returning_member_state", "") or "none",
has_member_context=bool(memory_hints.get("member_context")),
is_followup=memory_hints.get("is_followup", False),
last_active_at=memory_hints.get("last_active_at", "") or "",
)
trigger = self.trigger_router.route(message | {"content": content}, memory_hints)
flow_state = self.flow_manager.apply_message_event(room_id, {
"is_at": message.get("is_at", False),
"is_question": trigger.is_question,
"is_followup": trigger.is_followup,
"topic_hit": bool(trigger.topic),
"topic": trigger.topic,
"is_returning_member": trigger.is_returning_member,
"message_after_bot": True,
})
self._log_event(
"decision",
room_id=room_id,
sender=sender,
trigger_type=trigger.trigger_type,
priority=trigger.priority,
reasons="|".join(trigger.reasons),
flow_state=flow_state.state,
flow_score=round(flow_state.score, 2),
topic=trigger.topic or "",
)
allow_proactive = bool(self.mode_config.get("allow_proactive_reply", True))
reply_mode = self.response_planner.choose_reply_mode(trigger.__dict__, flow_state.state)
should_reply = self.response_planner.should_reply(trigger.__dict__, flow_state.state, allow_proactive)
if not should_reply:
self._log_event(
"skip",
room_id=room_id,
sender=sender,
reason="planner_skip",
trigger_type=trigger.trigger_type,
reply_mode=reply_mode,
flow_state=flow_state.state,
)
return False, "skip"
if not self._pass_cooldown(room_id, trigger.__dict__):
self._log_event(
"skip",
room_id=room_id,
sender=sender,
reason="cooldown",
trigger_type=trigger.trigger_type,
reply_mode=reply_mode,
)
return False, "cooldown"
recent_messages = self.group_messages.get(room_id) or self.memory_store.get_recent_messages(room_id)
vector_memories = []
if self.vector_memory.should_search(reply_mode, trigger.trigger_type, memory_hints.get("returning_member_state", "")):
vector_memories = self.vector_memory.search(content, room_id, sender)
self._log_event(
"context",
room_id=room_id,
sender=sender,
group_mode=group_profile.get("mode", ""),
knowledge_domain=group_profile.get("knowledge_domain", ""),
reply_mode=reply_mode,
recent_message_count=len(recent_messages),
vector_hit_count=len(vector_memories),
)
context = self.context_builder.build(
room_id=room_id,
group_profile=group_profile,
sender=sender,
sender_name=sender_name,
content=content,
recent_messages=recent_messages,
member_context=memory_hints.get("member_context", {}),
trigger=trigger.__dict__,
flow_state=flow_state.state,
reply_mode=reply_mode,
vector_memories=vector_memories,
)
system_prompt = self.persona_engine.build_system_prompt(group_profile)
user_prompt = self._build_user_prompt(context, memory_hints)
response = self._sanitize_response(self.llm_client.chat(system_prompt, user_prompt, user_id=f"{room_id}:{sender}"))
if not response:
self._log_event(
"model_empty",
room_id=room_id,
sender=sender,
model=self.llm_client.model,
last_error=self.llm_client.last_error,
reply_mode=reply_mode,
)
return False, "empty_response"
response = self._finalize_reply(response, reply_mode)
await bot.send_text_message(room_id, response, sender)
self.last_reply_at[room_id] = time.time()
self.flow_manager.note_bot_reply(room_id)
self.memory_store.note_bot_reply(room_id, sender, trigger.topic)
self._upsert_interaction_memory(room_id, sender, sender_name, content, response, trigger.trigger_type, trigger.topic)
self._log_event(
"sent",
room_id=room_id,
sender=sender,
sender_name=sender_name,
trigger_type=trigger.trigger_type,
reply_mode=reply_mode,
response_preview=self._preview(response),
response_len=len(response),
)
return False, "replied"
def _append_group_message(self, room_id: str, message: Dict) -> None:
items = self.group_messages.setdefault(room_id, [])
items.append(message)
size = int(self.mode_config.get("recent_context_size", 30))
if len(items) > size:
self.group_messages[room_id] = items[-size:]
def _normalize_content(self, message: Dict[str, Any]) -> str:
msg_type = message.get("type")
content = str(message.get("content", "")).strip()
if msg_type == MessageType.TEXT:
return self._strip_at_prefix(content)
if msg_type == MessageType.APP:
try:
root = ET.fromstring(content)
title = root.find(".//title")
return (title.text or "").strip() if title is not None else "[应用消息]"
except Exception:
return "[应用消息]"
return content
@staticmethod
def _strip_at_prefix(content: str) -> str:
return re.sub(r"@.*?[\u2005\s]+", "", content).strip()
def _should_ignore(self, content: str) -> bool:
if len(content) < int(self.filters.get("min_text_length", 1)):
return True
if content in set(self.filters.get("ignore_exact", [])):
return True
return any(content.startswith(prefix) for prefix in self.filters.get("ignore_prefixes", []))
def _is_targeting_other_user(self, message: Dict[str, Any]) -> bool:
if message.get("is_at", False):
return False
raw_content = str(message.get("content", "") or "")
return "@" in raw_content
def _get_sender_name(self, room_id: str, sender: str) -> str:
try:
members = ContactManager.get_instance().get_group_members(room_id)
return members.get(sender, sender)
except Exception:
return sender
@staticmethod
def _get_group_name(room_id: str, message: Dict[str, Any]) -> str:
all_contacts = message.get("all_contacts", {}) or {}
return str(all_contacts.get(room_id, room_id))
def _pass_cooldown(self, room_id: str, trigger: Dict) -> bool:
current_ts = time.time()
room_cd = int(self.cooldown_config.get("group_reply_cooldown_sec", 45))
user_cd = int(self.cooldown_config.get("same_user_followup_cooldown_sec", 10))
last_room_reply = self.last_reply_at.get(room_id, 0.0)
if trigger.get("is_question") or trigger.get("is_followup") or trigger.get("trigger_type") == "at_trigger":
return (current_ts - last_room_reply) >= user_cd
return (current_ts - last_room_reply) >= room_cd
def _build_user_prompt(self, context: Dict, memory_hints: Dict) -> str:
recent_text = "\n".join(context.get("recent_messages", [])[-8:]) or "暂无"
reply_mode = context.get("reply_mode", "social_short")
length_rule = self._build_length_rule(reply_mode)
return (
f"当前群聊消息:\n{recent_text}\n\n"
f"当前发言:{context.get('current_message', '')}\n"
f"触发类型:{context.get('trigger_type', 'none')}\n"
f"回复模式:{context.get('reply_mode', 'social_short')}\n"
f"当前心流状态:{context.get('flow_state', 'idle')}\n"
f"当前群画像:\n{context.get('group_profile_prompt', '暂无')}\n\n"
f"成员稳定记忆:\n{context.get('memory_prompt', '暂无')}\n\n"
f"向量召回记忆:\n{context.get('vector_memory_prompt', '') or '暂无'}\n\n"
f"补充信息:回归状态={memory_hints.get('returning_member_state', '') or 'none'}\n"
f"要求:\n"
f"1. 如果是明确问题,先给清楚答案。\n"
f"2. 如果只是轻量接话,保持自然短句。\n"
f"3. 不要暴露系统记忆来源。\n"
f"4. 如果信息不足,不要硬编。\n"
f"5. 输出最终可直接发到群里的内容,不要解释你的思路。\n"
f"6. {length_rule}\n"
f"7. 优先直接回应“当前发言”本身,不要被较早上下文带跑。\n"
f"8. 成员记忆和向量召回只有在与当前问题直接相关时才允许使用,否则忽略。\n"
f"9. 如果你不确定自己是否理解对了,就宁可不展开,只回很短。\n"
f"10. 把这次回复当作真人聊天里的第一反应,先只给第一层结论,不要主动补第二层解释。\n"
f"11. 如果一句话已经够了,就立刻停,不要为了完整而补充。\n"
f"12. 回答时优先服从当前群画像里的知识域和回答风格,不要跨领域乱发挥。\n"
)
@staticmethod
def _sanitize_response(response: str) -> str:
if not response:
return ""
response = response.strip()
response = re.sub(r"\n{3,}", "\n\n", response)
return response[:500].strip()
def _finalize_reply(self, response: str, reply_mode: str) -> str:
text = (response or "").strip()
if not text:
return ""
text = re.sub(r"\s+", " ", text)
text = text.replace("\n", " ").strip()
if reply_mode == "social_short":
text = self._take_first_sentence(text, 12)
elif reply_mode == "qa_fast":
text = self._take_first_sentence(text, 28)
elif reply_mode == "qa_with_context":
text = self._take_first_sentence(text, 36)
else:
text = self._take_first_sentence(text, 24)
return text.strip()
@staticmethod
def _build_length_rule(reply_mode: str) -> str:
if reply_mode == "social_short":
return "默认只回一句短话最好控制在2到8个字除非非常不自然。"
if reply_mode == "qa_fast":
return "尽量只回1句话总长度优先控制在28字内先给结论不要主动补解释。"
if reply_mode == "qa_with_context":
return "优先控制在1句话必要时最多2句总长度优先控制在36字内只给第一层答案。"
return "尽量短,像群友临时接一句,不要长篇大论。"
@staticmethod
def _take_first_sentence(text: str, limit: int) -> str:
parts = re.split(r"(?<=[。!?!?;])", text)
first = parts[0].strip() if parts and parts[0].strip() else text.strip()
if len(first) <= limit:
return first
clipped = first[:limit].rstrip(",、;;:")
return clipped
def _sync_member_memory(self, room_id: str, sender: str, sender_name: str, member_context: Dict) -> None:
if not member_context:
return
version = str(member_context.get("last_profiled_at", ""))
cache_key = f"{room_id}:{sender}"
if version and self._synced_member_context_versions.get(cache_key) == version:
return
text = self.context_builder._build_member_memory_prompt(member_context)
if not text or text == "暂无稳定成员画像。":
return
payload = {
"chatroom_id": room_id,
"wxid": sender,
"display_name": sender_name,
"memory_type": "member_context_snapshot",
"source_id": cache_key,
"last_active_at": member_context.get("last_profiled_at", ""),
"topic_tags": member_context.get("topics_of_interest", [])[:5],
"summary_text": member_context.get("summary_text", ""),
}
ok = self.vector_memory.upsert_memory(f"member_context:{cache_key}:{version}", text, payload)
self._log_event(
"memory_upsert",
room_id=room_id,
sender=sender,
memory_type="member_context_snapshot",
ok=ok,
)
if ok and version:
self._synced_member_context_versions[cache_key] = version
def _upsert_interaction_memory(
self,
room_id: str,
sender: str,
sender_name: str,
content: str,
response: str,
trigger_type: str,
topic: str,
) -> None:
text = f"{sender_name}说:{content}\n小牛回复:{response}"
payload = {
"chatroom_id": room_id,
"wxid": sender,
"display_name": sender_name,
"memory_type": "interaction_memory",
"topic_tags": [item for item in [topic, trigger_type] if item],
"created_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"source_id": f"{room_id}:{sender}:{int(time.time())}",
"summary_text": text[:500],
}
ok = self.vector_memory.upsert_memory(payload["source_id"], text, payload)
self._log_event(
"memory_upsert",
room_id=room_id,
sender=sender,
memory_type="interaction_memory",
ok=ok,
trigger_type=trigger_type,
)
def _log_event(self, event: str, **kwargs: Any) -> None:
if not self.log_debug:
return
summary = self._build_log_summary(event, kwargs)
self.LOG.info(summary)
@staticmethod
def _preview(text: str, limit: int = 80) -> str:
text = (text or "").replace("\n", "\\n").strip()
if len(text) <= limit:
return text
return text[: limit - 3] + "..."
def _build_log_summary(self, event: str, data: Dict[str, Any]) -> str:
room = self._short_id(data.get("room_id", ""))
sender_name = data.get("sender_name", "") or self._short_id(data.get("sender", ""))
sender = self._short_id(data.get("sender", ""))
if event == "recv":
return (
f"[XIAONIU] RECV room={room} user={sender_name}/{sender} "
f"at={self._yn(data.get('is_at'))} "
f"style={self._style_mark(data.get('humor_style', ''), data.get('sharpness_style', ''))} "
f"msg={data.get('content_preview', '')}"
).strip()
if event == "memory":
return (
f"[XIAONIU] MEMORY room={room} user={sender} "
f"ctx={self._yn(data.get('has_member_context'))} "
f"follow={self._yn(data.get('is_followup'))} "
f"return={data.get('returning_state', 'none')}"
).strip()
if event == "decision":
return (
f"[XIAONIU] DECIDE room={room} user={sender} "
f"trigger={data.get('trigger_type', 'none')} "
f"flow={data.get('flow_state', '')}:{data.get('flow_score', '')} "
f"topic={data.get('topic', '-') or '-'} "
f"reasons={data.get('reasons', '-') or '-'}"
).strip()
if event == "skip":
return (
f"[XIAONIU] SKIP room={room} user={sender} "
f"reason={data.get('reason', '')} "
f"trigger={data.get('trigger_type', 'none')} "
f"mode={data.get('reply_mode', '')}"
).strip()
if event == "context":
return (
f"[XIAONIU] CTX room={room} user={sender} "
f"mode={data.get('reply_mode', '')} "
f"recent={data.get('recent_message_count', 0)} "
f"vector={data.get('vector_hit_count', 0)}"
).strip()
if event == "model_empty":
return (
f"[XIAONIU] MODEL_EMPTY room={room} user={sender} "
f"model={data.get('model', '')} "
f"mode={data.get('reply_mode', '')} "
f"err={data.get('last_error', '')}"
).strip()
if event == "sent":
return (
f"[XIAONIU] SENT room={room} user={sender_name}/{sender} "
f"trigger={data.get('trigger_type', 'none')} "
f"mode={data.get('reply_mode', '')} "
f"len={data.get('response_len', 0)} "
f"reply={data.get('response_preview', '')}"
).strip()
if event == "memory_upsert":
return (
f"[XIAONIU] MEM_UPSERT room={room} user={sender} "
f"type={data.get('memory_type', '')} "
f"ok={self._yn(data.get('ok'))} "
f"trigger={data.get('trigger_type', '-') or '-'}"
).strip()
compact = " ".join(f"{key}={data[key]}" for key in sorted(data) if data.get(key) not in (None, ""))
return f"[XIAONIU] {event.upper()} {compact}".strip()
@staticmethod
def _yn(value: Any) -> str:
return "Y" if bool(value) else "N"
@staticmethod
def _short_id(value: str) -> str:
value = str(value or "")
if len(value) <= 10:
return value
return value[:4] + "..." + value[-4:]
@staticmethod
def _style_mark(humor_style: str, sharpness_style: str) -> str:
humor = "humor" if "中等" in str(humor_style) or "偏上" in str(humor_style) else "plain"
sharp = "sharp" if "毒舌" in str(sharpness_style) or "嘴欠" in str(sharpness_style) else "soft"
return f"{humor}/{sharp}"