add multimodal quote handling for xiaoniu bot
This commit is contained in:
@@ -21,6 +21,7 @@ class ContextBuilder:
|
||||
flow_state: str,
|
||||
reply_mode: str,
|
||||
vector_memories: List[Dict],
|
||||
quote_context: Dict | None = None,
|
||||
) -> Dict:
|
||||
recent_lines = []
|
||||
for item in recent_messages[-self.recent_context_size:]:
|
||||
@@ -43,6 +44,7 @@ class ContextBuilder:
|
||||
"memory_prompt": self._build_member_memory_prompt(member_context),
|
||||
"vector_memory_prompt": self._build_vector_memory_prompt(vector_memories),
|
||||
"group_profile_prompt": self._build_group_profile_prompt(group_profile or {}),
|
||||
"quote_prompt": self._build_quote_prompt(quote_context or {}),
|
||||
"current_message": f"{sender_name}: {content}",
|
||||
}
|
||||
|
||||
@@ -116,3 +118,21 @@ class ContextBuilder:
|
||||
str(style_profile.get("expressiveness_style", "") or "").strip(),
|
||||
]
|
||||
).strip(" /")
|
||||
|
||||
@staticmethod
|
||||
def _build_quote_prompt(quote_context: Dict) -> str:
|
||||
if not quote_context:
|
||||
return ""
|
||||
quote_type = quote_context.get("quote_type_label", "引用消息")
|
||||
quote_sender = quote_context.get("quote_sender_name", "") or "未知成员"
|
||||
quote_body = quote_context.get("quote_body", "") or ""
|
||||
title = quote_context.get("title", "") or ""
|
||||
lines = [
|
||||
f"用户这次是在引用消息后发言。",
|
||||
f"引用类型:{quote_type}",
|
||||
f"被引用发送者:{quote_sender}",
|
||||
f"图片附件:已附带原图" if quote_context.get("has_image_attachment") else "",
|
||||
f"引用标题:{title}" if title else "",
|
||||
f"被引用内容:{quote_body}" if quote_body else "",
|
||||
]
|
||||
return "\n".join([line for line in lines if line])
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
@@ -20,23 +20,35 @@ class LLMClient:
|
||||
self.stream = bool(self.config.get("stream", True))
|
||||
self.last_error = ""
|
||||
|
||||
def chat(self, system_prompt: str, user_prompt: str, user_id: str) -> str:
|
||||
def chat(
|
||||
self,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
user_id: str,
|
||||
image_urls: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
self.last_error = ""
|
||||
if not self.base_url:
|
||||
self.last_error = "empty_base_url"
|
||||
return ""
|
||||
if self.provider == "openai_compatible":
|
||||
return self._chat_openai_compatible(system_prompt, user_prompt, user_id)
|
||||
return self._chat_openai_compatible(system_prompt, user_prompt, user_id, image_urls or [])
|
||||
self.last_error = f"unsupported_provider:{self.provider}"
|
||||
return ""
|
||||
|
||||
def _chat_openai_compatible(self, system_prompt: str, user_prompt: str, user_id: str) -> str:
|
||||
def _chat_openai_compatible(
|
||||
self,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
user_id: str,
|
||||
image_urls: List[str],
|
||||
) -> str:
|
||||
if not self.model:
|
||||
return ""
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(system_prompt, user_prompt),
|
||||
"messages": self._build_messages(system_prompt, user_prompt, image_urls),
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"user": user_id,
|
||||
@@ -107,10 +119,19 @@ class LLMClient:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _build_messages(system_prompt: str, user_prompt: str) -> List[Dict[str, str]]:
|
||||
def _build_messages(system_prompt: str, user_prompt: str, image_urls: List[str]) -> List[Dict]:
|
||||
user_content: str | List[Dict[str, object]]
|
||||
if image_urls:
|
||||
content_parts: List[Dict[str, object]] = [{"type": "text", "text": user_prompt}]
|
||||
for image_url in image_urls:
|
||||
if image_url:
|
||||
content_parts.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
user_content = content_parts
|
||||
else:
|
||||
user_content = user_prompt
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import html
|
||||
import imghdr
|
||||
import re
|
||||
import time
|
||||
import xml.etree.ElementTree as ET
|
||||
@@ -137,6 +140,7 @@ class AIAutoResponsePlugin(MessagePluginInterface):
|
||||
sender = message.get("sender", "")
|
||||
bot: WechatAPIClient = message.get("bot")
|
||||
content = self._normalize_content(message)
|
||||
quote_context = self._parse_quote_context(message.get("full_wx_msg"), room_id)
|
||||
sender_name = self._get_sender_name(room_id, sender)
|
||||
group_name = self._get_group_name(room_id, message)
|
||||
group_memory_profile = self.group_memory_service.build_group_memory_profile(room_id, group_name)
|
||||
@@ -153,6 +157,7 @@ class AIAutoResponsePlugin(MessagePluginInterface):
|
||||
sharpness_style=group_profile.get("sharpness_style", ""),
|
||||
is_at=message.get("is_at", False),
|
||||
content_preview=self._preview(content),
|
||||
quote_type=quote_context.get("quote_type_label", ""),
|
||||
msg_type=str(message.get("type")),
|
||||
)
|
||||
|
||||
@@ -236,6 +241,7 @@ class AIAutoResponsePlugin(MessagePluginInterface):
|
||||
vector_memories = []
|
||||
if self.vector_memory.should_search(reply_mode, trigger.trigger_type, memory_hints.get("returning_member_state", "")):
|
||||
vector_memories = self.vector_memory.search(content, room_id, sender)
|
||||
image_urls = await self._prepare_quote_image_inputs(bot, quote_context)
|
||||
self._log_event(
|
||||
"context",
|
||||
room_id=room_id,
|
||||
@@ -246,6 +252,7 @@ class AIAutoResponsePlugin(MessagePluginInterface):
|
||||
reply_mode=reply_mode,
|
||||
recent_message_count=len(recent_messages),
|
||||
vector_hit_count=len(vector_memories),
|
||||
image_input_count=len(image_urls),
|
||||
)
|
||||
|
||||
context = self.context_builder.build(
|
||||
@@ -260,11 +267,19 @@ class AIAutoResponsePlugin(MessagePluginInterface):
|
||||
flow_state=flow_state.state,
|
||||
reply_mode=reply_mode,
|
||||
vector_memories=vector_memories,
|
||||
quote_context=quote_context | {"has_image_attachment": bool(image_urls)},
|
||||
)
|
||||
|
||||
system_prompt = self.persona_engine.build_system_prompt(group_profile)
|
||||
user_prompt = self._build_user_prompt(context, memory_hints)
|
||||
response = self._sanitize_response(self.llm_client.chat(system_prompt, user_prompt, user_id=f"{room_id}:{sender}"))
|
||||
response = self._sanitize_response(
|
||||
self.llm_client.chat(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
user_id=f"{room_id}:{sender}",
|
||||
image_urls=image_urls,
|
||||
)
|
||||
)
|
||||
if not response:
|
||||
self._log_event(
|
||||
"model_empty",
|
||||
@@ -361,6 +376,7 @@ class AIAutoResponsePlugin(MessagePluginInterface):
|
||||
return (
|
||||
f"当前群聊消息:\n{recent_text}\n\n"
|
||||
f"当前发言:{context.get('current_message', '')}\n"
|
||||
f"引用补充:\n{context.get('quote_prompt', '') or '无'}\n"
|
||||
f"触发类型:{context.get('trigger_type', 'none')}\n"
|
||||
f"回复模式:{context.get('reply_mode', 'social_short')}\n"
|
||||
f"当前心流状态:{context.get('flow_state', 'idle')}\n"
|
||||
@@ -561,6 +577,7 @@ class AIAutoResponsePlugin(MessagePluginInterface):
|
||||
f"[XIAONIU] RECV room={room} user={sender_name}/{sender} "
|
||||
f"at={self._yn(data.get('is_at'))} "
|
||||
f"style={self._style_mark(data.get('humor_style', ''), data.get('sharpness_style', ''))} "
|
||||
f"quote={data.get('quote_type', '-') or '-'} "
|
||||
f"msg={data.get('content_preview', '')}"
|
||||
).strip()
|
||||
|
||||
@@ -597,7 +614,8 @@ class AIAutoResponsePlugin(MessagePluginInterface):
|
||||
f"mode={data.get('reply_mode', '')} "
|
||||
f"acc={data.get('acceptance_state', '-') or '-'} "
|
||||
f"recent={data.get('recent_message_count', 0)} "
|
||||
f"vector={data.get('vector_hit_count', 0)}"
|
||||
f"vector={data.get('vector_hit_count', 0)} "
|
||||
f"img={data.get('image_input_count', 0)}"
|
||||
).strip()
|
||||
|
||||
if event == "model_empty":
|
||||
@@ -644,3 +662,119 @@ class AIAutoResponsePlugin(MessagePluginInterface):
|
||||
humor = "humor" if "中等" in str(humor_style) or "偏上" in str(humor_style) else "plain"
|
||||
sharp = "sharp" if "毒舌" in str(sharpness_style) or "嘴欠" in str(sharpness_style) else "soft"
|
||||
return f"{humor}/{sharp}"
|
||||
|
||||
def _parse_quote_context(self, full_msg: Any, room_id: str) -> Dict[str, str]:
|
||||
if not full_msg or not getattr(full_msg, "content", None):
|
||||
return {}
|
||||
xml_content = getattr(full_msg.content, "xml_content", "") or ""
|
||||
if not xml_content:
|
||||
return {}
|
||||
try:
|
||||
root = ET.fromstring(xml_content)
|
||||
except ET.ParseError:
|
||||
return {}
|
||||
|
||||
appmsg = root.find(".//appmsg")
|
||||
if appmsg is None or appmsg.findtext("type", "").strip() != "57":
|
||||
return {}
|
||||
|
||||
refer = appmsg.find("refermsg")
|
||||
if refer is None:
|
||||
return {}
|
||||
|
||||
title = html.unescape(appmsg.findtext("title", "") or "").strip()
|
||||
quote_sender_name = html.unescape(refer.findtext("displayname", "") or "").strip()
|
||||
if not quote_sender_name:
|
||||
quote_sender = html.unescape(refer.findtext("chatusr", "") or "").strip()
|
||||
quote_sender_name = self._get_sender_name(room_id, quote_sender) if quote_sender else "未知成员"
|
||||
ref_type = int(refer.findtext("type", "0") or 0)
|
||||
ref_content = html.unescape(refer.findtext("content", "") or "").strip()
|
||||
quote_type_label = self._quote_type_label(ref_type)
|
||||
quote_body = self._build_quote_body(ref_type, ref_content, title)
|
||||
return {
|
||||
"title": title,
|
||||
"quote_sender_name": quote_sender_name,
|
||||
"quote_type_label": quote_type_label,
|
||||
"quote_body": quote_body,
|
||||
"raw_ref_content": ref_content,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _quote_type_label(ref_type: int) -> str:
|
||||
mapping = {
|
||||
MessageType.TEXT.value: "引用文本",
|
||||
MessageType.IMAGE.value: "引用图片",
|
||||
MessageType.VIDEO.value: "引用视频",
|
||||
MessageType.APP.value: "引用应用消息",
|
||||
MessageType.EMOTICON.value: "引用表情",
|
||||
}
|
||||
return mapping.get(ref_type, f"引用消息[{ref_type}]")
|
||||
|
||||
@staticmethod
|
||||
def _build_quote_body(ref_type: int, ref_content: str, title: str) -> str:
|
||||
if ref_type == MessageType.TEXT.value:
|
||||
return ref_content[:220].strip()
|
||||
if ref_type == MessageType.IMAGE.value:
|
||||
details = []
|
||||
if title:
|
||||
details.append(f"当前追问文案:{title}")
|
||||
if ref_content:
|
||||
details.append("被引用的是一张图片")
|
||||
return ";".join(details) or "被引用的是一张图片"
|
||||
if title:
|
||||
return title[:220].strip()
|
||||
return ref_content[:220].strip()
|
||||
|
||||
async def _prepare_quote_image_inputs(self, bot: WechatAPIClient, quote_context: Dict[str, str]) -> List[str]:
|
||||
if not quote_context or quote_context.get("quote_type_label") != "引用图片":
|
||||
return []
|
||||
ref_content = quote_context.get("raw_ref_content", "") or ""
|
||||
image_info = self._extract_quote_image_info(ref_content)
|
||||
if not image_info:
|
||||
return []
|
||||
try:
|
||||
base64_str = await bot.download_image(
|
||||
aeskey=image_info["aeskey"],
|
||||
cdnmidimgurl=image_info["url"],
|
||||
)
|
||||
except Exception as exc:
|
||||
self._log_event("quote_image_fail", reason=f"download:{exc}")
|
||||
return []
|
||||
data_url = self._build_image_data_url(base64_str)
|
||||
if not data_url:
|
||||
self._log_event("quote_image_fail", reason="invalid_base64")
|
||||
return []
|
||||
return [data_url]
|
||||
|
||||
@staticmethod
|
||||
def _extract_quote_image_info(ref_content: str) -> Dict[str, str]:
|
||||
if not ref_content:
|
||||
return {}
|
||||
aeskey_match = re.search(r'aeskey="([^"]+)"', ref_content)
|
||||
if not aeskey_match:
|
||||
return {}
|
||||
url_match = re.search(r'cdnmidimgurl="([^"]+)"', ref_content)
|
||||
if not url_match:
|
||||
url_match = re.search(r'cdnbigimgurl="([^"]+)"', ref_content)
|
||||
if not url_match:
|
||||
url_match = re.search(r'cdnthumburl="([^"]+)"', ref_content)
|
||||
if not url_match:
|
||||
return {}
|
||||
return {
|
||||
"aeskey": aeskey_match.group(1),
|
||||
"url": url_match.group(1),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_image_data_url(base64_str: str) -> str:
|
||||
raw_base64 = str(base64_str or "").strip()
|
||||
if not raw_base64:
|
||||
return ""
|
||||
if "," in raw_base64 and raw_base64.startswith("data:"):
|
||||
raw_base64 = raw_base64.split(",", 1)[1]
|
||||
try:
|
||||
image_bytes = base64.b64decode(raw_base64)
|
||||
except Exception:
|
||||
return ""
|
||||
image_type = imghdr.what(None, h=image_bytes) or "jpeg"
|
||||
return f"data:image/{image_type};base64,{raw_base64}"
|
||||
|
||||
Reference in New Issue
Block a user