Files
WeChatHookBot/plugins/GrokVideo/main.py

948 lines
37 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Grok 视频生成插件
用户引用图片并发送 /视频 提示词 来生成视频
支持队列系统和积分制
"""
import re
import tomllib
import httpx
import xml.etree.ElementTree as ET
import asyncio
import html
import pymysql
import uuid
import json
from pathlib import Path
from loguru import logger
from typing import Dict, Optional
from dataclasses import dataclass
from datetime import datetime
from utils.plugin_base import PluginBase
from utils.decorators import on_text_message
from utils.image_processor import ImageProcessor, MediaConfig
from minio import Minio
from minio.error import S3Error
# 定义引用消息装饰器
def on_quote_message(priority=50):
"""引用消息装饰器"""
def decorator(func):
setattr(func, '_event_type', 'quote_message') # 修复:应该是 quote_message
setattr(func, '_priority', min(max(priority, 0), 99))
return func
return decorator
@dataclass
class VideoTask:
"""视频生成任务数据"""
user_wxid: str
from_wxid: str
prompt: str
cdnurl: str = ""
aeskey: str = ""
is_group: bool = False
timestamp: datetime = datetime.now()
image_base64: str = ""
video_config: Optional[dict] = None
class GrokVideo(PluginBase):
"""Grok 视频生成插件"""
description = "使用 Grok AI 根据提示词生成视频(可选图片,支持队列和积分系统)"
author = "ShiHao"
version = "2.0.0"
def __init__(self):
super().__init__()
self.config = None
self.task_queue: asyncio.Queue = None
self.processing_tasks: Dict[str, VideoTask] = {} # 正在处理的任务
self.worker_task = None
self.minio_client = None
self._image_processor = None
async def async_init(self):
"""插件异步初始化"""
config_path = Path(__file__).parent / "config.toml"
with open(config_path, "rb") as f:
self.config = tomllib.load(f)
self.api_url = f"{self.config['api']['server_url'].rstrip('/')}/v1/chat/completions"
# 初始化MinIO客户端
self.minio_client = Minio(
"115.190.113.141:19000",
access_key="admin",
secret_key="80012029Lz",
secure=False
)
self.minio_bucket = "wechat"
temp_dir = Path(__file__).parent / "temp"
temp_dir.mkdir(exist_ok=True)
self._image_processor = ImageProcessor(MediaConfig(), temp_dir=temp_dir)
# 初始化队列
max_queue_size = self.config.get("queue", {}).get("max_queue_size", 10)
self.task_queue = asyncio.Queue(maxsize=max_queue_size)
# 启动工作线程
max_concurrent = self.config.get("queue", {}).get("max_concurrent", 1)
self.worker_task = asyncio.create_task(self._queue_worker())
logger.success(f"Grok 视频生成插件已加载")
logger.info(f"API: {self.api_url}")
logger.info(f"队列配置: 最大并发={max_concurrent}, 最大队列长度={max_queue_size}")
logger.info(f"积分系统: {'启用' if self.config.get('points', {}).get('enabled') else '禁用'}")
if self.config.get('points', {}).get('enabled'):
logger.info(f"每次生成消耗: {self.config['points']['cost']} 积分")
def get_db_connection(self):
"""获取数据库连接"""
db_config = self.config["database"]
return pymysql.connect(
host=db_config["host"],
port=db_config["port"],
user=db_config["user"],
password=db_config["password"],
database=db_config["database"],
charset=db_config["charset"],
autocommit=True
)
def get_user_points(self, wxid: str) -> int:
"""获取用户积分"""
try:
with self.get_db_connection() as conn:
with conn.cursor() as cursor:
sql = "SELECT points FROM user_signin WHERE wxid = %s"
cursor.execute(sql, (wxid,))
result = cursor.fetchone()
return result[0] if result else 0
except Exception as e:
logger.error(f"获取用户积分失败: {e}")
return 0
def deduct_points(self, wxid: str, points: int) -> bool:
"""扣除用户积分"""
try:
with self.get_db_connection() as conn:
with conn.cursor() as cursor:
# 检查积分是否足够
sql_check = "SELECT points FROM user_signin WHERE wxid = %s"
cursor.execute(sql_check, (wxid,))
result = cursor.fetchone()
if not result or result[0] < points:
return False
# 扣除积分
sql_update = "UPDATE user_signin SET points = points - %s WHERE wxid = %s"
cursor.execute(sql_update, (points, wxid))
logger.info(f"用户 {wxid} 扣除 {points} 积分")
return True
except Exception as e:
logger.error(f"扣除用户积分失败: {e}")
return False
def is_admin(self, wxid: str) -> bool:
"""检查是否是管理员"""
admins = self.config.get("points", {}).get("admins", [])
return wxid in admins
async def upload_video_to_minio(self, local_file: str, original_filename: str = "") -> str:
"""上传视频到MinIO"""
try:
# 生成唯一文件名
file_ext = Path(local_file).suffix
unique_id = uuid.uuid4().hex
if original_filename:
# 使用原始文件名(去掉扩展名)+ 唯一ID + 扩展名
original_name = Path(original_filename).stem
# 清理文件名中的特殊字符
import re
original_name = re.sub(r'[^\w\-_\.]', '_', original_name)
filename = f"{original_name}_{unique_id}{file_ext}"
else:
filename = f"grok_video_{unique_id}{file_ext}"
object_name = f"videos/{datetime.now().strftime('%Y%m%d')}/{filename}"
# 上传文件
await asyncio.to_thread(
self.minio_client.fput_object,
self.minio_bucket,
object_name,
local_file
)
# 返回访问URL
url = f"http://115.190.113.141:19000/{self.minio_bucket}/{object_name}"
logger.info(f"视频上传成功: {url}")
return url
except S3Error as e:
logger.error(f"上传视频到MinIO失败: {e}")
return ""
async def _queue_worker(self):
"""队列工作线程"""
logger.info("视频生成队列工作线程已启动")
while True:
try:
# 从队列获取任务
task_data = await self.task_queue.get()
bot, task = task_data
# 处理任务
await self._process_video_task(bot, task)
# 标记任务完成
self.task_queue.task_done()
except Exception as e:
logger.error(f"队列工作线程错误: {e}")
await asyncio.sleep(1)
async def _process_video_task(self, bot, task: VideoTask):
"""处理视频生成任务"""
logger.info(f"开始处理视频任务: user={task.user_wxid}, prompt={task.prompt}")
try:
image_base64 = (task.image_base64 or "").strip()
if not image_base64 and task.cdnurl and task.aeskey:
image_base64 = await self._download_and_encode_image(bot, task.cdnurl, task.aeskey)
if not image_base64:
await bot.send_text(task.from_wxid, "❌ 图片下载失败,请稍后重试")
return
video_url = ""
try:
video_url = await self._call_grok_api(
task.prompt,
image_base64=image_base64,
video_config=task.video_config,
)
except Exception as e:
if image_base64 and ("upstream_error" in str(e) or "500" in str(e)):
logger.warning(f"携带图片参数调用失败,自动降级文生视频重试: {e}")
video_url = await self._call_grok_api(
task.prompt,
image_base64="",
video_config=task.video_config,
)
else:
raise
if video_url:
video_path = await self._download_video(video_url)
if video_path:
logger.info(f"准备上传视频到 MinIO: {video_path}")
video_filename = Path(video_path).name
minio_url = await self.upload_video_to_minio(video_path, video_filename)
logger.info(f"准备发送视频到微信: {video_path}")
video_sent = await bot.send_media(task.from_wxid, video_path, media_type="video")
if video_sent:
try:
from plugins.MessageLogger.main import MessageLogger
message_logger = MessageLogger.get_instance()
if message_logger and minio_url:
await message_logger.save_bot_message(
task.from_wxid,
f"[视频] {task.prompt}",
"video",
minio_url
)
logger.info(f"视频消息已写入 MessageLogger: {minio_url}")
except Exception as e:
logger.warning(f"记录视频消息到 MessageLogger 失败: {e}")
points_config = self.config.get("points", {})
if points_config.get("enabled", False):
if self.is_admin(task.user_wxid):
success_msg = "✅ 视频生成成功\n🎟️ 管理员免费使用"
else:
cost = points_config.get("cost", 50)
remaining_points = self.get_user_points(task.user_wxid)
success_msg = f"✅ 视频生成成功\n💰 本次消费:{cost} 积分\n💎 剩余积分:{remaining_points}"
await bot.send_text(task.from_wxid, success_msg)
else:
await bot.send_text(task.from_wxid, "✅ 视频生成成功")
try:
Path(video_path).unlink()
logger.info(f"已清理本地视频文件: {video_path}")
except Exception as e:
logger.warning(f"清理本地视频文件失败: {e}")
else:
await bot.send_text(task.from_wxid, "❌ 视频发送失败,请稍后重试")
logger.error(f"视频发送失败: {video_path}")
else:
await bot.send_text(task.from_wxid, "❌ 视频下载失败,请稍后重试")
else:
await bot.send_text(task.from_wxid, "❌ 未从接口响应中获取到视频地址")
except Exception as e:
logger.error(f"处理视频任务失败: {e}")
await bot.send_text(task.from_wxid, f"❌ 视频生成失败: {str(e)}")
def _check_behavior_enabled(self, from_wxid: str, is_group: bool) -> bool:
if not self.config["behavior"].get("enabled", True):
return False
if not is_group:
return True
enabled_groups = self.config["behavior"].get("enabled_groups", [])
disabled_groups = self.config["behavior"].get("disabled_groups", [])
if from_wxid in disabled_groups:
return False
if enabled_groups and from_wxid not in enabled_groups:
return False
return True
def _build_video_config(self, *, from_tool: bool, aspect_ratio: str = "", preset: str = "") -> dict:
video_conf = self.config.get("video_generation", {})
length = int(video_conf.get("fixed_video_length", 6) or 6)
length = max(5, min(15, length))
resolution = str(video_conf.get("fixed_resolution", "SD") or "SD").upper()
if resolution not in {"SD", "HD"}:
resolution = "SD"
cfg = {
"video_length": length,
"resolution": resolution,
}
allowed_aspect_ratios = {"16:9", "9:16", "1:1", "2:3", "3:2"}
ar = (aspect_ratio or "").strip()
if ar in allowed_aspect_ratios:
cfg["aspect_ratio"] = ar
allowed_presets = {"fun", "normal", "spicy"}
default_preset = str(video_conf.get("command_preset", "normal") or "normal").strip().lower()
if default_preset not in allowed_presets:
default_preset = "normal"
if from_tool:
tool_preset = (preset or "").strip().lower()
if tool_preset in allowed_presets:
cfg["preset"] = tool_preset
else:
cfg["preset"] = default_preset
else:
cfg["preset"] = default_preset
return cfg
def _normalize_image_base64_input(self, raw: str) -> str:
"""清洗工具入参中的图片字段,避免无效值触发上游 500。"""
value = str(raw or "").strip()
if not value:
return ""
if value.lower() in {"null", "none", "nil", "undefined", "n/a", "", ""}:
return ""
if value.startswith("data:image/"):
return value
if value.startswith(("http://", "https://")):
return value
# 允许纯 base64自动补为 data URL
compact = re.sub(r"\s+", "", value)
if len(compact) >= 256 and re.fullmatch(r"[A-Za-z0-9+/=]+", compact):
return f"data:image/png;base64,{compact}"
logger.warning("检测到无效 image_base64 参数,已忽略该字段")
return ""
def _refund_points(self, user_wxid: str, amount: int):
if amount <= 0:
return
try:
with self.get_db_connection() as conn:
with conn.cursor() as cursor:
sql = "UPDATE user_signin SET points = points + %s WHERE wxid = %s"
cursor.execute(sql, (amount, user_wxid))
logger.info(f"积分已回退: {user_wxid}, +{amount}")
except Exception as e:
logger.error(f"回退积分失败: {e}")
async def _enqueue_video_task(
self,
bot,
*,
user_wxid: str,
from_wxid: str,
prompt: str,
is_group: bool,
cdnurl: str = "",
aeskey: str = "",
image_base64: str = "",
video_config: Optional[dict] = None,
) -> tuple[bool, str]:
points_config = self.config.get("points", {})
deducted_cost = 0
if points_config.get("enabled", False):
if not self.is_admin(user_wxid):
cost = int(points_config.get("cost", 50) or 50)
current_points = self.get_user_points(user_wxid)
if current_points < cost:
await bot.send_text(
from_wxid,
f"❌ 积分不足\n💎 当前积分:{current_points}\n💸 需要积分:{cost}\n\n请先签到或获取积分后再试~"
)
return False, "积分不足"
if not self.deduct_points(user_wxid, cost):
await bot.send_text(from_wxid, "❌ 扣除积分失败,请稍后重试")
return False, "扣分失败"
deducted_cost = cost
logger.info(f"用户 {user_wxid} 已扣除 {cost} 积分")
else:
logger.info(f"管理员用户 {user_wxid} 免费使用")
if self.task_queue.full():
await bot.send_text(from_wxid, f"❌ 当前队列已满({self.task_queue.qsize()}/{self.task_queue.maxsize}),请稍后再试")
if deducted_cost > 0:
self._refund_points(user_wxid, deducted_cost)
return False, "队列已满"
task = VideoTask(
user_wxid=user_wxid,
from_wxid=from_wxid,
prompt=prompt,
cdnurl=cdnurl,
aeskey=aeskey,
is_group=is_group,
timestamp=datetime.now(),
image_base64=image_base64,
video_config=video_config or {},
)
try:
await self.task_queue.put((bot, task))
queue_position = self.task_queue.qsize()
if queue_position == 1:
await bot.send_text(from_wxid, "⏳ 任务已提交,正在为你生成视频,请稍候...")
else:
await bot.send_text(from_wxid, f"⏳ 任务已加入队列\n📍 当前排队位置:第 {queue_position}\n🚀 正在加速处理中...")
logger.success(f"视频任务入队成功: user={user_wxid}, position={queue_position}")
return True, "任务已提交"
except Exception as e:
logger.error(f"视频任务入队失败: {e}")
await bot.send_text(from_wxid, "❌ 任务提交失败,请稍后重试")
if deducted_cost > 0:
self._refund_points(user_wxid, deducted_cost)
return False, f"任务提交失败: {e}"
@on_text_message(priority=90)
async def handle_video_text_command(self, bot, message: dict):
"""处理文本 /视频 命令(支持文生视频)"""
content = (message.get("Content", "") or "").strip()
if not content.startswith("/视频"):
return True
from_wxid = message.get("FromWxid", "")
sender_wxid = message.get("SenderWxid", "")
is_group = message.get("IsGroup", False)
user_wxid = sender_wxid if is_group else from_wxid
if not self._check_behavior_enabled(from_wxid, is_group):
return True
prompt = content[3:].strip()
if not prompt:
await bot.send_text(from_wxid, "❌ 请输入提示词,格式:/视频 提示词")
return False
video_config = self._build_video_config(from_tool=False)
await self._enqueue_video_task(
bot,
user_wxid=user_wxid,
from_wxid=from_wxid,
prompt=prompt,
is_group=is_group,
video_config=video_config,
)
return False
@on_quote_message(priority=90)
async def handle_video_command(self, bot, message: dict):
"""处理引用消息中的 /视频 命令(支持图生视频和文生视频)"""
content = message.get("Content", "").strip()
from_wxid = message.get("FromWxid", "")
sender_wxid = message.get("SenderWxid", "")
is_group = message.get("IsGroup", False)
user_wxid = sender_wxid if is_group else from_wxid
try:
xml_content = content.lstrip("\ufeff")
if ":\n" in xml_content:
xml_start = xml_content.find("<?xml")
if xml_start == -1:
xml_start = xml_content.find("<msg")
if xml_start > 0:
xml_content = xml_content[xml_start:]
root = ET.fromstring(xml_content)
title = root.find(".//title")
if title is None or not title.text:
return
title_text = title.text.strip()
if not title_text.startswith("/视频"):
return
if not self._check_behavior_enabled(from_wxid, is_group):
return False
prompt = title_text[3:].strip()
if not prompt:
await bot.send_text(from_wxid, "❌ 请输入提示词,格式:/视频 提示词")
return False
cdnbigimgurl = ""
aeskey = ""
refermsg = root.find(".//refermsg")
if refermsg is not None:
refer_content = refermsg.find("content")
if refer_content is not None and refer_content.text:
try:
import html
refer_xml = html.unescape(refer_content.text).lstrip("\ufeff")
if ":\n" in refer_xml:
xml_start = refer_xml.find("<?xml")
if xml_start == -1:
xml_start = refer_xml.find("<msg")
if xml_start > 0:
refer_xml = refer_xml[xml_start:]
refer_root = ET.fromstring(refer_xml)
img = refer_root.find(".//img")
if img is not None:
cdnbigimgurl = img.get("cdnbigimgurl", "")
aeskey = img.get("aeskey", "")
except Exception as e:
logger.warning(f"解析引用内容失败,降级为文生视频: {e}")
if cdnbigimgurl and aeskey:
logger.info("检测到引用图片,按图生视频处理")
else:
cdnbigimgurl = ""
aeskey = ""
logger.info("未检测到可用引用图片,按文生视频处理")
video_config = self._build_video_config(from_tool=False)
await self._enqueue_video_task(
bot,
user_wxid=user_wxid,
from_wxid=from_wxid,
prompt=prompt,
is_group=is_group,
cdnurl=cdnbigimgurl,
aeskey=aeskey,
video_config=video_config,
)
except Exception as e:
logger.error(f"解析引用消息失败: {e}")
return
return False
async def _download_and_encode_image(self, bot, cdnurl: str, aeskey: str) -> str:
"""下载图片并转换为 base64"""
if not self._image_processor:
logger.warning("ImageProcessor 未初始化,无法下载图片")
return ""
logger.info(f"正在下载图片: {cdnurl[:50]}...")
return await self._image_processor.download_image_by_cdn(bot, cdnurl, aeskey)
async def _call_grok_api(self, prompt: str, image_base64: str = "", video_config: Optional[dict] = None) -> str:
"""调用 Grok API 生成视频并返回视频 URL"""
api_key = self.config["api"]["api_key"]
if not api_key:
raise Exception("未配置 API Key")
user_content = prompt
image_base64 = self._normalize_image_base64_input(image_base64)
if image_base64:
user_content = [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": image_base64}}
]
payload = {
"model": self.config["api"]["model_id"],
"messages": [
{
"role": "user",
"content": user_content,
}
]
}
if isinstance(video_config, dict) and video_config:
payload["video_config"] = video_config
payload_for_log = json.loads(json.dumps(payload, ensure_ascii=False))
try:
if isinstance(payload_for_log.get("messages"), list):
for msg in payload_for_log["messages"]:
content = msg.get("content") if isinstance(msg, dict) else None
if isinstance(content, list):
for item in content:
if not isinstance(item, dict):
continue
if item.get("type") == "image_url":
image_url = (item.get("image_url") or {}).get("url", "")
if isinstance(image_url, str) and image_url:
item["image_url"]["url"] = f"<image_base64_or_url length={len(image_url)}>"
except Exception:
pass
logger.debug(f"Grok 请求 payload: {json.dumps(payload_for_log, ensure_ascii=False)}")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
timeout = httpx.Timeout(connect=10.0, read=self.config["api"]["timeout"], write=10.0, pool=10.0)
proxy = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5")
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
logger.info(f"使用代理请求: {proxy}")
async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client:
response = await client.post(self.api_url, json=payload, headers=headers)
raw_text = response.text or ""
logger.debug(
f"Grok 响应状态: {response.status_code}, content-type: {response.headers.get('content-type', '')}, 长度: {len(raw_text)}"
)
logger.debug(f"Grok 原始响应全文: {raw_text}")
if response.status_code != 200:
err_text = raw_text[:300]
raise Exception(f"API 请求失败: {response.status_code}, {err_text}")
try:
result = response.json()
except Exception as e:
raise Exception(f"API 返回非 JSON: {e}, body={raw_text[:500]}")
message_content = ""
try:
message_content = result["choices"][0]["message"].get("content", "")
except Exception:
message_content = ""
if isinstance(message_content, (dict, list)):
content_text = json.dumps(message_content, ensure_ascii=False)
else:
content_text = str(message_content or "")
video_url = self._extract_video_url(content_text)
if not video_url:
video_url = self._extract_video_url_from_obj(result)
if not video_url:
video_url = self._extract_video_url(json.dumps(result, ensure_ascii=False))
if not video_url:
result_preview = json.dumps(result, ensure_ascii=False)[:800]
logger.error(f"Grok 响应中未找到视频 URL响应预览: {result_preview}")
raise Exception("未从响应中提取到视频 URL")
logger.info(f"提取到视频 URL: {video_url}")
return video_url
def get_llm_tools(self):
"""获取 LLM 工具定义"""
llm_tool_conf = self.config.get("llm_tool", {})
if not llm_tool_conf.get("enabled", True):
return []
tool_name = llm_tool_conf.get("tool_name", "grok_video_generation")
tool_desc = llm_tool_conf.get(
"tool_description",
"Grok视频生成工具。可只传 prompt 文生视频,也可额外传 image_base64 图生视频。"
)
return [{
"type": "function",
"function": {
"name": tool_name,
"description": tool_desc,
"parameters": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "视频生成提示词"
},
"image_base64": {
"type": "string",
"description": "可选。参考图的 base64data:image/... 或纯 base64"
},
"aspect_ratio": {
"type": "string",
"enum": ["16:9", "9:16", "1:1", "2:3", "3:2"],
"description": "可选。视频画幅比例"
},
"preset": {
"type": "string",
"enum": ["fun", "normal", "spicy"],
"description": "可选。视频风格预设"
}
},
"required": ["prompt"],
"additionalProperties": False
}
}
}]
async def execute_llm_tool(self, tool_name: str, arguments: dict, bot, from_wxid: str) -> dict:
"""执行 LLM 工具调用"""
expected_tool = self.config.get("llm_tool", {}).get("tool_name", "grok_video_generation")
if tool_name != expected_tool:
return None
prompt = str(arguments.get("prompt") or "").strip()
if not prompt:
return {
"success": False,
"message": "缺少视频提示词",
"already_sent": False,
"no_reply": False,
}
user_wxid = str(arguments.get("user_wxid") or from_wxid)
is_group = bool(arguments.get("is_group", False))
image_base64 = self._normalize_image_base64_input(arguments.get("image_base64"))
aspect_ratio = str(arguments.get("aspect_ratio") or "").strip()
preset = str(arguments.get("preset") or "").strip()
if not self._check_behavior_enabled(from_wxid, is_group):
return {
"success": False,
"message": "当前会话未开启视频生成功能",
"already_sent": False,
"no_reply": False,
}
video_cfg = self._build_video_config(from_tool=True, aspect_ratio=aspect_ratio, preset=preset)
ok, msg = await self._enqueue_video_task(
bot,
user_wxid=user_wxid,
from_wxid=from_wxid,
prompt=prompt,
is_group=is_group,
image_base64=image_base64,
video_config=video_cfg,
)
return {
"success": ok,
"message": msg,
"already_sent": True,
"no_reply": True,
}
def _extract_video_url_from_obj(self, data) -> str:
"""从嵌套对象中提取视频 URL。"""
def is_http_url(value: str) -> bool:
return isinstance(value, str) and value.lower().startswith(("http://", "https://"))
def looks_like_video_url(value: str) -> bool:
if not is_http_url(value):
return False
lower = value.lower()
video_hints = (
".mp4", ".mov", ".m4v", ".webm", ".m3u8", ".mpd",
"video", "mime=video", "content-type=video", "content_type=video",
)
return any(hint in lower for hint in video_hints)
preferred_keys = (
"video_url", "videoUrl", "video", "url", "src",
"play_url", "playUrl", "download_url", "downloadUrl",
)
def walk(node):
if isinstance(node, dict):
for key in preferred_keys:
value = node.get(key)
if isinstance(value, str) and looks_like_video_url(value):
return value
for key in preferred_keys:
value = node.get(key)
if isinstance(value, str) and is_http_url(value):
return value
for value in node.values():
found = walk(value)
if found:
return found
elif isinstance(node, list):
for item in node:
found = walk(item)
if found:
return found
elif isinstance(node, str):
text = node.strip()
if "<video" in text.lower():
extracted = self._extract_video_url(text)
if extracted:
return extracted
if looks_like_video_url(text):
return text
if is_http_url(text):
return text
return ""
return walk(data)
def _extract_video_url(self, content: str) -> str:
"""从响应内容中提取视频 URL"""
content = str(content or "")
def clean_url(url: str) -> str:
cleaned = str(url or "").strip()
cleaned = cleaned.replace("\\/", "/")
cleaned = cleaned.strip(" \t\r\n\"'`")
cleaned = cleaned.rstrip("\\")
cleaned = cleaned.rstrip(">,.;)]}")
return cleaned
variants = [content]
v_unescape = html.unescape(content)
if v_unescape and v_unescape not in variants:
variants.append(v_unescape)
v_slash = content.replace("\\/", "/")
if v_slash and v_slash not in variants:
variants.append(v_slash)
v_quote = content.replace('\\"', '"').replace("\\'", "'")
if v_quote and v_quote not in variants:
variants.append(v_quote)
v_all = html.unescape(v_slash).replace('\\"', '"').replace("\\'", "'")
if v_all and v_all not in variants:
variants.append(v_all)
# 尝试从 HTML video 标签提取(兼容转义引号与无引号)
for variant in variants:
match = re.search(
r'<video[^>]*\bsrc\s*=\s*(?:\\?["\"])??([^"\'\s>]+)',
variant,
re.IGNORECASE,
)
if match:
url = clean_url(match.group(1))
if url.lower().startswith(("http://", "https://")):
return url
# 尝试从 JSON src 字段提取
for variant in variants:
match = re.search(r'"src"\s*:\s*"([^"]+)"', variant, re.IGNORECASE)
if match:
url = clean_url(match.group(1))
if url.lower().startswith(("http://", "https://")):
return url
# 尝试提取常见视频扩展链接
for variant in variants:
match = re.search(
r'(https?://[^\s<>"\')\]]+\.(?:mp4|mov|m4v|webm|m3u8|mpd)(?:\?[^\s<>"\')\]]*)?)',
variant,
re.IGNORECASE,
)
if match:
return clean_url(match.group(1))
# 尝试提取 markdown 链接中的 URL
for variant in variants:
match = re.search(r'\((https?://[^\s)]+)\)', variant, re.IGNORECASE)
if match:
return clean_url(match.group(1))
# 尝试提取任意 URL并优先选择像视频的链接
for variant in variants:
candidates = re.findall(r'https?://[^\s<>"\')\]]+', variant, re.IGNORECASE)
if candidates:
cleaned = [clean_url(u) for u in candidates]
for url in cleaned:
lower = url.lower()
if any(k in lower for k in (".mp4", ".mov", ".m4v", ".webm", ".m3u8", ".mpd", "video")):
return url
return cleaned[0]
# 内容是 JSON 文本时,尝试按对象结构提取
try:
parsed = json.loads(content)
from_obj = self._extract_video_url_from_obj(parsed)
if from_obj:
return from_obj
except Exception:
pass
return ""
async def _download_video(self, video_url: str) -> str:
"""下载视频到本地"""
videos_dir = Path(__file__).parent / "videos"
videos_dir.mkdir(exist_ok=True)
filename = f"grok_{datetime.now():%Y%m%d_%H%M%S}_{uuid.uuid4().hex[:8]}.mp4"
file_path = videos_dir / filename
timeout = httpx.Timeout(connect=10.0, read=300.0, write=10.0, pool=10.0)
proxy = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5")
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
logger.debug(f"下载视频使用代理: {proxy}")
async with httpx.AsyncClient(timeout=timeout, proxy=proxy, follow_redirects=True) as client:
response = await client.get(video_url)
response.raise_for_status()
with open(file_path, "wb") as f:
f.write(response.content)
return str(file_path.resolve())