为Dota2批量生图脚本增加4线程并发生成\n\n- 引入线程池并将图片生成任务改为并发执行\n- 新增并发线程参数,默认使用4个线程提升批量生成效率\n- 为日志输出与清单写入增加线程锁,避免并发场景下内容冲突

This commit is contained in:
liuwei
2026-04-29 16:23:59 +08:00
parent 77a49dfb45
commit 28dc9da852

View File

@@ -15,13 +15,15 @@ from __future__ import annotations
import argparse
import base64
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
import os
import re
import sys
import threading
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
import requests
import yaml
@@ -39,6 +41,9 @@ DEFAULT_OUTPUT_DIR = Path("temp") / "dota2_douyin_images"
# 这里固定每个英雄的 4 张图的语言分布,避免每次运行时还要手动指定。
DEFAULT_LANGUAGE_VARIANTS = ["zh", "zh", "ja", "ja"]
# 这里把默认并发数固定为 4满足你“开 4 个线程跑”的诉求。
DEFAULT_MAX_WORKERS = 4
def parse_args() -> argparse.Namespace:
"""解析命令行参数。"""
@@ -105,6 +110,12 @@ def parse_args() -> argparse.Namespace:
default="",
help="只生成英雄名中包含该关键字的英雄,便于单独补图。",
)
parser.add_argument(
"--max-workers",
type=int,
default=DEFAULT_MAX_WORKERS,
help="并发线程数,默认 4。",
)
parser.add_argument(
"--force",
action="store_true",
@@ -246,7 +257,7 @@ def build_consistent_prompt(hero: Dict[str, str], image_index: int) -> str:
请为短视频封面创作一张高完成度竖版插画,主体是 Dota2 英雄 {hero_name_cn}{hero_name_en})。
核心要求:
1. 角色设定明确为 Dota2 的骷髅王风格体系下的“至宝级华丽皮肤质感”,但角色身份必须是 {hero_name_cn} 本人,不要画成别的英雄。
1. 角色设定明确为 Dota2 的风格体系下的“至宝级华丽皮肤质感”,但角色身份必须是 {hero_name_cn} 本人,不要画成别的英雄。
2. 画面整体要强烈偏向 JOJO 气质:夸张肌肉与体块、强烈明暗对比、戏剧化姿势、锐利线条、张力十足的漫画分镜感、厚重阴影、速度线、压迫感构图。
3. 需要比普通日漫更偏 JOJO 风,风格统一、成熟、硬朗、华丽,视觉冲击力强。
4. 画面左下角固定放一个“能力雷达图”,用日式游戏 UI 风格表现,半透明发光面板,结构清晰。
@@ -346,6 +357,160 @@ def ensure_output_dir(output_dir: Path) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
def build_generation_tasks(
heroes: List[Dict[str, str]],
output_dir: Path,
count_per_hero: int,
) -> List[Dict[str, Any]]:
"""
预先展开所有生图任务。
这样做的目的:
1. 先把“英雄 x 第几张图”拍平成统一任务列表,便于线程池直接消费;
2. 任务对象中提前算好输出目录、文件名、提示词,线程里只负责执行;
3. 任务顺序保持稳定,后续日志更容易排查。
"""
tasks: List[Dict[str, Any]] = []
total_heroes = len(heroes)
for hero_index, hero in enumerate(heroes, start=1):
hero_slug = sanitize_filename(hero["english_name"].lower().replace(" ", "_"))
hero_dir = output_dir / f"{hero_slug}_{sanitize_filename(hero['localized_name'])}"
ensure_output_dir(hero_dir)
for image_index in range(1, count_per_hero + 1):
file_name = f"{hero_slug}_{image_index:02d}.png"
image_path = hero_dir / file_name
tasks.append(
{
"hero": hero,
"hero_index": hero_index,
"total_heroes": total_heroes,
"hero_slug": hero_slug,
"hero_dir": hero_dir,
"image_index": image_index,
"image_path": image_path,
"prompt": build_consistent_prompt(hero, image_index),
}
)
return tasks
def run_single_generation_task(
task: Dict[str, Any],
request_url: str,
api_key: str,
model: str,
image_size: str,
image_quality: str,
timeout_seconds: int,
max_retries: int,
delay_seconds: float,
force: bool,
manifest_path: Path,
manifest_lock: threading.Lock,
print_lock: threading.Lock,
) -> Tuple[str, Dict[str, Any]]:
"""
在线程池中执行单个图片生成任务。
返回值约定:
1. status 为 success / skipped / failed 三种之一;
2. payload 会带上日志和清单记录所需的数据,主线程只负责汇总结果;
3. manifest 写入放在线程内完成,但通过锁保证同一时刻只有一个线程落盘。
"""
hero = task["hero"]
image_index = task["image_index"]
image_path: Path = task["image_path"]
prompt = task["prompt"]
with print_lock:
print(
f"\n[{task['hero_index']}/{task['total_heroes']}] "
f"处理英雄: {hero['localized_name']} ({hero['english_name']}) "
f"- 第 {image_index}"
)
if image_path.exists() and not force:
with print_lock:
print(f" - 已存在,跳过: {image_path.name}")
return "skipped", {
"hero_id": hero["hero_id"],
"localized_name": hero["localized_name"],
"english_name": hero["english_name"],
"image_index": image_index,
"image_path": str(image_path.as_posix()),
}
last_error: Optional[str] = None
for retry_index in range(1, max_retries + 1):
try:
with print_lock:
print(f" - 生成第 {image_index} 张,尝试 {retry_index}/{max_retries}")
image_bytes = generate_one_image(
request_url=request_url,
api_key=api_key,
model=model,
prompt=prompt,
image_size=image_size,
image_quality=image_quality,
timeout_seconds=timeout_seconds,
)
with image_path.open("wb") as file_obj:
file_obj.write(image_bytes)
manifest_row = {
"hero_id": hero["hero_id"],
"localized_name": hero["localized_name"],
"english_name": hero["english_name"],
"image_index": image_index,
"image_path": str(image_path.as_posix()),
"size": image_size,
"quality": image_quality,
"model": model,
"request_url": request_url,
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"prompt": prompt,
}
# 这里用锁保护清单写入,避免多个线程同时写 JSONL 时内容互相穿插。
with manifest_lock:
append_manifest_row(manifest_path, manifest_row)
with print_lock:
print(f" - 生成成功: {image_path.name}")
time.sleep(delay_seconds)
return "success", manifest_row
except Exception as exc:
last_error = str(exc)
with print_lock:
print(f" - 生成失败: {last_error}")
if retry_index < max_retries:
# 这里做一个简短退避,降低临时网络波动或网关限流的影响。
time.sleep(min(5, retry_index * 2))
failed_row = {
"hero_id": hero["hero_id"],
"localized_name": hero["localized_name"],
"english_name": hero["english_name"],
"image_index": image_index,
"image_path": str(image_path.as_posix()),
"size": image_size,
"quality": image_quality,
"model": model,
"request_url": request_url,
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"error": last_error or "未知错误",
}
with manifest_lock:
append_manifest_row(manifest_path, failed_row)
return "failed", failed_row
def main() -> int:
"""脚本入口。"""
args = parse_args()
@@ -385,93 +550,49 @@ def main() -> int:
print(f"共准备生成 {len(heroes)} 个英雄,每个英雄 {args.count_per_hero} 张。")
print(f"图片接口: {request_url}")
print(f"输出目录: {output_dir.resolve()}")
print(f"并发线程数: {args.max_workers}")
total_success = 0
total_skipped = 0
total_failed = 0
manifest_lock = threading.Lock()
print_lock = threading.Lock()
tasks = build_generation_tasks(
heroes=heroes,
output_dir=output_dir,
count_per_hero=args.count_per_hero,
)
for hero_index, hero in enumerate(heroes, start=1):
hero_slug = sanitize_filename(hero["english_name"].lower().replace(" ", "_"))
hero_dir = output_dir / f"{hero_slug}_{sanitize_filename(hero['localized_name'])}"
ensure_output_dir(hero_dir)
# 这里将所有任务交给线程池统一调度,让脚本能够同时发起 4 个图片请求。
with ThreadPoolExecutor(max_workers=max(1, int(args.max_workers))) as executor:
future_to_task = {
executor.submit(
run_single_generation_task,
task,
request_url,
backend["api_key"],
backend["model"],
args.size,
args.quality,
timeout_seconds,
args.max_retries,
args.delay,
args.force,
manifest_path,
manifest_lock,
print_lock,
): task
for task in tasks
}
print(f"\n[{hero_index}/{len(heroes)}] 开始处理英雄: {hero['localized_name']} ({hero['english_name']})")
for image_index in range(1, args.count_per_hero + 1):
file_name = f"{hero_slug}_{image_index:02d}.png"
image_path = hero_dir / file_name
if image_path.exists() and not args.force:
for future in as_completed(future_to_task):
status, _ = future.result()
if status == "success":
total_success += 1
elif status == "skipped":
total_skipped += 1
print(f" - 已存在,跳过: {image_path.name}")
continue
prompt = build_consistent_prompt(hero, image_index)
last_error: Optional[str] = None
for retry_index in range(1, args.max_retries + 1):
try:
print(f" - 生成第 {image_index} 张,尝试 {retry_index}/{args.max_retries}")
image_bytes = generate_one_image(
request_url=request_url,
api_key=backend["api_key"],
model=backend["model"],
prompt=prompt,
image_size=args.size,
image_quality=args.quality,
timeout_seconds=timeout_seconds,
)
with image_path.open("wb") as file_obj:
file_obj.write(image_bytes)
append_manifest_row(
manifest_path,
{
"hero_id": hero["hero_id"],
"localized_name": hero["localized_name"],
"english_name": hero["english_name"],
"image_index": image_index,
"image_path": str(image_path.as_posix()),
"size": args.size,
"quality": args.quality,
"model": backend["model"],
"request_url": request_url,
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"prompt": prompt,
},
)
total_success += 1
print(f" - 生成成功: {image_path.name}")
time.sleep(args.delay)
break
except Exception as exc:
last_error = str(exc)
print(f" - 生成失败: {last_error}")
if retry_index < args.max_retries:
# 这里做一个简短退避,降低临时网络波动或网关限流的影响。
time.sleep(min(5, retry_index * 2))
else:
total_failed += 1
if last_error and (not image_path.exists()):
append_manifest_row(
manifest_path,
{
"hero_id": hero["hero_id"],
"localized_name": hero["localized_name"],
"english_name": hero["english_name"],
"image_index": image_index,
"image_path": str(image_path.as_posix()),
"size": args.size,
"quality": args.quality,
"model": backend["model"],
"request_url": request_url,
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"error": last_error,
},
)
else:
total_failed += 1
print("\n生成完成。")
print(f"成功: {total_success}")