add multimodal quote handling for xiaoniu bot

This commit is contained in:
liuwei
2026-04-07 13:51:15 +08:00
parent 61edbbe987
commit 7c12738967
3 changed files with 184 additions and 9 deletions

View File

@@ -21,6 +21,7 @@ class ContextBuilder:
flow_state: str,
reply_mode: str,
vector_memories: List[Dict],
quote_context: Dict | None = None,
) -> Dict:
recent_lines = []
for item in recent_messages[-self.recent_context_size:]:
@@ -43,6 +44,7 @@ class ContextBuilder:
"memory_prompt": self._build_member_memory_prompt(member_context),
"vector_memory_prompt": self._build_vector_memory_prompt(vector_memories),
"group_profile_prompt": self._build_group_profile_prompt(group_profile or {}),
"quote_prompt": self._build_quote_prompt(quote_context or {}),
"current_message": f"{sender_name}: {content}",
}
@@ -116,3 +118,21 @@ class ContextBuilder:
str(style_profile.get("expressiveness_style", "") or "").strip(),
]
).strip(" /")
@staticmethod
def _build_quote_prompt(quote_context: Dict) -> str:
if not quote_context:
return ""
quote_type = quote_context.get("quote_type_label", "引用消息")
quote_sender = quote_context.get("quote_sender_name", "") or "未知成员"
quote_body = quote_context.get("quote_body", "") or ""
title = quote_context.get("title", "") or ""
lines = [
f"用户这次是在引用消息后发言。",
f"引用类型:{quote_type}",
f"被引用发送者:{quote_sender}",
f"图片附件:已附带原图" if quote_context.get("has_image_attachment") else "",
f"引用标题:{title}" if title else "",
f"被引用内容:{quote_body}" if quote_body else "",
]
return "\n".join([line for line in lines if line])

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import json
from typing import Dict, List
from typing import Dict, List, Optional
import requests
@@ -20,23 +20,35 @@ class LLMClient:
self.stream = bool(self.config.get("stream", True))
self.last_error = ""
def chat(self, system_prompt: str, user_prompt: str, user_id: str) -> str:
def chat(
self,
system_prompt: str,
user_prompt: str,
user_id: str,
image_urls: Optional[List[str]] = None,
) -> str:
self.last_error = ""
if not self.base_url:
self.last_error = "empty_base_url"
return ""
if self.provider == "openai_compatible":
return self._chat_openai_compatible(system_prompt, user_prompt, user_id)
return self._chat_openai_compatible(system_prompt, user_prompt, user_id, image_urls or [])
self.last_error = f"unsupported_provider:{self.provider}"
return ""
def _chat_openai_compatible(self, system_prompt: str, user_prompt: str, user_id: str) -> str:
def _chat_openai_compatible(
self,
system_prompt: str,
user_prompt: str,
user_id: str,
image_urls: List[str],
) -> str:
if not self.model:
return ""
payload = {
"model": self.model,
"messages": self._build_messages(system_prompt, user_prompt),
"messages": self._build_messages(system_prompt, user_prompt, image_urls),
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"user": user_id,
@@ -107,10 +119,19 @@ class LLMClient:
return ""
@staticmethod
def _build_messages(system_prompt: str, user_prompt: str) -> List[Dict[str, str]]:
def _build_messages(system_prompt: str, user_prompt: str, image_urls: List[str]) -> List[Dict]:
user_content: str | List[Dict[str, object]]
if image_urls:
content_parts: List[Dict[str, object]] = [{"type": "text", "text": user_prompt}]
for image_url in image_urls:
if image_url:
content_parts.append({"type": "image_url", "image_url": {"url": image_url}})
user_content = content_parts
else:
user_content = user_prompt
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
{"role": "user", "content": user_content},
]
@staticmethod

View File

