948 lines
37 KiB
Python
948 lines
37 KiB
Python
"""
|
||
Grok 视频生成插件
|
||
|
||
用户引用图片并发送 /视频 提示词 来生成视频
|
||
支持队列系统和积分制
|
||
"""
|
||
|
||
import re
|
||
import tomllib
|
||
import httpx
|
||
import xml.etree.ElementTree as ET
|
||
import asyncio
|
||
import html
|
||
import pymysql
|
||
import uuid
|
||
import json
|
||
from pathlib import Path
|
||
from loguru import logger
|
||
from typing import Dict, Optional
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
from utils.plugin_base import PluginBase
|
||
from utils.decorators import on_text_message
|
||
from utils.image_processor import ImageProcessor, MediaConfig
|
||
from minio import Minio
|
||
from minio.error import S3Error
|
||
|
||
|
||
# 定义引用消息装饰器
|
||
def on_quote_message(priority=50):
|
||
"""引用消息装饰器"""
|
||
def decorator(func):
|
||
setattr(func, '_event_type', 'quote_message') # 修复:应该是 quote_message
|
||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||
return func
|
||
return decorator
|
||
|
||
|
||
@dataclass
|
||
class VideoTask:
|
||
"""视频生成任务数据"""
|
||
user_wxid: str
|
||
from_wxid: str
|
||
prompt: str
|
||
cdnurl: str = ""
|
||
aeskey: str = ""
|
||
is_group: bool = False
|
||
timestamp: datetime = datetime.now()
|
||
image_base64: str = ""
|
||
video_config: Optional[dict] = None
|
||
|
||
|
||
class GrokVideo(PluginBase):
|
||
"""Grok 视频生成插件"""
|
||
|
||
description = "使用 Grok AI 根据提示词生成视频(可选图片,支持队列和积分系统)"
|
||
author = "ShiHao"
|
||
version = "2.0.0"
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.config = None
|
||
self.task_queue: asyncio.Queue = None
|
||
self.processing_tasks: Dict[str, VideoTask] = {} # 正在处理的任务
|
||
self.worker_task = None
|
||
self.minio_client = None
|
||
self._image_processor = None
|
||
|
||
async def async_init(self):
|
||
"""插件异步初始化"""
|
||
config_path = Path(__file__).parent / "config.toml"
|
||
with open(config_path, "rb") as f:
|
||
self.config = tomllib.load(f)
|
||
|
||
self.api_url = f"{self.config['api']['server_url'].rstrip('/')}/v1/chat/completions"
|
||
|
||
# 初始化MinIO客户端
|
||
self.minio_client = Minio(
|
||
"115.190.113.141:19000",
|
||
access_key="admin",
|
||
secret_key="80012029Lz",
|
||
secure=False
|
||
)
|
||
self.minio_bucket = "wechat"
|
||
|
||
temp_dir = Path(__file__).parent / "temp"
|
||
temp_dir.mkdir(exist_ok=True)
|
||
self._image_processor = ImageProcessor(MediaConfig(), temp_dir=temp_dir)
|
||
|
||
# 初始化队列
|
||
max_queue_size = self.config.get("queue", {}).get("max_queue_size", 10)
|
||
self.task_queue = asyncio.Queue(maxsize=max_queue_size)
|
||
|
||
# 启动工作线程
|
||
max_concurrent = self.config.get("queue", {}).get("max_concurrent", 1)
|
||
self.worker_task = asyncio.create_task(self._queue_worker())
|
||
|
||
logger.success(f"Grok 视频生成插件已加载")
|
||
logger.info(f"API: {self.api_url}")
|
||
logger.info(f"队列配置: 最大并发={max_concurrent}, 最大队列长度={max_queue_size}")
|
||
logger.info(f"积分系统: {'启用' if self.config.get('points', {}).get('enabled') else '禁用'}")
|
||
if self.config.get('points', {}).get('enabled'):
|
||
logger.info(f"每次生成消耗: {self.config['points']['cost']} 积分")
|
||
|
||
def get_db_connection(self):
|
||
"""获取数据库连接"""
|
||
db_config = self.config["database"]
|
||
return pymysql.connect(
|
||
host=db_config["host"],
|
||
port=db_config["port"],
|
||
user=db_config["user"],
|
||
password=db_config["password"],
|
||
database=db_config["database"],
|
||
charset=db_config["charset"],
|
||
autocommit=True
|
||
)
|
||
|
||
def get_user_points(self, wxid: str) -> int:
|
||
"""获取用户积分"""
|
||
try:
|
||
with self.get_db_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
sql = "SELECT points FROM user_signin WHERE wxid = %s"
|
||
cursor.execute(sql, (wxid,))
|
||
result = cursor.fetchone()
|
||
return result[0] if result else 0
|
||
except Exception as e:
|
||
logger.error(f"获取用户积分失败: {e}")
|
||
return 0
|
||
|
||
def deduct_points(self, wxid: str, points: int) -> bool:
|
||
"""扣除用户积分"""
|
||
try:
|
||
with self.get_db_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
# 检查积分是否足够
|
||
sql_check = "SELECT points FROM user_signin WHERE wxid = %s"
|
||
cursor.execute(sql_check, (wxid,))
|
||
result = cursor.fetchone()
|
||
|
||
if not result or result[0] < points:
|
||
return False
|
||
|
||
# 扣除积分
|
||
sql_update = "UPDATE user_signin SET points = points - %s WHERE wxid = %s"
|
||
cursor.execute(sql_update, (points, wxid))
|
||
logger.info(f"用户 {wxid} 扣除 {points} 积分")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"扣除用户积分失败: {e}")
|
||
return False
|
||
|
||
def is_admin(self, wxid: str) -> bool:
|
||
"""检查是否是管理员"""
|
||
admins = self.config.get("points", {}).get("admins", [])
|
||
return wxid in admins
|
||
|
||
async def upload_video_to_minio(self, local_file: str, original_filename: str = "") -> str:
|
||
"""上传视频到MinIO"""
|
||
try:
|
||
# 生成唯一文件名
|
||
file_ext = Path(local_file).suffix
|
||
unique_id = uuid.uuid4().hex
|
||
|
||
if original_filename:
|
||
# 使用原始文件名(去掉扩展名)+ 唯一ID + 扩展名
|
||
original_name = Path(original_filename).stem
|
||
# 清理文件名中的特殊字符
|
||
import re
|
||
original_name = re.sub(r'[^\w\-_\.]', '_', original_name)
|
||
filename = f"{original_name}_{unique_id}{file_ext}"
|
||
else:
|
||
filename = f"grok_video_{unique_id}{file_ext}"
|
||
|
||
object_name = f"videos/{datetime.now().strftime('%Y%m%d')}/{filename}"
|
||
|
||
# 上传文件
|
||
await asyncio.to_thread(
|
||
self.minio_client.fput_object,
|
||
self.minio_bucket,
|
||
object_name,
|
||
local_file
|
||
)
|
||
|
||
# 返回访问URL
|
||
url = f"http://115.190.113.141:19000/{self.minio_bucket}/{object_name}"
|
||
logger.info(f"视频上传成功: {url}")
|
||
return url
|
||
|
||
except S3Error as e:
|
||
logger.error(f"上传视频到MinIO失败: {e}")
|
||
return ""
|
||
|
||
async def _queue_worker(self):
|
||
"""队列工作线程"""
|
||
logger.info("视频生成队列工作线程已启动")
|
||
while True:
|
||
try:
|
||
# 从队列获取任务
|
||
task_data = await self.task_queue.get()
|
||
bot, task = task_data
|
||
|
||
# 处理任务
|
||
await self._process_video_task(bot, task)
|
||
|
||
# 标记任务完成
|
||
self.task_queue.task_done()
|
||
|
||
except Exception as e:
|
||
logger.error(f"队列工作线程错误: {e}")
|
||
await asyncio.sleep(1)
|
||
|
||
async def _process_video_task(self, bot, task: VideoTask):
|
||
"""处理视频生成任务"""
|
||
logger.info(f"开始处理视频任务: user={task.user_wxid}, prompt={task.prompt}")
|
||
|
||
try:
|
||
image_base64 = (task.image_base64 or "").strip()
|
||
|
||
if not image_base64 and task.cdnurl and task.aeskey:
|
||
image_base64 = await self._download_and_encode_image(bot, task.cdnurl, task.aeskey)
|
||
if not image_base64:
|
||
await bot.send_text(task.from_wxid, "❌ 图片下载失败,请稍后重试")
|
||
return
|
||
|
||
video_url = ""
|
||
try:
|
||
video_url = await self._call_grok_api(
|
||
task.prompt,
|
||
image_base64=image_base64,
|
||
video_config=task.video_config,
|
||
)
|
||
except Exception as e:
|
||
if image_base64 and ("upstream_error" in str(e) or "500" in str(e)):
|
||
logger.warning(f"携带图片参数调用失败,自动降级文生视频重试: {e}")
|
||
video_url = await self._call_grok_api(
|
||
task.prompt,
|
||
image_base64="",
|
||
video_config=task.video_config,
|
||
)
|
||
else:
|
||
raise
|
||
|
||
if video_url:
|
||
video_path = await self._download_video(video_url)
|
||
if video_path:
|
||
logger.info(f"准备上传视频到 MinIO: {video_path}")
|
||
video_filename = Path(video_path).name
|
||
minio_url = await self.upload_video_to_minio(video_path, video_filename)
|
||
|
||
logger.info(f"准备发送视频到微信: {video_path}")
|
||
video_sent = await bot.send_media(task.from_wxid, video_path, media_type="video")
|
||
|
||
if video_sent:
|
||
try:
|
||
from plugins.MessageLogger.main import MessageLogger
|
||
message_logger = MessageLogger.get_instance()
|
||
if message_logger and minio_url:
|
||
await message_logger.save_bot_message(
|
||
task.from_wxid,
|
||
f"[视频] {task.prompt}",
|
||
"video",
|
||
minio_url
|
||
)
|
||
logger.info(f"视频消息已写入 MessageLogger: {minio_url}")
|
||
except Exception as e:
|
||
logger.warning(f"记录视频消息到 MessageLogger 失败: {e}")
|
||
|
||
points_config = self.config.get("points", {})
|
||
if points_config.get("enabled", False):
|
||
if self.is_admin(task.user_wxid):
|
||
success_msg = "✅ 视频生成成功\n🎟️ 管理员免费使用"
|
||
else:
|
||
cost = points_config.get("cost", 50)
|
||
remaining_points = self.get_user_points(task.user_wxid)
|
||
success_msg = f"✅ 视频生成成功\n💰 本次消费:{cost} 积分\n💎 剩余积分:{remaining_points}"
|
||
|
||
await bot.send_text(task.from_wxid, success_msg)
|
||
else:
|
||
await bot.send_text(task.from_wxid, "✅ 视频生成成功")
|
||
|
||
try:
|
||
Path(video_path).unlink()
|
||
logger.info(f"已清理本地视频文件: {video_path}")
|
||
except Exception as e:
|
||
logger.warning(f"清理本地视频文件失败: {e}")
|
||
else:
|
||
await bot.send_text(task.from_wxid, "❌ 视频发送失败,请稍后重试")
|
||
logger.error(f"视频发送失败: {video_path}")
|
||
else:
|
||
await bot.send_text(task.from_wxid, "❌ 视频下载失败,请稍后重试")
|
||
else:
|
||
await bot.send_text(task.from_wxid, "❌ 未从接口响应中获取到视频地址")
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理视频任务失败: {e}")
|
||
await bot.send_text(task.from_wxid, f"❌ 视频生成失败: {str(e)}")
|
||
|
||
def _check_behavior_enabled(self, from_wxid: str, is_group: bool) -> bool:
|
||
if not self.config["behavior"].get("enabled", True):
|
||
return False
|
||
|
||
if not is_group:
|
||
return True
|
||
|
||
enabled_groups = self.config["behavior"].get("enabled_groups", [])
|
||
disabled_groups = self.config["behavior"].get("disabled_groups", [])
|
||
|
||
if from_wxid in disabled_groups:
|
||
return False
|
||
if enabled_groups and from_wxid not in enabled_groups:
|
||
return False
|
||
return True
|
||
|
||
def _build_video_config(self, *, from_tool: bool, aspect_ratio: str = "", preset: str = "") -> dict:
|
||
video_conf = self.config.get("video_generation", {})
|
||
|
||
length = int(video_conf.get("fixed_video_length", 6) or 6)
|
||
length = max(5, min(15, length))
|
||
|
||
resolution = str(video_conf.get("fixed_resolution", "SD") or "SD").upper()
|
||
if resolution not in {"SD", "HD"}:
|
||
resolution = "SD"
|
||
|
||
cfg = {
|
||
"video_length": length,
|
||
"resolution": resolution,
|
||
}
|
||
|
||
allowed_aspect_ratios = {"16:9", "9:16", "1:1", "2:3", "3:2"}
|
||
ar = (aspect_ratio or "").strip()
|
||
if ar in allowed_aspect_ratios:
|
||
cfg["aspect_ratio"] = ar
|
||
|
||
allowed_presets = {"fun", "normal", "spicy"}
|
||
default_preset = str(video_conf.get("command_preset", "normal") or "normal").strip().lower()
|
||
if default_preset not in allowed_presets:
|
||
default_preset = "normal"
|
||
|
||
if from_tool:
|
||
tool_preset = (preset or "").strip().lower()
|
||
if tool_preset in allowed_presets:
|
||
cfg["preset"] = tool_preset
|
||
else:
|
||
cfg["preset"] = default_preset
|
||
else:
|
||
cfg["preset"] = default_preset
|
||
|
||
return cfg
|
||
|
||
def _normalize_image_base64_input(self, raw: str) -> str:
|
||
"""清洗工具入参中的图片字段,避免无效值触发上游 500。"""
|
||
value = str(raw or "").strip()
|
||
if not value:
|
||
return ""
|
||
|
||
if value.lower() in {"null", "none", "nil", "undefined", "n/a", "无", "空"}:
|
||
return ""
|
||
|
||
if value.startswith("data:image/"):
|
||
return value
|
||
|
||
if value.startswith(("http://", "https://")):
|
||
return value
|
||
|
||
# 允许纯 base64,自动补为 data URL
|
||
compact = re.sub(r"\s+", "", value)
|
||
if len(compact) >= 256 and re.fullmatch(r"[A-Za-z0-9+/=]+", compact):
|
||
return f"data:image/png;base64,{compact}"
|
||
|
||
logger.warning("检测到无效 image_base64 参数,已忽略该字段")
|
||
return ""
|
||
|
||
def _refund_points(self, user_wxid: str, amount: int):
|
||
if amount <= 0:
|
||
return
|
||
try:
|
||
with self.get_db_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
sql = "UPDATE user_signin SET points = points + %s WHERE wxid = %s"
|
||
cursor.execute(sql, (amount, user_wxid))
|
||
logger.info(f"积分已回退: {user_wxid}, +{amount}")
|
||
except Exception as e:
|
||
logger.error(f"回退积分失败: {e}")
|
||
|
||
async def _enqueue_video_task(
|
||
self,
|
||
bot,
|
||
*,
|
||
user_wxid: str,
|
||
from_wxid: str,
|
||
prompt: str,
|
||
is_group: bool,
|
||
cdnurl: str = "",
|
||
aeskey: str = "",
|
||
image_base64: str = "",
|
||
video_config: Optional[dict] = None,
|
||
) -> tuple[bool, str]:
|
||
points_config = self.config.get("points", {})
|
||
deducted_cost = 0
|
||
|
||
if points_config.get("enabled", False):
|
||
if not self.is_admin(user_wxid):
|
||
cost = int(points_config.get("cost", 50) or 50)
|
||
current_points = self.get_user_points(user_wxid)
|
||
if current_points < cost:
|
||
await bot.send_text(
|
||
from_wxid,
|
||
f"❌ 积分不足\n💎 当前积分:{current_points}\n💸 需要积分:{cost}\n\n请先签到或获取积分后再试~"
|
||
)
|
||
return False, "积分不足"
|
||
|
||
if not self.deduct_points(user_wxid, cost):
|
||
await bot.send_text(from_wxid, "❌ 扣除积分失败,请稍后重试")
|
||
return False, "扣分失败"
|
||
|
||
deducted_cost = cost
|
||
logger.info(f"用户 {user_wxid} 已扣除 {cost} 积分")
|
||
else:
|
||
logger.info(f"管理员用户 {user_wxid} 免费使用")
|
||
|
||
if self.task_queue.full():
|
||
await bot.send_text(from_wxid, f"❌ 当前队列已满({self.task_queue.qsize()}/{self.task_queue.maxsize}),请稍后再试")
|
||
if deducted_cost > 0:
|
||
self._refund_points(user_wxid, deducted_cost)
|
||
return False, "队列已满"
|
||
|
||
task = VideoTask(
|
||
user_wxid=user_wxid,
|
||
from_wxid=from_wxid,
|
||
prompt=prompt,
|
||
cdnurl=cdnurl,
|
||
aeskey=aeskey,
|
||
is_group=is_group,
|
||
timestamp=datetime.now(),
|
||
image_base64=image_base64,
|
||
video_config=video_config or {},
|
||
)
|
||
|
||
try:
|
||
await self.task_queue.put((bot, task))
|
||
queue_position = self.task_queue.qsize()
|
||
|
||
if queue_position == 1:
|
||
await bot.send_text(from_wxid, "⏳ 任务已提交,正在为你生成视频,请稍候...")
|
||
else:
|
||
await bot.send_text(from_wxid, f"⏳ 任务已加入队列\n📍 当前排队位置:第 {queue_position} 位\n🚀 正在加速处理中...")
|
||
|
||
logger.success(f"视频任务入队成功: user={user_wxid}, position={queue_position}")
|
||
return True, "任务已提交"
|
||
|
||
except Exception as e:
|
||
logger.error(f"视频任务入队失败: {e}")
|
||
await bot.send_text(from_wxid, "❌ 任务提交失败,请稍后重试")
|
||
if deducted_cost > 0:
|
||
self._refund_points(user_wxid, deducted_cost)
|
||
return False, f"任务提交失败: {e}"
|
||
|
||
@on_text_message(priority=90)
|
||
async def handle_video_text_command(self, bot, message: dict):
|
||
"""处理文本 /视频 命令(支持文生视频)"""
|
||
content = (message.get("Content", "") or "").strip()
|
||
if not content.startswith("/视频"):
|
||
return True
|
||
|
||
from_wxid = message.get("FromWxid", "")
|
||
sender_wxid = message.get("SenderWxid", "")
|
||
is_group = message.get("IsGroup", False)
|
||
user_wxid = sender_wxid if is_group else from_wxid
|
||
|
||
if not self._check_behavior_enabled(from_wxid, is_group):
|
||
return True
|
||
|
||
prompt = content[3:].strip()
|
||
if not prompt:
|
||
await bot.send_text(from_wxid, "❌ 请输入提示词,格式:/视频 提示词")
|
||
return False
|
||
|
||
video_config = self._build_video_config(from_tool=False)
|
||
await self._enqueue_video_task(
|
||
bot,
|
||
user_wxid=user_wxid,
|
||
from_wxid=from_wxid,
|
||
prompt=prompt,
|
||
is_group=is_group,
|
||
video_config=video_config,
|
||
)
|
||
return False
|
||
|
||
@on_quote_message(priority=90)
|
||
async def handle_video_command(self, bot, message: dict):
|
||
"""处理引用消息中的 /视频 命令(支持图生视频和文生视频)"""
|
||
content = message.get("Content", "").strip()
|
||
from_wxid = message.get("FromWxid", "")
|
||
sender_wxid = message.get("SenderWxid", "")
|
||
is_group = message.get("IsGroup", False)
|
||
user_wxid = sender_wxid if is_group else from_wxid
|
||
|
||
try:
|
||
xml_content = content.lstrip("\ufeff")
|
||
if ":\n" in xml_content:
|
||
xml_start = xml_content.find("<?xml")
|
||
if xml_start == -1:
|
||
xml_start = xml_content.find("<msg")
|
||
if xml_start > 0:
|
||
xml_content = xml_content[xml_start:]
|
||
|
||
root = ET.fromstring(xml_content)
|
||
title = root.find(".//title")
|
||
if title is None or not title.text:
|
||
return
|
||
|
||
title_text = title.text.strip()
|
||
if not title_text.startswith("/视频"):
|
||
return
|
||
|
||
if not self._check_behavior_enabled(from_wxid, is_group):
|
||
return False
|
||
|
||
prompt = title_text[3:].strip()
|
||
if not prompt:
|
||
await bot.send_text(from_wxid, "❌ 请输入提示词,格式:/视频 提示词")
|
||
return False
|
||
|
||
cdnbigimgurl = ""
|
||
aeskey = ""
|
||
|
||
refermsg = root.find(".//refermsg")
|
||
if refermsg is not None:
|
||
refer_content = refermsg.find("content")
|
||
if refer_content is not None and refer_content.text:
|
||
try:
|
||
import html
|
||
refer_xml = html.unescape(refer_content.text).lstrip("\ufeff")
|
||
if ":\n" in refer_xml:
|
||
xml_start = refer_xml.find("<?xml")
|
||
if xml_start == -1:
|
||
xml_start = refer_xml.find("<msg")
|
||
if xml_start > 0:
|
||
refer_xml = refer_xml[xml_start:]
|
||
refer_root = ET.fromstring(refer_xml)
|
||
|
||
img = refer_root.find(".//img")
|
||
if img is not None:
|
||
cdnbigimgurl = img.get("cdnbigimgurl", "")
|
||
aeskey = img.get("aeskey", "")
|
||
except Exception as e:
|
||
logger.warning(f"解析引用内容失败,降级为文生视频: {e}")
|
||
|
||
if cdnbigimgurl and aeskey:
|
||
logger.info("检测到引用图片,按图生视频处理")
|
||
else:
|
||
cdnbigimgurl = ""
|
||
aeskey = ""
|
||
logger.info("未检测到可用引用图片,按文生视频处理")
|
||
|
||
video_config = self._build_video_config(from_tool=False)
|
||
await self._enqueue_video_task(
|
||
bot,
|
||
user_wxid=user_wxid,
|
||
from_wxid=from_wxid,
|
||
prompt=prompt,
|
||
is_group=is_group,
|
||
cdnurl=cdnbigimgurl,
|
||
aeskey=aeskey,
|
||
video_config=video_config,
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"解析引用消息失败: {e}")
|
||
return
|
||
return False
|
||
|
||
async def _download_and_encode_image(self, bot, cdnurl: str, aeskey: str) -> str:
|
||
"""下载图片并转换为 base64"""
|
||
if not self._image_processor:
|
||
logger.warning("ImageProcessor 未初始化,无法下载图片")
|
||
return ""
|
||
|
||
logger.info(f"正在下载图片: {cdnurl[:50]}...")
|
||
return await self._image_processor.download_image_by_cdn(bot, cdnurl, aeskey)
|
||
|
||
async def _call_grok_api(self, prompt: str, image_base64: str = "", video_config: Optional[dict] = None) -> str:
|
||
"""调用 Grok API 生成视频并返回视频 URL"""
|
||
api_key = self.config["api"]["api_key"]
|
||
if not api_key:
|
||
raise Exception("未配置 API Key")
|
||
|
||
user_content = prompt
|
||
image_base64 = self._normalize_image_base64_input(image_base64)
|
||
if image_base64:
|
||
user_content = [
|
||
{"type": "text", "text": prompt},
|
||
{"type": "image_url", "image_url": {"url": image_base64}}
|
||
]
|
||
|
||
payload = {
|
||
"model": self.config["api"]["model_id"],
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": user_content,
|
||
}
|
||
]
|
||
}
|
||
|
||
if isinstance(video_config, dict) and video_config:
|
||
payload["video_config"] = video_config
|
||
|
||
payload_for_log = json.loads(json.dumps(payload, ensure_ascii=False))
|
||
try:
|
||
if isinstance(payload_for_log.get("messages"), list):
|
||
for msg in payload_for_log["messages"]:
|
||
content = msg.get("content") if isinstance(msg, dict) else None
|
||
if isinstance(content, list):
|
||
for item in content:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
if item.get("type") == "image_url":
|
||
image_url = (item.get("image_url") or {}).get("url", "")
|
||
if isinstance(image_url, str) and image_url:
|
||
item["image_url"]["url"] = f"<image_base64_or_url length={len(image_url)}>"
|
||
except Exception:
|
||
pass
|
||
|
||
logger.debug(f"Grok 请求 payload: {json.dumps(payload_for_log, ensure_ascii=False)}")
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}"
|
||
}
|
||
|
||
timeout = httpx.Timeout(connect=10.0, read=self.config["api"]["timeout"], write=10.0, pool=10.0)
|
||
|
||
proxy = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5")
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
logger.info(f"使用代理请求: {proxy}")
|
||
|
||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client:
|
||
response = await client.post(self.api_url, json=payload, headers=headers)
|
||
raw_text = response.text or ""
|
||
|
||
logger.debug(
|
||
f"Grok 响应状态: {response.status_code}, content-type: {response.headers.get('content-type', '')}, 长度: {len(raw_text)}"
|
||
)
|
||
logger.debug(f"Grok 原始响应全文: {raw_text}")
|
||
|
||
if response.status_code != 200:
|
||
err_text = raw_text[:300]
|
||
raise Exception(f"API 请求失败: {response.status_code}, {err_text}")
|
||
|
||
try:
|
||
result = response.json()
|
||
except Exception as e:
|
||
raise Exception(f"API 返回非 JSON: {e}, body={raw_text[:500]}")
|
||
|
||
message_content = ""
|
||
try:
|
||
message_content = result["choices"][0]["message"].get("content", "")
|
||
except Exception:
|
||
message_content = ""
|
||
|
||
if isinstance(message_content, (dict, list)):
|
||
content_text = json.dumps(message_content, ensure_ascii=False)
|
||
else:
|
||
content_text = str(message_content or "")
|
||
|
||
video_url = self._extract_video_url(content_text)
|
||
if not video_url:
|
||
video_url = self._extract_video_url_from_obj(result)
|
||
if not video_url:
|
||
video_url = self._extract_video_url(json.dumps(result, ensure_ascii=False))
|
||
|
||
if not video_url:
|
||
result_preview = json.dumps(result, ensure_ascii=False)[:800]
|
||
logger.error(f"Grok 响应中未找到视频 URL,响应预览: {result_preview}")
|
||
raise Exception("未从响应中提取到视频 URL")
|
||
|
||
logger.info(f"提取到视频 URL: {video_url}")
|
||
return video_url
|
||
|
||
def get_llm_tools(self):
|
||
"""获取 LLM 工具定义"""
|
||
llm_tool_conf = self.config.get("llm_tool", {})
|
||
if not llm_tool_conf.get("enabled", True):
|
||
return []
|
||
|
||
tool_name = llm_tool_conf.get("tool_name", "grok_video_generation")
|
||
tool_desc = llm_tool_conf.get(
|
||
"tool_description",
|
||
"Grok视频生成工具。可只传 prompt 文生视频,也可额外传 image_base64 图生视频。"
|
||
)
|
||
|
||
return [{
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool_name,
|
||
"description": tool_desc,
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"prompt": {
|
||
"type": "string",
|
||
"description": "视频生成提示词"
|
||
},
|
||
"image_base64": {
|
||
"type": "string",
|
||
"description": "可选。参考图的 base64(data:image/... 或纯 base64)"
|
||
},
|
||
"aspect_ratio": {
|
||
"type": "string",
|
||
"enum": ["16:9", "9:16", "1:1", "2:3", "3:2"],
|
||
"description": "可选。视频画幅比例"
|
||
},
|
||
"preset": {
|
||
"type": "string",
|
||
"enum": ["fun", "normal", "spicy"],
|
||
"description": "可选。视频风格预设"
|
||
}
|
||
},
|
||
"required": ["prompt"],
|
||
"additionalProperties": False
|
||
}
|
||
}
|
||
}]
|
||
|
||
async def execute_llm_tool(self, tool_name: str, arguments: dict, bot, from_wxid: str) -> dict:
|
||
"""执行 LLM 工具调用"""
|
||
expected_tool = self.config.get("llm_tool", {}).get("tool_name", "grok_video_generation")
|
||
if tool_name != expected_tool:
|
||
return None
|
||
|
||
prompt = str(arguments.get("prompt") or "").strip()
|
||
if not prompt:
|
||
return {
|
||
"success": False,
|
||
"message": "缺少视频提示词",
|
||
"already_sent": False,
|
||
"no_reply": False,
|
||
}
|
||
|
||
user_wxid = str(arguments.get("user_wxid") or from_wxid)
|
||
is_group = bool(arguments.get("is_group", False))
|
||
image_base64 = self._normalize_image_base64_input(arguments.get("image_base64"))
|
||
aspect_ratio = str(arguments.get("aspect_ratio") or "").strip()
|
||
preset = str(arguments.get("preset") or "").strip()
|
||
|
||
if not self._check_behavior_enabled(from_wxid, is_group):
|
||
return {
|
||
"success": False,
|
||
"message": "当前会话未开启视频生成功能",
|
||
"already_sent": False,
|
||
"no_reply": False,
|
||
}
|
||
|
||
video_cfg = self._build_video_config(from_tool=True, aspect_ratio=aspect_ratio, preset=preset)
|
||
|
||
ok, msg = await self._enqueue_video_task(
|
||
bot,
|
||
user_wxid=user_wxid,
|
||
from_wxid=from_wxid,
|
||
prompt=prompt,
|
||
is_group=is_group,
|
||
image_base64=image_base64,
|
||
video_config=video_cfg,
|
||
)
|
||
|
||
return {
|
||
"success": ok,
|
||
"message": msg,
|
||
"already_sent": True,
|
||
"no_reply": True,
|
||
}
|
||
|
||
|
||
|
||
def _extract_video_url_from_obj(self, data) -> str:
|
||
"""从嵌套对象中提取视频 URL。"""
|
||
|
||
def is_http_url(value: str) -> bool:
|
||
return isinstance(value, str) and value.lower().startswith(("http://", "https://"))
|
||
|
||
def looks_like_video_url(value: str) -> bool:
|
||
if not is_http_url(value):
|
||
return False
|
||
lower = value.lower()
|
||
video_hints = (
|
||
".mp4", ".mov", ".m4v", ".webm", ".m3u8", ".mpd",
|
||
"video", "mime=video", "content-type=video", "content_type=video",
|
||
)
|
||
return any(hint in lower for hint in video_hints)
|
||
|
||
preferred_keys = (
|
||
"video_url", "videoUrl", "video", "url", "src",
|
||
"play_url", "playUrl", "download_url", "downloadUrl",
|
||
)
|
||
|
||
def walk(node):
|
||
if isinstance(node, dict):
|
||
for key in preferred_keys:
|
||
value = node.get(key)
|
||
if isinstance(value, str) and looks_like_video_url(value):
|
||
return value
|
||
for key in preferred_keys:
|
||
value = node.get(key)
|
||
if isinstance(value, str) and is_http_url(value):
|
||
return value
|
||
for value in node.values():
|
||
found = walk(value)
|
||
if found:
|
||
return found
|
||
elif isinstance(node, list):
|
||
for item in node:
|
||
found = walk(item)
|
||
if found:
|
||
return found
|
||
elif isinstance(node, str):
|
||
text = node.strip()
|
||
if "<video" in text.lower():
|
||
extracted = self._extract_video_url(text)
|
||
if extracted:
|
||
return extracted
|
||
if looks_like_video_url(text):
|
||
return text
|
||
if is_http_url(text):
|
||
return text
|
||
return ""
|
||
|
||
return walk(data)
|
||
|
||
|
||
def _extract_video_url(self, content: str) -> str:
|
||
"""从响应内容中提取视频 URL"""
|
||
content = str(content or "")
|
||
|
||
def clean_url(url: str) -> str:
|
||
cleaned = str(url or "").strip()
|
||
cleaned = cleaned.replace("\\/", "/")
|
||
cleaned = cleaned.strip(" \t\r\n\"'`")
|
||
cleaned = cleaned.rstrip("\\")
|
||
cleaned = cleaned.rstrip(">,.;)]}")
|
||
return cleaned
|
||
|
||
variants = [content]
|
||
v_unescape = html.unescape(content)
|
||
if v_unescape and v_unescape not in variants:
|
||
variants.append(v_unescape)
|
||
v_slash = content.replace("\\/", "/")
|
||
if v_slash and v_slash not in variants:
|
||
variants.append(v_slash)
|
||
v_quote = content.replace('\\"', '"').replace("\\'", "'")
|
||
if v_quote and v_quote not in variants:
|
||
variants.append(v_quote)
|
||
v_all = html.unescape(v_slash).replace('\\"', '"').replace("\\'", "'")
|
||
if v_all and v_all not in variants:
|
||
variants.append(v_all)
|
||
|
||
# 尝试从 HTML video 标签提取(兼容转义引号与无引号)
|
||
for variant in variants:
|
||
match = re.search(
|
||
r'<video[^>]*\bsrc\s*=\s*(?:\\?["\"])??([^"\'\s>]+)',
|
||
variant,
|
||
re.IGNORECASE,
|
||
)
|
||
if match:
|
||
url = clean_url(match.group(1))
|
||
if url.lower().startswith(("http://", "https://")):
|
||
return url
|
||
|
||
# 尝试从 JSON src 字段提取
|
||
for variant in variants:
|
||
match = re.search(r'"src"\s*:\s*"([^"]+)"', variant, re.IGNORECASE)
|
||
if match:
|
||
url = clean_url(match.group(1))
|
||
if url.lower().startswith(("http://", "https://")):
|
||
return url
|
||
|
||
# 尝试提取常见视频扩展链接
|
||
for variant in variants:
|
||
match = re.search(
|
||
r'(https?://[^\s<>"\')\]]+\.(?:mp4|mov|m4v|webm|m3u8|mpd)(?:\?[^\s<>"\')\]]*)?)',
|
||
variant,
|
||
re.IGNORECASE,
|
||
)
|
||
if match:
|
||
return clean_url(match.group(1))
|
||
|
||
# 尝试提取 markdown 链接中的 URL
|
||
for variant in variants:
|
||
match = re.search(r'\((https?://[^\s)]+)\)', variant, re.IGNORECASE)
|
||
if match:
|
||
return clean_url(match.group(1))
|
||
|
||
# 尝试提取任意 URL,并优先选择像视频的链接
|
||
for variant in variants:
|
||
candidates = re.findall(r'https?://[^\s<>"\')\]]+', variant, re.IGNORECASE)
|
||
if candidates:
|
||
cleaned = [clean_url(u) for u in candidates]
|
||
for url in cleaned:
|
||
lower = url.lower()
|
||
if any(k in lower for k in (".mp4", ".mov", ".m4v", ".webm", ".m3u8", ".mpd", "video")):
|
||
return url
|
||
return cleaned[0]
|
||
|
||
# 内容是 JSON 文本时,尝试按对象结构提取
|
||
try:
|
||
parsed = json.loads(content)
|
||
from_obj = self._extract_video_url_from_obj(parsed)
|
||
if from_obj:
|
||
return from_obj
|
||
except Exception:
|
||
pass
|
||
|
||
return ""
|
||
|
||
async def _download_video(self, video_url: str) -> str:
|
||
"""下载视频到本地"""
|
||
videos_dir = Path(__file__).parent / "videos"
|
||
videos_dir.mkdir(exist_ok=True)
|
||
|
||
filename = f"grok_{datetime.now():%Y%m%d_%H%M%S}_{uuid.uuid4().hex[:8]}.mp4"
|
||
file_path = videos_dir / filename
|
||
|
||
timeout = httpx.Timeout(connect=10.0, read=300.0, write=10.0, pool=10.0)
|
||
|
||
proxy = None
|
||
proxy_config = self.config.get("proxy", {})
|
||
if proxy_config.get("enabled", False):
|
||
proxy_type = proxy_config.get("type", "socks5")
|
||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||
proxy_port = proxy_config.get("port", 7890)
|
||
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||
logger.debug(f"下载视频使用代理: {proxy}")
|
||
|
||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy, follow_redirects=True) as client:
|
||
response = await client.get(video_url)
|
||
response.raise_for_status()
|
||
|
||
with open(file_path, "wb") as f:
|
||
f.write(response.content)
|
||
|
||
return str(file_path.resolve())
|