import base64 import os import time import urllib.parse import uuid from typing import Dict, Any, List, Optional, Tuple import requests from loguru import logger from pathlib import Path from base.plugin_common.message_plugin_interface import MessagePluginInterface from base.plugin_common.plugin_interface import PluginStatus from utils.ai.llm_registry import LLMRegistry from utils.decorator.plugin_decorators import plugin_stats_decorator from utils.robot_cmd.robot_command import Feature, PermissionStatus, GroupBotManager from utils.decorator.points_decorator import plugin_points_cost from wechat_ipad import WechatAPIClient class AIGenImagePlugin(MessagePluginInterface): """AI绘图插件""" # 功能权限常量 FEATURE_KEY = "AI_GEN_IMAGE" FEATURE_DESCRIPTION = "🎨 AI绘图功能 [AI绘图, 绘图, 画图, 生成图片]" @property def name(self) -> str: return "AI绘图" @property def version(self) -> str: return "1.1.0" @property def description(self) -> str: return "提供AI绘图功能,支持通过项目统一 LLM 配置路由到 OpenAI 兼容图片接口" @property def author(self) -> str: return "liu.wei" @property def command_prefix(self) -> Optional[str]: return "" # 不需要前缀,直接匹配命令 @property def commands(self) -> List[str]: return self._commands @property def feature_key(self) -> Optional[str]: return self.FEATURE_KEY @property def feature_description(self) -> Optional[str]: return self.FEATURE_DESCRIPTION def __init__(self): super().__init__() self.feature = self.register_feature() def initialize(self, context: Dict[str, Any]) -> bool: """初始化插件""" self.LOG = logger self.LOG.debug(f"正在初始化 {self.name} 插件...") # 保存上下文对象 self.event_system = context.get("event_system") # 统一读取插件配置,避免后续多次重复从字典深层取值。 plugin_config = self._config.get("AIGenImage", {}) # 指令与开关配置继续保持兼容,避免影响现有用户使用方式。 self._commands = plugin_config.get("command", ["AI绘图", "绘图", "画图", "生成图片"]) self.command_format = plugin_config.get("command-format", "AI绘图 描述文字") self.enable = plugin_config.get("enable", True) # 兼容保留旧版直连 URL 配置。 # 当没有配置统一 LLM 路由时,插件仍可按旧逻辑回退到 pollinations。 self.image_api_url = plugin_config.get( "image_api_url", "https://image.pollinations.ai/prompt/{prompt}" ) # 图片基础参数。 self.default_width = int(plugin_config.get("default_width", 1024)) self.default_height = int(plugin_config.get("default_height", 1024)) self.default_timeout = int(plugin_config.get("default_timeout", 300)) # 图片模型配置: # 1. 优先使用插件显式配置的图片模型; # 2. 未配置时,默认走通用的 gpt-image-1; # 3. 旧版 pollinations 的模型字段仍保留为回退值。 self.default_model = str(plugin_config.get("default_model", "gpt-image-1")).strip() self.legacy_model = str(plugin_config.get("legacy_model", "turbo")).strip() self.image_quality = str(plugin_config.get("image_quality", "standard")).strip() self.image_size = str( plugin_config.get("image_size", f"{self.default_width}x{self.default_height}") ).strip() self.image_count = max(int(plugin_config.get("image_count", 1) or 1), 1) self.image_response_format = str(plugin_config.get("image_response_format", "b64_json")).strip() # 统一 LLM 路由配置: # 这里复用项目现有 scene/backend 解析能力,只取连接信息与认证信息。 llm_config = plugin_config.get("llm", {}) or {} self.llm_scene = str( llm_config.get("scene") or plugin_config.get("llm_scene") or "" ).strip() self.image_api_base_url = str( llm_config.get("api_base_url") or llm_config.get("base_url") or plugin_config.get("image_api_base_url") or os.getenv("AIGENIMAGE_API_BASE_URL", "") or "" ).strip().rstrip("/") self.image_api_endpoint = str( llm_config.get("image_endpoint") or plugin_config.get("image_api_endpoint") or "images/generations" ).strip() self.image_provider = "openai_compatible" self.image_api_key = str( llm_config.get("api_key") or plugin_config.get("image_api_key") or os.getenv("AIGENIMAGE_API_KEY", "") ).strip() # 如果插件声明了 llm scene,则优先从全局 LLM 注册表解析。 # 这样用户后续只改 config.yaml 的 llm 路由,不需要再动插件代码。 if self.llm_scene: resolved_llm_config = LLMRegistry.resolve({"scene": self.llm_scene}) or {} self.LOG.debug(f"[{self.name}] llm scene 解析结果: scene={self.llm_scene}, config={resolved_llm_config}") # 统一路由主要复用网关地址与鉴权信息。 # 图片接口 endpoint 默认仍使用 images/generations,除非用户显式覆盖 image_endpoint。 self.image_provider = str( resolved_llm_config.get("provider") or self.image_provider ).strip().lower() self.image_api_base_url = str( resolved_llm_config.get("api_base_url") or resolved_llm_config.get("base_url") or self.image_api_base_url ).strip().rstrip("/") self.image_api_key = str( resolved_llm_config.get("api_key") or self.image_api_key ).strip() self.default_timeout = int( resolved_llm_config.get("timeout_seconds") or resolved_llm_config.get("request_timeout") or self.default_timeout ) # 若插件未显式配置图片模型,则允许沿用场景内的 model。 # 这样对于支持图像生成的兼容网关,可以直接从同一套后端配置继承模型名。 if not plugin_config.get("default_model"): self.default_model = str( resolved_llm_config.get("model") or self.default_model ).strip() # 确保临时目录存在 self.temp_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'temp') os.makedirs(self.temp_dir, exist_ok=True) self.LOG.debug( f"[{self.name}] 插件初始化完成,指令:{self._commands}," f"llm_scene={self.llm_scene or '-'},image_api_base_url={self.image_api_base_url or '-'}," f"image_api_endpoint={self.image_api_endpoint},provider={self.image_provider}" ) return True def start(self) -> bool: """启动插件""" self.LOG.debug(f"[{self.name}] 插件已启动") self.status = PluginStatus.RUNNING return True def stop(self) -> bool: """停止插件""" self.LOG.info(f"[{self.name}] 插件已停止") self.status = PluginStatus.STOPPED return True def can_process(self, message: Dict[str, Any]) -> bool: """检查是否可以处理该消息""" if not self.enable: return False content = str(message.get("content", "")).strip() command = content.split(" ")[0] return command in self._commands @plugin_stats_decorator(plugin_name="AI绘图") @plugin_points_cost(20, "AI绘图消耗积分", FEATURE_KEY) async def process_message(self, message: Dict[str, Any]) -> Tuple[bool, Optional[str]]: """处理消息""" content = str(message.get("content", "")).strip() self.LOG.debug(f"插件执行: {self.name}:{content}") command = content.split(" ")[0] sender = message.get("sender") roomid = message.get("roomid", "") gbm: GroupBotManager = message.get("gbm") bot: WechatAPIClient = message.get("bot") # 检查命令格式 if len(content.split(" ")) == 1: await bot.send_text_message((roomid if roomid else sender), f"❌命令格式错误!\n{self.command_format}" , sender) return False, "命令格式错误" # 检查权限 if roomid and gbm.get_group_permission(roomid, self.feature) == PermissionStatus.DISABLED: return False, "没有权限" # 提取描述文字 prompt = content[len(command):].strip() try: # 发送提示消息 await bot.send_text_message((roomid if roomid else sender), f"🎨正在生成图片,请稍候...", sender) # 生成图片 image_path = self._generate_image(prompt) if not image_path or not os.path.exists(image_path): await bot.send_text_message((roomid if roomid else sender), f"❌生成图片失败,请重试", sender) return False, "生成图片失败" # 发送图片 await bot.send_image_message((roomid if roomid else sender), Path(image_path)) return True, "发送成功" except Exception as e: self.LOG.error(f"处理AI绘图请求出错: {e}") await bot.send_text_message((roomid if roomid else sender), f"❌生成图片出错: {str(e)}", sender) return False, f"处理出错: {e}" def _generate_image(self, prompt: str) -> str: """生成图片并返回图片路径""" try: self.LOG.info( f"正在生成图片,提示词: {prompt[:30]}...," f"route={'llm' if self.image_api_base_url and self.image_api_key and self.image_provider == 'openai_compatible' else 'legacy'}" ) # 优先使用项目统一 LLM 路由出来的 OpenAI 兼容网关。 # 这是本次改造的主路径,适合用户通过 config.yaml 统一维护网关与密钥。 if self.image_provider == "openai_compatible" and self.image_api_base_url and self.image_api_key: return self._generate_image_via_openai_compatible(prompt) # 如果没有配置统一网关,则回退到旧版 pollinations 逻辑,确保兼容老配置。 return self._generate_image_via_legacy_pollinations(prompt) except Exception as e: self.LOG.error(f"生成图片出错: {e}") return "" def _generate_image_via_openai_compatible(self, prompt: str) -> str: """通过 OpenAI 兼容图片接口生成图片。""" headers = { "Content-Type": "application/json", "Authorization": self._build_auth_header(self.image_api_key), } # 这里的请求体尽量贴近 OpenAI 图片生成协议, # 以兼容常见的 newapi / one-api / 反向代理网关。 payload = { "model": self.default_model, "prompt": prompt, "n": self.image_count, "size": self.image_size, "quality": self.image_quality, "response_format": self.image_response_format, } # 部分兼容服务对 user 字段兼容良好,可用于链路追踪;没有要求时不影响结果。 payload["user"] = "abot_ai_gen_image" request_url = self._join_url(self.image_api_base_url, self.image_api_endpoint) response = requests.post(request_url, headers=headers, json=payload, timeout=self.default_timeout) response.raise_for_status() response_json = response.json() or {} image_bytes = self._extract_image_bytes_from_response(response_json) if not image_bytes: raise ValueError(f"图片接口未返回可用图片数据: {response_json}") image_path = self._save_image_bytes(image_bytes, "png") self.LOG.info(f"图片生成成功(OpenAI兼容接口),保存至: {image_path}") return image_path def _generate_image_via_legacy_pollinations(self, prompt: str) -> str: """回退到旧版 pollinations 接口,保证兼容历史配置。""" params = { "width": self.default_width, "height": self.default_height, "model": self.legacy_model, "seed": int(time.time()) % 1000000, "nologo": "true" } encoded_prompt = urllib.parse.quote(prompt) url = self.image_api_url.format(prompt=encoded_prompt) response = requests.get(url, params=params, timeout=self.default_timeout) response.raise_for_status() image_path = self._save_image_bytes(response.content, "jpg") self.LOG.info(f"图片生成成功(旧版回退接口),保存至: {image_path}") return image_path def _extract_image_bytes_from_response(self, response_json: Dict[str, Any]) -> bytes: """从 OpenAI 兼容图片响应中提取图片二进制内容。""" data_list = response_json.get("data") or [] if not data_list: return b"" first_item = data_list[0] or {} # 大多数兼容服务会返回 b64_json,直接解码即可落盘。 b64_content = ( first_item.get("b64_json") or first_item.get("image_base64") or first_item.get("base64") or "" ) if b64_content: return base64.b64decode(b64_content) # 也有一部分网关返回可访问图片 URL,此时补一次下载。 image_url = str(first_item.get("url") or first_item.get("image_url") or "").strip() if image_url: download_response = requests.get(image_url, timeout=self.default_timeout) download_response.raise_for_status() return download_response.content return b"" def _save_image_bytes(self, image_bytes: bytes, extension: str) -> str: """把图片字节保存到 temp 目录,并返回保存路径。""" image_filename = f"ai_image_{uuid.uuid4().hex[:8]}.{extension}" image_path = os.path.join(self.temp_dir, image_filename) with open(image_path, 'wb') as file_obj: file_obj.write(image_bytes) return image_path @staticmethod def _join_url(base_url: str, endpoint: str) -> str: """拼接 base_url 与 endpoint,兼容 endpoint 传完整 URL 的场景。""" endpoint = str(endpoint or "").strip() if endpoint.startswith("http://") or endpoint.startswith("https://"): return endpoint return f"{str(base_url or '').rstrip('/')}/{endpoint.lstrip('/')}" @staticmethod def _build_auth_header(api_key: str) -> str: """统一生成 Bearer 鉴权头,兼容已带 Bearer 前缀的配置。""" normalized_api_key = str(api_key or "").strip() if normalized_api_key.lower().startswith("bearer "): return normalized_api_key return f"Bearer {normalized_api_key}"