Files
WechatHookBot/plugins/AIChat/main.py
2025-12-11 13:52:19 +08:00

3063 lines
135 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI 聊天插件
支持自定义模型、API 和人设
支持 Redis 存储对话历史和限流
"""
import asyncio
import tomllib
import aiohttp
import sqlite3
from pathlib import Path
from datetime import datetime
from loguru import logger
from utils.plugin_base import PluginBase
from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message
from utils.redis_cache import get_cache
import xml.etree.ElementTree as ET
import base64
import uuid
# 可选导入代理支持
try:
from aiohttp_socks import ProxyConnector
PROXY_SUPPORT = True
except ImportError:
PROXY_SUPPORT = False
logger.warning("aiohttp_socks 未安装,代理功能将不可用")
class AIChat(PluginBase):
"""AI 聊天插件"""
# 插件元数据
description = "AI 聊天插件,支持自定义模型和人设"
author = "ShiHao"
version = "1.0.0"
def __init__(self):
super().__init__()
self.config = None
self.system_prompt = ""
self.memory = {} # 存储每个会话的记忆 {chat_id: [messages]}
self.history_dir = None # 历史记录目录
self.history_locks = {} # 每个会话一把锁
self.image_desc_queue = asyncio.Queue() # 图片描述任务队列
self.image_desc_workers = [] # 工作协程列表
self.persistent_memory_db = None # 持久记忆数据库路径
async def async_init(self):
"""插件异步初始化"""
# 读取配置
config_path = Path(__file__).parent / "config.toml"
with open(config_path, "rb") as f:
self.config = tomllib.load(f)
# 读取人设
prompt_file = self.config["prompt"]["system_prompt_file"]
prompt_path = Path(__file__).parent / "prompts" / prompt_file
if prompt_path.exists():
with open(prompt_path, "r", encoding="utf-8") as f:
self.system_prompt = f.read().strip()
logger.success(f"已加载人设: {prompt_file}")
else:
logger.warning(f"人设文件不存在: {prompt_file},使用默认人设")
self.system_prompt = "你是一个友好的 AI 助手。"
# 检查代理配置
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5")
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
logger.info(f"AI 聊天插件已启用代理: {proxy_type}://{proxy_host}:{proxy_port}")
# 初始化历史记录目录
history_config = self.config.get("history", {})
if history_config.get("enabled", True):
history_dir_name = history_config.get("history_dir", "history")
self.history_dir = Path(__file__).parent / history_dir_name
self.history_dir.mkdir(exist_ok=True)
logger.info(f"历史记录目录: {self.history_dir}")
# 启动图片描述工作协程并发数为2
for i in range(2):
worker = asyncio.create_task(self._image_desc_worker())
self.image_desc_workers.append(worker)
logger.info("已启动 2 个图片描述工作协程")
# 初始化持久记忆数据库
self._init_persistent_memory_db()
logger.info(f"AI 聊天插件已加载,模型: {self.config['api']['model']}")
def _init_persistent_memory_db(self):
"""初始化持久记忆数据库"""
db_dir = Path(__file__).parent / "data"
db_dir.mkdir(exist_ok=True)
self.persistent_memory_db = db_dir / "persistent_memory.db"
conn = sqlite3.connect(self.persistent_memory_db)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS memories (
id INTEGER PRIMARY KEY AUTOINCREMENT,
chat_id TEXT NOT NULL,
chat_type TEXT NOT NULL,
user_wxid TEXT NOT NULL,
user_nickname TEXT,
content TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_chat_id ON memories(chat_id)")
conn.commit()
conn.close()
logger.info(f"持久记忆数据库已初始化: {self.persistent_memory_db}")
def _add_persistent_memory(self, chat_id: str, chat_type: str, user_wxid: str,
user_nickname: str, content: str) -> int:
"""添加持久记忆返回记忆ID"""
conn = sqlite3.connect(self.persistent_memory_db)
cursor = conn.cursor()
cursor.execute("""
INSERT INTO memories (chat_id, chat_type, user_wxid, user_nickname, content)
VALUES (?, ?, ?, ?, ?)
""", (chat_id, chat_type, user_wxid, user_nickname, content))
memory_id = cursor.lastrowid
conn.commit()
conn.close()
return memory_id
def _get_persistent_memories(self, chat_id: str) -> list:
"""获取指定会话的所有持久记忆"""
conn = sqlite3.connect(self.persistent_memory_db)
cursor = conn.cursor()
cursor.execute("""
SELECT id, user_nickname, content, created_at
FROM memories
WHERE chat_id = ?
ORDER BY created_at ASC
""", (chat_id,))
rows = cursor.fetchall()
conn.close()
return [{"id": r[0], "nickname": r[1], "content": r[2], "time": r[3]} for r in rows]
def _delete_persistent_memory(self, chat_id: str, memory_id: int) -> bool:
"""删除指定的持久记忆"""
conn = sqlite3.connect(self.persistent_memory_db)
cursor = conn.cursor()
cursor.execute("DELETE FROM memories WHERE id = ? AND chat_id = ?", (memory_id, chat_id))
deleted = cursor.rowcount > 0
conn.commit()
conn.close()
return deleted
def _clear_persistent_memories(self, chat_id: str) -> int:
"""清空指定会话的所有持久记忆,返回删除数量"""
conn = sqlite3.connect(self.persistent_memory_db)
cursor = conn.cursor()
cursor.execute("DELETE FROM memories WHERE chat_id = ?", (chat_id,))
deleted_count = cursor.rowcount
conn.commit()
conn.close()
return deleted_count
def _get_chat_id(self, from_wxid: str, sender_wxid: str = None, is_group: bool = False) -> str:
"""获取会话ID"""
if is_group:
# 群聊使用 "群ID:用户ID" 组合,确保每个用户有独立的对话记忆
user_wxid = sender_wxid or from_wxid
return f"{from_wxid}:{user_wxid}"
else:
return sender_wxid or from_wxid # 私聊使用用户ID
async def _get_user_nickname(self, bot, from_wxid: str, user_wxid: str, is_group: bool) -> str:
"""
获取用户昵称,优先使用 Redis 缓存
Args:
bot: WechatHookClient 实例
from_wxid: 消息来源群聊ID或私聊用户ID
user_wxid: 用户wxid
is_group: 是否群聊
Returns:
用户昵称
"""
if not is_group:
return ""
nickname = ""
# 1. 优先从 Redis 缓存获取
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid)
if cached_info and cached_info.get("nickname"):
logger.debug(f"[缓存命中] 用户昵称: {user_wxid} -> {cached_info['nickname']}")
return cached_info["nickname"]
# 2. 缓存未命中,调用 API 获取
try:
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
if user_info and user_info.get("nickName", {}).get("string"):
nickname = user_info["nickName"]["string"]
# 存入缓存
if redis_cache and redis_cache.enabled:
redis_cache.set_user_info(from_wxid, user_wxid, user_info)
logger.debug(f"[已缓存] 用户昵称: {user_wxid} -> {nickname}")
return nickname
except Exception as e:
logger.warning(f"API获取用户昵称失败: {e}")
# 3. 从 MessageLogger 数据库查询
if not nickname:
try:
from plugins.MessageLogger.main import MessageLogger
msg_logger = MessageLogger.get_instance()
if msg_logger:
with msg_logger.get_db_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
(user_wxid,)
)
result = cursor.fetchone()
if result:
nickname = result[0]
except Exception as e:
logger.debug(f"从数据库获取昵称失败: {e}")
# 4. 最后降级使用 wxid
if not nickname:
nickname = user_wxid or "未知用户"
return nickname
def _check_rate_limit(self, user_wxid: str) -> tuple:
"""
检查用户是否超过限流
Args:
user_wxid: 用户wxid
Returns:
(是否允许, 剩余次数, 重置时间秒数)
"""
rate_limit_config = self.config.get("rate_limit", {})
if not rate_limit_config.get("enabled", True):
return (True, 999, 0)
redis_cache = get_cache()
if not redis_cache or not redis_cache.enabled:
return (True, 999, 0) # Redis 不可用时不限流
limit = rate_limit_config.get("ai_chat_limit", 20)
window = rate_limit_config.get("ai_chat_window", 60)
return redis_cache.check_rate_limit(user_wxid, limit, window, "ai_chat")
def _add_to_memory(self, chat_id: str, role: str, content, image_base64: str = None):
"""
添加消息到记忆
Args:
chat_id: 会话ID
role: 角色 (user/assistant)
content: 消息内容(可以是字符串或列表)
image_base64: 可选的图片base64数据
"""
if not self.config.get("memory", {}).get("enabled", False):
return
# 如果有图片,构建多模态内容
if image_base64:
message_content = [
{"type": "text", "text": content if isinstance(content, str) else ""},
{"type": "image_url", "image_url": {"url": image_base64}}
]
else:
message_content = content
# 优先使用 Redis 存储
redis_config = self.config.get("redis", {})
if redis_config.get("use_redis_history", True):
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
ttl = redis_config.get("chat_history_ttl", 86400)
redis_cache.add_chat_message(chat_id, role, message_content, ttl=ttl)
# 裁剪历史
max_messages = self.config["memory"]["max_messages"]
redis_cache.trim_chat_history(chat_id, max_messages)
return
# 降级到内存存储
if chat_id not in self.memory:
self.memory[chat_id] = []
self.memory[chat_id].append({"role": role, "content": message_content})
# 限制记忆长度
max_messages = self.config["memory"]["max_messages"]
if len(self.memory[chat_id]) > max_messages:
self.memory[chat_id] = self.memory[chat_id][-max_messages:]
def _get_memory_messages(self, chat_id: str) -> list:
"""获取记忆中的消息"""
if not self.config.get("memory", {}).get("enabled", False):
return []
# 优先从 Redis 获取
redis_config = self.config.get("redis", {})
if redis_config.get("use_redis_history", True):
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
max_messages = self.config["memory"]["max_messages"]
return redis_cache.get_chat_history(chat_id, max_messages)
# 降级到内存
return self.memory.get(chat_id, [])
def _clear_memory(self, chat_id: str):
"""清空指定会话的记忆"""
# 清空 Redis
redis_config = self.config.get("redis", {})
if redis_config.get("use_redis_history", True):
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
redis_cache.clear_chat_history(chat_id)
# 同时清空内存
if chat_id in self.memory:
del self.memory[chat_id]
async def _download_and_encode_image(self, bot, cdnurl: str, aeskey: str) -> str:
"""下载图片并转换为base64优先从缓存获取"""
try:
# 1. 优先从 Redis 缓存获取
from utils.redis_cache import RedisCache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
if media_key:
cached_data = redis_cache.get_cached_media(media_key, "image")
if cached_data:
logger.debug(f"[缓存命中] 图片从 Redis 获取: {media_key[:20]}...")
return cached_data
# 2. 缓存未命中,下载图片
logger.debug(f"[缓存未命中] 开始下载图片...")
temp_dir = Path(__file__).parent / "temp"
temp_dir.mkdir(exist_ok=True)
filename = f"temp_{uuid.uuid4().hex[:8]}.jpg"
save_path = str((temp_dir / filename).resolve())
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=2)
if not success:
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=1)
if not success:
return ""
# 等待文件写入完成
import os
import asyncio
for _ in range(20): # 最多等待10秒
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
break
await asyncio.sleep(0.5)
if not os.path.exists(save_path):
return ""
with open(save_path, "rb") as f:
image_data = base64.b64encode(f.read()).decode()
base64_result = f"data:image/jpeg;base64,{image_data}"
# 3. 缓存到 Redis供后续使用
if redis_cache and redis_cache.enabled and media_key:
redis_cache.cache_media(media_key, base64_result, "image", ttl=300)
logger.debug(f"[已缓存] 图片缓存到 Redis: {media_key[:20]}...")
try:
Path(save_path).unlink()
except:
pass
return base64_result
except Exception as e:
logger.error(f"下载图片失败: {e}")
return ""
async def _download_emoji_and_encode(self, cdn_url: str, max_retries: int = 3) -> str:
"""下载表情包并转换为base64HTTP 直接下载,带重试机制),优先从缓存获取"""
# 替换 HTML 实体
cdn_url = cdn_url.replace("&", "&")
# 1. 优先从 Redis 缓存获取
from utils.redis_cache import RedisCache
redis_cache = get_cache()
media_key = RedisCache.generate_media_key(cdnurl=cdn_url)
if redis_cache and redis_cache.enabled and media_key:
cached_data = redis_cache.get_cached_media(media_key, "emoji")
if cached_data:
logger.debug(f"[缓存命中] 表情包从 Redis 获取: {media_key[:20]}...")
return cached_data
# 2. 缓存未命中,下载表情包
logger.debug(f"[缓存未命中] 开始下载表情包...")
temp_dir = Path(__file__).parent / "temp"
temp_dir.mkdir(exist_ok=True)
filename = f"temp_{uuid.uuid4().hex[:8]}.gif"
save_path = temp_dir / filename
last_error = None
for attempt in range(max_retries):
try:
# 使用 aiohttp 下载,每次重试增加超时时间
timeout = aiohttp.ClientTimeout(total=30 + attempt * 15)
# 配置代理
connector = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5").upper()
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_username = proxy_config.get("username")
proxy_password = proxy_config.get("password")
if proxy_username and proxy_password:
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
else:
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
if PROXY_SUPPORT:
try:
connector = ProxyConnector.from_url(proxy_url)
except:
connector = None
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.get(cdn_url) as response:
if response.status == 200:
content = await response.read()
if len(content) == 0:
logger.warning(f"表情包下载内容为空,重试 {attempt + 1}/{max_retries}")
continue
# 编码为 base64
image_data = base64.b64encode(content).decode()
logger.debug(f"表情包下载成功,大小: {len(content)} 字节")
base64_result = f"data:image/gif;base64,{image_data}"
# 3. 缓存到 Redis供后续使用
if redis_cache and redis_cache.enabled and media_key:
redis_cache.cache_media(media_key, base64_result, "emoji", ttl=300)
logger.debug(f"[已缓存] 表情包缓存到 Redis: {media_key[:20]}...")
return base64_result
else:
logger.warning(f"表情包下载失败,状态码: {response.status},重试 {attempt + 1}/{max_retries}")
except asyncio.TimeoutError:
last_error = "请求超时"
logger.warning(f"表情包下载超时,重试 {attempt + 1}/{max_retries}")
except aiohttp.ClientError as e:
last_error = str(e)
logger.warning(f"表情包下载网络错误: {e},重试 {attempt + 1}/{max_retries}")
except Exception as e:
last_error = str(e)
logger.warning(f"表情包下载异常: {e},重试 {attempt + 1}/{max_retries}")
# 重试前等待(指数退避)
if attempt < max_retries - 1:
await asyncio.sleep(1 * (attempt + 1))
logger.error(f"表情包下载失败,已重试 {max_retries} 次: {last_error}")
return ""
async def _generate_image_description(self, image_base64: str, prompt: str, config: dict) -> str:
"""
使用 AI 生成图片描述
Args:
image_base64: 图片的 base64 数据
prompt: 描述提示词
config: 图片描述配置
Returns:
图片描述文本,失败返回空字符串
"""
try:
api_config = self.config["api"]
description_model = config.get("model", api_config["model"])
# 构建消息
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": image_base64}}
]
}
]
payload = {
"model": description_model,
"messages": messages,
"max_tokens": config.get("max_tokens", 1000),
"stream": True
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_config['api_key']}"
}
timeout = aiohttp.ClientTimeout(total=api_config["timeout"])
# 配置代理
connector = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5").upper()
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_username = proxy_config.get("username")
proxy_password = proxy_config.get("password")
if proxy_username and proxy_password:
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
else:
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
if PROXY_SUPPORT:
try:
connector = ProxyConnector.from_url(proxy_url)
except Exception as e:
logger.warning(f"代理配置失败,将直连: {e}")
connector = None
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.post(
api_config["url"],
json=payload,
headers=headers
) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"图片描述 API 返回错误: {resp.status}, {error_text[:200]}")
return ""
# 流式接收响应
import json
description = ""
async for line in resp.content:
line = line.decode('utf-8').strip()
if not line or line == "data: [DONE]":
continue
if line.startswith("data: "):
try:
data = json.loads(line[6:])
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
description += content
except:
pass
logger.debug(f"图片描述生成成功: {description}")
return description.strip()
except Exception as e:
logger.error(f"生成图片描述失败: {e}")
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
return ""
def _collect_tools(self):
"""收集所有插件的LLM工具支持白名单/黑名单过滤)"""
from utils.plugin_manager import PluginManager
tools = []
# 获取工具过滤配置
tools_config = self.config.get("tools", {})
mode = tools_config.get("mode", "all")
whitelist = set(tools_config.get("whitelist", []))
blacklist = set(tools_config.get("blacklist", []))
for plugin in PluginManager().plugins.values():
if hasattr(plugin, 'get_llm_tools'):
plugin_tools = plugin.get_llm_tools()
if plugin_tools:
for tool in plugin_tools:
tool_name = tool.get("function", {}).get("name", "")
# 根据模式过滤
if mode == "whitelist":
if tool_name in whitelist:
tools.append(tool)
logger.debug(f"[白名单] 启用工具: {tool_name}")
elif mode == "blacklist":
if tool_name not in blacklist:
tools.append(tool)
else:
logger.debug(f"[黑名单] 禁用工具: {tool_name}")
else: # all
tools.append(tool)
return tools
async def _handle_list_prompts(self, bot, from_wxid: str):
"""处理人设列表指令"""
try:
prompts_dir = Path(__file__).parent / "prompts"
# 获取所有 .txt 文件
if not prompts_dir.exists():
await bot.send_text(from_wxid, "❌ prompts 目录不存在")
return
txt_files = sorted(prompts_dir.glob("*.txt"))
if not txt_files:
await bot.send_text(from_wxid, "❌ 没有找到任何人设文件")
return
# 构建列表消息
current_file = self.config["prompt"]["system_prompt_file"]
msg = "📋 可用人设列表:\n\n"
for i, file_path in enumerate(txt_files, 1):
filename = file_path.name
# 标记当前使用的人设
if filename == current_file:
msg += f"{i}. {filename}\n"
else:
msg += f"{i}. {filename}\n"
msg += f"\n💡 使用方法:/切人设 文件名.txt"
await bot.send_text(from_wxid, msg)
logger.info("已发送人设列表")
except Exception as e:
logger.error(f"获取人设列表失败: {e}")
await bot.send_text(from_wxid, f"❌ 获取人设列表失败: {str(e)}")
def _estimate_tokens(self, text: str) -> int:
"""
估算文本的 token 数量
简单估算规则:
- 中文:约 1.5 字符 = 1 token
- 英文:约 4 字符 = 1 token
- 混合文本取平均
"""
if not text:
return 0
# 统计中文字符数
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
# 其他字符数
other_chars = len(text) - chinese_chars
# 估算 token 数
chinese_tokens = chinese_chars / 1.5
other_tokens = other_chars / 4
return int(chinese_tokens + other_tokens)
def _estimate_message_tokens(self, message: dict) -> int:
"""估算单条消息的 token 数"""
content = message.get("content", "")
if isinstance(content, str):
return self._estimate_tokens(content)
elif isinstance(content, list):
# 多模态消息
total = 0
for item in content:
if item.get("type") == "text":
total += self._estimate_tokens(item.get("text", ""))
elif item.get("type") == "image_url":
# 图片按 85 token 估算OpenAI 低分辨率图片)
total += 85
return total
return 0
async def _handle_context_stats(self, bot, from_wxid: str, user_wxid: str, is_group: bool):
"""处理上下文统计指令"""
try:
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
# 计算持久记忆 token
memory_chat_id = from_wxid if is_group else user_wxid
persistent_memories = self._get_persistent_memories(memory_chat_id) if memory_chat_id else []
persistent_tokens = 0
if persistent_memories:
persistent_tokens += self._estimate_tokens("【持久记忆】以下是用户要求你记住的重要信息:\n")
for m in persistent_memories:
mem_time = m['time'][:10] if m['time'] else ""
persistent_tokens += self._estimate_tokens(f"- [{mem_time}] {m['nickname']}: {m['content']}\n")
if is_group:
# 群聊:使用 history 机制
history = await self._load_history(from_wxid)
max_context = self.config.get("history", {}).get("max_context", 50)
# 实际会发送给 AI 的上下文
context_messages = history[-max_context:] if len(history) > max_context else history
# 计算 token
context_tokens = 0
for msg in context_messages:
msg_content = msg.get("content", "")
nickname = msg.get("nickname", "")
if isinstance(msg_content, list):
# 多模态消息
for item in msg_content:
if item.get("type") == "text":
context_tokens += self._estimate_tokens(f"[{nickname}] {item.get('text', '')}")
elif item.get("type") == "image_url":
context_tokens += 85
else:
context_tokens += self._estimate_tokens(f"[{nickname}] {msg_content}")
# 加上 system prompt 的 token
system_tokens = self._estimate_tokens(self.system_prompt)
total_tokens = system_tokens + persistent_tokens + context_tokens
# 计算百分比
context_limit = self.config.get("api", {}).get("context_limit", 200000)
usage_percent = (total_tokens / context_limit) * 100
remaining_tokens = context_limit - total_tokens
msg = f"📊 群聊上下文统计\n\n"
msg += f"💬 历史总条数: {len(history)}\n"
msg += f"📤 AI可见条数: {len(context_messages)}/{max_context}\n"
msg += f"🤖 人设 Token: ~{system_tokens}\n"
msg += f"📌 持久记忆: {len(persistent_memories)} 条 (~{persistent_tokens} token)\n"
msg += f"📝 上下文 Token: ~{context_tokens}\n"
msg += f"📦 总计 Token: ~{total_tokens}\n"
msg += f"📈 使用率: {usage_percent:.1f}% (剩余 ~{remaining_tokens:,})\n"
msg += f"\n💡 /清空记忆 清空上下文 | /记忆列表 查看持久记忆"
else:
# 私聊:使用 memory 机制
memory_messages = self._get_memory_messages(chat_id)
max_messages = self.config.get("memory", {}).get("max_messages", 20)
# 计算 token
context_tokens = 0
for msg in memory_messages:
context_tokens += self._estimate_message_tokens(msg)
# 加上 system prompt 的 token
system_tokens = self._estimate_tokens(self.system_prompt)
total_tokens = system_tokens + persistent_tokens + context_tokens
# 计算百分比
context_limit = self.config.get("api", {}).get("context_limit", 200000)
usage_percent = (total_tokens / context_limit) * 100
remaining_tokens = context_limit - total_tokens
msg = f"📊 私聊上下文统计\n\n"
msg += f"💬 记忆条数: {len(memory_messages)}/{max_messages}\n"
msg += f"🤖 人设 Token: ~{system_tokens}\n"
msg += f"📌 持久记忆: {len(persistent_memories)} 条 (~{persistent_tokens} token)\n"
msg += f"📝 上下文 Token: ~{context_tokens}\n"
msg += f"📦 总计 Token: ~{total_tokens}\n"
msg += f"📈 使用率: {usage_percent:.1f}% (剩余 ~{remaining_tokens:,})\n"
msg += f"\n💡 /清空记忆 清空上下文 | /记忆列表 查看持久记忆"
await bot.send_text(from_wxid, msg)
logger.info(f"已发送上下文统计: {chat_id}")
except Exception as e:
logger.error(f"获取上下文统计失败: {e}")
await bot.send_text(from_wxid, f"❌ 获取上下文统计失败: {str(e)}")
async def _handle_switch_prompt(self, bot, from_wxid: str, content: str):
"""处理切换人设指令"""
try:
# 提取文件名
parts = content.split(maxsplit=1)
if len(parts) < 2:
await bot.send_text(from_wxid, "❌ 请指定人设文件名\n格式:/切人设 文件名.txt")
return
filename = parts[1].strip()
# 检查文件是否存在
prompt_path = Path(__file__).parent / "prompts" / filename
if not prompt_path.exists():
await bot.send_text(from_wxid, f"❌ 人设文件不存在: {filename}")
return
# 读取新人设
with open(prompt_path, "r", encoding="utf-8") as f:
new_prompt = f.read().strip()
# 更新人设
self.system_prompt = new_prompt
self.config["prompt"]["system_prompt_file"] = filename
await bot.send_text(from_wxid, f"✅ 已切换人设: {filename}")
logger.success(f"管理员切换人设: {filename}")
except Exception as e:
logger.error(f"切换人设失败: {e}")
await bot.send_text(from_wxid, f"❌ 切换人设失败: {str(e)}")
@on_text_message(priority=80)
async def handle_message(self, bot, message: dict):
"""处理文本消息"""
content = message.get("Content", "").strip()
from_wxid = message.get("FromWxid", "")
sender_wxid = message.get("SenderWxid", "")
is_group = message.get("IsGroup", False)
# 获取实际发送者
user_wxid = sender_wxid if is_group else from_wxid
# 获取机器人 wxid 和管理员列表
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
admins = main_config.get("Bot", {}).get("admins", [])
# 检查是否是人设列表指令(精确匹配)
if content == "/人设列表":
await self._handle_list_prompts(bot, from_wxid)
return False
# 检查是否是切换人设指令(精确匹配前缀)
if content.startswith("/切人设 ") or content.startswith("/切换人设 "):
if user_wxid in admins:
await self._handle_switch_prompt(bot, from_wxid, content)
else:
await bot.send_text(from_wxid, "❌ 仅管理员可以切换人设")
return False
# 检查是否是清空记忆指令
clear_command = self.config.get("memory", {}).get("clear_command", "/清空记忆")
if content == clear_command:
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
self._clear_memory(chat_id)
await bot.send_text(from_wxid, "✅ 已清空当前会话的记忆")
return False
# 检查是否是上下文统计指令
if content == "/context" or content == "/上下文":
await self._handle_context_stats(bot, from_wxid, user_wxid, is_group)
return False
# 检查是否是记忆状态指令(仅管理员)
if content == "/记忆状态":
if user_wxid in admins:
if is_group:
history = await self._load_history(from_wxid)
max_context = self.config.get("history", {}).get("max_context", 50)
context_count = min(len(history), max_context)
msg = f"📊 群聊记忆: {len(history)}\n"
msg += f"💬 AI可见: 最近 {context_count}"
await bot.send_text(from_wxid, msg)
else:
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
memory = self._get_memory_messages(chat_id)
msg = f"📊 私聊记忆: {len(memory)}"
await bot.send_text(from_wxid, msg)
else:
await bot.send_text(from_wxid, "❌ 仅管理员可以查看记忆状态")
return False
# 持久记忆相关指令
# 记录持久记忆:/记录 xxx
if content.startswith("/记录 "):
memory_content = content[4:].strip()
if memory_content:
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
# 群聊用群ID私聊用用户ID
memory_chat_id = from_wxid if is_group else user_wxid
chat_type = "group" if is_group else "private"
memory_id = self._add_persistent_memory(
memory_chat_id, chat_type, user_wxid, nickname, memory_content
)
await bot.send_text(from_wxid, f"✅ 已记录到持久记忆 (ID: {memory_id})")
logger.info(f"添加持久记忆: {memory_chat_id} - {memory_content[:30]}...")
else:
await bot.send_text(from_wxid, "❌ 请输入要记录的内容\n格式:/记录 要记住的内容")
return False
# 查看持久记忆列表(所有人可用)
if content == "/记忆列表" or content == "/持久记忆":
memory_chat_id = from_wxid if is_group else user_wxid
memories = self._get_persistent_memories(memory_chat_id)
if memories:
msg = f"📋 持久记忆列表 (共 {len(memories)} 条)\n\n"
for m in memories:
time_str = m['time'][:16] if m['time'] else "未知"
content_preview = m['content'][:30] + "..." if len(m['content']) > 30 else m['content']
msg += f"[{m['id']}] {m['nickname']}: {content_preview}\n 📅 {time_str}\n"
msg += f"\n💡 删除记忆:/删除记忆 ID (管理员)"
else:
msg = "📋 暂无持久记忆"
await bot.send_text(from_wxid, msg)
return False
# 删除持久记忆(管理员)
if content.startswith("/删除记忆 "):
if user_wxid in admins:
try:
memory_id = int(content[6:].strip())
memory_chat_id = from_wxid if is_group else user_wxid
if self._delete_persistent_memory(memory_chat_id, memory_id):
await bot.send_text(from_wxid, f"✅ 已删除记忆 ID: {memory_id}")
else:
await bot.send_text(from_wxid, f"❌ 未找到记忆 ID: {memory_id}")
except ValueError:
await bot.send_text(from_wxid, "❌ 请输入有效的记忆ID\n格式:/删除记忆 ID")
else:
await bot.send_text(from_wxid, "❌ 仅管理员可以删除持久记忆")
return False
# 清空所有持久记忆(管理员)
if content == "/清空持久记忆":
if user_wxid in admins:
memory_chat_id = from_wxid if is_group else user_wxid
deleted_count = self._clear_persistent_memories(memory_chat_id)
await bot.send_text(from_wxid, f"✅ 已清空 {deleted_count} 条持久记忆")
else:
await bot.send_text(from_wxid, "❌ 仅管理员可以清空持久记忆")
return False
# 检查是否应该回复
should_reply = self._should_reply(message, content, bot_wxid)
# 获取用户昵称(用于历史记录)- 使用缓存优化
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
# 保存到群组历史记录(所有消息都保存,不管是否回复)
if is_group:
await self._add_to_history(from_wxid, nickname, content)
# 如果不需要回复,直接返回
if not should_reply:
return
# 限流检查(仅在需要回复时检查)
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
if not allowed:
rate_limit_config = self.config.get("rate_limit", {})
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
msg = msg.format(seconds=reset_time)
await bot.send_text(from_wxid, msg)
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
return False
# 提取实际消息内容(去除@
actual_content = self._extract_content(message, content)
if not actual_content:
return
logger.info(f"AI 处理消息: {actual_content[:50]}...")
try:
# 获取会话ID并添加用户消息到记忆
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
self._add_to_memory(chat_id, "user", actual_content)
# 调用 AI API带重试机制
max_retries = self.config.get("api", {}).get("max_retries", 2)
response = None
last_error = None
for attempt in range(max_retries + 1):
try:
response = await self._call_ai_api(actual_content, bot, from_wxid, chat_id, nickname, user_wxid, is_group)
# 检查返回值:
# - None: 工具调用已异步处理,不需要重试
# - "": 真正的空响应,需要重试
# - 有内容: 正常响应
if response is None:
# 工具调用,不重试
logger.info("AI 触发工具调用,已异步处理")
break
if response == "" and attempt < max_retries:
logger.warning(f"AI 返回空内容,重试 {attempt + 1}/{max_retries}")
await asyncio.sleep(1) # 等待1秒后重试
continue
break # 成功或已达到最大重试次数
except Exception as e:
last_error = e
if attempt < max_retries:
logger.warning(f"AI API 调用失败,重试 {attempt + 1}/{max_retries}: {e}")
await asyncio.sleep(1)
else:
raise
# 发送回复并添加到记忆
# 注意:如果返回 None 或空字符串,说明已经以其他形式处理了,不需要再发送文本
if response:
await bot.send_text(from_wxid, response)
self._add_to_memory(chat_id, "assistant", response)
# 保存机器人回复到历史记录
if is_group:
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
await self._add_to_history(from_wxid, bot_nickname, response)
logger.success(f"AI 回复成功: {response[:50]}...")
else:
logger.info("AI 回复为空或已通过其他方式发送(如聊天记录)")
except Exception as e:
import traceback
error_detail = traceback.format_exc()
logger.error(f"AI 处理失败: {type(e).__name__}: {str(e)}")
logger.error(f"详细错误:\n{error_detail}")
await bot.send_text(from_wxid, "抱歉,我遇到了一些问题,请稍后再试。")
def _should_reply(self, message: dict, content: str, bot_wxid: str = None) -> bool:
"""判断是否应该回复"""
# 检查是否由AutoReply插件触发
if message.get('_auto_reply_triggered'):
return True
is_group = message.get("IsGroup", False)
# 检查群聊/私聊开关
if is_group and not self.config["behavior"].get("reply_group", True):
return False
if not is_group and not self.config["behavior"].get("reply_private", True):
return False
trigger_mode = self.config["behavior"].get("trigger_mode", "mention")
# all 模式:回复所有消息
if trigger_mode == "all":
return True
# mention 模式:检查是否@了机器人
if trigger_mode == "mention":
if is_group:
ats = message.get("Ats", [])
# 如果没有 bot_wxid从配置文件读取
if not bot_wxid:
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
else:
# 也需要读取昵称用于备用检测
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
# 方式1检查 @ 列表中是否包含机器人的 wxid
if ats and bot_wxid and bot_wxid in ats:
return True
# 方式2备用检测 - 从消息内容中检查是否包含 @机器人昵称
# (当 API 没有返回 at_user_list 时使用)
if bot_nickname and f"@{bot_nickname}" in content:
logger.debug(f"通过内容检测到 @{bot_nickname},触发回复")
return True
return False
else:
# 私聊直接回复
return True
# keyword 模式:检查关键词
if trigger_mode == "keyword":
keywords = self.config["behavior"]["keywords"]
return any(kw in content for kw in keywords)
return False
def _extract_content(self, message: dict, content: str) -> str:
"""提取实际消息内容(去除@等)"""
is_group = message.get("IsGroup", False)
if is_group:
# 群聊消息,去除@部分
# 格式通常是 "@昵称 消息内容"
parts = content.split(maxsplit=1)
if len(parts) > 1 and parts[0].startswith("@"):
return parts[1].strip()
return content.strip()
return content.strip()
async def _call_ai_api(self, user_message: str, bot=None, from_wxid: str = None, chat_id: str = None, nickname: str = "", user_wxid: str = None, is_group: bool = False) -> str:
"""调用 AI API"""
api_config = self.config["api"]
# 收集工具
tools = self._collect_tools()
logger.info(f"收集到 {len(tools)} 个工具函数")
if tools:
tool_names = [t["function"]["name"] for t in tools]
logger.info(f"工具列表: {tool_names}")
# 构建消息列表
system_content = self.system_prompt
# 添加当前时间信息
current_time = datetime.now()
weekday_map = {
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
4: "星期五", 5: "星期六", 6: "星期日"
}
weekday = weekday_map[current_time.weekday()]
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
system_content += f"\n\n当前时间:{time_str}"
if nickname:
system_content += f"\n当前对话用户的昵称是:{nickname}"
# 加载持久记忆
memory_chat_id = from_wxid if is_group else user_wxid
if memory_chat_id:
persistent_memories = self._get_persistent_memories(memory_chat_id)
if persistent_memories:
system_content += "\n\n【持久记忆】以下是用户要求你记住的重要信息:\n"
for m in persistent_memories:
mem_time = m['time'][:10] if m['time'] else ""
system_content += f"- [{mem_time}] {m['nickname']}: {m['content']}\n"
messages = [{"role": "system", "content": system_content}]
# 从 JSON 历史记录加载上下文(仅群聊)
if is_group and from_wxid:
history = await self._load_history(from_wxid)
max_context = self.config.get("history", {}).get("max_context", 50)
# 取最近的 N 条消息作为上下文
recent_history = history[-max_context:] if len(history) > max_context else history
# 转换为 AI 消息格式
for msg in recent_history:
msg_nickname = msg.get("nickname", "")
msg_content = msg.get("content", "")
# 检查是否是多模态内容(包含图片)
if isinstance(msg_content, list):
# 多模态内容:添加昵称前缀到文本部分
content_with_nickname = []
for item in msg_content:
if item.get("type") == "text":
# 在文本前添加昵称
content_with_nickname.append({
"type": "text",
"text": f"[{msg_nickname}] {item.get('text', '')}"
})
else:
# 图片等其他类型直接保留
content_with_nickname.append(item)
messages.append({
"role": "user",
"content": content_with_nickname
})
else:
# 纯文本内容:简单格式 [昵称] 内容
messages.append({
"role": "user",
"content": f"[{msg_nickname}] {msg_content}"
})
else:
# 私聊使用原有的 memory 机制
if chat_id:
memory_messages = self._get_memory_messages(chat_id)
if memory_messages and len(memory_messages) > 1:
messages.extend(memory_messages[:-1])
# 添加当前用户消息
messages.append({"role": "user", "content": f"[{nickname}] {user_message}" if is_group and nickname else user_message})
# 保存用户信息供工具调用使用
self._current_user_wxid = user_wxid
self._current_is_group = is_group
payload = {
"model": api_config["model"],
"messages": messages,
"max_tokens": api_config.get("max_tokens", 4096) # 防止回复被截断
}
if tools:
payload["tools"] = tools
logger.debug(f"已将 {len(tools)} 个工具添加到请求中")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_config['api_key']}"
}
timeout = aiohttp.ClientTimeout(total=api_config["timeout"])
# 配置代理
connector = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5").upper()
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_username = proxy_config.get("username")
proxy_password = proxy_config.get("password")
# 构建代理 URL
if proxy_username and proxy_password:
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
else:
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
if PROXY_SUPPORT:
try:
connector = ProxyConnector.from_url(proxy_url)
logger.debug(f"使用代理: {proxy_type}://{proxy_host}:{proxy_port}")
except Exception as e:
logger.warning(f"代理配置失败,将直连: {e}")
connector = None
else:
logger.warning("代理功能不可用aiohttp_socks 未安装),将直连")
connector = None
# 启用流式响应
payload["stream"] = True
try:
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
logger.debug(f"发送流式 API 请求: {api_config['url']}")
async with session.post(
api_config["url"],
json=payload,
headers=headers
) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"API 返回错误状态码: {resp.status}, 响应: {error_text}")
raise Exception(f"API 错误 {resp.status}: {error_text}")
# 流式接收响应
import json
full_content = ""
tool_calls_dict = {} # 使用字典来组装工具调用 {index: tool_call}
tool_call_hint_sent = False # 是否已发送工具调用提示
async for line in resp.content:
line = line.decode('utf-8').strip()
if not line or line == "data: [DONE]":
continue
if line.startswith("data: "):
try:
data = json.loads(line[6:])
choices = data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
# 收集内容
content = delta.get("content", "")
if content:
full_content += content
# 收集工具调用(增量式组装)
if delta.get("tool_calls"):
# 第一次检测到工具调用时,如果有文本内容则立即发送
if not tool_call_hint_sent and bot and from_wxid:
tool_call_hint_sent = True
# 只有当 AI 有文本输出时才发送
if full_content and full_content.strip():
logger.info(f"[流式] 检测到工具调用,先发送已有文本: {full_content[:30]}...")
await bot.send_text(from_wxid, full_content.strip())
else:
# AI 没有输出文本,不发送默认提示
logger.info("[流式] 检测到工具调用AI 未输出文本")
for tool_call_delta in delta["tool_calls"]:
index = tool_call_delta.get("index", 0)
# 初始化工具调用
if index not in tool_calls_dict:
tool_calls_dict[index] = {
"id": "",
"type": "function",
"function": {
"name": "",
"arguments": ""
}
}
# 更新 id
if "id" in tool_call_delta:
tool_calls_dict[index]["id"] = tool_call_delta["id"]
# 更新 type
if "type" in tool_call_delta:
tool_calls_dict[index]["type"] = tool_call_delta["type"]
# 更新 function
if "function" in tool_call_delta:
func_delta = tool_call_delta["function"]
if "name" in func_delta:
tool_calls_dict[index]["function"]["name"] += func_delta["name"]
if "arguments" in func_delta:
tool_calls_dict[index]["function"]["arguments"] += func_delta["arguments"]
except Exception as e:
logger.debug(f"解析流式数据失败: {e}")
pass
# 转换为列表
tool_calls_data = [tool_calls_dict[i] for i in sorted(tool_calls_dict.keys())] if tool_calls_dict else []
logger.info(f"流式 API 响应完成, 内容长度: {len(full_content)}, 工具调用数: {len(tool_calls_data)}")
# 检查是否有函数调用
if tool_calls_data:
# 提示已在流式处理中发送,直接启动异步工具执行
logger.info(f"启动异步工具执行,共 {len(tool_calls_data)} 个工具")
asyncio.create_task(
self._execute_tools_async(
tool_calls_data, bot, from_wxid, chat_id,
nickname, is_group, messages
)
)
# 返回 None 表示工具调用已异步处理,不需要重试
return None
# 检查是否包含错误的工具调用格式
if "<tool_code>" in full_content or "print(" in full_content and "flow2_ai_image_generation" in full_content:
logger.warning("检测到模型输出了错误的工具调用格式,拦截并返回提示")
return "抱歉,我遇到了一些技术问题,请重新描述一下你的需求~"
return full_content.strip()
except aiohttp.ClientError as e:
logger.error(f"网络请求失败: {type(e).__name__}: {str(e)}")
raise Exception(f"网络请求失败: {str(e)}")
except asyncio.TimeoutError:
logger.error(f"API 请求超时 (timeout={api_config['timeout']}s)")
raise Exception(f"API 请求超时")
except KeyError as e:
logger.error(f"API 响应格式错误,缺少字段: {e}")
raise Exception(f"API 响应格式错误: {e}")
def _get_history_file(self, chat_id: str) -> Path:
"""获取群聊历史记录文件路径"""
if not self.history_dir:
return None
safe_name = chat_id.replace("@", "_").replace(":", "_")
return self.history_dir / f"{safe_name}.json"
def _get_history_lock(self, chat_id: str) -> asyncio.Lock:
"""获取指定会话的锁, 每个会话一把"""
lock = self.history_locks.get(chat_id)
if lock is None:
lock = asyncio.Lock()
self.history_locks[chat_id] = lock
return lock
def _read_history_file(self, history_file: Path) -> list:
try:
import json
with open(history_file, "r", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError:
return []
except Exception as e:
logger.error(f"读取历史记录失败: {e}")
return []
def _write_history_file(self, history_file: Path, history: list):
import json
history_file.parent.mkdir(parents=True, exist_ok=True)
temp_file = Path(str(history_file) + ".tmp")
with open(temp_file, "w", encoding="utf-8") as f:
json.dump(history, f, ensure_ascii=False, indent=2)
temp_file.replace(history_file)
def _use_redis_for_group_history(self) -> bool:
"""检查是否使用 Redis 存储群聊历史"""
redis_config = self.config.get("redis", {})
if not redis_config.get("use_redis_history", True):
return False
redis_cache = get_cache()
return redis_cache and redis_cache.enabled
async def _load_history(self, chat_id: str) -> list:
"""异步读取群聊历史, 优先使用 Redis"""
# 优先使用 Redis
if self._use_redis_for_group_history():
redis_cache = get_cache()
max_history = self.config.get("history", {}).get("max_history", 100)
return redis_cache.get_group_history(chat_id, max_history)
# 降级到文件存储
history_file = self._get_history_file(chat_id)
if not history_file:
return []
lock = self._get_history_lock(chat_id)
async with lock:
return self._read_history_file(history_file)
async def _save_history(self, chat_id: str, history: list):
"""异步写入群聊历史, 包含长度截断"""
# Redis 模式下不需要单独保存add_group_message 已经处理
if self._use_redis_for_group_history():
return
history_file = self._get_history_file(chat_id)
if not history_file:
return
max_history = self.config.get("history", {}).get("max_history", 100)
if len(history) > max_history:
history = history[-max_history:]
lock = self._get_history_lock(chat_id)
async with lock:
self._write_history_file(history_file, history)
async def _add_to_history(self, chat_id: str, nickname: str, content: str, image_base64: str = None):
"""
将消息存入群聊历史
Args:
chat_id: 群聊ID
nickname: 用户昵称
content: 消息内容
image_base64: 可选的图片base64
"""
if not self.config.get("history", {}).get("enabled", True):
return
# 构建消息内容
if image_base64:
message_content = [
{"type": "text", "text": content},
{"type": "image_url", "image_url": {"url": image_base64}}
]
else:
message_content = content
# 优先使用 Redis
if self._use_redis_for_group_history():
redis_cache = get_cache()
redis_config = self.config.get("redis", {})
ttl = redis_config.get("group_history_ttl", 172800)
redis_cache.add_group_message(chat_id, nickname, message_content, ttl=ttl)
# 裁剪历史
max_history = self.config.get("history", {}).get("max_history", 100)
redis_cache.trim_group_history(chat_id, max_history)
return
# 降级到文件存储
history_file = self._get_history_file(chat_id)
if not history_file:
return
lock = self._get_history_lock(chat_id)
async with lock:
history = self._read_history_file(history_file)
message_record = {
"nickname": nickname,
"timestamp": datetime.now().isoformat(),
"content": message_content
}
history.append(message_record)
max_history = self.config.get("history", {}).get("max_history", 100)
if len(history) > max_history:
history = history[-max_history:]
self._write_history_file(history_file, history)
async def _add_to_history_with_id(self, chat_id: str, nickname: str, content: str, record_id: str):
"""带ID的历史追加, 便于后续更新"""
if not self.config.get("history", {}).get("enabled", True):
return
# 优先使用 Redis
if self._use_redis_for_group_history():
redis_cache = get_cache()
redis_config = self.config.get("redis", {})
ttl = redis_config.get("group_history_ttl", 172800)
redis_cache.add_group_message(chat_id, nickname, content, record_id=record_id, ttl=ttl)
# 裁剪历史
max_history = self.config.get("history", {}).get("max_history", 100)
redis_cache.trim_group_history(chat_id, max_history)
return
# 降级到文件存储
history_file = self._get_history_file(chat_id)
if not history_file:
return
lock = self._get_history_lock(chat_id)
async with lock:
history = self._read_history_file(history_file)
message_record = {
"id": record_id,
"nickname": nickname,
"timestamp": datetime.now().isoformat(),
"content": content
}
history.append(message_record)
max_history = self.config.get("history", {}).get("max_history", 100)
if len(history) > max_history:
history = history[-max_history:]
self._write_history_file(history_file, history)
async def _update_history_by_id(self, chat_id: str, record_id: str, new_content: str):
"""根据ID更新历史记录"""
if not self.config.get("history", {}).get("enabled", True):
return
# 优先使用 Redis
if self._use_redis_for_group_history():
redis_cache = get_cache()
redis_cache.update_group_message_by_id(chat_id, record_id, new_content)
return
# 降级到文件存储
history_file = self._get_history_file(chat_id)
if not history_file:
return
lock = self._get_history_lock(chat_id)
async with lock:
history = self._read_history_file(history_file)
for record in history:
if record.get("id") == record_id:
record["content"] = new_content
break
max_history = self.config.get("history", {}).get("max_history", 100)
if len(history) > max_history:
history = history[-max_history:]
self._write_history_file(history_file, history)
async def _execute_tool_and_get_result(self, tool_name: str, arguments: dict, bot, from_wxid: str):
"""执行工具调用并返回结果"""
from utils.plugin_manager import PluginManager
# 添加用户信息到 arguments
arguments["user_wxid"] = getattr(self, "_current_user_wxid", from_wxid)
arguments["is_group"] = getattr(self, "_current_is_group", False)
logger.info(f"开始执行工具: {tool_name}")
plugins = PluginManager().plugins
logger.info(f"检查 {len(plugins)} 个插件")
for plugin_name, plugin in plugins.items():
logger.debug(f"检查插件: {plugin_name}, 有execute_llm_tool: {hasattr(plugin, 'execute_llm_tool')}")
if hasattr(plugin, 'execute_llm_tool'):
try:
logger.info(f"调用 {plugin_name}.execute_llm_tool")
result = await plugin.execute_llm_tool(tool_name, arguments, bot, from_wxid)
logger.info(f"{plugin_name} 返回: {result}")
if result is not None:
if result.get("success"):
logger.success(f"工具执行成功: {tool_name}")
return result
else:
logger.debug(f"{plugin_name} 不处理此工具,继续检查下一个插件")
except Exception as e:
logger.error(f"工具执行异常 ({plugin_name}): {tool_name}, {e}")
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
logger.warning(f"未找到工具: {tool_name}")
return {"success": False, "message": f"未找到工具: {tool_name}"}
async def _execute_tools_async(self, tool_calls_data: list, bot, from_wxid: str,
chat_id: str, nickname: str, is_group: bool,
messages: list):
"""
异步执行工具调用(不阻塞主流程)
AI 已经先回复用户,这里异步执行工具,完成后发送结果
支持 need_ai_reply 标记:工具结果回传给 AI 继续对话(保留上下文和人设)
"""
import json
try:
logger.info(f"开始异步执行 {len(tool_calls_data)} 个工具调用")
# 并行执行所有工具
tasks = []
tool_info_list = [] # 保存工具信息用于后续处理
for tool_call in tool_calls_data:
function_name = tool_call.get("function", {}).get("name", "")
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
tool_call_id = tool_call.get("id", "")
if not function_name:
continue
try:
arguments = json.loads(arguments_str)
except:
arguments = {}
logger.info(f"[异步] 准备执行工具: {function_name}, 参数: {arguments}")
# 创建异步任务
task = self._execute_tool_and_get_result(function_name, arguments, bot, from_wxid)
tasks.append(task)
tool_info_list.append({
"tool_call_id": tool_call_id,
"function_name": function_name,
"arguments": arguments
})
# 并行执行所有工具
if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True)
# 收集需要 AI 回复的工具结果
need_ai_reply_results = []
# 处理每个工具的结果
for i, result in enumerate(results):
tool_info = tool_info_list[i]
function_name = tool_info["function_name"]
tool_call_id = tool_info["tool_call_id"]
if isinstance(result, Exception):
logger.error(f"[异步] 工具 {function_name} 执行异常: {result}")
# 发送错误提示
await bot.send_text(from_wxid, f"{function_name} 执行失败")
continue
if result and result.get("success"):
logger.success(f"[异步] 工具 {function_name} 执行成功")
# 检查是否需要 AI 基于工具结果继续回复
if result.get("need_ai_reply"):
need_ai_reply_results.append({
"tool_call_id": tool_call_id,
"function_name": function_name,
"result": result.get("message", "")
})
continue # 不直接发送,等待 AI 处理
# 如果工具没有自己发送内容,且有消息需要发送
if not result.get("already_sent") and result.get("message"):
# 某些工具可能需要发送结果消息
msg = result.get("message", "")
if msg and not result.get("no_reply"):
# 检查是否需要发送文本结果
if result.get("send_result_text"):
await bot.send_text(from_wxid, msg)
# 保存工具结果到记忆(可选)
if result.get("save_to_memory") and chat_id:
self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {result.get('message', '')}")
else:
logger.warning(f"[异步] 工具 {function_name} 执行失败: {result}")
if result and result.get("message"):
await bot.send_text(from_wxid, f"{result.get('message')}")
# 如果有需要 AI 回复的工具结果,调用 AI 继续对话
if need_ai_reply_results:
await self._continue_with_tool_results(
need_ai_reply_results, bot, from_wxid, chat_id,
nickname, is_group, messages, tool_calls_data
)
logger.info(f"[异步] 所有工具执行完成")
except Exception as e:
logger.error(f"[异步] 工具执行总体异常: {e}")
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
try:
await bot.send_text(from_wxid, "❌ 工具执行过程中出现错误")
except:
pass
async def _continue_with_tool_results(self, tool_results: list, bot, from_wxid: str,
chat_id: str, nickname: str, is_group: bool,
messages: list, tool_calls_data: list):
"""
基于工具结果继续调用 AI 对话(保留上下文和人设)
用于 need_ai_reply=True 的工具,如联网搜索等
"""
import json
try:
logger.info(f"[工具回传] 开始基于 {len(tool_results)} 个工具结果继续对话")
# 构建包含工具调用和结果的消息
# 1. 添加 assistant 的工具调用消息
tool_calls_msg = []
for tool_call in tool_calls_data:
tool_call_id = tool_call.get("id", "")
function_name = tool_call.get("function", {}).get("name", "")
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
# 只添加需要 AI 回复的工具
for tr in tool_results:
if tr["tool_call_id"] == tool_call_id:
tool_calls_msg.append({
"id": tool_call_id,
"type": "function",
"function": {
"name": function_name,
"arguments": arguments_str
}
})
break
if tool_calls_msg:
messages.append({
"role": "assistant",
"content": None,
"tool_calls": tool_calls_msg
})
# 2. 添加工具结果消息
for tr in tool_results:
messages.append({
"role": "tool",
"tool_call_id": tr["tool_call_id"],
"content": tr["result"]
})
# 3. 调用 AI 继续对话(不带 tools 参数,避免再次调用工具)
api_config = self.config["api"]
proxy_config = self.config.get("proxy", {})
payload = {
"model": api_config["model"],
"messages": messages,
"max_tokens": api_config.get("max_tokens", 4096),
"stream": True
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_config['api_key']}"
}
proxy = None
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "http")
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
timeout = aiohttp.ClientTimeout(total=api_config.get("timeout", 120))
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(
api_config["url"],
json=payload,
headers=headers,
proxy=proxy
) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"[工具回传] AI API 错误: {resp.status}, {error_text}")
await bot.send_text(from_wxid, "❌ AI 处理搜索结果失败")
return
# 流式读取响应
full_content = ""
async for line in resp.content:
line = line.decode("utf-8").strip()
if not line or not line.startswith("data: "):
continue
if line == "data: [DONE]":
break
try:
data = json.loads(line[6:])
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
full_content += content
except:
continue
# 发送 AI 的回复
if full_content.strip():
await bot.send_text(from_wxid, full_content.strip())
logger.success(f"[工具回传] AI 回复完成,长度: {len(full_content)}")
# 保存到历史记录
if chat_id:
self._add_to_memory(chat_id, "assistant", full_content.strip())
else:
logger.warning("[工具回传] AI 返回空内容")
except Exception as e:
logger.error(f"[工具回传] 继续对话失败: {e}")
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
try:
await bot.send_text(from_wxid, "❌ 处理搜索结果时出错")
except:
pass
async def _execute_tools_async_with_image(self, tool_calls_data: list, bot, from_wxid: str,
chat_id: str, nickname: str, is_group: bool,
messages: list, image_base64: str):
"""
异步执行工具调用(带图片参数,用于图生图等场景)
AI 已经先回复用户,这里异步执行工具,完成后发送结果
"""
import json
try:
logger.info(f"[异步-图片] 开始执行 {len(tool_calls_data)} 个工具调用")
# 并行执行所有工具
tasks = []
tool_info_list = []
for tool_call in tool_calls_data:
function_name = tool_call.get("function", {}).get("name", "")
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
tool_call_id = tool_call.get("id", "")
if not function_name:
continue
try:
arguments = json.loads(arguments_str)
except:
arguments = {}
# 如果是图生图工具,添加图片 base64
if function_name == "flow2_ai_image_generation" and image_base64:
arguments["image_base64"] = image_base64
logger.info(f"[异步-图片] 图生图工具,已添加图片数据")
logger.info(f"[异步-图片] 准备执行工具: {function_name}")
task = self._execute_tool_and_get_result(function_name, arguments, bot, from_wxid)
tasks.append(task)
tool_info_list.append({
"tool_call_id": tool_call_id,
"function_name": function_name,
"arguments": arguments
})
# 并行执行所有工具
if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(results):
tool_info = tool_info_list[i]
function_name = tool_info["function_name"]
if isinstance(result, Exception):
logger.error(f"[异步-图片] 工具 {function_name} 执行异常: {result}")
await bot.send_text(from_wxid, f"{function_name} 执行失败")
continue
if result and result.get("success"):
logger.success(f"[异步-图片] 工具 {function_name} 执行成功")
if not result.get("already_sent") and result.get("message"):
msg = result.get("message", "")
if msg and not result.get("no_reply") and result.get("send_result_text"):
await bot.send_text(from_wxid, msg)
if result.get("save_to_memory") and chat_id:
self._add_to_memory(chat_id, "assistant", f"[工具 {function_name} 结果]: {result.get('message', '')}")
else:
logger.warning(f"[异步-图片] 工具 {function_name} 执行失败: {result}")
if result and result.get("message"):
await bot.send_text(from_wxid, f"{result.get('message')}")
logger.info(f"[异步-图片] 所有工具执行完成")
except Exception as e:
logger.error(f"[异步-图片] 工具执行总体异常: {e}")
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
try:
await bot.send_text(from_wxid, "❌ 工具执行过程中出现错误")
except:
pass
@on_quote_message(priority=79)
async def handle_quote_message(self, bot, message: dict):
"""处理引用消息(包含图片或记录指令)"""
content = message.get("Content", "").strip()
from_wxid = message.get("FromWxid", "")
sender_wxid = message.get("SenderWxid", "")
is_group = message.get("IsGroup", False)
user_wxid = sender_wxid if is_group else from_wxid
try:
# 解析XML获取标题和引用消息
root = ET.fromstring(content)
title = root.find(".//title")
if title is None or not title.text:
logger.debug("引用消息没有标题,跳过")
return True
title_text = title.text.strip()
logger.info(f"收到引用消息,标题: {title_text[:50]}...")
# 检查是否是 /记录 指令(引用消息记录)
if title_text == "/记录" or title_text.startswith("/记录 "):
# 获取被引用的消息内容
refermsg = root.find(".//refermsg")
if refermsg is not None:
# 获取被引用消息的发送者昵称
refer_displayname = refermsg.find("displayname")
refer_nickname = refer_displayname.text if refer_displayname is not None and refer_displayname.text else "未知"
# 获取被引用消息的内容
refer_content_elem = refermsg.find("content")
if refer_content_elem is not None and refer_content_elem.text:
refer_text = refer_content_elem.text.strip()
# 如果是XML格式如图片尝试提取文本描述
if refer_text.startswith("<?xml") or refer_text.startswith("<"):
refer_text = f"[多媒体消息]"
else:
refer_text = "[空消息]"
# 组合记忆内容:被引用者说的话
memory_content = f"{refer_nickname}: {refer_text}"
# 如果 /记录 后面有额外备注,添加到记忆中
if title_text.startswith("/记录 "):
extra_note = title_text[4:].strip()
if extra_note:
memory_content += f" (备注: {extra_note})"
# 保存到持久记忆
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
memory_chat_id = from_wxid if is_group else user_wxid
chat_type = "group" if is_group else "private"
memory_id = self._add_persistent_memory(
memory_chat_id, chat_type, user_wxid, nickname, memory_content
)
await bot.send_text(from_wxid, f"✅ 已记录到持久记忆 (ID: {memory_id})\n📝 {memory_content[:50]}...")
logger.info(f"通过引用添加持久记忆: {memory_chat_id} - {memory_content[:30]}...")
else:
await bot.send_text(from_wxid, "❌ 无法获取被引用的消息")
return False
# 检查是否应该回复
if not self._should_reply_quote(message, title_text):
logger.debug("引用消息不满足回复条件")
return True
# 获取引用消息中的图片信息
refermsg = root.find(".//refermsg")
if refermsg is None:
logger.debug("引用消息中没有 refermsg 节点")
return True
refer_content = refermsg.find("content")
if refer_content is None or not refer_content.text:
logger.debug("引用消息中没有 content")
return True
# 解码HTML实体
import html
refer_xml = html.unescape(refer_content.text)
refer_root = ET.fromstring(refer_xml)
# 尝试提取图片信息
img = refer_root.find(".//img")
# 尝试提取视频信息
video = refer_root.find(".//videomsg")
if img is None and video is None:
logger.debug("引用的消息不是图片或视频")
return True
# 检查是否应该回复(提前检查,避免下载后才发现不需要回复)
if not self._should_reply_quote(message, title_text):
logger.debug("引用消息不满足回复条件")
return True
# 限流检查
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
if not allowed:
rate_limit_config = self.config.get("rate_limit", {})
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
msg = msg.format(seconds=reset_time)
await bot.send_text(from_wxid, msg)
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
return False
# 获取用户昵称 - 使用缓存优化
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
# 处理视频消息
if video is not None:
return await self._handle_quote_video(
bot, video, title_text, from_wxid, user_wxid,
is_group, nickname, chat_id
)
# 处理图片消息
cdnbigimgurl = img.get("cdnbigimgurl", "")
aeskey = img.get("aeskey", "")
if not cdnbigimgurl or not aeskey:
logger.warning(f"图片信息不完整: cdnurl={bool(cdnbigimgurl)}, aeskey={bool(aeskey)}")
return True
logger.info(f"AI处理引用图片消息: {title_text[:50]}...")
# 下载并编码图片
logger.info(f"开始下载图片: {cdnbigimgurl[:50]}...")
image_base64 = await self._download_and_encode_image(bot, cdnbigimgurl, aeskey)
if not image_base64:
logger.error("图片下载失败")
await bot.send_text(from_wxid, "❌ 无法处理图片")
return False
logger.info("图片下载和编码成功")
# 添加消息到记忆包含图片base64
self._add_to_memory(chat_id, "user", title_text, image_base64=image_base64)
# 保存用户引用图片消息到群组历史记录
if is_group:
await self._add_to_history(from_wxid, nickname, title_text, image_base64=image_base64)
# 调用AI API带图片
response = await self._call_ai_api_with_image(title_text, image_base64, bot, from_wxid, chat_id, nickname, user_wxid, is_group)
if response:
await bot.send_text(from_wxid, response)
self._add_to_memory(chat_id, "assistant", response)
# 保存机器人回复到历史记录
if is_group:
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
await self._add_to_history(from_wxid, bot_nickname, response)
logger.success(f"AI回复成功: {response[:50]}...")
return False
except Exception as e:
logger.error(f"处理引用消息失败: {e}")
return True
async def _handle_quote_video(self, bot, video_elem, title_text: str, from_wxid: str,
user_wxid: str, is_group: bool, nickname: str, chat_id: str):
"""处理引用的视频消息 - 双AI架构"""
try:
# 检查视频识别功能是否启用
video_config = self.config.get("video_recognition", {})
if not video_config.get("enabled", True):
logger.info("[视频识别] 功能未启用")
await bot.send_text(from_wxid, "❌ 视频识别功能未启用")
return False
# 提取视频 CDN 信息
cdnvideourl = video_elem.get("cdnvideourl", "")
aeskey = video_elem.get("aeskey", "")
# 如果主要的CDN信息为空尝试获取原始视频信息
if not cdnvideourl or not aeskey:
cdnvideourl = video_elem.get("cdnrawvideourl", "")
aeskey = video_elem.get("cdnrawvideoaeskey", "")
if not cdnvideourl or not aeskey:
logger.warning(f"[视频识别] 视频信息不完整: cdnurl={bool(cdnvideourl)}, aeskey={bool(aeskey)}")
await bot.send_text(from_wxid, "❌ 无法获取视频信息")
return False
logger.info(f"[视频识别] 处理引用视频: {title_text[:50]}...")
# 提示用户正在处理
await bot.send_text(from_wxid, "🎬 正在分析视频,请稍候...")
# 下载并编码视频
video_base64 = await self._download_and_encode_video(bot, cdnvideourl, aeskey)
if not video_base64:
logger.error("[视频识别] 视频下载失败")
await bot.send_text(from_wxid, "❌ 视频下载失败")
return False
logger.info("[视频识别] 视频下载和编码成功")
# ========== 第一步视频AI 分析视频内容 ==========
video_description = await self._analyze_video_content(video_base64, video_config)
if not video_description:
logger.error("[视频识别] 视频AI分析失败")
await bot.send_text(from_wxid, "❌ 视频分析失败")
return False
logger.info(f"[视频识别] 视频AI分析完成: {video_description[:100]}...")
# ========== 第二步主AI 基于视频描述生成回复 ==========
# 构造包含视频描述的用户消息
user_question = title_text.strip() if title_text.strip() else "这个视频讲了什么?"
combined_message = f"[用户发送了一个视频,以下是视频内容描述]\n{video_description}\n\n[用户的问题]\n{user_question}"
# 添加到记忆让主AI知道用户发了视频
self._add_to_memory(chat_id, "user", combined_message)
# 如果是群聊,添加到历史记录
if is_group:
await self._add_to_history(from_wxid, nickname, f"[发送了一个视频] {user_question}")
# 调用主AI生成回复使用现有的 _call_ai_api 方法,继承完整上下文)
response = await self._call_ai_api(combined_message, chat_id, from_wxid, is_group, nickname)
if response:
await bot.send_text(from_wxid, response)
self._add_to_memory(chat_id, "assistant", response)
# 保存机器人回复到历史记录
if is_group:
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
await self._add_to_history(from_wxid, bot_nickname, response)
logger.success(f"[视频识别] 主AI回复成功: {response[:50]}...")
else:
await bot.send_text(from_wxid, "❌ AI 回复生成失败")
return False
except Exception as e:
logger.error(f"[视频识别] 处理视频失败: {e}")
import traceback
logger.error(traceback.format_exc())
await bot.send_text(from_wxid, "❌ 视频处理出错")
return False
async def _analyze_video_content(self, video_base64: str, video_config: dict) -> str:
"""视频AI专门分析视频内容生成客观描述"""
try:
api_url = video_config.get("api_url", "https://api.functen.cn/v1beta/models")
api_key = video_config.get("api_key", self.config["api"]["api_key"])
model = video_config.get("model", "gemini-3-pro-preview")
full_url = f"{api_url}/{model}:generateContent"
# 去除 data:video/mp4;base64, 前缀(如果有)
if video_base64.startswith("data:"):
video_base64 = video_base64.split(",", 1)[1]
logger.debug("[视频AI] 已去除 base64 前缀")
# 视频分析专用提示词
analyze_prompt = """请详细分析这个视频的内容,包括:
1. 视频的主要场景和环境
2. 出现的人物/物体及其动作
3. 视频中的文字、对话或声音(如果有)
4. 视频的整体主题或要表达的内容
5. 任何值得注意的细节
请用客观、详细的方式描述,不要加入主观评价。"""
payload = {
"contents": [
{
"parts": [
{"text": analyze_prompt},
{
"inline_data": {
"mime_type": "video/mp4",
"data": video_base64
}
}
]
}
],
"generationConfig": {
"maxOutputTokens": video_config.get("max_tokens", 8192)
}
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
timeout = aiohttp.ClientTimeout(total=video_config.get("timeout", 360))
logger.info(f"[视频AI] 开始分析视频...")
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(full_url, json=payload, headers=headers) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"[视频AI] API 错误: {resp.status}, {error_text[:300]}")
return ""
result = await resp.json()
logger.info(f"[视频AI] API 响应 keys: {list(result.keys())}")
# 检查安全过滤
if "promptFeedback" in result:
feedback = result["promptFeedback"]
if feedback.get("blockReason"):
logger.warning(f"[视频AI] 内容被过滤: {feedback.get('blockReason')}")
return ""
# 提取文本
if "candidates" in result and result["candidates"]:
for candidate in result["candidates"]:
# 检查是否被安全过滤
if candidate.get("finishReason") == "SAFETY":
logger.warning("[视频AI] 响应被安全过滤")
return ""
content = candidate.get("content", {})
for part in content.get("parts", []):
if "text" in part:
text = part["text"]
logger.info(f"[视频AI] 分析完成,长度: {len(text)}")
return text
# 记录失败原因
if "usageMetadata" in result:
usage = result["usageMetadata"]
logger.warning(f"[视频AI] 无响应Token: prompt={usage.get('promptTokenCount', 0)}")
logger.error(f"[视频AI] 没有有效响应: {str(result)[:300]}")
return ""
except asyncio.TimeoutError:
logger.error(f"[视频AI] 请求超时")
return ""
except Exception as e:
logger.error(f"[视频AI] 分析失败: {e}")
import traceback
logger.error(traceback.format_exc())
return ""
async def _download_and_encode_video(self, bot, cdnurl: str, aeskey: str) -> str:
"""下载视频并转换为 base64"""
try:
# 从缓存获取
from utils.redis_cache import RedisCache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
if media_key:
cached_data = redis_cache.get_cached_media(media_key, "video")
if cached_data:
logger.debug(f"[视频识别] 从缓存获取视频: {media_key[:20]}...")
return cached_data
# 下载视频
logger.info(f"[视频识别] 开始下载视频...")
temp_dir = Path(__file__).parent / "temp"
temp_dir.mkdir(exist_ok=True)
filename = f"video_{uuid.uuid4().hex[:8]}.mp4"
save_path = str((temp_dir / filename).resolve())
# file_type=4 表示视频
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=4)
if not success:
logger.error("[视频识别] CDN 下载失败")
return ""
# 等待文件写入完成
import os
for _ in range(30): # 最多等待15秒
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
break
await asyncio.sleep(0.5)
if not os.path.exists(save_path):
logger.error("[视频识别] 视频文件未生成")
return ""
file_size = os.path.getsize(save_path)
logger.info(f"[视频识别] 视频下载完成,大小: {file_size / 1024 / 1024:.2f} MB")
# 检查文件大小限制
video_config = self.config.get("video_recognition", {})
max_size_mb = video_config.get("max_size_mb", 20)
if file_size > max_size_mb * 1024 * 1024:
logger.warning(f"[视频识别] 视频文件过大: {file_size / 1024 / 1024:.2f} MB > {max_size_mb} MB")
try:
Path(save_path).unlink()
except:
pass
return ""
# 读取并编码为 base64
with open(save_path, "rb") as f:
video_data = base64.b64encode(f.read()).decode()
video_base64 = f"data:video/mp4;base64,{video_data}"
# 缓存到 Redis
if redis_cache and redis_cache.enabled and media_key:
redis_cache.cache_media(media_key, video_base64, "video", ttl=600)
logger.debug(f"[视频识别] 视频已缓存: {media_key[:20]}...")
# 清理临时文件
try:
Path(save_path).unlink()
except:
pass
return video_base64
except Exception as e:
logger.error(f"[视频识别] 下载视频失败: {e}")
import traceback
logger.error(traceback.format_exc())
return ""
async def _call_ai_api_with_video(self, user_message: str, video_base64: str, bot=None,
from_wxid: str = None, chat_id: str = None,
nickname: str = "", user_wxid: str = None,
is_group: bool = False) -> str:
"""调用 Gemini 原生 API带视频- 继承完整上下文"""
try:
video_config = self.config.get("video_recognition", {})
# 使用视频识别专用配置
video_model = video_config.get("model", "gemini-3-pro-preview")
api_url = video_config.get("api_url", "https://api.functen.cn/v1beta/models")
api_key = video_config.get("api_key", self.config["api"]["api_key"])
# 构建完整的 API URL
full_url = f"{api_url}/{video_model}:generateContent"
# 构建系统提示(与 _call_ai_api 保持一致)
system_content = self.system_prompt
current_time = datetime.now()
weekday_map = {
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
4: "星期五", 5: "星期六", 6: "星期日"
}
weekday = weekday_map[current_time.weekday()]
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
system_content += f"\n\n当前时间:{time_str}"
if nickname:
system_content += f"\n当前对话用户的昵称是:{nickname}"
# 加载持久记忆
memory_chat_id = from_wxid if is_group else user_wxid
if memory_chat_id:
persistent_memories = self._get_persistent_memories(memory_chat_id)
if persistent_memories:
system_content += "\n\n【持久记忆】以下是用户要求你记住的重要信息:\n"
for m in persistent_memories:
mem_time = m['time'][:10] if m['time'] else ""
system_content += f"- [{mem_time}] {m['nickname']}: {m['content']}\n"
# 构建历史上下文
history_context = ""
if is_group and from_wxid:
# 群聊:从 Redis/文件加载历史
history = await self._load_history(from_wxid)
max_context = self.config.get("history", {}).get("max_context", 50)
recent_history = history[-max_context:] if len(history) > max_context else history
if recent_history:
history_context = "\n\n【最近的群聊记录】\n"
for msg in recent_history:
msg_nickname = msg.get("nickname", "")
msg_content = msg.get("content", "")
if isinstance(msg_content, list):
# 多模态内容,提取文本
for item in msg_content:
if item.get("type") == "text":
msg_content = item.get("text", "")
break
else:
msg_content = "[图片]"
# 限制单条消息长度
if len(str(msg_content)) > 200:
msg_content = str(msg_content)[:200] + "..."
history_context += f"[{msg_nickname}] {msg_content}\n"
else:
# 私聊:从 memory 加载
if chat_id:
memory_messages = self._get_memory_messages(chat_id)
if memory_messages:
history_context = "\n\n【最近的对话记录】\n"
for msg in memory_messages[-20:]: # 最近20条
role = msg.get("role", "")
content = msg.get("content", "")
if isinstance(content, list):
for item in content:
if item.get("type") == "text":
content = item.get("text", "")
break
else:
content = "[图片]"
role_name = "用户" if role == "user" else ""
if len(str(content)) > 200:
content = str(content)[:200] + "..."
history_context += f"[{role_name}] {content}\n"
# 从 data:video/mp4;base64,xxx 中提取纯 base64 数据
if video_base64.startswith("data:"):
video_base64 = video_base64.split(",", 1)[1]
# 构建完整提示(人设 + 历史 + 当前问题)
full_prompt = system_content + history_context + f"\n\n【当前】用户发送了一个视频并问:{user_message or '请描述这个视频的内容'}"
# 构建 Gemini 原生格式请求
payload = {
"contents": [
{
"parts": [
{"text": full_prompt},
{
"inline_data": {
"mime_type": "video/mp4",
"data": video_base64
}
}
]
}
],
"generationConfig": {
"maxOutputTokens": video_config.get("max_tokens", 8192)
}
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
timeout = aiohttp.ClientTimeout(total=video_config.get("timeout", 360))
# 配置代理
connector = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False) and PROXY_SUPPORT:
proxy_type = proxy_config.get("type", "socks5").upper()
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
try:
connector = ProxyConnector.from_url(proxy_url)
except Exception as e:
logger.warning(f"[视频识别] 代理配置失败: {e}")
logger.info(f"[视频识别] 调用 Gemini API: {full_url}")
logger.debug(f"[视频识别] 提示词长度: {len(full_prompt)} 字符")
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.post(full_url, json=payload, headers=headers) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"[视频识别] API 错误: {resp.status}, {error_text[:500]}")
return ""
# 解析 Gemini 响应格式
result = await resp.json()
# 详细记录响应(用于调试)
logger.info(f"[视频识别] API 响应 keys: {list(result.keys()) if isinstance(result, dict) else type(result)}")
# 检查是否有错误
if "error" in result:
logger.error(f"[视频识别] API 返回错误: {result['error']}")
return ""
# 检查 promptFeedback安全过滤信息
if "promptFeedback" in result:
feedback = result["promptFeedback"]
block_reason = feedback.get("blockReason", "")
if block_reason:
logger.warning(f"[视频识别] 请求被阻止,原因: {block_reason}")
logger.warning(f"[视频识别] 安全评级: {feedback.get('safetyRatings', [])}")
return "抱歉,视频内容无法分析(内容策略限制)。"
# 提取文本内容
full_content = ""
if "candidates" in result and result["candidates"]:
logger.info(f"[视频识别] candidates 数量: {len(result['candidates'])}")
for i, candidate in enumerate(result["candidates"]):
# 检查 finishReason
finish_reason = candidate.get("finishReason", "")
if finish_reason:
logger.info(f"[视频识别] candidate[{i}] finishReason: {finish_reason}")
if finish_reason == "SAFETY":
logger.warning(f"[视频识别] 内容被安全过滤: {candidate.get('safetyRatings', [])}")
return "抱歉,视频内容无法分析。"
content = candidate.get("content", {})
parts = content.get("parts", [])
logger.info(f"[视频识别] candidate[{i}] parts 数量: {len(parts)}")
for part in parts:
if "text" in part:
full_content += part["text"]
else:
# 没有 candidates记录完整响应
logger.error(f"[视频识别] 响应中没有 candidates: {str(result)[:500]}")
# 可能是上下文太长导致,记录 token 使用情况
if "usageMetadata" in result:
usage = result["usageMetadata"]
logger.warning(f"[视频识别] Token 使用: prompt={usage.get('promptTokenCount', 0)}, total={usage.get('totalTokenCount', 0)}")
logger.info(f"[视频识别] AI 响应完成,长度: {len(full_content)}")
# 如果没有内容,尝试简化重试
if not full_content:
logger.info("[视频识别] 尝试简化请求重试...")
return await self._call_ai_api_with_video_simple(
user_message or "请描述这个视频的内容",
video_base64,
video_config
)
return full_content.strip()
except Exception as e:
logger.error(f"[视频识别] API 调用失败: {e}")
import traceback
logger.error(traceback.format_exc())
return ""
async def _call_ai_api_with_video_simple(self, user_message: str, video_base64: str, video_config: dict) -> str:
"""简化版视频识别 API 调用(不带上下文,用于降级重试)"""
try:
api_url = video_config.get("api_url", "https://api.functen.cn/v1beta/models")
api_key = video_config.get("api_key", self.config["api"]["api_key"])
model = video_config.get("model", "gemini-3-pro-preview")
full_url = f"{api_url}/{model}:generateContent"
# 简化请求:只发送用户问题和视频
payload = {
"contents": [
{
"parts": [
{"text": user_message},
{
"inline_data": {
"mime_type": "video/mp4",
"data": video_base64
}
}
]
}
],
"generationConfig": {
"maxOutputTokens": video_config.get("max_tokens", 8192)
}
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
timeout = aiohttp.ClientTimeout(total=video_config.get("timeout", 360))
logger.info(f"[视频识别-简化] 调用 API: {full_url}")
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(full_url, json=payload, headers=headers) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"[视频识别-简化] API 错误: {resp.status}, {error_text[:300]}")
return ""
result = await resp.json()
logger.info(f"[视频识别-简化] API 响应 keys: {list(result.keys())}")
# 提取文本
if "candidates" in result and result["candidates"]:
for candidate in result["candidates"]:
content = candidate.get("content", {})
for part in content.get("parts", []):
if "text" in part:
text = part["text"]
logger.info(f"[视频识别-简化] 成功,长度: {len(text)}")
return text
logger.error(f"[视频识别-简化] 仍然没有 candidates: {str(result)[:300]}")
return ""
except Exception as e:
logger.error(f"[视频识别-简化] 失败: {e}")
return ""
def _should_reply_quote(self, message: dict, title_text: str) -> bool:
"""判断是否应该回复引用消息"""
is_group = message.get("IsGroup", False)
# 检查群聊/私聊开关
if is_group and not self.config["behavior"]["reply_group"]:
return False
if not is_group and not self.config["behavior"]["reply_private"]:
return False
trigger_mode = self.config["behavior"]["trigger_mode"]
# all模式回复所有消息
if trigger_mode == "all":
return True
# mention模式检查是否@了机器人
if trigger_mode == "mention":
if is_group:
ats = message.get("Ats", [])
if not ats:
return False
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
return bot_wxid and bot_wxid in ats
else:
return True
# keyword模式检查关键词
if trigger_mode == "keyword":
keywords = self.config["behavior"]["keywords"]
return any(kw in title_text for kw in keywords)
return False
async def _call_ai_api_with_image(self, user_message: str, image_base64: str, bot=None, from_wxid: str = None, chat_id: str = None, nickname: str = "", user_wxid: str = None, is_group: bool = False) -> str:
"""调用AI API带图片"""
api_config = self.config["api"]
tools = self._collect_tools()
# 构建消息列表
system_content = self.system_prompt
# 添加当前时间信息
current_time = datetime.now()
weekday_map = {
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
4: "星期五", 5: "星期六", 6: "星期日"
}
weekday = weekday_map[current_time.weekday()]
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
system_content += f"\n\n当前时间:{time_str}"
if nickname:
system_content += f"\n当前对话用户的昵称是:{nickname}"
messages = [{"role": "system", "content": system_content}]
# 添加历史记忆
if chat_id:
memory_messages = self._get_memory_messages(chat_id)
if memory_messages and len(memory_messages) > 1:
messages.extend(memory_messages[:-1])
# 添加当前用户消息(带图片)
messages.append({
"role": "user",
"content": [
{"type": "text", "text": user_message},
{"type": "image_url", "image_url": {"url": image_base64}}
]
})
# 保存用户信息供工具调用使用
self._current_user_wxid = user_wxid
self._current_is_group = is_group
payload = {
"model": api_config["model"],
"messages": messages,
"stream": True,
"max_tokens": api_config.get("max_tokens", 4096) # 防止回复被截断
}
if tools:
payload["tools"] = tools
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_config['api_key']}"
}
timeout = aiohttp.ClientTimeout(total=api_config["timeout"])
# 配置代理
connector = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5").upper()
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_username = proxy_config.get("username")
proxy_password = proxy_config.get("password")
if proxy_username and proxy_password:
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
else:
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
if PROXY_SUPPORT:
try:
connector = ProxyConnector.from_url(proxy_url)
logger.debug(f"使用代理: {proxy_type}://{proxy_host}:{proxy_port}")
except Exception as e:
logger.warning(f"代理配置失败,将直连: {e}")
connector = None
else:
logger.warning("代理功能不可用aiohttp_socks 未安装),将直连")
connector = None
try:
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.post(api_config["url"], json=payload, headers=headers) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"API返回错误状态码: {resp.status}, 响应: {error_text}")
raise Exception(f"API错误 {resp.status}: {error_text}")
# 流式接收响应
import json
full_content = ""
tool_calls_dict = {} # 使用字典来组装工具调用 {index: tool_call}
tool_call_hint_sent = False # 是否已发送工具调用提示
async for line in resp.content:
line = line.decode('utf-8').strip()
if not line or line == "data: [DONE]":
continue
if line.startswith("data: "):
try:
data = json.loads(line[6:])
choices = data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
full_content += content
# 收集工具调用(增量式组装)
if delta.get("tool_calls"):
# 第一次检测到工具调用时,如果有文本内容则立即发送
if not tool_call_hint_sent and bot and from_wxid:
tool_call_hint_sent = True
if full_content and full_content.strip():
logger.info(f"[流式-图片] 检测到工具调用,先发送已有文本")
await bot.send_text(from_wxid, full_content.strip())
else:
logger.info("[流式-图片] 检测到工具调用AI 未输出文本")
for tool_call_delta in delta["tool_calls"]:
index = tool_call_delta.get("index", 0)
# 初始化工具调用
if index not in tool_calls_dict:
tool_calls_dict[index] = {
"id": "",
"type": "function",
"function": {
"name": "",
"arguments": ""
}
}
# 更新 id
if "id" in tool_call_delta:
tool_calls_dict[index]["id"] = tool_call_delta["id"]
# 更新 type
if "type" in tool_call_delta:
tool_calls_dict[index]["type"] = tool_call_delta["type"]
# 更新 function
if "function" in tool_call_delta:
func_delta = tool_call_delta["function"]
if "name" in func_delta:
tool_calls_dict[index]["function"]["name"] += func_delta["name"]
if "arguments" in func_delta:
tool_calls_dict[index]["function"]["arguments"] += func_delta["arguments"]
except Exception as e:
logger.debug(f"解析流式数据失败: {e}")
pass
# 转换为列表
tool_calls_data = [tool_calls_dict[i] for i in sorted(tool_calls_dict.keys())] if tool_calls_dict else []
# 检查是否有函数调用
if tool_calls_data:
# 提示已在流式处理中发送,直接启动异步工具执行
logger.info(f"[图片] 启动异步工具执行,共 {len(tool_calls_data)} 个工具")
asyncio.create_task(
self._execute_tools_async_with_image(
tool_calls_data, bot, from_wxid, chat_id,
nickname, is_group, messages, image_base64
)
)
return ""
# 检查是否包含错误的工具调用格式
if "<tool_code>" in full_content or "print(" in full_content and "flow2_ai_image_generation" in full_content:
logger.warning("检测到模型输出了错误的工具调用格式,拦截并返回提示")
return "抱歉,我遇到了一些技术问题,请重新描述一下你的需求~"
return full_content.strip()
except Exception as e:
logger.error(f"调用AI API失败: {e}")
raise
async def _send_chat_records(self, bot, from_wxid: str, title: str, content: str):
"""发送聊天记录格式消息"""
try:
import uuid
import time
import hashlib
import xml.etree.ElementTree as ET
is_group = from_wxid.endswith("@chatroom")
# 自动分割内容
max_length = 800
content_parts = []
if len(content) <= max_length:
content_parts = [content]
else:
lines = content.split('\n')
current_part = ""
for line in lines:
if len(current_part + line + '\n') > max_length:
if current_part:
content_parts.append(current_part.strip())
current_part = line + '\n'
else:
content_parts.append(line[:max_length])
current_part = line[max_length:] + '\n'
else:
current_part += line + '\n'
if current_part.strip():
content_parts.append(current_part.strip())
recordinfo = ET.Element("recordinfo")
info_el = ET.SubElement(recordinfo, "info")
info_el.text = title
is_group_el = ET.SubElement(recordinfo, "isChatRoom")
is_group_el.text = "1" if is_group else "0"
datalist = ET.SubElement(recordinfo, "datalist")
datalist.set("count", str(len(content_parts)))
desc_el = ET.SubElement(recordinfo, "desc")
desc_el.text = title
fromscene_el = ET.SubElement(recordinfo, "fromscene")
fromscene_el.text = "3"
for i, part in enumerate(content_parts):
di = ET.SubElement(datalist, "dataitem")
di.set("datatype", "1")
di.set("dataid", uuid.uuid4().hex)
src_local_id = str((int(time.time() * 1000) % 90000) + 10000)
new_msg_id = str(int(time.time() * 1000) + i)
create_time = str(int(time.time()) - len(content_parts) + i)
ET.SubElement(di, "srcMsgLocalid").text = src_local_id
ET.SubElement(di, "sourcetime").text = time.strftime("%Y-%m-%d %H:%M", time.localtime(int(create_time)))
ET.SubElement(di, "fromnewmsgid").text = new_msg_id
ET.SubElement(di, "srcMsgCreateTime").text = create_time
ET.SubElement(di, "sourcename").text = "AI助手"
ET.SubElement(di, "sourceheadurl").text = ""
ET.SubElement(di, "datatitle").text = part
ET.SubElement(di, "datadesc").text = part
ET.SubElement(di, "datafmt").text = "text"
ET.SubElement(di, "ischatroom").text = "1" if is_group else "0"
dataitemsource = ET.SubElement(di, "dataitemsource")
ET.SubElement(dataitemsource, "hashusername").text = hashlib.sha256(from_wxid.encode("utf-8")).hexdigest()
record_xml = ET.tostring(recordinfo, encoding="unicode")
appmsg_parts = [
"<appmsg appid=\"\" sdkver=\"0\">",
f"<title>{title}</title>",
f"<des>{title}</des>",
"<type>19</type>",
"<url>https://support.weixin.qq.com/cgi-bin/mmsupport-bin/readtemplate?t=page/favorite_record__w_unsupport</url>",
"<appattach><cdnthumbaeskey></cdnthumbaeskey><aeskey></aeskey></appattach>",
f"<recorditem><![CDATA[{record_xml}]]></recorditem>",
"<percent>0</percent>",
"</appmsg>"
]
appmsg_xml = "".join(appmsg_parts)
await bot._send_data_async(11214, {"to_wxid": from_wxid, "content": appmsg_xml})
logger.success(f"已发送聊天记录: {title}")
except Exception as e:
logger.error(f"发送聊天记录失败: {e}")
async def _process_image_to_history(self, bot, message: dict, content: str) -> bool:
"""处理图片/表情包并保存描述到 history通用方法"""
from_wxid = message.get("FromWxid", "")
sender_wxid = message.get("SenderWxid", "")
is_group = message.get("IsGroup", False)
user_wxid = sender_wxid if is_group else from_wxid
# 只处理群聊
if not is_group:
return True
# 检查是否启用图片描述功能
image_desc_config = self.config.get("image_description", {})
if not image_desc_config.get("enabled", True):
return True
try:
# 解析XML获取图片信息
root = ET.fromstring(content)
# 尝试查找 <img> 标签(图片消息)或 <emoji> 标签(表情包)
img = root.find(".//img")
if img is None:
img = root.find(".//emoji")
if img is None:
return True
cdnbigimgurl = img.get("cdnbigimgurl", "") or img.get("cdnurl", "")
aeskey = img.get("aeskey", "")
# 检查是否是表情包(有 cdnurl 但可能没有 aeskey
is_emoji = img.tag == "emoji"
if not cdnbigimgurl:
return True
# 图片消息需要 aeskey表情包不需要
if not is_emoji and not aeskey:
return True
# 获取用户昵称 - 使用缓存优化
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
# 立即插入占位符到 history
placeholder_id = str(uuid.uuid4())
await self._add_to_history_with_id(from_wxid, nickname, "[图片: 处理中...]", placeholder_id)
logger.info(f"已插入图片占位符: {placeholder_id}")
# 将任务加入队列(不阻塞)
task = {
"bot": bot,
"from_wxid": from_wxid,
"nickname": nickname,
"cdnbigimgurl": cdnbigimgurl,
"aeskey": aeskey,
"is_emoji": is_emoji,
"placeholder_id": placeholder_id,
"config": image_desc_config
}
await self.image_desc_queue.put(task)
logger.info(f"图片描述任务已加入队列,当前队列长度: {self.image_desc_queue.qsize()}")
return True
except Exception as e:
logger.error(f"处理图片消息失败: {e}")
return True
async def _image_desc_worker(self):
"""图片描述工作协程,从队列中取任务并处理"""
while True:
try:
task = await self.image_desc_queue.get()
await self._generate_and_update_image_description(
task["bot"], task["from_wxid"], task["nickname"],
task["cdnbigimgurl"], task["aeskey"], task["is_emoji"],
task["placeholder_id"], task["config"]
)
self.image_desc_queue.task_done()
except Exception as e:
logger.error(f"图片描述工作协程异常: {e}")
async def _generate_and_update_image_description(self, bot, from_wxid: str, nickname: str,
cdnbigimgurl: str, aeskey: str, is_emoji: bool,
placeholder_id: str, image_desc_config: dict):
"""异步生成图片描述并更新 history"""
try:
# 下载并编码图片/表情包
if is_emoji:
image_base64 = await self._download_emoji_and_encode(cdnbigimgurl)
else:
image_base64 = await self._download_and_encode_image(bot, cdnbigimgurl, aeskey)
if not image_base64:
logger.warning(f"{'表情包' if is_emoji else '图片'}下载失败")
await self._update_history_by_id(from_wxid, placeholder_id, "[图片]")
return
# 调用 AI 生成图片描述
description_prompt = image_desc_config.get("prompt", "请用一句话简洁地描述这张图片的主要内容。")
description = await self._generate_image_description(image_base64, description_prompt, image_desc_config)
if description:
await self._update_history_by_id(from_wxid, placeholder_id, f"[图片: {description}]")
logger.success(f"已更新图片描述: {nickname} - {description[:30]}...")
else:
await self._update_history_by_id(from_wxid, placeholder_id, "[图片]")
logger.warning(f"图片描述生成失败")
except Exception as e:
logger.error(f"异步生成图片描述失败: {e}")
await self._update_history_by_id(from_wxid, placeholder_id, "[图片]")
@on_image_message(priority=15)
async def handle_image_message(self, bot, message: dict):
"""处理直接发送的图片消息(生成描述并保存到 history不触发 AI 回复)"""
logger.info("AIChat: handle_image_message 被调用")
content = message.get("Content", "")
return await self._process_image_to_history(bot, message, content)
@on_emoji_message(priority=15)
async def handle_emoji_message(self, bot, message: dict):
"""处理表情包消息(生成描述并保存到 history不触发 AI 回复)"""
logger.info("AIChat: handle_emoji_message 被调用")
content = message.get("Content", "")
return await self._process_image_to_history(bot, message, content)