Files

386 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
ZImageTurbo AI绘图插件
基于 Z-Image-Turbo API 的图像生成插件
支持命令触发: /z绘图 xxx 或 /Z绘图 xxx
支持在提示词中指定尺寸: 512x512, 768x768, 1024x1024, 1024x768, 768x1024, 1280x720, 720x1280
"""
import asyncio
import re
import json
import tomllib
import httpx
import uuid
from pathlib import Path
from datetime import datetime
from typing import Optional
from loguru import logger
from utils.plugin_base import PluginBase
from utils.decorators import on_text_message
from WechatHook import WechatHookClient
class ZImageTurbo(PluginBase):
"""ZImageTurbo AI绘图插件"""
description = "ZImageTurbo AI绘图插件 - 基于 Z-Image-Turbo API"
author = "ShiHao"
version = "1.0.0"
def __init__(self):
super().__init__()
self.config = None
self.images_dir = None
async def async_init(self):
"""异步初始化"""
config_path = Path(__file__).parent / "config.toml"
with open(config_path, "rb") as f:
self.config = tomllib.load(f)
# 创建图片目录
self.images_dir = Path(__file__).parent / "images"
self.images_dir.mkdir(exist_ok=True)
logger.success("[ZImageTurbo] 插件初始化完成")
async def generate_image(self, prompt: str) -> Optional[str]:
"""
生成图像
Args:
prompt: 提示词(可包含尺寸如 1024x768
Returns:
图片本地路径,失败返回 None
"""
api_config = self.config["api"]
gen_config = self.config["generation"]
max_retry = gen_config["max_retry_attempts"]
use_stream = gen_config.get("stream", True)
for attempt in range(max_retry):
if attempt > 0:
wait_time = min(2 ** attempt, 10)
logger.info(f"[ZImageTurbo] 等待 {wait_time} 秒后重试...")
await asyncio.sleep(wait_time)
try:
url = api_config["url"]
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_config['token']}"
}
payload = {
"model": api_config["model"],
"messages": [{"role": "user", "content": prompt}],
"stream": use_stream
}
logger.info(f"[ZImageTurbo] 请求: prompt={prompt[:50]}..., stream={use_stream}")
# 设置超时
timeout = httpx.Timeout(
connect=10.0,
read=float(api_config["timeout"]),
write=10.0,
pool=10.0
)
# 获取代理配置
proxy = await self._get_proxy()
async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client:
if use_stream:
# 流式响应处理
image_url = await self._handle_stream_response(client, url, payload, headers)
else:
# 非流式响应处理
image_url = await self._handle_normal_response(client, url, payload, headers)
if image_url:
# 下载图片
image_path = await self._download_image(image_url)
if image_path:
logger.success("[ZImageTurbo] 图像生成成功")
return image_path
else:
logger.warning(f"[ZImageTurbo] 图片下载失败,重试中... ({attempt + 1}/{max_retry})")
continue
except httpx.ReadTimeout:
logger.warning(f"[ZImageTurbo] 读取超时,重试中... ({attempt + 1}/{max_retry})")
continue
except asyncio.TimeoutError:
logger.warning(f"[ZImageTurbo] 请求超时,重试中... ({attempt + 1}/{max_retry})")
continue
except Exception as e:
logger.error(f"[ZImageTurbo] 请求异常: {type(e).__name__}: {e}")
import traceback
logger.error(f"[ZImageTurbo] 详细错误:\n{traceback.format_exc()}")
continue
logger.error("[ZImageTurbo] 图像生成失败,已达最大重试次数")
return None
async def _handle_stream_response(self, client: httpx.AsyncClient, url: str, payload: dict, headers: dict) -> Optional[str]:
"""处理流式响应"""
full_content = ""
async with client.stream("POST", url, json=payload, headers=headers) as response:
logger.debug(f"[ZImageTurbo] 响应状态码: {response.status_code}")
if response.status_code != 200:
error_text = await response.aread()
logger.error(f"[ZImageTurbo] API请求失败: {response.status_code}, {error_text[:200]}")
return None
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
logger.debug("[ZImageTurbo] 收到 [DONE] 标记")
break
try:
data = json.loads(data_str)
if "choices" in data and data["choices"]:
delta = data["choices"][0].get("delta", {})
content = delta.get("content", "")
if content:
full_content += content
except Exception as e:
logger.warning(f"[ZImageTurbo] 解析响应数据失败: {e}")
continue
# 从内容中提取图片URL
return self._extract_image_url(full_content)
async def _handle_normal_response(self, client: httpx.AsyncClient, url: str, payload: dict, headers: dict) -> Optional[str]:
"""处理非流式响应"""
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
logger.error(f"[ZImageTurbo] API请求失败: {response.status_code}, {response.text[:200]}")
return None
result = response.json()
logger.debug(f"[ZImageTurbo] API返回: {json.dumps(result, ensure_ascii=False)[:200]}")
# 提取内容
if "choices" in result and result["choices"]:
content = result["choices"][0].get("message", {}).get("content", "")
return self._extract_image_url(content)
return None
def _extract_image_url(self, content: str) -> Optional[str]:
"""从 markdown 格式内容中提取图片URL"""
if not content:
logger.warning("[ZImageTurbo] 响应内容为空")
return None
logger.debug(f"[ZImageTurbo] 提取URL内容: {content[:200]}")
# 匹配 markdown 图片格式: ![image](url)
md_match = re.search(r'!\[.*?\]\((https?://[^\s\)]+)\)', content)
if md_match:
url = md_match.group(1)
logger.info(f"[ZImageTurbo] 提取到图片URL: {url}")
return url
# 直接匹配 URL
url_match = re.search(r'https?://[^\s\)\]"\']+', content)
if url_match:
url = url_match.group(0).rstrip("'\"")
logger.info(f"[ZImageTurbo] 提取到图片URL: {url}")
return url
logger.warning(f"[ZImageTurbo] 未找到图片URL内容: {content}")
return None
async def _get_proxy(self) -> Optional[str]:
"""获取代理配置(从 AIChat 插件读取)"""
try:
aichat_config_path = Path(__file__).parent.parent / "AIChat" / "config.toml"
if aichat_config_path.exists():
with open(aichat_config_path, "rb") as f:
aichat_config = tomllib.load(f)
proxy_config = aichat_config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5")
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
logger.debug(f"[ZImageTurbo] 使用代理: {proxy}")
return proxy
except Exception as e:
logger.warning(f"[ZImageTurbo] 读取代理配置失败: {e}")
return None
async def _download_image(self, url: str) -> Optional[str]:
"""下载图片到本地"""
try:
timeout = httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0)
proxy = await self._get_proxy()
async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client:
response = await client.get(url)
response.raise_for_status()
# 生成文件名
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
uid = uuid.uuid4().hex[:8]
file_path = self.images_dir / f"zimg_{ts}_{uid}.png"
# 保存文件
with open(file_path, "wb") as f:
f.write(response.content)
logger.info(f"[ZImageTurbo] 图片下载成功: {file_path}")
return str(file_path)
except Exception as e:
logger.error(f"[ZImageTurbo] 下载图片失败: {e}")
return None
@on_text_message(priority=70)
async def handle_message(self, bot: WechatHookClient, message: dict):
"""处理文本消息"""
if not self.config["behavior"]["enable_command"]:
return True
content = message.get("Content", "").strip()
from_wxid = message.get("FromWxid", "")
is_group = message.get("IsGroup", False)
# 检查群聊/私聊开关
if is_group and not self.config["behavior"]["enable_group"]:
return True
if not is_group and not self.config["behavior"]["enable_private"]:
return True
# 检查是否是绘图命令
keywords = self.config["behavior"]["command_keywords"]
matched_keyword = None
for keyword in keywords:
if content.startswith(keyword + " ") or content == keyword:
matched_keyword = keyword
break
if not matched_keyword:
return True
# 提取提示词
prompt = content[len(matched_keyword):].strip()
if not prompt:
await bot.send_text(
from_wxid,
"请提供绘图提示词\n"
"用法: /z绘图 <提示词>\n"
"示例: /z绘图 a cute cat 1024x768\n"
"支持尺寸: 512x512, 768x768, 1024x1024, 1024x768, 768x1024, 1280x720, 720x1280"
)
return False
# 如果提示词中没有尺寸,添加默认尺寸
size_pattern = r'\d+x\d+'
if not re.search(size_pattern, prompt):
default_size = self.config["generation"]["default_size"]
prompt = f"{prompt} {default_size}"
logger.info(f"[ZImageTurbo] 收到绘图请求: {prompt}")
# 发送等待提示
if self.config["behavior"].get("send_waiting_message", True):
await bot.send_text(from_wxid, "正在生成图像请稍候约需100-200秒...")
try:
# 生成图像
image_path = await self.generate_image(prompt)
if image_path:
await bot.send_image(from_wxid, image_path)
logger.success("[ZImageTurbo] 绘图成功,已发送图片")
else:
await bot.send_text(from_wxid, "图像生成失败,请稍后重试")
except Exception as e:
logger.error(f"[ZImageTurbo] 绘图处理失败: {e}")
await bot.send_text(from_wxid, f"处理失败: {str(e)}")
return False
def get_llm_tools(self):
"""返回LLM工具定义供AIChat插件调用"""
return [{
"type": "function",
"function": {
"name": "generate_image",
"description": "仅当用户明确要求生成图片/画图/出图/创作图像时调用;不要在闲聊中触发。",
"parameters": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "图像生成的提示词,描述想要生成的图像内容。建议使用英文以获得更好的效果。"
},
"size": {
"type": "string",
"description": "图像尺寸,可选值: 512x512, 768x768, 1024x1024, 1024x768, 768x1024, 1280x720, 720x1280",
"enum": ["512x512", "768x768", "1024x1024", "1024x768", "768x1024", "1280x720", "720x1280"]
}
},
"required": ["prompt"]
}
}
}]
async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict:
"""执行LLM工具调用供AIChat插件调用"""
if tool_name != "generate_image":
return None
try:
prompt = arguments.get("prompt", "")
size = arguments.get("size", self.config["generation"]["default_size"])
if not prompt:
return {"success": False, "message": "缺少图像描述提示词"}
# 添加尺寸到提示词
if size and size not in prompt:
prompt = f"{prompt} {size}"
logger.info(f"[ZImageTurbo] LLM工具调用: prompt={prompt}")
# 生成图像
image_path = await self.generate_image(prompt)
if image_path:
# 发送图片
await bot.send_image(from_wxid, image_path)
return {
"success": True,
"message": "图像已生成并发送",
"no_reply": True # 已发送图片不需要AI再回复
}
else:
return {
"success": False,
"message": "图像生成失败,请稍后重试"
}
except Exception as e:
logger.error(f"[ZImageTurbo] LLM工具执行失败: {e}")
return {
"success": False,
"message": f"执行失败: {str(e)}"
}