为Dota2批量生图脚本增加4线程并发生成\n\n- 引入线程池并将图片生成任务改为并发执行\n- 新增并发线程参数,默认使用4个线程提升批量生成效率\n- 为日志输出与清单写入增加线程锁,避免并发场景下内容冲突
This commit is contained in:
@@ -15,13 +15,15 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import base64
|
import base64
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import yaml
|
import yaml
|
||||||
@@ -39,6 +41,9 @@ DEFAULT_OUTPUT_DIR = Path("temp") / "dota2_douyin_images"
|
|||||||
# 这里固定每个英雄的 4 张图的语言分布,避免每次运行时还要手动指定。
|
# 这里固定每个英雄的 4 张图的语言分布,避免每次运行时还要手动指定。
|
||||||
DEFAULT_LANGUAGE_VARIANTS = ["zh", "zh", "ja", "ja"]
|
DEFAULT_LANGUAGE_VARIANTS = ["zh", "zh", "ja", "ja"]
|
||||||
|
|
||||||
|
# 这里把默认并发数固定为 4,满足你“开 4 个线程跑”的诉求。
|
||||||
|
DEFAULT_MAX_WORKERS = 4
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
"""解析命令行参数。"""
|
"""解析命令行参数。"""
|
||||||
@@ -105,6 +110,12 @@ def parse_args() -> argparse.Namespace:
|
|||||||
default="",
|
default="",
|
||||||
help="只生成英雄名中包含该关键字的英雄,便于单独补图。",
|
help="只生成英雄名中包含该关键字的英雄,便于单独补图。",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-workers",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_MAX_WORKERS,
|
||||||
|
help="并发线程数,默认 4。",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--force",
|
"--force",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -246,7 +257,7 @@ def build_consistent_prompt(hero: Dict[str, str], image_index: int) -> str:
|
|||||||
请为短视频封面创作一张高完成度竖版插画,主体是 Dota2 英雄 {hero_name_cn}({hero_name_en})。
|
请为短视频封面创作一张高完成度竖版插画,主体是 Dota2 英雄 {hero_name_cn}({hero_name_en})。
|
||||||
|
|
||||||
核心要求:
|
核心要求:
|
||||||
1. 角色设定明确为 Dota2 的骷髅王风格体系下的“至宝级华丽皮肤质感”,但角色身份必须是 {hero_name_cn} 本人,不要画成别的英雄。
|
1. 角色设定明确为 Dota2 的风格体系下的“至宝级华丽皮肤质感”,但角色身份必须是 {hero_name_cn} 本人,不要画成别的英雄。
|
||||||
2. 画面整体要强烈偏向 JOJO 气质:夸张肌肉与体块、强烈明暗对比、戏剧化姿势、锐利线条、张力十足的漫画分镜感、厚重阴影、速度线、压迫感构图。
|
2. 画面整体要强烈偏向 JOJO 气质:夸张肌肉与体块、强烈明暗对比、戏剧化姿势、锐利线条、张力十足的漫画分镜感、厚重阴影、速度线、压迫感构图。
|
||||||
3. 需要比普通日漫更偏 JOJO 风,风格统一、成熟、硬朗、华丽,视觉冲击力强。
|
3. 需要比普通日漫更偏 JOJO 风,风格统一、成熟、硬朗、华丽,视觉冲击力强。
|
||||||
4. 画面左下角固定放一个“能力雷达图”,用日式游戏 UI 风格表现,半透明发光面板,结构清晰。
|
4. 画面左下角固定放一个“能力雷达图”,用日式游戏 UI 风格表现,半透明发光面板,结构清晰。
|
||||||
@@ -346,6 +357,160 @@ def ensure_output_dir(output_dir: Path) -> None:
|
|||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def build_generation_tasks(
|
||||||
|
heroes: List[Dict[str, str]],
|
||||||
|
output_dir: Path,
|
||||||
|
count_per_hero: int,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
预先展开所有生图任务。
|
||||||
|
|
||||||
|
这样做的目的:
|
||||||
|
1. 先把“英雄 x 第几张图”拍平成统一任务列表,便于线程池直接消费;
|
||||||
|
2. 任务对象中提前算好输出目录、文件名、提示词,线程里只负责执行;
|
||||||
|
3. 任务顺序保持稳定,后续日志更容易排查。
|
||||||
|
"""
|
||||||
|
tasks: List[Dict[str, Any]] = []
|
||||||
|
total_heroes = len(heroes)
|
||||||
|
|
||||||
|
for hero_index, hero in enumerate(heroes, start=1):
|
||||||
|
hero_slug = sanitize_filename(hero["english_name"].lower().replace(" ", "_"))
|
||||||
|
hero_dir = output_dir / f"{hero_slug}_{sanitize_filename(hero['localized_name'])}"
|
||||||
|
ensure_output_dir(hero_dir)
|
||||||
|
|
||||||
|
for image_index in range(1, count_per_hero + 1):
|
||||||
|
file_name = f"{hero_slug}_{image_index:02d}.png"
|
||||||
|
image_path = hero_dir / file_name
|
||||||
|
tasks.append(
|
||||||
|
{
|
||||||
|
"hero": hero,
|
||||||
|
"hero_index": hero_index,
|
||||||
|
"total_heroes": total_heroes,
|
||||||
|
"hero_slug": hero_slug,
|
||||||
|
"hero_dir": hero_dir,
|
||||||
|
"image_index": image_index,
|
||||||
|
"image_path": image_path,
|
||||||
|
"prompt": build_consistent_prompt(hero, image_index),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
|
||||||
|
def run_single_generation_task(
|
||||||
|
task: Dict[str, Any],
|
||||||
|
request_url: str,
|
||||||
|
api_key: str,
|
||||||
|
model: str,
|
||||||
|
image_size: str,
|
||||||
|
image_quality: str,
|
||||||
|
timeout_seconds: int,
|
||||||
|
max_retries: int,
|
||||||
|
delay_seconds: float,
|
||||||
|
force: bool,
|
||||||
|
manifest_path: Path,
|
||||||
|
manifest_lock: threading.Lock,
|
||||||
|
print_lock: threading.Lock,
|
||||||
|
) -> Tuple[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
在线程池中执行单个图片生成任务。
|
||||||
|
|
||||||
|
返回值约定:
|
||||||
|
1. status 为 success / skipped / failed 三种之一;
|
||||||
|
2. payload 会带上日志和清单记录所需的数据,主线程只负责汇总结果;
|
||||||
|
3. manifest 写入放在线程内完成,但通过锁保证同一时刻只有一个线程落盘。
|
||||||
|
"""
|
||||||
|
hero = task["hero"]
|
||||||
|
image_index = task["image_index"]
|
||||||
|
image_path: Path = task["image_path"]
|
||||||
|
prompt = task["prompt"]
|
||||||
|
|
||||||
|
with print_lock:
|
||||||
|
print(
|
||||||
|
f"\n[{task['hero_index']}/{task['total_heroes']}] "
|
||||||
|
f"处理英雄: {hero['localized_name']} ({hero['english_name']}) "
|
||||||
|
f"- 第 {image_index} 张"
|
||||||
|
)
|
||||||
|
|
||||||
|
if image_path.exists() and not force:
|
||||||
|
with print_lock:
|
||||||
|
print(f" - 已存在,跳过: {image_path.name}")
|
||||||
|
return "skipped", {
|
||||||
|
"hero_id": hero["hero_id"],
|
||||||
|
"localized_name": hero["localized_name"],
|
||||||
|
"english_name": hero["english_name"],
|
||||||
|
"image_index": image_index,
|
||||||
|
"image_path": str(image_path.as_posix()),
|
||||||
|
}
|
||||||
|
|
||||||
|
last_error: Optional[str] = None
|
||||||
|
for retry_index in range(1, max_retries + 1):
|
||||||
|
try:
|
||||||
|
with print_lock:
|
||||||
|
print(f" - 生成第 {image_index} 张,尝试 {retry_index}/{max_retries}")
|
||||||
|
|
||||||
|
image_bytes = generate_one_image(
|
||||||
|
request_url=request_url,
|
||||||
|
api_key=api_key,
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
image_size=image_size,
|
||||||
|
image_quality=image_quality,
|
||||||
|
timeout_seconds=timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
with image_path.open("wb") as file_obj:
|
||||||
|
file_obj.write(image_bytes)
|
||||||
|
|
||||||
|
manifest_row = {
|
||||||
|
"hero_id": hero["hero_id"],
|
||||||
|
"localized_name": hero["localized_name"],
|
||||||
|
"english_name": hero["english_name"],
|
||||||
|
"image_index": image_index,
|
||||||
|
"image_path": str(image_path.as_posix()),
|
||||||
|
"size": image_size,
|
||||||
|
"quality": image_quality,
|
||||||
|
"model": model,
|
||||||
|
"request_url": request_url,
|
||||||
|
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
"prompt": prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 这里用锁保护清单写入,避免多个线程同时写 JSONL 时内容互相穿插。
|
||||||
|
with manifest_lock:
|
||||||
|
append_manifest_row(manifest_path, manifest_row)
|
||||||
|
|
||||||
|
with print_lock:
|
||||||
|
print(f" - 生成成功: {image_path.name}")
|
||||||
|
|
||||||
|
time.sleep(delay_seconds)
|
||||||
|
return "success", manifest_row
|
||||||
|
except Exception as exc:
|
||||||
|
last_error = str(exc)
|
||||||
|
with print_lock:
|
||||||
|
print(f" - 生成失败: {last_error}")
|
||||||
|
if retry_index < max_retries:
|
||||||
|
# 这里做一个简短退避,降低临时网络波动或网关限流的影响。
|
||||||
|
time.sleep(min(5, retry_index * 2))
|
||||||
|
|
||||||
|
failed_row = {
|
||||||
|
"hero_id": hero["hero_id"],
|
||||||
|
"localized_name": hero["localized_name"],
|
||||||
|
"english_name": hero["english_name"],
|
||||||
|
"image_index": image_index,
|
||||||
|
"image_path": str(image_path.as_posix()),
|
||||||
|
"size": image_size,
|
||||||
|
"quality": image_quality,
|
||||||
|
"model": model,
|
||||||
|
"request_url": request_url,
|
||||||
|
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
"error": last_error or "未知错误",
|
||||||
|
}
|
||||||
|
with manifest_lock:
|
||||||
|
append_manifest_row(manifest_path, failed_row)
|
||||||
|
return "failed", failed_row
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
def main() -> int:
|
||||||
"""脚本入口。"""
|
"""脚本入口。"""
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
@@ -385,93 +550,49 @@ def main() -> int:
|
|||||||
print(f"共准备生成 {len(heroes)} 个英雄,每个英雄 {args.count_per_hero} 张。")
|
print(f"共准备生成 {len(heroes)} 个英雄,每个英雄 {args.count_per_hero} 张。")
|
||||||
print(f"图片接口: {request_url}")
|
print(f"图片接口: {request_url}")
|
||||||
print(f"输出目录: {output_dir.resolve()}")
|
print(f"输出目录: {output_dir.resolve()}")
|
||||||
|
print(f"并发线程数: {args.max_workers}")
|
||||||
|
|
||||||
total_success = 0
|
total_success = 0
|
||||||
total_skipped = 0
|
total_skipped = 0
|
||||||
total_failed = 0
|
total_failed = 0
|
||||||
|
manifest_lock = threading.Lock()
|
||||||
|
print_lock = threading.Lock()
|
||||||
|
tasks = build_generation_tasks(
|
||||||
|
heroes=heroes,
|
||||||
|
output_dir=output_dir,
|
||||||
|
count_per_hero=args.count_per_hero,
|
||||||
|
)
|
||||||
|
|
||||||
for hero_index, hero in enumerate(heroes, start=1):
|
# 这里将所有任务交给线程池统一调度,让脚本能够同时发起 4 个图片请求。
|
||||||
hero_slug = sanitize_filename(hero["english_name"].lower().replace(" ", "_"))
|
with ThreadPoolExecutor(max_workers=max(1, int(args.max_workers))) as executor:
|
||||||
hero_dir = output_dir / f"{hero_slug}_{sanitize_filename(hero['localized_name'])}"
|
future_to_task = {
|
||||||
ensure_output_dir(hero_dir)
|
executor.submit(
|
||||||
|
run_single_generation_task,
|
||||||
|
task,
|
||||||
|
request_url,
|
||||||
|
backend["api_key"],
|
||||||
|
backend["model"],
|
||||||
|
args.size,
|
||||||
|
args.quality,
|
||||||
|
timeout_seconds,
|
||||||
|
args.max_retries,
|
||||||
|
args.delay,
|
||||||
|
args.force,
|
||||||
|
manifest_path,
|
||||||
|
manifest_lock,
|
||||||
|
print_lock,
|
||||||
|
): task
|
||||||
|
for task in tasks
|
||||||
|
}
|
||||||
|
|
||||||
print(f"\n[{hero_index}/{len(heroes)}] 开始处理英雄: {hero['localized_name']} ({hero['english_name']})")
|
for future in as_completed(future_to_task):
|
||||||
|
status, _ = future.result()
|
||||||
for image_index in range(1, args.count_per_hero + 1):
|
if status == "success":
|
||||||
file_name = f"{hero_slug}_{image_index:02d}.png"
|
total_success += 1
|
||||||
image_path = hero_dir / file_name
|
elif status == "skipped":
|
||||||
|
|
||||||
if image_path.exists() and not args.force:
|
|
||||||
total_skipped += 1
|
total_skipped += 1
|
||||||
print(f" - 已存在,跳过: {image_path.name}")
|
else:
|
||||||
continue
|
total_failed += 1
|
||||||
|
|
||||||
prompt = build_consistent_prompt(hero, image_index)
|
|
||||||
last_error: Optional[str] = None
|
|
||||||
|
|
||||||
for retry_index in range(1, args.max_retries + 1):
|
|
||||||
try:
|
|
||||||
print(f" - 生成第 {image_index} 张,尝试 {retry_index}/{args.max_retries}")
|
|
||||||
image_bytes = generate_one_image(
|
|
||||||
request_url=request_url,
|
|
||||||
api_key=backend["api_key"],
|
|
||||||
model=backend["model"],
|
|
||||||
prompt=prompt,
|
|
||||||
image_size=args.size,
|
|
||||||
image_quality=args.quality,
|
|
||||||
timeout_seconds=timeout_seconds,
|
|
||||||
)
|
|
||||||
|
|
||||||
with image_path.open("wb") as file_obj:
|
|
||||||
file_obj.write(image_bytes)
|
|
||||||
|
|
||||||
append_manifest_row(
|
|
||||||
manifest_path,
|
|
||||||
{
|
|
||||||
"hero_id": hero["hero_id"],
|
|
||||||
"localized_name": hero["localized_name"],
|
|
||||||
"english_name": hero["english_name"],
|
|
||||||
"image_index": image_index,
|
|
||||||
"image_path": str(image_path.as_posix()),
|
|
||||||
"size": args.size,
|
|
||||||
"quality": args.quality,
|
|
||||||
"model": backend["model"],
|
|
||||||
"request_url": request_url,
|
|
||||||
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
||||||
"prompt": prompt,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
total_success += 1
|
|
||||||
print(f" - 生成成功: {image_path.name}")
|
|
||||||
time.sleep(args.delay)
|
|
||||||
break
|
|
||||||
except Exception as exc:
|
|
||||||
last_error = str(exc)
|
|
||||||
print(f" - 生成失败: {last_error}")
|
|
||||||
if retry_index < args.max_retries:
|
|
||||||
# 这里做一个简短退避,降低临时网络波动或网关限流的影响。
|
|
||||||
time.sleep(min(5, retry_index * 2))
|
|
||||||
else:
|
|
||||||
total_failed += 1
|
|
||||||
|
|
||||||
if last_error and (not image_path.exists()):
|
|
||||||
append_manifest_row(
|
|
||||||
manifest_path,
|
|
||||||
{
|
|
||||||
"hero_id": hero["hero_id"],
|
|
||||||
"localized_name": hero["localized_name"],
|
|
||||||
"english_name": hero["english_name"],
|
|
||||||
"image_index": image_index,
|
|
||||||
"image_path": str(image_path.as_posix()),
|
|
||||||
"size": args.size,
|
|
||||||
"quality": args.quality,
|
|
||||||
"model": backend["model"],
|
|
||||||
"request_url": request_url,
|
|
||||||
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
||||||
"error": last_error,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
print("\n生成完成。")
|
print("\n生成完成。")
|
||||||
print(f"成功: {total_success}")
|
print(f"成功: {total_success}")
|
||||||
|
|||||||
Reference in New Issue
Block a user