320 lines
12 KiB
Python
320 lines
12 KiB
Python
"""
|
||
NanoImage AI绘图插件
|
||
|
||
支持 OpenAI 格式的绘图 API,用户可自定义 URL、模型 ID、密钥
|
||
支持命令触发和 LLM 工具调用
|
||
"""
|
||
|
||
import asyncio
|
||
import tomllib
|
||
import httpx
|
||
import uuid
|
||
import base64
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
from typing import List, Optional
|
||
from loguru import logger
|
||
from utils.plugin_base import PluginBase
|
||
from utils.decorators import on_text_message
|
||
from WechatHook import WechatHookClient
|
||
|
||
|
||
class NanoImage(PluginBase):
|
||
"""NanoImage AI绘图插件"""
|
||
|
||
description = "NanoImage AI绘图插件 - 支持 OpenAI 格式的绘图 API"
|
||
author = "ShiHao"
|
||
version = "1.0.0"
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.config = None
|
||
self.images_dir = None
|
||
|
||
async def async_init(self):
|
||
"""异步初始化"""
|
||
config_path = Path(__file__).parent / "config.toml"
|
||
with open(config_path, "rb") as f:
|
||
self.config = tomllib.load(f)
|
||
|
||
# 创建图片目录
|
||
self.images_dir = Path(__file__).parent / "images"
|
||
self.images_dir.mkdir(exist_ok=True)
|
||
|
||
logger.success(f"NanoImage AI插件初始化完成,模型: {self.config['api']['model']}")
|
||
|
||
async def generate_image(self, prompt: str) -> List[str]:
|
||
"""
|
||
生成图像
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
|
||
Returns:
|
||
图片本地路径列表
|
||
"""
|
||
api_config = self.config["api"]
|
||
gen_config = self.config["generation"]
|
||
max_retry = gen_config["max_retry_attempts"]
|
||
|
||
for attempt in range(max_retry):
|
||
if attempt > 0:
|
||
await asyncio.sleep(min(2 ** attempt, 10))
|
||
|
||
try:
|
||
url = f"{api_config['base_url'].rstrip('/')}/v1/chat/completions"
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_config['api_key']}"
|
||
}
|
||
|
||
payload = {
|
||
"model": api_config["model"],
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"stream": True
|
||
}
|
||
|
||
logger.info(f"NanoImage请求: {api_config['model']}, 提示词长度: {len(prompt)} 字符")
|
||
logger.debug(f"完整提示词: {prompt}")
|
||
|
||
# 设置超时时间
|
||
max_timeout = min(api_config["timeout"], 600)
|
||
timeout = httpx.Timeout(
|
||
connect=10.0,
|
||
read=max_timeout,
|
||
write=10.0,
|
||
pool=10.0
|
||
)
|
||
|
||
# 获取代理配置
|
||
proxy = await self._get_proxy()
|
||
|
||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client:
|
||
async with client.stream("POST", url, json=payload, headers=headers) as response:
|
||
logger.debug(f"收到响应状态码: {response.status_code}")
|
||
if response.status_code == 200:
|
||
# 处理流式响应
|
||
image_url = None
|
||
full_content = ""
|
||
async for line in response.aiter_lines():
|
||
if line.startswith("data: "):
|
||
data_str = line[6:]
|
||
if data_str == "[DONE]":
|
||
break
|
||
try:
|
||
import json
|
||
data = json.loads(data_str)
|
||
if "choices" in data and data["choices"]:
|
||
delta = data["choices"][0].get("delta", {})
|
||
content = delta.get("content", "")
|
||
if content:
|
||
full_content += content
|
||
if "http" in content:
|
||
# 提取图片URL
|
||
import re
|
||
urls = re.findall(r'https?://[^\s\)\]"\']+', content)
|
||
if urls:
|
||
image_url = urls[0].rstrip("'\"")
|
||
logger.info(f"提取到图片URL: {image_url}")
|
||
except Exception as e:
|
||
logger.warning(f"解析响应数据失败: {e}")
|
||
continue
|
||
|
||
# 如果没有从流中提取到URL,尝试从完整内容中提取
|
||
if not image_url and full_content:
|
||
import re
|
||
urls = re.findall(r'https?://[^\s\)\]"\']+', full_content)
|
||
if urls:
|
||
image_url = urls[0].rstrip("'\"")
|
||
logger.info(f"从完整内容提取到图片URL: {image_url}")
|
||
|
||
if not image_url:
|
||
logger.error(f"未能提取到图片URL,完整响应: {full_content[:500]}")
|
||
|
||
if image_url:
|
||
# 下载图片
|
||
image_path = await self._download_image(image_url)
|
||
if image_path:
|
||
logger.success("成功生成图像")
|
||
return [image_path]
|
||
else:
|
||
logger.warning(f"图片下载失败,将重试 ({attempt + 1}/{max_retry})")
|
||
continue
|
||
|
||
elif response.status_code == 401:
|
||
logger.error("API Key 认证失败")
|
||
return []
|
||
else:
|
||
error_text = await response.aread()
|
||
logger.error(f"API请求失败: {response.status_code}, {error_text[:200]}")
|
||
continue
|
||
|
||
except asyncio.TimeoutError:
|
||
logger.warning(f"请求超时,重试中... ({attempt + 1}/{max_retry})")
|
||
continue
|
||
except httpx.ReadTimeout:
|
||
logger.warning(f"读取超时,重试中... ({attempt + 1}/{max_retry})")
|
||
continue
|
||
except Exception as e:
|
||
import traceback
|
||
logger.error(f"请求异常: {type(e).__name__}: {str(e)}")
|
||
logger.error(f"异常详情:\n{traceback.format_exc()}")
|
||
continue
|
||
|
||
logger.error("图像生成失败")
|
||
return []
|
||
|
||
async def _get_proxy(self) -> Optional[str]:
|
||
"""获取 AIChat 插件的代理配置"""
|
||
try:
|
||
aichat_config_path = Path(__file__).parent.parent / "AIChat" / "config.toml"
|
||
if aichat_config_path.exists():
|
||
with open(aichat_config_path, "rb") as f:
|
||
aichat_config = tomllib.load(f)
|
||
|
||
proxy_config = aichat_config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5")
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
logger.debug(f"使用代理: {proxy}")
|
||
return proxy
|
||
except Exception as e:
|
||
logger.warning(f"读取代理配置失败: {e}")
|
||
return None
|
||
|
||
async def _download_image(self, url: str) -> Optional[str]:
|
||
"""下载图片到本地"""
|
||
try:
|
||
timeout = httpx.Timeout(connect=10.0, read=30.0, write=10.0, pool=10.0)
|
||
proxy = await self._get_proxy()
|
||
|
||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client:
|
||
response = await client.get(url)
|
||
response.raise_for_status()
|
||
|
||
# 生成文件名
|
||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
uid = uuid.uuid4().hex[:8]
|
||
file_path = self.images_dir / f"nano_{ts}_{uid}.jpg"
|
||
|
||
# 保存文件
|
||
with open(file_path, "wb") as f:
|
||
f.write(response.content)
|
||
|
||
logger.info(f"图片下载成功: {file_path}")
|
||
return str(file_path)
|
||
except Exception as e:
|
||
logger.error(f"下载图片失败: {e}")
|
||
return None
|
||
|
||
@on_text_message(priority=70)
|
||
async def handle_message(self, bot: WechatHookClient, message: dict):
|
||
"""处理文本消息"""
|
||
if not self.config["behavior"]["enable_command"]:
|
||
return True
|
||
|
||
content = message.get("Content", "").strip()
|
||
from_wxid = message.get("FromWxid", "")
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
# 检查群聊/私聊开关
|
||
if is_group and not self.config["behavior"]["enable_group"]:
|
||
return True
|
||
if not is_group and not self.config["behavior"]["enable_private"]:
|
||
return True
|
||
|
||
# 检查是否是绘图命令
|
||
keywords = self.config["behavior"]["command_keywords"]
|
||
matched_keyword = None
|
||
for keyword in keywords:
|
||
if content.startswith(keyword + " ") or content == keyword:
|
||
matched_keyword = keyword
|
||
break
|
||
|
||
if not matched_keyword:
|
||
return True
|
||
|
||
# 提取提示词
|
||
prompt = content[len(matched_keyword):].strip()
|
||
|
||
if not prompt:
|
||
await bot.send_text(from_wxid, f"❌ 请提供绘图提示词\n用法: {matched_keyword} <提示词>")
|
||
return False
|
||
|
||
logger.info(f"收到绘图请求: {prompt[:50]}...")
|
||
|
||
try:
|
||
# 生成图像
|
||
image_paths = await self.generate_image(prompt)
|
||
|
||
if image_paths:
|
||
# 直接发送图片
|
||
await bot.send_image(from_wxid, image_paths[0])
|
||
logger.success("绘图成功,已发送图片")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 图像生成失败,请稍后重试")
|
||
|
||
except Exception as e:
|
||
logger.error(f"绘图处理失败: {e}")
|
||
await bot.send_text(from_wxid, f"❌ 处理失败: {str(e)}")
|
||
|
||
return False
|
||
|
||
def get_llm_tools(self) -> List[dict]:
|
||
"""返回 LLM 工具定义"""
|
||
if not self.config["llm_tool"]["enabled"]:
|
||
return []
|
||
|
||
return [{
|
||
"type": "function",
|
||
"function": {
|
||
"name": self.config["llm_tool"]["tool_name"],
|
||
"description": self.config["llm_tool"]["tool_description"],
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"prompt": {
|
||
"type": "string",
|
||
"description": "图像生成提示词,描述想要生成的图像内容"
|
||
}
|
||
},
|
||
"required": ["prompt"]
|
||
}
|
||
}
|
||
}]
|
||
|
||
async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict:
|
||
"""执行 LLM 工具调用"""
|
||
expected_tool_name = self.config["llm_tool"]["tool_name"]
|
||
|
||
if tool_name != expected_tool_name:
|
||
return None
|
||
|
||
try:
|
||
prompt = arguments.get("prompt")
|
||
|
||
if not prompt:
|
||
return {"success": False, "message": "缺少提示词参数"}
|
||
|
||
logger.info(f"LLM工具调用绘图: {prompt[:50]}...")
|
||
|
||
# 生成图像
|
||
image_paths = await self.generate_image(prompt)
|
||
|
||
if image_paths:
|
||
# 直接发送图片
|
||
await bot.send_image(from_wxid, image_paths[0])
|
||
return {
|
||
"success": True,
|
||
"message": "已生成并发送图像",
|
||
"images": [image_paths[0]]
|
||
}
|
||
else:
|
||
return {"success": False, "message": "图像生成失败"}
|
||
|
||
except Exception as e:
|
||
logger.error(f"LLM工具执行失败: {e}")
|
||
return {"success": False, "message": f"执行失败: {str(e)}"}
|