Files
abot/dify/dify_chat.py
2025-03-13 15:38:48 +08:00

245 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import tomllib
import os
import requests
import json
import time
from typing import Dict, List, Optional
from wcferry import WxMsg, Wcf
from robot_cmd.robot_command import Feature, PermissionStatus, GroupBotManager
class DifyChat:
def __init__(self, wcf: Wcf, gbm: GroupBotManager):
self.LOG = logging.getLogger(__name__)
self.wcf = wcf
self.gbm = gbm # 权限功能
# 加载配置文件
with open("dify/config.toml", "rb") as f:
plugin_config = tomllib.load(f)
config = plugin_config["Dify"]
# 基本配置
self.enable = config["enable"]
self.api_key = config["api-key"]
self.base_url = config["base-url"]
self.commands = config["commands"]
self.command_tip = config["command-tip"]
self.price = config["price"]
self.admin_ignore = config.get("admin_ignore", False)
self.whitelist_ignore = config.get("whitelist_ignore", False)
self.http_proxy = config.get("http-proxy", "")
# 会话上下文管理,格式: {group_id/wxid: [conversation_history]}
self.conversations: Dict[str, List[Dict]] = {}
# tokens 消耗统计,格式: {wxid: total_tokens}
self.token_usage: Dict[str, int] = {}
# 最大上下文长度
self.max_history_length = 10
# 会话过期时间(秒)
self.conversation_timeout = 3600 # 1小时
self.last_activity: Dict[str, float] = {}
self.LOG.info(f"[Dify聊天] 组件初始化完成,指令:{self.commands}")
def handle_message(self, message: WxMsg) -> None:
"""处理微信消息"""
if not self.enable:
return
content = str(message.content).strip()
parts = content.split(" ", 1)
command = parts[0]
# 检查是否是触发命令
if command not in self.commands:
return
# 如果没有查询内容,返回使用提示
if len(parts) < 2 or not parts[1].strip():
self.wcf.send_text(self.command_tip,
(message.roomid if message.from_group() else message.sender))
return
# 检查权限
if message.from_group() and self.gbm.get_group_permission(message.roomid, Feature.AI_CAPABILITY) == PermissionStatus.DISABLED:
return
# 获取查询内容
query = parts[1].strip()
# 获取会话ID群聊使用群ID私聊使用个人wxid
session_id = message.roomid if message.from_group() else message.sender
# 获取用户ID
user_id = message.sender
# 检查是否需要扣除积分
if self.price > 0:
# 管理员和白名单检查逻辑
is_admin = False # 这里需要实现管理员检查逻辑
is_whitelist = False # 这里需要实现白名单检查逻辑
should_deduct = True
if (self.admin_ignore and is_admin) or (self.whitelist_ignore and is_whitelist):
should_deduct = False
if should_deduct:
# 这里需要实现积分扣除逻辑
# 如果积分不足,返回提示
pass
try:
# 调用Dify API获取回复
response = self.chat_with_dify(session_id, user_id, query)
# 发送回复
if response:
self.wcf.send_text(response,
(message.roomid if message.from_group() else message.sender),
message.sender if message.from_group() else "")
except Exception as e:
self.LOG.error(f"Dify聊天出错{e}")
self.wcf.send_text(f"-----Bot-----\n❌请求出错:{str(e)}",
(message.roomid if message.from_group() else message.sender),
message.sender if message.from_group() else "")
def chat_with_dify(self, session_id: str, user_id: str, query: str) -> Optional[str]:
"""
与Dify API交互获取回复
Args:
session_id: 会话ID群聊为群ID私聊为个人wxid
user_id: 用户wxid
query: 用户查询内容
Returns:
API返回的回复内容
"""
# 清理过期会话
self._cleanup_expired_conversations()
# 更新最后活动时间
self.last_activity[session_id] = time.time()
# 初始化会话历史
if session_id not in self.conversations:
self.conversations[session_id] = []
# 准备请求头
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 准备请求数据
data = {
"query": query,
"sys.files": [],
"sys.user_id": user_id,
"response_mode": "streaming", # 使用流式响应
"conversation_id": session_id # 使用会话ID保持上下文
}
# 添加历史记录
if self.conversations[session_id]:
data["conversation_history"] = self.conversations[session_id]
# 设置代理
proxies = None
if self.http_proxy:
proxies = {
"http": self.http_proxy,
"https": self.http_proxy
}
# 发送请求
url = f"{self.base_url}/chat-messages"
response = requests.post(url, headers=headers, json=data, proxies=proxies)
if response.status_code != 200:
self.LOG.error(f"Dify API请求失败: {response.status_code} {response.text}")
return f"请求失败,状态码: {response.status_code}"
# 解析响应
result = response.json()
# 提取回复内容
answer = result.get("answer", "")
# 更新会话历史
self.conversations[session_id].append({
"role": "user",
"content": query
})
self.conversations[session_id].append({
"role": "assistant",
"content": answer
})
# 限制会话历史长度
if len(self.conversations[session_id]) > self.max_history_length * 2:
self.conversations[session_id] = self.conversations[session_id][-self.max_history_length * 2:]
# 统计token使用情况
if "usage" in result and "total_tokens" in result["usage"]:
total_tokens = result["usage"]["total_tokens"]
if user_id in self.token_usage:
self.token_usage[user_id] += total_tokens
else:
self.token_usage[user_id] = total_tokens
self.LOG.info(f"用户 {user_id} 本次消耗 {total_tokens} tokens累计 {self.token_usage[user_id]} tokens")
return answer
def _cleanup_expired_conversations(self) -> None:
"""清理过期的会话"""
current_time = time.time()
expired_sessions = []
for session_id, last_time in self.last_activity.items():
if current_time - last_time > self.conversation_timeout:
expired_sessions.append(session_id)
for session_id in expired_sessions:
if session_id in self.conversations:
del self.conversations[session_id]
del self.last_activity[session_id]
if expired_sessions:
self.LOG.info(f"已清理 {len(expired_sessions)} 个过期会话")
def get_token_usage_report(self) -> str:
"""获取token使用情况报告"""
if not self.token_usage:
return "暂无token使用记录"
report = "Token使用情况统计\n"
for user_id, tokens in sorted(self.token_usage.items(), key=lambda x: x[1], reverse=True):
report += f"用户 {user_id}: {tokens} tokens\n"
total = sum(self.token_usage.values())
report += f"\n总计: {total} tokens"
return report
def reset_conversation(self, session_id: str) -> bool:
"""重置指定会话的上下文"""
if session_id in self.conversations:
del self.conversations[session_id]
if session_id in self.last_activity:
del self.last_activity[session_id]
return True
return False
def reset_all_conversations(self) -> None:
"""重置所有会话上下文"""
self.conversations.clear()
self.last_activity.clear()