Files
WechatHookBot/plugins/JimengAI/main.py
2025-12-03 15:48:44 +08:00

373 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
即梦AI绘图插件
支持命令触发和LLM工具调用
"""
import asyncio
import tomllib
import aiohttp
import uuid
from pathlib import Path
from datetime import datetime
from typing import List, Optional
from loguru import logger
from utils.plugin_base import PluginBase
from utils.decorators import on_text_message
from WechatHook import WechatHookClient
class TokenState:
"""Token轮询状态管理"""
def __init__(self):
self.token_index = 0
self._lock = asyncio.Lock()
async def get_next_token(self, tokens: List[str]) -> str:
"""获取下一个可用的token"""
async with self._lock:
if not tokens:
raise ValueError("Token列表为空")
return tokens[self.token_index % len(tokens)]
async def rotate(self, tokens: List[str]):
"""轮换到下一个token"""
async with self._lock:
if tokens:
self.token_index = (self.token_index + 1) % len(tokens)
class JimengAI(PluginBase):
"""即梦AI绘图插件"""
description = "即梦AI绘图插件 - 支持AI绘图和LLM工具调用"
author = "ShiHao"
version = "1.0.0"
def __init__(self):
super().__init__()
self.config = None
self.token_state = TokenState()
self.images_dir = None
async def async_init(self):
"""异步初始化"""
config_path = Path(__file__).parent / "config.toml"
with open(config_path, "rb") as f:
self.config = tomllib.load(f)
# 创建图片目录
self.images_dir = Path(__file__).parent / "images"
self.images_dir.mkdir(exist_ok=True)
logger.success(f"即梦AI插件初始化完成配置了 {len(self.config['api']['tokens'])} 个token")
async def generate_image(self, prompt: str, **kwargs) -> List[str]:
"""
生成图像
Args:
prompt: 提示词
**kwargs: 其他参数model, width, height, sample_strength, negative_prompt
Returns:
图片本地路径列表
"""
api_config = self.config["api"]
gen_config = self.config["generation"]
model = kwargs.get("model", gen_config["default_model"])
width = kwargs.get("width", gen_config["default_width"])
height = kwargs.get("height", gen_config["default_height"])
sample_strength = kwargs.get("sample_strength", gen_config["default_sample_strength"])
negative_prompt = kwargs.get("negative_prompt", gen_config["default_negative_prompt"])
# 参数验证
sample_strength = max(0.0, min(1.0, sample_strength))
width = max(64, min(2048, width))
height = max(64, min(2048, height))
tokens = api_config["tokens"]
max_retry = gen_config["max_retry_attempts"]
# 尝试每个token
for token_attempt in range(len(tokens)):
current_token = await self.token_state.get_next_token(tokens)
for attempt in range(max_retry):
if attempt > 0:
await asyncio.sleep(min(2 ** attempt, 10))
try:
url = f"{api_config['base_url'].rstrip('/')}/v1/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {current_token}"
}
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"prompt": prompt,
"negativePrompt": negative_prompt,
"width": width,
"height": height,
"sample_strength": sample_strength
}
logger.info(f"即梦AI请求: {model}, 尺寸: {width}x{height}, 提示词: {prompt[:50]}...")
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=api_config["timeout"])) as session:
async with session.post(url, headers=headers, json=payload) as response:
if response.status == 200:
data = await response.json()
logger.debug(f"API返回数据: {data}")
if "error" in data:
logger.error(f"API错误: {data['error']}")
continue
# 提取图片URL
image_paths = await self._extract_images(data)
if image_paths:
logger.success(f"成功生成 {len(image_paths)} 张图像")
return image_paths
else:
logger.warning(f"未找到图像数据API返回: {str(data)[:200]}")
continue
elif response.status == 401:
logger.warning("Token认证失败尝试下一个token")
break
elif response.status == 429:
logger.warning("请求频率限制,等待后重试")
await asyncio.sleep(5)
continue
else:
error_text = await response.text()
logger.error(f"API请求失败: {response.status}, {error_text[:200]}")
continue
except asyncio.TimeoutError:
logger.warning(f"请求超时,重试中... ({attempt + 1}/{max_retry})")
continue
except Exception as e:
logger.error(f"请求异常: {e}")
continue
# 当前token失败轮换
await self.token_state.rotate(tokens)
logger.error("所有token都失败了")
return []
async def _extract_images(self, data: dict) -> List[str]:
"""从API响应中提取图片"""
import re
image_paths = []
# 格式1: OpenAI格式的choices
if "choices" in data and data["choices"]:
for choice in data["choices"]:
if "message" in choice and "content" in choice["message"]:
content = choice["message"]["content"]
if "https://" in content:
urls = re.findall(r'https://[^\s\)]+', content)
for url in urls:
path = await self._download_image(url)
if path:
image_paths.append(path)
# 格式2: data数组
elif "data" in data:
data_list = data["data"] if isinstance(data["data"], list) else [data["data"]]
for item in data_list:
if isinstance(item, str) and item.startswith("http"):
path = await self._download_image(item)
if path:
image_paths.append(path)
elif isinstance(item, dict) and "url" in item:
path = await self._download_image(item["url"])
if path:
image_paths.append(path)
# 格式3: images数组
elif "images" in data:
images_list = data["images"] if isinstance(data["images"], list) else [data["images"]]
for item in images_list:
if isinstance(item, str) and item.startswith("http"):
path = await self._download_image(item)
if path:
image_paths.append(path)
elif isinstance(item, dict) and "url" in item:
path = await self._download_image(item["url"])
if path:
image_paths.append(path)
# 格式4: 单个URL
elif "url" in data:
path = await self._download_image(data["url"])
if path:
image_paths.append(path)
return image_paths
async def _download_image(self, url: str) -> Optional[str]:
"""下载图片到本地"""
try:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) as session:
async with session.get(url) as response:
if response.status == 200:
content = await response.read()
# 生成文件名
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
uid = uuid.uuid4().hex[:8]
file_path = self.images_dir / f"jimeng_{ts}_{uid}.jpg"
# 保存文件
with open(file_path, "wb") as f:
f.write(content)
logger.info(f"图片下载成功: {file_path}")
return str(file_path)
except Exception as e:
logger.error(f"下载图片失败: {e}")
return None
@on_text_message(priority=70)
async def handle_message(self, bot: WechatHookClient, message: dict):
"""处理文本消息"""
if not self.config["behavior"]["enable_command"]:
return True
content = message.get("Content", "").strip()
from_wxid = message.get("FromWxid", "")
is_group = message.get("IsGroup", False)
# 检查群聊/私聊开关
if is_group and not self.config["behavior"]["enable_group"]:
return True
if not is_group and not self.config["behavior"]["enable_private"]:
return True
# 检查是否是绘图命令(精确匹配命令+空格+提示词)
keywords = self.config["behavior"]["command_keywords"]
matched_keyword = None
for keyword in keywords:
if content.startswith(keyword + " "):
matched_keyword = keyword
break
if not matched_keyword:
return True
# 提取提示词
prompt = content[len(matched_keyword):].strip()
if not prompt:
await bot.send_text(from_wxid, "❌ 请提供绘图提示词\n用法: /绘图 <提示词>")
return False
logger.info(f"收到绘图请求: {prompt[:50]}...")
# 发送处理中提示
await bot.send_text(from_wxid, "🎨 正在为您生成图像,请稍候...")
try:
# 生成图像
image_paths = await self.generate_image(prompt)
if image_paths:
# 直接发送图片
await bot.send_image(from_wxid, image_paths[0])
logger.success(f"绘图成功,已发送图片")
else:
await bot.send_text(from_wxid, "❌ 图像生成失败,请稍后重试")
except Exception as e:
logger.error(f"绘图处理失败: {e}")
await bot.send_text(from_wxid, f"❌ 处理失败: {str(e)}")
return False
def get_llm_tools(self) -> List[dict]:
"""
返回LLM工具定义
供AIChat插件调用
"""
if not self.config["llm_tool"]["enabled"]:
return []
return [{
"type": "function",
"function": {
"name": self.config["llm_tool"]["tool_name"],
"description": self.config["llm_tool"]["tool_description"],
"parameters": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "图像生成提示词,描述想要生成的图像内容"
},
"width": {
"type": "integer",
"description": "图像宽度64-2048默认1024",
"default": 1024
},
"height": {
"type": "integer",
"description": "图像高度64-2048默认1024",
"default": 1024
}
},
"required": ["prompt"]
}
}
}]
async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict:
"""
执行LLM工具调用
供AIChat插件调用
Returns:
{"success": bool, "message": str, "images": List[str]}
"""
expected_tool_name = self.config["llm_tool"]["tool_name"]
logger.info(f"JimengAI工具检查: 收到={tool_name}, 期望={expected_tool_name}")
if tool_name != expected_tool_name:
return None # 不是本插件的工具返回None让其他插件处理
try:
prompt = arguments.get("prompt")
if not prompt:
return {"success": False, "message": "缺少提示词参数"}
logger.info(f"LLM工具调用绘图: {prompt[:50]}...")
# 生成图像(使用配置的默认尺寸)
gen_config = self.config["generation"]
image_paths = await self.generate_image(
prompt=prompt,
width=arguments.get("width", gen_config["default_width"]),
height=arguments.get("height", gen_config["default_height"])
)
if image_paths:
# 直接发送图片
await bot.send_image(from_wxid, image_paths[0])
return {
"success": True,
"message": "已生成并发送图像",
"images": [image_paths[0]]
}
else:
return {"success": False, "message": "图像生成失败"}
except Exception as e:
logger.error(f"LLM工具执行失败: {e}")
return {"success": False, "message": f"执行失败: {str(e)}"}