291 lines
11 KiB
Python
291 lines
11 KiB
Python
import logging
|
||
import tomllib
|
||
import os
|
||
import requests
|
||
import json
|
||
import time
|
||
from typing import Dict, List, Optional
|
||
from wcferry import WxMsg, Wcf
|
||
|
||
from robot_cmd.robot_command import Feature, PermissionStatus, GroupBotManager
|
||
|
||
|
||
class DifyChat:
|
||
def __init__(self, wcf: Wcf, gbm: GroupBotManager):
|
||
self.LOG = logging.getLogger(__name__)
|
||
self.wcf = wcf
|
||
self.gbm = gbm # 权限功能
|
||
|
||
# 加载配置文件
|
||
with open("dify/config.toml", "rb") as f:
|
||
plugin_config = tomllib.load(f)
|
||
|
||
config = plugin_config["Dify"]
|
||
|
||
# 基本配置
|
||
self.enable = config["enable"]
|
||
self.api_key = config["api-key"]
|
||
self.base_url = config["base-url"]
|
||
self.commands = config["commands"]
|
||
self.command_tip = config["command-tip"]
|
||
self.price = config["price"]
|
||
self.admin_ignore = config.get("admin_ignore", False)
|
||
self.whitelist_ignore = config.get("whitelist_ignore", False)
|
||
self.http_proxy = config.get("http-proxy", "")
|
||
|
||
# 会话上下文管理,格式: {group_id/wxid: [conversation_history]}
|
||
self.conversations: Dict[str, List[Dict]] = {}
|
||
|
||
# tokens 消耗统计,格式: {wxid: total_tokens}
|
||
self.token_usage: Dict[str, int] = {}
|
||
|
||
# 最大上下文长度
|
||
self.max_history_length = 10
|
||
|
||
# 会话过期时间(秒)
|
||
self.conversation_timeout = 3600 # 1小时
|
||
self.last_activity: Dict[str, float] = {}
|
||
|
||
self.LOG.info(f"[Dify聊天] 组件初始化完成,指令:{self.commands}")
|
||
|
||
def handle_message(self, message: WxMsg) -> None:
|
||
"""处理微信消息"""
|
||
if not self.enable:
|
||
return
|
||
|
||
content = str(message.content).strip()
|
||
parts = content.split(" ", 1)
|
||
command = parts[0]
|
||
|
||
# 检查是否是触发命令
|
||
if command not in self.commands:
|
||
return
|
||
|
||
# 如果没有查询内容,返回使用提示
|
||
if len(parts) < 2 or not parts[1].strip():
|
||
self.wcf.send_text(self.command_tip,
|
||
(message.roomid if message.from_group() else message.sender))
|
||
return
|
||
|
||
# 检查权限
|
||
if message.from_group() and self.gbm.get_group_permission(message.roomid, Feature.AI_CAPABILITY) == PermissionStatus.DISABLED:
|
||
return
|
||
|
||
# 获取查询内容
|
||
query = parts[1].strip()
|
||
|
||
# 获取会话ID(群聊使用群ID,私聊使用个人wxid)
|
||
session_id = message.roomid if message.from_group() else message.sender
|
||
|
||
# 获取用户ID
|
||
user_id = message.sender
|
||
|
||
# 检查是否需要扣除积分
|
||
if self.price > 0:
|
||
# 管理员和白名单检查逻辑
|
||
is_admin = False # 这里需要实现管理员检查逻辑
|
||
is_whitelist = False # 这里需要实现白名单检查逻辑
|
||
|
||
should_deduct = True
|
||
if (self.admin_ignore and is_admin) or (self.whitelist_ignore and is_whitelist):
|
||
should_deduct = False
|
||
|
||
if should_deduct:
|
||
# 这里需要实现积分扣除逻辑
|
||
# 如果积分不足,返回提示
|
||
pass
|
||
|
||
try:
|
||
# 调用Dify API获取回复
|
||
response = self.chat_with_dify(session_id, user_id, query)
|
||
|
||
# 发送回复
|
||
if response:
|
||
self.wcf.send_text(response,
|
||
(message.roomid if message.from_group() else message.sender),
|
||
message.sender if message.from_group() else "")
|
||
except Exception as e:
|
||
self.LOG.error(f"Dify聊天出错:{e}")
|
||
self.wcf.send_text(f"-----Bot-----\n❌请求出错:{str(e)}",
|
||
(message.roomid if message.from_group() else message.sender),
|
||
message.sender if message.from_group() else "")
|
||
|
||
def chat_with_dify(self, session_id: str, user_id: str, query: str) -> Optional[str]:
|
||
"""
|
||
与Dify API交互获取回复
|
||
|
||
Args:
|
||
session_id: 会话ID(群聊为群ID,私聊为个人wxid)
|
||
user_id: 用户wxid
|
||
query: 用户查询内容
|
||
|
||
Returns:
|
||
API返回的回复内容
|
||
"""
|
||
# 清理过期会话
|
||
self._cleanup_expired_conversations()
|
||
|
||
# 更新最后活动时间
|
||
self.last_activity[session_id] = time.time()
|
||
|
||
# 初始化会话历史
|
||
if session_id not in self.conversations:
|
||
self.conversations[session_id] = []
|
||
|
||
# 准备请求头
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json",
|
||
"Accept": "text/event-stream" # 指定接受事件流
|
||
}
|
||
|
||
# 准备请求数据
|
||
data = {
|
||
"query": query,
|
||
"sys.files": [],
|
||
"user": user_id,
|
||
"response_mode": "streaming", # 使用流式响应
|
||
"conversation_id": session_id # 使用会话ID保持上下文
|
||
}
|
||
|
||
# 添加历史记录
|
||
if self.conversations[session_id]:
|
||
data["conversation_history"] = self.conversations[session_id]
|
||
|
||
# 设置代理
|
||
proxies = None
|
||
if self.http_proxy:
|
||
proxies = {
|
||
"http": self.http_proxy,
|
||
"https": self.http_proxy
|
||
}
|
||
|
||
# 发送请求
|
||
url = f"{self.base_url}/chat-messages"
|
||
|
||
try:
|
||
# 使用流式请求
|
||
with requests.post(url, headers=headers, json=data, proxies=proxies, stream=True) as response:
|
||
if response.status_code != 200:
|
||
self.LOG.error(f"Dify API请求失败: {response.status_code} {response.text}")
|
||
return f"请求失败,状态码: {response.status_code}"
|
||
|
||
answer = ""
|
||
total_tokens = 0
|
||
|
||
# 处理流式响应
|
||
for line in response.iter_lines():
|
||
if not line:
|
||
continue
|
||
|
||
# 解析事件流数据
|
||
line_text = line.decode('utf-8')
|
||
if not line_text.startswith('data: '):
|
||
continue
|
||
|
||
# 提取JSON数据部分
|
||
json_str = line_text[6:] # 去掉 "data: " 前缀
|
||
|
||
try:
|
||
event_data = json.loads(json_str)
|
||
event_type = event_data.get("event")
|
||
|
||
# 处理不同类型的事件
|
||
if event_type == "workflow_finished":
|
||
# 工作流完成事件,可以获取总token数
|
||
data_obj = event_data.get("data", {})
|
||
total_tokens = data_obj.get("total_tokens", 0)
|
||
|
||
elif event_type == "node_finished":
|
||
# 节点完成事件,可以获取节点输出和token使用情况
|
||
data_obj = event_data.get("data", {})
|
||
outputs = data_obj.get("outputs", {})
|
||
|
||
# 从outputs中提取回答内容
|
||
if outputs and isinstance(outputs, dict):
|
||
for key, value in outputs.items():
|
||
if isinstance(value, str) and value.strip():
|
||
answer += value
|
||
|
||
# 获取token使用情况
|
||
execution_metadata = data_obj.get("execution_metadata", {})
|
||
if "total_tokens" in execution_metadata:
|
||
total_tokens = execution_metadata.get("total_tokens", 0)
|
||
|
||
except json.JSONDecodeError:
|
||
self.LOG.error(f"解析事件流数据失败: {line_text}")
|
||
continue
|
||
|
||
# 更新会话历史
|
||
self.conversations[session_id].append({
|
||
"role": "user",
|
||
"content": query
|
||
})
|
||
|
||
self.conversations[session_id].append({
|
||
"role": "assistant",
|
||
"content": answer
|
||
})
|
||
|
||
# 限制会话历史长度
|
||
if len(self.conversations[session_id]) > self.max_history_length * 2:
|
||
self.conversations[session_id] = self.conversations[session_id][-self.max_history_length * 2:]
|
||
|
||
# 统计token使用情况
|
||
if total_tokens > 0:
|
||
if user_id in self.token_usage:
|
||
self.token_usage[user_id] += total_tokens
|
||
else:
|
||
self.token_usage[user_id] = total_tokens
|
||
|
||
self.LOG.info(f"用户 {user_id} 本次消耗 {total_tokens} tokens,累计 {self.token_usage[user_id]} tokens")
|
||
|
||
return answer
|
||
|
||
except Exception as e:
|
||
self.LOG.error(f"处理Dify流式响应时出错: {str(e)}")
|
||
return f"处理响应时出错: {str(e)}"
|
||
|
||
def _cleanup_expired_conversations(self) -> None:
|
||
"""清理过期的会话"""
|
||
current_time = time.time()
|
||
expired_sessions = []
|
||
|
||
for session_id, last_time in self.last_activity.items():
|
||
if current_time - last_time > self.conversation_timeout:
|
||
expired_sessions.append(session_id)
|
||
|
||
for session_id in expired_sessions:
|
||
if session_id in self.conversations:
|
||
del self.conversations[session_id]
|
||
del self.last_activity[session_id]
|
||
|
||
if expired_sessions:
|
||
self.LOG.info(f"已清理 {len(expired_sessions)} 个过期会话")
|
||
|
||
def get_token_usage_report(self) -> str:
|
||
"""获取token使用情况报告"""
|
||
if not self.token_usage:
|
||
return "暂无token使用记录"
|
||
|
||
report = "Token使用情况统计:\n"
|
||
for user_id, tokens in sorted(self.token_usage.items(), key=lambda x: x[1], reverse=True):
|
||
report += f"用户 {user_id}: {tokens} tokens\n"
|
||
|
||
total = sum(self.token_usage.values())
|
||
report += f"\n总计: {total} tokens"
|
||
return report
|
||
|
||
def reset_conversation(self, session_id: str) -> bool:
|
||
"""重置指定会话的上下文"""
|
||
if session_id in self.conversations:
|
||
del self.conversations[session_id]
|
||
if session_id in self.last_activity:
|
||
del self.last_activity[session_id]
|
||
return True
|
||
return False
|
||
|
||
def reset_all_conversations(self) -> None:
|
||
"""重置所有会话上下文"""
|
||
self.conversations.clear()
|
||
self.last_activity.clear() |