diff --git a/scripts/generate_dota2_douyin_images.py b/scripts/generate_dota2_douyin_images.py index 86103c5..f6b65f2 100644 --- a/scripts/generate_dota2_douyin_images.py +++ b/scripts/generate_dota2_douyin_images.py @@ -15,13 +15,15 @@ from __future__ import annotations import argparse import base64 +from concurrent.futures import ThreadPoolExecutor, as_completed import json import os import re import sys +import threading import time from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import requests import yaml @@ -39,6 +41,9 @@ DEFAULT_OUTPUT_DIR = Path("temp") / "dota2_douyin_images" # 这里固定每个英雄的 4 张图的语言分布,避免每次运行时还要手动指定。 DEFAULT_LANGUAGE_VARIANTS = ["zh", "zh", "ja", "ja"] +# 这里把默认并发数固定为 4,满足你“开 4 个线程跑”的诉求。 +DEFAULT_MAX_WORKERS = 4 + def parse_args() -> argparse.Namespace: """解析命令行参数。""" @@ -105,6 +110,12 @@ def parse_args() -> argparse.Namespace: default="", help="只生成英雄名中包含该关键字的英雄,便于单独补图。", ) + parser.add_argument( + "--max-workers", + type=int, + default=DEFAULT_MAX_WORKERS, + help="并发线程数,默认 4。", + ) parser.add_argument( "--force", action="store_true", @@ -246,7 +257,7 @@ def build_consistent_prompt(hero: Dict[str, str], image_index: int) -> str: 请为短视频封面创作一张高完成度竖版插画,主体是 Dota2 英雄 {hero_name_cn}({hero_name_en})。 核心要求: -1. 角色设定明确为 Dota2 的骷髅王风格体系下的“至宝级华丽皮肤质感”,但角色身份必须是 {hero_name_cn} 本人,不要画成别的英雄。 +1. 角色设定明确为 Dota2 的风格体系下的“至宝级华丽皮肤质感”,但角色身份必须是 {hero_name_cn} 本人,不要画成别的英雄。 2. 画面整体要强烈偏向 JOJO 气质:夸张肌肉与体块、强烈明暗对比、戏剧化姿势、锐利线条、张力十足的漫画分镜感、厚重阴影、速度线、压迫感构图。 3. 需要比普通日漫更偏 JOJO 风,风格统一、成熟、硬朗、华丽,视觉冲击力强。 4. 画面左下角固定放一个“能力雷达图”,用日式游戏 UI 风格表现,半透明发光面板,结构清晰。 @@ -346,6 +357,160 @@ def ensure_output_dir(output_dir: Path) -> None: output_dir.mkdir(parents=True, exist_ok=True) +def build_generation_tasks( + heroes: List[Dict[str, str]], + output_dir: Path, + count_per_hero: int, +) -> List[Dict[str, Any]]: + """ + 预先展开所有生图任务。 + + 这样做的目的: + 1. 先把“英雄 x 第几张图”拍平成统一任务列表,便于线程池直接消费; + 2. 任务对象中提前算好输出目录、文件名、提示词,线程里只负责执行; + 3. 任务顺序保持稳定,后续日志更容易排查。 + """ + tasks: List[Dict[str, Any]] = [] + total_heroes = len(heroes) + + for hero_index, hero in enumerate(heroes, start=1): + hero_slug = sanitize_filename(hero["english_name"].lower().replace(" ", "_")) + hero_dir = output_dir / f"{hero_slug}_{sanitize_filename(hero['localized_name'])}" + ensure_output_dir(hero_dir) + + for image_index in range(1, count_per_hero + 1): + file_name = f"{hero_slug}_{image_index:02d}.png" + image_path = hero_dir / file_name + tasks.append( + { + "hero": hero, + "hero_index": hero_index, + "total_heroes": total_heroes, + "hero_slug": hero_slug, + "hero_dir": hero_dir, + "image_index": image_index, + "image_path": image_path, + "prompt": build_consistent_prompt(hero, image_index), + } + ) + + return tasks + + +def run_single_generation_task( + task: Dict[str, Any], + request_url: str, + api_key: str, + model: str, + image_size: str, + image_quality: str, + timeout_seconds: int, + max_retries: int, + delay_seconds: float, + force: bool, + manifest_path: Path, + manifest_lock: threading.Lock, + print_lock: threading.Lock, +) -> Tuple[str, Dict[str, Any]]: + """ + 在线程池中执行单个图片生成任务。 + + 返回值约定: + 1. status 为 success / skipped / failed 三种之一; + 2. payload 会带上日志和清单记录所需的数据,主线程只负责汇总结果; + 3. manifest 写入放在线程内完成,但通过锁保证同一时刻只有一个线程落盘。 + """ + hero = task["hero"] + image_index = task["image_index"] + image_path: Path = task["image_path"] + prompt = task["prompt"] + + with print_lock: + print( + f"\n[{task['hero_index']}/{task['total_heroes']}] " + f"处理英雄: {hero['localized_name']} ({hero['english_name']}) " + f"- 第 {image_index} 张" + ) + + if image_path.exists() and not force: + with print_lock: + print(f" - 已存在,跳过: {image_path.name}") + return "skipped", { + "hero_id": hero["hero_id"], + "localized_name": hero["localized_name"], + "english_name": hero["english_name"], + "image_index": image_index, + "image_path": str(image_path.as_posix()), + } + + last_error: Optional[str] = None + for retry_index in range(1, max_retries + 1): + try: + with print_lock: + print(f" - 生成第 {image_index} 张,尝试 {retry_index}/{max_retries}") + + image_bytes = generate_one_image( + request_url=request_url, + api_key=api_key, + model=model, + prompt=prompt, + image_size=image_size, + image_quality=image_quality, + timeout_seconds=timeout_seconds, + ) + + with image_path.open("wb") as file_obj: + file_obj.write(image_bytes) + + manifest_row = { + "hero_id": hero["hero_id"], + "localized_name": hero["localized_name"], + "english_name": hero["english_name"], + "image_index": image_index, + "image_path": str(image_path.as_posix()), + "size": image_size, + "quality": image_quality, + "model": model, + "request_url": request_url, + "generated_at": time.strftime("%Y-%m-%d %H:%M:%S"), + "prompt": prompt, + } + + # 这里用锁保护清单写入,避免多个线程同时写 JSONL 时内容互相穿插。 + with manifest_lock: + append_manifest_row(manifest_path, manifest_row) + + with print_lock: + print(f" - 生成成功: {image_path.name}") + + time.sleep(delay_seconds) + return "success", manifest_row + except Exception as exc: + last_error = str(exc) + with print_lock: + print(f" - 生成失败: {last_error}") + if retry_index < max_retries: + # 这里做一个简短退避,降低临时网络波动或网关限流的影响。 + time.sleep(min(5, retry_index * 2)) + + failed_row = { + "hero_id": hero["hero_id"], + "localized_name": hero["localized_name"], + "english_name": hero["english_name"], + "image_index": image_index, + "image_path": str(image_path.as_posix()), + "size": image_size, + "quality": image_quality, + "model": model, + "request_url": request_url, + "generated_at": time.strftime("%Y-%m-%d %H:%M:%S"), + "error": last_error or "未知错误", + } + with manifest_lock: + append_manifest_row(manifest_path, failed_row) + return "failed", failed_row + + def main() -> int: """脚本入口。""" args = parse_args() @@ -385,93 +550,49 @@ def main() -> int: print(f"共准备生成 {len(heroes)} 个英雄,每个英雄 {args.count_per_hero} 张。") print(f"图片接口: {request_url}") print(f"输出目录: {output_dir.resolve()}") + print(f"并发线程数: {args.max_workers}") total_success = 0 total_skipped = 0 total_failed = 0 + manifest_lock = threading.Lock() + print_lock = threading.Lock() + tasks = build_generation_tasks( + heroes=heroes, + output_dir=output_dir, + count_per_hero=args.count_per_hero, + ) - for hero_index, hero in enumerate(heroes, start=1): - hero_slug = sanitize_filename(hero["english_name"].lower().replace(" ", "_")) - hero_dir = output_dir / f"{hero_slug}_{sanitize_filename(hero['localized_name'])}" - ensure_output_dir(hero_dir) + # 这里将所有任务交给线程池统一调度,让脚本能够同时发起 4 个图片请求。 + with ThreadPoolExecutor(max_workers=max(1, int(args.max_workers))) as executor: + future_to_task = { + executor.submit( + run_single_generation_task, + task, + request_url, + backend["api_key"], + backend["model"], + args.size, + args.quality, + timeout_seconds, + args.max_retries, + args.delay, + args.force, + manifest_path, + manifest_lock, + print_lock, + ): task + for task in tasks + } - print(f"\n[{hero_index}/{len(heroes)}] 开始处理英雄: {hero['localized_name']} ({hero['english_name']})") - - for image_index in range(1, args.count_per_hero + 1): - file_name = f"{hero_slug}_{image_index:02d}.png" - image_path = hero_dir / file_name - - if image_path.exists() and not args.force: + for future in as_completed(future_to_task): + status, _ = future.result() + if status == "success": + total_success += 1 + elif status == "skipped": total_skipped += 1 - print(f" - 已存在,跳过: {image_path.name}") - continue - - prompt = build_consistent_prompt(hero, image_index) - last_error: Optional[str] = None - - for retry_index in range(1, args.max_retries + 1): - try: - print(f" - 生成第 {image_index} 张,尝试 {retry_index}/{args.max_retries}") - image_bytes = generate_one_image( - request_url=request_url, - api_key=backend["api_key"], - model=backend["model"], - prompt=prompt, - image_size=args.size, - image_quality=args.quality, - timeout_seconds=timeout_seconds, - ) - - with image_path.open("wb") as file_obj: - file_obj.write(image_bytes) - - append_manifest_row( - manifest_path, - { - "hero_id": hero["hero_id"], - "localized_name": hero["localized_name"], - "english_name": hero["english_name"], - "image_index": image_index, - "image_path": str(image_path.as_posix()), - "size": args.size, - "quality": args.quality, - "model": backend["model"], - "request_url": request_url, - "generated_at": time.strftime("%Y-%m-%d %H:%M:%S"), - "prompt": prompt, - }, - ) - - total_success += 1 - print(f" - 生成成功: {image_path.name}") - time.sleep(args.delay) - break - except Exception as exc: - last_error = str(exc) - print(f" - 生成失败: {last_error}") - if retry_index < args.max_retries: - # 这里做一个简短退避,降低临时网络波动或网关限流的影响。 - time.sleep(min(5, retry_index * 2)) - else: - total_failed += 1 - - if last_error and (not image_path.exists()): - append_manifest_row( - manifest_path, - { - "hero_id": hero["hero_id"], - "localized_name": hero["localized_name"], - "english_name": hero["english_name"], - "image_index": image_index, - "image_path": str(image_path.as_posix()), - "size": args.size, - "quality": args.quality, - "model": backend["model"], - "request_url": request_url, - "generated_at": time.strftime("%Y-%m-%d %H:%M:%S"), - "error": last_error, - }, - ) + else: + total_failed += 1 print("\n生成完成。") print(f"成功: {total_success}")