Files
abot/plugins/dify/main.py
2025-03-19 17:37:02 +08:00

345 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import os
import requests
import json
import time
from typing import Dict, Any, List, Optional, Tuple
from wcferry import Wcf
from plugin_common.message_plugin_interface import MessagePluginInterface
from plugin_common.plugin_interface import PluginStatus
from plugins.stats_collector.decorators import plugin_stats_decorator
from robot_cmd.robot_command import Feature, PermissionStatus, GroupBotManager
class DifyPlugin(MessagePluginInterface):
"""Dify AI聊天插件"""
@property
def name(self) -> str:
return "Dify聊天"
@property
def version(self) -> str:
return "1.0.0"
@property
def description(self) -> str:
return "提供基于Dify的AI聊天功能"
@property
def author(self) -> str:
return "Trae AI"
@property
def command_prefix(self) -> Optional[str]:
return "" # 不需要前缀,直接匹配命令
@property
def commands(self) -> List[str]:
return self._commands
def __init__(self):
super().__init__()
# 会话上下文管理,格式: {group_id/wxid: [conversation_history]}
self.conversations: Dict[str, List[Dict]] = {}
# tokens 消耗统计,格式: {wxid: total_tokens}
self.token_usage: Dict[str, int] = {}
# 最大上下文长度
self.max_history_length = 10
# 会话过期时间(秒)
self.conversation_timeout = 3600 # 1小时
self.last_activity: Dict[str, float] = {}
def initialize(self, context: Dict[str, Any]) -> bool:
"""初始化插件"""
self.LOG = logging.getLogger(f"Plugin.{self.name}")
self.LOG.info(f"正在初始化 {self.name} 插件...")
# 保存上下文对象
self.wcf = context.get("wcf")
self.event_system = context.get("event_system")
self.message_util = context.get("message_util")
self.gbm = context.get("gbm")
# 从配置中获取参数
dify_config = self._config.get("Dify", {})
self._commands = dify_config.get("commands", ["ai", "dify", "聊天", "AI"])
self.command_format = dify_config.get("command-tip", "聊天 请求内容")
self.enable = dify_config.get("enable", True)
self.api_key = dify_config.get("api-key", "")
self.base_url = dify_config.get("base-url", "")
self.price = dify_config.get("price", 0)
self.admin_ignore = dify_config.get("admin_ignore", False)
self.whitelist_ignore = dify_config.get("whitelist_ignore", False)
self.http_proxy = dify_config.get("http-proxy", "")
self.LOG.info(f"[{self.name}] 插件初始化完成,指令:{self._commands}")
return True
def start(self) -> bool:
"""启动插件"""
self.LOG.info(f"[{self.name}] 插件已启动")
self.status = PluginStatus.RUNNING
return True
def stop(self) -> bool:
"""停止插件"""
self.LOG.info(f"[{self.name}] 插件已停止")
self.status = PluginStatus.STOPPED
return True
def can_process(self, message: Dict[str, Any]) -> bool:
"""检查是否可以处理该消息"""
if not self.enable:
return False
content = str(message.get("content", "")).strip()
command = content.split(" ")[0]
return command in self._commands
@plugin_stats_decorator(plugin_name="Dify聊天")
def process_message(self, message: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
"""处理消息"""
content = str(message.get("content", "")).strip()
self.LOG.info(f"插件执行: {self.name}{content}")
parts = content.split(" ", 1)
command = parts[0]
sender = message.get("sender")
roomid = message.get("roomid", "")
wcf: Wcf = message.get("wcf")
gbm: GroupBotManager = message.get("gbm")
# 检查命令格式
if len(parts) < 2 or not parts[1].strip():
wcf.send_text(f"{self.command_format}",
(roomid if roomid else sender), sender)
return True, "命令格式错误"
# 检查权限
if roomid and gbm.get_group_permission(roomid, Feature.AI_CAPABILITY) == PermissionStatus.DISABLED:
return False, "没有权限"
# 获取查询内容
query = parts[1].strip()
# 获取会话ID群聊使用群ID私聊使用个人wxid
session_id = roomid if roomid else sender
# 获取用户ID
user_id = sender
# 检查是否需要扣除积分
if self.price > 0:
# 管理员和白名单检查逻辑
is_admin = False # 这里需要实现管理员检查逻辑
is_whitelist = False # 这里需要实现白名单检查逻辑
should_deduct = True
if (self.admin_ignore and is_admin) or (self.whitelist_ignore and is_whitelist):
should_deduct = False
if should_deduct:
# 这里需要实现积分扣除逻辑
# 如果积分不足,返回提示
pass
try:
# 调用Dify API获取回复
response = self._chat_with_dify(session_id, user_id, query)
# 发送回复
if response:
wcf.send_text(response, (roomid if roomid else sender), sender if roomid else "")
return True, "发送成功"
else:
wcf.send_text("❌未能获取到回复,请稍后再试", (roomid if roomid else sender), sender if roomid else "")
return True, "未获取到回复"
except Exception as e:
self.LOG.error(f"处理Dify聊天请求出错: {e}")
wcf.send_text(f"❌请求出错:{str(e)}", (roomid if roomid else sender), sender if roomid else "")
return True, f"处理出错: {e}"
def _chat_with_dify(self, session_id: str, user_id: str, query: str) -> Optional[str]:
"""
与Dify API交互获取回复
Args:
session_id: 会话ID群聊为群ID私聊为个人wxid
user_id: 用户wxid
query: 用户查询内容
Returns:
API返回的回复内容
"""
# 清理过期会话
self._cleanup_expired_conversations()
# 更新最后活动时间
self.last_activity[session_id] = time.time()
# 初始化会话历史
if session_id not in self.conversations:
self.conversations[session_id] = []
# 准备请求头
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "text/event-stream" # 指定接受事件流
}
inputs_params = {
"query": query,
"conversation_id": session_id
}
# 准备请求数据
data = {
"sys.files": [],
"user": user_id,
"inputs": inputs_params,
"response_mode": "blocking" # 使用阻塞响应模式
}
# 添加历史记录
if self.conversations[session_id]:
data["conversation_history"] = self.conversations[session_id]
# 设置代理
proxies = None
if self.http_proxy:
proxies = {
"http": self.http_proxy,
"https": self.http_proxy
}
# 发送请求
url = f"{self.base_url}/workflows/run"
self.LOG.info(f"发送请求到Dify API: {url}")
self.LOG.info(f"请求数据: {json.dumps(data, ensure_ascii=False)}")
try:
# 使用普通请求(非流式)
response = requests.post(url, headers=headers, json=data, proxies=proxies)
if response.status_code != 200:
self.LOG.error(f"Dify API请求失败: {response.status_code} {response.text}")
return f"请求失败,状态码: {response.status_code}"
# 解析响应
response_data = response.json()
self.LOG.info(f"收到Dify API响应: {json.dumps(response_data, ensure_ascii=False)}")
# 提取回答内容
answer = ""
total_tokens = 0
# 获取输出内容
outputs = response_data.get("data", {}).get("outputs", {})
if outputs:
# 尝试从text字段获取回答
if "text" in outputs and isinstance(outputs["text"], str):
answer = outputs["text"]
# 如果没有text字段尝试从其他字段获取
else:
for key, value in outputs.items():
if isinstance(value, str) and value.strip():
answer += value
elif isinstance(value, dict):
# 处理嵌套字典的情况
for sub_key, sub_value in value.items():
if isinstance(sub_value, str) and sub_value.strip():
answer += sub_value
elif isinstance(value, list):
# 处理列表的情况
for item in value:
if isinstance(item, str) and item.strip():
answer += item
elif isinstance(item, dict):
# 处理列表中的字典
for item_key, item_value in item.items():
if isinstance(item_value, str) and item_value.strip():
answer += item_value
# 获取token使用情况
total_tokens = response_data.get("data", {}).get("total_tokens", 0)
# 更新会话历史
self.conversations[session_id].append({
"role": "user",
"content": query
})
self.conversations[session_id].append({
"role": "assistant",
"content": answer
})
# 限制会话历史长度
if len(self.conversations[session_id]) > self.max_history_length * 2:
self.conversations[session_id] = self.conversations[session_id][-self.max_history_length * 2:]
# 统计token使用情况
if total_tokens > 0:
if user_id in self.token_usage:
self.token_usage[user_id] += total_tokens
else:
self.token_usage[user_id] = total_tokens
self.LOG.info(
f"用户 {user_id} 本次消耗 {total_tokens} tokens累计 {self.token_usage[user_id]} tokens")
return answer
except Exception as e:
self.LOG.error(f"处理Dify响应时出错: {str(e)}")
return f"处理响应时出错: {str(e)}"
def _cleanup_expired_conversations(self) -> None:
"""清理过期的会话"""
current_time = time.time()
expired_sessions = []
for session_id, last_time in self.last_activity.items():
if current_time - last_time > self.conversation_timeout:
expired_sessions.append(session_id)
for session_id in expired_sessions:
if session_id in self.conversations:
del self.conversations[session_id]
del self.last_activity[session_id]
if expired_sessions:
self.LOG.info(f"已清理 {len(expired_sessions)} 个过期会话")
def get_token_usage_report(self) -> str:
"""获取token使用情况报告"""
if not self.token_usage:
return "暂无token使用记录"
report = "Token使用情况统计\n"
for user_id, tokens in sorted(self.token_usage.items(), key=lambda x: x[1], reverse=True):
report += f"用户 {user_id}: {tokens} tokens\n"
total = sum(self.token_usage.values())
report += f"\n总计: {total} tokens"
return report
def reset_conversation(self, session_id: str) -> bool:
"""重置指定会话的上下文"""
if session_id in self.conversations:
del self.conversations[session_id]
if session_id in self.last_activity:
del self.last_activity[session_id]
return True
return False
def reset_all_conversations(self) -> None:
"""重置所有会话上下文"""
self.conversations.clear()
self.last_activity.clear()