diff --git a/plugins/dify/main.py b/plugins/dify/main.py index f56d36f..dba55fb 100644 --- a/plugins/dify/main.py +++ b/plugins/dify/main.py @@ -3,6 +3,7 @@ import cv2 import requests import json import time +import binascii import re # 添加re模块导入 import asyncio import base64 @@ -258,6 +259,15 @@ class DifyPlugin(MessagePluginInterface): response: str, roomid: str) -> Tuple[bool, str]: """发送响应消息的辅助方法""" try: + # 先识别“内联图片”: + # 1. Dify 工作流有可能直接返回 data URL 或纯 base64 图片内容; + # 2. 这类响应如果继续走文本清洗或 md2image,会被错误包装成“文字截图”; + # 3. 因此这里优先提取图片字节并直接发送,避免再经过 markdown 转图链路。 + inline_image_bytes = self._extract_inline_image_bytes(response) + if inline_image_bytes: + await bot.send_image_message(target, inline_image_bytes) + return True, "发送成功" + if response and not os.path.isfile(response): response = remove_reasoning_content(response) response = remove_trailing_content(response) @@ -506,7 +516,9 @@ class DifyPlugin(MessagePluginInterface): media_path = await downloader.download_media(outputs["result"]) answer = media_path - if answer and not os.path.isfile(answer): + # 如果工作流直接返回的是内联图片(data URL / base64), + # 这里不能再做文本裁剪,否则很容易把 base64 内容破坏掉。 + if answer and not os.path.isfile(answer) and not self._extract_inline_image_bytes(answer): answer = remove_reasoning_content(answer) answer = remove_trailing_content(answer) answer = remove_grok_render_tags(answer) @@ -543,6 +555,93 @@ class DifyPlugin(MessagePluginInterface): self.LOG.error(f"处理Dify响应时出错: {str(e)}") return False, f"处理响应时出错" + def _extract_inline_image_bytes(self, response_text: Any) -> bytes: + """尽量从 Dify 返回内容中提取内联图片字节。 + + 支持的输入形态: + 1. `data:image/png;base64,...` 这类 data URL; + 2. 纯 base64 图片字符串; + 3. JSON 字符串里常见的 `b64_json` / `image_base64` / `base64` / `data` 字段。 + """ + normalized_text = str(response_text or "").strip() + if not normalized_text: + return b"" + + # 先处理最标准的 data URL。 + image_bytes, mime_type = UnifiedLLMClient.decode_data_url(normalized_text) + if image_bytes and str(mime_type or "").startswith("image/"): + return image_bytes + + # 再尝试把字符串解析为 JSON,兼容工作流把 base64 包在结构化字段里的情况。 + # 这里优先挑常见图片字段,避免把普通 JSON 文本误判成图片。 + json_candidate = self._extract_base64_from_json_text(normalized_text) + if json_candidate: + image_bytes = self._decode_base64_image_bytes(json_candidate) + if image_bytes: + return image_bytes + + # 最后再把整个字符串当作纯 base64 图片尝试解码。 + return self._decode_base64_image_bytes(normalized_text) + + def _extract_base64_from_json_text(self, response_text: str) -> str: + """从 JSON 文本中提取可能的图片 base64 字段。""" + try: + parsed = json.loads(response_text) + except Exception: + return "" + + if isinstance(parsed, dict): + for key in ("b64_json", "image_base64", "base64", "data", "image_data"): + value = parsed.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return "" + + def _decode_base64_image_bytes(self, candidate: str) -> bytes: + """把候选字符串解码为图片字节,仅在确认为图片头时返回。""" + normalized = str(candidate or "").strip() + if not normalized: + return b"" + + # 一些网关会返回带 data URL 前缀的字段,这里顺手兼容一层。 + image_bytes, mime_type = UnifiedLLMClient.decode_data_url(normalized) + if image_bytes and str(mime_type or "").startswith("image/"): + return image_bytes + + # 对纯 base64 做保守校验: + # 1. 长度太短的普通文本不尝试; + # 2. 字符集必须符合 base64; + # 3. 只有解码后命中常见图片文件头,才认定为图片。 + if len(normalized) < 16: + return b"" + if not re.fullmatch(r"[A-Za-z0-9+/=\r\n]+", normalized): + return b"" + + compact_base64 = re.sub(r"\s+", "", normalized) + try: + decoded_bytes = base64.b64decode(compact_base64, validate=False) + except (ValueError, binascii.Error): + return b"" + + if self._is_supported_image_bytes(decoded_bytes): + return decoded_bytes + return b"" + + @staticmethod + def _is_supported_image_bytes(image_bytes: bytes) -> bool: + """根据文件头判断字节内容是否为常见图片格式。""" + if not image_bytes or len(image_bytes) < 12: + return False + + return any(( + image_bytes.startswith(b"\x89PNG\r\n\x1a\n"), + image_bytes.startswith(b"\xff\xd8\xff"), + image_bytes.startswith((b"GIF87a", b"GIF89a")), + image_bytes.startswith(b"BM"), + image_bytes.startswith(b"RIFF") and image_bytes[8:12] == b"WEBP", + image_bytes.startswith((b"II*\x00", b"MM\x00*")), + )) + def _cleanup_expired_conversations(self) -> None: """清理过期的会话""" current_time = time.time()