Files
WechatHookBot/plugins/AIChat/main.py
2025-12-05 18:06:13 +08:00

1923 lines
82 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI 聊天插件
支持自定义模型、API 和人设
支持 Redis 存储对话历史和限流
"""
import asyncio
import tomllib
import aiohttp
from pathlib import Path
from datetime import datetime
from loguru import logger
from utils.plugin_base import PluginBase
from utils.decorators import on_text_message, on_quote_message, on_image_message, on_emoji_message
from utils.redis_cache import get_cache
import xml.etree.ElementTree as ET
import base64
import uuid
# 可选导入代理支持
try:
from aiohttp_socks import ProxyConnector
PROXY_SUPPORT = True
except ImportError:
PROXY_SUPPORT = False
logger.warning("aiohttp_socks 未安装,代理功能将不可用")
class AIChat(PluginBase):
"""AI 聊天插件"""
# 插件元数据
description = "AI 聊天插件,支持自定义模型和人设"
author = "ShiHao"
version = "1.0.0"
def __init__(self):
super().__init__()
self.config = None
self.system_prompt = ""
self.memory = {} # 存储每个会话的记忆 {chat_id: [messages]}
self.history_dir = None # 历史记录目录
self.history_locks = {} # 每个会话一把锁
self.image_desc_queue = asyncio.Queue() # 图片描述任务队列
self.image_desc_workers = [] # 工作协程列表
async def async_init(self):
"""插件异步初始化"""
# 读取配置
config_path = Path(__file__).parent / "config.toml"
with open(config_path, "rb") as f:
self.config = tomllib.load(f)
# 读取人设
prompt_file = self.config["prompt"]["system_prompt_file"]
prompt_path = Path(__file__).parent / "prompts" / prompt_file
if prompt_path.exists():
with open(prompt_path, "r", encoding="utf-8") as f:
self.system_prompt = f.read().strip()
logger.success(f"已加载人设: {prompt_file}")
else:
logger.warning(f"人设文件不存在: {prompt_file},使用默认人设")
self.system_prompt = "你是一个友好的 AI 助手。"
# 检查代理配置
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5")
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
logger.info(f"AI 聊天插件已启用代理: {proxy_type}://{proxy_host}:{proxy_port}")
# 初始化历史记录目录
history_config = self.config.get("history", {})
if history_config.get("enabled", True):
history_dir_name = history_config.get("history_dir", "history")
self.history_dir = Path(__file__).parent / history_dir_name
self.history_dir.mkdir(exist_ok=True)
logger.info(f"历史记录目录: {self.history_dir}")
# 启动图片描述工作协程并发数为2
for i in range(2):
worker = asyncio.create_task(self._image_desc_worker())
self.image_desc_workers.append(worker)
logger.info("已启动 2 个图片描述工作协程")
logger.info(f"AI 聊天插件已加载,模型: {self.config['api']['model']}")
def _get_chat_id(self, from_wxid: str, sender_wxid: str = None, is_group: bool = False) -> str:
"""获取会话ID"""
if is_group:
# 群聊使用 "群ID:用户ID" 组合,确保每个用户有独立的对话记忆
user_wxid = sender_wxid or from_wxid
return f"{from_wxid}:{user_wxid}"
else:
return sender_wxid or from_wxid # 私聊使用用户ID
async def _get_user_nickname(self, bot, from_wxid: str, user_wxid: str, is_group: bool) -> str:
"""
获取用户昵称,优先使用 Redis 缓存
Args:
bot: WechatHookClient 实例
from_wxid: 消息来源群聊ID或私聊用户ID
user_wxid: 用户wxid
is_group: 是否群聊
Returns:
用户昵称
"""
if not is_group:
return ""
nickname = ""
# 1. 优先从 Redis 缓存获取
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
cached_info = redis_cache.get_user_basic_info(from_wxid, user_wxid)
if cached_info and cached_info.get("nickname"):
logger.debug(f"[缓存命中] 用户昵称: {user_wxid} -> {cached_info['nickname']}")
return cached_info["nickname"]
# 2. 缓存未命中,调用 API 获取
try:
user_info = await bot.get_user_info_in_chatroom(from_wxid, user_wxid)
if user_info and user_info.get("nickName", {}).get("string"):
nickname = user_info["nickName"]["string"]
# 存入缓存
if redis_cache and redis_cache.enabled:
redis_cache.set_user_info(from_wxid, user_wxid, user_info)
logger.debug(f"[已缓存] 用户昵称: {user_wxid} -> {nickname}")
return nickname
except Exception as e:
logger.warning(f"API获取用户昵称失败: {e}")
# 3. 从 MessageLogger 数据库查询
if not nickname:
try:
from plugins.MessageLogger.main import MessageLogger
msg_logger = MessageLogger.get_instance()
if msg_logger:
with msg_logger.get_db_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(
"SELECT nickname FROM messages WHERE sender_wxid = %s AND nickname != '' ORDER BY create_time DESC LIMIT 1",
(user_wxid,)
)
result = cursor.fetchone()
if result:
nickname = result[0]
except Exception as e:
logger.debug(f"从数据库获取昵称失败: {e}")
# 4. 最后降级使用 wxid
if not nickname:
nickname = user_wxid or "未知用户"
return nickname
def _check_rate_limit(self, user_wxid: str) -> tuple:
"""
检查用户是否超过限流
Args:
user_wxid: 用户wxid
Returns:
(是否允许, 剩余次数, 重置时间秒数)
"""
rate_limit_config = self.config.get("rate_limit", {})
if not rate_limit_config.get("enabled", True):
return (True, 999, 0)
redis_cache = get_cache()
if not redis_cache or not redis_cache.enabled:
return (True, 999, 0) # Redis 不可用时不限流
limit = rate_limit_config.get("ai_chat_limit", 20)
window = rate_limit_config.get("ai_chat_window", 60)
return redis_cache.check_rate_limit(user_wxid, limit, window, "ai_chat")
def _add_to_memory(self, chat_id: str, role: str, content, image_base64: str = None):
"""
添加消息到记忆
Args:
chat_id: 会话ID
role: 角色 (user/assistant)
content: 消息内容(可以是字符串或列表)
image_base64: 可选的图片base64数据
"""
if not self.config.get("memory", {}).get("enabled", False):
return
# 如果有图片,构建多模态内容
if image_base64:
message_content = [
{"type": "text", "text": content if isinstance(content, str) else ""},
{"type": "image_url", "image_url": {"url": image_base64}}
]
else:
message_content = content
# 优先使用 Redis 存储
redis_config = self.config.get("redis", {})
if redis_config.get("use_redis_history", True):
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
ttl = redis_config.get("chat_history_ttl", 86400)
redis_cache.add_chat_message(chat_id, role, message_content, ttl=ttl)
# 裁剪历史
max_messages = self.config["memory"]["max_messages"]
redis_cache.trim_chat_history(chat_id, max_messages)
return
# 降级到内存存储
if chat_id not in self.memory:
self.memory[chat_id] = []
self.memory[chat_id].append({"role": role, "content": message_content})
# 限制记忆长度
max_messages = self.config["memory"]["max_messages"]
if len(self.memory[chat_id]) > max_messages:
self.memory[chat_id] = self.memory[chat_id][-max_messages:]
def _get_memory_messages(self, chat_id: str) -> list:
"""获取记忆中的消息"""
if not self.config.get("memory", {}).get("enabled", False):
return []
# 优先从 Redis 获取
redis_config = self.config.get("redis", {})
if redis_config.get("use_redis_history", True):
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
max_messages = self.config["memory"]["max_messages"]
return redis_cache.get_chat_history(chat_id, max_messages)
# 降级到内存
return self.memory.get(chat_id, [])
def _clear_memory(self, chat_id: str):
"""清空指定会话的记忆"""
# 清空 Redis
redis_config = self.config.get("redis", {})
if redis_config.get("use_redis_history", True):
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
redis_cache.clear_chat_history(chat_id)
# 同时清空内存
if chat_id in self.memory:
del self.memory[chat_id]
async def _download_and_encode_image(self, bot, cdnurl: str, aeskey: str) -> str:
"""下载图片并转换为base64优先从缓存获取"""
try:
# 1. 优先从 Redis 缓存获取
from utils.redis_cache import RedisCache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
if media_key:
cached_data = redis_cache.get_cached_media(media_key, "image")
if cached_data:
logger.debug(f"[缓存命中] 图片从 Redis 获取: {media_key[:20]}...")
return cached_data
# 2. 缓存未命中,下载图片
logger.debug(f"[缓存未命中] 开始下载图片...")
temp_dir = Path(__file__).parent / "temp"
temp_dir.mkdir(exist_ok=True)
filename = f"temp_{uuid.uuid4().hex[:8]}.jpg"
save_path = str((temp_dir / filename).resolve())
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=2)
if not success:
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=1)
if not success:
return ""
# 等待文件写入完成
import os
import asyncio
for _ in range(20): # 最多等待10秒
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
break
await asyncio.sleep(0.5)
if not os.path.exists(save_path):
return ""
with open(save_path, "rb") as f:
image_data = base64.b64encode(f.read()).decode()
base64_result = f"data:image/jpeg;base64,{image_data}"
# 3. 缓存到 Redis供后续使用
if redis_cache and redis_cache.enabled and media_key:
redis_cache.cache_media(media_key, base64_result, "image", ttl=300)
logger.debug(f"[已缓存] 图片缓存到 Redis: {media_key[:20]}...")
try:
Path(save_path).unlink()
except:
pass
return base64_result
except Exception as e:
logger.error(f"下载图片失败: {e}")
return ""
async def _download_emoji_and_encode(self, cdn_url: str, max_retries: int = 3) -> str:
"""下载表情包并转换为base64HTTP 直接下载,带重试机制),优先从缓存获取"""
# 替换 HTML 实体
cdn_url = cdn_url.replace("&", "&")
# 1. 优先从 Redis 缓存获取
from utils.redis_cache import RedisCache
redis_cache = get_cache()
media_key = RedisCache.generate_media_key(cdnurl=cdn_url)
if redis_cache and redis_cache.enabled and media_key:
cached_data = redis_cache.get_cached_media(media_key, "emoji")
if cached_data:
logger.debug(f"[缓存命中] 表情包从 Redis 获取: {media_key[:20]}...")
return cached_data
# 2. 缓存未命中,下载表情包
logger.debug(f"[缓存未命中] 开始下载表情包...")
temp_dir = Path(__file__).parent / "temp"
temp_dir.mkdir(exist_ok=True)
filename = f"temp_{uuid.uuid4().hex[:8]}.gif"
save_path = temp_dir / filename
last_error = None
for attempt in range(max_retries):
try:
# 使用 aiohttp 下载,每次重试增加超时时间
timeout = aiohttp.ClientTimeout(total=30 + attempt * 15)
# 配置代理
connector = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5").upper()
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_username = proxy_config.get("username")
proxy_password = proxy_config.get("password")
if proxy_username and proxy_password:
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
else:
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
if PROXY_SUPPORT:
try:
connector = ProxyConnector.from_url(proxy_url)
except:
connector = None
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.get(cdn_url) as response:
if response.status == 200:
content = await response.read()
if len(content) == 0:
logger.warning(f"表情包下载内容为空,重试 {attempt + 1}/{max_retries}")
continue
# 编码为 base64
image_data = base64.b64encode(content).decode()
logger.debug(f"表情包下载成功,大小: {len(content)} 字节")
base64_result = f"data:image/gif;base64,{image_data}"
# 3. 缓存到 Redis供后续使用
if redis_cache and redis_cache.enabled and media_key:
redis_cache.cache_media(media_key, base64_result, "emoji", ttl=300)
logger.debug(f"[已缓存] 表情包缓存到 Redis: {media_key[:20]}...")
return base64_result
else:
logger.warning(f"表情包下载失败,状态码: {response.status},重试 {attempt + 1}/{max_retries}")
except asyncio.TimeoutError:
last_error = "请求超时"
logger.warning(f"表情包下载超时,重试 {attempt + 1}/{max_retries}")
except aiohttp.ClientError as e:
last_error = str(e)
logger.warning(f"表情包下载网络错误: {e},重试 {attempt + 1}/{max_retries}")
except Exception as e:
last_error = str(e)
logger.warning(f"表情包下载异常: {e},重试 {attempt + 1}/{max_retries}")
# 重试前等待(指数退避)
if attempt < max_retries - 1:
await asyncio.sleep(1 * (attempt + 1))
logger.error(f"表情包下载失败,已重试 {max_retries} 次: {last_error}")
return ""
async def _generate_image_description(self, image_base64: str, prompt: str, config: dict) -> str:
"""
使用 AI 生成图片描述
Args:
image_base64: 图片的 base64 数据
prompt: 描述提示词
config: 图片描述配置
Returns:
图片描述文本,失败返回空字符串
"""
try:
api_config = self.config["api"]
description_model = config.get("model", api_config["model"])
# 构建消息
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": image_base64}}
]
}
]
payload = {
"model": description_model,
"messages": messages,
"max_tokens": config.get("max_tokens", 1000),
"stream": True
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_config['api_key']}"
}
timeout = aiohttp.ClientTimeout(total=api_config["timeout"])
# 配置代理
connector = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5").upper()
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_username = proxy_config.get("username")
proxy_password = proxy_config.get("password")
if proxy_username and proxy_password:
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
else:
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
if PROXY_SUPPORT:
try:
connector = ProxyConnector.from_url(proxy_url)
except Exception as e:
logger.warning(f"代理配置失败,将直连: {e}")
connector = None
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.post(
api_config["url"],
json=payload,
headers=headers
) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"图片描述 API 返回错误: {resp.status}, {error_text[:200]}")
return ""
# 流式接收响应
import json
description = ""
async for line in resp.content:
line = line.decode('utf-8').strip()
if not line or line == "data: [DONE]":
continue
if line.startswith("data: "):
try:
data = json.loads(line[6:])
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
description += content
except:
pass
logger.debug(f"图片描述生成成功: {description}")
return description.strip()
except Exception as e:
logger.error(f"生成图片描述失败: {e}")
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
return ""
def _collect_tools(self):
"""收集所有插件的LLM工具"""
from utils.plugin_manager import PluginManager
tools = []
for plugin in PluginManager().plugins.values():
if hasattr(plugin, 'get_llm_tools'):
plugin_tools = plugin.get_llm_tools()
if plugin_tools:
tools.extend(plugin_tools)
return tools
async def _handle_list_prompts(self, bot, from_wxid: str):
"""处理人设列表指令"""
try:
prompts_dir = Path(__file__).parent / "prompts"
# 获取所有 .txt 文件
if not prompts_dir.exists():
await bot.send_text(from_wxid, "❌ prompts 目录不存在")
return
txt_files = sorted(prompts_dir.glob("*.txt"))
if not txt_files:
await bot.send_text(from_wxid, "❌ 没有找到任何人设文件")
return
# 构建列表消息
current_file = self.config["prompt"]["system_prompt_file"]
msg = "📋 可用人设列表:\n\n"
for i, file_path in enumerate(txt_files, 1):
filename = file_path.name
# 标记当前使用的人设
if filename == current_file:
msg += f"{i}. {filename}\n"
else:
msg += f"{i}. {filename}\n"
msg += f"\n💡 使用方法:/切人设 文件名.txt"
await bot.send_text(from_wxid, msg)
logger.info("已发送人设列表")
except Exception as e:
logger.error(f"获取人设列表失败: {e}")
await bot.send_text(from_wxid, f"❌ 获取人设列表失败: {str(e)}")
async def _handle_switch_prompt(self, bot, from_wxid: str, content: str):
"""处理切换人设指令"""
try:
# 提取文件名
parts = content.split(maxsplit=1)
if len(parts) < 2:
await bot.send_text(from_wxid, "❌ 请指定人设文件名\n格式:/切人设 文件名.txt")
return
filename = parts[1].strip()
# 检查文件是否存在
prompt_path = Path(__file__).parent / "prompts" / filename
if not prompt_path.exists():
await bot.send_text(from_wxid, f"❌ 人设文件不存在: {filename}")
return
# 读取新人设
with open(prompt_path, "r", encoding="utf-8") as f:
new_prompt = f.read().strip()
# 更新人设
self.system_prompt = new_prompt
self.config["prompt"]["system_prompt_file"] = filename
await bot.send_text(from_wxid, f"✅ 已切换人设: {filename}")
logger.success(f"管理员切换人设: {filename}")
except Exception as e:
logger.error(f"切换人设失败: {e}")
await bot.send_text(from_wxid, f"❌ 切换人设失败: {str(e)}")
@on_text_message(priority=80)
async def handle_message(self, bot, message: dict):
"""处理文本消息"""
content = message.get("Content", "").strip()
from_wxid = message.get("FromWxid", "")
sender_wxid = message.get("SenderWxid", "")
is_group = message.get("IsGroup", False)
# 获取实际发送者
user_wxid = sender_wxid if is_group else from_wxid
# 获取机器人 wxid 和管理员列表
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
admins = main_config.get("Bot", {}).get("admins", [])
# 检查是否是人设列表指令(精确匹配)
if content == "/人设列表":
await self._handle_list_prompts(bot, from_wxid)
return False
# 检查是否是切换人设指令(精确匹配前缀)
if content.startswith("/切人设 ") or content.startswith("/切换人设 "):
if user_wxid in admins:
await self._handle_switch_prompt(bot, from_wxid, content)
else:
await bot.send_text(from_wxid, "❌ 仅管理员可以切换人设")
return False
# 检查是否是清空记忆指令
clear_command = self.config.get("memory", {}).get("clear_command", "/清空记忆")
if content == clear_command:
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
self._clear_memory(chat_id)
await bot.send_text(from_wxid, "✅ 已清空当前会话的记忆")
return False
# 检查是否是记忆状态指令(仅管理员)
if content == "/记忆状态":
if user_wxid in admins:
if is_group:
history = await self._load_history(from_wxid)
max_context = self.config.get("history", {}).get("max_context", 50)
context_count = min(len(history), max_context)
msg = f"📊 群聊记忆: {len(history)}\n"
msg += f"💬 AI可见: 最近 {context_count}"
await bot.send_text(from_wxid, msg)
else:
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
memory = self._get_memory_messages(chat_id)
msg = f"📊 私聊记忆: {len(memory)}"
await bot.send_text(from_wxid, msg)
else:
await bot.send_text(from_wxid, "❌ 仅管理员可以查看记忆状态")
return False
# 检查是否应该回复
should_reply = self._should_reply(message, content, bot_wxid)
# 获取用户昵称(用于历史记录)- 使用缓存优化
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
# 保存到群组历史记录(所有消息都保存,不管是否回复)
if is_group:
await self._add_to_history(from_wxid, nickname, content)
# 如果不需要回复,直接返回
if not should_reply:
return
# 限流检查(仅在需要回复时检查)
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
if not allowed:
rate_limit_config = self.config.get("rate_limit", {})
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
msg = msg.format(seconds=reset_time)
await bot.send_text(from_wxid, msg)
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
return False
# 提取实际消息内容(去除@
actual_content = self._extract_content(message, content)
if not actual_content:
return
logger.info(f"AI 处理消息: {actual_content[:50]}...")
try:
# 获取会话ID并添加用户消息到记忆
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
self._add_to_memory(chat_id, "user", actual_content)
# 调用 AI API
response = await self._call_ai_api(actual_content, bot, from_wxid, chat_id, nickname, user_wxid, is_group)
# 发送回复并添加到记忆
# 注意:如果返回空字符串,说明已经以其他形式(如聊天记录)发送了,不需要再发送文本
if response:
await bot.send_text(from_wxid, response)
self._add_to_memory(chat_id, "assistant", response)
# 保存机器人回复到历史记录
if is_group:
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
await self._add_to_history(from_wxid, bot_nickname, response)
logger.success(f"AI 回复成功: {response[:50]}...")
else:
logger.info("AI 回复为空或已通过其他方式发送(如聊天记录)")
except Exception as e:
import traceback
error_detail = traceback.format_exc()
logger.error(f"AI 处理失败: {type(e).__name__}: {str(e)}")
logger.error(f"详细错误:\n{error_detail}")
await bot.send_text(from_wxid, "抱歉,我遇到了一些问题,请稍后再试。")
def _should_reply(self, message: dict, content: str, bot_wxid: str = None) -> bool:
"""判断是否应该回复"""
# 检查是否由AutoReply插件触发
if message.get('_auto_reply_triggered'):
return True
is_group = message.get("IsGroup", False)
# 检查群聊/私聊开关
if is_group and not self.config["behavior"].get("reply_group", True):
return False
if not is_group and not self.config["behavior"].get("reply_private", True):
return False
trigger_mode = self.config["behavior"].get("trigger_mode", "mention")
# all 模式:回复所有消息
if trigger_mode == "all":
return True
# mention 模式:检查是否@了机器人
if trigger_mode == "mention":
if is_group:
ats = message.get("Ats", [])
# 检查是否@了机器人
if not ats:
return False
# 如果没有 bot_wxid从配置文件读取
if not bot_wxid:
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
# 检查 @ 列表中是否包含机器人的 wxid
if bot_wxid and bot_wxid in ats:
return True
return False
else:
# 私聊直接回复
return True
# keyword 模式:检查关键词
if trigger_mode == "keyword":
keywords = self.config["behavior"]["keywords"]
return any(kw in content for kw in keywords)
return False
def _extract_content(self, message: dict, content: str) -> str:
"""提取实际消息内容(去除@等)"""
is_group = message.get("IsGroup", False)
if is_group:
# 群聊消息,去除@部分
# 格式通常是 "@昵称 消息内容"
parts = content.split(maxsplit=1)
if len(parts) > 1 and parts[0].startswith("@"):
return parts[1].strip()
return content.strip()
return content.strip()
async def _call_ai_api(self, user_message: str, bot=None, from_wxid: str = None, chat_id: str = None, nickname: str = "", user_wxid: str = None, is_group: bool = False) -> str:
"""调用 AI API"""
api_config = self.config["api"]
# 收集工具
tools = self._collect_tools()
logger.info(f"收集到 {len(tools)} 个工具函数")
if tools:
tool_names = [t["function"]["name"] for t in tools]
logger.info(f"工具列表: {tool_names}")
# 构建消息列表
system_content = self.system_prompt
# 添加当前时间信息
current_time = datetime.now()
weekday_map = {
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
4: "星期五", 5: "星期六", 6: "星期日"
}
weekday = weekday_map[current_time.weekday()]
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
system_content += f"\n\n当前时间:{time_str}"
if nickname:
system_content += f"\n当前对话用户的昵称是:{nickname}"
messages = [{"role": "system", "content": system_content}]
# 从 JSON 历史记录加载上下文(仅群聊)
if is_group and from_wxid:
history = await self._load_history(from_wxid)
max_context = self.config.get("history", {}).get("max_context", 50)
# 取最近的 N 条消息作为上下文
recent_history = history[-max_context:] if len(history) > max_context else history
# 转换为 AI 消息格式
for msg in recent_history:
msg_nickname = msg.get("nickname", "")
msg_content = msg.get("content", "")
# 检查是否是多模态内容(包含图片)
if isinstance(msg_content, list):
# 多模态内容:添加昵称前缀到文本部分
content_with_nickname = []
for item in msg_content:
if item.get("type") == "text":
# 在文本前添加昵称
content_with_nickname.append({
"type": "text",
"text": f"[{msg_nickname}] {item.get('text', '')}"
})
else:
# 图片等其他类型直接保留
content_with_nickname.append(item)
messages.append({
"role": "user",
"content": content_with_nickname
})
else:
# 纯文本内容:简单格式 [昵称] 内容
messages.append({
"role": "user",
"content": f"[{msg_nickname}] {msg_content}"
})
else:
# 私聊使用原有的 memory 机制
if chat_id:
memory_messages = self._get_memory_messages(chat_id)
if memory_messages and len(memory_messages) > 1:
messages.extend(memory_messages[:-1])
# 添加当前用户消息
messages.append({"role": "user", "content": f"[{nickname}] {user_message}" if is_group and nickname else user_message})
# 保存用户信息供工具调用使用
self._current_user_wxid = user_wxid
self._current_is_group = is_group
payload = {
"model": api_config["model"],
"messages": messages
}
if tools:
payload["tools"] = tools
logger.debug(f"已将 {len(tools)} 个工具添加到请求中")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_config['api_key']}"
}
timeout = aiohttp.ClientTimeout(total=api_config["timeout"])
# 配置代理
connector = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5").upper()
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_username = proxy_config.get("username")
proxy_password = proxy_config.get("password")
# 构建代理 URL
if proxy_username and proxy_password:
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
else:
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
if PROXY_SUPPORT:
try:
connector = ProxyConnector.from_url(proxy_url)
logger.debug(f"使用代理: {proxy_type}://{proxy_host}:{proxy_port}")
except Exception as e:
logger.warning(f"代理配置失败,将直连: {e}")
connector = None
else:
logger.warning("代理功能不可用aiohttp_socks 未安装),将直连")
connector = None
# 启用流式响应
payload["stream"] = True
try:
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
logger.debug(f"发送流式 API 请求: {api_config['url']}")
async with session.post(
api_config["url"],
json=payload,
headers=headers
) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"API 返回错误状态码: {resp.status}, 响应: {error_text}")
raise Exception(f"API 错误 {resp.status}: {error_text}")
# 流式接收响应
import json
full_content = ""
tool_calls_dict = {} # 使用字典来组装工具调用 {index: tool_call}
async for line in resp.content:
line = line.decode('utf-8').strip()
if not line or line == "data: [DONE]":
continue
if line.startswith("data: "):
try:
data = json.loads(line[6:])
choices = data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
# 收集内容
content = delta.get("content", "")
if content:
full_content += content
# 收集工具调用(增量式组装)
if delta.get("tool_calls"):
for tool_call_delta in delta["tool_calls"]:
index = tool_call_delta.get("index", 0)
# 初始化工具调用
if index not in tool_calls_dict:
tool_calls_dict[index] = {
"id": "",
"type": "function",
"function": {
"name": "",
"arguments": ""
}
}
# 更新 id
if "id" in tool_call_delta:
tool_calls_dict[index]["id"] = tool_call_delta["id"]
# 更新 type
if "type" in tool_call_delta:
tool_calls_dict[index]["type"] = tool_call_delta["type"]
# 更新 function
if "function" in tool_call_delta:
func_delta = tool_call_delta["function"]
if "name" in func_delta:
tool_calls_dict[index]["function"]["name"] += func_delta["name"]
if "arguments" in func_delta:
tool_calls_dict[index]["function"]["arguments"] += func_delta["arguments"]
except Exception as e:
logger.debug(f"解析流式数据失败: {e}")
pass
# 转换为列表
tool_calls_data = [tool_calls_dict[i] for i in sorted(tool_calls_dict.keys())] if tool_calls_dict else []
logger.debug(f"流式 API 响应完成")
# 检查是否有函数调用
if tool_calls_data:
# 收集所有工具调用结果
tool_results = []
has_no_reply = False
chat_record_info = None
for tool_call in tool_calls_data:
function_name = tool_call.get("function", {}).get("name", "")
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
tool_call_id = tool_call.get("id", "")
if not function_name:
continue
try:
arguments = json.loads(arguments_str)
except:
arguments = {}
logger.info(f"AI调用工具: {function_name}, 参数: {arguments}")
# 执行工具并等待结果
if bot and from_wxid:
result = await self._execute_tool_and_get_result(function_name, arguments, bot, from_wxid)
if result and result.get("no_reply"):
has_no_reply = True
logger.info(f"工具 {function_name} 要求不回复")
if result and result.get("send_as_chat_record"):
chat_record_info = {
"title": result.get("chat_record_title", "AI 回复"),
"bot": bot,
"from_wxid": from_wxid
}
logger.info(f"工具 {function_name} 要求以聊天记录形式发送")
tool_results.append({
"tool_call_id": tool_call_id,
"role": "tool",
"name": function_name,
"content": result.get("message", "") if result else "工具执行失败"
})
else:
logger.error(f"工具调用跳过: bot={bot}, from_wxid={from_wxid}")
tool_results.append({
"tool_call_id": tool_call_id,
"role": "tool",
"name": function_name,
"content": "工具执行失败:缺少必要参数"
})
if has_no_reply:
logger.info("工具要求不回复,跳过 AI 回复")
return ""
# 将工具结果发送回 AI让 AI 生成最终回复
messages.append({
"role": "assistant",
"content": full_content if full_content else None,
"tool_calls": tool_calls_data
})
messages.extend(tool_results)
# 检查工具执行结果,判断是否需要 AI 生成回复
# 如果所有工具都成功执行且已发送内容,可能不需要额外回复
all_tools_sent_content = all(
result.get("content") and ("已生成" in result.get("content", "") or "已发送" in result.get("content", ""))
for result in tool_results
)
# 如果工具已经发送了内容(如图片),可以选择不再调用 AI 生成额外回复
# 但为了更好的用户体验,我们还是让 AI 生成一个简短的回复
logger.debug(f"工具执行完成,准备获取 AI 最终回复")
# 再次调用 API 获取最终回复(流式)
payload["messages"] = messages
async with session.post(
api_config["url"],
json=payload,
headers=headers
) as resp2:
if resp2.status != 200:
error_text = await resp2.text()
logger.error(f"API 返回错误: {resp2.status}, {error_text}")
# 如果第二次调用失败,但工具已经发送了内容,返回空字符串
if all_tools_sent_content:
logger.info("工具已发送内容,跳过 AI 回复")
return ""
# 否则返回一个默认消息
return "✅ 已完成"
# 流式接收第二次响应
ai_reply = ""
async for line in resp2.content:
line = line.decode('utf-8').strip()
if not line or line == "data: [DONE]":
continue
if line.startswith("data: "):
try:
data = json.loads(line[6:])
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
ai_reply += content
except:
pass
# 如果需要以聊天记录形式发送
if chat_record_info and ai_reply:
await self._send_chat_records(
chat_record_info["bot"],
chat_record_info["from_wxid"],
chat_record_info["title"],
ai_reply
)
return ""
# 返回 AI 的回复
# 如果 AI 没有生成回复,但工具已经发送了内容,返回空字符串
if not ai_reply.strip() and all_tools_sent_content:
logger.info("AI 无回复且工具已发送内容,不发送额外消息")
return ""
# 返回 AI 的回复,如果为空则返回一个友好的确认消息
return ai_reply.strip() if ai_reply.strip() else "✅ 完成"
# 检查是否包含错误的工具调用格式
if "<tool_code>" in full_content or "print(" in full_content and "flow2_ai_image_generation" in full_content:
logger.warning("检测到模型输出了错误的工具调用格式,拦截并返回提示")
return "抱歉,我遇到了一些技术问题,请重新描述一下你的需求~"
return full_content.strip()
except aiohttp.ClientError as e:
logger.error(f"网络请求失败: {type(e).__name__}: {str(e)}")
raise Exception(f"网络请求失败: {str(e)}")
except asyncio.TimeoutError:
logger.error(f"API 请求超时 (timeout={api_config['timeout']}s)")
raise Exception(f"API 请求超时")
except KeyError as e:
logger.error(f"API 响应格式错误,缺少字段: {e}")
raise Exception(f"API 响应格式错误: {e}")
def _get_history_file(self, chat_id: str) -> Path:
"""获取群聊历史记录文件路径"""
if not self.history_dir:
return None
safe_name = chat_id.replace("@", "_").replace(":", "_")
return self.history_dir / f"{safe_name}.json"
def _get_history_lock(self, chat_id: str) -> asyncio.Lock:
"""获取指定会话的锁, 每个会话一把"""
lock = self.history_locks.get(chat_id)
if lock is None:
lock = asyncio.Lock()
self.history_locks[chat_id] = lock
return lock
def _read_history_file(self, history_file: Path) -> list:
try:
import json
with open(history_file, "r", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError:
return []
except Exception as e:
logger.error(f"读取历史记录失败: {e}")
return []
def _write_history_file(self, history_file: Path, history: list):
import json
history_file.parent.mkdir(parents=True, exist_ok=True)
temp_file = Path(str(history_file) + ".tmp")
with open(temp_file, "w", encoding="utf-8") as f:
json.dump(history, f, ensure_ascii=False, indent=2)
temp_file.replace(history_file)
def _use_redis_for_group_history(self) -> bool:
"""检查是否使用 Redis 存储群聊历史"""
redis_config = self.config.get("redis", {})
if not redis_config.get("use_redis_history", True):
return False
redis_cache = get_cache()
return redis_cache and redis_cache.enabled
async def _load_history(self, chat_id: str) -> list:
"""异步读取群聊历史, 优先使用 Redis"""
# 优先使用 Redis
if self._use_redis_for_group_history():
redis_cache = get_cache()
max_history = self.config.get("history", {}).get("max_history", 100)
return redis_cache.get_group_history(chat_id, max_history)
# 降级到文件存储
history_file = self._get_history_file(chat_id)
if not history_file:
return []
lock = self._get_history_lock(chat_id)
async with lock:
return self._read_history_file(history_file)
async def _save_history(self, chat_id: str, history: list):
"""异步写入群聊历史, 包含长度截断"""
# Redis 模式下不需要单独保存add_group_message 已经处理
if self._use_redis_for_group_history():
return
history_file = self._get_history_file(chat_id)
if not history_file:
return
max_history = self.config.get("history", {}).get("max_history", 100)
if len(history) > max_history:
history = history[-max_history:]
lock = self._get_history_lock(chat_id)
async with lock:
self._write_history_file(history_file, history)
async def _add_to_history(self, chat_id: str, nickname: str, content: str, image_base64: str = None):
"""
将消息存入群聊历史
Args:
chat_id: 群聊ID
nickname: 用户昵称
content: 消息内容
image_base64: 可选的图片base64
"""
if not self.config.get("history", {}).get("enabled", True):
return
# 构建消息内容
if image_base64:
message_content = [
{"type": "text", "text": content},
{"type": "image_url", "image_url": {"url": image_base64}}
]
else:
message_content = content
# 优先使用 Redis
if self._use_redis_for_group_history():
redis_cache = get_cache()
redis_config = self.config.get("redis", {})
ttl = redis_config.get("group_history_ttl", 172800)
redis_cache.add_group_message(chat_id, nickname, message_content, ttl=ttl)
# 裁剪历史
max_history = self.config.get("history", {}).get("max_history", 100)
redis_cache.trim_group_history(chat_id, max_history)
return
# 降级到文件存储
history_file = self._get_history_file(chat_id)
if not history_file:
return
lock = self._get_history_lock(chat_id)
async with lock:
history = self._read_history_file(history_file)
message_record = {
"nickname": nickname,
"timestamp": datetime.now().isoformat(),
"content": message_content
}
history.append(message_record)
max_history = self.config.get("history", {}).get("max_history", 100)
if len(history) > max_history:
history = history[-max_history:]
self._write_history_file(history_file, history)
async def _add_to_history_with_id(self, chat_id: str, nickname: str, content: str, record_id: str):
"""带ID的历史追加, 便于后续更新"""
if not self.config.get("history", {}).get("enabled", True):
return
# 优先使用 Redis
if self._use_redis_for_group_history():
redis_cache = get_cache()
redis_config = self.config.get("redis", {})
ttl = redis_config.get("group_history_ttl", 172800)
redis_cache.add_group_message(chat_id, nickname, content, record_id=record_id, ttl=ttl)
# 裁剪历史
max_history = self.config.get("history", {}).get("max_history", 100)
redis_cache.trim_group_history(chat_id, max_history)
return
# 降级到文件存储
history_file = self._get_history_file(chat_id)
if not history_file:
return
lock = self._get_history_lock(chat_id)
async with lock:
history = self._read_history_file(history_file)
message_record = {
"id": record_id,
"nickname": nickname,
"timestamp": datetime.now().isoformat(),
"content": content
}
history.append(message_record)
max_history = self.config.get("history", {}).get("max_history", 100)
if len(history) > max_history:
history = history[-max_history:]
self._write_history_file(history_file, history)
async def _update_history_by_id(self, chat_id: str, record_id: str, new_content: str):
"""根据ID更新历史记录"""
if not self.config.get("history", {}).get("enabled", True):
return
# 优先使用 Redis
if self._use_redis_for_group_history():
redis_cache = get_cache()
redis_cache.update_group_message_by_id(chat_id, record_id, new_content)
return
# 降级到文件存储
history_file = self._get_history_file(chat_id)
if not history_file:
return
lock = self._get_history_lock(chat_id)
async with lock:
history = self._read_history_file(history_file)
for record in history:
if record.get("id") == record_id:
record["content"] = new_content
break
max_history = self.config.get("history", {}).get("max_history", 100)
if len(history) > max_history:
history = history[-max_history:]
self._write_history_file(history_file, history)
async def _execute_tool_and_get_result(self, tool_name: str, arguments: dict, bot, from_wxid: str):
"""执行工具调用并返回结果"""
from utils.plugin_manager import PluginManager
# 添加用户信息到 arguments
arguments["user_wxid"] = getattr(self, "_current_user_wxid", from_wxid)
arguments["is_group"] = getattr(self, "_current_is_group", False)
logger.info(f"开始执行工具: {tool_name}")
plugins = PluginManager().plugins
logger.info(f"检查 {len(plugins)} 个插件")
for plugin_name, plugin in plugins.items():
logger.debug(f"检查插件: {plugin_name}, 有execute_llm_tool: {hasattr(plugin, 'execute_llm_tool')}")
if hasattr(plugin, 'execute_llm_tool'):
try:
logger.info(f"调用 {plugin_name}.execute_llm_tool")
result = await plugin.execute_llm_tool(tool_name, arguments, bot, from_wxid)
logger.info(f"{plugin_name} 返回: {result}")
if result is not None:
if result.get("success"):
logger.success(f"工具执行成功: {tool_name}")
return result
else:
logger.debug(f"{plugin_name} 不处理此工具,继续检查下一个插件")
except Exception as e:
logger.error(f"工具执行异常 ({plugin_name}): {tool_name}, {e}")
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
logger.warning(f"未找到工具: {tool_name}")
return {"success": False, "message": f"未找到工具: {tool_name}"}
@on_quote_message(priority=79)
async def handle_quote_message(self, bot, message: dict):
"""处理引用消息(包含图片)"""
content = message.get("Content", "").strip()
from_wxid = message.get("FromWxid", "")
sender_wxid = message.get("SenderWxid", "")
is_group = message.get("IsGroup", False)
user_wxid = sender_wxid if is_group else from_wxid
try:
# 解析XML获取标题和引用消息
root = ET.fromstring(content)
title = root.find(".//title")
if title is None or not title.text:
logger.debug("引用消息没有标题,跳过")
return True
title_text = title.text.strip()
logger.info(f"收到引用消息,标题: {title_text[:50]}...")
# 检查是否应该回复
if not self._should_reply_quote(message, title_text):
logger.debug("引用消息不满足回复条件")
return True
# 获取引用消息中的图片信息
refermsg = root.find(".//refermsg")
if refermsg is None:
logger.debug("引用消息中没有 refermsg 节点")
return True
refer_content = refermsg.find("content")
if refer_content is None or not refer_content.text:
logger.debug("引用消息中没有 content")
return True
# 解码HTML实体
import html
refer_xml = html.unescape(refer_content.text)
refer_root = ET.fromstring(refer_xml)
# 提取图片信息
img = refer_root.find(".//img")
if img is None:
logger.debug("引用的消息不是图片")
return True
cdnbigimgurl = img.get("cdnbigimgurl", "")
aeskey = img.get("aeskey", "")
if not cdnbigimgurl or not aeskey:
logger.warning(f"图片信息不完整: cdnurl={bool(cdnbigimgurl)}, aeskey={bool(aeskey)}")
return True
logger.info(f"AI处理引用图片消息: {title_text[:50]}...")
# 限流检查
allowed, remaining, reset_time = self._check_rate_limit(user_wxid)
if not allowed:
rate_limit_config = self.config.get("rate_limit", {})
msg = rate_limit_config.get("rate_limit_message", "⚠️ 消息太频繁了,请 {seconds} 秒后再试~")
msg = msg.format(seconds=reset_time)
await bot.send_text(from_wxid, msg)
logger.warning(f"用户 {user_wxid} 触发限流,{reset_time}秒后重置")
return False
# 获取用户昵称 - 使用缓存优化
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
# 下载并编码图片
logger.info(f"开始下载图片: {cdnbigimgurl[:50]}...")
image_base64 = await self._download_and_encode_image(bot, cdnbigimgurl, aeskey)
if not image_base64:
logger.error("图片下载失败")
await bot.send_text(from_wxid, "❌ 无法处理图片")
return False
logger.info("图片下载和编码成功")
# 获取会话ID并添加消息到记忆包含图片base64
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
self._add_to_memory(chat_id, "user", title_text, image_base64=image_base64)
# 保存用户引用图片消息到群组历史记录
if is_group:
await self._add_to_history(from_wxid, nickname, title_text, image_base64=image_base64)
# 调用AI API带图片
response = await self._call_ai_api_with_image(title_text, image_base64, bot, from_wxid, chat_id, nickname, user_wxid, is_group)
if response:
await bot.send_text(from_wxid, response)
self._add_to_memory(chat_id, "assistant", response)
# 保存机器人回复到历史记录
if is_group:
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_nickname = main_config.get("Bot", {}).get("nickname", "机器人")
await self._add_to_history(from_wxid, bot_nickname, response)
logger.success(f"AI回复成功: {response[:50]}...")
return False
except Exception as e:
logger.error(f"处理引用消息失败: {e}")
return True
def _should_reply_quote(self, message: dict, title_text: str) -> bool:
"""判断是否应该回复引用消息"""
is_group = message.get("IsGroup", False)
# 检查群聊/私聊开关
if is_group and not self.config["behavior"]["reply_group"]:
return False
if not is_group and not self.config["behavior"]["reply_private"]:
return False
trigger_mode = self.config["behavior"]["trigger_mode"]
# all模式回复所有消息
if trigger_mode == "all":
return True
# mention模式检查是否@了机器人
if trigger_mode == "mention":
if is_group:
ats = message.get("Ats", [])
if not ats:
return False
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
return bot_wxid and bot_wxid in ats
else:
return True
# keyword模式检查关键词
if trigger_mode == "keyword":
keywords = self.config["behavior"]["keywords"]
return any(kw in title_text for kw in keywords)
return False
async def _call_ai_api_with_image(self, user_message: str, image_base64: str, bot=None, from_wxid: str = None, chat_id: str = None, nickname: str = "", user_wxid: str = None, is_group: bool = False) -> str:
"""调用AI API带图片"""
api_config = self.config["api"]
tools = self._collect_tools()
# 构建消息列表
system_content = self.system_prompt
# 添加当前时间信息
current_time = datetime.now()
weekday_map = {
0: "星期一", 1: "星期二", 2: "星期三", 3: "星期四",
4: "星期五", 5: "星期六", 6: "星期日"
}
weekday = weekday_map[current_time.weekday()]
time_str = current_time.strftime(f"%Y年%m月%d日 %H:%M:%S {weekday}")
system_content += f"\n\n当前时间:{time_str}"
if nickname:
system_content += f"\n当前对话用户的昵称是:{nickname}"
messages = [{"role": "system", "content": system_content}]
# 添加历史记忆
if chat_id:
memory_messages = self._get_memory_messages(chat_id)
if memory_messages and len(memory_messages) > 1:
messages.extend(memory_messages[:-1])
# 添加当前用户消息(带图片)
messages.append({
"role": "user",
"content": [
{"type": "text", "text": user_message},
{"type": "image_url", "image_url": {"url": image_base64}}
]
})
# 保存用户信息供工具调用使用
self._current_user_wxid = user_wxid
self._current_is_group = is_group
payload = {
"model": api_config["model"],
"messages": messages,
"stream": True
}
if tools:
payload["tools"] = tools
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_config['api_key']}"
}
timeout = aiohttp.ClientTimeout(total=api_config["timeout"])
# 配置代理
connector = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5").upper()
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_username = proxy_config.get("username")
proxy_password = proxy_config.get("password")
if proxy_username and proxy_password:
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
else:
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
if PROXY_SUPPORT:
try:
connector = ProxyConnector.from_url(proxy_url)
logger.debug(f"使用代理: {proxy_type}://{proxy_host}:{proxy_port}")
except Exception as e:
logger.warning(f"代理配置失败,将直连: {e}")
connector = None
else:
logger.warning("代理功能不可用aiohttp_socks 未安装),将直连")
connector = None
try:
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.post(api_config["url"], json=payload, headers=headers) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"API返回错误状态码: {resp.status}, 响应: {error_text}")
raise Exception(f"API错误 {resp.status}: {error_text}")
# 流式接收响应
import json
full_content = ""
tool_calls_dict = {} # 使用字典来组装工具调用 {index: tool_call}
async for line in resp.content:
line = line.decode('utf-8').strip()
if not line or line == "data: [DONE]":
continue
if line.startswith("data: "):
try:
data = json.loads(line[6:])
choices = data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
full_content += content
# 收集工具调用(增量式组装)
if delta.get("tool_calls"):
for tool_call_delta in delta["tool_calls"]:
index = tool_call_delta.get("index", 0)
# 初始化工具调用
if index not in tool_calls_dict:
tool_calls_dict[index] = {
"id": "",
"type": "function",
"function": {
"name": "",
"arguments": ""
}
}
# 更新 id
if "id" in tool_call_delta:
tool_calls_dict[index]["id"] = tool_call_delta["id"]
# 更新 type
if "type" in tool_call_delta:
tool_calls_dict[index]["type"] = tool_call_delta["type"]
# 更新 function
if "function" in tool_call_delta:
func_delta = tool_call_delta["function"]
if "name" in func_delta:
tool_calls_dict[index]["function"]["name"] += func_delta["name"]
if "arguments" in func_delta:
tool_calls_dict[index]["function"]["arguments"] += func_delta["arguments"]
except Exception as e:
logger.debug(f"解析流式数据失败: {e}")
pass
# 转换为列表
tool_calls_data = [tool_calls_dict[i] for i in sorted(tool_calls_dict.keys())] if tool_calls_dict else []
# 检查是否有函数调用
if tool_calls_data:
# 收集所有工具调用结果
tool_results = []
for tool_call in tool_calls_data:
function_name = tool_call.get("function", {}).get("name", "")
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
tool_call_id = tool_call.get("id", "")
if not function_name:
continue
try:
arguments = json.loads(arguments_str)
except:
arguments = {}
# 如果是图生图工具,添加图片 base64
if function_name == "flow2_ai_image_generation" and image_base64:
arguments["image_base64"] = image_base64
logger.info(f"AI调用图生图工具已添加图片数据")
logger.info(f"AI调用工具: {function_name}, 参数: {arguments}")
# 执行工具并等待结果
if bot and from_wxid:
result = await self._execute_tool_and_get_result(function_name, arguments, bot, from_wxid)
tool_results.append({
"tool_call_id": tool_call_id,
"role": "tool",
"name": function_name,
"content": result.get("message", "") if result else "工具执行失败"
})
# 如果工具已发送内容,返回空
if tool_results and all("已生成" in r.get("content", "") or "已发送" in r.get("content", "") for r in tool_results):
return ""
return full_content.strip() or None
# 检查是否包含错误的工具调用格式
if "<tool_code>" in full_content or "print(" in full_content and "flow2_ai_image_generation" in full_content:
logger.warning("检测到模型输出了错误的工具调用格式,拦截并返回提示")
return "抱歉,我遇到了一些技术问题,请重新描述一下你的需求~"
return full_content.strip()
except Exception as e:
logger.error(f"调用AI API失败: {e}")
raise
async def _send_chat_records(self, bot, from_wxid: str, title: str, content: str):
"""发送聊天记录格式消息"""
try:
import uuid
import time
import hashlib
import xml.etree.ElementTree as ET
is_group = from_wxid.endswith("@chatroom")
# 自动分割内容
max_length = 800
content_parts = []
if len(content) <= max_length:
content_parts = [content]
else:
lines = content.split('\n')
current_part = ""
for line in lines:
if len(current_part + line + '\n') > max_length:
if current_part:
content_parts.append(current_part.strip())
current_part = line + '\n'
else:
content_parts.append(line[:max_length])
current_part = line[max_length:] + '\n'
else:
current_part += line + '\n'
if current_part.strip():
content_parts.append(current_part.strip())
recordinfo = ET.Element("recordinfo")
info_el = ET.SubElement(recordinfo, "info")
info_el.text = title
is_group_el = ET.SubElement(recordinfo, "isChatRoom")
is_group_el.text = "1" if is_group else "0"
datalist = ET.SubElement(recordinfo, "datalist")
datalist.set("count", str(len(content_parts)))
desc_el = ET.SubElement(recordinfo, "desc")
desc_el.text = title
fromscene_el = ET.SubElement(recordinfo, "fromscene")
fromscene_el.text = "3"
for i, part in enumerate(content_parts):
di = ET.SubElement(datalist, "dataitem")
di.set("datatype", "1")
di.set("dataid", uuid.uuid4().hex)
src_local_id = str((int(time.time() * 1000) % 90000) + 10000)
new_msg_id = str(int(time.time() * 1000) + i)
create_time = str(int(time.time()) - len(content_parts) + i)
ET.SubElement(di, "srcMsgLocalid").text = src_local_id
ET.SubElement(di, "sourcetime").text = time.strftime("%Y-%m-%d %H:%M", time.localtime(int(create_time)))
ET.SubElement(di, "fromnewmsgid").text = new_msg_id
ET.SubElement(di, "srcMsgCreateTime").text = create_time
ET.SubElement(di, "sourcename").text = "AI助手"
ET.SubElement(di, "sourceheadurl").text = ""
ET.SubElement(di, "datatitle").text = part
ET.SubElement(di, "datadesc").text = part
ET.SubElement(di, "datafmt").text = "text"
ET.SubElement(di, "ischatroom").text = "1" if is_group else "0"
dataitemsource = ET.SubElement(di, "dataitemsource")
ET.SubElement(dataitemsource, "hashusername").text = hashlib.sha256(from_wxid.encode("utf-8")).hexdigest()
record_xml = ET.tostring(recordinfo, encoding="unicode")
appmsg_parts = [
"<appmsg appid=\"\" sdkver=\"0\">",
f"<title>{title}</title>",
f"<des>{title}</des>",
"<type>19</type>",
"<url>https://support.weixin.qq.com/cgi-bin/mmsupport-bin/readtemplate?t=page/favorite_record__w_unsupport</url>",
"<appattach><cdnthumbaeskey></cdnthumbaeskey><aeskey></aeskey></appattach>",
f"<recorditem><![CDATA[{record_xml}]]></recorditem>",
"<percent>0</percent>",
"</appmsg>"
]
appmsg_xml = "".join(appmsg_parts)
await bot._send_data_async(11214, {"to_wxid": from_wxid, "content": appmsg_xml})
logger.success(f"已发送聊天记录: {title}")
except Exception as e:
logger.error(f"发送聊天记录失败: {e}")
async def _process_image_to_history(self, bot, message: dict, content: str) -> bool:
"""处理图片/表情包并保存描述到 history通用方法"""
from_wxid = message.get("FromWxid", "")
sender_wxid = message.get("SenderWxid", "")
is_group = message.get("IsGroup", False)
user_wxid = sender_wxid if is_group else from_wxid
# 只处理群聊
if not is_group:
return True
# 检查是否启用图片描述功能
image_desc_config = self.config.get("image_description", {})
if not image_desc_config.get("enabled", True):
return True
try:
# 解析XML获取图片信息
root = ET.fromstring(content)
# 尝试查找 <img> 标签(图片消息)或 <emoji> 标签(表情包)
img = root.find(".//img")
if img is None:
img = root.find(".//emoji")
if img is None:
return True
cdnbigimgurl = img.get("cdnbigimgurl", "") or img.get("cdnurl", "")
aeskey = img.get("aeskey", "")
# 检查是否是表情包(有 cdnurl 但可能没有 aeskey
is_emoji = img.tag == "emoji"
if not cdnbigimgurl:
return True
# 图片消息需要 aeskey表情包不需要
if not is_emoji and not aeskey:
return True
# 获取用户昵称 - 使用缓存优化
nickname = await self._get_user_nickname(bot, from_wxid, user_wxid, is_group)
# 立即插入占位符到 history
placeholder_id = str(uuid.uuid4())
await self._add_to_history_with_id(from_wxid, nickname, "[图片: 处理中...]", placeholder_id)
logger.info(f"已插入图片占位符: {placeholder_id}")
# 将任务加入队列(不阻塞)
task = {
"bot": bot,
"from_wxid": from_wxid,
"nickname": nickname,
"cdnbigimgurl": cdnbigimgurl,
"aeskey": aeskey,
"is_emoji": is_emoji,
"placeholder_id": placeholder_id,
"config": image_desc_config
}
await self.image_desc_queue.put(task)
logger.info(f"图片描述任务已加入队列,当前队列长度: {self.image_desc_queue.qsize()}")
return True
except Exception as e:
logger.error(f"处理图片消息失败: {e}")
return True
async def _image_desc_worker(self):
"""图片描述工作协程,从队列中取任务并处理"""
while True:
try:
task = await self.image_desc_queue.get()
await self._generate_and_update_image_description(
task["bot"], task["from_wxid"], task["nickname"],
task["cdnbigimgurl"], task["aeskey"], task["is_emoji"],
task["placeholder_id"], task["config"]
)
self.image_desc_queue.task_done()
except Exception as e:
logger.error(f"图片描述工作协程异常: {e}")
async def _generate_and_update_image_description(self, bot, from_wxid: str, nickname: str,
cdnbigimgurl: str, aeskey: str, is_emoji: bool,
placeholder_id: str, image_desc_config: dict):
"""异步生成图片描述并更新 history"""
try:
# 下载并编码图片/表情包
if is_emoji:
image_base64 = await self._download_emoji_and_encode(cdnbigimgurl)
else:
image_base64 = await self._download_and_encode_image(bot, cdnbigimgurl, aeskey)
if not image_base64:
logger.warning(f"{'表情包' if is_emoji else '图片'}下载失败")
await self._update_history_by_id(from_wxid, placeholder_id, "[图片]")
return
# 调用 AI 生成图片描述
description_prompt = image_desc_config.get("prompt", "请用一句话简洁地描述这张图片的主要内容。")
description = await self._generate_image_description(image_base64, description_prompt, image_desc_config)
if description:
await self._update_history_by_id(from_wxid, placeholder_id, f"[图片: {description}]")
logger.success(f"已更新图片描述: {nickname} - {description[:30]}...")
else:
await self._update_history_by_id(from_wxid, placeholder_id, "[图片]")
logger.warning(f"图片描述生成失败")
except Exception as e:
logger.error(f"异步生成图片描述失败: {e}")
await self._update_history_by_id(from_wxid, placeholder_id, "[图片]")
@on_image_message(priority=15)
async def handle_image_message(self, bot, message: dict):
"""处理直接发送的图片消息(生成描述并保存到 history不触发 AI 回复)"""
logger.info("AIChat: handle_image_message 被调用")
content = message.get("Content", "")
return await self._process_image_to_history(bot, message, content)
@on_emoji_message(priority=15)
async def handle_emoji_message(self, bot, message: dict):
"""处理表情包消息(生成描述并保存到 history不触发 AI 回复)"""
logger.info("AIChat: handle_emoji_message 被调用")
content = message.get("Content", "")
return await self._process_image_to_history(bot, message, content)