""" 即梦AI绘图插件 支持命令触发和LLM工具调用 """ import asyncio import tomllib import aiohttp import uuid from pathlib import Path from datetime import datetime from typing import List, Optional from loguru import logger from utils.plugin_base import PluginBase from utils.decorators import on_text_message from WechatHook import WechatHookClient class TokenState: """Token轮询状态管理""" def __init__(self): self.token_index = 0 self._lock = asyncio.Lock() async def get_next_token(self, tokens: List[str]) -> str: """获取下一个可用的token""" async with self._lock: if not tokens: raise ValueError("Token列表为空") return tokens[self.token_index % len(tokens)] async def rotate(self, tokens: List[str]): """轮换到下一个token""" async with self._lock: if tokens: self.token_index = (self.token_index + 1) % len(tokens) class JimengAI(PluginBase): """即梦AI绘图插件""" description = "即梦AI绘图插件 - 支持AI绘图和LLM工具调用" author = "ShiHao" version = "1.0.0" def __init__(self): super().__init__() self.config = None self.token_state = TokenState() self.images_dir = None async def async_init(self): """异步初始化""" config_path = Path(__file__).parent / "config.toml" with open(config_path, "rb") as f: self.config = tomllib.load(f) # 创建图片目录 self.images_dir = Path(__file__).parent / "images" self.images_dir.mkdir(exist_ok=True) logger.success(f"即梦AI插件初始化完成,配置了 {len(self.config['api']['tokens'])} 个token") async def generate_image(self, prompt: str, **kwargs) -> List[str]: """ 生成图像 Args: prompt: 提示词 **kwargs: 其他参数(model, width, height, sample_strength, negative_prompt) Returns: 图片本地路径列表 """ api_config = self.config["api"] gen_config = self.config["generation"] model = kwargs.get("model", gen_config["default_model"]) width = kwargs.get("width", gen_config["default_width"]) height = kwargs.get("height", gen_config["default_height"]) sample_strength = kwargs.get("sample_strength", gen_config["default_sample_strength"]) negative_prompt = kwargs.get("negative_prompt", gen_config["default_negative_prompt"]) # 参数验证 sample_strength = max(0.0, min(1.0, sample_strength)) width = max(64, min(2048, width)) height = max(64, min(2048, height)) tokens = api_config["tokens"] max_retry = gen_config["max_retry_attempts"] # 尝试每个token for token_attempt in range(len(tokens)): current_token = await self.token_state.get_next_token(tokens) for attempt in range(max_retry): if attempt > 0: await asyncio.sleep(min(2 ** attempt, 10)) try: url = f"{api_config['base_url'].rstrip('/')}/v1/chat/completions" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {current_token}" } payload = { "model": model, "messages": [{"role": "user", "content": prompt}], "prompt": prompt, "negativePrompt": negative_prompt, "width": width, "height": height, "sample_strength": sample_strength } logger.info(f"即梦AI请求: {model}, 尺寸: {width}x{height}, 提示词: {prompt[:50]}...") async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=api_config["timeout"])) as session: async with session.post(url, headers=headers, json=payload) as response: if response.status == 200: data = await response.json() logger.debug(f"API返回数据: {data}") if "error" in data: logger.error(f"API错误: {data['error']}") continue # 提取图片URL image_paths = await self._extract_images(data) if image_paths: logger.success(f"成功生成 {len(image_paths)} 张图像") return image_paths else: logger.warning(f"未找到图像数据,API返回: {str(data)[:200]}") continue elif response.status == 401: logger.warning("Token认证失败,尝试下一个token") break elif response.status == 429: logger.warning("请求频率限制,等待后重试") await asyncio.sleep(5) continue else: error_text = await response.text() logger.error(f"API请求失败: {response.status}, {error_text[:200]}") continue except asyncio.TimeoutError: logger.warning(f"请求超时,重试中... ({attempt + 1}/{max_retry})") continue except Exception as e: logger.error(f"请求异常: {e}") continue # 当前token失败,轮换 await self.token_state.rotate(tokens) logger.error("所有token都失败了") return [] async def _extract_images(self, data: dict) -> List[str]: """从API响应中提取图片""" import re image_paths = [] # 格式1: OpenAI格式的choices if "choices" in data and data["choices"]: for choice in data["choices"]: if "message" in choice and "content" in choice["message"]: content = choice["message"]["content"] if "https://" in content: urls = re.findall(r'https://[^\s\)]+', content) for url in urls: path = await self._download_image(url) if path: image_paths.append(path) # 格式2: data数组 elif "data" in data: data_list = data["data"] if isinstance(data["data"], list) else [data["data"]] for item in data_list: if isinstance(item, str) and item.startswith("http"): path = await self._download_image(item) if path: image_paths.append(path) elif isinstance(item, dict) and "url" in item: path = await self._download_image(item["url"]) if path: image_paths.append(path) # 格式3: images数组 elif "images" in data: images_list = data["images"] if isinstance(data["images"], list) else [data["images"]] for item in images_list: if isinstance(item, str) and item.startswith("http"): path = await self._download_image(item) if path: image_paths.append(path) elif isinstance(item, dict) and "url" in item: path = await self._download_image(item["url"]) if path: image_paths.append(path) # 格式4: 单个URL elif "url" in data: path = await self._download_image(data["url"]) if path: image_paths.append(path) return image_paths async def _download_image(self, url: str) -> Optional[str]: """下载图片到本地""" try: async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) as session: async with session.get(url) as response: if response.status == 200: content = await response.read() # 生成文件名 ts = datetime.now().strftime("%Y%m%d_%H%M%S") uid = uuid.uuid4().hex[:8] file_path = self.images_dir / f"jimeng_{ts}_{uid}.jpg" # 保存文件 with open(file_path, "wb") as f: f.write(content) logger.info(f"图片下载成功: {file_path}") return str(file_path) except Exception as e: logger.error(f"下载图片失败: {e}") return None @on_text_message(priority=70) async def handle_message(self, bot: WechatHookClient, message: dict): """处理文本消息""" if not self.config["behavior"]["enable_command"]: return True content = message.get("Content", "").strip() from_wxid = message.get("FromWxid", "") is_group = message.get("IsGroup", False) # 检查群聊/私聊开关 if is_group and not self.config["behavior"]["enable_group"]: return True if not is_group and not self.config["behavior"]["enable_private"]: return True # 检查是否是绘图命令(精确匹配命令+空格+提示词) keywords = self.config["behavior"]["command_keywords"] matched_keyword = None for keyword in keywords: if content.startswith(keyword + " "): matched_keyword = keyword break if not matched_keyword: return True # 提取提示词 prompt = content[len(matched_keyword):].strip() if not prompt: await bot.send_text(from_wxid, "❌ 请提供绘图提示词\n用法: /绘图 <提示词>") return False logger.info(f"收到绘图请求: {prompt[:50]}...") # 发送处理中提示 await bot.send_text(from_wxid, "🎨 正在为您生成图像,请稍候...") try: # 生成图像 image_paths = await self.generate_image(prompt) if image_paths: # 直接发送图片 await bot.send_image(from_wxid, image_paths[0]) logger.success(f"绘图成功,已发送图片") else: await bot.send_text(from_wxid, "❌ 图像生成失败,请稍后重试") except Exception as e: logger.error(f"绘图处理失败: {e}") await bot.send_text(from_wxid, f"❌ 处理失败: {str(e)}") return False def get_llm_tools(self) -> List[dict]: """ 返回LLM工具定义 供AIChat插件调用 """ if not self.config["llm_tool"]["enabled"]: return [] return [{ "type": "function", "function": { "name": self.config["llm_tool"]["tool_name"], "description": self.config["llm_tool"]["tool_description"], "parameters": { "type": "object", "properties": { "prompt": { "type": "string", "description": "图像生成提示词,描述想要生成的图像内容" }, "width": { "type": "integer", "description": "图像宽度(64-2048),默认1024", "default": 1024 }, "height": { "type": "integer", "description": "图像高度(64-2048),默认1024", "default": 1024 } }, "required": ["prompt"] } } }] async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict: """ 执行LLM工具调用 供AIChat插件调用 Returns: {"success": bool, "message": str, "images": List[str]} """ expected_tool_name = self.config["llm_tool"]["tool_name"] logger.info(f"JimengAI工具检查: 收到={tool_name}, 期望={expected_tool_name}") if tool_name != expected_tool_name: return None # 不是本插件的工具,返回None让其他插件处理 try: prompt = arguments.get("prompt") if not prompt: return {"success": False, "message": "缺少提示词参数"} logger.info(f"LLM工具调用绘图: {prompt[:50]}...") # 生成图像(使用配置的默认尺寸) gen_config = self.config["generation"] image_paths = await self.generate_image( prompt=prompt, width=arguments.get("width", gen_config["default_width"]), height=arguments.get("height", gen_config["default_height"]) ) if image_paths: # 直接发送图片 await bot.send_image(from_wxid, image_paths[0]) return { "success": True, "message": "已生成并发送图像", "images": [image_paths[0]] } else: return {"success": False, "message": "图像生成失败"} except Exception as e: logger.error(f"LLM工具执行失败: {e}") return {"success": False, "message": f"执行失败: {str(e)}"}