@@ -1,5 +1,8 @@
from __future__ import annotations
import base64
import html
import imghdr
import re
import time
import xml.etree.ElementTree as ET
@@ -137,6 +140,7 @@ class AIAutoResponsePlugin(MessagePluginInterface):
sender = message.get("sender", "")
bot: WechatAPIClient = message.get("bot")
content = self._normalize_content(message)
quote_context = self._parse_quote_context(message.get("full_wx_msg"), room_id)
sender_name = self._get_sender_name(room_id, sender)
group_name = self._get_group_name(room_id, message)
group_memory_profile = self.group_memory_service.build_group_memory_profile(room_id, group_name)
@@ -153,6 +157,7 @@ class AIAutoResponsePlugin(MessagePluginInterface):
sharpness_style=group_profile.get("sharpness_style", ""),
is_at=message.get("is_at", False),
content_preview=self._preview(content),
quote_type=quote_context.get("quote_type_label", ""),
msg_type=str(message.get("type")),
)
@@ -236,6 +241,7 @@ class AIAutoResponsePlugin(MessagePluginInterface):
vector_memories = []
if self.vector_memory.should_search(reply_mode, trigger.trigger_type, memory_hints.get("returning_member_state", "")):
vector_memories = self.vector_memory.search(content, room_id, sender)
image_urls = await self._prepare_quote_image_inputs(bot, quote_context)
self._log_event(
"context",
room_id=room_id,
@@ -246,6 +252,7 @@ class AIAutoResponsePlugin(MessagePluginInterface):
reply_mode=reply_mode,
recent_message_count=len(recent_messages),
vector_hit_count=len(vector_memories),
image_input_count=len(image_urls),
)
context = self.context_builder.build(
@@ -260,11 +267,19 @@ class AIAutoResponsePlugin(MessagePluginInterface):
flow_state=flow_state.state,
reply_mode=reply_mode,
vector_memories=vector_memories,
quote_context=quote_context | {"has_image_attachment": bool(image_urls)},
)
system_prompt = self.persona_engine.build_system_prompt(group_profile)
user_prompt = self._build_user_prompt(context, memory_hints)
response = self._sanitize_response(self.llm_client.chat(system_prompt, user_prompt, user_id=f"{room_id}:{sender}"))
response = self._sanitize_response(
self.llm_client.chat(
system_prompt,
user_prompt,
user_id=f"{room_id}:{sender}",
image_urls=image_urls,
)
)
if not response:
self._log_event(
"model_empty",
@@ -361,6 +376,7 @@ class AIAutoResponsePlugin(MessagePluginInterface):
return (
f"当前群聊消息:\n{recent_text}\n\n"
f"当前发言:{context.get('current_message', '')}\n"
f"引用补充:\n{context.get('quote_prompt', '') or ''}\n"
f"触发类型:{context.get('trigger_type', 'none')}\n"
f"回复模式:{context.get('reply_mode', 'social_short')}\n"
f"当前心流状态:{context.get('flow_state', 'idle')}\n"
@@ -561,6 +577,7 @@ class AIAutoResponsePlugin(MessagePluginInterface):
f"[XIAONIU] RECV room={room} user={sender_name}/{sender} "
f"at={self._yn(data.get('is_at'))} "
f"style={self._style_mark(data.get('humor_style', ''), data.get('sharpness_style', ''))} "
f"quote={data.get('quote_type', '-') or '-'} "
f"msg={data.get('content_preview', '')}"
).strip()
@@ -597,7 +614,8 @@ class AIAutoResponsePlugin(MessagePluginInterface):
f"mode={data.get('reply_mode', '')} "
f"acc={data.get('acceptance_state', '-') or '-'} "
f"recent={data.get('recent_message_count', 0)} "
f"vector={data.get('vector_hit_count', 0)}"
f"vector={data.get('vector_hit_count', 0)} "
f"img={data.get('image_input_count', 0)}"
).strip()
if event == "model_empty":
@@ -644,3 +662,119 @@ class AIAutoResponsePlugin(MessagePluginInterface):
humor = "humor" if "中等" in str(humor_style) or "偏上" in str(humor_style) else "plain"
sharp = "sharp" if "毒舌" in str(sharpness_style) or "嘴欠" in str(sharpness_style) else "soft"
return f"{humor}/{sharp}"
def _parse_quote_context(self, full_msg: Any, room_id: str) -> Dict[str, str]:
if not full_msg or not getattr(full_msg, "content", None):
return {}
xml_content = getattr(full_msg.content, "xml_content", "") or ""
if not xml_content:
return {}
try:
root = ET.fromstring(xml_content)
except ET.ParseError:
return {}
appmsg = root.find(".//appmsg")
if appmsg is None or appmsg.findtext("type", "").strip() != "57":
return {}
refer = appmsg.find("refermsg")
if refer is None:
return {}
title = html.unescape(appmsg.findtext("title", "") or "").strip()
quote_sender_name = html.unescape(refer.findtext("displayname", "") or "").strip()
if not quote_sender_name:
quote_sender = html.unescape(refer.findtext("chatusr", "") or "").strip()
quote_sender_name = self._get_sender_name(room_id, quote_sender) if quote_sender else "未知成员"
ref_type = int(refer.findtext("type", "0") or 0)
ref_content = html.unescape(refer.findtext("content", "") or "").strip()
quote_type_label = self._quote_type_label(ref_type)
quote_body = self._build_quote_body(ref_type, ref_content, title)
return {
"title": title,
"quote_sender_name": quote_sender_name,
"quote_type_label": quote_type_label,
"quote_body": quote_body,
"raw_ref_content": ref_content,
}
@staticmethod
def _quote_type_label(ref_type: int) -> str:
mapping = {
MessageType.TEXT.value: "引用文本",
MessageType.IMAGE.value: "引用图片",
MessageType.VIDEO.value: "引用视频",
MessageType.APP.value: "引用应用消息",
MessageType.EMOTICON.value: "引用表情",
}
return mapping.get(ref_type, f"引用消息[{ref_type}]")
@staticmethod
def _build_quote_body(ref_type: int, ref_content: str, title: str) -> str:
if ref_type == MessageType.TEXT.value:
return ref_content[:220].strip()
if ref_type == MessageType.IMAGE.value:
details = []
if title:
details.append(f"当前追问文案:{title}")
if ref_content:
details.append("被引用的是一张图片")
return "".join(details) or "被引用的是一张图片"
if title:
return title[:220].strip()
return ref_content[:220].strip()
async def _prepare_quote_image_inputs(self, bot: WechatAPIClient, quote_context: Dict[str, str]) -> List[str]:
if not quote_context or quote_context.get("quote_type_label") != "引用图片":
return []
ref_content = quote_context.get("raw_ref_content", "") or ""
image_info = self._extract_quote_image_info(ref_content)
if not image_info:
return []
try:
base64_str = await bot.download_image(
aeskey=image_info["aeskey"],
cdnmidimgurl=image_info["url"],
)
except Exception as exc:
self._log_event("quote_image_fail", reason=f"download:{exc}")
return []
data_url = self._build_image_data_url(base64_str)
if not data_url:
self._log_event("quote_image_fail", reason="invalid_base64")
return []
return [data_url]
@staticmethod
def _extract_quote_image_info(ref_content: str) -> Dict[str, str]:
if not ref_content:
return {}
aeskey_match = re.search(r'aeskey="([^"]+)"', ref_content)
if not aeskey_match:
return {}
url_match = re.search(r'cdnmidimgurl="([^"]+)"', ref_content)
if not url_match:
url_match = re.search(r'cdnbigimgurl="([^"]+)"', ref_content)
if not url_match:
url_match = re.search(r'cdnthumburl="([^"]+)"', ref_content)
if not url_match:
return {}
return {
"aeskey": aeskey_match.group(1),
"url": url_match.group(1),
}
@staticmethod
def _build_image_data_url(base64_str: str) -> str:
raw_base64 = str(base64_str or "").strip()
if not raw_base64:
return ""
if "," in raw_base64 and raw_base64.startswith("data:"):
raw_base64 = raw_base64.split(",", 1)[1]
try:
image_bytes = base64.b64decode(raw_base64)
except Exception:
return ""
image_type = imghdr.what(None, h=image_bytes) or "jpeg"
return f"data:image/{image_type};base64,{raw_base64}"