feat:初版
This commit is contained in:
350
plugins/Kiira2AI/main.py
Normal file
350
plugins/Kiira2AI/main.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
Kiira2 AI绘图插件
|
||||
|
||||
支持命令触发和LLM工具调用
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import tomllib
|
||||
import httpx
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from loguru import logger
|
||||
from utils.plugin_base import PluginBase
|
||||
from utils.decorators import on_text_message
|
||||
from WechatHook import WechatHookClient
|
||||
|
||||
|
||||
class TokenState:
|
||||
"""Token轮询状态管理"""
|
||||
def __init__(self):
|
||||
self.token_index = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_next_token(self, tokens: List[str]) -> str:
|
||||
"""获取下一个可用的token"""
|
||||
async with self._lock:
|
||||
if not tokens:
|
||||
raise ValueError("Token列表为空")
|
||||
return tokens[self.token_index % len(tokens)]
|
||||
|
||||
async def rotate(self, tokens: List[str]):
|
||||
"""轮换到下一个token"""
|
||||
async with self._lock:
|
||||
if tokens:
|
||||
self.token_index = (self.token_index + 1) % len(tokens)
|
||||
|
||||
|
||||
class Kiira2AI(PluginBase):
|
||||
"""Kiira2 AI绘图插件"""
|
||||
|
||||
description = "Kiira2 AI绘图插件 - 支持AI绘图和LLM工具调用"
|
||||
author = "ShiHao"
|
||||
version = "1.0.0"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.config = None
|
||||
self.token_state = TokenState()
|
||||
self.images_dir = None
|
||||
|
||||
async def async_init(self):
|
||||
"""异步初始化"""
|
||||
config_path = Path(__file__).parent / "config.toml"
|
||||
with open(config_path, "rb") as f:
|
||||
self.config = tomllib.load(f)
|
||||
|
||||
# 创建图片目录
|
||||
self.images_dir = Path(__file__).parent / "images"
|
||||
self.images_dir.mkdir(exist_ok=True)
|
||||
|
||||
logger.success(f"Kiira2 AI插件初始化完成,配置了 {len(self.config['api']['tokens'])} 个token")
|
||||
|
||||
async def generate_image(self, prompt: str, **kwargs) -> List[str]:
|
||||
"""
|
||||
生成图像
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
**kwargs: 其他参数(model)
|
||||
|
||||
Returns:
|
||||
图片本地路径列表
|
||||
"""
|
||||
api_config = self.config["api"]
|
||||
gen_config = self.config["generation"]
|
||||
|
||||
model = kwargs.get("model", gen_config["default_model"])
|
||||
tokens = api_config["tokens"]
|
||||
max_retry = gen_config["max_retry_attempts"]
|
||||
|
||||
# 尝试每个token
|
||||
for token_attempt in range(len(tokens)):
|
||||
current_token = await self.token_state.get_next_token(tokens)
|
||||
|
||||
for attempt in range(max_retry):
|
||||
if attempt > 0:
|
||||
await asyncio.sleep(min(2 ** attempt, 10))
|
||||
|
||||
try:
|
||||
url = f"{api_config['base_url'].rstrip('/')}/v1/chat/completions"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {current_token}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"stream": False
|
||||
}
|
||||
|
||||
logger.info(f"Kiira2 AI请求: {model}, 提示词: {prompt[:50]}...")
|
||||
|
||||
timeout = httpx.Timeout(connect=10.0, read=api_config["timeout"], write=10.0, pool=10.0)
|
||||
|
||||
# 配置代理
|
||||
proxy = None
|
||||
proxy_config = self.config.get("proxy", {})
|
||||
if proxy_config.get("enabled", False):
|
||||
proxy_type = proxy_config.get("type", "socks5")
|
||||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||||
proxy_port = proxy_config.get("port", 7890)
|
||||
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||||
logger.info(f"使用代理: {proxy}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client:
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
logger.debug(f"API返回数据: {data}")
|
||||
|
||||
if "error" in data:
|
||||
logger.error(f"API错误: {data['error']}")
|
||||
continue
|
||||
|
||||
# 检查是否返回空content(图片还在生成中)
|
||||
if "choices" in data and data["choices"]:
|
||||
message = data["choices"][0].get("message", {})
|
||||
content = message.get("content", "")
|
||||
video_url = message.get("video_url")
|
||||
|
||||
# 如果content为空且没有video_url,说明还在生成,等待后重试
|
||||
if not content and not video_url:
|
||||
wait_time = min(10 + attempt * 5, 30)
|
||||
logger.info(f"图片生成中,等待 {wait_time} 秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
# 提取图片URL
|
||||
image_paths = await self._extract_images(data)
|
||||
|
||||
if image_paths:
|
||||
logger.success(f"成功生成 {len(image_paths)} 张图像")
|
||||
return image_paths
|
||||
else:
|
||||
logger.warning(f"未找到图像数据,API返回: {str(data)[:500]}")
|
||||
continue
|
||||
|
||||
elif response.status_code == 401:
|
||||
logger.warning("Token认证失败,尝试下一个token")
|
||||
break
|
||||
elif response.status_code == 429:
|
||||
logger.warning("请求频率限制,等待后重试")
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
else:
|
||||
error_text = response.text
|
||||
logger.error(f"API请求失败: {response.status_code}, {error_text[:200]}")
|
||||
continue
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"请求超时,重试中... ({attempt + 1}/{max_retry})")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"请求异常: {e}")
|
||||
continue
|
||||
|
||||
# 当前token失败,轮换
|
||||
await self.token_state.rotate(tokens)
|
||||
|
||||
logger.error("所有token都失败了")
|
||||
return []
|
||||
|
||||
async def _extract_images(self, data: dict) -> List[str]:
|
||||
"""从API响应中提取图片(只提取图片,忽略文字)"""
|
||||
import re
|
||||
image_paths = []
|
||||
|
||||
# OpenAI格式的choices
|
||||
if "choices" in data and data["choices"]:
|
||||
for choice in data["choices"]:
|
||||
message = choice.get("message", {})
|
||||
|
||||
# 检查video_url字段(实际包含图片URL)
|
||||
if "video_url" in message:
|
||||
video_url = message["video_url"]
|
||||
if isinstance(video_url, list) and video_url:
|
||||
url = video_url[0]
|
||||
if isinstance(url, str) and url.startswith("http"):
|
||||
path = await self._download_image(url)
|
||||
if path:
|
||||
image_paths.append(path)
|
||||
|
||||
# 检查content字段
|
||||
if "content" in message and not image_paths:
|
||||
content = message["content"]
|
||||
if content and "http" in content:
|
||||
urls = re.findall(r'https?://[^\s\)\]"]+', content)
|
||||
for url in urls:
|
||||
path = await self._download_image(url)
|
||||
if path:
|
||||
image_paths.append(path)
|
||||
|
||||
return image_paths
|
||||
|
||||
async def _download_image(self, url: str) -> Optional[str]:
|
||||
"""下载图片到本地"""
|
||||
try:
|
||||
timeout = httpx.Timeout(connect=10.0, read=30.0, write=10.0, pool=10.0)
|
||||
|
||||
# 配置代理
|
||||
proxy = None
|
||||
proxy_config = self.config.get("proxy", {})
|
||||
if proxy_config.get("enabled", False):
|
||||
proxy_type = proxy_config.get("type", "socks5")
|
||||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||||
proxy_port = proxy_config.get("port", 7890)
|
||||
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# 生成文件名
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
uid = uuid.uuid4().hex[:8]
|
||||
file_path = self.images_dir / f"kiira2_{ts}_{uid}.jpg"
|
||||
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
logger.info(f"图片下载成功: {file_path}")
|
||||
return str(file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"下载图片失败: {e}")
|
||||
return None
|
||||
|
||||
@on_text_message(priority=70)
|
||||
async def handle_message(self, bot: WechatHookClient, message: dict):
|
||||
"""处理文本消息"""
|
||||
if not self.config["behavior"]["enable_command"]:
|
||||
return True
|
||||
|
||||
content = message.get("Content", "").strip()
|
||||
from_wxid = message.get("FromWxid", "")
|
||||
is_group = message.get("IsGroup", False)
|
||||
|
||||
# 检查群聊/私聊开关
|
||||
if is_group and not self.config["behavior"]["enable_group"]:
|
||||
return True
|
||||
if not is_group and not self.config["behavior"]["enable_private"]:
|
||||
return True
|
||||
|
||||
# 检查是否是绘图命令
|
||||
keywords = self.config["behavior"]["command_keywords"]
|
||||
matched_keyword = None
|
||||
for keyword in keywords:
|
||||
if content.startswith(keyword + " "):
|
||||
matched_keyword = keyword
|
||||
break
|
||||
|
||||
if not matched_keyword:
|
||||
return True
|
||||
|
||||
# 提取提示词
|
||||
prompt = content[len(matched_keyword):].strip()
|
||||
if not prompt:
|
||||
await bot.send_text(from_wxid, "❌ 请提供绘图提示词\n用法: /画画 <提示词>")
|
||||
return False
|
||||
|
||||
logger.info(f"收到绘图请求: {prompt[:50]}...")
|
||||
|
||||
# 发送处理中提示
|
||||
await bot.send_text(from_wxid, "🎨 正在为您生成图像,请稍候...")
|
||||
|
||||
try:
|
||||
# 生成图像
|
||||
image_paths = await self.generate_image(prompt)
|
||||
|
||||
if image_paths:
|
||||
# 直接发送图片
|
||||
await bot.send_image(from_wxid, image_paths[0])
|
||||
logger.success(f"绘图成功,已发送图片")
|
||||
else:
|
||||
await bot.send_text(from_wxid, "❌ 图像生成失败,请稍后重试")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"绘图处理失败: {e}")
|
||||
await bot.send_text(from_wxid, f"❌ 处理失败: {str(e)}")
|
||||
|
||||
return False
|
||||
|
||||
def get_llm_tools(self) -> List[dict]:
|
||||
"""返回LLM工具定义"""
|
||||
if not self.config["llm_tool"]["enabled"]:
|
||||
return []
|
||||
|
||||
return [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.config["llm_tool"]["tool_name"],
|
||||
"description": self.config["llm_tool"]["tool_description"],
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "图像生成提示词,描述想要生成的图像内容"
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict:
|
||||
"""执行LLM工具调用"""
|
||||
expected_tool_name = self.config["llm_tool"]["tool_name"]
|
||||
|
||||
if tool_name != expected_tool_name:
|
||||
return None
|
||||
|
||||
try:
|
||||
prompt = arguments.get("prompt")
|
||||
if not prompt:
|
||||
return {"success": False, "message": "缺少提示词参数"}
|
||||
|
||||
logger.info(f"LLM工具调用绘图: {prompt[:50]}...")
|
||||
|
||||
# 生成图像
|
||||
image_paths = await self.generate_image(prompt=prompt)
|
||||
|
||||
if image_paths:
|
||||
# 直接发送图片
|
||||
await bot.send_image(from_wxid, image_paths[0])
|
||||
return {
|
||||
"success": True,
|
||||
"message": "已生成并发送图像",
|
||||
"images": [image_paths[0]]
|
||||
}
|
||||
else:
|
||||
return {"success": False, "message": "图像生成失败"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM工具执行失败: {e}")
|
||||
return {"success": False, "message": f"执行失败: {str(e)}"}
|
||||
Reference in New Issue
Block a user