feat: 持久记忆和代码优化、函数工具筛选
This commit is contained in:
385
plugins/ZImageTurbo/main.py
Normal file
385
plugins/ZImageTurbo/main.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""
|
||||
ZImageTurbo AI绘图插件
|
||||
|
||||
基于 Z-Image-Turbo API 的图像生成插件
|
||||
支持命令触发: /z绘图 xxx 或 /Z绘图 xxx
|
||||
支持在提示词中指定尺寸: 512x512, 768x768, 1024x1024, 1024x768, 768x1024, 1280x720, 720x1280
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import json
|
||||
import tomllib
|
||||
import httpx
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
from utils.plugin_base import PluginBase
|
||||
from utils.decorators import on_text_message
|
||||
from WechatHook import WechatHookClient
|
||||
|
||||
|
||||
class ZImageTurbo(PluginBase):
|
||||
"""ZImageTurbo AI绘图插件"""
|
||||
|
||||
description = "ZImageTurbo AI绘图插件 - 基于 Z-Image-Turbo API"
|
||||
author = "ShiHao"
|
||||
version = "1.0.0"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.config = None
|
||||
self.images_dir = None
|
||||
|
||||
async def async_init(self):
|
||||
"""异步初始化"""
|
||||
config_path = Path(__file__).parent / "config.toml"
|
||||
with open(config_path, "rb") as f:
|
||||
self.config = tomllib.load(f)
|
||||
|
||||
# 创建图片目录
|
||||
self.images_dir = Path(__file__).parent / "images"
|
||||
self.images_dir.mkdir(exist_ok=True)
|
||||
|
||||
logger.success("[ZImageTurbo] 插件初始化完成")
|
||||
|
||||
async def generate_image(self, prompt: str) -> Optional[str]:
|
||||
"""
|
||||
生成图像
|
||||
|
||||
Args:
|
||||
prompt: 提示词(可包含尺寸如 1024x768)
|
||||
|
||||
Returns:
|
||||
图片本地路径,失败返回 None
|
||||
"""
|
||||
api_config = self.config["api"]
|
||||
gen_config = self.config["generation"]
|
||||
max_retry = gen_config["max_retry_attempts"]
|
||||
use_stream = gen_config.get("stream", True)
|
||||
|
||||
for attempt in range(max_retry):
|
||||
if attempt > 0:
|
||||
wait_time = min(2 ** attempt, 10)
|
||||
logger.info(f"[ZImageTurbo] 等待 {wait_time} 秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
try:
|
||||
url = api_config["url"]
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_config['token']}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": api_config["model"],
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"stream": use_stream
|
||||
}
|
||||
|
||||
logger.info(f"[ZImageTurbo] 请求: prompt={prompt[:50]}..., stream={use_stream}")
|
||||
|
||||
# 设置超时
|
||||
timeout = httpx.Timeout(
|
||||
connect=10.0,
|
||||
read=float(api_config["timeout"]),
|
||||
write=10.0,
|
||||
pool=10.0
|
||||
)
|
||||
|
||||
# 获取代理配置
|
||||
proxy = await self._get_proxy()
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client:
|
||||
if use_stream:
|
||||
# 流式响应处理
|
||||
image_url = await self._handle_stream_response(client, url, payload, headers)
|
||||
else:
|
||||
# 非流式响应处理
|
||||
image_url = await self._handle_normal_response(client, url, payload, headers)
|
||||
|
||||
if image_url:
|
||||
# 下载图片
|
||||
image_path = await self._download_image(image_url)
|
||||
if image_path:
|
||||
logger.success("[ZImageTurbo] 图像生成成功")
|
||||
return image_path
|
||||
else:
|
||||
logger.warning(f"[ZImageTurbo] 图片下载失败,重试中... ({attempt + 1}/{max_retry})")
|
||||
continue
|
||||
|
||||
except httpx.ReadTimeout:
|
||||
logger.warning(f"[ZImageTurbo] 读取超时,重试中... ({attempt + 1}/{max_retry})")
|
||||
continue
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[ZImageTurbo] 请求超时,重试中... ({attempt + 1}/{max_retry})")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"[ZImageTurbo] 请求异常: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
logger.error(f"[ZImageTurbo] 详细错误:\n{traceback.format_exc()}")
|
||||
continue
|
||||
|
||||
logger.error("[ZImageTurbo] 图像生成失败,已达最大重试次数")
|
||||
return None
|
||||
|
||||
async def _handle_stream_response(self, client: httpx.AsyncClient, url: str, payload: dict, headers: dict) -> Optional[str]:
|
||||
"""处理流式响应"""
|
||||
full_content = ""
|
||||
|
||||
async with client.stream("POST", url, json=payload, headers=headers) as response:
|
||||
logger.debug(f"[ZImageTurbo] 响应状态码: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = await response.aread()
|
||||
logger.error(f"[ZImageTurbo] API请求失败: {response.status_code}, {error_text[:200]}")
|
||||
return None
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
logger.debug("[ZImageTurbo] 收到 [DONE] 标记")
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
if "choices" in data and data["choices"]:
|
||||
delta = data["choices"][0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
full_content += content
|
||||
except Exception as e:
|
||||
logger.warning(f"[ZImageTurbo] 解析响应数据失败: {e}")
|
||||
continue
|
||||
|
||||
# 从内容中提取图片URL
|
||||
return self._extract_image_url(full_content)
|
||||
|
||||
async def _handle_normal_response(self, client: httpx.AsyncClient, url: str, payload: dict, headers: dict) -> Optional[str]:
|
||||
"""处理非流式响应"""
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[ZImageTurbo] API请求失败: {response.status_code}, {response.text[:200]}")
|
||||
return None
|
||||
|
||||
result = response.json()
|
||||
logger.debug(f"[ZImageTurbo] API返回: {json.dumps(result, ensure_ascii=False)[:200]}")
|
||||
|
||||
# 提取内容
|
||||
if "choices" in result and result["choices"]:
|
||||
content = result["choices"][0].get("message", {}).get("content", "")
|
||||
return self._extract_image_url(content)
|
||||
|
||||
return None
|
||||
|
||||
def _extract_image_url(self, content: str) -> Optional[str]:
|
||||
"""从 markdown 格式内容中提取图片URL"""
|
||||
if not content:
|
||||
logger.warning("[ZImageTurbo] 响应内容为空")
|
||||
return None
|
||||
|
||||
logger.debug(f"[ZImageTurbo] 提取URL,内容: {content[:200]}")
|
||||
|
||||
# 匹配 markdown 图片格式: 
|
||||
md_match = re.search(r'!\[.*?\]\((https?://[^\s\)]+)\)', content)
|
||||
if md_match:
|
||||
url = md_match.group(1)
|
||||
logger.info(f"[ZImageTurbo] 提取到图片URL: {url}")
|
||||
return url
|
||||
|
||||
# 直接匹配 URL
|
||||
url_match = re.search(r'https?://[^\s\)\]"\']+', content)
|
||||
if url_match:
|
||||
url = url_match.group(0).rstrip("'\"")
|
||||
logger.info(f"[ZImageTurbo] 提取到图片URL: {url}")
|
||||
return url
|
||||
|
||||
logger.warning(f"[ZImageTurbo] 未找到图片URL,内容: {content}")
|
||||
return None
|
||||
|
||||
async def _get_proxy(self) -> Optional[str]:
|
||||
"""获取代理配置(从 AIChat 插件读取)"""
|
||||
try:
|
||||
aichat_config_path = Path(__file__).parent.parent / "AIChat" / "config.toml"
|
||||
if aichat_config_path.exists():
|
||||
with open(aichat_config_path, "rb") as f:
|
||||
aichat_config = tomllib.load(f)
|
||||
|
||||
proxy_config = aichat_config.get("proxy", {})
|
||||
if proxy_config.get("enabled", False):
|
||||
proxy_type = proxy_config.get("type", "socks5")
|
||||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||||
proxy_port = proxy_config.get("port", 7890)
|
||||
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||||
logger.debug(f"[ZImageTurbo] 使用代理: {proxy}")
|
||||
return proxy
|
||||
except Exception as e:
|
||||
logger.warning(f"[ZImageTurbo] 读取代理配置失败: {e}")
|
||||
return None
|
||||
|
||||
async def _download_image(self, url: str) -> Optional[str]:
|
||||
"""下载图片到本地"""
|
||||
try:
|
||||
timeout = httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0)
|
||||
proxy = await self._get_proxy()
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# 生成文件名
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
uid = uuid.uuid4().hex[:8]
|
||||
file_path = self.images_dir / f"zimg_{ts}_{uid}.png"
|
||||
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
logger.info(f"[ZImageTurbo] 图片下载成功: {file_path}")
|
||||
return str(file_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ZImageTurbo] 下载图片失败: {e}")
|
||||
return None
|
||||
|
||||
@on_text_message(priority=70)
|
||||
async def handle_message(self, bot: WechatHookClient, message: dict):
|
||||
"""处理文本消息"""
|
||||
if not self.config["behavior"]["enable_command"]:
|
||||
return True
|
||||
|
||||
content = message.get("Content", "").strip()
|
||||
from_wxid = message.get("FromWxid", "")
|
||||
is_group = message.get("IsGroup", False)
|
||||
|
||||
# 检查群聊/私聊开关
|
||||
if is_group and not self.config["behavior"]["enable_group"]:
|
||||
return True
|
||||
if not is_group and not self.config["behavior"]["enable_private"]:
|
||||
return True
|
||||
|
||||
# 检查是否是绘图命令
|
||||
keywords = self.config["behavior"]["command_keywords"]
|
||||
matched_keyword = None
|
||||
for keyword in keywords:
|
||||
if content.startswith(keyword + " ") or content == keyword:
|
||||
matched_keyword = keyword
|
||||
break
|
||||
|
||||
if not matched_keyword:
|
||||
return True
|
||||
|
||||
# 提取提示词
|
||||
prompt = content[len(matched_keyword):].strip()
|
||||
|
||||
if not prompt:
|
||||
await bot.send_text(
|
||||
from_wxid,
|
||||
"请提供绘图提示词\n"
|
||||
"用法: /z绘图 <提示词>\n"
|
||||
"示例: /z绘图 a cute cat 1024x768\n"
|
||||
"支持尺寸: 512x512, 768x768, 1024x1024, 1024x768, 768x1024, 1280x720, 720x1280"
|
||||
)
|
||||
return False
|
||||
|
||||
# 如果提示词中没有尺寸,添加默认尺寸
|
||||
size_pattern = r'\d+x\d+'
|
||||
if not re.search(size_pattern, prompt):
|
||||
default_size = self.config["generation"]["default_size"]
|
||||
prompt = f"{prompt} {default_size}"
|
||||
|
||||
logger.info(f"[ZImageTurbo] 收到绘图请求: {prompt}")
|
||||
|
||||
# 发送等待提示
|
||||
if self.config["behavior"].get("send_waiting_message", True):
|
||||
await bot.send_text(from_wxid, "正在生成图像,请稍候(约需100-200秒)...")
|
||||
|
||||
try:
|
||||
# 生成图像
|
||||
image_path = await self.generate_image(prompt)
|
||||
|
||||
if image_path:
|
||||
await bot.send_image(from_wxid, image_path)
|
||||
logger.success("[ZImageTurbo] 绘图成功,已发送图片")
|
||||
else:
|
||||
await bot.send_text(from_wxid, "图像生成失败,请稍后重试")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ZImageTurbo] 绘图处理失败: {e}")
|
||||
await bot.send_text(from_wxid, f"处理失败: {str(e)}")
|
||||
|
||||
return False
|
||||
|
||||
def get_llm_tools(self):
|
||||
"""返回LLM工具定义,供AIChat插件调用"""
|
||||
return [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_image",
|
||||
"description": "使用AI生成图像。当用户要求画图、绘画、生成图片、创作图像时调用此工具。支持各种风格的图像生成。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "图像生成的提示词,描述想要生成的图像内容。建议使用英文以获得更好的效果。"
|
||||
},
|
||||
"size": {
|
||||
"type": "string",
|
||||
"description": "图像尺寸,可选值: 512x512, 768x768, 1024x1024, 1024x768, 768x1024, 1280x720, 720x1280",
|
||||
"enum": ["512x512", "768x768", "1024x1024", "1024x768", "768x1024", "1280x720", "720x1280"]
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict:
|
||||
"""执行LLM工具调用,供AIChat插件调用"""
|
||||
if tool_name != "generate_image":
|
||||
return {"success": False, "message": "未知的工具名称"}
|
||||
|
||||
try:
|
||||
prompt = arguments.get("prompt", "")
|
||||
size = arguments.get("size", self.config["generation"]["default_size"])
|
||||
|
||||
if not prompt:
|
||||
return {"success": False, "message": "缺少图像描述提示词"}
|
||||
|
||||
# 添加尺寸到提示词
|
||||
if size and size not in prompt:
|
||||
prompt = f"{prompt} {size}"
|
||||
|
||||
logger.info(f"[ZImageTurbo] LLM工具调用: prompt={prompt}")
|
||||
|
||||
# 生成图像
|
||||
image_path = await self.generate_image(prompt)
|
||||
|
||||
if image_path:
|
||||
# 发送图片
|
||||
await bot.send_image(from_wxid, image_path)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "图像已生成并发送",
|
||||
"no_reply": True # 已发送图片,不需要AI再回复
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "图像生成失败,请稍后重试"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ZImageTurbo] LLM工具执行失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"执行失败: {str(e)}"
|
||||
}
|
||||
Reference in New Issue
Block a user