696 lines
28 KiB
Python
696 lines
28 KiB
Python
import os
|
||
import cv2
|
||
import requests
|
||
import json
|
||
import time
|
||
import re # 添加re模块导入
|
||
import asyncio
|
||
import base64
|
||
import html
|
||
import xml.etree.ElementTree as ET
|
||
from typing import Dict, Any, List, Optional, Tuple
|
||
|
||
from pathlib import Path
|
||
|
||
from base.plugin_common.message_plugin_interface import MessagePluginInterface
|
||
from base.plugin_common.plugin_interface import PluginStatus
|
||
from utils.decorator.plugin_decorators import plugin_stats_decorator
|
||
from utils.decorator.rate_limit_decorator import group_feature_rate_limit
|
||
from utils.markdown_to_image import convert_md_str_to_image
|
||
from utils.revoke.message_auto_revoke import MessageAutoRevoke
|
||
from utils.robot_cmd.robot_command import Feature, PermissionStatus, GroupBotManager
|
||
from utils.decorator.points_decorator import plugin_points_cost
|
||
from utils.media_downloader import MediaDownloader
|
||
from utils.string_utils import remove_reasoning_content, remove_trailing_content, remove_grok_render_tags
|
||
from wechat_ipad import WechatAPIClient
|
||
from wechat_ipad.models.message import MessageType
|
||
import aiohttp
|
||
|
||
# 常见的图片和视频文件扩展名
|
||
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
|
||
VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.wmv', '.flv', '.mkv', '.webm'}
|
||
|
||
|
||
class DifyPlugin(MessagePluginInterface):
|
||
"""Dify AI聊天插件"""
|
||
|
||
# 功能权限常量
|
||
FEATURE_KEY = "AI_CAPABILITY"
|
||
FEATURE_DESCRIPTION = "🤖 AI对话 [ai, 聊天, AI] 用法:ai 如何写一个机器人?"
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return "Dify聊天"
|
||
|
||
@property
|
||
def version(self) -> str:
|
||
return "1.0.0"
|
||
|
||
@property
|
||
def description(self) -> str:
|
||
return "提供基于Dify的AI聊天功能"
|
||
|
||
@property
|
||
def author(self) -> str:
|
||
return "liu.wei"
|
||
|
||
@property
|
||
def command_prefix(self) -> Optional[str]:
|
||
return "" # 不需要前缀,直接匹配命令
|
||
|
||
@property
|
||
def commands(self) -> List[str]:
|
||
return self._commands
|
||
|
||
@property
|
||
def feature_key(self) -> Optional[str]:
|
||
return self.FEATURE_KEY
|
||
|
||
@property
|
||
def feature_description(self) -> Optional[str]:
|
||
return self.FEATURE_DESCRIPTION
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
# 会话上下文管理,格式: {group_id/wxid: [conversation_history]}
|
||
self.conversations: Dict[str, List[Dict]] = {}
|
||
# tokens 消耗统计,格式: {wxid: total_tokens}
|
||
self.token_usage: Dict[str, int] = {}
|
||
# 最大上下文长度
|
||
self.max_history_length = 10
|
||
# 会话过期时间(秒)
|
||
self.conversation_timeout = 3600 # 1小时
|
||
self.last_activity: Dict[str, float] = {}
|
||
# 注册功能权限
|
||
self.feature = self.register_feature()
|
||
|
||
def initialize(self, context: Dict[str, Any]) -> bool:
|
||
"""初始化插件"""
|
||
self.LOG.debug(f"正在初始化 {self.name} 插件...")
|
||
|
||
# 保存上下文对象
|
||
self.event_system = context.get("event_system")
|
||
self.gbm = context.get("gbm")
|
||
|
||
# 从配置中获取参数
|
||
dify_config = self._config.get("Dify", {})
|
||
self._commands = dify_config.get("commands", ["ai", "dify", "聊天", "AI"])
|
||
self.command_format = dify_config.get("command-tip", "聊天 请求内容")
|
||
self.enable = dify_config.get("enable", True)
|
||
self.api_key = dify_config.get("api-key", "")
|
||
self.base_url = dify_config.get("base-url", "")
|
||
self.price = dify_config.get("price", 0)
|
||
self.admin_ignore = dify_config.get("admin_ignore", False)
|
||
self.whitelist_ignore = dify_config.get("whitelist_ignore", False)
|
||
self.http_proxy = dify_config.get("http-proxy", "")
|
||
|
||
self.LOG.debug(f"[{self.name}] 插件初始化完成,指令:{self._commands}")
|
||
return True
|
||
|
||
def start(self) -> bool:
|
||
"""启动插件"""
|
||
self.LOG.debug(f"[{self.name}] 插件已启动")
|
||
self.status = PluginStatus.RUNNING
|
||
return True
|
||
|
||
def stop(self) -> bool:
|
||
"""停止插件"""
|
||
self.LOG.info(f"[{self.name}] 插件已停止")
|
||
self.status = PluginStatus.STOPPED
|
||
return True
|
||
|
||
def can_process(self, message: Dict[str, Any]) -> bool:
|
||
"""检查是否可以处理该消息"""
|
||
if not self.enable:
|
||
return False
|
||
|
||
content = str(message.get("content", "")).strip()
|
||
command = content.split(" ")[0]
|
||
|
||
# 检查是否是命令触发
|
||
if command in self._commands:
|
||
return True
|
||
|
||
# 检查是否是被@的消息
|
||
if message.get("is_at", False) and message.get("roomid", ""):
|
||
# 只处理群聊中被@的消息
|
||
return True
|
||
|
||
return False
|
||
|
||
@plugin_stats_decorator(plugin_name="Dify聊天")
|
||
@group_feature_rate_limit(max_per_minute=5, feature_key=FEATURE_KEY)
|
||
async def process_message(self, message: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
|
||
"""处理消息"""
|
||
content = str(message.get("content", "")).strip()
|
||
self.LOG.debug(f"插件执行: {self.name}:{content}")
|
||
sender = message.get("sender")
|
||
roomid = message.get("roomid", "")
|
||
gbm: GroupBotManager = message.get("gbm")
|
||
bot: WechatAPIClient = message.get("bot")
|
||
revoke: MessageAutoRevoke = message.get("revoke")
|
||
|
||
# 获取目标接收者(群聊为群ID,私聊为个人wxid)
|
||
target = roomid if roomid else sender
|
||
self.LOG.debug(
|
||
f"消息上下文: sender={sender}, roomid={roomid}, target={target}, is_at={message.get('is_at', False)}")
|
||
|
||
# 检查权限
|
||
if roomid and gbm.get_group_permission(target, self.feature) == PermissionStatus.DISABLED:
|
||
return False, "没有权限"
|
||
|
||
# 处理被@的消息
|
||
if message.get("is_at", False) and roomid:
|
||
# 去除@的人和空格等字符
|
||
query = re.sub(r"@.*?[\u2005|\s]", "", content).strip()
|
||
if not query:
|
||
# await bot.send_at_message(target, "请在@我的同时提供问题内容", [sender])
|
||
return False, "没有提供问题内容"
|
||
else:
|
||
# 处理命令消息
|
||
parts = content.split(" ", 1)
|
||
if len(parts) < 2 or not parts[1].strip():
|
||
await bot.send_text_message(target, f"{self.command_format}", sender)
|
||
return False, "命令格式错误"
|
||
query = parts[1].strip()
|
||
|
||
self.LOG.debug(f"解析请求: query_len={len(query)} query_preview={query[:120]}")
|
||
|
||
try:
|
||
# 发送等待消息
|
||
# client_msg_id, create_time, new_msg_id = await bot.send_text_message(
|
||
# target, "⏳AI 正在加油,请稍候… 😊", sender if roomid else "")
|
||
# revoke.add_message_to_revoke(target, client_msg_id, create_time, new_msg_id, 3)
|
||
|
||
dify_files = []
|
||
quote_payload = self._parse_quote_payload(message.get("full_wx_msg"))
|
||
if quote_payload:
|
||
ref_type = quote_payload.get("ref_type", 0)
|
||
ref_content = quote_payload.get("ref_content", "").strip()
|
||
title = quote_payload.get("title", "").strip()
|
||
self.LOG.debug(
|
||
f"检测到引用消息: type={ref_type}, title_preview={title[:80]}, ref_len={len(ref_content)}")
|
||
|
||
if ref_type == MessageType.TEXT.value:
|
||
if ref_content:
|
||
self.LOG.debug("使用引用文本作为问题")
|
||
query = ref_content
|
||
elif ref_type == MessageType.IMAGE.value:
|
||
if title:
|
||
title_query = re.sub(r"@.*?[\u2005|\s]", "", title).strip()
|
||
if title_query:
|
||
query = title_query
|
||
self.LOG.debug("检测到引用图片,开始下载与上传")
|
||
|
||
image_path = await self._download_quote_image(bot, ref_content)
|
||
if not image_path:
|
||
return False, "图片引用解析失败"
|
||
|
||
upload_id = await self._upload_file_to_dify(image_path, sender)
|
||
if not upload_id:
|
||
return False, "图片上传失败"
|
||
self.LOG.debug(f"图片上传完成: upload_id={upload_id}")
|
||
|
||
dify_files.append({
|
||
"type": "image",
|
||
"transfer_method": "local_file",
|
||
"upload_file_id": upload_id
|
||
})
|
||
elif ref_type == MessageType.VIDEO.value:
|
||
return False, "暂不支持视频引用"
|
||
elif ref_type in (MessageType.EMOTICON.value, MessageType.EMOJI.value):
|
||
return False, "暂不支持表情引用"
|
||
|
||
success, response = await self._chat_with_dify(target, sender, query, dify_files)
|
||
if not success:
|
||
return False, response
|
||
|
||
# 去除广告内容
|
||
response = remove_trailing_content(response)
|
||
response = remove_grok_render_tags(response)
|
||
self.LOG.debug(f"处理后的响应: {response}")
|
||
|
||
# 发送回复
|
||
if response:
|
||
return await self._send_response(bot, target, sender, response, roomid)
|
||
else:
|
||
client_msg_id, create_time, new_msg_id = await bot.send_text_message(
|
||
target, "❌未能获取到回复,请稍后再试", sender if roomid else "")
|
||
revoke.add_message_to_revoke(target, client_msg_id, create_time, new_msg_id, 5)
|
||
return False, "未获取到回复"
|
||
|
||
except Exception as e:
|
||
self.LOG.exception(f"处理Dify聊天请求出错: {e}")
|
||
client_msg_id, create_time, new_msg_id = await bot.send_text_message(
|
||
target, f"❌DIFY响应失败", sender if roomid else "")
|
||
revoke.add_message_to_revoke(target, client_msg_id, create_time, new_msg_id, 5)
|
||
return False, f"处理出错: {e}"
|
||
|
||
async def _send_response(self, bot: WechatAPIClient, target: str, sender: str,
|
||
response: str, roomid: str) -> Tuple[bool, str]:
|
||
"""发送响应消息的辅助方法"""
|
||
try:
|
||
if response and not os.path.isfile(response):
|
||
response = remove_reasoning_content(response)
|
||
response = remove_trailing_content(response)
|
||
response = remove_grok_render_tags(response)
|
||
response = re.sub(r'\n{3,}', '\n\n', response).strip()
|
||
|
||
# 判断是否为本地文件路径
|
||
if os.path.isfile(response):
|
||
# 如果是文件路径,使用发送文件方法
|
||
file_type = self.check_file_type(response)
|
||
if file_type == 1:
|
||
await bot.send_image_message(target, Path(response))
|
||
elif file_type == 2:
|
||
first_frame = await self._get_first_frame(response, f"dify_frame_{int(time.time())}.jpg")
|
||
await bot.send_video_message(target, Path(response), Path(first_frame))
|
||
else:
|
||
return False, "获取媒资失败"
|
||
else:
|
||
# 如果是普通文本,则在长度大于800字时,转为图片发送
|
||
if len(response) > 1500:
|
||
# 转图片
|
||
output_image = f"dify_output_{int(time.time())}.png"
|
||
respath = await convert_md_str_to_image(response, output_image)
|
||
await bot.send_image_message(target, Path(respath))
|
||
else:
|
||
# 如果是普通文本,使用发送文本方法
|
||
await bot.send_text_message(target, response, sender)
|
||
return True, "发送成功"
|
||
except Exception as e:
|
||
self.LOG.error(f"发送响应消息时出错: {e}")
|
||
return False, f"发送响应失败: {e}"
|
||
|
||
def _parse_quote_payload(self, full_msg: Any) -> Optional[Dict[str, Any]]:
|
||
if not full_msg or not getattr(full_msg, "content", None):
|
||
return None
|
||
|
||
xml_content = getattr(full_msg.content, "xml_content", "")
|
||
if not xml_content:
|
||
return None
|
||
|
||
try:
|
||
root = ET.fromstring(xml_content)
|
||
except ET.ParseError:
|
||
return None
|
||
|
||
appmsg = root.find(".//appmsg")
|
||
if appmsg is None:
|
||
return None
|
||
|
||
if appmsg.findtext("type", "").strip() != "57":
|
||
return None
|
||
|
||
title = appmsg.findtext("title", "") or ""
|
||
refer = appmsg.find("refermsg")
|
||
if refer is None:
|
||
return None
|
||
|
||
ref_type = int(refer.findtext("type", "0") or 0)
|
||
ref_content = refer.findtext("content", "") or ""
|
||
|
||
self.LOG.debug(
|
||
f"引用解析成功: type={ref_type}, title_len={len(title)}, content_len={len(ref_content)}")
|
||
|
||
return {
|
||
"title": html.unescape(title),
|
||
"ref_type": ref_type,
|
||
"ref_content": html.unescape(ref_content)
|
||
}
|
||
|
||
def _extract_quote_image_info(self, ref_content: str) -> Optional[Dict[str, str]]:
|
||
if not ref_content:
|
||
return None
|
||
|
||
aeskey_match = re.search(r'aeskey="([^"]+)"', ref_content)
|
||
if not aeskey_match:
|
||
return None
|
||
|
||
url_match = re.search(r'cdnmidimgurl="([^"]+)"', ref_content)
|
||
if not url_match:
|
||
url_match = re.search(r'cdnbigimgurl="([^"]+)"', ref_content)
|
||
if not url_match:
|
||
url_match = re.search(r'cdnthumburl="([^"]+)"', ref_content)
|
||
|
||
if not url_match:
|
||
return None
|
||
|
||
md5_match = re.search(r'md5="([^"]+)"', ref_content)
|
||
|
||
return {
|
||
"aeskey": aeskey_match.group(1),
|
||
"url": url_match.group(1),
|
||
"md5": md5_match.group(1) if md5_match else ""
|
||
}
|
||
|
||
async def _download_quote_image(self, bot: WechatAPIClient, ref_content: str) -> Optional[str]:
|
||
image_info = self._extract_quote_image_info(ref_content)
|
||
if not image_info:
|
||
return None
|
||
self.LOG.debug(
|
||
f"准备下载引用图片: url_len={len(image_info['url'])}, aeskey_prefix={image_info['aeskey'][:6]}")
|
||
|
||
try:
|
||
base64_str = await bot.download_image(
|
||
aeskey=image_info["aeskey"],
|
||
cdnmidimgurl=image_info["url"]
|
||
)
|
||
except Exception as e:
|
||
self.LOG.error(f"下载引用图片失败: {e}")
|
||
return None
|
||
|
||
if not base64_str:
|
||
return None
|
||
|
||
try:
|
||
image_data = base64.b64decode(base64_str)
|
||
except Exception as e:
|
||
self.LOG.error(f"解码引用图片失败: {e}")
|
||
return None
|
||
|
||
temp_dir = Path(__file__).resolve().parents[2] / "temp"
|
||
os.makedirs(temp_dir, exist_ok=True)
|
||
|
||
suffix = image_info["md5"] if image_info.get("md5") else str(int(time.time()))
|
||
file_path = temp_dir / f"dify_quote_{suffix}.jpg"
|
||
|
||
try:
|
||
with open(file_path, "wb") as f:
|
||
f.write(image_data)
|
||
except Exception as e:
|
||
self.LOG.error(f"保存引用图片失败: {e}")
|
||
return None
|
||
|
||
self.LOG.debug(f"引用图片已保存: {file_path}")
|
||
return str(file_path)
|
||
|
||
async def _upload_file_to_dify(self, file_path: str, user_id: str) -> Optional[str]:
|
||
if not file_path or not os.path.isfile(file_path):
|
||
return None
|
||
|
||
self.LOG.debug(f"开始上传文件到Dify: {file_path}")
|
||
url = f"{self.base_url}/files/upload"
|
||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||
proxy = self.http_proxy if self.http_proxy else None
|
||
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
with open(file_path, "rb") as f:
|
||
form = aiohttp.FormData()
|
||
form.add_field("file", f, filename=os.path.basename(file_path))
|
||
form.add_field("user", user_id)
|
||
response = await session.post(url, headers=headers, data=form, proxy=proxy, timeout=40)
|
||
if response.status not in (200, 201):
|
||
error_text = await response.text()
|
||
self.LOG.error(f"Dify上传失败: {response.status} {error_text}")
|
||
return None
|
||
|
||
resp_data = await response.json()
|
||
self.LOG.debug(f"Dify上传成功: status={response.status}, keys={list(resp_data.keys())}")
|
||
|
||
if isinstance(resp_data, dict):
|
||
if resp_data.get("id"):
|
||
return resp_data.get("id")
|
||
data = resp_data.get("data", {})
|
||
if isinstance(data, dict):
|
||
return data.get("id") or data.get("file_id")
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
self.LOG.error(f"上传文件到Dify失败: {e}")
|
||
return None
|
||
|
||
async def _chat_with_dify(self, session_id: str, user_id: str, query: str,
|
||
files: Optional[List[Dict[str, Any]]] = None) -> Tuple[bool, Optional[str]]:
|
||
"""
|
||
与Dify API交互获取回复
|
||
|
||
Args:
|
||
session_id: 会话ID(群聊为群ID,私聊为个人wxid)
|
||
user_id: 用户wxid
|
||
query: 用户查询内容
|
||
|
||
Returns:
|
||
API返回的回复内容
|
||
"""
|
||
# 清理过期会话
|
||
self._cleanup_expired_conversations()
|
||
|
||
# 更新最后活动时间
|
||
self.last_activity[session_id] = time.time()
|
||
|
||
# 初始化会话历史
|
||
if session_id not in self.conversations:
|
||
self.conversations[session_id] = []
|
||
|
||
# 准备请求头
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json",
|
||
"Accept": "text/event-stream" # 指定接受事件流
|
||
}
|
||
|
||
# 准备历史记录
|
||
history_text = ""
|
||
if self.conversations[session_id]:
|
||
# 将历史记录转换为字符串格式
|
||
for msg in self.conversations[session_id]:
|
||
role = "用户" if msg["role"] == "user" else "助手"
|
||
if role == "用户":
|
||
history_text += f"{role}: {msg['content']}\n"
|
||
history_text = history_text.strip()
|
||
|
||
# 准备输入参数
|
||
inputs_params = {
|
||
"query": query,
|
||
"conversation_id": session_id
|
||
}
|
||
|
||
# 如果有历史记录,添加到inputs_params中
|
||
if history_text:
|
||
inputs_params["history"] = history_text
|
||
|
||
if files is None:
|
||
files = []
|
||
self.LOG.debug(f"Dify请求准备: files={len(files)}")
|
||
|
||
# 准备请求数据
|
||
data = {
|
||
"files": files,
|
||
"user": user_id,
|
||
"inputs": inputs_params,
|
||
"response_mode": "blocking" # 使用阻塞响应模式
|
||
}
|
||
|
||
# 如果有历史记录,同时添加到conversation_history中
|
||
if self.conversations[session_id]:
|
||
data["conversation_history"] = self.conversations[session_id]
|
||
|
||
# 设置代理
|
||
proxy = self.http_proxy if self.http_proxy else None
|
||
|
||
# 发送请求
|
||
url = f"{self.base_url}/workflows/run"
|
||
|
||
self.LOG.info(f"发送请求到Dify API: {url}")
|
||
self.LOG.info(f"请求数据: {json.dumps(data, ensure_ascii=False)}")
|
||
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
response = await session.post(url, headers=headers, json=data, proxy=proxy, timeout=40)
|
||
if response.status != 200:
|
||
error_text = await response.text()
|
||
self.LOG.error(f"Dify API请求失败: {response.status} {error_text}")
|
||
return False, f"请求失败,状态码: {response.status}"
|
||
|
||
# 解析响应
|
||
response_data = await response.json()
|
||
self.LOG.info(f"收到Dify API响应: {json.dumps(response_data, ensure_ascii=False)}")
|
||
|
||
# 提取回答内容
|
||
answer = ""
|
||
total_tokens = 0
|
||
|
||
# 获取输出内容
|
||
outputs = response_data.get("data", {}).get("outputs", {})
|
||
if outputs:
|
||
# 处理媒体类型返回
|
||
if "result" in outputs and "type" in outputs:
|
||
if outputs["type"] == "image":
|
||
downloader = MediaDownloader()
|
||
image_url = outputs["result"]
|
||
image_path = await downloader.download_media(image_url)
|
||
answer = image_path
|
||
if outputs["type"] == "video":
|
||
downloader = MediaDownloader()
|
||
image_url = outputs["result"]
|
||
image_path = await downloader.download_media(image_url)
|
||
answer = image_path
|
||
# 处理文本类型返回
|
||
elif "text" in outputs and isinstance(outputs["text"], str):
|
||
answer = outputs["text"]
|
||
# 兼容旧版处理逻辑
|
||
else:
|
||
for key, value in outputs.items():
|
||
if isinstance(value, str) and value.strip():
|
||
answer += value
|
||
elif isinstance(value, dict):
|
||
# 处理嵌套字典的情况
|
||
for sub_key, sub_value in value.items():
|
||
if isinstance(sub_value, str) and sub_value.strip():
|
||
answer += sub_value
|
||
elif isinstance(value, list):
|
||
# 处理列表的情况
|
||
for item in value:
|
||
if isinstance(item, str) and item.strip():
|
||
answer += item
|
||
elif isinstance(item, dict):
|
||
# 处理列表中的字典
|
||
for item_key, item_value in item.items():
|
||
if isinstance(item_value, str) and item_value.strip():
|
||
answer += item_value
|
||
|
||
# 获取token使用情况
|
||
total_tokens = response_data.get("data", {}).get("total_tokens", 0)
|
||
|
||
if answer and not os.path.isfile(answer):
|
||
answer = remove_reasoning_content(answer)
|
||
answer = remove_trailing_content(answer)
|
||
answer = remove_grok_render_tags(answer)
|
||
answer = re.sub(r'\n{3,}', '\n\n', answer).strip()
|
||
|
||
# 更新会话历史
|
||
self.conversations[session_id].append({
|
||
"role": "user",
|
||
"content": query
|
||
})
|
||
|
||
self.conversations[session_id].append({
|
||
"role": "assistant",
|
||
"content": answer
|
||
})
|
||
|
||
# 限制会话历史长度
|
||
if len(self.conversations[session_id]) > self.max_history_length * 2:
|
||
self.conversations[session_id] = self.conversations[session_id][-self.max_history_length * 2:]
|
||
|
||
# 统计token使用情况
|
||
if total_tokens > 0:
|
||
if user_id in self.token_usage:
|
||
self.token_usage[user_id] += total_tokens
|
||
else:
|
||
self.token_usage[user_id] = total_tokens
|
||
|
||
self.LOG.info(
|
||
f"用户 {user_id} 本次消耗 {total_tokens} tokens,累计 {self.token_usage[user_id]} tokens")
|
||
|
||
return True, answer
|
||
|
||
except Exception as e:
|
||
self.LOG.error(f"处理Dify响应时出错: {str(e)}")
|
||
return False, f"处理响应时出错"
|
||
|
||
def _cleanup_expired_conversations(self) -> None:
|
||
"""清理过期的会话"""
|
||
current_time = time.time()
|
||
expired_sessions = []
|
||
|
||
for session_id, last_time in self.last_activity.items():
|
||
if current_time - last_time > self.conversation_timeout:
|
||
expired_sessions.append(session_id)
|
||
|
||
for session_id in expired_sessions:
|
||
if session_id in self.conversations:
|
||
del self.conversations[session_id]
|
||
del self.last_activity[session_id]
|
||
|
||
if expired_sessions:
|
||
self.LOG.info(f"已清理 {len(expired_sessions)} 个过期会话")
|
||
|
||
def get_token_usage_report(self) -> str:
|
||
"""获取token使用情况报告"""
|
||
if not self.token_usage:
|
||
return "暂无token使用记录"
|
||
|
||
report = "Token使用情况统计:\n"
|
||
for user_id, tokens in sorted(self.token_usage.items(), key=lambda x: x[1], reverse=True):
|
||
report += f"用户 {user_id}: {tokens} tokens\n"
|
||
|
||
total = sum(self.token_usage.values())
|
||
report += f"\n总计: {total} tokens"
|
||
return report
|
||
|
||
def reset_conversation(self, session_id: str) -> bool:
|
||
"""重置指定会话的上下文"""
|
||
if session_id in self.conversations:
|
||
del self.conversations[session_id]
|
||
if session_id in self.last_activity:
|
||
del self.last_activity[session_id]
|
||
return True
|
||
return False
|
||
|
||
def reset_all_conversations(self) -> None:
|
||
"""重置所有会话上下文"""
|
||
self.conversations.clear()
|
||
self.last_activity.clear()
|
||
|
||
def check_file_type(self, file_path) -> int:
|
||
if os.path.isfile(file_path):
|
||
# 获取文件扩展名并转换为小写
|
||
file_extension = os.path.splitext(file_path)[1].lower()
|
||
|
||
if file_extension in IMAGE_EXTENSIONS:
|
||
return 1
|
||
# 在此处添加发送图片的逻辑
|
||
elif file_extension in VIDEO_EXTENSIONS:
|
||
return 2
|
||
# 在此处添加发送视频的逻辑
|
||
else:
|
||
return 0
|
||
else:
|
||
return 0
|
||
|
||
async def _get_first_frame(self, video_path: str, output_path: str) -> Optional[str]:
|
||
"""
|
||
异步提取视频的第一帧并保存为图片
|
||
:param video_path: 视频文件路径
|
||
:param output_path: 输出图片路径
|
||
:return: 输出图片的绝对路径,如果失败则返回None
|
||
"""
|
||
try:
|
||
self.LOG.info(f"开始提取视频首帧: {video_path}")
|
||
|
||
# 使用线程池执行OpenCV操作
|
||
def extract_frame():
|
||
cap = cv2.VideoCapture(video_path)
|
||
if not cap.isOpened():
|
||
self.LOG.error(f"无法打开视频: {video_path}")
|
||
return None
|
||
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
self.LOG.error("无法读取视频帧")
|
||
cap.release()
|
||
return None
|
||
|
||
try:
|
||
cv2.imwrite(output_path, frame)
|
||
self.LOG.info(f"首帧已保存为: {output_path}")
|
||
except Exception as e:
|
||
self.LOG.error(f"保存首帧图片失败: {e}")
|
||
cap.release()
|
||
return None
|
||
|
||
cap.release()
|
||
return os.path.abspath(output_path)
|
||
|
||
# 在线程池中执行OpenCV操作
|
||
result = await asyncio.to_thread(extract_frame)
|
||
return result
|
||
|
||
except Exception as e:
|
||
self.LOG.error(f"提取视频首帧时出错: {e}")
|
||
return None
|