Files
abot/plugins/dify/main.py

696 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import cv2
import requests
import json
import time
import re # 添加re模块导入
import asyncio
import base64
import html
import xml.etree.ElementTree as ET
from typing import Dict, Any, List, Optional, Tuple
from pathlib import Path
from base.plugin_common.message_plugin_interface import MessagePluginInterface
from base.plugin_common.plugin_interface import PluginStatus
from utils.decorator.plugin_decorators import plugin_stats_decorator
from utils.decorator.rate_limit_decorator import group_feature_rate_limit
from utils.markdown_to_image import convert_md_str_to_image
from utils.revoke.message_auto_revoke import MessageAutoRevoke
from utils.robot_cmd.robot_command import Feature, PermissionStatus, GroupBotManager
from utils.decorator.points_decorator import plugin_points_cost
from utils.media_downloader import MediaDownloader
from utils.string_utils import remove_reasoning_content, remove_trailing_content, remove_grok_render_tags
from wechat_ipad import WechatAPIClient
from wechat_ipad.models.message import MessageType
import aiohttp
# 常见的图片和视频文件扩展名
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.wmv', '.flv', '.mkv', '.webm'}
class DifyPlugin(MessagePluginInterface):
"""Dify AI聊天插件"""
# 功能权限常量
FEATURE_KEY = "AI_CAPABILITY"
FEATURE_DESCRIPTION = "🤖 AI对话 [ai, 聊天, AI] 用法ai 如何写一个机器人?"
@property
def name(self) -> str:
return "Dify聊天"
@property
def version(self) -> str:
return "1.0.0"
@property
def description(self) -> str:
return "提供基于Dify的AI聊天功能"
@property
def author(self) -> str:
return "liu.wei"
@property
def command_prefix(self) -> Optional[str]:
return "" # 不需要前缀,直接匹配命令
@property
def commands(self) -> List[str]:
return self._commands
@property
def feature_key(self) -> Optional[str]:
return self.FEATURE_KEY
@property
def feature_description(self) -> Optional[str]:
return self.FEATURE_DESCRIPTION
def __init__(self):
super().__init__()
# 会话上下文管理,格式: {group_id/wxid: [conversation_history]}
self.conversations: Dict[str, List[Dict]] = {}
# tokens 消耗统计,格式: {wxid: total_tokens}
self.token_usage: Dict[str, int] = {}
# 最大上下文长度
self.max_history_length = 10
# 会话过期时间(秒)
self.conversation_timeout = 3600 # 1小时
self.last_activity: Dict[str, float] = {}
# 注册功能权限
self.feature = self.register_feature()
def initialize(self, context: Dict[str, Any]) -> bool:
"""初始化插件"""
self.LOG.debug(f"正在初始化 {self.name} 插件...")
# 保存上下文对象
self.event_system = context.get("event_system")
self.gbm = context.get("gbm")
# 从配置中获取参数
dify_config = self._config.get("Dify", {})
self._commands = dify_config.get("commands", ["ai", "dify", "聊天", "AI"])
self.command_format = dify_config.get("command-tip", "聊天 请求内容")
self.enable = dify_config.get("enable", True)
self.api_key = dify_config.get("api-key", "")
self.base_url = dify_config.get("base-url", "")
self.price = dify_config.get("price", 0)
self.admin_ignore = dify_config.get("admin_ignore", False)
self.whitelist_ignore = dify_config.get("whitelist_ignore", False)
self.http_proxy = dify_config.get("http-proxy", "")
self.LOG.debug(f"[{self.name}] 插件初始化完成,指令:{self._commands}")
return True
def start(self) -> bool:
"""启动插件"""
self.LOG.debug(f"[{self.name}] 插件已启动")
self.status = PluginStatus.RUNNING
return True
def stop(self) -> bool:
"""停止插件"""
self.LOG.info(f"[{self.name}] 插件已停止")
self.status = PluginStatus.STOPPED
return True
def can_process(self, message: Dict[str, Any]) -> bool:
"""检查是否可以处理该消息"""
if not self.enable:
return False
content = str(message.get("content", "")).strip()
command = content.split(" ")[0]
# 检查是否是命令触发
if command in self._commands:
return True
# 检查是否是被@的消息
if message.get("is_at", False) and message.get("roomid", ""):
# 只处理群聊中被@的消息
return True
return False
@plugin_stats_decorator(plugin_name="Dify聊天")
@group_feature_rate_limit(max_per_minute=5, feature_key=FEATURE_KEY)
async def process_message(self, message: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
"""处理消息"""
content = str(message.get("content", "")).strip()
self.LOG.debug(f"插件执行: {self.name}{content}")
sender = message.get("sender")
roomid = message.get("roomid", "")
gbm: GroupBotManager = message.get("gbm")
bot: WechatAPIClient = message.get("bot")
revoke: MessageAutoRevoke = message.get("revoke")
# 获取目标接收者群聊为群ID私聊为个人wxid
target = roomid if roomid else sender
self.LOG.debug(
f"消息上下文: sender={sender}, roomid={roomid}, target={target}, is_at={message.get('is_at', False)}")
# 检查权限
if roomid and gbm.get_group_permission(target, self.feature) == PermissionStatus.DISABLED:
return False, "没有权限"
# 处理被@的消息
if message.get("is_at", False) and roomid:
# 去除@的人和空格等字符
query = re.sub(r"@.*?[\u2005|\s]", "", content).strip()
if not query:
# await bot.send_at_message(target, "请在@我的同时提供问题内容", [sender])
return False, "没有提供问题内容"
else:
# 处理命令消息
parts = content.split(" ", 1)
if len(parts) < 2 or not parts[1].strip():
await bot.send_text_message(target, f"{self.command_format}", sender)
return False, "命令格式错误"
query = parts[1].strip()
self.LOG.debug(f"解析请求: query_len={len(query)} query_preview={query[:120]}")
try:
# 发送等待消息
# client_msg_id, create_time, new_msg_id = await bot.send_text_message(
# target, "⏳AI 正在加油,请稍候… 😊", sender if roomid else "")
# revoke.add_message_to_revoke(target, client_msg_id, create_time, new_msg_id, 3)
dify_files = []
quote_payload = self._parse_quote_payload(message.get("full_wx_msg"))
if quote_payload:
ref_type = quote_payload.get("ref_type", 0)
ref_content = quote_payload.get("ref_content", "").strip()
title = quote_payload.get("title", "").strip()
self.LOG.debug(
f"检测到引用消息: type={ref_type}, title_preview={title[:80]}, ref_len={len(ref_content)}")
if ref_type == MessageType.TEXT.value:
if ref_content:
self.LOG.debug("使用引用文本作为问题")
query = ref_content
elif ref_type == MessageType.IMAGE.value:
if title:
title_query = re.sub(r"@.*?[\u2005|\s]", "", title).strip()
if title_query:
query = title_query
self.LOG.debug("检测到引用图片,开始下载与上传")
image_path = await self._download_quote_image(bot, ref_content)
if not image_path:
return False, "图片引用解析失败"
upload_id = await self._upload_file_to_dify(image_path, sender)
if not upload_id:
return False, "图片上传失败"
self.LOG.debug(f"图片上传完成: upload_id={upload_id}")
dify_files.append({
"type": "image",
"transfer_method": "local_file",
"upload_file_id": upload_id
})
elif ref_type == MessageType.VIDEO.value:
return False, "暂不支持视频引用"
elif ref_type in (MessageType.EMOTICON.value, MessageType.EMOJI.value):
return False, "暂不支持表情引用"
success, response = await self._chat_with_dify(target, sender, query, dify_files)
if not success:
return False, response
# 去除广告内容
response = remove_trailing_content(response)
response = remove_grok_render_tags(response)
self.LOG.debug(f"处理后的响应: {response}")
# 发送回复
if response:
return await self._send_response(bot, target, sender, response, roomid)
else:
client_msg_id, create_time, new_msg_id = await bot.send_text_message(
target, "❌未能获取到回复,请稍后再试", sender if roomid else "")
revoke.add_message_to_revoke(target, client_msg_id, create_time, new_msg_id, 5)
return False, "未获取到回复"
except Exception as e:
self.LOG.exception(f"处理Dify聊天请求出错: {e}")
client_msg_id, create_time, new_msg_id = await bot.send_text_message(
target, f"❌DIFY响应失败", sender if roomid else "")
revoke.add_message_to_revoke(target, client_msg_id, create_time, new_msg_id, 5)
return False, f"处理出错: {e}"
async def _send_response(self, bot: WechatAPIClient, target: str, sender: str,
response: str, roomid: str) -> Tuple[bool, str]:
"""发送响应消息的辅助方法"""
try:
if response and not os.path.isfile(response):
response = remove_reasoning_content(response)
response = remove_trailing_content(response)
response = remove_grok_render_tags(response)
response = re.sub(r'\n{3,}', '\n\n', response).strip()
# 判断是否为本地文件路径
if os.path.isfile(response):
# 如果是文件路径,使用发送文件方法
file_type = self.check_file_type(response)
if file_type == 1:
await bot.send_image_message(target, Path(response))
elif file_type == 2:
first_frame = await self._get_first_frame(response, f"dify_frame_{int(time.time())}.jpg")
await bot.send_video_message(target, Path(response), Path(first_frame))
else:
return False, "获取媒资失败"
else:
# 如果是普通文本则在长度大于800字时转为图片发送
if len(response) > 1500:
# 转图片
output_image = f"dify_output_{int(time.time())}.png"
respath = await convert_md_str_to_image(response, output_image)
await bot.send_image_message(target, Path(respath))
else:
# 如果是普通文本,使用发送文本方法
await bot.send_text_message(target, response, sender)
return True, "发送成功"
except Exception as e:
self.LOG.error(f"发送响应消息时出错: {e}")
return False, f"发送响应失败: {e}"
def _parse_quote_payload(self, full_msg: Any) -> Optional[Dict[str, Any]]:
if not full_msg or not getattr(full_msg, "content", None):
return None
xml_content = getattr(full_msg.content, "xml_content", "")
if not xml_content:
return None
try:
root = ET.fromstring(xml_content)
except ET.ParseError:
return None
appmsg = root.find(".//appmsg")
if appmsg is None:
return None
if appmsg.findtext("type", "").strip() != "57":
return None
title = appmsg.findtext("title", "") or ""
refer = appmsg.find("refermsg")
if refer is None:
return None
ref_type = int(refer.findtext("type", "0") or 0)
ref_content = refer.findtext("content", "") or ""
self.LOG.debug(
f"引用解析成功: type={ref_type}, title_len={len(title)}, content_len={len(ref_content)}")
return {
"title": html.unescape(title),
"ref_type": ref_type,
"ref_content": html.unescape(ref_content)
}
def _extract_quote_image_info(self, ref_content: str) -> Optional[Dict[str, str]]:
if not ref_content:
return None
aeskey_match = re.search(r'aeskey="([^"]+)"', ref_content)
if not aeskey_match:
return None
url_match = re.search(r'cdnmidimgurl="([^"]+)"', ref_content)
if not url_match:
url_match = re.search(r'cdnbigimgurl="([^"]+)"', ref_content)
if not url_match:
url_match = re.search(r'cdnthumburl="([^"]+)"', ref_content)
if not url_match:
return None
md5_match = re.search(r'md5="([^"]+)"', ref_content)
return {
"aeskey": aeskey_match.group(1),
"url": url_match.group(1),
"md5": md5_match.group(1) if md5_match else ""
}
async def _download_quote_image(self, bot: WechatAPIClient, ref_content: str) -> Optional[str]:
image_info = self._extract_quote_image_info(ref_content)
if not image_info:
return None
self.LOG.debug(
f"准备下载引用图片: url_len={len(image_info['url'])}, aeskey_prefix={image_info['aeskey'][:6]}")
try:
base64_str = await bot.download_image(
aeskey=image_info["aeskey"],
cdnmidimgurl=image_info["url"]
)
except Exception as e:
self.LOG.error(f"下载引用图片失败: {e}")
return None
if not base64_str:
return None
try:
image_data = base64.b64decode(base64_str)
except Exception as e:
self.LOG.error(f"解码引用图片失败: {e}")
return None
temp_dir = Path(__file__).resolve().parents[2] / "temp"
os.makedirs(temp_dir, exist_ok=True)
suffix = image_info["md5"] if image_info.get("md5") else str(int(time.time()))
file_path = temp_dir / f"dify_quote_{suffix}.jpg"
try:
with open(file_path, "wb") as f:
f.write(image_data)
except Exception as e:
self.LOG.error(f"保存引用图片失败: {e}")
return None
self.LOG.debug(f"引用图片已保存: {file_path}")
return str(file_path)
async def _upload_file_to_dify(self, file_path: str, user_id: str) -> Optional[str]:
if not file_path or not os.path.isfile(file_path):
return None
self.LOG.debug(f"开始上传文件到Dify: {file_path}")
url = f"{self.base_url}/files/upload"
headers = {"Authorization": f"Bearer {self.api_key}"}
proxy = self.http_proxy if self.http_proxy else None
try:
async with aiohttp.ClientSession() as session:
with open(file_path, "rb") as f:
form = aiohttp.FormData()
form.add_field("file", f, filename=os.path.basename(file_path))
form.add_field("user", user_id)
response = await session.post(url, headers=headers, data=form, proxy=proxy, timeout=40)
if response.status not in (200, 201):
error_text = await response.text()
self.LOG.error(f"Dify上传失败: {response.status} {error_text}")
return None
resp_data = await response.json()
self.LOG.debug(f"Dify上传成功: status={response.status}, keys={list(resp_data.keys())}")
if isinstance(resp_data, dict):
if resp_data.get("id"):
return resp_data.get("id")
data = resp_data.get("data", {})
if isinstance(data, dict):
return data.get("id") or data.get("file_id")
return None
except Exception as e:
self.LOG.error(f"上传文件到Dify失败: {e}")
return None
async def _chat_with_dify(self, session_id: str, user_id: str, query: str,
files: Optional[List[Dict[str, Any]]] = None) -> Tuple[bool, Optional[str]]:
"""
与Dify API交互获取回复
Args:
session_id: 会话ID群聊为群ID私聊为个人wxid
user_id: 用户wxid
query: 用户查询内容
Returns:
API返回的回复内容
"""
# 清理过期会话
self._cleanup_expired_conversations()
# 更新最后活动时间
self.last_activity[session_id] = time.time()
# 初始化会话历史
if session_id not in self.conversations:
self.conversations[session_id] = []
# 准备请求头
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "text/event-stream" # 指定接受事件流
}
# 准备历史记录
history_text = ""
if self.conversations[session_id]:
# 将历史记录转换为字符串格式
for msg in self.conversations[session_id]:
role = "用户" if msg["role"] == "user" else "助手"
if role == "用户":
history_text += f"{role}: {msg['content']}\n"
history_text = history_text.strip()
# 准备输入参数
inputs_params = {
"query": query,
"conversation_id": session_id
}
# 如果有历史记录添加到inputs_params中
if history_text:
inputs_params["history"] = history_text
if files is None:
files = []
self.LOG.debug(f"Dify请求准备: files={len(files)}")
# 准备请求数据
data = {
"files": files,
"user": user_id,
"inputs": inputs_params,
"response_mode": "blocking" # 使用阻塞响应模式
}
# 如果有历史记录同时添加到conversation_history中
if self.conversations[session_id]:
data["conversation_history"] = self.conversations[session_id]
# 设置代理
proxy = self.http_proxy if self.http_proxy else None
# 发送请求
url = f"{self.base_url}/workflows/run"
self.LOG.info(f"发送请求到Dify API: {url}")
self.LOG.info(f"请求数据: {json.dumps(data, ensure_ascii=False)}")
try:
async with aiohttp.ClientSession() as session:
response = await session.post(url, headers=headers, json=data, proxy=proxy, timeout=40)
if response.status != 200:
error_text = await response.text()
self.LOG.error(f"Dify API请求失败: {response.status} {error_text}")
return False, f"请求失败,状态码: {response.status}"
# 解析响应
response_data = await response.json()
self.LOG.info(f"收到Dify API响应: {json.dumps(response_data, ensure_ascii=False)}")
# 提取回答内容
answer = ""
total_tokens = 0
# 获取输出内容
outputs = response_data.get("data", {}).get("outputs", {})
if outputs:
# 处理媒体类型返回
if "result" in outputs and "type" in outputs:
if outputs["type"] == "image":
downloader = MediaDownloader()
image_url = outputs["result"]
image_path = await downloader.download_media(image_url)
answer = image_path
if outputs["type"] == "video":
downloader = MediaDownloader()
image_url = outputs["result"]
image_path = await downloader.download_media(image_url)
answer = image_path
# 处理文本类型返回
elif "text" in outputs and isinstance(outputs["text"], str):
answer = outputs["text"]
# 兼容旧版处理逻辑
else:
for key, value in outputs.items():
if isinstance(value, str) and value.strip():
answer += value
elif isinstance(value, dict):
# 处理嵌套字典的情况
for sub_key, sub_value in value.items():
if isinstance(sub_value, str) and sub_value.strip():
answer += sub_value
elif isinstance(value, list):
# 处理列表的情况
for item in value:
if isinstance(item, str) and item.strip():
answer += item
elif isinstance(item, dict):
# 处理列表中的字典
for item_key, item_value in item.items():
if isinstance(item_value, str) and item_value.strip():
answer += item_value
# 获取token使用情况
total_tokens = response_data.get("data", {}).get("total_tokens", 0)
if answer and not os.path.isfile(answer):
answer = remove_reasoning_content(answer)
answer = remove_trailing_content(answer)
answer = remove_grok_render_tags(answer)
answer = re.sub(r'\n{3,}', '\n\n', answer).strip()
# 更新会话历史
self.conversations[session_id].append({
"role": "user",
"content": query
})
self.conversations[session_id].append({
"role": "assistant",
"content": answer
})
# 限制会话历史长度
if len(self.conversations[session_id]) > self.max_history_length * 2:
self.conversations[session_id] = self.conversations[session_id][-self.max_history_length * 2:]
# 统计token使用情况
if total_tokens > 0:
if user_id in self.token_usage:
self.token_usage[user_id] += total_tokens
else:
self.token_usage[user_id] = total_tokens
self.LOG.info(
f"用户 {user_id} 本次消耗 {total_tokens} tokens累计 {self.token_usage[user_id]} tokens")
return True, answer
except Exception as e:
self.LOG.error(f"处理Dify响应时出错: {str(e)}")
return False, f"处理响应时出错"
def _cleanup_expired_conversations(self) -> None:
"""清理过期的会话"""
current_time = time.time()
expired_sessions = []
for session_id, last_time in self.last_activity.items():
if current_time - last_time > self.conversation_timeout:
expired_sessions.append(session_id)
for session_id in expired_sessions:
if session_id in self.conversations:
del self.conversations[session_id]
del self.last_activity[session_id]
if expired_sessions:
self.LOG.info(f"已清理 {len(expired_sessions)} 个过期会话")
def get_token_usage_report(self) -> str:
"""获取token使用情况报告"""
if not self.token_usage:
return "暂无token使用记录"
report = "Token使用情况统计\n"
for user_id, tokens in sorted(self.token_usage.items(), key=lambda x: x[1], reverse=True):
report += f"用户 {user_id}: {tokens} tokens\n"
total = sum(self.token_usage.values())
report += f"\n总计: {total} tokens"
return report
def reset_conversation(self, session_id: str) -> bool:
"""重置指定会话的上下文"""
if session_id in self.conversations:
del self.conversations[session_id]
if session_id in self.last_activity:
del self.last_activity[session_id]
return True
return False
def reset_all_conversations(self) -> None:
"""重置所有会话上下文"""
self.conversations.clear()
self.last_activity.clear()
def check_file_type(self, file_path) -> int:
if os.path.isfile(file_path):
# 获取文件扩展名并转换为小写
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension in IMAGE_EXTENSIONS:
return 1
# 在此处添加发送图片的逻辑
elif file_extension in VIDEO_EXTENSIONS:
return 2
# 在此处添加发送视频的逻辑
else:
return 0
else:
return 0
async def _get_first_frame(self, video_path: str, output_path: str) -> Optional[str]:
"""
异步提取视频的第一帧并保存为图片
:param video_path: 视频文件路径
:param output_path: 输出图片路径
:return: 输出图片的绝对路径如果失败则返回None
"""
try:
self.LOG.info(f"开始提取视频首帧: {video_path}")
# 使用线程池执行OpenCV操作
def extract_frame():
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
self.LOG.error(f"无法打开视频: {video_path}")
return None
ret, frame = cap.read()
if not ret:
self.LOG.error("无法读取视频帧")
cap.release()
return None
try:
cv2.imwrite(output_path, frame)
self.LOG.info(f"首帧已保存为: {output_path}")
except Exception as e:
self.LOG.error(f"保存首帧图片失败: {e}")
cap.release()
return None
cap.release()
return os.path.abspath(output_path)
# 在线程池中执行OpenCV操作
result = await asyncio.to_thread(extract_frame)
return result
except Exception as e:
self.LOG.error(f"提取视频首帧时出错: {e}")
return None