1742 lines
74 KiB
Python
1742 lines
74 KiB
Python
"""
|
||
AI 聊天插件
|
||
|
||
支持自定义模型、API 和人设
|
||
"""
|
||
|
||
import asyncio
|
||
import tomllib
|
||
import aiohttp
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
from loguru import logger
|
||
from utils.plugin_base import PluginBase
|
||
from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message
|
||
import xml.etree.ElementTree as ET
|
||
import base64
|
||
import uuid
|
||
|
||
# 可选导入代理支持
|
||
try:
|
||
from aiohttp_socks import ProxyConnector
|
||
PROXY_SUPPORT = True
|
||
except ImportError:
|
||
PROXY_SUPPORT = False
|
||
logger.warning("aiohttp_socks 未安装,代理功能将不可用")
|
||
|
||
|
||
class AIChat(PluginBase):
|
||
"""AI 聊天插件"""
|
||
|
||
# 插件元数据
|
||
description = "AI 聊天插件,支持自定义模型和人设"
|
||
author = "ShiHao"
|
||
version = "1.0.0"
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.config = None
|
||
self.system_prompt = ""
|
||
self.memory = {} # 存储每个会话的记忆 {chat_id: [messages]}
|
||
self.history_dir = None # 历史记录目录
|
||
self.history_locks = {} # 每个会话一把锁
|
||
self.image_desc_queue = asyncio.Queue() # 图片描述任务队列
|
||
self.image_desc_workers = [] # 工作协程列表
|
||
|
||
async def async_init(self):
|
||
"""插件异步初始化"""
|
||
# 读取配置
|
||
config_path = Path(__file__).parent / "config.toml"
|
||
with open(config_path, "rb") as f:
|
||
self.config = tomllib.load(f)
|
||
|
||
# 读取人设
|
||
prompt_file = self.config["prompt"]["system_prompt_file"]
|
||
prompt_path = Path(__file__).parent / "prompts" / prompt_file
|
||
|
||
if prompt_path.exists():
|
||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||
self.system_prompt = f.read().strip()
|
||
logger.success(f"已加载人设: {prompt_file}")
|
||
else:
|
||
logger.warning(f"人设文件不存在: {prompt_file},使用默认人设")
|
||
self.system_prompt = "你是一个友好的 AI 助手。"
|
||
|
||
# 检查代理配置
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5")
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
logger.info(f"AI 聊天插件已启用代理: {proxy_type}://{proxy_host}:{proxy_port}")
|
||
|
||
# 初始化历史记录目录
|
||
history_config = self.config.get("history", {})
|
||
if history_config.get("enabled", True):
|
||
history_dir_name = history_config.get("history_dir", "history")
|
||
self.history_dir = Path(__file__).parent / history_dir_name
|
||
self.history_dir.mkdir(exist_ok=True)
|
||
logger.info(f"历史记录目录: {self.history_dir}")
|
||
|
||
# 启动图片描述工作协程(并发数为2)
|
||
for i in range(2):
|
||
worker = asyncio.create_task(self._image_desc_worker())
|
||
self.image_desc_workers.append(worker)
|
||
logger.info("已启动 2 个图片描述工作协程")
|
||
|
||
logger.info(f"AI 聊天插件已加载,模型: {self.config['api']['model']}")
|
||
|
||
def _get_chat_id(self, from_wxid: str, sender_wxid: str = None, is_group: bool = False) -> str:
|
||
"""获取会话ID"""
|
||
if is_group:
|
||
# 群聊使用 "群ID:用户ID" 组合,确保每个用户有独立的对话记忆
|
||
user_wxid = sender_wxid or from_wxid
|
||
return f"{from_wxid}:{user_wxid}"
|
||
else:
|
||
return sender_wxid or from_wxid # 私聊使用用户ID
|
||
|
||
def _add_to_memory(self, chat_id: str, role: str, content, image_base64: str = None):
|
||
"""
|
||
添加消息到记忆
|
||
|
||
Args:
|
||
chat_id: 会话ID
|
||
role: 角色 (user/assistant)
|
||
content: 消息内容(可以是字符串或列表)
|
||
image_base64: 可选的图片base64数据
|
||
"""
|
||
if not self.config.get("memory", {}).get("enabled", False):
|
||
return
|
||
|
||
if chat_id not in self.memory:
|
||
self.memory[chat_id] = []
|
||
|
||
# 如果有图片,构建多模态内容
|
||
if image_base64:
|
||
message_content = [
|
||
{"type": "text", "text": content if isinstance(content, str) else ""},
|
||
{"type": "image_url", "image_url": {"url": image_base64}}
|
||
]
|
||
else:
|
||
message_content = content
|
||
|
||
self.memory[chat_id].append({"role": role, "content": message_content})
|
||
|
||
# 限制记忆长度
|
||
max_messages = self.config["memory"]["max_messages"]
|
||
if len(self.memory[chat_id]) > max_messages:
|
||
self.memory[chat_id] = self.memory[chat_id][-max_messages:]
|
||
|
||
def _get_memory_messages(self, chat_id: str) -> list:
|
||
"""获取记忆中的消息"""
|
||
if not self.config.get("memory", {}).get("enabled", False):
|
||
return []
|
||
return self.memory.get(chat_id, [])
|
||
|
||
def _clear_memory(self, chat_id: str):
|
||
"""清空指定会话的记忆"""
|
||
if chat_id in self.memory:
|
||
del self.memory[chat_id]
|
||
|
||
async def _download_and_encode_image(self, bot, cdnurl: str, aeskey: str) -> str:
|
||
"""下载图片并转换为base64"""
|
||
try:
|
||
temp_dir = Path(__file__).parent / "temp"
|
||
temp_dir.mkdir(exist_ok=True)
|
||
|
||
filename = f"temp_{uuid.uuid4().hex[:8]}.jpg"
|
||
save_path = str((temp_dir / filename).resolve())
|
||
|
||
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=2)
|
||
if not success:
|
||
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=1)
|
||
|
||
if not success:
|
||
return ""
|
||
|
||
# 等待文件写入完成
|
||
import os
|
||
import asyncio
|
||
for _ in range(20): # 最多等待10秒
|
||
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
|
||
break
|
||
await asyncio.sleep(0.5)
|
||
|
||
if not os.path.exists(save_path):
|
||
return ""
|
||
|
||
with open(save_path, "rb") as f:
|
||
image_data = base64.b64encode(f.read()).decode()
|
||
|
||
try:
|
||
Path(save_path).unlink()
|
||
except:
|
||
pass
|
||
|
||
return f"data:image/jpeg;base64,{image_data}"
|
||
except Exception as e:
|
||
logger.error(f"下载图片失败: {e}")
|
||
return ""
|
||
|
||
async def _download_emoji_and_encode(self, cdn_url: str) -> str:
|
||
"""下载表情包并转换为base64(HTTP 直接下载)"""
|
||
try:
|
||
# 替换 HTML 实体
|
||
cdn_url = cdn_url.replace("&", "&")
|
||
|
||
temp_dir = Path(__file__).parent / "temp"
|
||
temp_dir.mkdir(exist_ok=True)
|
||
|
||
filename = f"temp_{uuid.uuid4().hex[:8]}.gif"
|
||
save_path = temp_dir / filename
|
||
|
||
# 使用 aiohttp 下载
|
||
timeout = aiohttp.ClientTimeout(total=30)
|
||
|
||
# 配置代理
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_username = proxy_config.get("username")
|
||
proxy_password = proxy_config.get("password")
|
||
|
||
if proxy_username and proxy_password:
|
||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||
else:
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
|
||
if PROXY_SUPPORT:
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
except:
|
||
connector = None
|
||
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
async with session.get(cdn_url) as response:
|
||
if response.status == 200:
|
||
content = await response.read()
|
||
with open(save_path, "wb") as f:
|
||
f.write(content)
|
||
|
||
# 编码为 base64
|
||
image_data = base64.b64encode(content).decode()
|
||
|
||
# 删除临时文件
|
||
try:
|
||
save_path.unlink()
|
||
except:
|
||
pass
|
||
|
||
return f"data:image/gif;base64,{image_data}"
|
||
|
||
return ""
|
||
except Exception as e:
|
||
logger.error(f"下载表情包失败: {e}")
|
||
return ""
|
||
|
||
async def _generate_image_description(self, image_base64: str, prompt: str, config: dict) -> str:
|
||
"""
|
||
使用 AI 生成图片描述
|
||
|
||
Args:
|
||
image_base64: 图片的 base64 数据
|
||
prompt: 描述提示词
|
||
config: 图片描述配置
|
||
|
||
Returns:
|
||
图片描述文本,失败返回空字符串
|
||
"""
|
||
try:
|
||
api_config = self.config["api"]
|
||
description_model = config.get("model", api_config["model"])
|
||
|
||
# 构建消息
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{"type": "image_url", "image_url": {"url": image_base64}}
|
||
]
|
||
}
|
||
]
|
||
|
||
payload = {
|
||
"model": description_model,
|
||
"messages": messages,
|
||
"max_tokens": config.get("max_tokens", 1000),
|
||
"stream": True
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_config['api_key']}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=api_config["timeout"])
|
||
|
||
# 配置代理
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_username = proxy_config.get("username")
|
||
proxy_password = proxy_config.get("password")
|
||
|
||
if proxy_username and proxy_password:
|
||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||
else:
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
|
||
if PROXY_SUPPORT:
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
except Exception as e:
|
||
logger.warning(f"代理配置失败,将直连: {e}")
|
||
connector = None
|
||
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
async with session.post(
|
||
api_config["url"],
|
||
json=payload,
|
||
headers=headers
|
||
) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"图片描述 API 返回错误: {resp.status}, {error_text[:200]}")
|
||
return ""
|
||
|
||
# 流式接收响应
|
||
import json
|
||
description = ""
|
||
async for line in resp.content:
|
||
line = line.decode('utf-8').strip()
|
||
if not line or line == "data: [DONE]":
|
||
continue
|
||
|
||
if line.startswith("data: "):
|
||
try:
|
||
data = json.loads(line[6:])
|
||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||
content = delta.get("content", "")
|
||
if content:
|
||
description += content
|
||
except:
|
||
pass
|
||
|
||
logger.debug(f"图片描述生成成功: {description}")
|
||
return description.strip()
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成图片描述失败: {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
return ""
|
||
|
||
def _collect_tools(self):
|
||
"""收集所有插件的LLM工具"""
|
||
from utils.plugin_manager import PluginManager
|
||
tools = []
|
||
for plugin in PluginManager().plugins.values():
|
||
if hasattr(plugin, 'get_llm_tools'):
|
||
plugin_tools = plugin.get_llm_tools()
|
||
if plugin_tools:
|
||
tools.extend(plugin_tools)
|
||
return tools
|
||
|
||
async def _handle_list_prompts(self, bot, from_wxid: str):
|
||
"""处理人设列表指令"""
|
||
try:
|
||
prompts_dir = Path(__file__).parent / "prompts"
|
||
|
||
# 获取所有 .txt 文件
|
||
if not prompts_dir.exists():
|
||
await bot.send_text(from_wxid, "❌ prompts 目录不存在")
|
||
return
|
||
|
||
txt_files = sorted(prompts_dir.glob("*.txt"))
|
||
|
||
if not txt_files:
|
||
await bot.send_text(from_wxid, "❌ 没有找到任何人设文件")
|
||
return
|
||
|
||
# 构建列表消息
|
||
current_file = self.config["prompt"]["system_prompt_file"]
|
||
msg = "📋 可用人设列表:\n\n"
|
||
|
||
for i, file_path in enumerate(txt_files, 1):
|
||
filename = file_path.name
|
||
# 标记当前使用的人设
|
||
if filename == current_file:
|
||
msg += f"{i}. {filename} ✅\n"
|
||
else:
|
||
msg += f"{i}. {filename}\n"
|
||
|
||
msg += f"\n💡 使用方法:/切人设 文件名.txt"
|
||
|
||
await bot.send_text(from_wxid, msg)
|
||
logger.info("已发送人设列表")
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取人设列表失败: {e}")
|
||
await bot.send_text(from_wxid, f"❌ 获取人设列表失败: {str(e)}")
|
||
|
||
async def _handle_switch_prompt(self, bot, from_wxid: str, content: str):
|
||
"""处理切换人设指令"""
|
||
try:
|
||
# 提取文件名
|
||
parts = content.split(maxsplit=1)
|
||
if len(parts) < 2:
|
||
await bot.send_text(from_wxid, "❌ 请指定人设文件名\n格式:/切人设 文件名.txt")
|
||
return
|
||
|
||
filename = parts[1].strip()
|
||
|
||
# 检查文件是否存在
|
||
prompt_path = Path(__file__).parent / "prompts" / filename
|
||
if not prompt_path.exists():
|
||
await bot.send_text(from_wxid, f"❌ 人设文件不存在: {filename}")
|
||
return
|
||
|
||
# 读取新人设
|
||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||
new_prompt = f.read().strip()
|
||
|
||
# 更新人设
|
||
self.system_prompt = new_prompt
|
||
self.config["prompt"]["system_prompt_file"] = filename
|
||
|
||
await bot.send_text(from_wxid, f"✅ 已切换人设: {filename}")
|
||
logger.success(f"管理员切换人设: {filename}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"切换人设失败: {e}")
|
||
await bot.send_text(from_wxid, f"❌ 切换人设失败: {str(e)}")
|
||
|
||
@on_text_message(priority=80)
|
||
async def handle_message(self, bot, message: dict):
|
||
"""处理文本消息"""
|
||
content = message.get("Content", "").strip()
|
||
from_wxid = message.get("FromWxid", "")
|
||
sender_wxid = message.get("SenderWxid", "")
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
# 获取实际发送者
|
||
user_wxid = sender_wxid if is_group else from_wxid
|
||
|
||
# 获取机器人 wxid 和管理员列表
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
|
||
admins = main_config.get("Bot", {}).get("admins", [])
|
||
|
||
# 检查是否是人设列表指令(精确匹配)
|
||
if content == "/人设列表":
|
||
await self._handle_list_prompts(bot, from_wxid)
|
||
return False
|
||
|
||
# 检查是否是切换人设指令(精确匹配前缀)
|
||
if content.startswith("/切人设 ") or content.startswith("/切换人设 "):
|
||
if user_wxid in admins:
|
||
await self._handle_switch_prompt(bot, from_wxid, content)
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可以切换人设")
|
||
return False
|
||
|
||
# 检查是否是清空记忆指令
|
||
clear_command = self.config.get("memory", {}).get("clear_command", "/清空记忆")
|
||
if content == clear_command:
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
self._clear_memory(chat_id)
|
||
await bot.send_text(from_wxid, "✅ 已清空当前会话的记忆")
|
||
return False
|
||
|
||
# 检查是否是记忆状态指令(仅管理员)
|
||
if content == "/记忆状态":
|
||
if user_wxid in admins:
|
||
if is_group:
|
||
history = await self._load_history(from_wxid)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
context_count = min(len(history), max_context)
|
||
msg = f"📊 群聊记忆: {len(history)} 条\n"
|
||
msg += f"💬 AI可见: 最近 {context_count} 条"
|
||
await bot.send_text(from_wxid, msg)
|
||
else:
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
memory = self._get_memory_messages(chat_id)
|
||
msg = f"📊 私聊记忆: {len(memory)} 条"
|
||
await bot.send_text(from_wxid, msg)
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 仅管理员可以查看记忆状态")
|
||
return False
|
||
|
||
# 检查是否应该回复
|
||
should_reply = self._should_reply(message, content, bot_wxid)
|
||
|
||
# 获取用户昵称(用于历史记录)
|
||
nickname = ""
|
||
if is_group:
|
||
try:
|
||
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
|
||
if user_info and user_info.get("nickName", {}).get("string"):
|
||
nickname = user_info["nickName"]["string"]
|
||
except:
|
||
pass
|
||
|
||
# 如果获取昵称失败,从 MessageLogger 数据库查询
|
||
if not nickname:
|
||
from plugins.MessageLogger.main import MessageLogger
|
||
msg_logger = MessageLogger.get_instance()
|
||
if msg_logger:
|
||
try:
|
||
with msg_logger.get_db_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute(
|
||
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
|
||
(user_wxid,)
|
||
)
|
||
result = cursor.fetchone()
|
||
if result:
|
||
nickname = result[0]
|
||
except:
|
||
pass
|
||
|
||
# 最后降级使用 wxid
|
||
if not nickname:
|
||
nickname = user_wxid or sender_wxid or "未知用户"
|
||
|
||
# 保存到群组历史记录(所有消息都保存,不管是否回复)
|
||
if is_group:
|
||
await self._add_to_history(from_wxid, nickname, content)
|
||
|
||
# 如果不需要回复,直接返回
|
||
if not should_reply:
|
||
return
|
||
|
||
# 提取实际消息内容(去除@)
|
||
actual_content = self._extract_content(message, content)
|
||
if not actual_content:
|
||
return
|
||
|
||
logger.info(f"AI 处理消息: {actual_content[:50]}...")
|
||
|
||
try:
|
||
# 获取会话ID并添加用户消息到记忆
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
self._add_to_memory(chat_id, "user", actual_content)
|
||
|
||
# 调用 AI API
|
||
response = await self._call_ai_api(actual_content, bot, from_wxid, chat_id, nickname, user_wxid, is_group)
|
||
|
||
# 发送回复并添加到记忆
|
||
# 注意:如果返回空字符串,说明已经以其他形式(如聊天记录)发送了,不需要再发送文本
|
||
if response:
|
||
await bot.send_text(from_wxid, response)
|
||
self._add_to_memory(chat_id, "assistant", response)
|
||
# 保存机器人回复到历史记录
|
||
if is_group:
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
|
||
await self._add_to_history(from_wxid, bot_nickname, response)
|
||
logger.success(f"AI 回复成功: {response[:50]}...")
|
||
else:
|
||
logger.info("AI 回复为空或已通过其他方式发送(如聊天记录)")
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_detail = traceback.format_exc()
|
||
logger.error(f"AI 处理失败: {type(e).__name__}: {str(e)}")
|
||
logger.error(f"详细错误:\n{error_detail}")
|
||
await bot.send_text(from_wxid, "抱歉,我遇到了一些问题,请稍后再试。")
|
||
|
||
def _should_reply(self, message: dict, content: str, bot_wxid: str = None) -> bool:
|
||
"""判断是否应该回复"""
|
||
# 检查是否由AutoReply插件触发
|
||
if message.get('_auto_reply_triggered'):
|
||
return True
|
||
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
# 检查群聊/私聊开关
|
||
if is_group and not self.config["behavior"].get("reply_group", True):
|
||
return False
|
||
if not is_group and not self.config["behavior"].get("reply_private", True):
|
||
return False
|
||
|
||
trigger_mode = self.config["behavior"].get("trigger_mode", "mention")
|
||
|
||
# all 模式:回复所有消息
|
||
if trigger_mode == "all":
|
||
return True
|
||
|
||
# mention 模式:检查是否@了机器人
|
||
if trigger_mode == "mention":
|
||
if is_group:
|
||
ats = message.get("Ats", [])
|
||
# 检查是否@了机器人
|
||
if not ats:
|
||
return False
|
||
|
||
# 如果没有 bot_wxid,从配置文件读取
|
||
if not bot_wxid:
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
|
||
|
||
# 检查 @ 列表中是否包含机器人的 wxid
|
||
if bot_wxid and bot_wxid in ats:
|
||
return True
|
||
|
||
return False
|
||
else:
|
||
# 私聊直接回复
|
||
return True
|
||
|
||
# keyword 模式:检查关键词
|
||
if trigger_mode == "keyword":
|
||
keywords = self.config["behavior"]["keywords"]
|
||
return any(kw in content for kw in keywords)
|
||
|
||
return False
|
||
|
||
def _extract_content(self, message: dict, content: str) -> str:
|
||
"""提取实际消息内容(去除@等)"""
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
if is_group:
|
||
# 群聊消息,去除@部分
|
||
# 格式通常是 "@昵称 消息内容"
|
||
parts = content.split(maxsplit=1)
|
||
if len(parts) > 1 and parts[0].startswith("@"):
|
||
return parts[1].strip()
|
||
return content.strip()
|
||
|
||
return content.strip()
|
||
|
||
async def _call_ai_api(self, user_message: str, bot=None, from_wxid: str = None, chat_id: str = None, nickname: str = "", user_wxid: str = None, is_group: bool = False) -> str:
|
||
"""调用 AI API"""
|
||
api_config = self.config["api"]
|
||
|
||
# 收集工具
|
||
tools = self._collect_tools()
|
||
logger.info(f"收集到 {len(tools)} 个工具函数")
|
||
if tools:
|
||
tool_names = [t["function"]["name"] for t in tools]
|
||
logger.info(f"工具列表: {tool_names}")
|
||
|
||
# 构建消息列表
|
||
system_content = self.system_prompt
|
||
|
||
# 添加当前时间信息
|
||
current_time = datetime.now()
|
||
weekday_map = {
|
||
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
|
||
4: "星期五", 5: "星期六", 6: "星期日"
|
||
}
|
||
weekday = weekday_map[current_time.weekday()]
|
||
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
|
||
system_content += f"\n\n当前时间:{time_str}"
|
||
|
||
if nickname:
|
||
system_content += f"\n当前对话用户的昵称是:{nickname}"
|
||
messages = [{"role": "system", "content": system_content}]
|
||
|
||
# 从 JSON 历史记录加载上下文(仅群聊)
|
||
if is_group and from_wxid:
|
||
history = await self._load_history(from_wxid)
|
||
max_context = self.config.get("history", {}).get("max_context", 50)
|
||
|
||
# 取最近的 N 条消息作为上下文
|
||
recent_history = history[-max_context:] if len(history) > max_context else history
|
||
|
||
# 转换为 AI 消息格式
|
||
for msg in recent_history:
|
||
msg_nickname = msg.get("nickname", "")
|
||
msg_content = msg.get("content", "")
|
||
|
||
# 检查是否是多模态内容(包含图片)
|
||
if isinstance(msg_content, list):
|
||
# 多模态内容:添加昵称前缀到文本部分
|
||
content_with_nickname = []
|
||
for item in msg_content:
|
||
if item.get("type") == "text":
|
||
# 在文本前添加昵称
|
||
content_with_nickname.append({
|
||
"type": "text",
|
||
"text": f"[{msg_nickname}] {item.get('text', '')}"
|
||
})
|
||
else:
|
||
# 图片等其他类型直接保留
|
||
content_with_nickname.append(item)
|
||
|
||
messages.append({
|
||
"role": "user",
|
||
"content": content_with_nickname
|
||
})
|
||
else:
|
||
# 纯文本内容:简单格式 [昵称] 内容
|
||
messages.append({
|
||
"role": "user",
|
||
"content": f"[{msg_nickname}] {msg_content}"
|
||
})
|
||
else:
|
||
# 私聊使用原有的 memory 机制
|
||
if chat_id:
|
||
memory_messages = self._get_memory_messages(chat_id)
|
||
if memory_messages and len(memory_messages) > 1:
|
||
messages.extend(memory_messages[:-1])
|
||
|
||
# 添加当前用户消息
|
||
messages.append({"role": "user", "content": f"[{nickname}] {user_message}" if is_group and nickname else user_message})
|
||
|
||
# 保存用户信息供工具调用使用
|
||
self._current_user_wxid = user_wxid
|
||
self._current_is_group = is_group
|
||
|
||
payload = {
|
||
"model": api_config["model"],
|
||
"messages": messages
|
||
}
|
||
|
||
if tools:
|
||
payload["tools"] = tools
|
||
logger.debug(f"已将 {len(tools)} 个工具添加到请求中")
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_config['api_key']}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=api_config["timeout"])
|
||
|
||
# 配置代理
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_username = proxy_config.get("username")
|
||
proxy_password = proxy_config.get("password")
|
||
|
||
# 构建代理 URL
|
||
if proxy_username and proxy_password:
|
||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||
else:
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
|
||
if PROXY_SUPPORT:
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
logger.debug(f"使用代理: {proxy_type}://{proxy_host}:{proxy_port}")
|
||
except Exception as e:
|
||
logger.warning(f"代理配置失败,将直连: {e}")
|
||
connector = None
|
||
else:
|
||
logger.warning("代理功能不可用(aiohttp_socks 未安装),将直连")
|
||
connector = None
|
||
|
||
# 启用流式响应
|
||
payload["stream"] = True
|
||
|
||
try:
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
logger.debug(f"发送流式 API 请求: {api_config['url']}")
|
||
async with session.post(
|
||
api_config["url"],
|
||
json=payload,
|
||
headers=headers
|
||
) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"API 返回错误状态码: {resp.status}, 响应: {error_text}")
|
||
raise Exception(f"API 错误 {resp.status}: {error_text}")
|
||
|
||
# 流式接收响应
|
||
import json
|
||
full_content = ""
|
||
tool_calls_dict = {} # 使用字典来组装工具调用 {index: tool_call}
|
||
|
||
async for line in resp.content:
|
||
line = line.decode('utf-8').strip()
|
||
if not line or line == "data: [DONE]":
|
||
continue
|
||
|
||
if line.startswith("data: "):
|
||
try:
|
||
data = json.loads(line[6:])
|
||
choices = data.get("choices", [])
|
||
if not choices:
|
||
continue
|
||
|
||
delta = choices[0].get("delta", {})
|
||
|
||
# 收集内容
|
||
content = delta.get("content", "")
|
||
if content:
|
||
full_content += content
|
||
|
||
# 收集工具调用(增量式组装)
|
||
if delta.get("tool_calls"):
|
||
for tool_call_delta in delta["tool_calls"]:
|
||
index = tool_call_delta.get("index", 0)
|
||
|
||
# 初始化工具调用
|
||
if index not in tool_calls_dict:
|
||
tool_calls_dict[index] = {
|
||
"id": "",
|
||
"type": "function",
|
||
"function": {
|
||
"name": "",
|
||
"arguments": ""
|
||
}
|
||
}
|
||
|
||
# 更新 id
|
||
if "id" in tool_call_delta:
|
||
tool_calls_dict[index]["id"] = tool_call_delta["id"]
|
||
|
||
# 更新 type
|
||
if "type" in tool_call_delta:
|
||
tool_calls_dict[index]["type"] = tool_call_delta["type"]
|
||
|
||
# 更新 function
|
||
if "function" in tool_call_delta:
|
||
func_delta = tool_call_delta["function"]
|
||
if "name" in func_delta:
|
||
tool_calls_dict[index]["function"]["name"] += func_delta["name"]
|
||
if "arguments" in func_delta:
|
||
tool_calls_dict[index]["function"]["arguments"] += func_delta["arguments"]
|
||
except Exception as e:
|
||
logger.debug(f"解析流式数据失败: {e}")
|
||
pass
|
||
|
||
# 转换为列表
|
||
tool_calls_data = [tool_calls_dict[i] for i in sorted(tool_calls_dict.keys())] if tool_calls_dict else []
|
||
|
||
logger.debug(f"流式 API 响应完成")
|
||
|
||
# 检查是否有函数调用
|
||
if tool_calls_data:
|
||
# 收集所有工具调用结果
|
||
tool_results = []
|
||
has_no_reply = False
|
||
chat_record_info = None
|
||
|
||
for tool_call in tool_calls_data:
|
||
function_name = tool_call.get("function", {}).get("name", "")
|
||
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
|
||
tool_call_id = tool_call.get("id", "")
|
||
|
||
if not function_name:
|
||
continue
|
||
|
||
try:
|
||
arguments = json.loads(arguments_str)
|
||
except:
|
||
arguments = {}
|
||
|
||
logger.info(f"AI调用工具: {function_name}, 参数: {arguments}")
|
||
|
||
# 执行工具并等待结果
|
||
if bot and from_wxid:
|
||
result = await self._execute_tool_and_get_result(function_name, arguments, bot, from_wxid)
|
||
|
||
if result and result.get("no_reply"):
|
||
has_no_reply = True
|
||
logger.info(f"工具 {function_name} 要求不回复")
|
||
|
||
if result and result.get("send_as_chat_record"):
|
||
chat_record_info = {
|
||
"title": result.get("chat_record_title", "AI 回复"),
|
||
"bot": bot,
|
||
"from_wxid": from_wxid
|
||
}
|
||
logger.info(f"工具 {function_name} 要求以聊天记录形式发送")
|
||
|
||
tool_results.append({
|
||
"tool_call_id": tool_call_id,
|
||
"role": "tool",
|
||
"name": function_name,
|
||
"content": result.get("message", "") if result else "工具执行失败"
|
||
})
|
||
else:
|
||
logger.error(f"工具调用跳过: bot={bot}, from_wxid={from_wxid}")
|
||
tool_results.append({
|
||
"tool_call_id": tool_call_id,
|
||
"role": "tool",
|
||
"name": function_name,
|
||
"content": "工具执行失败:缺少必要参数"
|
||
})
|
||
|
||
if has_no_reply:
|
||
logger.info("工具要求不回复,跳过 AI 回复")
|
||
return ""
|
||
|
||
# 将工具结果发送回 AI,让 AI 生成最终回复
|
||
messages.append({
|
||
"role": "assistant",
|
||
"content": full_content if full_content else None,
|
||
"tool_calls": tool_calls_data
|
||
})
|
||
messages.extend(tool_results)
|
||
|
||
# 检查工具执行结果,判断是否需要 AI 生成回复
|
||
# 如果所有工具都成功执行且已发送内容,可能不需要额外回复
|
||
all_tools_sent_content = all(
|
||
result.get("content") and ("已生成" in result.get("content", "") or "已发送" in result.get("content", ""))
|
||
for result in tool_results
|
||
)
|
||
|
||
# 如果工具已经发送了内容(如图片),可以选择不再调用 AI 生成额外回复
|
||
# 但为了更好的用户体验,我们还是让 AI 生成一个简短的回复
|
||
logger.debug(f"工具执行完成,准备获取 AI 最终回复")
|
||
|
||
# 再次调用 API 获取最终回复(流式)
|
||
payload["messages"] = messages
|
||
async with session.post(
|
||
api_config["url"],
|
||
json=payload,
|
||
headers=headers
|
||
) as resp2:
|
||
if resp2.status != 200:
|
||
error_text = await resp2.text()
|
||
logger.error(f"API 返回错误: {resp2.status}, {error_text}")
|
||
# 如果第二次调用失败,但工具已经发送了内容,返回空字符串
|
||
if all_tools_sent_content:
|
||
logger.info("工具已发送内容,跳过 AI 回复")
|
||
return ""
|
||
# 否则返回一个默认消息
|
||
return "✅ 已完成"
|
||
|
||
# 流式接收第二次响应
|
||
ai_reply = ""
|
||
async for line in resp2.content:
|
||
line = line.decode('utf-8').strip()
|
||
if not line or line == "data: [DONE]":
|
||
continue
|
||
|
||
if line.startswith("data: "):
|
||
try:
|
||
data = json.loads(line[6:])
|
||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||
content = delta.get("content", "")
|
||
if content:
|
||
ai_reply += content
|
||
except:
|
||
pass
|
||
|
||
# 如果需要以聊天记录形式发送
|
||
if chat_record_info and ai_reply:
|
||
await self._send_chat_records(
|
||
chat_record_info["bot"],
|
||
chat_record_info["from_wxid"],
|
||
chat_record_info["title"],
|
||
ai_reply
|
||
)
|
||
return ""
|
||
|
||
# 返回 AI 的回复
|
||
# 如果 AI 没有生成回复,但工具已经发送了内容,返回空字符串
|
||
if not ai_reply.strip() and all_tools_sent_content:
|
||
logger.info("AI 无回复且工具已发送内容,不发送额外消息")
|
||
return ""
|
||
|
||
# 返回 AI 的回复,如果为空则返回一个友好的确认消息
|
||
return ai_reply.strip() if ai_reply.strip() else "✅ 完成"
|
||
|
||
# 检查是否包含错误的工具调用格式
|
||
if "<tool_code>" in full_content or "print(" in full_content and "flow2_ai_image_generation" in full_content:
|
||
logger.warning("检测到模型输出了错误的工具调用格式,拦截并返回提示")
|
||
return "抱歉,我遇到了一些技术问题,请重新描述一下你的需求~"
|
||
|
||
return full_content.strip()
|
||
except aiohttp.ClientError as e:
|
||
logger.error(f"网络请求失败: {type(e).__name__}: {str(e)}")
|
||
raise Exception(f"网络请求失败: {str(e)}")
|
||
except asyncio.TimeoutError:
|
||
logger.error(f"API 请求超时 (timeout={api_config['timeout']}s)")
|
||
raise Exception(f"API 请求超时")
|
||
except KeyError as e:
|
||
logger.error(f"API 响应格式错误,缺少字段: {e}")
|
||
raise Exception(f"API 响应格式错误: {e}")
|
||
|
||
|
||
def _get_history_file(self, chat_id: str) -> Path:
|
||
"""获取群聊历史记录文件路径"""
|
||
if not self.history_dir:
|
||
return None
|
||
safe_name = chat_id.replace("@", "_").replace(":", "_")
|
||
return self.history_dir / f"{safe_name}.json"
|
||
|
||
def _get_history_lock(self, chat_id: str) -> asyncio.Lock:
|
||
"""获取指定会话的锁, 每个会话一把"""
|
||
lock = self.history_locks.get(chat_id)
|
||
if lock is None:
|
||
lock = asyncio.Lock()
|
||
self.history_locks[chat_id] = lock
|
||
return lock
|
||
|
||
def _read_history_file(self, history_file: Path) -> list:
|
||
try:
|
||
import json
|
||
with open(history_file, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
except FileNotFoundError:
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"读取历史记录失败: {e}")
|
||
return []
|
||
|
||
def _write_history_file(self, history_file: Path, history: list):
|
||
import json
|
||
history_file.parent.mkdir(parents=True, exist_ok=True)
|
||
temp_file = Path(str(history_file) + ".tmp")
|
||
with open(temp_file, "w", encoding="utf-8") as f:
|
||
json.dump(history, f, ensure_ascii=False, indent=2)
|
||
temp_file.replace(history_file)
|
||
|
||
async def _load_history(self, chat_id: str) -> list:
|
||
"""异步读取群聊历史, 用锁避免与写入冲突"""
|
||
history_file = self._get_history_file(chat_id)
|
||
if not history_file:
|
||
return []
|
||
lock = self._get_history_lock(chat_id)
|
||
async with lock:
|
||
return self._read_history_file(history_file)
|
||
|
||
async def _save_history(self, chat_id: str, history: list):
|
||
"""异步写入群聊历史, 包含长度截断"""
|
||
history_file = self._get_history_file(chat_id)
|
||
if not history_file:
|
||
return
|
||
|
||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||
if len(history) > max_history:
|
||
history = history[-max_history:]
|
||
|
||
lock = self._get_history_lock(chat_id)
|
||
async with lock:
|
||
self._write_history_file(history_file, history)
|
||
|
||
async def _add_to_history(self, chat_id: str, nickname: str, content: str, image_base64: str = None):
|
||
"""
|
||
将消息存入群聊历史
|
||
|
||
Args:
|
||
chat_id: 群聊ID
|
||
nickname: 用户昵称
|
||
content: 消息内容
|
||
image_base64: 可选的图片base64
|
||
"""
|
||
if not self.config.get("history", {}).get("enabled", True):
|
||
return
|
||
|
||
history_file = self._get_history_file(chat_id)
|
||
if not history_file:
|
||
return
|
||
|
||
lock = self._get_history_lock(chat_id)
|
||
async with lock:
|
||
history = self._read_history_file(history_file)
|
||
|
||
message_record = {
|
||
"nickname": nickname,
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
|
||
if image_base64:
|
||
message_record["content"] = [
|
||
{"type": "text", "text": content},
|
||
{"type": "image_url", "image_url": {"url": image_base64}}
|
||
]
|
||
else:
|
||
message_record["content"] = content
|
||
|
||
history.append(message_record)
|
||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||
if len(history) > max_history:
|
||
history = history[-max_history:]
|
||
|
||
self._write_history_file(history_file, history)
|
||
|
||
async def _add_to_history_with_id(self, chat_id: str, nickname: str, content: str, record_id: str):
|
||
"""带ID的历史追加, 便于后续更新"""
|
||
if not self.config.get("history", {}).get("enabled", True):
|
||
return
|
||
|
||
history_file = self._get_history_file(chat_id)
|
||
if not history_file:
|
||
return
|
||
|
||
lock = self._get_history_lock(chat_id)
|
||
async with lock:
|
||
history = self._read_history_file(history_file)
|
||
message_record = {
|
||
"id": record_id,
|
||
"nickname": nickname,
|
||
"timestamp": datetime.now().isoformat(),
|
||
"content": content
|
||
}
|
||
history.append(message_record)
|
||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||
if len(history) > max_history:
|
||
history = history[-max_history:]
|
||
self._write_history_file(history_file, history)
|
||
|
||
async def _update_history_by_id(self, chat_id: str, record_id: str, new_content: str):
|
||
"""根据ID更新历史记录"""
|
||
if not self.config.get("history", {}).get("enabled", True):
|
||
return
|
||
|
||
history_file = self._get_history_file(chat_id)
|
||
if not history_file:
|
||
return
|
||
|
||
lock = self._get_history_lock(chat_id)
|
||
async with lock:
|
||
history = self._read_history_file(history_file)
|
||
for record in history:
|
||
if record.get("id") == record_id:
|
||
record["content"] = new_content
|
||
break
|
||
max_history = self.config.get("history", {}).get("max_history", 100)
|
||
if len(history) > max_history:
|
||
history = history[-max_history:]
|
||
self._write_history_file(history_file, history)
|
||
|
||
|
||
async def _execute_tool_and_get_result(self, tool_name: str, arguments: dict, bot, from_wxid: str):
|
||
"""执行工具调用并返回结果"""
|
||
from utils.plugin_manager import PluginManager
|
||
|
||
# 添加用户信息到 arguments
|
||
arguments["user_wxid"] = getattr(self, "_current_user_wxid", from_wxid)
|
||
arguments["is_group"] = getattr(self, "_current_is_group", False)
|
||
|
||
logger.info(f"开始执行工具: {tool_name}")
|
||
|
||
plugins = PluginManager().plugins
|
||
logger.info(f"检查 {len(plugins)} 个插件")
|
||
|
||
for plugin_name, plugin in plugins.items():
|
||
logger.debug(f"检查插件: {plugin_name}, 有execute_llm_tool: {hasattr(plugin, 'execute_llm_tool')}")
|
||
if hasattr(plugin, 'execute_llm_tool'):
|
||
try:
|
||
logger.info(f"调用 {plugin_name}.execute_llm_tool")
|
||
result = await plugin.execute_llm_tool(tool_name, arguments, bot, from_wxid)
|
||
logger.info(f"{plugin_name} 返回: {result}")
|
||
if result is not None:
|
||
if result.get("success"):
|
||
logger.success(f"工具执行成功: {tool_name}")
|
||
return result
|
||
else:
|
||
logger.debug(f"{plugin_name} 不处理此工具,继续检查下一个插件")
|
||
except Exception as e:
|
||
logger.error(f"工具执行异常 ({plugin_name}): {tool_name}, {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
|
||
logger.warning(f"未找到工具: {tool_name}")
|
||
return {"success": False, "message": f"未找到工具: {tool_name}"}
|
||
|
||
@on_quote_message(priority=79)
|
||
async def handle_quote_message(self, bot, message: dict):
|
||
"""处理引用消息(包含图片)"""
|
||
content = message.get("Content", "").strip()
|
||
from_wxid = message.get("FromWxid", "")
|
||
sender_wxid = message.get("SenderWxid", "")
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
user_wxid = sender_wxid if is_group else from_wxid
|
||
|
||
try:
|
||
# 解析XML获取标题和引用消息
|
||
root = ET.fromstring(content)
|
||
title = root.find(".//title")
|
||
if title is None or not title.text:
|
||
logger.debug("引用消息没有标题,跳过")
|
||
return True
|
||
|
||
title_text = title.text.strip()
|
||
logger.info(f"收到引用消息,标题: {title_text[:50]}...")
|
||
|
||
# 检查是否应该回复
|
||
if not self._should_reply_quote(message, title_text):
|
||
logger.debug("引用消息不满足回复条件")
|
||
return True
|
||
|
||
# 获取引用消息中的图片信息
|
||
refermsg = root.find(".//refermsg")
|
||
if refermsg is None:
|
||
logger.debug("引用消息中没有 refermsg 节点")
|
||
return True
|
||
|
||
refer_content = refermsg.find("content")
|
||
if refer_content is None or not refer_content.text:
|
||
logger.debug("引用消息中没有 content")
|
||
return True
|
||
|
||
# 解码HTML实体
|
||
import html
|
||
refer_xml = html.unescape(refer_content.text)
|
||
refer_root = ET.fromstring(refer_xml)
|
||
|
||
# 提取图片信息
|
||
img = refer_root.find(".//img")
|
||
if img is None:
|
||
logger.debug("引用的消息不是图片")
|
||
return True
|
||
|
||
cdnbigimgurl = img.get("cdnbigimgurl", "")
|
||
aeskey = img.get("aeskey", "")
|
||
|
||
if not cdnbigimgurl or not aeskey:
|
||
logger.warning(f"图片信息不完整: cdnurl={bool(cdnbigimgurl)}, aeskey={bool(aeskey)}")
|
||
return True
|
||
|
||
logger.info(f"AI处理引用图片消息: {title_text[:50]}...")
|
||
|
||
# 获取用户昵称
|
||
nickname = ""
|
||
if is_group:
|
||
try:
|
||
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
|
||
if user_info and user_info.get("nickName", {}).get("string"):
|
||
nickname = user_info["nickName"]["string"]
|
||
logger.info(f"获取到用户昵称: {nickname}")
|
||
except Exception as e:
|
||
logger.error(f"获取用户昵称失败: {e}")
|
||
|
||
# 下载并编码图片
|
||
logger.info(f"开始下载图片: {cdnbigimgurl[:50]}...")
|
||
image_base64 = await self._download_and_encode_image(bot, cdnbigimgurl, aeskey)
|
||
if not image_base64:
|
||
logger.error("图片下载失败")
|
||
await bot.send_text(from_wxid, "❌ 无法处理图片")
|
||
return False
|
||
logger.info("图片下载和编码成功")
|
||
|
||
# 获取会话ID并添加消息到记忆(包含图片base64)
|
||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||
self._add_to_memory(chat_id, "user", title_text, image_base64=image_base64)
|
||
|
||
# 保存用户引用图片消息到群组历史记录
|
||
if is_group:
|
||
await self._add_to_history(from_wxid, nickname, title_text, image_base64=image_base64)
|
||
|
||
# 调用AI API(带图片)
|
||
response = await self._call_ai_api_with_image(title_text, image_base64, bot, from_wxid, chat_id, nickname, user_wxid, is_group)
|
||
|
||
if response:
|
||
await bot.send_text(from_wxid, response)
|
||
self._add_to_memory(chat_id, "assistant", response)
|
||
# 保存机器人回复到历史记录
|
||
if is_group:
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
|
||
await self._add_to_history(from_wxid, bot_nickname, response)
|
||
logger.success(f"AI回复成功: {response[:50]}...")
|
||
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理引用消息失败: {e}")
|
||
return True
|
||
|
||
def _should_reply_quote(self, message: dict, title_text: str) -> bool:
|
||
"""判断是否应该回复引用消息"""
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
# 检查群聊/私聊开关
|
||
if is_group and not self.config["behavior"]["reply_group"]:
|
||
return False
|
||
if not is_group and not self.config["behavior"]["reply_private"]:
|
||
return False
|
||
|
||
trigger_mode = self.config["behavior"]["trigger_mode"]
|
||
|
||
# all模式:回复所有消息
|
||
if trigger_mode == "all":
|
||
return True
|
||
|
||
# mention模式:检查是否@了机器人
|
||
if trigger_mode == "mention":
|
||
if is_group:
|
||
ats = message.get("Ats", [])
|
||
if not ats:
|
||
return False
|
||
|
||
import tomllib
|
||
with open("main_config.toml", "rb") as f:
|
||
main_config = tomllib.load(f)
|
||
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
|
||
|
||
return bot_wxid and bot_wxid in ats
|
||
else:
|
||
return True
|
||
|
||
# keyword模式:检查关键词
|
||
if trigger_mode == "keyword":
|
||
keywords = self.config["behavior"]["keywords"]
|
||
return any(kw in title_text for kw in keywords)
|
||
|
||
return False
|
||
|
||
async def _call_ai_api_with_image(self, user_message: str, image_base64: str, bot=None, from_wxid: str = None, chat_id: str = None, nickname: str = "", user_wxid: str = None, is_group: bool = False) -> str:
|
||
"""调用AI API(带图片)"""
|
||
api_config = self.config["api"]
|
||
tools = self._collect_tools()
|
||
|
||
# 构建消息列表
|
||
system_content = self.system_prompt
|
||
|
||
# 添加当前时间信息
|
||
current_time = datetime.now()
|
||
weekday_map = {
|
||
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
|
||
4: "星期五", 5: "星期六", 6: "星期日"
|
||
}
|
||
weekday = weekday_map[current_time.weekday()]
|
||
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
|
||
system_content += f"\n\n当前时间:{time_str}"
|
||
|
||
if nickname:
|
||
system_content += f"\n当前对话用户的昵称是:{nickname}"
|
||
messages = [{"role": "system", "content": system_content}]
|
||
|
||
# 添加历史记忆
|
||
if chat_id:
|
||
memory_messages = self._get_memory_messages(chat_id)
|
||
if memory_messages and len(memory_messages) > 1:
|
||
messages.extend(memory_messages[:-1])
|
||
|
||
# 添加当前用户消息(带图片)
|
||
messages.append({
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": user_message},
|
||
{"type": "image_url", "image_url": {"url": image_base64}}
|
||
]
|
||
})
|
||
|
||
# 保存用户信息供工具调用使用
|
||
self._current_user_wxid = user_wxid
|
||
self._current_is_group = is_group
|
||
|
||
payload = {
|
||
"model": api_config["model"],
|
||
"messages": messages,
|
||
"stream": True
|
||
}
|
||
|
||
if tools:
|
||
payload["tools"] = tools
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_config['api_key']}"
|
||
}
|
||
|
||
timeout = aiohttp.ClientTimeout(total=api_config["timeout"])
|
||
|
||
# 配置代理
|
||
connector = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy_username = proxy_config.get("username")
|
||
proxy_password = proxy_config.get("password")
|
||
|
||
if proxy_username and proxy_password:
|
||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||
else:
|
||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
|
||
if PROXY_SUPPORT:
|
||
try:
|
||
connector = ProxyConnector.from_url(proxy_url)
|
||
logger.debug(f"使用代理: {proxy_type}://{proxy_host}:{proxy_port}")
|
||
except Exception as e:
|
||
logger.warning(f"代理配置失败,将直连: {e}")
|
||
connector = None
|
||
else:
|
||
logger.warning("代理功能不可用(aiohttp_socks 未安装),将直连")
|
||
connector = None
|
||
|
||
try:
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
async with session.post(api_config["url"], json=payload, headers=headers) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"API返回错误状态码: {resp.status}, 响应: {error_text}")
|
||
raise Exception(f"API错误 {resp.status}: {error_text}")
|
||
|
||
# 流式接收响应
|
||
import json
|
||
full_content = ""
|
||
tool_calls_dict = {} # 使用字典来组装工具调用 {index: tool_call}
|
||
|
||
async for line in resp.content:
|
||
line = line.decode('utf-8').strip()
|
||
if not line or line == "data: [DONE]":
|
||
continue
|
||
|
||
if line.startswith("data: "):
|
||
try:
|
||
data = json.loads(line[6:])
|
||
choices = data.get("choices", [])
|
||
if not choices:
|
||
continue
|
||
|
||
delta = choices[0].get("delta", {})
|
||
content = delta.get("content", "")
|
||
if content:
|
||
full_content += content
|
||
|
||
# 收集工具调用(增量式组装)
|
||
if delta.get("tool_calls"):
|
||
for tool_call_delta in delta["tool_calls"]:
|
||
index = tool_call_delta.get("index", 0)
|
||
|
||
# 初始化工具调用
|
||
if index not in tool_calls_dict:
|
||
tool_calls_dict[index] = {
|
||
"id": "",
|
||
"type": "function",
|
||
"function": {
|
||
"name": "",
|
||
"arguments": ""
|
||
}
|
||
}
|
||
|
||
# 更新 id
|
||
if "id" in tool_call_delta:
|
||
tool_calls_dict[index]["id"] = tool_call_delta["id"]
|
||
|
||
# 更新 type
|
||
if "type" in tool_call_delta:
|
||
tool_calls_dict[index]["type"] = tool_call_delta["type"]
|
||
|
||
# 更新 function
|
||
if "function" in tool_call_delta:
|
||
func_delta = tool_call_delta["function"]
|
||
if "name" in func_delta:
|
||
tool_calls_dict[index]["function"]["name"] += func_delta["name"]
|
||
if "arguments" in func_delta:
|
||
tool_calls_dict[index]["function"]["arguments"] += func_delta["arguments"]
|
||
except Exception as e:
|
||
logger.debug(f"解析流式数据失败: {e}")
|
||
pass
|
||
|
||
# 转换为列表
|
||
tool_calls_data = [tool_calls_dict[i] for i in sorted(tool_calls_dict.keys())] if tool_calls_dict else []
|
||
|
||
# 检查是否有函数调用
|
||
if tool_calls_data:
|
||
# 收集所有工具调用结果
|
||
tool_results = []
|
||
|
||
for tool_call in tool_calls_data:
|
||
function_name = tool_call.get("function", {}).get("name", "")
|
||
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
|
||
tool_call_id = tool_call.get("id", "")
|
||
|
||
if not function_name:
|
||
continue
|
||
|
||
try:
|
||
arguments = json.loads(arguments_str)
|
||
except:
|
||
arguments = {}
|
||
|
||
# 如果是图生图工具,添加图片 base64
|
||
if function_name == "flow2_ai_image_generation" and image_base64:
|
||
arguments["image_base64"] = image_base64
|
||
logger.info(f"AI调用图生图工具,已添加图片数据")
|
||
|
||
logger.info(f"AI调用工具: {function_name}, 参数: {arguments}")
|
||
|
||
# 执行工具并等待结果
|
||
if bot and from_wxid:
|
||
result = await self._execute_tool_and_get_result(function_name, arguments, bot, from_wxid)
|
||
tool_results.append({
|
||
"tool_call_id": tool_call_id,
|
||
"role": "tool",
|
||
"name": function_name,
|
||
"content": result.get("message", "") if result else "工具执行失败"
|
||
})
|
||
|
||
# 如果工具已发送内容,返回空
|
||
if tool_results and all("已生成" in r.get("content", "") or "已发送" in r.get("content", "") for r in tool_results):
|
||
return ""
|
||
|
||
return full_content.strip() or None
|
||
|
||
# 检查是否包含错误的工具调用格式
|
||
if "<tool_code>" in full_content or "print(" in full_content and "flow2_ai_image_generation" in full_content:
|
||
logger.warning("检测到模型输出了错误的工具调用格式,拦截并返回提示")
|
||
return "抱歉,我遇到了一些技术问题,请重新描述一下你的需求~"
|
||
|
||
return full_content.strip()
|
||
|
||
except Exception as e:
|
||
logger.error(f"调用AI API失败: {e}")
|
||
raise
|
||
|
||
async def _send_chat_records(self, bot, from_wxid: str, title: str, content: str):
|
||
"""发送聊天记录格式消息"""
|
||
try:
|
||
import uuid
|
||
import time
|
||
import hashlib
|
||
import xml.etree.ElementTree as ET
|
||
|
||
is_group = from_wxid.endswith("@chatroom")
|
||
|
||
# 自动分割内容
|
||
max_length = 800
|
||
content_parts = []
|
||
|
||
if len(content) <= max_length:
|
||
content_parts = [content]
|
||
else:
|
||
lines = content.split('\n')
|
||
current_part = ""
|
||
|
||
for line in lines:
|
||
if len(current_part + line + '\n') > max_length:
|
||
if current_part:
|
||
content_parts.append(current_part.strip())
|
||
current_part = line + '\n'
|
||
else:
|
||
content_parts.append(line[:max_length])
|
||
current_part = line[max_length:] + '\n'
|
||
else:
|
||
current_part += line + '\n'
|
||
|
||
if current_part.strip():
|
||
content_parts.append(current_part.strip())
|
||
|
||
recordinfo = ET.Element("recordinfo")
|
||
info_el = ET.SubElement(recordinfo, "info")
|
||
info_el.text = title
|
||
is_group_el = ET.SubElement(recordinfo, "isChatRoom")
|
||
is_group_el.text = "1" if is_group else "0"
|
||
datalist = ET.SubElement(recordinfo, "datalist")
|
||
datalist.set("count", str(len(content_parts)))
|
||
desc_el = ET.SubElement(recordinfo, "desc")
|
||
desc_el.text = title
|
||
fromscene_el = ET.SubElement(recordinfo, "fromscene")
|
||
fromscene_el.text = "3"
|
||
|
||
for i, part in enumerate(content_parts):
|
||
di = ET.SubElement(datalist, "dataitem")
|
||
di.set("datatype", "1")
|
||
di.set("dataid", uuid.uuid4().hex)
|
||
|
||
src_local_id = str((int(time.time() * 1000) % 90000) + 10000)
|
||
new_msg_id = str(int(time.time() * 1000) + i)
|
||
create_time = str(int(time.time()) - len(content_parts) + i)
|
||
|
||
ET.SubElement(di, "srcMsgLocalid").text = src_local_id
|
||
ET.SubElement(di, "sourcetime").text = time.strftime("%Y-%m-%d %H:%M", time.localtime(int(create_time)))
|
||
ET.SubElement(di, "fromnewmsgid").text = new_msg_id
|
||
ET.SubElement(di, "srcMsgCreateTime").text = create_time
|
||
ET.SubElement(di, "sourcename").text = "AI助手"
|
||
ET.SubElement(di, "sourceheadurl").text = ""
|
||
ET.SubElement(di, "datatitle").text = part
|
||
ET.SubElement(di, "datadesc").text = part
|
||
ET.SubElement(di, "datafmt").text = "text"
|
||
ET.SubElement(di, "ischatroom").text = "1" if is_group else "0"
|
||
|
||
dataitemsource = ET.SubElement(di, "dataitemsource")
|
||
ET.SubElement(dataitemsource, "hashusername").text = hashlib.sha256(from_wxid.encode("utf-8")).hexdigest()
|
||
|
||
record_xml = ET.tostring(recordinfo, encoding="unicode")
|
||
|
||
appmsg_parts = [
|
||
"<appmsg appid=\"\" sdkver=\"0\">",
|
||
f"<title>{title}</title>",
|
||
f"<des>{title}</des>",
|
||
"<type>19</type>",
|
||
"<url>https://support.weixin.qq.com/cgi-bin/mmsupport-bin/readtemplate?t=page/favorite_record__w_unsupport</url>",
|
||
"<appattach><cdnthumbaeskey></cdnthumbaeskey><aeskey></aeskey></appattach>",
|
||
f"<recorditem><![CDATA[{record_xml}]]></recorditem>",
|
||
"<percent>0</percent>",
|
||
"</appmsg>"
|
||
]
|
||
appmsg_xml = "".join(appmsg_parts)
|
||
|
||
await bot._send_data_async(11214, {"to_wxid": from_wxid, "content": appmsg_xml})
|
||
logger.success(f"已发送聊天记录: {title}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"发送聊天记录失败: {e}")
|
||
|
||
async def _process_image_to_history(self, bot, message: dict, content: str) -> bool:
|
||
"""处理图片/表情包并保存描述到 history(通用方法)"""
|
||
from_wxid = message.get("FromWxid", "")
|
||
sender_wxid = message.get("SenderWxid", "")
|
||
is_group = message.get("IsGroup", False)
|
||
user_wxid = sender_wxid if is_group else from_wxid
|
||
|
||
# 只处理群聊
|
||
if not is_group:
|
||
return True
|
||
|
||
# 检查是否启用图片描述功能
|
||
image_desc_config = self.config.get("image_description", {})
|
||
if not image_desc_config.get("enabled", True):
|
||
return True
|
||
|
||
try:
|
||
# 解析XML获取图片信息
|
||
root = ET.fromstring(content)
|
||
|
||
# 尝试查找 <img> 标签(图片消息)或 <emoji> 标签(表情包)
|
||
img = root.find(".//img")
|
||
if img is None:
|
||
img = root.find(".//emoji")
|
||
|
||
if img is None:
|
||
return True
|
||
|
||
cdnbigimgurl = img.get("cdnbigimgurl", "") or img.get("cdnurl", "")
|
||
aeskey = img.get("aeskey", "")
|
||
|
||
# 检查是否是表情包(有 cdnurl 但可能没有 aeskey)
|
||
is_emoji = img.tag == "emoji"
|
||
|
||
if not cdnbigimgurl:
|
||
return True
|
||
|
||
# 图片消息需要 aeskey,表情包不需要
|
||
if not is_emoji and not aeskey:
|
||
return True
|
||
|
||
# 获取用户昵称
|
||
nickname = ""
|
||
try:
|
||
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
|
||
if user_info and user_info.get("nickName", {}).get("string"):
|
||
nickname = user_info["nickName"]["string"]
|
||
except:
|
||
pass
|
||
|
||
if not nickname:
|
||
from plugins.MessageLogger.main import MessageLogger
|
||
msg_logger = MessageLogger.get_instance()
|
||
if msg_logger:
|
||
try:
|
||
with msg_logger.get_db_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute(
|
||
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
|
||
(user_wxid,)
|
||
)
|
||
result = cursor.fetchone()
|
||
if result:
|
||
nickname = result[0]
|
||
except:
|
||
pass
|
||
|
||
if not nickname:
|
||
nickname = user_wxid or sender_wxid or "未知用户"
|
||
|
||
# 立即插入占位符到 history
|
||
placeholder_id = str(uuid.uuid4())
|
||
await self._add_to_history_with_id(from_wxid, nickname, "[图片: 处理中...]", placeholder_id)
|
||
logger.info(f"已插入图片占位符: {placeholder_id}")
|
||
|
||
# 将任务加入队列(不阻塞)
|
||
task = {
|
||
"bot": bot,
|
||
"from_wxid": from_wxid,
|
||
"nickname": nickname,
|
||
"cdnbigimgurl": cdnbigimgurl,
|
||
"aeskey": aeskey,
|
||
"is_emoji": is_emoji,
|
||
"placeholder_id": placeholder_id,
|
||
"config": image_desc_config
|
||
}
|
||
await self.image_desc_queue.put(task)
|
||
logger.info(f"图片描述任务已加入队列,当前队列长度: {self.image_desc_queue.qsize()}")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理图片消息失败: {e}")
|
||
return True
|
||
|
||
async def _image_desc_worker(self):
|
||
"""图片描述工作协程,从队列中取任务并处理"""
|
||
while True:
|
||
try:
|
||
task = await self.image_desc_queue.get()
|
||
await self._generate_and_update_image_description(
|
||
task["bot"], task["from_wxid"], task["nickname"],
|
||
task["cdnbigimgurl"], task["aeskey"], task["is_emoji"],
|
||
task["placeholder_id"], task["config"]
|
||
)
|
||
self.image_desc_queue.task_done()
|
||
except Exception as e:
|
||
logger.error(f"图片描述工作协程异常: {e}")
|
||
|
||
async def _generate_and_update_image_description(self, bot, from_wxid: str, nickname: str,
|
||
cdnbigimgurl: str, aeskey: str, is_emoji: bool,
|
||
placeholder_id: str, image_desc_config: dict):
|
||
"""异步生成图片描述并更新 history"""
|
||
try:
|
||
# 下载并编码图片/表情包
|
||
if is_emoji:
|
||
image_base64 = await self._download_emoji_and_encode(cdnbigimgurl)
|
||
else:
|
||
image_base64 = await self._download_and_encode_image(bot, cdnbigimgurl, aeskey)
|
||
|
||
if not image_base64:
|
||
logger.warning(f"{'表情包' if is_emoji else '图片'}下载失败")
|
||
await self._update_history_by_id(from_wxid, placeholder_id, "[图片]")
|
||
return
|
||
|
||
# 调用 AI 生成图片描述
|
||
description_prompt = image_desc_config.get("prompt", "请用一句话简洁地描述这张图片的主要内容。")
|
||
description = await self._generate_image_description(image_base64, description_prompt, image_desc_config)
|
||
|
||
if description:
|
||
await self._update_history_by_id(from_wxid, placeholder_id, f"[图片: {description}]")
|
||
logger.success(f"已更新图片描述: {nickname} - {description[:30]}...")
|
||
else:
|
||
await self._update_history_by_id(from_wxid, placeholder_id, "[图片]")
|
||
logger.warning(f"图片描述生成失败")
|
||
|
||
except Exception as e:
|
||
logger.error(f"异步生成图片描述失败: {e}")
|
||
await self._update_history_by_id(from_wxid, placeholder_id, "[图片]")
|
||
|
||
@on_image_message(priority=15)
|
||
async def handle_image_message(self, bot, message: dict):
|
||
"""处理直接发送的图片消息(生成描述并保存到 history,不触发 AI 回复)"""
|
||
logger.info("AIChat: handle_image_message 被调用")
|
||
content = message.get("Content", "")
|
||
return await self._process_image_to_history(bot, message, content)
|
||
|
||
@on_emoji_message(priority=15)
|
||
async def handle_emoji_message(self, bot, message: dict):
|
||
"""处理表情包消息(生成描述并保存到 history,不触发 AI 回复)"""
|
||
logger.info("AIChat: handle_emoji_message 被调用")
|
||
content = message.get("Content", "")
|
||
return await self._process_image_to_history(bot, message, content)
|