为Dota2批量生图脚本增加4线程并发生成\n\n- 引入线程池并将图片生成任务改为并发执行\n- 新增并发线程参数,默认使用4个线程提升批量生成效率\n- 为日志输出与清单写入增加线程锁,避免并发场景下内容冲突
This commit is contained in:
@@ -15,13 +15,15 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
@@ -39,6 +41,9 @@ DEFAULT_OUTPUT_DIR = Path("temp") / "dota2_douyin_images"
|
||||
# 这里固定每个英雄的 4 张图的语言分布,避免每次运行时还要手动指定。
|
||||
DEFAULT_LANGUAGE_VARIANTS = ["zh", "zh", "ja", "ja"]
|
||||
|
||||
# 这里把默认并发数固定为 4,满足你“开 4 个线程跑”的诉求。
|
||||
DEFAULT_MAX_WORKERS = 4
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""解析命令行参数。"""
|
||||
@@ -105,6 +110,12 @@ def parse_args() -> argparse.Namespace:
|
||||
default="",
|
||||
help="只生成英雄名中包含该关键字的英雄,便于单独补图。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-workers",
|
||||
type=int,
|
||||
default=DEFAULT_MAX_WORKERS,
|
||||
help="并发线程数,默认 4。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
@@ -246,7 +257,7 @@ def build_consistent_prompt(hero: Dict[str, str], image_index: int) -> str:
|
||||
请为短视频封面创作一张高完成度竖版插画,主体是 Dota2 英雄 {hero_name_cn}({hero_name_en})。
|
||||
|
||||
核心要求:
|
||||
1. 角色设定明确为 Dota2 的骷髅王风格体系下的“至宝级华丽皮肤质感”,但角色身份必须是 {hero_name_cn} 本人,不要画成别的英雄。
|
||||
1. 角色设定明确为 Dota2 的风格体系下的“至宝级华丽皮肤质感”,但角色身份必须是 {hero_name_cn} 本人,不要画成别的英雄。
|
||||
2. 画面整体要强烈偏向 JOJO 气质:夸张肌肉与体块、强烈明暗对比、戏剧化姿势、锐利线条、张力十足的漫画分镜感、厚重阴影、速度线、压迫感构图。
|
||||
3. 需要比普通日漫更偏 JOJO 风,风格统一、成熟、硬朗、华丽,视觉冲击力强。
|
||||
4. 画面左下角固定放一个“能力雷达图”,用日式游戏 UI 风格表现,半透明发光面板,结构清晰。
|
||||
@@ -346,6 +357,160 @@ def ensure_output_dir(output_dir: Path) -> None:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def build_generation_tasks(
|
||||
heroes: List[Dict[str, str]],
|
||||
output_dir: Path,
|
||||
count_per_hero: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
预先展开所有生图任务。
|
||||
|
||||
这样做的目的:
|
||||
1. 先把“英雄 x 第几张图”拍平成统一任务列表,便于线程池直接消费;
|
||||
2. 任务对象中提前算好输出目录、文件名、提示词,线程里只负责执行;
|
||||
3. 任务顺序保持稳定,后续日志更容易排查。
|
||||
"""
|
||||
tasks: List[Dict[str, Any]] = []
|
||||
total_heroes = len(heroes)
|
||||
|
||||
for hero_index, hero in enumerate(heroes, start=1):
|
||||
hero_slug = sanitize_filename(hero["english_name"].lower().replace(" ", "_"))
|
||||
hero_dir = output_dir / f"{hero_slug}_{sanitize_filename(hero['localized_name'])}"
|
||||
ensure_output_dir(hero_dir)
|
||||
|
||||
for image_index in range(1, count_per_hero + 1):
|
||||
file_name = f"{hero_slug}_{image_index:02d}.png"
|
||||
image_path = hero_dir / file_name
|
||||
tasks.append(
|
||||
{
|
||||
"hero": hero,
|
||||
"hero_index": hero_index,
|
||||
"total_heroes": total_heroes,
|
||||
"hero_slug": hero_slug,
|
||||
"hero_dir": hero_dir,
|
||||
"image_index": image_index,
|
||||
"image_path": image_path,
|
||||
"prompt": build_consistent_prompt(hero, image_index),
|
||||
}
|
||||
)
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
def run_single_generation_task(
|
||||
task: Dict[str, Any],
|
||||
request_url: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
image_size: str,
|
||||
image_quality: str,
|
||||
timeout_seconds: int,
|
||||
max_retries: int,
|
||||
delay_seconds: float,
|
||||
force: bool,
|
||||
manifest_path: Path,
|
||||
manifest_lock: threading.Lock,
|
||||
print_lock: threading.Lock,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
在线程池中执行单个图片生成任务。
|
||||
|
||||
返回值约定:
|
||||
1. status 为 success / skipped / failed 三种之一;
|
||||
2. payload 会带上日志和清单记录所需的数据,主线程只负责汇总结果;
|
||||
3. manifest 写入放在线程内完成,但通过锁保证同一时刻只有一个线程落盘。
|
||||
"""
|
||||
hero = task["hero"]
|
||||
image_index = task["image_index"]
|
||||
image_path: Path = task["image_path"]
|
||||
prompt = task["prompt"]
|
||||
|
||||
with print_lock:
|
||||
print(
|
||||
f"\n[{task['hero_index']}/{task['total_heroes']}] "
|
||||
f"处理英雄: {hero['localized_name']} ({hero['english_name']}) "
|
||||
f"- 第 {image_index} 张"
|
||||
)
|
||||
|
||||
if image_path.exists() and not force:
|
||||
with print_lock:
|
||||
print(f" - 已存在,跳过: {image_path.name}")
|
||||
return "skipped", {
|
||||
"hero_id": hero["hero_id"],
|
||||
"localized_name": hero["localized_name"],
|
||||
"english_name": hero["english_name"],
|
||||
"image_index": image_index,
|
||||
"image_path": str(image_path.as_posix()),
|
||||
}
|
||||
|
||||
last_error: Optional[str] = None
|
||||
for retry_index in range(1, max_retries + 1):
|
||||
try:
|
||||
with print_lock:
|
||||
print(f" - 生成第 {image_index} 张,尝试 {retry_index}/{max_retries}")
|
||||
|
||||
image_bytes = generate_one_image(
|
||||
request_url=request_url,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
image_size=image_size,
|
||||
image_quality=image_quality,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
with image_path.open("wb") as file_obj:
|
||||
file_obj.write(image_bytes)
|
||||
|
||||
manifest_row = {
|
||||
"hero_id": hero["hero_id"],
|
||||
"localized_name": hero["localized_name"],
|
||||
"english_name": hero["english_name"],
|
||||
"image_index": image_index,
|
||||
"image_path": str(image_path.as_posix()),
|
||||
"size": image_size,
|
||||
"quality": image_quality,
|
||||
"model": model,
|
||||
"request_url": request_url,
|
||||
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
# 这里用锁保护清单写入,避免多个线程同时写 JSONL 时内容互相穿插。
|
||||
with manifest_lock:
|
||||
append_manifest_row(manifest_path, manifest_row)
|
||||
|
||||
with print_lock:
|
||||
print(f" - 生成成功: {image_path.name}")
|
||||
|
||||
time.sleep(delay_seconds)
|
||||
return "success", manifest_row
|
||||
except Exception as exc:
|
||||
last_error = str(exc)
|
||||
with print_lock:
|
||||
print(f" - 生成失败: {last_error}")
|
||||
if retry_index < max_retries:
|
||||
# 这里做一个简短退避,降低临时网络波动或网关限流的影响。
|
||||
time.sleep(min(5, retry_index * 2))
|
||||
|
||||
failed_row = {
|
||||
"hero_id": hero["hero_id"],
|
||||
"localized_name": hero["localized_name"],
|
||||
"english_name": hero["english_name"],
|
||||
"image_index": image_index,
|
||||
"image_path": str(image_path.as_posix()),
|
||||
"size": image_size,
|
||||
"quality": image_quality,
|
||||
"model": model,
|
||||
"request_url": request_url,
|
||||
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"error": last_error or "未知错误",
|
||||
}
|
||||
with manifest_lock:
|
||||
append_manifest_row(manifest_path, failed_row)
|
||||
return "failed", failed_row
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""脚本入口。"""
|
||||
args = parse_args()
|
||||
@@ -385,93 +550,49 @@ def main() -> int:
|
||||
print(f"共准备生成 {len(heroes)} 个英雄,每个英雄 {args.count_per_hero} 张。")
|
||||
print(f"图片接口: {request_url}")
|
||||
print(f"输出目录: {output_dir.resolve()}")
|
||||
print(f"并发线程数: {args.max_workers}")
|
||||
|
||||
total_success = 0
|
||||
total_skipped = 0
|
||||
total_failed = 0
|
||||
manifest_lock = threading.Lock()
|
||||
print_lock = threading.Lock()
|
||||
tasks = build_generation_tasks(
|
||||
heroes=heroes,
|
||||
output_dir=output_dir,
|
||||
count_per_hero=args.count_per_hero,
|
||||
)
|
||||
|
||||
for hero_index, hero in enumerate(heroes, start=1):
|
||||
hero_slug = sanitize_filename(hero["english_name"].lower().replace(" ", "_"))
|
||||
hero_dir = output_dir / f"{hero_slug}_{sanitize_filename(hero['localized_name'])}"
|
||||
ensure_output_dir(hero_dir)
|
||||
# 这里将所有任务交给线程池统一调度,让脚本能够同时发起 4 个图片请求。
|
||||
with ThreadPoolExecutor(max_workers=max(1, int(args.max_workers))) as executor:
|
||||
future_to_task = {
|
||||
executor.submit(
|
||||
run_single_generation_task,
|
||||
task,
|
||||
request_url,
|
||||
backend["api_key"],
|
||||
backend["model"],
|
||||
args.size,
|
||||
args.quality,
|
||||
timeout_seconds,
|
||||
args.max_retries,
|
||||
args.delay,
|
||||
args.force,
|
||||
manifest_path,
|
||||
manifest_lock,
|
||||
print_lock,
|
||||
): task
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
print(f"\n[{hero_index}/{len(heroes)}] 开始处理英雄: {hero['localized_name']} ({hero['english_name']})")
|
||||
|
||||
for image_index in range(1, args.count_per_hero + 1):
|
||||
file_name = f"{hero_slug}_{image_index:02d}.png"
|
||||
image_path = hero_dir / file_name
|
||||
|
||||
if image_path.exists() and not args.force:
|
||||
for future in as_completed(future_to_task):
|
||||
status, _ = future.result()
|
||||
if status == "success":
|
||||
total_success += 1
|
||||
elif status == "skipped":
|
||||
total_skipped += 1
|
||||
print(f" - 已存在,跳过: {image_path.name}")
|
||||
continue
|
||||
|
||||
prompt = build_consistent_prompt(hero, image_index)
|
||||
last_error: Optional[str] = None
|
||||
|
||||
for retry_index in range(1, args.max_retries + 1):
|
||||
try:
|
||||
print(f" - 生成第 {image_index} 张,尝试 {retry_index}/{args.max_retries}")
|
||||
image_bytes = generate_one_image(
|
||||
request_url=request_url,
|
||||
api_key=backend["api_key"],
|
||||
model=backend["model"],
|
||||
prompt=prompt,
|
||||
image_size=args.size,
|
||||
image_quality=args.quality,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
with image_path.open("wb") as file_obj:
|
||||
file_obj.write(image_bytes)
|
||||
|
||||
append_manifest_row(
|
||||
manifest_path,
|
||||
{
|
||||
"hero_id": hero["hero_id"],
|
||||
"localized_name": hero["localized_name"],
|
||||
"english_name": hero["english_name"],
|
||||
"image_index": image_index,
|
||||
"image_path": str(image_path.as_posix()),
|
||||
"size": args.size,
|
||||
"quality": args.quality,
|
||||
"model": backend["model"],
|
||||
"request_url": request_url,
|
||||
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"prompt": prompt,
|
||||
},
|
||||
)
|
||||
|
||||
total_success += 1
|
||||
print(f" - 生成成功: {image_path.name}")
|
||||
time.sleep(args.delay)
|
||||
break
|
||||
except Exception as exc:
|
||||
last_error = str(exc)
|
||||
print(f" - 生成失败: {last_error}")
|
||||
if retry_index < args.max_retries:
|
||||
# 这里做一个简短退避,降低临时网络波动或网关限流的影响。
|
||||
time.sleep(min(5, retry_index * 2))
|
||||
else:
|
||||
total_failed += 1
|
||||
|
||||
if last_error and (not image_path.exists()):
|
||||
append_manifest_row(
|
||||
manifest_path,
|
||||
{
|
||||
"hero_id": hero["hero_id"],
|
||||
"localized_name": hero["localized_name"],
|
||||
"english_name": hero["english_name"],
|
||||
"image_index": image_index,
|
||||
"image_path": str(image_path.as_posix()),
|
||||
"size": args.size,
|
||||
"quality": args.quality,
|
||||
"model": backend["model"],
|
||||
"request_url": request_url,
|
||||
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"error": last_error,
|
||||
},
|
||||
)
|
||||
else:
|
||||
total_failed += 1
|
||||
|
||||
print("\n生成完成。")
|
||||
print(f"成功: {total_success}")
|
||||
|
||||
Reference in New Issue
Block a user