834 lines
38 KiB
Python
834 lines
38 KiB
Python
"""
|
||
NanoImage AI绘图插件
|
||
|
||
支持 OpenAI 格式的绘图 API,用户可自定义 URL、模型 ID、密钥
|
||
支持命令触发和 LLM 工具调用
|
||
"""
|
||
|
||
import asyncio
|
||
import os
|
||
import tomllib
|
||
import httpx
|
||
import uuid
|
||
import base64
|
||
import re
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
from typing import List, Optional
|
||
from urllib.parse import urlparse
|
||
from loguru import logger
|
||
from utils.plugin_base import PluginBase
|
||
from utils.decorators import on_text_message
|
||
from WechatHook import WechatHookClient
|
||
|
||
|
||
class NanoImage(PluginBase):
|
||
"""NanoImage AI绘图插件"""
|
||
|
||
description = "NanoImage AI绘图插件 - 支持 OpenAI 格式的绘图 API"
|
||
author = "ShiHao"
|
||
version = "1.0.0"
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.config = None
|
||
self.images_dir = None
|
||
|
||
async def async_init(self):
|
||
"""异步初始化"""
|
||
config_path = Path(__file__).parent / "config.toml"
|
||
logger.info(f"NanoImage 配置文件路径: {config_path}")
|
||
with open(config_path, "rb") as f:
|
||
self.config = tomllib.load(f)
|
||
|
||
# 创建图片目录
|
||
self.images_dir = Path(__file__).parent / "images"
|
||
self.images_dir.mkdir(exist_ok=True)
|
||
|
||
logger.success(f"NanoImage AI插件初始化完成,base_url: {self.config['api']['base_url']}, 模型: {self.config['api']['model']}")
|
||
|
||
async def generate_image(self, prompt: str) -> List[str]:
|
||
"""
|
||
生成图像
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
|
||
Returns:
|
||
图片本地路径列表
|
||
"""
|
||
api_config = self.config["api"]
|
||
gen_config = self.config["generation"]
|
||
max_retry = gen_config["max_retry_attempts"]
|
||
simplify_on_fail = gen_config.get("simplify_prompt_on_fail", True)
|
||
simplify_max_tags = gen_config.get("simplify_max_tags", 8)
|
||
simplify_max_chars = gen_config.get("simplify_max_chars", 140)
|
||
simplified_prompt = ""
|
||
|
||
for attempt in range(max_retry):
|
||
if attempt > 0:
|
||
await asyncio.sleep(min(2 ** attempt, 10))
|
||
|
||
try:
|
||
url = f"{api_config['base_url'].rstrip('/')}/v1/chat/completions"
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_config['api_key']}"
|
||
}
|
||
|
||
current_prompt = prompt
|
||
if simplified_prompt and attempt > 0:
|
||
current_prompt = simplified_prompt
|
||
|
||
# 添加明确的绘图指令前缀
|
||
full_prompt = f"请生成一张图片,并在回复中包含图片URL:{current_prompt}"
|
||
|
||
stream_enabled = api_config.get("stream", True)
|
||
payload = {
|
||
"model": api_config["model"],
|
||
"messages": [{"role": "user", "content": full_prompt}],
|
||
"stream": stream_enabled
|
||
}
|
||
modalities = api_config.get("modalities")
|
||
if isinstance(modalities, list) and modalities:
|
||
payload["modalities"] = modalities
|
||
|
||
logger.info(f"NanoImage请求: {url}, 模型: {api_config['model']}, 提示词长度: {len(current_prompt)} 字符")
|
||
logger.debug(f"完整提示词: {current_prompt}")
|
||
|
||
# 设置超时时间
|
||
max_timeout = min(api_config["timeout"], 600)
|
||
timeout = httpx.Timeout(
|
||
connect=10.0,
|
||
read=max_timeout,
|
||
write=10.0,
|
||
pool=10.0
|
||
)
|
||
|
||
# 获取代理配置
|
||
proxy = await self._get_proxy()
|
||
|
||
if proxy and not self._ensure_socksio(proxy):
|
||
return []
|
||
|
||
client_kwargs = {"timeout": timeout, "trust_env": True}
|
||
if proxy:
|
||
client_kwargs["proxy"] = proxy
|
||
|
||
async with httpx.AsyncClient(**client_kwargs) as client:
|
||
async with client.stream("POST", url, json=payload, headers=headers) as response:
|
||
logger.debug(f"收到响应状态码: {response.status_code}")
|
||
if response.status_code == 200:
|
||
content_type = (response.headers.get("content-type") or "").lower()
|
||
logger.info(f"响应 Content-Type: {content_type}")
|
||
# 放宽 SSE 检测:只要不是明确的 application/json 就当作流式处理
|
||
is_sse = "application/json" not in content_type
|
||
|
||
# 处理流式响应(SSE)
|
||
image_url = None
|
||
image_base64 = None
|
||
full_content = ""
|
||
if is_sse:
|
||
logger.info("使用流式响应处理模式")
|
||
async for line in response.aiter_lines():
|
||
if not line:
|
||
continue
|
||
if line.startswith("event:") or line.startswith(":"):
|
||
continue
|
||
if line.startswith("data:"):
|
||
data_str = line[5:].lstrip()
|
||
if data_str == "[DONE]":
|
||
break
|
||
try:
|
||
import json
|
||
data = json.loads(data_str)
|
||
url_candidate, b64_candidate = self._extract_image_from_payload(data)
|
||
if url_candidate and not image_url:
|
||
image_url = url_candidate
|
||
logger.info(f"从 SSE JSON 提取到图片URL: {image_url}")
|
||
if b64_candidate and not image_base64:
|
||
image_base64 = b64_candidate
|
||
logger.info("从 SSE JSON 提取到 base64 图片")
|
||
|
||
if "choices" in data and data["choices"]:
|
||
choice0 = data["choices"][0] if isinstance(data["choices"][0], dict) else {}
|
||
delta = choice0.get("delta", {}) if isinstance(choice0, dict) else {}
|
||
content = delta.get("content", "")
|
||
reasoning = delta.get("reasoning_content", "")
|
||
if content:
|
||
full_content += content
|
||
if reasoning:
|
||
full_content += reasoning
|
||
except Exception as e:
|
||
# JSON 解析失败时,尝试从原始内容提取
|
||
if "http" in data_str or "data:image" in data_str:
|
||
urls = re.findall(r'https?://[^\s\)\]"\']+', data_str)
|
||
if urls and not image_url:
|
||
image_url = urls[0].rstrip("'\"")
|
||
logger.info(f"从 SSE 文本提取到图片URL: {image_url}")
|
||
if "data:image" in data_str and not image_base64:
|
||
image_base64 = data_str.strip()
|
||
logger.info("从 SSE 文本提取到 base64 图片")
|
||
else:
|
||
logger.warning(f"解析响应数据失败: {e}")
|
||
continue
|
||
else:
|
||
# 非流式(application/json):某些网关即使传了 stream=true 也会返回完整 JSON
|
||
raw = await response.aread()
|
||
try:
|
||
import json
|
||
data = json.loads(raw.decode("utf-8", errors="ignore"))
|
||
except Exception as e:
|
||
logger.error(f"解析 JSON 响应失败: {type(e).__name__}: {e}")
|
||
data = None
|
||
|
||
if isinstance(data, dict):
|
||
# 1) 标准 images endpoint 兼容:{"data":[{"url":...}|{"b64_json":...}]}
|
||
items = data.get("data")
|
||
if isinstance(items, list) and items:
|
||
first = items[0] if isinstance(items[0], dict) else {}
|
||
if isinstance(first, dict):
|
||
b64_json = first.get("b64_json")
|
||
if b64_json:
|
||
image_base64 = b64_json
|
||
logger.info("从 data[0].b64_json 提取到 base64 图片")
|
||
else:
|
||
u = first.get("url") or ""
|
||
if isinstance(u, str) and u:
|
||
image_url = u
|
||
logger.info(f"从 data[0].url 提取到图片URL: {image_url}")
|
||
|
||
# 2) chat.completion 兼容:choices[0].message.images[0].image_url.url
|
||
if not image_url and not image_base64:
|
||
try:
|
||
choices = data.get("choices") or []
|
||
if choices:
|
||
msg = (choices[0].get("message") or {}) if isinstance(choices[0], dict) else {}
|
||
images = msg.get("images") or []
|
||
if isinstance(images, list) and images:
|
||
img0 = images[0] if isinstance(images[0], dict) else {}
|
||
if isinstance(img0, dict):
|
||
img_data = (
|
||
(img0.get("image_url") or {}).get("url")
|
||
if isinstance(img0.get("image_url"), dict)
|
||
else img0.get("url")
|
||
)
|
||
if isinstance(img_data, str) and img_data:
|
||
if img_data.startswith("data:image"):
|
||
image_base64 = img_data
|
||
logger.info("从 message.images 提取到 base64 图片")
|
||
elif img_data.startswith("http"):
|
||
image_url = img_data
|
||
logger.info(f"从 message.images 提取到图片URL: {image_url}")
|
||
except Exception:
|
||
pass
|
||
|
||
if not image_url and not image_base64:
|
||
url_candidate, b64_candidate = self._extract_image_from_payload(data)
|
||
if url_candidate:
|
||
image_url = url_candidate
|
||
logger.info(f"从 JSON payload 提取到图片URL: {image_url}")
|
||
if b64_candidate:
|
||
image_base64 = b64_candidate
|
||
logger.info("从 JSON payload 提取到 base64 图片")
|
||
|
||
# 如果没有从流中提取到URL,尝试从完整内容中提取
|
||
if not image_url and not image_base64 and full_content:
|
||
urls = re.findall(r'https?://[^\s\)\]"\']+', full_content)
|
||
if urls:
|
||
image_url = urls[0].rstrip("'\"")
|
||
logger.info(f"从完整内容提取到图片URL: {image_url}")
|
||
|
||
if not image_url and not image_base64:
|
||
# 避免把 base64 打到日志里:只输出裁剪后的概要
|
||
if full_content:
|
||
logger.error(f"未能提取到图片,完整响应(截断): {full_content[:500]}")
|
||
else:
|
||
# 非SSE时 full_content 可能为空,补充输出 content-type 便于定位
|
||
logger.error(f"未能提取到图片(content-type={content_type or 'unknown'})")
|
||
|
||
if simplify_on_fail and not simplified_prompt:
|
||
simplified_prompt = self._simplify_prompt(prompt, simplify_max_tags, simplify_max_chars)
|
||
if simplified_prompt and simplified_prompt != prompt:
|
||
logger.warning(f"将使用简化提示词重试: {simplified_prompt}")
|
||
|
||
# 处理 base64 图片
|
||
if image_base64:
|
||
image_path = await self._save_base64_image(image_base64)
|
||
if image_path:
|
||
logger.success("成功生成图像 (base64)")
|
||
return [image_path]
|
||
else:
|
||
logger.warning(f"base64图片保存失败,将重试 ({attempt + 1}/{max_retry})")
|
||
continue
|
||
|
||
# 处理 URL 图片
|
||
if image_url:
|
||
image_path = await self._download_image(image_url)
|
||
if image_path:
|
||
logger.success("成功生成图像")
|
||
return [image_path]
|
||
else:
|
||
logger.warning(f"图片下载失败,将重试 ({attempt + 1}/{max_retry})")
|
||
continue
|
||
|
||
elif response.status_code == 401:
|
||
logger.error("API Key 认证失败")
|
||
return []
|
||
else:
|
||
error_text = await response.aread()
|
||
logger.error(f"API请求失败: {response.status_code}, {error_text[:200]}")
|
||
continue
|
||
|
||
except asyncio.TimeoutError:
|
||
logger.warning(f"请求超时,重试中... ({attempt + 1}/{max_retry})")
|
||
continue
|
||
except httpx.ReadTimeout:
|
||
logger.warning(f"读取超时,重试中... ({attempt + 1}/{max_retry})")
|
||
continue
|
||
except Exception as e:
|
||
import traceback
|
||
logger.error(f"请求异常: {type(e).__name__}: {str(e)}")
|
||
logger.error(f"异常详情:\n{traceback.format_exc()}")
|
||
continue
|
||
|
||
logger.error("图像生成失败")
|
||
return []
|
||
|
||
def _decode_base64_image(self, image_base64: str) -> tuple[Optional[bytes], str, str]:
|
||
"""解析 base64 图片,返回 (bytes, ext, mime)。"""
|
||
if not isinstance(image_base64, str) or not image_base64.strip():
|
||
return None, "png", "image/png"
|
||
|
||
raw = image_base64.strip()
|
||
payload = raw
|
||
mime_type = "image/png"
|
||
ext = "png"
|
||
|
||
if raw.startswith("data:image"):
|
||
try:
|
||
header, payload = raw.split(",", 1)
|
||
mime_match = re.search(r"^data:([^;]+);base64$", header, re.IGNORECASE)
|
||
if mime_match:
|
||
mime_type = mime_match.group(1).strip().lower()
|
||
except Exception:
|
||
payload = raw
|
||
|
||
if "jpeg" in mime_type or "jpg" in mime_type:
|
||
ext = "jpg"
|
||
elif "webp" in mime_type:
|
||
ext = "webp"
|
||
elif "gif" in mime_type:
|
||
ext = "gif"
|
||
elif "png" in mime_type:
|
||
ext = "png"
|
||
|
||
try:
|
||
image_bytes = base64.b64decode(payload)
|
||
if not image_bytes:
|
||
return None, ext, mime_type
|
||
return image_bytes, ext, mime_type
|
||
except Exception as e:
|
||
logger.error(f"解析编辑图片 base64 失败: {e}")
|
||
return None, ext, mime_type
|
||
|
||
async def edit_image(self, prompt: str, image_base64: str) -> List[str]:
|
||
"""基于引用图调用 /v1/images/edits 进行改图。"""
|
||
edits_config = self.config.get("edits", {})
|
||
if not edits_config.get("enabled", True):
|
||
logger.warning("NanoImage 编辑能力已关闭,跳过 edits 调用")
|
||
return []
|
||
|
||
api_config = self.config["api"]
|
||
gen_config = self.config.get("generation", {})
|
||
max_retry = int(edits_config.get("max_retry_attempts", gen_config.get("max_retry_attempts", 2)))
|
||
max_retry = max(max_retry, 1)
|
||
|
||
endpoint = str(edits_config.get("endpoint", "/v1/images/edits") or "/v1/images/edits").strip()
|
||
if not endpoint.startswith("/"):
|
||
endpoint = f"/{endpoint}"
|
||
url = f"{api_config['base_url'].rstrip('/')}{endpoint}"
|
||
|
||
image_bytes, image_ext, image_mime = self._decode_base64_image(image_base64)
|
||
if not image_bytes:
|
||
logger.error("改图失败:引用图片解析失败")
|
||
return []
|
||
|
||
size = str(edits_config.get("size", "") or "").strip()
|
||
response_format = str(edits_config.get("response_format", "") or "").strip()
|
||
image_count = int(edits_config.get("n", 1) or 1)
|
||
allow_json_fallback = bool(edits_config.get("allow_json_fallback", True))
|
||
|
||
for attempt in range(max_retry):
|
||
if attempt > 0:
|
||
await asyncio.sleep(min(2 ** attempt, 10))
|
||
|
||
try:
|
||
headers = {
|
||
"Authorization": f"Bearer {api_config['api_key']}"
|
||
}
|
||
|
||
max_timeout = min(api_config["timeout"], 600)
|
||
timeout = httpx.Timeout(
|
||
connect=10.0,
|
||
read=max_timeout,
|
||
write=10.0,
|
||
pool=10.0
|
||
)
|
||
|
||
proxy = await self._get_proxy()
|
||
if proxy and not self._ensure_socksio(proxy):
|
||
return []
|
||
|
||
client_kwargs = {"timeout": timeout, "trust_env": True}
|
||
if proxy:
|
||
client_kwargs["proxy"] = proxy
|
||
|
||
form_data = {
|
||
"model": api_config["model"],
|
||
"prompt": prompt,
|
||
"n": str(max(image_count, 1)),
|
||
}
|
||
if size:
|
||
form_data["size"] = size
|
||
if response_format:
|
||
form_data["response_format"] = response_format
|
||
|
||
logger.info(f"NanoImage 改图请求: {url}, 模型: {api_config['model']}, 提示词长度: {len(prompt)}")
|
||
|
||
async with httpx.AsyncClient(**client_kwargs) as client:
|
||
response = await client.post(
|
||
url,
|
||
data=form_data,
|
||
files={"image": (f"edit_source.{image_ext}", image_bytes, image_mime)},
|
||
headers=headers,
|
||
)
|
||
|
||
if response.status_code != 200 and allow_json_fallback and response.status_code in (400, 415, 422):
|
||
json_payload = {
|
||
"model": api_config["model"],
|
||
"prompt": prompt,
|
||
"n": max(image_count, 1),
|
||
"image": image_base64,
|
||
}
|
||
if size:
|
||
json_payload["size"] = size
|
||
if response_format:
|
||
json_payload["response_format"] = response_format
|
||
|
||
logger.warning(f"edits multipart 失败({response.status_code}),尝试 JSON 回退")
|
||
response = await client.post(
|
||
url,
|
||
json=json_payload,
|
||
headers={
|
||
"Authorization": f"Bearer {api_config['api_key']}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
)
|
||
|
||
logger.debug(f"NanoImage edits 状态码: {response.status_code}")
|
||
|
||
if response.status_code != 200:
|
||
err_text = (response.text or "")[:300]
|
||
logger.error(f"edits 请求失败: {response.status_code}, {err_text}")
|
||
continue
|
||
|
||
payload = None
|
||
try:
|
||
payload = response.json()
|
||
except Exception as e:
|
||
logger.warning(f"edits 响应非 JSON: {e}")
|
||
|
||
image_url = None
|
||
result_base64 = None
|
||
|
||
if isinstance(payload, list):
|
||
payload = {"data": payload}
|
||
|
||
if isinstance(payload, dict):
|
||
image_url, result_base64 = self._extract_image_from_payload(payload)
|
||
|
||
if not image_url and not result_base64:
|
||
body_text = response.text or ""
|
||
if body_text:
|
||
urls = re.findall(r'https?://[^\s\)\]\"\']+', body_text)
|
||
if urls:
|
||
image_url = urls[0].rstrip("'\"")
|
||
elif "data:image" in body_text:
|
||
result_base64 = body_text.strip()
|
||
|
||
if result_base64:
|
||
image_path = await self._save_base64_image(result_base64)
|
||
if image_path:
|
||
logger.success("NanoImage 改图成功 (base64)")
|
||
return [image_path]
|
||
|
||
if image_url:
|
||
image_path = await self._download_image(image_url)
|
||
if image_path:
|
||
logger.success("NanoImage 改图成功 (url)")
|
||
return [image_path]
|
||
|
||
logger.error("edits 响应中未解析到图片结果")
|
||
|
||
except asyncio.TimeoutError:
|
||
logger.warning(f"改图请求超时,重试中... ({attempt + 1}/{max_retry})")
|
||
continue
|
||
except httpx.ReadTimeout:
|
||
logger.warning(f"改图读取超时,重试中... ({attempt + 1}/{max_retry})")
|
||
continue
|
||
except Exception as e:
|
||
logger.error(f"改图请求异常: {type(e).__name__}: {e}")
|
||
continue
|
||
|
||
logger.error("NanoImage 改图失败")
|
||
return []
|
||
|
||
def _extract_image_from_payload(self, data: dict) -> tuple[Optional[str], Optional[str]]:
|
||
"""从响应 payload 中提取图片 URL 或 base64"""
|
||
if not isinstance(data, dict):
|
||
return None, None
|
||
|
||
image_url = None
|
||
image_base64 = None
|
||
|
||
items = data.get("data")
|
||
if isinstance(items, list) and items:
|
||
first = items[0] if isinstance(items[0], dict) else {}
|
||
if isinstance(first, dict):
|
||
b64 = first.get("b64_json") or first.get("image_base64")
|
||
if isinstance(b64, str) and b64:
|
||
image_base64 = b64
|
||
u = first.get("url") or first.get("image_url")
|
||
if isinstance(u, dict):
|
||
u = u.get("url")
|
||
if isinstance(u, str) and u:
|
||
image_url = u
|
||
|
||
choices = data.get("choices") or []
|
||
for choice in choices:
|
||
if not isinstance(choice, dict):
|
||
continue
|
||
|
||
for container in (choice.get("delta"), choice.get("message")):
|
||
if not isinstance(container, dict):
|
||
continue
|
||
|
||
images = container.get("images") or []
|
||
if isinstance(images, list) and images:
|
||
img0 = images[0] if isinstance(images[0], dict) else {}
|
||
if isinstance(img0, dict):
|
||
img_data = (
|
||
(img0.get("image_url") or {}).get("url")
|
||
if isinstance(img0.get("image_url"), dict)
|
||
else img0.get("url")
|
||
)
|
||
if isinstance(img_data, str) and img_data:
|
||
if img_data.startswith("data:image"):
|
||
image_base64 = img_data
|
||
else:
|
||
image_url = img_data
|
||
|
||
direct_url = container.get("image_url") or container.get("url")
|
||
if isinstance(direct_url, dict):
|
||
direct_url = direct_url.get("url")
|
||
if isinstance(direct_url, str) and direct_url:
|
||
if direct_url.startswith("data:image"):
|
||
image_base64 = direct_url
|
||
else:
|
||
image_url = direct_url
|
||
|
||
b64 = container.get("b64_json") or container.get("image_base64")
|
||
if isinstance(b64, str) and b64:
|
||
image_base64 = b64
|
||
|
||
content = container.get("content")
|
||
if isinstance(content, str) and content:
|
||
if "data:image" in content and not image_base64:
|
||
image_base64 = content
|
||
if "http" in content and not image_url:
|
||
urls = re.findall(r'https?://[^\s\)\]"\']+', content)
|
||
if urls:
|
||
image_url = urls[0].rstrip("'\"")
|
||
|
||
tool_calls = container.get("tool_calls") or []
|
||
if isinstance(tool_calls, list):
|
||
for call in tool_calls:
|
||
if not isinstance(call, dict):
|
||
continue
|
||
args = call.get("arguments")
|
||
if not args:
|
||
func = call.get("function") or {}
|
||
if isinstance(func, dict):
|
||
args = func.get("arguments")
|
||
if isinstance(args, str):
|
||
try:
|
||
import json
|
||
args = json.loads(args)
|
||
except Exception:
|
||
continue
|
||
if isinstance(args, dict):
|
||
for key in ("url", "image_url", "image", "b64_json", "image_base64"):
|
||
val = args.get(key)
|
||
if isinstance(val, dict):
|
||
val = val.get("url")
|
||
if isinstance(val, str) and val:
|
||
if val.startswith("data:image"):
|
||
image_base64 = val
|
||
elif val.startswith("http"):
|
||
image_url = val
|
||
|
||
return image_url, image_base64
|
||
|
||
def _simplify_prompt(self, prompt: str, max_tags: int, max_chars: int) -> str:
|
||
"""简化提示词,减少失败概率"""
|
||
if not isinstance(prompt, str):
|
||
return ""
|
||
parts = [p.strip() for p in prompt.split(",") if p.strip()]
|
||
if len(parts) >= 2:
|
||
return ", ".join(parts[:max_tags])
|
||
return prompt.strip()[:max_chars]
|
||
|
||
async def _get_proxy(self) -> Optional[str]:
|
||
"""获取 AIChat 插件的代理配置"""
|
||
try:
|
||
aichat_config_path = Path(__file__).parent.parent / "AIChat" / "config.toml"
|
||
if aichat_config_path.exists():
|
||
with open(aichat_config_path, "rb") as f:
|
||
aichat_config = tomllib.load(f)
|
||
|
||
proxy_config = aichat_config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5")
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
logger.debug(f"使用代理: {proxy}")
|
||
return proxy
|
||
except Exception as e:
|
||
logger.warning(f"读取代理配置失败: {e}")
|
||
# 尝试读取环境变量代理(适配全局代理/系统代理)
|
||
for key in ("HTTPS_PROXY", "https_proxy", "HTTP_PROXY", "http_proxy", "ALL_PROXY", "all_proxy"):
|
||
proxy = (os.environ.get(key) or "").strip()
|
||
if proxy:
|
||
logger.debug(f"使用环境变量代理: {key}={proxy}")
|
||
return proxy
|
||
return None
|
||
|
||
async def _save_base64_image(self, base64_data: str) -> Optional[str]:
|
||
"""保存 base64 图片到本地"""
|
||
try:
|
||
# 去除 data:image/xxx;base64, 前缀
|
||
if base64_data.startswith("data:image"):
|
||
# 提取格式和数据
|
||
header, data = base64_data.split(",", 1)
|
||
# 从 header 中提取格式,如 data:image/jpeg;base64
|
||
if "jpeg" in header or "jpg" in header:
|
||
ext = "jpg"
|
||
elif "png" in header:
|
||
ext = "png"
|
||
elif "gif" in header:
|
||
ext = "gif"
|
||
elif "webp" in header:
|
||
ext = "webp"
|
||
else:
|
||
ext = "jpg"
|
||
else:
|
||
data = base64_data
|
||
ext = "jpg"
|
||
|
||
# 解码 base64
|
||
image_bytes = base64.b64decode(data)
|
||
|
||
# 生成文件名
|
||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
uid = uuid.uuid4().hex[:8]
|
||
file_path = self.images_dir / f"nano_{ts}_{uid}.{ext}"
|
||
|
||
# 保存文件
|
||
with open(file_path, "wb") as f:
|
||
f.write(image_bytes)
|
||
|
||
logger.info(f"base64图片保存成功: {file_path}")
|
||
return str(file_path)
|
||
except Exception as e:
|
||
logger.error(f"保存base64图片失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return None
|
||
|
||
async def _download_image(self, url: str) -> Optional[str]:
|
||
"""下载图片到本地"""
|
||
try:
|
||
timeout = httpx.Timeout(connect=10.0, read=30.0, write=10.0, pool=10.0)
|
||
proxy = await self._get_proxy()
|
||
|
||
headers = {
|
||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36",
|
||
"Accept": "image/avif,image/webp,image/apng,image/*,*/*;q=0.8",
|
||
"Accept-Language": "zh-CN,zh;q=0.9",
|
||
}
|
||
|
||
if proxy and not self._ensure_socksio(proxy):
|
||
return None
|
||
|
||
client_kwargs = {"timeout": timeout, "trust_env": True}
|
||
if proxy:
|
||
client_kwargs["proxy"] = proxy
|
||
|
||
async with httpx.AsyncClient(**client_kwargs) as client:
|
||
response = await client.get(url, headers=headers, follow_redirects=True)
|
||
response.raise_for_status()
|
||
|
||
# 生成文件名
|
||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
uid = uuid.uuid4().hex[:8]
|
||
file_path = self.images_dir / f"nano_{ts}_{uid}.jpg"
|
||
|
||
# 保存文件
|
||
with open(file_path, "wb") as f:
|
||
f.write(response.content)
|
||
|
||
logger.info(f"图片下载成功: {file_path}")
|
||
return str(file_path)
|
||
except Exception as e:
|
||
logger.error(f"下载图片失败: {type(e).__name__}: {e!r}")
|
||
return None
|
||
|
||
def _ensure_socksio(self, proxy_url: str) -> bool:
|
||
"""检测 SOCKS 代理依赖"""
|
||
try:
|
||
scheme = urlparse(proxy_url).scheme.lower()
|
||
except Exception:
|
||
return True
|
||
if scheme.startswith("socks"):
|
||
try:
|
||
import socksio # noqa: F401
|
||
except Exception:
|
||
logger.error("检测到 SOCKS 代理,但未安装 socksio。请执行: pip install socksio")
|
||
return False
|
||
return True
|
||
|
||
@on_text_message(priority=70)
|
||
async def handle_message(self, bot: WechatHookClient, message: dict):
|
||
"""处理文本消息"""
|
||
if not self.config["behavior"]["enable_command"]:
|
||
return True
|
||
|
||
content = message.get("Content", "").strip()
|
||
from_wxid = message.get("FromWxid", "")
|
||
is_group = message.get("IsGroup", False)
|
||
|
||
# 检查群聊/私聊开关
|
||
if is_group and not self.config["behavior"]["enable_group"]:
|
||
return True
|
||
if not is_group and not self.config["behavior"]["enable_private"]:
|
||
return True
|
||
|
||
# 检查是否是绘图命令
|
||
keywords = self.config["behavior"]["command_keywords"]
|
||
matched_keyword = None
|
||
for keyword in keywords:
|
||
if content.startswith(keyword + " ") or content == keyword:
|
||
matched_keyword = keyword
|
||
break
|
||
|
||
if not matched_keyword:
|
||
return True
|
||
|
||
# 提取提示词
|
||
prompt = content[len(matched_keyword):].strip()
|
||
|
||
if not prompt:
|
||
await bot.send_text(from_wxid, f"❌ 请提供绘图提示词\n用法: {matched_keyword} <提示词>")
|
||
return False
|
||
|
||
logger.info(f"收到绘图请求: {prompt[:50]}...")
|
||
|
||
try:
|
||
# 生成图像
|
||
image_paths = await self.generate_image(prompt)
|
||
|
||
if image_paths:
|
||
# 直接发送图片
|
||
await bot.send_image(from_wxid, image_paths[0])
|
||
logger.success("绘图成功,已发送图片")
|
||
else:
|
||
await bot.send_text(from_wxid, "❌ 图像生成失败,请稍后重试")
|
||
|
||
except Exception as e:
|
||
logger.error(f"绘图处理失败: {e}")
|
||
await bot.send_text(from_wxid, f"❌ 处理失败: {str(e)}")
|
||
|
||
return False
|
||
|
||
def get_llm_tools(self) -> List[dict]:
|
||
"""?? LLM ?????"""
|
||
if not self.config["llm_tool"]["enabled"]:
|
||
return []
|
||
|
||
return [{
|
||
"type": "function",
|
||
"function": {
|
||
"name": self.config["llm_tool"]["tool_name"],
|
||
"description": (
|
||
str(self.config["llm_tool"]["tool_description"] or "").strip()
|
||
or "????/???????????????????????????"
|
||
),
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"prompt": {
|
||
"type": "string",
|
||
"description": "?????????????"
|
||
},
|
||
"image_base64": {
|
||
"type": "string",
|
||
"description": "??????????????data:image/... ?? base64??????? /v1/images/edits?"
|
||
}
|
||
},
|
||
"required": ["prompt"],
|
||
"additionalProperties": False
|
||
}
|
||
}
|
||
}]
|
||
|
||
async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict:
|
||
"""?? LLM ?????"""
|
||
expected_tool_name = self.config["llm_tool"]["tool_name"]
|
||
if tool_name != expected_tool_name:
|
||
return None
|
||
|
||
try:
|
||
prompt = str(arguments.get("prompt") or "").strip()
|
||
image_base64 = arguments.get("image_base64")
|
||
if not isinstance(image_base64, str):
|
||
image_base64 = ""
|
||
image_base64 = image_base64.strip()
|
||
|
||
if not prompt:
|
||
return {"success": False, "message": "???????"}
|
||
|
||
image_paths: List[str] = []
|
||
if image_base64:
|
||
logger.info(f"LLM??????: {prompt[:50]}...")
|
||
image_paths = await self.edit_image(prompt, image_base64)
|
||
if not image_paths and self.config.get("edits", {}).get("fallback_to_generation", False):
|
||
logger.warning("?????????????")
|
||
image_paths = await self.generate_image(prompt)
|
||
else:
|
||
logger.info(f"LLM??????: {prompt[:50]}...")
|
||
image_paths = await self.generate_image(prompt)
|
||
|
||
if image_paths:
|
||
await bot.send_image(from_wxid, image_paths[0])
|
||
return {
|
||
"success": True,
|
||
"message": "????????",
|
||
"images": [image_paths[0]]
|
||
}
|
||
return {"success": False, "message": "??????"}
|
||
except Exception as e:
|
||
logger.error(f"LLM??????: {e}")
|
||
return {"success": False, "message": f"????: {str(e)}"}
|