feat: initialize aivideo project
This commit is contained in:
241
backend/app/core/providers.py
Normal file
241
backend/app/core/providers.py
Normal file
@@ -0,0 +1,241 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import httpx
|
||||
|
||||
from app.common.config.settings import get_settings
|
||||
from app.common.utils.id_gen import new_public_id
|
||||
from app.models.entities import ProviderAccount, ProviderModel, VideoGenerationTask
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class ProviderAdapter:
|
||||
def __init__(self, account: ProviderAccount, provider_model: ProviderModel) -> None:
|
||||
self.account = account
|
||||
self.provider_model = provider_model
|
||||
|
||||
@property
|
||||
def is_mock(self) -> bool:
|
||||
return self.account.base_url.startswith("mock://")
|
||||
|
||||
def submit_task(self, payload: dict) -> dict:
|
||||
if self.is_mock:
|
||||
return {
|
||||
"externalTaskId": new_public_id("ext"),
|
||||
"normalizedStatus": "submitted",
|
||||
"progress": 0,
|
||||
"rawResponse": {
|
||||
"mock": True,
|
||||
"apiFormat": self.account.api_format,
|
||||
"submittedPayload": payload,
|
||||
},
|
||||
}
|
||||
|
||||
if self.account.api_format == "openai_official_video":
|
||||
return self._submit_openai(payload)
|
||||
if self.account.api_format == "seedance_video_generation":
|
||||
return self._submit_seedance(payload)
|
||||
raise ValueError("unsupported provider format")
|
||||
|
||||
def query_task(self, task: VideoGenerationTask) -> dict:
|
||||
if self.is_mock:
|
||||
return self._query_mock(task)
|
||||
if self.account.api_format == "openai_official_video":
|
||||
return self._query_openai(task.external_task_id)
|
||||
if self.account.api_format == "seedance_video_generation":
|
||||
return self._query_seedance(task.external_task_id)
|
||||
raise ValueError("unsupported provider format")
|
||||
|
||||
def download_result(self, task: VideoGenerationTask) -> bytes:
|
||||
if self.is_mock:
|
||||
return self._download_mock(task)
|
||||
if self.account.api_format == "openai_official_video":
|
||||
return self._download_openai(task.external_task_id)
|
||||
if self.account.api_format == "seedance_video_generation":
|
||||
return self._download_seedance(task)
|
||||
raise ValueError("unsupported provider format")
|
||||
|
||||
def _submit_openai(self, payload: dict) -> dict:
|
||||
files = {
|
||||
"prompt": (None, payload["prompt"]),
|
||||
"model": (None, self.provider_model.model_code),
|
||||
"seconds": (None, str(payload["durationSeconds"])),
|
||||
"size": (None, payload["resolution"]),
|
||||
}
|
||||
response = httpx.post(
|
||||
f"{self.account.base_url.rstrip('/')}/v1/videos",
|
||||
headers={"Authorization": f"Bearer {self.account.api_key_encrypted}"},
|
||||
files=files,
|
||||
timeout=self.account.timeout_seconds,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return {
|
||||
"externalTaskId": data["id"],
|
||||
"normalizedStatus": self._normalize_status(data.get("status")),
|
||||
"progress": data.get("progress", 0),
|
||||
"rawResponse": data,
|
||||
}
|
||||
|
||||
def _submit_seedance(self, payload: dict) -> dict:
|
||||
content = [{"type": "text", "text": payload["prompt"]}]
|
||||
for item in payload.get("referenceImages", []):
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": item["url"]},
|
||||
"role": "reference_image",
|
||||
}
|
||||
)
|
||||
response = httpx.post(
|
||||
f"{self.account.base_url.rstrip('/')}/v1/video/generations",
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.account.api_key_encrypted}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": self.provider_model.model_code,
|
||||
"content": content,
|
||||
"duration": payload["durationSeconds"],
|
||||
"ratio": payload["ratio"],
|
||||
"generate_audio": payload["generateAudio"],
|
||||
},
|
||||
timeout=self.account.timeout_seconds,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
external_id = data.get("id") or data.get("task_id")
|
||||
return {
|
||||
"externalTaskId": external_id,
|
||||
"normalizedStatus": self._normalize_status(data.get("status")),
|
||||
"progress": data.get("progress", 0),
|
||||
"rawResponse": data,
|
||||
}
|
||||
|
||||
def _query_openai(self, external_task_id: str) -> dict:
|
||||
response = httpx.get(
|
||||
f"{self.account.base_url.rstrip('/')}/v1/videos/{external_task_id}",
|
||||
headers={"Authorization": f"Bearer {self.account.api_key_encrypted}"},
|
||||
timeout=self.account.timeout_seconds,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return {
|
||||
"externalTaskId": external_task_id,
|
||||
"normalizedStatus": self._normalize_status(data.get("status")),
|
||||
"progress": data.get("progress", 0),
|
||||
"resultUrl": data.get("result_url", ""),
|
||||
"rawResponse": data,
|
||||
}
|
||||
|
||||
def _query_seedance(self, external_task_id: str) -> dict:
|
||||
response = httpx.get(
|
||||
f"{self.account.base_url.rstrip('/')}/v1/video/generations/{external_task_id}",
|
||||
headers={"Authorization": f"Bearer {self.account.api_key_encrypted}"},
|
||||
timeout=self.account.timeout_seconds,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
result_url = ""
|
||||
if isinstance(data.get("result"), dict):
|
||||
result_url = data["result"].get("video_url", "")
|
||||
return {
|
||||
"externalTaskId": external_task_id,
|
||||
"normalizedStatus": self._normalize_status(data.get("status")),
|
||||
"progress": data.get("progress", 0),
|
||||
"resultUrl": result_url,
|
||||
"rawResponse": data,
|
||||
}
|
||||
|
||||
def _download_openai(self, external_task_id: str) -> bytes:
|
||||
response = httpx.get(
|
||||
f"{self.account.base_url.rstrip('/')}/v1/videos/{external_task_id}/content",
|
||||
headers={"Authorization": f"Bearer {self.account.api_key_encrypted}"},
|
||||
timeout=self.account.timeout_seconds,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
def _download_seedance(self, task: VideoGenerationTask) -> bytes:
|
||||
payload = task.response_payload or {}
|
||||
result_url = payload.get("resultUrl")
|
||||
if not result_url and isinstance(payload.get("rawResponse"), dict):
|
||||
result_url = payload["rawResponse"].get("result", {}).get("video_url")
|
||||
if not result_url:
|
||||
raise ValueError("missing result url")
|
||||
response = httpx.get(result_url, timeout=self.account.timeout_seconds)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
def _query_mock(self, task: VideoGenerationTask) -> dict:
|
||||
started = task.submitted_at or task.created_at
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
elapsed = max(0, int((now - started).total_seconds()))
|
||||
total = max(1, settings.mock_task_run_seconds)
|
||||
progress = min(100, int(elapsed / total * 100))
|
||||
if elapsed < 3:
|
||||
status = "submitted"
|
||||
elif elapsed < total:
|
||||
status = "running"
|
||||
else:
|
||||
status = "succeeded"
|
||||
progress = 100
|
||||
return {
|
||||
"externalTaskId": task.external_task_id,
|
||||
"normalizedStatus": status,
|
||||
"progress": progress,
|
||||
"resultUrl": "",
|
||||
"rawResponse": {
|
||||
"mock": True,
|
||||
"elapsedSeconds": elapsed,
|
||||
"status": status,
|
||||
"progress": progress,
|
||||
},
|
||||
}
|
||||
|
||||
def _download_mock(self, task: VideoGenerationTask) -> bytes:
|
||||
payload = task.request_payload or {}
|
||||
resolution = payload.get("resolution", "1280x720")
|
||||
width, height = resolution.split("x", 1)
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
target = Path(tmp_dir) / f"{task.task_no}.mp4"
|
||||
command = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-f",
|
||||
"lavfi",
|
||||
"-i",
|
||||
f"color=c=#14213d:s={width}x{height}:d=3",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
str(target),
|
||||
]
|
||||
subprocess.run(command, check=True, capture_output=True)
|
||||
return target.read_bytes()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_status(status: str | None) -> str:
|
||||
mapping = {
|
||||
"queued": "queued",
|
||||
"pending": "queued",
|
||||
"submitted": "submitted",
|
||||
"running": "running",
|
||||
"in_progress": "running",
|
||||
"completed": "succeeded",
|
||||
"succeeded": "succeeded",
|
||||
"failed": "failed",
|
||||
"error": "failed",
|
||||
"cancelled": "cancelled",
|
||||
"timed_out": "timed_out",
|
||||
}
|
||||
return mapping.get((status or "").lower(), "running")
|
||||
|
||||
|
||||
def build_adapter(account: ProviderAccount, provider_model: ProviderModel) -> ProviderAdapter:
|
||||
return ProviderAdapter(account, provider_model)
|
||||
Reference in New Issue
Block a user