3659 lines
158 KiB
Python
3659 lines
158 KiB
Python
"""
|
||
AI 聊天插件
|
||
|
||
支持自定义模型、API 和人设
|
||
支持 Redis 存储对话历史和限流
|
||
"""
|
||
|
||
import asyncio
|
||
import tomllib
|
||
import aiohttp
|
||
import sqlite3
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
from loguru import logger
|
||
from utils.plugin_base import PluginBase
|
||
from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message
|
||
from utils.redis_cache import get_cache
|
||
import xml.etree.ElementTree as ET
|
||
import base64
|
||
import uuid
|
||
|
||
# 可选导入代理支持
|
||
try:
|
||
from aiohttp_socks import ProxyConnector
|
||
PROXY_SUPPORT = True
|
||
except ImportError:
|
||
PROXY_SUPPORT = False
|
||
logger.warning("aiohttp_socks 未安装,代理功能将不可用")
|
||
|
||
|
||
class AIChat(PluginBase):
|
||
"""AI 聊天插件"""
|
||
|
||
# 插件元数据
|
||
description = "AI 聊天插件,支持自定义模型和人设"
|
||
author = "ShiHao"
|
||
version = "1.0.0"
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.config = None
|
||
self.system_prompt = ""
|
||
self.memory = {} # 存储每个会话的记忆 {chat_id: [messages]}
|
||
self.history_dir = None # 历史记录目录
|
||
self.history_locks = {} # 每个会话一把锁
|
||
self.image_desc_queue = asyncio.Queue() # 图片描述任务队列
|
||
self.image_desc_workers = [] # 工作协程列表
|
||
self.persistent_memory_db = None # 持久记忆数据库路径
|
||
|
||
async def async_init(self):
|
||
"""插件异步初始化"""
|
||
# 读取配置
|
||
config_path = Path(__file__).parent / "config.toml"
|
||
with open(config_path, "rb") as f:
|
||
self.config = tomllib.load(f)
|
||
|
||
# 读取人设
|
||
prompt_file = self.config["prompt"]["system_prompt_file"]
|
||
prompt_path = Path(__file__).parent / "prompts" / prompt_file
|
||
|
||
if prompt_path.exists():
|
||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||
self.system_prompt = f.read().strip()
|
||
logger.success(f"已加载人设: {prompt_file}")
|
||
else:
|
||
logger.warning(f"人设文件不存在: {prompt_file},使用默认人设")
|
||
self.system_prompt = "你是一个友好的 AI 助手。"
|
||
|
||
# 检查代理配置
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5")
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
logger.info(f"AI 聊天插件已启用代理: {proxy_type}://{proxy_host}:{proxy_port}")
|
||
|
||
# 初始化历史记录目录
|
||
history_config = self.config.get("history", {})
|
||
if history_config.get("enabled", True):
|
||
history_dir_name = history_config.get("history_dir", "history")
|
||
self.history_dir = Path(__file__).parent / history_dir_name
|
||
self.history_dir.mkdir(exist_ok=True)
|
||
logger.info(f"历史记录目录: {self.history_dir}")
|
||
|
||
# 启动图片描述工作协程(并发数为2)
|
||
for i in range(2):
|
||
worker = asyncio.create_task(self._image_desc_worker())
|
||
self.image_desc_workers.append(worker)
|
||
logger.info("已启动 2 个图片描述工作协程")
|
||
|
||
# 初始化持久记忆数据库
|
||
self._init_persistent_memory_db()
|
||
|
||
logger.info(f"AI 聊天插件已加载,模型: {self.config['api']['model']}")
|
||
|
||
def _init_persistent_memory_db(self):
|
||
"""初始化持久记忆数据库"""
|
||
db_dir = Path(__file__).parent / "data"
|
||
db_dir.mkdir(exist_ok=True)
|
||
self.persistent_memory_db = db_dir / "persistent_memory.db"
|
||
|
||
conn = sqlite3.connect(self.persistent_memory_db)
|
||
cursor = conn.cursor()
|
||
cursor.execute("""
|
||
CREATE TABLE IF NOT EXISTS memories (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
chat_id TEXT NOT NULL,
|
||
chat_type TEXT NOT NULL,
|
||
user_wxid TEXT NOT NULL,
|
||
user_nickname TEXT,
|
||
content TEXT NOT NULL,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
)
|
||
""")
|
||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_chat_id ON memories(chat_id)")
|
||
conn.commit()
|
||
conn.close()
|
||
logger.info(f"持久记忆数据库已初始化: {self.persistent_memory_db}")
|
||
|
||
def _add_persistent_memory(self, chat_id: str, chat_type: str, user_wxid: str,
|
||
user_nickname: str, content: str) -> int:
|
||
"""添加持久记忆,返回记忆ID"""
|
||
conn = sqlite3.connect(self.persistent_memory_db)
|
||
cursor = conn.cursor()
|
||
cursor.execute("""
|
||
INSERT INTO memories (chat_id, chat_type, user_wxid, user_nickname, content)
|
||
VALUES (?, ?, ?, ?, ?)
|
||
""", (chat_id, chat_type, user_wxid, user_nickname, content))
|
||
memory_id = cursor.lastrowid
|
||
conn.commit()
|
||
conn.close()
|
||
return memory_id
|
||
|
||
def _get_persistent_memories(self, chat_id: str) -> list:
|
||
"""获取指定会话的所有持久记忆"""
|
||
conn = sqlite3.connect(self.persistent_memory_db)
|
||
cursor = conn.cursor()
|
||
cursor.execute("""
|
||
SELECT id, user_nickname, content, created_at
|
||
FROM memories
|
||
WHERE chat_id = ?
|
||
ORDER BY created_at ASC
|
||
""", (chat_id,))
|
||
rows = cursor.fetchall()
|
||
conn.close()
|
||
return [{"id": r[0], "nickname": r[1], "content": r[2], "time": r[3]} for r in rows]
|
||
|
||
def _delete_persistent_memory(self, chat_id: str, memory_id: int) -> bool:
|
||
"""删除指定的持久记忆"""
|
||
conn = sqlite3.connect(self.persistent_memory_db)
|
||
cursor = conn.cursor()
|
||
cursor.execute("DELETE FROM memories WHERE id = ? AND chat_id = ?", (memory_id, chat_id))
|
||
deleted = cursor.rowcount > 0
|
||
conn.commit()
|
||
conn.close()
|
||
return deleted
|
||
|
||
def _clear_persistent_memories(self, chat_id: str) -> int:
|
||
"""清空指定会话的所有持久记忆,返回删除数量"""
|
||
conn = sqlite3.connect(self.persistent_memory_db)
|
||
cursor = conn.cursor()
|
||
cursor.execute("DELETE FROM memories WHERE chat_id = ?", (chat_id,))
|
||
deleted_count = cursor.rowcount
|
||
conn.commit()
|
||
conn.close()
|
||
return deleted_count
|
||
|
||
def _get_chat_id(self, from_wxid: str, sender_wxid: str = None, is_group: bool = False) -> str:
|
||
"""获取会话ID"""
|
||
if is_group:
|
||
# 群聊使用 "群ID:用户ID" 组合,确保每个用户有独立的对话记忆
|
||
user_wxid = sender_wxid or from_wxid
|
||
return f"{from_wxid}:{user_wxid}"
|
||
else:
|
||
return sender_wxid or from_wxid # 私聊使用用户ID
|
||
|
||
async def _get_user_nickname(self, bot, from_wxid: str, user_wxid: str, is_group: bool) -> str:
|
||
"""
|
||
获取用户昵称,优先使用 Redis 缓存
|
||
|
||
Args:
|
||
bot: WechatHookClient 实例
|
||
from_wxid: 消息来源(群聊ID或私聊用户ID)
|
||
user_wxid: 用户wxid
|
||
is_group: 是否群聊
|
||
|
||
Returns:
|
||
用户昵称
|
||
"""
|
||
if not is_group:
|
||
return ""
|
||
|
||
nickname = ""
|
||
|
||
# 1. 优先从 Redis 缓存获取
|
||
redis_cache = get_cache()
|
||
if redis_cache and redis_cache.enabled:
|
||
cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid)
|
||
if cached_info and cached_info.get("nickname"):
|
||
logger.debug(f"[缓存命中] 用户昵称: {user_wxid} -> {cached_info['nickname']}")
|
||
return cached_info["nickname"]
|
||
|
||
# 2. 缓存未命中,调用 API 获取
|
||
try:
|
||
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
|
||
if user_info and user_info.get("nickName", {}).get("string"):
|
||
nickname = user_info["nickName"]["string"]
|
||
# 存入缓存
|
||
if redis_cache and redis_cache.enabled:
|
||
redis_cache.set_user_info(from_wxid, user_wxid, user_info)
|
||
logger.debug(f"[已缓存] 用户昵称: {user_wxid} -> {nickname}")
|
||
return nickname
|
||
except Exception as e:
|
||
logger.warning(f"API获取用户昵称失败: {e}")
|
||
|
||
# 3. 从 MessageLogger 数据库查询
|
||
if not nickname:
|
||
try:
|
||
from plugins.MessageLogger.main import MessageLogger
|
||
msg_logger = MessageLogger.get_instance()
|
||
if msg_logger:
|
||
with msg_logger.get_db_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute(
|
||
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
|
||
(user_wxid,)
|
||
)
|
||
result = cursor.fetchone()
|
||
if result:
|
||
nickname = result[0]
|
||
except Exception as e:
|
||
logger.debug(f"从数据库获取昵称失败: {e}")
|
||
|
||
# 4. 最后降级使用 wxid
|
||
if not nickname:
|
||
nickname = user_wxid or "未知用户"
|
||
|
||
return nickname
|
||
|
||
def _check_rate_limit(self, user_wxid: str) -> tuple:
|
||
"""
|
||
检查用户是否超过限流
|
||
|
||
Args:
|
||
user_wxid: 用户wxid
|
||
|
||
Returns:
|
||
(是否允许, 剩余次数, 重置时间秒数)
|
||
"""
|
||
rate_limit_config = self.config.get("rate_limit", {})
|
||
if not rate_limit_config.get("enabled", True):
|
||
return (True, 999, 0)
|
||
|
||
redis_cache = get_cache()
|
||
if not redis_cache or not redis_cache.enabled:
|
||
return (True, 999, 0) # Redis 不可用时不限流
|
||
|
||
limit = rate_limit_config.get("ai_chat_limit", 20)
|
||
window = rate_limit_config.get("ai_chat_window", 60)
|
||
|
||
return redis_cache.check_rate_limit(user_wxid, limit, window, "ai_chat")
|
||
|
||
def _add_to_memory(self, chat_id: str, role: str, content, image_base64: str = None):
|
||
"""
|
||
添加消息到记忆
|
||
|
||
Args:
|
||
chat_id: 会话ID
|
||
role: 角色 (user/assistant)
|
||
content: 消息内容(可以是字符串或列表)
|
||
image_base64: 可选的图片base64数据
|
||
"""
|
||
if not self.config.get("memory", {}).get("enabled", False):
|
||
return
|
||
|
||
# 如果有图片,构建多模态内容
|
||
if image_base64:
|
||
message_content = [
|
||
{"type": "text", "text": content if isinstance(content, str) else ""},
|
||
{"type": "image_url", "image_url": {"url": image_base64}}
|
||
]
|
||
else:
|
||
message_content = content
|
||
|
||
# 优先使用 Redis 存储
|
||
redis_config = self.config.get("redis", {})
|
||
if redis_config.get("use_redis_history", True):
|
||
redis_cache = get_cache()
|
||
if redis_cache and redis_cache.enabled:
|
||
ttl = redis_config.get("chat_history_ttl", 86400)
|
||
redis_cache.add_chat_message(chat_id, role, message_content, ttl=ttl)
|
||
# 裁剪历史
|
||
max_messages = self.config["memory"]["max_messages"]
|
||
redis_cache.trim_chat_history(chat_id, max_messages)
|
||
return
|
||
|
||
# 降级到内存存储
|
||
if chat_id not in self.memory:
|
||
self.memory[chat_id] = []
|
||
|
||
self.memory[chat_id].append({"role": role, "content": message_content})
|
||
|
||
# 限制记忆长度
|
||
max_messages = self.config["memory"]["max_messages"]
|
||
if len(self.memory[chat_id]) > max_messages:
|
||
self.memory[chat_id] = self.memory[chat_id][-max_messages:]
|
||
|
||
def _get_memory_messages(self, chat_id: str) -> list:
|
||
"""获取记忆中的消息"""
|
||
if not self.config.get("memory", {}).get("enabled", False):
|
||
return []
|
||
|
||
# 优先从 Redis 获取
|
||
redis_config = self.config.get("redis", {})
|
||
if redis_config.get("use_redis_history", True):
|
||
redis_cache = get_cache()
|
||
if redis_cache and redis_cache.enabled:
|
||
max_messages = self.config["memory"]["max_messages"]
|
||
return redis_cache.get_chat_history(chat_id, max_messages)
|
||
|
||
# 降级到内存
|
||
return self.memory.get(chat_id, [])
|
||
|
||
def _clear_memory(self, chat_id: str):
|
||
"""清空指定会话的记忆"""
|
||
# 清空 Redis
|
||
redis_config = self.config.get("redis", {})
|
||
if redis_config.get("use_redis_history", True):
|
||
redis_cache = get_cache()
|
||
if redis_cache and redis_cache.enabled:
|
||
redis_cache.clear_chat_history(chat_id)
|
||
|
||
# 同时清空内存
|
||
if chat_id in self.memory:
|
||
del self.memory[chat_id]
|
||
|
||
async def _download_and_encode_image(self, bot, cdnurl: str, aeskey: str) -> str:
|
||
"""下载图片并转换为base64,优先从缓存获取"""
|
||
try:
|
||
# 1. 优先从 Redis 缓存获取
|
||
from utils.redis_cache import RedisCache
|
||
redis_cache = get_cache()
|
||
if redis_cache and redis_cache.enabled:
|
||
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
|
||
if media_key:
|
||
cached_data = redis_cache.get_cached_media(media_key, "image")
|
||
if cached_data:
|
||
logger.debug(f"[缓存命中] 图片从 Redis 获取: {media_key[:20]}...")
|
||
return cached_data
|
||
|
||
# 2. 缓存未命中,下载图片
|
||
logger.debug(f"[缓存未命中] 开始下载图片...")
|
||
temp_dir = Path(__file__).parent / "temp"
|
||
temp_dir.mkdir(exist_ok=True)
|
||
|
||
filename = f"temp_{uuid.uuid4().hex[:8]}.jpg"
|
||
save_path = str((temp_dir / filename).resolve())
|
||
|
||
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=2)
|
||
if not success:
|
||
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=1)
|
||
|
||
if not success:
|
||
return ""
|
||
|
||
# 等待文件写入完成
|
||
import os
|
||
import asyncio
|
||
for _ in range(20): # 最多等待10秒
|
||
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
|
||
break
|
||
await asyncio.sleep(0.5)
|
||
|
||
if not os.path.exists(save_path):
|
||
return ""
|
||
|
||
with open(save_path, "rb") as f:
|
||
image_data = base64.b64encode(f.read()).decode()
|
||
|
||
base64_result = f"data:image/jpeg;base64,{image_data}"
|
||
|
||
# 3. 缓存到 Redis(供后续使用)
|
||
if redis_cache and redis_cache.enabled and media_key:
|
||
redis_cache.cache_media(media_key, base64_result, "image", ttl=300)
|
||
logger.debug(f"[已缓存] 图片缓存到 Redis: {media_key[:20]}...")
|
||
|
||
try:
|
||
Path(save_path).unlink()
|
||
except:
|
||
pass
|
||
|
||
return base64_result
|
||
except Exception as e:
|
||
logger.error(f"下载图片失败: {e}")
|
||
return ""
|
||
|
||
async def _download_emoji_and_encode(self, cdn_url: str, max_retries: int = 3) -> str:
|
||
"""下载表情包并转换为base64(HTTP 直接下载,带重试机制),优先从缓存获取"""
|
||
# 替换 HTML 实体
|
||
cdn_url = cdn_url.replace("&", "&")
|
||
|
||
# 1. 优先从 Redis 缓存获取
|
||
from utils.redis_cache import RedisCache
|
||
redis_cache = get_cache()
|
||
media_key = RedisCache.generate_media_key(cdnurl=cdn_url)
|
||
if redis_cache and redis_cache.enabled and media_key:
|
||
cached_data = redis_cache.get_cached_media(media_key, "emoji")
|
||
if cached_data:
|
||
logger.debug(f"[缓存命中] 表情包从 Redis 获取: {media_key[:20]}...")
|
||
return cached_data
|
||
|
||
# 2. 缓存未命中,下载表情包
|
||
logger.debug(f"[缓存未命中] 开始下载表情包...")
|
||
temp_dir = Path(__file__).parent / "temp"
|
||
temp_dir.mkdir(exist_ok=True)
|
||
|
||
filename = f"temp_{uuid.uuid4().hex[:8]}.gif"
|
||
save_path = temp_dir / filename
|
||
|
||
last_error = None
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
# 使用 aiohttp 下载,每次重试增加超时时间
|
||
timeout = aiohttp.ClientTimeout(total=30 + attempt * 15)
|
||
|
||
# 配置代理
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_username = proxy_config.get("username")
|
||
proxy_password = proxy_config.get("password")
|
||
|
||
if proxy_username and proxy_password:
|
||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||
else:
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
|
||
if PROXY_SUPPORT:
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
except:
|
||
connector = None
|
||
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
async with session.get(cdn_url) as response:
|
||
if response.status == 200:
|
||
content = await response.read()
|
||
|
||
if len(content) == 0:
|
||
logger.warning(f"表情包下载内容为空,重试 {attempt + 1}/{max_retries}")
|
||
continue
|
||
|
||
# 编码为 base64
|
||
image_data = base64.b64encode(content).decode()
|
||
|
||
logger.debug(f"表情包下载成功,大小: {len(content)} 字节")
|
||
base64_result = f"data:image/gif;base64,{image_data}"
|
||
|
||
# 3. 缓存到 Redis(供后续使用)
|
||
if redis_cache and redis_cache.enabled and media_key:
|
||
redis_cache.cache_media(media_key, base64_result, "emoji", ttl=300)
|
||
logger.debug(f"[已缓存] 表情包缓存到 Redis: {media_key[:20]}...")
|
||
|
||
return base64_result
|
||
else:
|
||
logger.warning(f"表情包下载失败,状态码: {response.status},重试 {attempt + 1}/{max_retries}")
|
||
|
||
except asyncio.TimeoutError:
|
||
last_error = "请求超时"
|
||
logger.warning(f"表情包下载超时,重试 {attempt + 1}/{max_retries}")
|
||
except aiohttp.ClientError as e:
|
||
last_error = str(e)
|
||
logger.warning(f"表情包下载网络错误: {e},重试 {attempt + 1}/{max_retries}")
|
||
except Exception as e:
|
||
last_error = str(e)
|
||
logger.warning(f"表情包下载异常: {e},重试 {attempt + 1}/{max_retries}")
|
||
|
||
# 重试前等待(指数退避)
|
||
if attempt < max_retries - 1:
|
||
await asyncio.sleep(1 * (attempt + 1))
|
||
|
||
logger.error(f"表情包下载失败,已重试 {max_retries} 次: {last_error}")
|
||
return ""
|
||
|
||
async def _generate_image_description(self, image_base64: str, prompt: str, config: dict) -> str:
|
||
"""
|
||
使用 Gemini API 生成图片描述
|
||
|
||
Args:
|
||
image_base64: 图片的 base64 数据
|
||
prompt: 描述提示词
|
||
config: 图片描述配置
|
||
|
||
Returns:
|
||
图片描述文本,失败返回空字符串
|
||
"""
|
||
import json
|
||
try:
|
||
api_config = self.config["api"]
|
||
description_model = config.get("model", api_config["model"])
|
||
api_url = api_config.get("gemini_url", "https://api.functen.cn/v1beta/models")
|
||
|
||
# 处理 base64 数据
|
||
image_data = image_base64
|
||
mime_type = "image/jpeg"
|
||
if image_data.startswith("data:"):
|
||
mime_type = image_data.split(";")[0].split(":")[1]
|
||
image_data = image_data.split(",", 1)[1]
|
||
|
||
# 构建 Gemini 格式请求
|
||
full_url = f"{api_url}/{description_model}:streamGenerateContent?alt=sse"
|
||
|
||
payload = {
|
||
"contents": [
|
||
{
|
||
"parts": [
|
||
{"text": prompt},
|
||
{
|
||
"inline_data": {
|
||
"mime_type": mime_type,
|
||
"data": image_data
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
"generationConfig": {
|
||
"maxOutputTokens": config.get("max_tokens", 1000)
|
||
}
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_config['api_key']}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=api_config.get("timeout", 120))
|
||
|
||
# 配置代理
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False) and PROXY_SUPPORT:
|
||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_username = proxy_config.get("username")
|
||
proxy_password = proxy_config.get("password")
|
||
|
||
if proxy_username and proxy_password:
|
||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||
else:
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
except Exception as e:
|
||
logger.warning(f"代理配置失败,将直连: {e}")
|
||
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
async with session.post(full_url, json=payload, headers=headers) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"图片描述 API 返回错误: {resp.status}, {error_text[:200]}")
|
||
return ""
|
||
|
||
# 流式接收 Gemini 响应
|
||
description = ""
|
||
async for line in resp.content:
|
||
line = line.decode('utf-8').strip()
|
||
if not line or not line.startswith("data: "):
|
||
continue
|
||
|
||
try:
|
||
data = json.loads(line[6:])
|
||
candidates = data.get("candidates", [])
|
||
if candidates:
|
||
parts = candidates[0].get("content", {}).get("parts", [])
|
||
for part in parts:
|
||
if "text" in part:
|
||
description += part["text"]
|
||
except:
|
||
pass
|
||
|
||
logger.debug(f"图片描述生成成功: {description[:100]}...")
|
||
return description.strip()
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成图片描述失败: {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
return ""
|
||
|
||
def _collect_tools(self):
|
||
"""收集所有插件的LLM工具(支持白名单/黑名单过滤)"""
|
||
from utils.plugin_manager import PluginManager
|
||
tools = []
|
||
|
||
# 获取工具过滤配置
|
||
tools_config = self.config.get("tools", {})
|
||
mode = tools_config.get("mode", "all")
|
||
whitelist = set(tools_config.get("whitelist", []))
|
||
blacklist = set(tools_config.get("blacklist", []))
|
||
|
||
for plugin in PluginManager().plugins.values():
|
||
if hasattr(plugin, 'get_llm_tools'):
|
||
plugin_tools = plugin.get_llm_tools()
|
||
if plugin_tools:
|
||
for tool in plugin_tools:
|
||
tool_name = tool.get("function", {}).get("name", "")
|
||
|
||
# 根据模式过滤
|
||
if mode == "whitelist":
|
||
if tool_name in whitelist:
|
||
tools.append(tool)
|
||
logger.debug(f"[白名单] 启用工具: {tool_name}")
|
||
elif mode == "blacklist":
|
||
if tool_name not in blacklist:
|
||
tools.append(tool)
|
||
else:
|
||
logger.debug(f"[黑名单] 禁用工具: {tool_name}")
|
||
else: # all
|
||
tools.append(tool)
|
||
|
||
return tools
|
||
|
||
# ==================== Gemini API 格式转换方法 ====================
|
||
|
||
def _convert_tools_to_gemini(self, openai_tools: list) -> list:
|
||
"""
|
||
将 OpenAI 格式的工具定义转换为 Gemini 格式
|
||
|
||
OpenAI: [{"type": "function", "function": {"name": ..., "parameters": ...}}]
|
||
Gemini: [{"function_declarations": [{"name": ..., "parameters": ...}]}]
|
||
"""
|
||
if not openai_tools:
|
||
return []
|
||
|
||
function_declarations = []
|
||
for tool in openai_tools:
|
||
if tool.get("type") == "function":
|
||
func = tool.get("function", {})
|
||
function_declarations.append({
|
||
"name": func.get("name", ""),
|
||
"description": func.get("description", ""),
|
||
"parameters": func.get("parameters", {})
|
||
})
|
||
|
||
if function_declarations:
|
||
return [{"function_declarations": function_declarations}]
|
||
return []
|
||
|
||
def _build_gemini_contents(self, system_content: str, history_messages: list,
|
||
current_message: dict, is_group: bool = False) -> list:
|
||
"""
|
||
构建 Gemini API 的 contents 格式
|
||
|
||
Args:
|
||
system_content: 系统提示词(包含人设、时间、持久记忆等)
|
||
history_messages: 历史消息列表
|
||
current_message: 当前用户消息 {"text": str, "media": optional}
|
||
is_group: 是否群聊
|
||
|
||
Returns:
|
||
Gemini contents 格式的列表
|
||
"""
|
||
contents = []
|
||
|
||
# Gemini 没有 system role,将系统提示放在第一条 user 消息中
|
||
# 然后用一条简短的 model 回复来"确认"
|
||
system_parts = [{"text": f"[系统指令]\n{system_content}\n\n请按照以上指令进行对话。"}]
|
||
contents.append({"role": "user", "parts": system_parts})
|
||
contents.append({"role": "model", "parts": [{"text": "好的,我会按照指令进行对话。"}]})
|
||
|
||
# 添加历史消息
|
||
for msg in history_messages:
|
||
gemini_msg = self._convert_message_to_gemini(msg, is_group)
|
||
if gemini_msg:
|
||
contents.append(gemini_msg)
|
||
|
||
# 添加当前用户消息
|
||
current_parts = []
|
||
if current_message.get("text"):
|
||
current_parts.append({"text": current_message["text"]})
|
||
|
||
# 添加媒体内容(图片/视频)
|
||
if current_message.get("image_base64"):
|
||
image_data = current_message["image_base64"]
|
||
# 去除 data:image/xxx;base64, 前缀
|
||
if image_data.startswith("data:"):
|
||
mime_type = image_data.split(";")[0].split(":")[1]
|
||
image_data = image_data.split(",", 1)[1]
|
||
else:
|
||
mime_type = "image/jpeg"
|
||
current_parts.append({
|
||
"inline_data": {
|
||
"mime_type": mime_type,
|
||
"data": image_data
|
||
}
|
||
})
|
||
|
||
if current_message.get("video_base64"):
|
||
video_data = current_message["video_base64"]
|
||
# 去除 data:video/xxx;base64, 前缀
|
||
if video_data.startswith("data:"):
|
||
video_data = video_data.split(",", 1)[1]
|
||
current_parts.append({
|
||
"inline_data": {
|
||
"mime_type": "video/mp4",
|
||
"data": video_data
|
||
}
|
||
})
|
||
|
||
if current_parts:
|
||
contents.append({"role": "user", "parts": current_parts})
|
||
|
||
return contents
|
||
|
||
def _convert_message_to_gemini(self, msg: dict, is_group: bool = False) -> dict:
|
||
"""
|
||
将单条历史消息转换为 Gemini 格式
|
||
|
||
支持的输入格式:
|
||
1. 群聊历史: {"nickname": str, "content": str|list}
|
||
2. 私聊记忆: {"role": "user"|"assistant", "content": str|list}
|
||
"""
|
||
parts = []
|
||
|
||
# 群聊历史格式
|
||
if "nickname" in msg:
|
||
nickname = msg.get("nickname", "")
|
||
content = msg.get("content", "")
|
||
|
||
if isinstance(content, list):
|
||
# 多模态内容
|
||
for item in content:
|
||
if item.get("type") == "text":
|
||
text = item.get("text", "")
|
||
parts.append({"text": f"[{nickname}] {text}" if nickname else text})
|
||
elif item.get("type") == "image_url":
|
||
image_url = item.get("image_url", {}).get("url", "")
|
||
if image_url.startswith("data:"):
|
||
mime_type = image_url.split(";")[0].split(":")[1]
|
||
image_data = image_url.split(",", 1)[1]
|
||
parts.append({
|
||
"inline_data": {
|
||
"mime_type": mime_type,
|
||
"data": image_data
|
||
}
|
||
})
|
||
else:
|
||
# 纯文本
|
||
parts.append({"text": f"[{nickname}] {content}" if nickname else content})
|
||
|
||
# 群聊历史都作为 user 消息(因为是多人对话记录)
|
||
return {"role": "user", "parts": parts} if parts else None
|
||
|
||
# 私聊记忆格式
|
||
elif "role" in msg:
|
||
role = msg.get("role", "user")
|
||
content = msg.get("content", "")
|
||
|
||
# 转换角色名
|
||
gemini_role = "model" if role == "assistant" else "user"
|
||
|
||
if isinstance(content, list):
|
||
for item in content:
|
||
if item.get("type") == "text":
|
||
parts.append({"text": item.get("text", "")})
|
||
elif item.get("type") == "image_url":
|
||
image_url = item.get("image_url", {}).get("url", "")
|
||
if image_url.startswith("data:"):
|
||
mime_type = image_url.split(";")[0].split(":")[1]
|
||
image_data = image_url.split(",", 1)[1]
|
||
parts.append({
|
||
"inline_data": {
|
||
"mime_type": mime_type,
|
||
"data": image_data
|
||
}
|
||
})
|
||
else:
|
||
parts.append({"text": content})
|
||
|
||
return {"role": gemini_role, "parts": parts} if parts else None
|
||
|
||
return None
|
||
|
||
def _parse_gemini_tool_calls(self, response_parts: list) -> list:
|
||
"""
|
||
从 Gemini 响应中解析工具调用
|
||
|
||
Gemini 格式: {"functionCall": {"name": "...", "args": {...}}}
|
||
转换为内部格式: {"id": "...", "function": {"name": "...", "arguments": "..."}}
|
||
"""
|
||
import json
|
||
tool_calls = []
|
||
for i, part in enumerate(response_parts):
|
||
if "functionCall" in part:
|
||
func_call = part["functionCall"]
|
||
tool_calls.append({
|
||
"id": f"call_{uuid.uuid4().hex[:8]}",
|
||
"type": "function",
|
||
"function": {
|
||
"name": func_call.get("name", ""),
|
||
"arguments": json.dumps(func_call.get("args", {}), ensure_ascii=False)
|
||
}
|
||
})
|
||
return tool_calls
|
||
|
||
def _build_tool_response_contents(self, contents: list, tool_calls: list,
|
||
tool_results: list) -> list:
|
||
"""
|
||
构建包含工具调用结果的 contents,用于继续对话
|
||
|
||
Args:
|
||
contents: 原始 contents
|
||
tool_calls: 工具调用列表
|
||
tool_results: 工具执行结果列表
|
||
"""
|
||
import json
|
||
new_contents = contents.copy()
|
||
|
||
# 添加 model 的工具调用响应
|
||
function_call_parts = []
|
||
for tc in tool_calls:
|
||
function_call_parts.append({
|
||
"functionCall": {
|
||
"name": tc["function"]["name"],
|
||
"args": json.loads(tc["function"]["arguments"])
|
||
}
|
||
})
|
||
if function_call_parts:
|
||
new_contents.append({"role": "model", "parts": function_call_parts})
|
||
|
||
# 添加工具执行结果
|
||
function_response_parts = []
|
||
for i, result in enumerate(tool_results):
|
||
tool_name = tool_calls[i]["function"]["name"] if i < len(tool_calls) else "unknown"
|
||
function_response_parts.append({
|
||
"functionResponse": {
|
||
"name": tool_name,
|
||
"response": {"result": result.get("message", str(result))}
|
||
}
|
||
})
|
||
if function_response_parts:
|
||
new_contents.append({"role": "user", "parts": function_response_parts})
|
||
|
||
return new_contents
|
||
|
||
# ==================== 统一的 Gemini API 调用 ====================
|
||
|
||
async def _call_gemini_api(self, contents: list, tools: list = None,
|
||
bot=None, from_wxid: str = None,
|
||
chat_id: str = None, nickname: str = "",
|
||
user_wxid: str = None, is_group: bool = False) -> tuple:
|
||
"""
|
||
统一的 Gemini API 调用方法
|
||
|
||
Args:
|
||
contents: Gemini 格式的对话内容
|
||
tools: Gemini 格式的工具定义(可选)
|
||
bot: WechatHookClient 实例
|
||
from_wxid: 消息来源
|
||
chat_id: 会话ID
|
||
nickname: 用户昵称
|
||
user_wxid: 用户wxid
|
||
is_group: 是否群聊
|
||
|
||
Returns:
|
||
(response_text, tool_calls) - 响应文本和工具调用列表
|
||
"""
|
||
import json
|
||
|
||
api_config = self.config["api"]
|
||
model = api_config["model"]
|
||
api_url = api_config.get("gemini_url", api_config.get("url", "").replace("/v1/chat/completions", "/v1beta/models"))
|
||
api_key = api_config["api_key"]
|
||
|
||
# 构建完整 URL
|
||
full_url = f"{api_url}/{model}:streamGenerateContent?alt=sse"
|
||
|
||
# 构建请求体
|
||
payload = {
|
||
"contents": contents,
|
||
"generationConfig": {
|
||
"maxOutputTokens": api_config.get("max_tokens", 8192)
|
||
}
|
||
}
|
||
|
||
if tools:
|
||
payload["tools"] = tools
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=api_config.get("timeout", 120))
|
||
|
||
# 配置代理
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False) and PROXY_SUPPORT:
|
||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_username = proxy_config.get("username")
|
||
proxy_password = proxy_config.get("password")
|
||
|
||
if proxy_username and proxy_password:
|
||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||
else:
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
logger.debug(f"[Gemini] 使用代理: {proxy_type}://{proxy_host}:{proxy_port}")
|
||
except Exception as e:
|
||
logger.warning(f"[Gemini] 代理配置失败: {e}")
|
||
|
||
# 保存用户信息供工具调用使用
|
||
self._current_user_wxid = user_wxid
|
||
self._current_is_group = is_group
|
||
|
||
try:
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
logger.debug(f"[Gemini] 发送流式请求: {full_url}")
|
||
async with session.post(full_url, json=payload, headers=headers) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"[Gemini] API 错误: {resp.status}, {error_text[:500]}")
|
||
raise Exception(f"Gemini API 错误 {resp.status}: {error_text[:200]}")
|
||
|
||
# 流式接收响应
|
||
full_text = ""
|
||
all_parts = []
|
||
tool_call_hint_sent = False
|
||
|
||
async for line in resp.content:
|
||
line = line.decode('utf-8').strip()
|
||
if not line or not line.startswith("data: "):
|
||
continue
|
||
|
||
try:
|
||
data = json.loads(line[6:])
|
||
candidates = data.get("candidates", [])
|
||
if not candidates:
|
||
continue
|
||
|
||
content = candidates[0].get("content", )
|
||
parts = content.get("parts", [])
|
||
|
||
for part in parts:
|
||
all_parts.append(part)
|
||
|
||
# 收集文本
|
||
if "text" in part:
|
||
full_text += part["text"]
|
||
|
||
# 检测到工具调用时,先发送已有文本
|
||
if "functionCall" in part:
|
||
if not tool_call_hint_sent and bot and from_wxid:
|
||
tool_call_hint_sent = True
|
||
if full_text.strip():
|
||
logger.info(f"[Gemini] 检测到工具调用,先发送文本: {full_text[:30]}...")
|
||
await bot.send_text(from_wxid, full_text.strip())
|
||
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
# 解析工具调用
|
||
tool_calls = self._parse_gemini_tool_calls(all_parts)
|
||
|
||
logger.info(f"[Gemini] 响应完成, 文本长度: {len(full_text)}, 工具调用: {len(tool_calls)}")
|
||
|
||
return full_text.strip(), tool_calls
|
||
|
||
except aiohttp.ClientError as e:
|
||
logger.error(f"[Gemini] 网络请求失败: {e}")
|
||
raise
|
||
except asyncio.TimeoutError:
|
||
logger.error(f"[Gemini] 请求超时")
|
||
raise
|
||
|
||
async def _handle_gemini_response(self, response_text: str, tool_calls: list,
|
||
contents: list, tools: list,
|
||
bot, from_wxid: str, chat_id: str,
|
||
nickname: str, user_wxid: str, is_group: bool):
|
||
"""
|
||
处理 Gemini API 响应,包括工具调用
|
||
|
||
Args:
|
||
response_text: AI 响应文本
|
||
tool_calls: 工具调用列表
|
||
contents: 原始 contents(用于工具调用后继续对话)
|
||
tools: 工具定义
|
||
bot, from_wxid, chat_id, nickname, user_wxid, is_group: 上下文信息
|
||
"""
|
||
if tool_calls:
|
||
# 有工具调用,异步执行
|
||
logger.info(f"[Gemini] 启动异步工具执行,共 {len(tool_calls)} 个工具")
|
||
asyncio.create_task(
|
||
self._execute_gemini_tools_async(
|
||
tool_calls, contents, tools,
|
||
bot, from_wxid, chat_id, nickname, user_wxid, is_group
|
||
)
|
||
)
|
||
return None # 工具调用异步处理
|
||
|
||
return response_text
|
||
|
||
async def _execute_gemini_tools_async(self, tool_calls: list, contents: list, tools: list,
|
||
bot, from_wxid: str, chat_id: str,
|
||
nickname: str, user_wxid: str, is_group: bool):
|
||
"""
|
||
异步执行 Gemini 工具调用
|
||
"""
|
||
import json
|
||
|
||
try:
|
||
logger.info(f"[Gemini] 开始执行 {len(tool_calls)} 个工具")
|
||
|
||
# 收集需要 AI 回复的工具结果
|
||
need_ai_reply_results = []
|
||
tool_results = []
|
||
|
||
for tool_call in tool_calls:
|
||
function_name = tool_call["function"]["name"]
|
||
try:
|
||
arguments = json.loads(tool_call["function"]["arguments"])
|
||
except:
|
||
arguments = {}
|
||
|
||
logger.info(f"[Gemini] 执行工具: {function_name}, 参数: {arguments}")
|
||
|
||
result = await self._execute_tool_and_get_result(function_name, arguments, bot, from_wxid)
|
||
tool_results.append(result)
|
||
|
||
if result and result.get("success"):
|
||
logger.success(f"[Gemini] 工具 {function_name} 执行成功")
|
||
|
||
# 检查是否需要 AI 继续回复
|
||
if result.get("need_ai_reply"):
|
||
need_ai_reply_results.append({
|
||
"tool_call": tool_call,
|
||
"result": result
|
||
})
|
||
elif not result.get("already_sent") and result.get("message"):
|
||
if result.get("send_result_text"):
|
||
await bot.send_text(from_wxid, result["message"])
|
||
else:
|
||
logger.warning(f"[Gemini] 工具 {function_name} 执行失败: {result}")
|
||
if result and result.get("message"):
|
||
await bot.send_text(from_wxid, f"❌ {result['message']}")
|
||
|
||
# 如果有需要 AI 回复的工具结果,继续对话
|
||
if need_ai_reply_results:
|
||
await self._continue_gemini_with_tool_results(
|
||
contents, tools, tool_calls, tool_results,
|
||
bot, from_wxid, chat_id, nickname, user_wxid, is_group
|
||
)
|
||
|
||
logger.info("[Gemini] 所有工具执行完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[Gemini] 工具执行异常: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
try:
|
||
await bot.send_text(from_wxid, "❌ 工具执行出错")
|
||
except:
|
||
pass
|
||
|
||
async def _continue_gemini_with_tool_results(self, contents: list, tools: list,
|
||
tool_calls: list, tool_results: list,
|
||
bot, from_wxid: str, chat_id: str,
|
||
nickname: str, user_wxid: str, is_group: bool):
|
||
"""
|
||
基于工具结果继续 Gemini 对话
|
||
"""
|
||
try:
|
||
# 构建包含工具结果的新 contents
|
||
new_contents = self._build_tool_response_contents(contents, tool_calls, tool_results)
|
||
|
||
# 继续调用 API(不带工具,避免循环调用)
|
||
response_text, new_tool_calls = await self._call_gemini_api(
|
||
new_contents, tools=None,
|
||
bot=bot, from_wxid=from_wxid, chat_id=chat_id,
|
||
nickname=nickname, user_wxid=user_wxid, is_group=is_group
|
||
)
|
||
|
||
if response_text:
|
||
await bot.send_text(from_wxid, response_text)
|
||
logger.success(f"[Gemini] 工具回传后 AI 回复: {response_text[:50]}...")
|
||
|
||
# 保存到记忆
|
||
if chat_id:
|
||
self._add_to_memory(chat_id, "assistant", response_text)
|
||
if is_group:
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
|
||
await self._add_to_history(from_wxid, bot_nickname, response_text)
|
||
|
||
except Exception as e:
|
||
logger.error(f"[Gemini] 工具回传后继续对话失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
|
||
async def _process_with_gemini(self, text: str = "", image_base64: str = None,
|
||
video_base64: str = None, bot=None,
|
||
from_wxid: str = None, chat_id: str = None,
|
||
nickname: str = "", user_wxid: str = None,
|
||
is_group: bool = False) -> str:
|
||
"""
|
||
统一的 Gemini 消息处理入口
|
||
|
||
支持:纯文本、图片+文本、视频+文本
|
||
|
||
Args:
|
||
text: 用户消息文本
|
||
image_base64: 图片 base64(可选)
|
||
video_base64: 视频 base64(可选)
|
||
bot, from_wxid, chat_id, nickname, user_wxid, is_group: 上下文信息
|
||
|
||
Returns:
|
||
AI 响应文本,如果是工具调用则返回 None
|
||
"""
|
||
import json
|
||
|
||
# 1. 构建系统提示词
|
||
system_content = self._build_system_content(nickname, from_wxid, user_wxid, is_group)
|
||
|
||
# 2. 加载历史消息
|
||
history_messages = []
|
||
if is_group and from_wxid:
|
||
history = await self._load_history(from_wxid)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
history_messages = history[-max_context:] if len(history) > max_context else history
|
||
elif chat_id:
|
||
memory_messages = self._get_memory_messages(chat_id)
|
||
if memory_messages and len(memory_messages) > 1:
|
||
history_messages = memory_messages[:-1] # 排除刚添加的当前消息
|
||
|
||
# 3. 构建当前消息
|
||
current_message = {"text": f"[{nickname}] {text}" if is_group and nickname else text}
|
||
if image_base64:
|
||
current_message["image_base64"] = image_base64
|
||
if video_base64:
|
||
current_message["video_base64"] = video_base64
|
||
|
||
# 4. 构建 Gemini contents
|
||
contents = self._build_gemini_contents(system_content, history_messages, current_message, is_group)
|
||
|
||
# 5. 收集并转换工具
|
||
openai_tools = self._collect_tools()
|
||
gemini_tools = self._convert_tools_to_gemini(openai_tools)
|
||
|
||
if gemini_tools:
|
||
logger.info(f"[Gemini] 已加载 {len(openai_tools)} 个工具")
|
||
|
||
# 6. 调用 Gemini API(带重试)
|
||
max_retries = self.config.get("api", {}).get("max_retries", 2)
|
||
last_error = None
|
||
|
||
for attempt in range(max_retries + 1):
|
||
try:
|
||
response_text, tool_calls = await self._call_gemini_api(
|
||
contents=contents,
|
||
tools=gemini_tools if gemini_tools else None,
|
||
bot=bot,
|
||
from_wxid=from_wxid,
|
||
chat_id=chat_id,
|
||
nickname=nickname,
|
||
user_wxid=user_wxid,
|
||
is_group=is_group
|
||
)
|
||
|
||
# 处理工具调用
|
||
if tool_calls:
|
||
result = await self._handle_gemini_response(
|
||
response_text, tool_calls, contents, gemini_tools,
|
||
bot, from_wxid, chat_id, nickname, user_wxid, is_group
|
||
)
|
||
return result # None 表示工具调用已异步处理
|
||
|
||
# 检查空响应
|
||
if not response_text and attempt < max_retries:
|
||
logger.warning(f"[Gemini] 返回空内容,重试 {attempt + 1}/{max_retries}")
|
||
await asyncio.sleep(1)
|
||
continue
|
||
|
||
return response_text
|
||
|
||
except Exception as e:
|
||
last_error = e
|
||
if attempt < max_retries:
|
||
logger.warning(f"[Gemini] API 调用失败,重试 {attempt + 1}/{max_retries}: {e}")
|
||
await asyncio.sleep(1)
|
||
else:
|
||
raise
|
||
|
||
return ""
|
||
|
||
def _build_system_content(self, nickname: str, from_wxid: str,
|
||
user_wxid: str, is_group: bool) -> str:
|
||
"""构建系统提示词(包含人设、时间、持久记忆等)"""
|
||
system_content = self.system_prompt
|
||
|
||
# 添加当前时间
|
||
current_time = datetime.now()
|
||
weekday_map = {
|
||
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
|
||
4: "星期五", 5: "星期六", 6: "星期日"
|
||
}
|
||
weekday = weekday_map[current_time.weekday()]
|
||
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
|
||
system_content += f"\n\n当前时间:{time_str}"
|
||
|
||
if nickname:
|
||
system_content += f"\n当前对话用户的昵称是:{nickname}"
|
||
|
||
# 加载持久记忆
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
if memory_chat_id:
|
||
persistent_memories = self._get_persistent_memories(memory_chat_id)
|
||
if persistent_memories:
|
||
system_content += "\n\n【持久记忆】以下是用户要求你记住的重要信息:\n"
|
||
for m in persistent_memories:
|
||
mem_time = m['time'][:10] if m['time'] else ""
|
||
system_content += f"- [{mem_time}] {m['nickname']}: {m['content']}\n"
|
||
|
||
return system_content
|
||
|
||
# ==================== 结束 Gemini API 方法 ====================
|
||
|
||
async def _handle_list_prompts(self, bot, from_wxid: str):
|
||
"""处理人设列表指令"""
|
||
try:
|
||
prompts_dir = Path(__file__).parent / "prompts"
|
||
|
||
# 获取所有 .txt 文件
|
||
if not prompts_dir.exists():
|
||
await bot.send_text(from_wxid, "❌ prompts 目录不存在")
|
||
return
|
||
|
||
txt_files = sorted(prompts_dir.glob("*.txt"))
|
||
|
||
if not txt_files:
|
||
await bot.send_text(from_wxid, "❌ 没有找到任何人设文件")
|
||
return
|
||
|
||
# 构建列表消息
|
||
current_file = self.config["prompt"]["system_prompt_file"]
|
||
msg = "📋 可用人设列表:\n\n"
|
||
|
||
for i, file_path in enumerate(txt_files, 1):
|
||
filename = file_path.name
|
||
# 标记当前使用的人设
|
||
if filename == current_file:
|
||
msg += f"{i}. {filename} ✅\n"
|
||
else:
|
||
msg += f"{i}. {filename}\n"
|
||
|
||
msg += f"\n💡 使用方法:/切人设 文件名.txt"
|
||
|
||
await bot.send_text(from_wxid, msg)
|
||
logger.info("已发送人设列表")
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取人设列表失败: {e}")
|
||
await bot.send_text(from_wxid, f"❌ 获取人设列表失败: {str(e)}")
|
||
|
||
def _estimate_tokens(self, text: str) -> int:
|
||
"""
|
||
估算文本的 token 数量
|
||
|
||
简单估算规则:
|
||
- 中文:约 1.5 字符 = 1 token
|
||
- 英文:约 4 字符 = 1 token
|
||
- 混合文本取平均
|
||
"""
|
||
if not text:
|
||
return 0
|
||
|
||
# 统计中文字符数
|
||
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||
# 其他字符数
|
||
other_chars = len(text) - chinese_chars
|
||
|
||
# 估算 token 数
|
||
chinese_tokens = chinese_chars / 1.5
|
||
other_tokens = other_chars / 4
|
||
|
||
return int(chinese_tokens + other_tokens)
|
||
|
||
def _estimate_message_tokens(self, message: dict) -> int:
|
||
"""估算单条消息的 token 数"""
|
||
content = message.get("content", "")
|
||
|
||
if isinstance(content, str):
|
||
return self._estimate_tokens(content)
|
||
elif isinstance(content, list):
|
||
# 多模态消息
|
||
total = 0
|
||
for item in content:
|
||
if item.get("type") == "text":
|
||
total += self._estimate_tokens(item.get("text", ""))
|
||
elif item.get("type") == "image_url":
|
||
# 图片按 85 token 估算(OpenAI 低分辨率图片)
|
||
total += 85
|
||
return total
|
||
return 0
|
||
|
||
async def _handle_context_stats(self, bot, from_wxid: str, user_wxid: str, is_group: bool):
|
||
"""处理上下文统计指令"""
|
||
try:
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
|
||
# 计算持久记忆 token
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
persistent_memories = self._get_persistent_memories(memory_chat_id) if memory_chat_id else []
|
||
persistent_tokens = 0
|
||
if persistent_memories:
|
||
persistent_tokens += self._estimate_tokens("【持久记忆】以下是用户要求你记住的重要信息:\n")
|
||
for m in persistent_memories:
|
||
mem_time = m['time'][:10] if m['time'] else ""
|
||
persistent_tokens += self._estimate_tokens(f"- [{mem_time}] {m['nickname']}: {m['content']}\n")
|
||
|
||
if is_group:
|
||
# 群聊:使用 history 机制
|
||
history = await self._load_history(from_wxid)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
|
||
# 实际会发送给 AI 的上下文
|
||
context_messages = history[-max_context:] if len(history) > max_context else history
|
||
|
||
# 计算 token
|
||
context_tokens = 0
|
||
for msg in context_messages:
|
||
msg_content = msg.get("content", "")
|
||
nickname = msg.get("nickname", "")
|
||
|
||
if isinstance(msg_content, list):
|
||
# 多模态消息
|
||
for item in msg_content:
|
||
if item.get("type") == "text":
|
||
context_tokens += self._estimate_tokens(f"[{nickname}] {item.get('text', '')}")
|
||
elif item.get("type") == "image_url":
|
||
context_tokens += 85
|
||
else:
|
||
context_tokens += self._estimate_tokens(f"[{nickname}] {msg_content}")
|
||
|
||
# 加上 system prompt 的 token
|
||
system_tokens = self._estimate_tokens(self.system_prompt)
|
||
total_tokens = system_tokens + persistent_tokens + context_tokens
|
||
|
||
# 计算百分比
|
||
context_limit = self.config.get("api", {}).get("context_limit", 200000)
|
||
usage_percent = (total_tokens / context_limit) * 100
|
||
remaining_tokens = context_limit - total_tokens
|
||
|
||
msg = f"📊 群聊上下文统计\n\n"
|
||
msg += f"💬 历史总条数: {len(history)}\n"
|
||
msg += f"📤 AI可见条数: {len(context_messages)}/{max_context}\n"
|
||
msg += f"🤖 人设 Token: ~{system_tokens}\n"
|
||
msg += f"📌 持久记忆: {len(persistent_memories)} 条 (~{persistent_tokens} token)\n"
|
||
msg += f"📝 上下文 Token: ~{context_tokens}\n"
|
||
msg += f"📦 总计 Token: ~{total_tokens}\n"
|
||
msg += f"📈 使用率: {usage_percent:.1f}% (剩余 ~{remaining_tokens:,})\n"
|
||
msg += f"\n💡 /清空记忆 清空上下文 | /记忆列表 查看持久记忆"
|
||
|
||
else:
|
||
# 私聊:使用 memory 机制
|
||
memory_messages = self._get_memory_messages(chat_id)
|
||
max_messages = self.config.get("memory", {}).get("max_messages", 20)
|
||
|
||
# 计算 token
|
||
context_tokens = 0
|
||
for msg in memory_messages:
|
||
context_tokens += self._estimate_message_tokens(msg)
|
||
|
||
# 加上 system prompt 的 token
|
||
system_tokens = self._estimate_tokens(self.system_prompt)
|
||
total_tokens = system_tokens + persistent_tokens + context_tokens
|
||
|
||
# 计算百分比
|
||
context_limit = self.config.get("api", {}).get("context_limit", 200000)
|
||
usage_percent = (total_tokens / context_limit) * 100
|
||
remaining_tokens = context_limit - total_tokens
|
||
|
||
msg = f"📊 私聊上下文统计\n\n"
|
||
msg += f"💬 记忆条数: {len(memory_messages)}/{max_messages}\n"
|
||
msg += f"🤖 人设 Token: ~{system_tokens}\n"
|
||
msg += f"📌 持久记忆: {len(persistent_memories)} 条 (~{persistent_tokens} token)\n"
|
||
msg += f"📝 上下文 Token: ~{context_tokens}\n"
|
||
msg += f"📦 总计 Token: ~{total_tokens}\n"
|
||
msg += f"📈 使用率: {usage_percent:.1f}% (剩余 ~{remaining_tokens:,})\n"
|
||
msg += f"\n💡 /清空记忆 清空上下文 | /记忆列表 查看持久记忆"
|
||
|
||
await bot.send_text(from_wxid, msg)
|
||
logger.info(f"已发送上下文统计: {chat_id}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取上下文统计失败: {e}")
|
||
await bot.send_text(from_wxid, f"❌ 获取上下文统计失败: {str(e)}")
|
||
|
||
async def _handle_switch_prompt(self, bot, from_wxid: str, content: str):
|
||
"""处理切换人设指令"""
|
||
try:
|
||
# 提取文件名
|
||
parts = content.split(maxsplit=1)
|
||
if len(parts) < 2:
|
||
await bot.send_text(from_wxid, "❌ 请指定人设文件名\n格式:/切人设 文件名.txt")
|
||
return
|
||
|
||
filename = parts[1].strip()
|
||
|
||
# 检查文件是否存在
|
||
prompt_path = Path(__file__).parent / "prompts" / filename
|
||
if not prompt_path.exists():
|
||
await bot.send_text(from_wxid, f"❌ 人设文件不存在: {filename}")
|
||
return
|
||
|
||
# 读取新人设
|
||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||
new_prompt = f.read().strip()
|
||
|
||
# 更新人设
|
||
self.system_prompt = new_prompt
|
||
self.config["prompt"]["system_prompt_file"] = filename
|
||
|
||
await bot.send_text(from_wxid, f"✅ 已切换人设: {filename}")
|
||
logger.success(f"管理员切换人设: {filename}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"切换人设失败: {e}")
|
||
await bot.send_text(from_wxid, f"❌ 切换人设失败: {str(e)}")
|
||
|
||
@on_text_message(priority=80)
|
||
async def handle_message(self, bot, message: dict):
|
||
"""处理文本消息"""
|
||
content = message.get("Content", "").strip()
|
||
from_wxid = message.get("FromWxid", "")
|
||
sender_wxid = message.get("SenderWxid", "")
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
# 获取实际发送者
|
||
user_wxid = sender_wxid if is_group else from_wxid
|
||
|
||
# 获取机器人 wxid 和管理员列表
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
|
||
admins = main_config.get("Bot", {}).get("admins", [])
|
||
|
||
# 检查是否是人设列表指令(精确匹配)
|
||
if content == "/人设列表":
|
||
await self._handle_list_prompts(bot, from_wxid)
|
||
return False
|
||
|
||
# 检查是否是切换人设指令(精确匹配前缀)
|
||
if content.startswith("/切人设 ") or content.startswith("/切换人设 "):
|
||
if user_wxid in admins:
|
||
await self._handle_switch_prompt(bot, from_wxid, content)
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可以切换人设")
|
||
return False
|
||
|
||
# 检查是否是清空记忆指令
|
||
clear_command = self.config.get("memory", {}).get("clear_command", "/清空记忆")
|
||
if content == clear_command:
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
self._clear_memory(chat_id)
|
||
await bot.send_text(from_wxid, "✅ 已清空当前会话的记忆")
|
||
return False
|
||
|
||
# 检查是否是上下文统计指令
|
||
if content == "/context" or content == "/上下文":
|
||
await self._handle_context_stats(bot, from_wxid, user_wxid, is_group)
|
||
return False
|
||
|
||
# 检查是否是记忆状态指令(仅管理员)
|
||
if content == "/记忆状态":
|
||
if user_wxid in admins:
|
||
if is_group:
|
||
history = await self._load_history(from_wxid)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
context_count = min(len(history), max_context)
|
||
msg = f"📊 群聊记忆: {len(history)} 条\n"
|
||
msg += f"💬 AI可见: 最近 {context_count} 条"
|
||
await bot.send_text(from_wxid, msg)
|
||
else:
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
memory = self._get_memory_messages(chat_id)
|
||
msg = f"📊 私聊记忆: {len(memory)} 条"
|
||
await bot.send_text(from_wxid, msg)
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可以查看记忆状态")
|
||
return False
|
||
|
||
# 持久记忆相关指令
|
||
# 记录持久记忆:/记录 xxx
|
||
if content.startswith("/记录 "):
|
||
memory_content = content[4:].strip()
|
||
if memory_content:
|
||
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
|
||
# 群聊用群ID,私聊用用户ID
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
chat_type = "group" if is_group else "private"
|
||
memory_id = self._add_persistent_memory(
|
||
memory_chat_id, chat_type, user_wxid, nickname, memory_content
|
||
)
|
||
await bot.send_text(from_wxid, f"✅ 已记录到持久记忆 (ID: {memory_id})")
|
||
logger.info(f"添加持久记忆: {memory_chat_id} - {memory_content[:30]}...")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 请输入要记录的内容\n格式:/记录 要记住的内容")
|
||
return False
|
||
|
||
# 查看持久记忆列表(所有人可用)
|
||
if content == "/记忆列表" or content == "/持久记忆":
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
memories = self._get_persistent_memories(memory_chat_id)
|
||
if memories:
|
||
msg = f"📋 持久记忆列表 (共 {len(memories)} 条)\n\n"
|
||
for m in memories:
|
||
time_str = m['time'][:16] if m['time'] else "未知"
|
||
content_preview = m['content'][:30] + "..." if len(m['content']) > 30 else m['content']
|
||
msg += f"[{m['id']}] {m['nickname']}: {content_preview}\n 📅 {time_str}\n"
|
||
msg += f"\n💡 删除记忆:/删除记忆 ID (管理员)"
|
||
else:
|
||
msg = "📋 暂无持久记忆"
|
||
await bot.send_text(from_wxid, msg)
|
||
return False
|
||
|
||
# 删除持久记忆(管理员)
|
||
if content.startswith("/删除记忆 "):
|
||
if user_wxid in admins:
|
||
try:
|
||
memory_id = int(content[6:].strip())
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
if self._delete_persistent_memory(memory_chat_id, memory_id):
|
||
await bot.send_text(from_wxid, f"✅ 已删除记忆 ID: {memory_id}")
|
||
else:
|
||
await bot.send_text(from_wxid, f"❌ 未找到记忆 ID: {memory_id}")
|
||
except ValueError:
|
||
await bot.send_text(from_wxid, "❌ 请输入有效的记忆ID\n格式:/删除记忆 ID")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可以删除持久记忆")
|
||
return False
|
||
|
||
# 清空所有持久记忆(管理员)
|
||
if content == "/清空持久记忆":
|
||
if user_wxid in admins:
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
deleted_count = self._clear_persistent_memories(memory_chat_id)
|
||
await bot.send_text(from_wxid, f"✅ 已清空 {deleted_count} 条持久记忆")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可以清空持久记忆")
|
||
return False
|
||
|
||
# 检查是否应该回复
|
||
should_reply = self._should_reply(message, content, bot_wxid)
|
||
|
||
# 获取用户昵称(用于历史记录)- 使用缓存优化
|
||
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
|
||
|
||
# 保存到群组历史记录(所有消息都保存,不管是否回复)
|
||
if is_group:
|
||
await self._add_to_history(from_wxid, nickname, content)
|
||
|
||
# 如果不需要回复,直接返回
|
||
if not should_reply:
|
||
return
|
||
|
||
# 限流检查(仅在需要回复时检查)
|
||
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
|
||
if not allowed:
|
||
rate_limit_config = self.config.get("rate_limit", {})
|
||
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
|
||
msg = msg.format(seconds=reset_time)
|
||
await bot.send_text(from_wxid, msg)
|
||
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
|
||
return False
|
||
|
||
# 提取实际消息内容(去除@)
|
||
actual_content = self._extract_content(message, content)
|
||
if not actual_content:
|
||
return
|
||
|
||
logger.info(f"AI 处理消息: {actual_content[:50]}...")
|
||
|
||
try:
|
||
# 获取会话ID并添加用户消息到记忆
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
self._add_to_memory(chat_id, "user", actual_content)
|
||
|
||
# 使用统一的 Gemini API 处理消息
|
||
response = await self._process_with_gemini(
|
||
text=actual_content,
|
||
bot=bot,
|
||
from_wxid=from_wxid,
|
||
chat_id=chat_id,
|
||
nickname=nickname,
|
||
user_wxid=user_wxid,
|
||
is_group=is_group
|
||
)
|
||
|
||
# 发送回复并添加到记忆
|
||
# 注意:如果返回 None 或空字符串,说明已经以其他形式处理了(如工具调用)
|
||
if response:
|
||
await bot.send_text(from_wxid, response)
|
||
self._add_to_memory(chat_id, "assistant", response)
|
||
# 保存机器人回复到历史记录
|
||
if is_group:
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
|
||
await self._add_to_history(from_wxid, bot_nickname, response)
|
||
logger.success(f"AI 回复成功: {response[:50]}...")
|
||
else:
|
||
logger.info("AI 回复为空或已通过其他方式发送(如工具调用)")
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_detail = traceback.format_exc()
|
||
logger.error(f"AI 处理失败: {type(e).__name__}: {str(e)}")
|
||
logger.error(f"详细错误:\n{error_detail}")
|
||
await bot.send_text(from_wxid, "抱歉,我遇到了一些问题,请稍后再试。")
|
||
|
||
def _should_reply(self, message: dict, content: str, bot_wxid: str = None) -> bool:
|
||
"""判断是否应该回复"""
|
||
# 检查是否由AutoReply插件触发
|
||
if message.get('_auto_reply_triggered'):
|
||
return True
|
||
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
# 检查群聊/私聊开关
|
||
if is_group and not self.config["behavior"].get("reply_group", True):
|
||
return False
|
||
if not is_group and not self.config["behavior"].get("reply_private", True):
|
||
return False
|
||
|
||
trigger_mode = self.config["behavior"].get("trigger_mode", "mention")
|
||
|
||
# all 模式:回复所有消息
|
||
if trigger_mode == "all":
|
||
return True
|
||
|
||
# mention 模式:检查是否@了机器人
|
||
if trigger_mode == "mention":
|
||
if is_group:
|
||
ats = message.get("Ats", [])
|
||
|
||
# 如果没有 bot_wxid,从配置文件读取
|
||
if not bot_wxid:
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
|
||
else:
|
||
# 也需要读取昵称用于备用检测
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
|
||
|
||
# 方式1:检查 @ 列表中是否包含机器人的 wxid
|
||
if ats and bot_wxid and bot_wxid in ats:
|
||
return True
|
||
|
||
# 方式2:备用检测 - 从消息内容中检查是否包含 @机器人昵称
|
||
# (当 API 没有返回 at_user_list 时使用)
|
||
if bot_nickname and f"@{bot_nickname}" in content:
|
||
logger.debug(f"通过内容检测到 @{bot_nickname},触发回复")
|
||
return True
|
||
|
||
return False
|
||
else:
|
||
# 私聊直接回复
|
||
return True
|
||
|
||
# keyword 模式:检查关键词
|
||
if trigger_mode == "keyword":
|
||
keywords = self.config["behavior"]["keywords"]
|
||
return any(kw in content for kw in keywords)
|
||
|
||
return False
|
||
|
||
def _extract_content(self, message: dict, content: str) -> str:
|
||
"""提取实际消息内容(去除@等)"""
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
if is_group:
|
||
# 群聊消息,去除@部分
|
||
# 格式通常是 "@昵称 消息内容"
|
||
parts = content.split(maxsplit=1)
|
||
if len(parts) > 1 and parts[0].startswith("@"):
|
||
return parts[1].strip()
|
||
return content.strip()
|
||
|
||
return content.strip()
|
||
|
||
async def _call_ai_api(self, user_message: str, bot=None, from_wxid: str = None, chat_id: str = None, nickname: str = "", user_wxid: str = None, is_group: bool = False) -> str:
|
||
"""调用 AI API"""
|
||
api_config = self.config["api"]
|
||
|
||
# 收集工具
|
||
tools = self._collect_tools()
|
||
logger.info(f"收集到 {len(tools)} 个工具函数")
|
||
if tools:
|
||
tool_names = [t["function"]["name"] for t in tools]
|
||
logger.info(f"工具列表: {tool_names}")
|
||
|
||
# 构建消息列表
|
||
system_content = self.system_prompt
|
||
|
||
# 添加当前时间信息
|
||
current_time = datetime.now()
|
||
weekday_map = {
|
||
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
|
||
4: "星期五", 5: "星期六", 6: "星期日"
|
||
}
|
||
weekday = weekday_map[current_time.weekday()]
|
||
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
|
||
system_content += f"\n\n当前时间:{time_str}"
|
||
|
||
if nickname:
|
||
system_content += f"\n当前对话用户的昵称是:{nickname}"
|
||
|
||
# 加载持久记忆
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
if memory_chat_id:
|
||
persistent_memories = self._get_persistent_memories(memory_chat_id)
|
||
if persistent_memories:
|
||
system_content += "\n\n【持久记忆】以下是用户要求你记住的重要信息:\n"
|
||
for m in persistent_memories:
|
||
mem_time = m['time'][:10] if m['time'] else ""
|
||
system_content += f"- [{mem_time}] {m['nickname']}: {m['content']}\n"
|
||
|
||
messages = [{"role": "system", "content": system_content}]
|
||
|
||
# 从 JSON 历史记录加载上下文(仅群聊)
|
||
if is_group and from_wxid:
|
||
history = await self._load_history(from_wxid)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
|
||
# 取最近的 N 条消息作为上下文
|
||
recent_history = history[-max_context:] if len(history) > max_context else history
|
||
|
||
# 转换为 AI 消息格式
|
||
for msg in recent_history:
|
||
msg_nickname = msg.get("nickname", "")
|
||
msg_content = msg.get("content", "")
|
||
|
||
# 检查是否是多模态内容(包含图片)
|
||
if isinstance(msg_content, list):
|
||
# 多模态内容:添加昵称前缀到文本部分
|
||
content_with_nickname = []
|
||
for item in msg_content:
|
||
if item.get("type") == "text":
|
||
# 在文本前添加昵称
|
||
content_with_nickname.append({
|
||
"type": "text",
|
||
"text": f"[{msg_nickname}] {item.get('text', '')}"
|
||
})
|
||
else:
|
||
# 图片等其他类型直接保留
|
||
content_with_nickname.append(item)
|
||
|
||
messages.append({
|
||
"role": "user",
|
||
"content": content_with_nickname
|
||
})
|
||
else:
|
||
# 纯文本内容:简单格式 [昵称] 内容
|
||
messages.append({
|
||
"role": "user",
|
||
"content": f"[{msg_nickname}] {msg_content}"
|
||
})
|
||
else:
|
||
# 私聊使用原有的 memory 机制
|
||
if chat_id:
|
||
memory_messages = self._get_memory_messages(chat_id)
|
||
if memory_messages and len(memory_messages) > 1:
|
||
messages.extend(memory_messages[:-1])
|
||
|
||
# 添加当前用户消息
|
||
messages.append({"role": "user", "content": f"[{nickname}] {user_message}" if is_group and nickname else user_message})
|
||
|
||
# 保存用户信息供工具调用使用
|
||
self._current_user_wxid = user_wxid
|
||
self._current_is_group = is_group
|
||
|
||
payload = {
|
||
"model": api_config["model"],
|
||
"messages": messages,
|
||
"max_tokens": api_config.get("max_tokens", 4096) # 防止回复被截断
|
||
}
|
||
|
||
if tools:
|
||
payload["tools"] = tools
|
||
logger.debug(f"已将 {len(tools)} 个工具添加到请求中")
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_config['api_key']}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=api_config["timeout"])
|
||
|
||
# 配置代理
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_username = proxy_config.get("username")
|
||
proxy_password = proxy_config.get("password")
|
||
|
||
# 构建代理 URL
|
||
if proxy_username and proxy_password:
|
||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||
else:
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
|
||
if PROXY_SUPPORT:
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
logger.debug(f"使用代理: {proxy_type}://{proxy_host}:{proxy_port}")
|
||
except Exception as e:
|
||
logger.warning(f"代理配置失败,将直连: {e}")
|
||
connector = None
|
||
else:
|
||
logger.warning("代理功能不可用(aiohttp_socks 未安装),将直连")
|
||
connector = None
|
||
|
||
# 启用流式响应
|
||
payload["stream"] = True
|
||
|
||
try:
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
logger.debug(f"发送流式 API 请求: {api_config['url']}")
|
||
async with session.post(
|
||
api_config["url"],
|
||
json=payload,
|
||
headers=headers
|
||
) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"API 返回错误状态码: {resp.status}, 响应: {error_text}")
|
||
raise Exception(f"API 错误 {resp.status}: {error_text}")
|
||
|
||
# 流式接收响应
|
||
import json
|
||
full_content = ""
|
||
tool_calls_dict = {} # 使用字典来组装工具调用 {index: tool_call}
|
||
tool_call_hint_sent = False # 是否已发送工具调用提示
|
||
|
||
async for line in resp.content:
|
||
line = line.decode('utf-8').strip()
|
||
if not line or line == "data: [DONE]":
|
||
continue
|
||
|
||
if line.startswith("data: "):
|
||
try:
|
||
data = json.loads(line[6:])
|
||
choices = data.get("choices", [])
|
||
if not choices:
|
||
continue
|
||
|
||
delta = choices[0].get("delta", {})
|
||
|
||
# 收集内容
|
||
content = delta.get("content", "")
|
||
if content:
|
||
full_content += content
|
||
|
||
# 收集工具调用(增量式组装)
|
||
if delta.get("tool_calls"):
|
||
# 第一次检测到工具调用时,如果有文本内容则立即发送
|
||
if not tool_call_hint_sent and bot and from_wxid:
|
||
tool_call_hint_sent = True
|
||
# 只有当 AI 有文本输出时才发送
|
||
if full_content and full_content.strip():
|
||
logger.info(f"[流式] 检测到工具调用,先发送已有文本: {full_content[:30]}...")
|
||
await bot.send_text(from_wxid, full_content.strip())
|
||
else:
|
||
# AI 没有输出文本,不发送默认提示
|
||
logger.info("[流式] 检测到工具调用,AI 未输出文本")
|
||
|
||
for tool_call_delta in delta["tool_calls"]:
|
||
index = tool_call_delta.get("index", 0)
|
||
|
||
# 初始化工具调用
|
||
if index not in tool_calls_dict:
|
||
tool_calls_dict[index] = {
|
||
"id": "",
|
||
"type": "function",
|
||
"function": {
|
||
"name": "",
|
||
"arguments": ""
|
||
}
|
||
}
|
||
|
||
# 更新 id
|
||
if "id" in tool_call_delta:
|
||
tool_calls_dict[index]["id"] = tool_call_delta["id"]
|
||
|
||
# 更新 type
|
||
if "type" in tool_call_delta:
|
||
tool_calls_dict[index]["type"] = tool_call_delta["type"]
|
||
|
||
# 更新 function
|
||
if "function" in tool_call_delta:
|
||
func_delta = tool_call_delta["function"]
|
||
if "name" in func_delta:
|
||
tool_calls_dict[index]["function"]["name"] += func_delta["name"]
|
||
if "arguments" in func_delta:
|
||
tool_calls_dict[index]["function"]["arguments"] += func_delta["arguments"]
|
||
except Exception as e:
|
||
logger.debug(f"解析流式数据失败: {e}")
|
||
pass
|
||
|
||
# 转换为列表
|
||
tool_calls_data = [tool_calls_dict[i] for i in sorted(tool_calls_dict.keys())] if tool_calls_dict else []
|
||
|
||
logger.info(f"流式 API 响应完成, 内容长度: {len(full_content)}, 工具调用数: {len(tool_calls_data)}")
|
||
|
||
# 检查是否有函数调用
|
||
if tool_calls_data:
|
||
# 提示已在流式处理中发送,直接启动异步工具执行
|
||
logger.info(f"启动异步工具执行,共 {len(tool_calls_data)} 个工具")
|
||
asyncio.create_task(
|
||
self._execute_tools_async(
|
||
tool_calls_data, bot, from_wxid, chat_id,
|
||
nickname, is_group, messages
|
||
)
|
||
)
|
||
# 返回 None 表示工具调用已异步处理,不需要重试
|
||
return None
|
||
|
||
# 检查是否包含错误的工具调用格式
|
||
if "<tool_code>" in full_content or "print(" in full_content and "flow2_ai_image_generation" in full_content:
|
||
logger.warning("检测到模型输出了错误的工具调用格式,拦截并返回提示")
|
||
return "抱歉,我遇到了一些技术问题,请重新描述一下你的需求~"
|
||
|
||
return full_content.strip()
|
||
except aiohttp.ClientError as e:
|
||
logger.error(f"网络请求失败: {type(e).__name__}: {str(e)}")
|
||
raise Exception(f"网络请求失败: {str(e)}")
|
||
except asyncio.TimeoutError:
|
||
logger.error(f"API 请求超时 (timeout={api_config['timeout']}s)")
|
||
raise Exception(f"API 请求超时")
|
||
except KeyError as e:
|
||
logger.error(f"API 响应格式错误,缺少字段: {e}")
|
||
raise Exception(f"API 响应格式错误: {e}")
|
||
|
||
|
||
def _get_history_file(self, chat_id: str) -> Path:
|
||
"""获取群聊历史记录文件路径"""
|
||
if not self.history_dir:
|
||
return None
|
||
safe_name = chat_id.replace("@", "_").replace(":", "_")
|
||
return self.history_dir / f"{safe_name}.json"
|
||
|
||
def _get_history_lock(self, chat_id: str) -> asyncio.Lock:
|
||
"""获取指定会话的锁, 每个会话一把"""
|
||
lock = self.history_locks.get(chat_id)
|
||
if lock is None:
|
||
lock = asyncio.Lock()
|
||
self.history_locks[chat_id] = lock
|
||
return lock
|
||
|
||
def _read_history_file(self, history_file: Path) -> list:
|
||
try:
|
||
import json
|
||
with open(history_file, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
except FileNotFoundError:
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"读取历史记录失败: {e}")
|
||
return []
|
||
|
||
def _write_history_file(self, history_file: Path, history: list):
|
||
import json
|
||
history_file.parent.mkdir(parents=True, exist_ok=True)
|
||
temp_file = Path(str(history_file) + ".tmp")
|
||
with open(temp_file, "w", encoding="utf-8") as f:
|
||
json.dump(history, f, ensure_ascii=False, indent=2)
|
||
temp_file.replace(history_file)
|
||
|
||
def _use_redis_for_group_history(self) -> bool:
|
||
"""检查是否使用 Redis 存储群聊历史"""
|
||
redis_config = self.config.get("redis", {})
|
||
if not redis_config.get("use_redis_history", True):
|
||
return False
|
||
redis_cache = get_cache()
|
||
return redis_cache and redis_cache.enabled
|
||
|
||
async def _load_history(self, chat_id: str) -> list:
|
||
"""异步读取群聊历史, 优先使用 Redis"""
|
||
# 优先使用 Redis
|
||
if self._use_redis_for_group_history():
|
||
redis_cache = get_cache()
|
||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||
return redis_cache.get_group_history(chat_id, max_history)
|
||
|
||
# 降级到文件存储
|
||
history_file = self._get_history_file(chat_id)
|
||
if not history_file:
|
||
return []
|
||
lock = self._get_history_lock(chat_id)
|
||
async with lock:
|
||
return self._read_history_file(history_file)
|
||
|
||
async def _save_history(self, chat_id: str, history: list):
|
||
"""异步写入群聊历史, 包含长度截断"""
|
||
# Redis 模式下不需要单独保存,add_group_message 已经处理
|
||
if self._use_redis_for_group_history():
|
||
return
|
||
|
||
history_file = self._get_history_file(chat_id)
|
||
if not history_file:
|
||
return
|
||
|
||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||
if len(history) > max_history:
|
||
history = history[-max_history:]
|
||
|
||
lock = self._get_history_lock(chat_id)
|
||
async with lock:
|
||
self._write_history_file(history_file, history)
|
||
|
||
async def _add_to_history(self, chat_id: str, nickname: str, content: str, image_base64: str = None):
|
||
"""
|
||
将消息存入群聊历史
|
||
|
||
Args:
|
||
chat_id: 群聊ID
|
||
nickname: 用户昵称
|
||
content: 消息内容
|
||
image_base64: 可选的图片base64
|
||
"""
|
||
if not self.config.get("history", {}).get("enabled", True):
|
||
return
|
||
|
||
# 构建消息内容
|
||
if image_base64:
|
||
message_content = [
|
||
{"type": "text", "text": content},
|
||
{"type": "image_url", "image_url": {"url": image_base64}}
|
||
]
|
||
else:
|
||
message_content = content
|
||
|
||
# 优先使用 Redis
|
||
if self._use_redis_for_group_history():
|
||
redis_cache = get_cache()
|
||
redis_config = self.config.get("redis", {})
|
||
ttl = redis_config.get("group_history_ttl", 172800)
|
||
redis_cache.add_group_message(chat_id, nickname, message_content, ttl=ttl)
|
||
# 裁剪历史
|
||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||
redis_cache.trim_group_history(chat_id, max_history)
|
||
return
|
||
|
||
# 降级到文件存储
|
||
history_file = self._get_history_file(chat_id)
|
||
if not history_file:
|
||
return
|
||
|
||
lock = self._get_history_lock(chat_id)
|
||
async with lock:
|
||
history = self._read_history_file(history_file)
|
||
|
||
message_record = {
|
||
"nickname": nickname,
|
||
"timestamp": datetime.now().isoformat(),
|
||
"content": message_content
|
||
}
|
||
|
||
history.append(message_record)
|
||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||
if len(history) > max_history:
|
||
history = history[-max_history:]
|
||
|
||
self._write_history_file(history_file, history)
|
||
|
||
async def _add_to_history_with_id(self, chat_id: str, nickname: str, content: str, record_id: str):
|
||
"""带ID的历史追加, 便于后续更新"""
|
||
if not self.config.get("history", {}).get("enabled", True):
|
||
return
|
||
|
||
# 优先使用 Redis
|
||
if self._use_redis_for_group_history():
|
||
redis_cache = get_cache()
|
||
redis_config = self.config.get("redis", {})
|
||
ttl = redis_config.get("group_history_ttl", 172800)
|
||
redis_cache.add_group_message(chat_id, nickname, content, record_id=record_id, ttl=ttl)
|
||
# 裁剪历史
|
||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||
redis_cache.trim_group_history(chat_id, max_history)
|
||
return
|
||
|
||
# 降级到文件存储
|
||
history_file = self._get_history_file(chat_id)
|
||
if not history_file:
|
||
return
|
||
|
||
lock = self._get_history_lock(chat_id)
|
||
async with lock:
|
||
history = self._read_history_file(history_file)
|
||
message_record = {
|
||
"id": record_id,
|
||
"nickname": nickname,
|
||
"timestamp": datetime.now().isoformat(),
|
||
"content": content
|
||
}
|
||
history.append(message_record)
|
||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||
if len(history) > max_history:
|
||
history = history[-max_history:]
|
||
self._write_history_file(history_file, history)
|
||
|
||
async def _update_history_by_id(self, chat_id: str, record_id: str, new_content: str):
|
||
"""根据ID更新历史记录"""
|
||
if not self.config.get("history", {}).get("enabled", True):
|
||
return
|
||
|
||
# 优先使用 Redis
|
||
if self._use_redis_for_group_history():
|
||
redis_cache = get_cache()
|
||
redis_cache.update_group_message_by_id(chat_id, record_id, new_content)
|
||
return
|
||
|
||
# 降级到文件存储
|
||
history_file = self._get_history_file(chat_id)
|
||
if not history_file:
|
||
return
|
||
|
||
lock = self._get_history_lock(chat_id)
|
||
async with lock:
|
||
history = self._read_history_file(history_file)
|
||
for record in history:
|
||
if record.get("id") == record_id:
|
||
record["content"] = new_content
|
||
break
|
||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||
if len(history) > max_history:
|
||
history = history[-max_history:]
|
||
self._write_history_file(history_file, history)
|
||
|
||
|
||
async def _execute_tool_and_get_result(self, tool_name: str, arguments: dict, bot, from_wxid: str):
|
||
"""执行工具调用并返回结果"""
|
||
from utils.plugin_manager import PluginManager
|
||
|
||
# 添加用户信息到 arguments
|
||
arguments["user_wxid"] = getattr(self, "_current_user_wxid", from_wxid)
|
||
arguments["is_group"] = getattr(self, "_current_is_group", False)
|
||
|
||
logger.info(f"开始执行工具: {tool_name}")
|
||
|
||
plugins = PluginManager().plugins
|
||
logger.info(f"检查 {len(plugins)} 个插件")
|
||
|
||
for plugin_name, plugin in plugins.items():
|
||
logger.debug(f"检查插件: {plugin_name}, 有execute_llm_tool: {hasattr(plugin, 'execute_llm_tool')}")
|
||
if hasattr(plugin, 'execute_llm_tool'):
|
||
try:
|
||
logger.info(f"调用 {plugin_name}.execute_llm_tool")
|
||
result = await plugin.execute_llm_tool(tool_name, arguments, bot, from_wxid)
|
||
logger.info(f"{plugin_name} 返回: {result}")
|
||
if result is not None:
|
||
if result.get("success"):
|
||
logger.success(f"工具执行成功: {tool_name}")
|
||
return result
|
||
else:
|
||
logger.debug(f"{plugin_name} 不处理此工具,继续检查下一个插件")
|
||
except Exception as e:
|
||
logger.error(f"工具执行异常 ({plugin_name}): {tool_name}, {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
|
||
logger.warning(f"未找到工具: {tool_name}")
|
||
return {"success": False, "message": f"未找到工具: {tool_name}"}
|
||
|
||
async def _execute_tools_async(self, tool_calls_data: list, bot, from_wxid: str,
|
||
chat_id: str, nickname: str, is_group: bool,
|
||
messages: list):
|
||
"""
|
||
异步执行工具调用(不阻塞主流程)
|
||
|
||
AI 已经先回复用户,这里异步执行工具,完成后发送结果
|
||
支持 need_ai_reply 标记:工具结果回传给 AI 继续对话(保留上下文和人设)
|
||
"""
|
||
import json
|
||
|
||
try:
|
||
logger.info(f"开始异步执行 {len(tool_calls_data)} 个工具调用")
|
||
|
||
# 并行执行所有工具
|
||
tasks = []
|
||
tool_info_list = [] # 保存工具信息用于后续处理
|
||
|
||
for tool_call in tool_calls_data:
|
||
function_name = tool_call.get("function", {}).get("name", "")
|
||
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
|
||
tool_call_id = tool_call.get("id", "")
|
||
|
||
if not function_name:
|
||
continue
|
||
|
||
try:
|
||
arguments = json.loads(arguments_str)
|
||
except:
|
||
arguments = {}
|
||
|
||
logger.info(f"[异步] 准备执行工具: {function_name}, 参数: {arguments}")
|
||
|
||
# 创建异步任务
|
||
task = self._execute_tool_and_get_result(function_name, arguments, bot, from_wxid)
|
||
tasks.append(task)
|
||
tool_info_list.append({
|
||
"tool_call_id": tool_call_id,
|
||
"function_name": function_name,
|
||
"arguments": arguments
|
||
})
|
||
|
||
# 并行执行所有工具
|
||
if tasks:
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
# 收集需要 AI 回复的工具结果
|
||
need_ai_reply_results = []
|
||
|
||
# 处理每个工具的结果
|
||
for i, result in enumerate(results):
|
||
tool_info = tool_info_list[i]
|
||
function_name = tool_info["function_name"]
|
||
tool_call_id = tool_info["tool_call_id"]
|
||
|
||
if isinstance(result, Exception):
|
||
logger.error(f"[异步] 工具 {function_name} 执行异常: {result}")
|
||
# 发送错误提示
|
||
await bot.send_text(from_wxid, f"❌ {function_name} 执行失败")
|
||
continue
|
||
|
||
if result and result.get("success"):
|
||
logger.success(f"[异步] 工具 {function_name} 执行成功")
|
||
|
||
# 检查是否需要 AI 基于工具结果继续回复
|
||
if result.get("need_ai_reply"):
|
||
need_ai_reply_results.append({
|
||
"tool_call_id": tool_call_id,
|
||
"function_name": function_name,
|
||
"result": result.get("message", "")
|
||
})
|
||
continue # 不直接发送,等待 AI 处理
|
||
|
||
# 如果工具没有自己发送内容,且有消息需要发送
|
||
if not result.get("already_sent") and result.get("message"):
|
||
# 某些工具可能需要发送结果消息
|
||
msg = result.get("message", "")
|
||
if msg and not result.get("no_reply"):
|
||
# 检查是否需要发送文本结果
|
||
if result.get("send_result_text"):
|
||
await bot.send_text(from_wxid, msg)
|
||
|
||
# 保存工具结果到记忆(可选)
|
||
if result.get("save_to_memory") and chat_id:
|
||
self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {result.get('message', '')}")
|
||
else:
|
||
logger.warning(f"[异步] 工具 {function_name} 执行失败: {result}")
|
||
if result and result.get("message"):
|
||
await bot.send_text(from_wxid, f"❌ {result.get('message')}")
|
||
|
||
# 如果有需要 AI 回复的工具结果,调用 AI 继续对话
|
||
if need_ai_reply_results:
|
||
await self._continue_with_tool_results(
|
||
need_ai_reply_results, bot, from_wxid, chat_id,
|
||
nickname, is_group, messages, tool_calls_data
|
||
)
|
||
|
||
logger.info(f"[异步] 所有工具执行完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[异步] 工具执行总体异常: {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
try:
|
||
await bot.send_text(from_wxid, "❌ 工具执行过程中出现错误")
|
||
except:
|
||
pass
|
||
|
||
async def _continue_with_tool_results(self, tool_results: list, bot, from_wxid: str,
|
||
chat_id: str, nickname: str, is_group: bool,
|
||
messages: list, tool_calls_data: list):
|
||
"""
|
||
基于工具结果继续调用 AI 对话(保留上下文和人设)
|
||
|
||
用于 need_ai_reply=True 的工具,如联网搜索等
|
||
"""
|
||
import json
|
||
|
||
try:
|
||
logger.info(f"[工具回传] 开始基于 {len(tool_results)} 个工具结果继续对话")
|
||
|
||
# 构建包含工具调用和结果的消息
|
||
# 1. 添加 assistant 的工具调用消息
|
||
tool_calls_msg = []
|
||
for tool_call in tool_calls_data:
|
||
tool_call_id = tool_call.get("id", "")
|
||
function_name = tool_call.get("function", {}).get("name", "")
|
||
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
|
||
|
||
# 只添加需要 AI 回复的工具
|
||
for tr in tool_results:
|
||
if tr["tool_call_id"] == tool_call_id:
|
||
tool_calls_msg.append({
|
||
"id": tool_call_id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": function_name,
|
||
"arguments": arguments_str
|
||
}
|
||
})
|
||
break
|
||
|
||
if tool_calls_msg:
|
||
messages.append({
|
||
"role": "assistant",
|
||
"content": None,
|
||
"tool_calls": tool_calls_msg
|
||
})
|
||
|
||
# 2. 添加工具结果消息
|
||
for tr in tool_results:
|
||
messages.append({
|
||
"role": "tool",
|
||
"tool_call_id": tr["tool_call_id"],
|
||
"content": tr["result"]
|
||
})
|
||
|
||
# 3. 调用 AI 继续对话(不带 tools 参数,避免再次调用工具)
|
||
api_config = self.config["api"]
|
||
proxy_config = self.config.get("proxy", {})
|
||
|
||
payload = {
|
||
"model": api_config["model"],
|
||
"messages": messages,
|
||
"max_tokens": api_config.get("max_tokens", 4096),
|
||
"stream": True
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_config['api_key']}"
|
||
}
|
||
|
||
proxy = None
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "http")
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
|
||
timeout = aiohttp.ClientTimeout(total=api_config.get("timeout", 120))
|
||
|
||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||
async with session.post(
|
||
api_config["url"],
|
||
json=payload,
|
||
headers=headers,
|
||
proxy=proxy
|
||
) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"[工具回传] AI API 错误: {resp.status}, {error_text}")
|
||
await bot.send_text(from_wxid, "❌ AI 处理搜索结果失败")
|
||
return
|
||
|
||
# 流式读取响应
|
||
full_content = ""
|
||
async for line in resp.content:
|
||
line = line.decode("utf-8").strip()
|
||
if not line or not line.startswith("data: "):
|
||
continue
|
||
if line == "data: [DONE]":
|
||
break
|
||
|
||
try:
|
||
data = json.loads(line[6:])
|
||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||
content = delta.get("content", "")
|
||
if content:
|
||
full_content += content
|
||
except:
|
||
continue
|
||
|
||
# 发送 AI 的回复
|
||
if full_content.strip():
|
||
await bot.send_text(from_wxid, full_content.strip())
|
||
logger.success(f"[工具回传] AI 回复完成,长度: {len(full_content)}")
|
||
|
||
# 保存到历史记录
|
||
if chat_id:
|
||
self._add_to_memory(chat_id, "assistant", full_content.strip())
|
||
else:
|
||
logger.warning("[工具回传] AI 返回空内容")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[工具回传] 继续对话失败: {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
try:
|
||
await bot.send_text(from_wxid, "❌ 处理搜索结果时出错")
|
||
except:
|
||
pass
|
||
|
||
async def _execute_tools_async_with_image(self, tool_calls_data: list, bot, from_wxid: str,
|
||
chat_id: str, nickname: str, is_group: bool,
|
||
messages: list, image_base64: str):
|
||
"""
|
||
异步执行工具调用(带图片参数,用于图生图等场景)
|
||
|
||
AI 已经先回复用户,这里异步执行工具,完成后发送结果
|
||
"""
|
||
import json
|
||
|
||
try:
|
||
logger.info(f"[异步-图片] 开始执行 {len(tool_calls_data)} 个工具调用")
|
||
|
||
# 并行执行所有工具
|
||
tasks = []
|
||
tool_info_list = []
|
||
|
||
for tool_call in tool_calls_data:
|
||
function_name = tool_call.get("function", {}).get("name", "")
|
||
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
|
||
tool_call_id = tool_call.get("id", "")
|
||
|
||
if not function_name:
|
||
continue
|
||
|
||
try:
|
||
arguments = json.loads(arguments_str)
|
||
except:
|
||
arguments = {}
|
||
|
||
# 如果是图生图工具,添加图片 base64
|
||
if function_name == "flow2_ai_image_generation" and image_base64:
|
||
arguments["image_base64"] = image_base64
|
||
logger.info(f"[异步-图片] 图生图工具,已添加图片数据")
|
||
|
||
logger.info(f"[异步-图片] 准备执行工具: {function_name}")
|
||
|
||
task = self._execute_tool_and_get_result(function_name, arguments, bot, from_wxid)
|
||
tasks.append(task)
|
||
tool_info_list.append({
|
||
"tool_call_id": tool_call_id,
|
||
"function_name": function_name,
|
||
"arguments": arguments
|
||
})
|
||
|
||
# 并行执行所有工具
|
||
if tasks:
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
for i, result in enumerate(results):
|
||
tool_info = tool_info_list[i]
|
||
function_name = tool_info["function_name"]
|
||
|
||
if isinstance(result, Exception):
|
||
logger.error(f"[异步-图片] 工具 {function_name} 执行异常: {result}")
|
||
await bot.send_text(from_wxid, f"❌ {function_name} 执行失败")
|
||
continue
|
||
|
||
if result and result.get("success"):
|
||
logger.success(f"[异步-图片] 工具 {function_name} 执行成功")
|
||
|
||
if not result.get("already_sent") and result.get("message"):
|
||
msg = result.get("message", "")
|
||
if msg and not result.get("no_reply") and result.get("send_result_text"):
|
||
await bot.send_text(from_wxid, msg)
|
||
|
||
if result.get("save_to_memory") and chat_id:
|
||
self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {result.get('message', '')}")
|
||
else:
|
||
logger.warning(f"[异步-图片] 工具 {function_name} 执行失败: {result}")
|
||
if result and result.get("message"):
|
||
await bot.send_text(from_wxid, f"❌ {result.get('message')}")
|
||
|
||
logger.info(f"[异步-图片] 所有工具执行完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[异步-图片] 工具执行总体异常: {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
try:
|
||
await bot.send_text(from_wxid, "❌ 工具执行过程中出现错误")
|
||
except:
|
||
pass
|
||
|
||
@on_quote_message(priority=79)
|
||
async def handle_quote_message(self, bot, message: dict):
|
||
"""处理引用消息(包含图片或记录指令)"""
|
||
content = message.get("Content", "").strip()
|
||
from_wxid = message.get("FromWxid", "")
|
||
sender_wxid = message.get("SenderWxid", "")
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
user_wxid = sender_wxid if is_group else from_wxid
|
||
|
||
try:
|
||
# 解析XML获取标题和引用消息
|
||
root = ET.fromstring(content)
|
||
title = root.find(".//title")
|
||
if title is None or not title.text:
|
||
logger.debug("引用消息没有标题,跳过")
|
||
return True
|
||
|
||
title_text = title.text.strip()
|
||
logger.info(f"收到引用消息,标题: {title_text[:50]}...")
|
||
|
||
# 检查是否是 /记录 指令(引用消息记录)
|
||
if title_text == "/记录" or title_text.startswith("/记录 "):
|
||
# 获取被引用的消息内容
|
||
refermsg = root.find(".//refermsg")
|
||
if refermsg is not None:
|
||
# 获取被引用消息的发送者昵称
|
||
refer_displayname = refermsg.find("displayname")
|
||
refer_nickname = refer_displayname.text if refer_displayname is not None and refer_displayname.text else "未知"
|
||
|
||
# 获取被引用消息的内容
|
||
refer_content_elem = refermsg.find("content")
|
||
if refer_content_elem is not None and refer_content_elem.text:
|
||
refer_text = refer_content_elem.text.strip()
|
||
# 如果是XML格式(如图片),尝试提取文本描述
|
||
if refer_text.startswith("<?xml") or refer_text.startswith("<"):
|
||
refer_text = f"[多媒体消息]"
|
||
else:
|
||
refer_text = "[空消息]"
|
||
|
||
# 组合记忆内容:被引用者说的话
|
||
memory_content = f"{refer_nickname}: {refer_text}"
|
||
|
||
# 如果 /记录 后面有额外备注,添加到记忆中
|
||
if title_text.startswith("/记录 "):
|
||
extra_note = title_text[4:].strip()
|
||
if extra_note:
|
||
memory_content += f" (备注: {extra_note})"
|
||
|
||
# 保存到持久记忆
|
||
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
chat_type = "group" if is_group else "private"
|
||
memory_id = self._add_persistent_memory(
|
||
memory_chat_id, chat_type, user_wxid, nickname, memory_content
|
||
)
|
||
await bot.send_text(from_wxid, f"✅ 已记录到持久记忆 (ID: {memory_id})\n📝 {memory_content[:50]}...")
|
||
logger.info(f"通过引用添加持久记忆: {memory_chat_id} - {memory_content[:30]}...")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 无法获取被引用的消息")
|
||
return False
|
||
|
||
# 检查是否应该回复
|
||
if not self._should_reply_quote(message, title_text):
|
||
logger.debug("引用消息不满足回复条件")
|
||
return True
|
||
|
||
# 获取引用消息中的图片信息
|
||
refermsg = root.find(".//refermsg")
|
||
if refermsg is None:
|
||
logger.debug("引用消息中没有 refermsg 节点")
|
||
return True
|
||
|
||
refer_content = refermsg.find("content")
|
||
if refer_content is None or not refer_content.text:
|
||
logger.debug("引用消息中没有 content")
|
||
return True
|
||
|
||
# 解码HTML实体
|
||
import html
|
||
refer_xml = html.unescape(refer_content.text)
|
||
refer_root = ET.fromstring(refer_xml)
|
||
|
||
# 尝试提取图片信息
|
||
img = refer_root.find(".//img")
|
||
# 尝试提取视频信息
|
||
video = refer_root.find(".//videomsg")
|
||
|
||
if img is None and video is None:
|
||
logger.debug("引用的消息不是图片或视频")
|
||
return True
|
||
|
||
# 检查是否应该回复(提前检查,避免下载后才发现不需要回复)
|
||
if not self._should_reply_quote(message, title_text):
|
||
logger.debug("引用消息不满足回复条件")
|
||
return True
|
||
|
||
# 限流检查
|
||
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
|
||
if not allowed:
|
||
rate_limit_config = self.config.get("rate_limit", {})
|
||
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
|
||
msg = msg.format(seconds=reset_time)
|
||
await bot.send_text(from_wxid, msg)
|
||
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
|
||
return False
|
||
|
||
# 获取用户昵称 - 使用缓存优化
|
||
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
|
||
# 处理视频消息
|
||
if video is not None:
|
||
return await self._handle_quote_video(
|
||
bot, video, title_text, from_wxid, user_wxid,
|
||
is_group, nickname, chat_id
|
||
)
|
||
|
||
# 处理图片消息
|
||
cdnbigimgurl = img.get("cdnbigimgurl", "")
|
||
aeskey = img.get("aeskey", "")
|
||
|
||
if not cdnbigimgurl or not aeskey:
|
||
logger.warning(f"图片信息不完整: cdnurl={bool(cdnbigimgurl)}, aeskey={bool(aeskey)}")
|
||
return True
|
||
|
||
logger.info(f"AI处理引用图片消息: {title_text[:50]}...")
|
||
|
||
# 下载并编码图片
|
||
logger.info(f"开始下载图片: {cdnbigimgurl[:50]}...")
|
||
image_base64 = await self._download_and_encode_image(bot, cdnbigimgurl, aeskey)
|
||
if not image_base64:
|
||
logger.error("图片下载失败")
|
||
await bot.send_text(from_wxid, "❌ 无法处理图片")
|
||
return False
|
||
logger.info("图片下载和编码成功")
|
||
|
||
# 添加消息到记忆(包含图片base64)
|
||
self._add_to_memory(chat_id, "user", title_text, image_base64=image_base64)
|
||
|
||
# 保存用户引用图片消息到群组历史记录
|
||
if is_group:
|
||
await self._add_to_history(from_wxid, nickname, title_text, image_base64=image_base64)
|
||
|
||
# 使用统一的 Gemini API 处理图片消息
|
||
response = await self._process_with_gemini(
|
||
text=title_text,
|
||
image_base64=image_base64,
|
||
bot=bot,
|
||
from_wxid=from_wxid,
|
||
chat_id=chat_id,
|
||
nickname=nickname,
|
||
user_wxid=user_wxid,
|
||
is_group=is_group
|
||
)
|
||
|
||
if response:
|
||
await bot.send_text(from_wxid, response)
|
||
self._add_to_memory(chat_id, "assistant", response)
|
||
# 保存机器人回复到历史记录
|
||
if is_group:
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
|
||
await self._add_to_history(from_wxid, bot_nickname, response)
|
||
logger.success(f"AI回复成功: {response[:50]}...")
|
||
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理引用消息失败: {e}")
|
||
return True
|
||
|
||
async def _handle_quote_video(self, bot, video_elem, title_text: str, from_wxid: str,
|
||
user_wxid: str, is_group: bool, nickname: str, chat_id: str):
|
||
"""处理引用的视频消息 - 统一 Gemini API(直接处理视频)"""
|
||
try:
|
||
# 提取视频 CDN 信息
|
||
cdnvideourl = video_elem.get("cdnvideourl", "")
|
||
aeskey = video_elem.get("aeskey", "")
|
||
|
||
# 如果主要的CDN信息为空,尝试获取原始视频信息
|
||
if not cdnvideourl or not aeskey:
|
||
cdnvideourl = video_elem.get("cdnrawvideourl", "")
|
||
aeskey = video_elem.get("cdnrawvideoaeskey", "")
|
||
|
||
if not cdnvideourl or not aeskey:
|
||
logger.warning(f"[视频] 视频信息不完整: cdnurl={bool(cdnvideourl)}, aeskey={bool(aeskey)}")
|
||
await bot.send_text(from_wxid, "❌ 无法获取视频信息")
|
||
return False
|
||
|
||
logger.info(f"[视频] 处理引用视频: {title_text[:50]}...")
|
||
|
||
# 提示用户正在处理
|
||
await bot.send_text(from_wxid, "🎬 正在分析视频,请稍候...")
|
||
|
||
# 下载并编码视频
|
||
video_base64 = await self._download_and_encode_video(bot, cdnvideourl, aeskey)
|
||
if not video_base64:
|
||
logger.error("[视频] 视频下载失败")
|
||
await bot.send_text(from_wxid, "❌ 视频下载失败")
|
||
return False
|
||
|
||
logger.info("[视频] 视频下载和编码成功")
|
||
|
||
# 用户问题
|
||
user_question = title_text.strip() if title_text.strip() else "这个视频讲了什么?"
|
||
|
||
# 添加到记忆
|
||
self._add_to_memory(chat_id, "user", f"[发送了一个视频] {user_question}")
|
||
|
||
# 如果是群聊,添加到历史记录
|
||
if is_group:
|
||
await self._add_to_history(from_wxid, nickname, f"[发送了一个视频] {user_question}")
|
||
|
||
# 使用统一的 Gemini API 直接处理视频(不再需要两步架构)
|
||
response = await self._process_with_gemini(
|
||
text=user_question,
|
||
video_base64=video_base64,
|
||
bot=bot,
|
||
from_wxid=from_wxid,
|
||
chat_id=chat_id,
|
||
nickname=nickname,
|
||
user_wxid=user_wxid,
|
||
is_group=is_group
|
||
)
|
||
|
||
if response:
|
||
await bot.send_text(from_wxid, response)
|
||
self._add_to_memory(chat_id, "assistant", response)
|
||
# 保存机器人回复到历史记录
|
||
if is_group:
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
|
||
await self._add_to_history(from_wxid, bot_nickname, response)
|
||
logger.success(f"[视频] AI回复成功: {response[:50]}...")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ AI 回复生成失败")
|
||
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"[视频识别] 处理视频失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
await bot.send_text(from_wxid, "❌ 视频处理出错")
|
||
return False
|
||
|
||
async def _analyze_video_content(self, video_base64: str, video_config: dict) -> str:
|
||
"""视频AI:专门分析视频内容,生成客观描述"""
|
||
try:
|
||
api_url = video_config.get("api_url", "https://api.functen.cn/v1beta/models")
|
||
api_key = video_config.get("api_key", self.config["api"]["api_key"])
|
||
model = video_config.get("model", "gemini-3-pro-preview")
|
||
|
||
full_url = f"{api_url}/{model}:generateContent"
|
||
|
||
# 去除 data:video/mp4;base64, 前缀(如果有)
|
||
if video_base64.startswith("data:"):
|
||
video_base64 = video_base64.split(",", 1)[1]
|
||
logger.debug("[视频AI] 已去除 base64 前缀")
|
||
|
||
# 视频分析专用提示词
|
||
analyze_prompt = """请详细分析这个视频的内容,包括:
|
||
1. 视频的主要场景和环境
|
||
2. 出现的人物/物体及其动作
|
||
3. 视频中的文字、对话或声音(如果有)
|
||
4. 视频的整体主题或要表达的内容
|
||
5. 任何值得注意的细节
|
||
|
||
请用客观、详细的方式描述,不要加入主观评价。"""
|
||
|
||
payload = {
|
||
"contents": [
|
||
{
|
||
"parts": [
|
||
{"text": analyze_prompt},
|
||
{
|
||
"inline_data": {
|
||
"mime_type": "video/mp4",
|
||
"data": video_base64
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
"generationConfig": {
|
||
"maxOutputTokens": video_config.get("max_tokens", 8192)
|
||
}
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=video_config.get("timeout", 360))
|
||
|
||
logger.info(f"[视频AI] 开始分析视频...")
|
||
|
||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||
async with session.post(full_url, json=payload, headers=headers) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"[视频AI] API 错误: {resp.status}, {error_text[:300]}")
|
||
return ""
|
||
|
||
result = await resp.json()
|
||
logger.info(f"[视频AI] API 响应 keys: {list(result.keys())}")
|
||
|
||
# 检查安全过滤
|
||
if "promptFeedback" in result:
|
||
feedback = result["promptFeedback"]
|
||
if feedback.get("blockReason"):
|
||
logger.warning(f"[视频AI] 内容被过滤: {feedback.get('blockReason')}")
|
||
return ""
|
||
|
||
# 提取文本
|
||
if "candidates" in result and result["candidates"]:
|
||
for candidate in result["candidates"]:
|
||
# 检查是否被安全过滤
|
||
if candidate.get("finishReason") == "SAFETY":
|
||
logger.warning("[视频AI] 响应被安全过滤")
|
||
return ""
|
||
|
||
content = candidate.get("content", {})
|
||
for part in content.get("parts", []):
|
||
if "text" in part:
|
||
text = part["text"]
|
||
logger.info(f"[视频AI] 分析完成,长度: {len(text)}")
|
||
return text
|
||
|
||
# 记录失败原因
|
||
if "usageMetadata" in result:
|
||
usage = result["usageMetadata"]
|
||
logger.warning(f"[视频AI] 无响应,Token: prompt={usage.get('promptTokenCount', 0)}")
|
||
|
||
logger.error(f"[视频AI] 没有有效响应: {str(result)[:300]}")
|
||
return ""
|
||
|
||
except asyncio.TimeoutError:
|
||
logger.error(f"[视频AI] 请求超时")
|
||
return ""
|
||
except Exception as e:
|
||
logger.error(f"[视频AI] 分析失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return ""
|
||
|
||
async def _download_and_encode_video(self, bot, cdnurl: str, aeskey: str) -> str:
|
||
"""下载视频并转换为 base64"""
|
||
try:
|
||
# 从缓存获取
|
||
from utils.redis_cache import RedisCache
|
||
redis_cache = get_cache()
|
||
if redis_cache and redis_cache.enabled:
|
||
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
|
||
if media_key:
|
||
cached_data = redis_cache.get_cached_media(media_key, "video")
|
||
if cached_data:
|
||
logger.debug(f"[视频识别] 从缓存获取视频: {media_key[:20]}...")
|
||
return cached_data
|
||
|
||
# 下载视频
|
||
logger.info(f"[视频识别] 开始下载视频...")
|
||
temp_dir = Path(__file__).parent / "temp"
|
||
temp_dir.mkdir(exist_ok=True)
|
||
|
||
filename = f"video_{uuid.uuid4().hex[:8]}.mp4"
|
||
save_path = str((temp_dir / filename).resolve())
|
||
|
||
# file_type=4 表示视频
|
||
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=4)
|
||
if not success:
|
||
logger.error("[视频识别] CDN 下载失败")
|
||
return ""
|
||
|
||
# 等待文件写入完成
|
||
import os
|
||
for _ in range(30): # 最多等待15秒
|
||
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
|
||
break
|
||
await asyncio.sleep(0.5)
|
||
|
||
if not os.path.exists(save_path):
|
||
logger.error("[视频识别] 视频文件未生成")
|
||
return ""
|
||
|
||
file_size = os.path.getsize(save_path)
|
||
logger.info(f"[视频识别] 视频下载完成,大小: {file_size / 1024 / 1024:.2f} MB")
|
||
|
||
# 检查文件大小限制
|
||
video_config = self.config.get("video_recognition", {})
|
||
max_size_mb = video_config.get("max_size_mb", 20)
|
||
if file_size > max_size_mb * 1024 * 1024:
|
||
logger.warning(f"[视频识别] 视频文件过大: {file_size / 1024 / 1024:.2f} MB > {max_size_mb} MB")
|
||
try:
|
||
Path(save_path).unlink()
|
||
except:
|
||
pass
|
||
return ""
|
||
|
||
# 读取并编码为 base64
|
||
with open(save_path, "rb") as f:
|
||
video_data = base64.b64encode(f.read()).decode()
|
||
|
||
video_base64 = f"data:video/mp4;base64,{video_data}"
|
||
|
||
# 缓存到 Redis
|
||
if redis_cache and redis_cache.enabled and media_key:
|
||
redis_cache.cache_media(media_key, video_base64, "video", ttl=600)
|
||
logger.debug(f"[视频识别] 视频已缓存: {media_key[:20]}...")
|
||
|
||
# 清理临时文件
|
||
try:
|
||
Path(save_path).unlink()
|
||
except:
|
||
pass
|
||
|
||
return video_base64
|
||
|
||
except Exception as e:
|
||
logger.error(f"[视频识别] 下载视频失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return ""
|
||
|
||
async def _call_ai_api_with_video(self, user_message: str, video_base64: str, bot=None,
|
||
from_wxid: str = None, chat_id: str = None,
|
||
nickname: str = "", user_wxid: str = None,
|
||
is_group: bool = False) -> str:
|
||
"""调用 Gemini 原生 API(带视频)- 继承完整上下文"""
|
||
try:
|
||
video_config = self.config.get("video_recognition", {})
|
||
|
||
# 使用视频识别专用配置
|
||
video_model = video_config.get("model", "gemini-3-pro-preview")
|
||
api_url = video_config.get("api_url", "https://api.functen.cn/v1beta/models")
|
||
api_key = video_config.get("api_key", self.config["api"]["api_key"])
|
||
|
||
# 构建完整的 API URL
|
||
full_url = f"{api_url}/{video_model}:generateContent"
|
||
|
||
# 构建系统提示(与 _call_ai_api 保持一致)
|
||
system_content = self.system_prompt
|
||
current_time = datetime.now()
|
||
weekday_map = {
|
||
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
|
||
4: "星期五", 5: "星期六", 6: "星期日"
|
||
}
|
||
weekday = weekday_map[current_time.weekday()]
|
||
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
|
||
system_content += f"\n\n当前时间:{time_str}"
|
||
|
||
if nickname:
|
||
system_content += f"\n当前对话用户的昵称是:{nickname}"
|
||
|
||
# 加载持久记忆
|
||
memory_chat_id = from_wxid if is_group else user_wxid
|
||
if memory_chat_id:
|
||
persistent_memories = self._get_persistent_memories(memory_chat_id)
|
||
if persistent_memories:
|
||
system_content += "\n\n【持久记忆】以下是用户要求你记住的重要信息:\n"
|
||
for m in persistent_memories:
|
||
mem_time = m['time'][:10] if m['time'] else ""
|
||
system_content += f"- [{mem_time}] {m['nickname']}: {m['content']}\n"
|
||
|
||
# 构建历史上下文
|
||
history_context = ""
|
||
if is_group and from_wxid:
|
||
# 群聊:从 Redis/文件加载历史
|
||
history = await self._load_history(from_wxid)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
recent_history = history[-max_context:] if len(history) > max_context else history
|
||
|
||
if recent_history:
|
||
history_context = "\n\n【最近的群聊记录】\n"
|
||
for msg in recent_history:
|
||
msg_nickname = msg.get("nickname", "")
|
||
msg_content = msg.get("content", "")
|
||
if isinstance(msg_content, list):
|
||
# 多模态内容,提取文本
|
||
for item in msg_content:
|
||
if item.get("type") == "text":
|
||
msg_content = item.get("text", "")
|
||
break
|
||
else:
|
||
msg_content = "[图片]"
|
||
# 限制单条消息长度
|
||
if len(str(msg_content)) > 200:
|
||
msg_content = str(msg_content)[:200] + "..."
|
||
history_context += f"[{msg_nickname}] {msg_content}\n"
|
||
else:
|
||
# 私聊:从 memory 加载
|
||
if chat_id:
|
||
memory_messages = self._get_memory_messages(chat_id)
|
||
if memory_messages:
|
||
history_context = "\n\n【最近的对话记录】\n"
|
||
for msg in memory_messages[-20:]: # 最近20条
|
||
role = msg.get("role", "")
|
||
content = msg.get("content", "")
|
||
if isinstance(content, list):
|
||
for item in content:
|
||
if item.get("type") == "text":
|
||
content = item.get("text", "")
|
||
break
|
||
else:
|
||
content = "[图片]"
|
||
role_name = "用户" if role == "user" else "你"
|
||
if len(str(content)) > 200:
|
||
content = str(content)[:200] + "..."
|
||
history_context += f"[{role_name}] {content}\n"
|
||
|
||
# 从 data:video/mp4;base64,xxx 中提取纯 base64 数据
|
||
if video_base64.startswith("data:"):
|
||
video_base64 = video_base64.split(",", 1)[1]
|
||
|
||
# 构建完整提示(人设 + 历史 + 当前问题)
|
||
full_prompt = system_content + history_context + f"\n\n【当前】用户发送了一个视频并问:{user_message or '请描述这个视频的内容'}"
|
||
|
||
# 构建 Gemini 原生格式请求
|
||
payload = {
|
||
"contents": [
|
||
{
|
||
"parts": [
|
||
{"text": full_prompt},
|
||
{
|
||
"inline_data": {
|
||
"mime_type": "video/mp4",
|
||
"data": video_base64
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
"generationConfig": {
|
||
"maxOutputTokens": video_config.get("max_tokens", 8192)
|
||
}
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=video_config.get("timeout", 360))
|
||
|
||
# 配置代理
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False) and PROXY_SUPPORT:
|
||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
except Exception as e:
|
||
logger.warning(f"[视频识别] 代理配置失败: {e}")
|
||
|
||
logger.info(f"[视频识别] 调用 Gemini API: {full_url}")
|
||
logger.debug(f"[视频识别] 提示词长度: {len(full_prompt)} 字符")
|
||
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
async with session.post(full_url, json=payload, headers=headers) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"[视频识别] API 错误: {resp.status}, {error_text[:500]}")
|
||
return ""
|
||
|
||
# 解析 Gemini 响应格式
|
||
result = await resp.json()
|
||
# 详细记录响应(用于调试)
|
||
logger.info(f"[视频识别] API 响应 keys: {list(result.keys()) if isinstance(result, dict) else type(result)}")
|
||
|
||
# 检查是否有错误
|
||
if "error" in result:
|
||
logger.error(f"[视频识别] API 返回错误: {result['error']}")
|
||
return ""
|
||
|
||
# 检查 promptFeedback(安全过滤信息)
|
||
if "promptFeedback" in result:
|
||
feedback = result["promptFeedback"]
|
||
block_reason = feedback.get("blockReason", "")
|
||
if block_reason:
|
||
logger.warning(f"[视频识别] 请求被阻止,原因: {block_reason}")
|
||
logger.warning(f"[视频识别] 安全评级: {feedback.get('safetyRatings', [])}")
|
||
return "抱歉,视频内容无法分析(内容策略限制)。"
|
||
|
||
# 提取文本内容
|
||
full_content = ""
|
||
if "candidates" in result and result["candidates"]:
|
||
logger.info(f"[视频识别] candidates 数量: {len(result['candidates'])}")
|
||
for i, candidate in enumerate(result["candidates"]):
|
||
# 检查 finishReason
|
||
finish_reason = candidate.get("finishReason", "")
|
||
if finish_reason:
|
||
logger.info(f"[视频识别] candidate[{i}] finishReason: {finish_reason}")
|
||
if finish_reason == "SAFETY":
|
||
logger.warning(f"[视频识别] 内容被安全过滤: {candidate.get('safetyRatings', [])}")
|
||
return "抱歉,视频内容无法分析。"
|
||
|
||
content = candidate.get("content", {})
|
||
parts = content.get("parts", [])
|
||
logger.info(f"[视频识别] candidate[{i}] parts 数量: {len(parts)}")
|
||
for part in parts:
|
||
if "text" in part:
|
||
full_content += part["text"]
|
||
else:
|
||
# 没有 candidates,记录完整响应
|
||
logger.error(f"[视频识别] 响应中没有 candidates: {str(result)[:500]}")
|
||
# 可能是上下文太长导致,记录 token 使用情况
|
||
if "usageMetadata" in result:
|
||
usage = result["usageMetadata"]
|
||
logger.warning(f"[视频识别] Token 使用: prompt={usage.get('promptTokenCount', 0)}, total={usage.get('totalTokenCount', 0)}")
|
||
|
||
logger.info(f"[视频识别] AI 响应完成,长度: {len(full_content)}")
|
||
|
||
# 如果没有内容,尝试简化重试
|
||
if not full_content:
|
||
logger.info("[视频识别] 尝试简化请求重试...")
|
||
return await self._call_ai_api_with_video_simple(
|
||
user_message or "请描述这个视频的内容",
|
||
video_base64,
|
||
video_config
|
||
)
|
||
|
||
return full_content.strip()
|
||
|
||
except Exception as e:
|
||
logger.error(f"[视频识别] API 调用失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return ""
|
||
|
||
async def _call_ai_api_with_video_simple(self, user_message: str, video_base64: str, video_config: dict) -> str:
|
||
"""简化版视频识别 API 调用(不带上下文,用于降级重试)"""
|
||
try:
|
||
api_url = video_config.get("api_url", "https://api.functen.cn/v1beta/models")
|
||
api_key = video_config.get("api_key", self.config["api"]["api_key"])
|
||
model = video_config.get("model", "gemini-3-pro-preview")
|
||
|
||
full_url = f"{api_url}/{model}:generateContent"
|
||
|
||
# 简化请求:只发送用户问题和视频
|
||
payload = {
|
||
"contents": [
|
||
{
|
||
"parts": [
|
||
{"text": user_message},
|
||
{
|
||
"inline_data": {
|
||
"mime_type": "video/mp4",
|
||
"data": video_base64
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
"generationConfig": {
|
||
"maxOutputTokens": video_config.get("max_tokens", 8192)
|
||
}
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=video_config.get("timeout", 360))
|
||
|
||
logger.info(f"[视频识别-简化] 调用 API: {full_url}")
|
||
|
||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||
async with session.post(full_url, json=payload, headers=headers) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"[视频识别-简化] API 错误: {resp.status}, {error_text[:300]}")
|
||
return ""
|
||
|
||
result = await resp.json()
|
||
logger.info(f"[视频识别-简化] API 响应 keys: {list(result.keys())}")
|
||
|
||
# 提取文本
|
||
if "candidates" in result and result["candidates"]:
|
||
for candidate in result["candidates"]:
|
||
content = candidate.get("content", {})
|
||
for part in content.get("parts", []):
|
||
if "text" in part:
|
||
text = part["text"]
|
||
logger.info(f"[视频识别-简化] 成功,长度: {len(text)}")
|
||
return text
|
||
|
||
logger.error(f"[视频识别-简化] 仍然没有 candidates: {str(result)[:300]}")
|
||
return ""
|
||
|
||
except Exception as e:
|
||
logger.error(f"[视频识别-简化] 失败: {e}")
|
||
return ""
|
||
|
||
def _should_reply_quote(self, message: dict, title_text: str) -> bool:
|
||
"""判断是否应该回复引用消息"""
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
# 检查群聊/私聊开关
|
||
if is_group and not self.config["behavior"]["reply_group"]:
|
||
return False
|
||
if not is_group and not self.config["behavior"]["reply_private"]:
|
||
return False
|
||
|
||
trigger_mode = self.config["behavior"]["trigger_mode"]
|
||
|
||
# all模式:回复所有消息
|
||
if trigger_mode == "all":
|
||
return True
|
||
|
||
# mention模式:检查是否@了机器人
|
||
if trigger_mode == "mention":
|
||
if is_group:
|
||
ats = message.get("Ats", [])
|
||
if not ats:
|
||
return False
|
||
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
|
||
|
||
return bot_wxid and bot_wxid in ats
|
||
else:
|
||
return True
|
||
|
||
# keyword模式:检查关键词
|
||
if trigger_mode == "keyword":
|
||
keywords = self.config["behavior"]["keywords"]
|
||
return any(kw in title_text for kw in keywords)
|
||
|
||
return False
|
||
|
||
async def _call_ai_api_with_image(self, user_message: str, image_base64: str, bot=None, from_wxid: str = None, chat_id: str = None, nickname: str = "", user_wxid: str = None, is_group: bool = False) -> str:
|
||
"""调用AI API(带图片)"""
|
||
api_config = self.config["api"]
|
||
tools = self._collect_tools()
|
||
|
||
# 构建消息列表
|
||
system_content = self.system_prompt
|
||
|
||
# 添加当前时间信息
|
||
current_time = datetime.now()
|
||
weekday_map = {
|
||
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
|
||
4: "星期五", 5: "星期六", 6: "星期日"
|
||
}
|
||
weekday = weekday_map[current_time.weekday()]
|
||
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
|
||
system_content += f"\n\n当前时间:{time_str}"
|
||
|
||
if nickname:
|
||
system_content += f"\n当前对话用户的昵称是:{nickname}"
|
||
messages = [{"role": "system", "content": system_content}]
|
||
|
||
# 添加历史记忆
|
||
if chat_id:
|
||
memory_messages = self._get_memory_messages(chat_id)
|
||
if memory_messages and len(memory_messages) > 1:
|
||
messages.extend(memory_messages[:-1])
|
||
|
||
# 添加当前用户消息(带图片)
|
||
messages.append({
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": user_message},
|
||
{"type": "image_url", "image_url": {"url": image_base64}}
|
||
]
|
||
})
|
||
|
||
# 保存用户信息供工具调用使用
|
||
self._current_user_wxid = user_wxid
|
||
self._current_is_group = is_group
|
||
|
||
payload = {
|
||
"model": api_config["model"],
|
||
"messages": messages,
|
||
"stream": True,
|
||
"max_tokens": api_config.get("max_tokens", 4096) # 防止回复被截断
|
||
}
|
||
|
||
if tools:
|
||
payload["tools"] = tools
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_config['api_key']}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=api_config["timeout"])
|
||
|
||
# 配置代理
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_username = proxy_config.get("username")
|
||
proxy_password = proxy_config.get("password")
|
||
|
||
if proxy_username and proxy_password:
|
||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||
else:
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
|
||
if PROXY_SUPPORT:
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
logger.debug(f"使用代理: {proxy_type}://{proxy_host}:{proxy_port}")
|
||
except Exception as e:
|
||
logger.warning(f"代理配置失败,将直连: {e}")
|
||
connector = None
|
||
else:
|
||
logger.warning("代理功能不可用(aiohttp_socks 未安装),将直连")
|
||
connector = None
|
||
|
||
try:
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
async with session.post(api_config["url"], json=payload, headers=headers) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"API返回错误状态码: {resp.status}, 响应: {error_text}")
|
||
raise Exception(f"API错误 {resp.status}: {error_text}")
|
||
|
||
# 流式接收响应
|
||
import json
|
||
full_content = ""
|
||
tool_calls_dict = {} # 使用字典来组装工具调用 {index: tool_call}
|
||
tool_call_hint_sent = False # 是否已发送工具调用提示
|
||
|
||
async for line in resp.content:
|
||
line = line.decode('utf-8').strip()
|
||
if not line or line == "data: [DONE]":
|
||
continue
|
||
|
||
if line.startswith("data: "):
|
||
try:
|
||
data = json.loads(line[6:])
|
||
choices = data.get("choices", [])
|
||
if not choices:
|
||
continue
|
||
|
||
delta = choices[0].get("delta", {})
|
||
content = delta.get("content", "")
|
||
if content:
|
||
full_content += content
|
||
|
||
# 收集工具调用(增量式组装)
|
||
if delta.get("tool_calls"):
|
||
# 第一次检测到工具调用时,如果有文本内容则立即发送
|
||
if not tool_call_hint_sent and bot and from_wxid:
|
||
tool_call_hint_sent = True
|
||
if full_content and full_content.strip():
|
||
logger.info(f"[流式-图片] 检测到工具调用,先发送已有文本")
|
||
await bot.send_text(from_wxid, full_content.strip())
|
||
else:
|
||
logger.info("[流式-图片] 检测到工具调用,AI 未输出文本")
|
||
|
||
for tool_call_delta in delta["tool_calls"]:
|
||
index = tool_call_delta.get("index", 0)
|
||
|
||
# 初始化工具调用
|
||
if index not in tool_calls_dict:
|
||
tool_calls_dict[index] = {
|
||
"id": "",
|
||
"type": "function",
|
||
"function": {
|
||
"name": "",
|
||
"arguments": ""
|
||
}
|
||
}
|
||
|
||
# 更新 id
|
||
if "id" in tool_call_delta:
|
||
tool_calls_dict[index]["id"] = tool_call_delta["id"]
|
||
|
||
# 更新 type
|
||
if "type" in tool_call_delta:
|
||
tool_calls_dict[index]["type"] = tool_call_delta["type"]
|
||
|
||
# 更新 function
|
||
if "function" in tool_call_delta:
|
||
func_delta = tool_call_delta["function"]
|
||
if "name" in func_delta:
|
||
tool_calls_dict[index]["function"]["name"] += func_delta["name"]
|
||
if "arguments" in func_delta:
|
||
tool_calls_dict[index]["function"]["arguments"] += func_delta["arguments"]
|
||
except Exception as e:
|
||
logger.debug(f"解析流式数据失败: {e}")
|
||
pass
|
||
|
||
# 转换为列表
|
||
tool_calls_data = [tool_calls_dict[i] for i in sorted(tool_calls_dict.keys())] if tool_calls_dict else []
|
||
|
||
# 检查是否有函数调用
|
||
if tool_calls_data:
|
||
# 提示已在流式处理中发送,直接启动异步工具执行
|
||
logger.info(f"[图片] 启动异步工具执行,共 {len(tool_calls_data)} 个工具")
|
||
asyncio.create_task(
|
||
self._execute_tools_async_with_image(
|
||
tool_calls_data, bot, from_wxid, chat_id,
|
||
nickname, is_group, messages, image_base64
|
||
)
|
||
)
|
||
return ""
|
||
|
||
# 检查是否包含错误的工具调用格式
|
||
if "<tool_code>" in full_content or "print(" in full_content and "flow2_ai_image_generation" in full_content:
|
||
logger.warning("检测到模型输出了错误的工具调用格式,拦截并返回提示")
|
||
return "抱歉,我遇到了一些技术问题,请重新描述一下你的需求~"
|
||
|
||
return full_content.strip()
|
||
|
||
except Exception as e:
|
||
logger.error(f"调用AI API失败: {e}")
|
||
raise
|
||
|
||
async def _send_chat_records(self, bot, from_wxid: str, title: str, content: str):
|
||
"""发送聊天记录格式消息"""
|
||
try:
|
||
import uuid
|
||
import time
|
||
import hashlib
|
||
import xml.etree.ElementTree as ET
|
||
|
||
is_group = from_wxid.endswith("@chatroom")
|
||
|
||
# 自动分割内容
|
||
max_length = 800
|
||
content_parts = []
|
||
|
||
if len(content) <= max_length:
|
||
content_parts = [content]
|
||
else:
|
||
lines = content.split('\n')
|
||
current_part = ""
|
||
|
||
for line in lines:
|
||
if len(current_part + line + '\n') > max_length:
|
||
if current_part:
|
||
content_parts.append(current_part.strip())
|
||
current_part = line + '\n'
|
||
else:
|
||
content_parts.append(line[:max_length])
|
||
current_part = line[max_length:] + '\n'
|
||
else:
|
||
current_part += line + '\n'
|
||
|
||
if current_part.strip():
|
||
content_parts.append(current_part.strip())
|
||
|
||
recordinfo = ET.Element("recordinfo")
|
||
info_el = ET.SubElement(recordinfo, "info")
|
||
info_el.text = title
|
||
is_group_el = ET.SubElement(recordinfo, "isChatRoom")
|
||
is_group_el.text = "1" if is_group else "0"
|
||
datalist = ET.SubElement(recordinfo, "datalist")
|
||
datalist.set("count", str(len(content_parts)))
|
||
desc_el = ET.SubElement(recordinfo, "desc")
|
||
desc_el.text = title
|
||
fromscene_el = ET.SubElement(recordinfo, "fromscene")
|
||
fromscene_el.text = "3"
|
||
|
||
for i, part in enumerate(content_parts):
|
||
di = ET.SubElement(datalist, "dataitem")
|
||
di.set("datatype", "1")
|
||
di.set("dataid", uuid.uuid4().hex)
|
||
|
||
src_local_id = str((int(time.time() * 1000) % 90000) + 10000)
|
||
new_msg_id = str(int(time.time() * 1000) + i)
|
||
create_time = str(int(time.time()) - len(content_parts) + i)
|
||
|
||
ET.SubElement(di, "srcMsgLocalid").text = src_local_id
|
||
ET.SubElement(di, "sourcetime").text = time.strftime("%Y-%m-%d %H:%M", time.localtime(int(create_time)))
|
||
ET.SubElement(di, "fromnewmsgid").text = new_msg_id
|
||
ET.SubElement(di, "srcMsgCreateTime").text = create_time
|
||
ET.SubElement(di, "sourcename").text = "AI助手"
|
||
ET.SubElement(di, "sourceheadurl").text = ""
|
||
ET.SubElement(di, "datatitle").text = part
|
||
ET.SubElement(di, "datadesc").text = part
|
||
ET.SubElement(di, "datafmt").text = "text"
|
||
ET.SubElement(di, "ischatroom").text = "1" if is_group else "0"
|
||
|
||
dataitemsource = ET.SubElement(di, "dataitemsource")
|
||
ET.SubElement(dataitemsource, "hashusername").text = hashlib.sha256(from_wxid.encode("utf-8")).hexdigest()
|
||
|
||
record_xml = ET.tostring(recordinfo, encoding="unicode")
|
||
|
||
appmsg_parts = [
|
||
"<appmsg appid=\"\" sdkver=\"0\">",
|
||
f"<title>{title}</title>",
|
||
f"<des>{title}</des>",
|
||
"<type>19</type>",
|
||
"<url>https://support.weixin.qq.com/cgi-bin/mmsupport-bin/readtemplate?t=page/favorite_record__w_unsupport</url>",
|
||
"<appattach><cdnthumbaeskey></cdnthumbaeskey><aeskey></aeskey></appattach>",
|
||
f"<recorditem><![CDATA[{record_xml}]]></recorditem>",
|
||
"<percent>0</percent>",
|
||
"</appmsg>"
|
||
]
|
||
appmsg_xml = "".join(appmsg_parts)
|
||
|
||
await bot._send_data_async(11214, {"to_wxid": from_wxid, "content": appmsg_xml})
|
||
logger.success(f"已发送聊天记录: {title}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"发送聊天记录失败: {e}")
|
||
|
||
async def _process_image_to_history(self, bot, message: dict, content: str) -> bool:
|
||
"""处理图片/表情包并保存描述到 history(通用方法)"""
|
||
from_wxid = message.get("FromWxid", "")
|
||
sender_wxid = message.get("SenderWxid", "")
|
||
is_group = message.get("IsGroup", False)
|
||
user_wxid = sender_wxid if is_group else from_wxid
|
||
|
||
# 只处理群聊
|
||
if not is_group:
|
||
return True
|
||
|
||
# 检查是否启用图片描述功能
|
||
image_desc_config = self.config.get("image_description", {})
|
||
if not image_desc_config.get("enabled", True):
|
||
return True
|
||
|
||
try:
|
||
# 解析XML获取图片信息
|
||
root = ET.fromstring(content)
|
||
|
||
# 尝试查找 <img> 标签(图片消息)或 <emoji> 标签(表情包)
|
||
img = root.find(".//img")
|
||
if img is None:
|
||
img = root.find(".//emoji")
|
||
|
||
if img is None:
|
||
return True
|
||
|
||
cdnbigimgurl = img.get("cdnbigimgurl", "") or img.get("cdnurl", "")
|
||
aeskey = img.get("aeskey", "")
|
||
|
||
# 检查是否是表情包(有 cdnurl 但可能没有 aeskey)
|
||
is_emoji = img.tag == "emoji"
|
||
|
||
if not cdnbigimgurl:
|
||
return True
|
||
|
||
# 图片消息需要 aeskey,表情包不需要
|
||
if not is_emoji and not aeskey:
|
||
return True
|
||
|
||
# 获取用户昵称 - 使用缓存优化
|
||
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
|
||
|
||
# 立即插入占位符到 history
|
||
placeholder_id = str(uuid.uuid4())
|
||
await self._add_to_history_with_id(from_wxid, nickname, "[图片: 处理中...]", placeholder_id)
|
||
logger.info(f"已插入图片占位符: {placeholder_id}")
|
||
|
||
# 将任务加入队列(不阻塞)
|
||
task = {
|
||
"bot": bot,
|
||
"from_wxid": from_wxid,
|
||
"nickname": nickname,
|
||
"cdnbigimgurl": cdnbigimgurl,
|
||
"aeskey": aeskey,
|
||
"is_emoji": is_emoji,
|
||
"placeholder_id": placeholder_id,
|
||
"config": image_desc_config
|
||
}
|
||
await self.image_desc_queue.put(task)
|
||
logger.info(f"图片描述任务已加入队列,当前队列长度: {self.image_desc_queue.qsize()}")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理图片消息失败: {e}")
|
||
return True
|
||
|
||
async def _image_desc_worker(self):
|
||
"""图片描述工作协程,从队列中取任务并处理"""
|
||
while True:
|
||
try:
|
||
task = await self.image_desc_queue.get()
|
||
await self._generate_and_update_image_description(
|
||
task["bot"], task["from_wxid"], task["nickname"],
|
||
task["cdnbigimgurl"], task["aeskey"], task["is_emoji"],
|
||
task["placeholder_id"], task["config"]
|
||
)
|
||
self.image_desc_queue.task_done()
|
||
except Exception as e:
|
||
logger.error(f"图片描述工作协程异常: {e}")
|
||
|
||
async def _generate_and_update_image_description(self, bot, from_wxid: str, nickname: str,
|
||
cdnbigimgurl: str, aeskey: str, is_emoji: bool,
|
||
placeholder_id: str, image_desc_config: dict):
|
||
"""异步生成图片描述并更新 history"""
|
||
try:
|
||
# 下载并编码图片/表情包
|
||
if is_emoji:
|
||
image_base64 = await self._download_emoji_and_encode(cdnbigimgurl)
|
||
else:
|
||
image_base64 = await self._download_and_encode_image(bot, cdnbigimgurl, aeskey)
|
||
|
||
if not image_base64:
|
||
logger.warning(f"{'表情包' if is_emoji else '图片'}下载失败")
|
||
await self._update_history_by_id(from_wxid, placeholder_id, "[图片]")
|
||
return
|
||
|
||
# 调用 AI 生成图片描述
|
||
description_prompt = image_desc_config.get("prompt", "请用一句话简洁地描述这张图片的主要内容。")
|
||
description = await self._generate_image_description(image_base64, description_prompt, image_desc_config)
|
||
|
||
if description:
|
||
await self._update_history_by_id(from_wxid, placeholder_id, f"[图片: {description}]")
|
||
logger.success(f"已更新图片描述: {nickname} - {description[:30]}...")
|
||
else:
|
||
await self._update_history_by_id(from_wxid, placeholder_id, "[图片]")
|
||
logger.warning(f"图片描述生成失败")
|
||
|
||
except Exception as e:
|
||
logger.error(f"异步生成图片描述失败: {e}")
|
||
await self._update_history_by_id(from_wxid, placeholder_id, "[图片]")
|
||
|
||
@on_image_message(priority=15)
|
||
async def handle_image_message(self, bot, message: dict):
|
||
"""处理直接发送的图片消息(生成描述并保存到 history,不触发 AI 回复)"""
|
||
logger.info("AIChat: handle_image_message 被调用")
|
||
content = message.get("Content", "")
|
||
return await self._process_image_to_history(bot, message, content)
|
||
|
||
@on_emoji_message(priority=15)
|
||
async def handle_emoji_message(self, bot, message: dict):
|
||
"""处理表情包消息(生成描述并保存到 history,不触发 AI 回复)"""
|
||
logger.info("AIChat: handle_emoji_message 被调用")
|
||
content = message.get("Content", "")
|
||
return await self._process_image_to_history(bot, message, content)
|