Files
abot/plugins/dify/main.py

447 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import requests
import json
import time
import re # 添加re模块导入
from typing import Dict, Any, List, Optional, Tuple
from pathlib import Path
from base.plugin_common.message_plugin_interface import MessagePluginInterface
from base.plugin_common.plugin_interface import PluginStatus
from utils.decorator.plugin_decorators import plugin_stats_decorator
from utils.markdown_to_image import convert_md_str_to_image
from utils.revoke.message_auto_revoke import MessageAutoRevoke
from utils.robot_cmd.robot_command import Feature, PermissionStatus, GroupBotManager
from utils.decorator.points_decorator import plugin_points_cost
from utils.media_downloader import MediaDownloader
from utils.string_utils import remove_trailing_content
from wechat_ipad import WechatAPIClient
class DifyPlugin(MessagePluginInterface):
"""Dify AI聊天插件"""
@property
def name(self) -> str:
return "Dify聊天"
@property
def version(self) -> str:
return "1.0.0"
@property
def description(self) -> str:
return "提供基于Dify的AI聊天功能"
@property
def author(self) -> str:
return "Trae AI"
@property
def command_prefix(self) -> Optional[str]:
return "" # 不需要前缀,直接匹配命令
@property
def commands(self) -> List[str]:
return self._commands
def __init__(self):
super().__init__()
# 会话上下文管理,格式: {group_id/wxid: [conversation_history]}
self.conversations: Dict[str, List[Dict]] = {}
# tokens 消耗统计,格式: {wxid: total_tokens}
self.token_usage: Dict[str, int] = {}
# 最大上下文长度
self.max_history_length = 10
# 会话过期时间(秒)
self.conversation_timeout = 3600 # 1小时
self.last_activity: Dict[str, float] = {}
def initialize(self, context: Dict[str, Any]) -> bool:
"""初始化插件"""
self.LOG.info(f"正在初始化 {self.name} 插件...")
# 保存上下文对象
self.event_system = context.get("event_system")
self.gbm = context.get("gbm")
# 从配置中获取参数
dify_config = self._config.get("Dify", {})
self._commands = dify_config.get("commands", ["ai", "dify", "聊天", "AI"])
self.command_format = dify_config.get("command-tip", "聊天 请求内容")
self.enable = dify_config.get("enable", True)
self.api_key = dify_config.get("api-key", "")
self.base_url = dify_config.get("base-url", "")
self.price = dify_config.get("price", 0)
self.admin_ignore = dify_config.get("admin_ignore", False)
self.whitelist_ignore = dify_config.get("whitelist_ignore", False)
self.http_proxy = dify_config.get("http-proxy", "")
self.LOG.info(f"[{self.name}] 插件初始化完成,指令:{self._commands}")
return True
def start(self) -> bool:
"""启动插件"""
self.LOG.info(f"[{self.name}] 插件已启动")
self.status = PluginStatus.RUNNING
return True
def stop(self) -> bool:
"""停止插件"""
self.LOG.info(f"[{self.name}] 插件已停止")
self.status = PluginStatus.STOPPED
return True
def can_process(self, message: Dict[str, Any]) -> bool:
"""检查是否可以处理该消息"""
if not self.enable:
return False
content = str(message.get("content", "")).strip()
command = content.split(" ")[0]
# 检查是否是命令触发
if command in self._commands:
return True
# 检查是否是被@的消息
if message.get("is_at", False) and message.get("roomid", ""):
# 只处理群聊中被@的消息
return True
return False
@plugin_stats_decorator(plugin_name="Dify聊天")
@plugin_points_cost(2, "AI聊天消耗积分", Feature.AI_CAPABILITY)
async def process_message(self, message: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
"""处理消息"""
content = str(message.get("content", "")).strip()
self.LOG.debug(f"插件执行: {self.name}{content}")
sender = message.get("sender")
roomid = message.get("roomid", "")
gbm: GroupBotManager = message.get("gbm")
bot: WechatAPIClient = message.get("bot")
revoke: MessageAutoRevoke = message.get("revoke")
# 处理被@的消息
if message.get("is_at", False) and roomid:
# 检查权限
if gbm.get_group_permission(roomid, Feature.AI_CAPABILITY) == PermissionStatus.DISABLED:
return False, "没有权限"
# 去除@的人和空格等字符
query = re.sub(r"@.*?[\u2005|\s]", "", content).strip()
if not query:
await bot.send_at_message(roomid, "请在@我的同时提供问题内容", [sender])
return False, "没有提供问题内容"
try:
# 调用Dify API获取回复
response = self._chat_with_dify((roomid if roomid else sender), sender, query)
# 去除广告内容
response = remove_trailing_content(response)
# 发送回复
if response:
# 判断是否为本地文件路径
if os.path.isfile(response):
# 如果是文件路径,使用发送文件方法
await bot.send_image_message((roomid if roomid else sender), Path(response))
else:
# 如果是普通文本则在长度大于100字时转为图片发送
if len(response) > 200:
# 转图片
respath = await convert_md_str_to_image(response, "dify_output.jpg")
await bot.send_image_message((roomid if roomid else sender), Path(respath))
else:
# 如果是普通文本,使用发送文本方法
await bot.send_at_message((roomid if roomid else sender), response, [sender])
return True, "发送成功"
else:
await bot.send_text_message((roomid if roomid else sender), "❌未能获取到回复,请稍后再试", sender)
return False, "未获取到回复"
except Exception as e:
self.LOG.exception(f"处理Dify聊天请求出错: {e}")
client_msg_id, create_time, new_msg_id = await bot.send_text_message((roomid if roomid else sender),
"❌未能获取到回复,请稍后再试",
sender)
revoke.add_message_to_revoke((roomid if roomid else sender), client_msg_id, create_time, new_msg_id, 5)
return False, f"处理出错: {e}"
# 原有的命令处理逻辑
parts = content.split(" ", 1)
command = parts[0]
# 检查命令格式
if len(parts) < 2 or not parts[1].strip():
await bot.send_text_message((roomid if roomid else sender), f"{self.command_format}", sender)
return False, "命令格式错误"
# 检查权限
if roomid and gbm.get_group_permission((roomid if roomid else sender), Feature.AI_CAPABILITY) == PermissionStatus.DISABLED:
return False, "没有权限"
client_msg_id, create_time, new_msg_id = await bot.send_text_message((roomid if roomid else sender),
"⏳AI 正在加油,请稍候… 😊",
sender if roomid else "")
revoke.add_message_to_revoke((roomid if roomid else sender), client_msg_id, create_time, new_msg_id, 3)
# 获取查询内容
query = parts[1].strip()
# 获取会话ID群聊使用群ID私聊使用个人wxid
session_id = roomid if roomid else sender
# 获取用户ID
user_id = sender
try:
# 调用Dify API获取回复
response = self._chat_with_dify(session_id, user_id, query)
# 去除广告内容
response = remove_trailing_content(response)
# 发送回复
if response:
# 判断是否为本地文件路径
if os.path.isfile(response):
# 如果是文件路径,使用发送文件方法
await bot.send_image_message((roomid if roomid else sender), Path(response))
else:
# 如果是普通文本则在长度大于100字时转为图片发送
if len(response) > 200:
# 转图片
respath = await convert_md_str_to_image(response, "dify_output.jpg")
await bot.send_image_message((roomid if roomid else sender), Path(respath))
else:
# 如果是普通文本,使用发送文本方法
await bot.send_at_message((roomid if roomid else sender), response, [sender])
return True, "发送成功"
else:
client_msg_id, create_time, new_msg_id = await bot.send_text_message((roomid if roomid else sender),
"❌未能获取到回复,请稍后再试",
sender if roomid else "")
revoke.add_message_to_revoke((roomid if roomid else sender), client_msg_id, create_time, new_msg_id, 5)
return False, "未获取到回复"
except Exception as e:
self.LOG.exception(f"处理Dify聊天请求出错: {e}")
client_msg_id, create_time, new_msg_id = await bot.send_text_message((roomid if roomid else sender),
f"❌请求出错:{str(e)}",
sender if roomid else "")
revoke.add_message_to_revoke((roomid if roomid else sender), client_msg_id, create_time, new_msg_id, 5)
return False, f"处理出错: {e}"
def _chat_with_dify(self, session_id: str, user_id: str, query: str) -> Optional[str]:
"""
与Dify API交互获取回复
Args:
session_id: 会话ID群聊为群ID私聊为个人wxid
user_id: 用户wxid
query: 用户查询内容
Returns:
API返回的回复内容
"""
# 清理过期会话
self._cleanup_expired_conversations()
# 更新最后活动时间
self.last_activity[session_id] = time.time()
# 初始化会话历史
if session_id not in self.conversations:
self.conversations[session_id] = []
# 准备请求头
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "text/event-stream" # 指定接受事件流
}
# 准备历史记录
history_text = ""
if self.conversations[session_id]:
# 将历史记录转换为字符串格式
for msg in self.conversations[session_id]:
role = "用户" if msg["role"] == "user" else "助手"
history_text += f"{role}: {msg['content']}\n"
history_text = history_text.strip()
# 准备输入参数
inputs_params = {
"query": query,
"conversation_id": session_id
}
# 如果有历史记录添加到inputs_params中
if history_text:
inputs_params["history"] = history_text
# 准备请求数据
data = {
"sys.files": [],
"user": user_id,
"inputs": inputs_params,
"response_mode": "blocking" # 使用阻塞响应模式
}
# 如果有历史记录同时添加到conversation_history中
if self.conversations[session_id]:
data["conversation_history"] = self.conversations[session_id]
# 设置代理
proxies = None
if self.http_proxy:
proxies = {
"http": self.http_proxy,
"https": self.http_proxy
}
# 发送请求
url = f"{self.base_url}/workflows/run"
self.LOG.info(f"发送请求到Dify API: {url}")
self.LOG.info(f"请求数据: {json.dumps(data, ensure_ascii=False)}")
try:
# 使用普通请求(非流式)
response = requests.post(url, headers=headers, json=data, proxies=proxies, timeout=40)
if response.status_code != 200:
self.LOG.error(f"Dify API请求失败: {response.status_code} {response.text}")
return f"请求失败,状态码: {response.status_code}"
# 解析响应
response_data = response.json()
self.LOG.info(f"收到Dify API响应: {json.dumps(response_data, ensure_ascii=False)}")
# 提取回答内容
answer = ""
total_tokens = 0
# 获取输出内容
outputs = response_data.get("data", {}).get("outputs", {})
if outputs:
# 处理媒体类型返回
if "result" in outputs and "type" in outputs:
if outputs["type"] == "image":
downloader = MediaDownloader()
image_url = outputs["result"]
image_path = downloader.download_media(image_url)
answer = image_path
if outputs["type"] == "video":
downloader = MediaDownloader()
image_url = outputs["result"]
image_path = downloader.download_media(image_url)
answer = image_path
# 处理文本类型返回
elif "text" in outputs and isinstance(outputs["text"], str):
answer = outputs["text"]
# 兼容旧版处理逻辑
else:
for key, value in outputs.items():
if isinstance(value, str) and value.strip():
answer += value
elif isinstance(value, dict):
# 处理嵌套字典的情况
for sub_key, sub_value in value.items():
if isinstance(sub_value, str) and sub_value.strip():
answer += sub_value
elif isinstance(value, list):
# 处理列表的情况
for item in value:
if isinstance(item, str) and item.strip():
answer += item
elif isinstance(item, dict):
# 处理列表中的字典
for item_key, item_value in item.items():
if isinstance(item_value, str) and item_value.strip():
answer += item_value
# 获取token使用情况
total_tokens = response_data.get("data", {}).get("total_tokens", 0)
# 更新会话历史
self.conversations[session_id].append({
"role": "user",
"content": query
})
self.conversations[session_id].append({
"role": "assistant",
"content": answer
})
# 限制会话历史长度
if len(self.conversations[session_id]) > self.max_history_length * 2:
self.conversations[session_id] = self.conversations[session_id][-self.max_history_length * 2:]
# 统计token使用情况
if total_tokens > 0:
if user_id in self.token_usage:
self.token_usage[user_id] += total_tokens
else:
self.token_usage[user_id] = total_tokens
self.LOG.info(
f"用户 {user_id} 本次消耗 {total_tokens} tokens累计 {self.token_usage[user_id]} tokens")
return answer
except Exception as e:
self.LOG.error(f"处理Dify响应时出错: {str(e)}")
return f"处理响应时出错: {str(e)}"
def _cleanup_expired_conversations(self) -> None:
"""清理过期的会话"""
current_time = time.time()
expired_sessions = []
for session_id, last_time in self.last_activity.items():
if current_time - last_time > self.conversation_timeout:
expired_sessions.append(session_id)
for session_id in expired_sessions:
if session_id in self.conversations:
del self.conversations[session_id]
del self.last_activity[session_id]
if expired_sessions:
self.LOG.info(f"已清理 {len(expired_sessions)} 个过期会话")
def get_token_usage_report(self) -> str:
"""获取token使用情况报告"""
if not self.token_usage:
return "暂无token使用记录"
report = "Token使用情况统计\n"
for user_id, tokens in sorted(self.token_usage.items(), key=lambda x: x[1], reverse=True):
report += f"用户 {user_id}: {tokens} tokens\n"
total = sum(self.token_usage.values())
report += f"\n总计: {total} tokens"
return report
def reset_conversation(self, session_id: str) -> bool:
"""重置指定会话的上下文"""
if session_id in self.conversations:
del self.conversations[session_id]
if session_id in self.last_activity:
del self.last_activity[session_id]
return True
return False
def reset_all_conversations(self) -> None:
"""重置所有会话上下文"""
self.conversations.clear()
self.last_activity.clear()