Files

344 lines
15 KiB
Python

from __future__ import annotations
from datetime import datetime
from pathlib import Path
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.common.errors.app_error import BusinessAppError, NotFoundAppError
from app.common.utils.id_gen import new_order_no
from app.core.providers import build_adapter
from app.core.storage import storage_service
from app.models.entities import MediaAsset, VideoGenerationTask, VideoTaskEvent
from app.modules.video_tasks.repository import VideoTasksRepository
from app.modules.wallets.service import WalletService
class VideoTasksService:
FINAL_STATUSES = {"succeeded", "failed", "cancelled", "timed_out"}
def __init__(self, db: Session) -> None:
self.db = db
self.repository = VideoTasksRepository(db)
self.wallet_service = WalletService(db)
def create_task(self, user_id: int, payload) -> dict:
video_model = self.repository.get_video_model(payload.videoModelId)
if not video_model or video_model.status != 1:
raise BusinessAppError("video model unavailable", code=50002)
pricing = self.repository.get_active_pricing(video_model.id)
if not pricing:
raise BusinessAppError("pricing rule unavailable", code=50003)
binding = self._select_binding(video_model.id)
provider_model = self.repository.get_provider_model(binding.provider_model_id)
provider_account = self.repository.get_provider_account(provider_model.provider_account_id)
estimated_points = max(
pricing.minimum_points,
pricing.points_per_second * payload.durationSeconds,
)
normalized = self._build_normalized_payload(user_id, payload, provider_model.id, provider_account.id)
task = VideoGenerationTask(
task_no=new_order_no("vt"),
user_id=user_id,
video_model_id=video_model.id,
provider_account_id=provider_account.id,
provider_model_id=provider_model.id,
provider_binding_id=binding.id,
pricing_rule_id=pricing.id,
task_status="queued",
generation_mode=self._infer_generation_mode(payload),
prompt_text=payload.prompt,
request_payload=normalized,
duration_seconds=payload.durationSeconds,
ratio=payload.ratio,
resolution=payload.resolution,
generate_audio=payload.generateAudio,
estimated_points=estimated_points,
frozen_points=estimated_points,
)
self.db.add(task)
self.db.flush()
self.wallet_service.freeze_points(
user_id,
estimated_points,
related_type="video_task",
related_id=task.id,
remark=f"freeze for {task.task_no}",
)
self._add_event(task.id, "created", "task created")
self._submit_task(task, provider_account, provider_model)
self.db.commit()
return {
"taskNo": task.task_no,
"taskStatus": task.task_status,
"estimatedPoints": task.estimated_points,
"frozenPoints": task.frozen_points,
}
def list_tasks(self, user_id: int) -> list[dict]:
tasks = self.repository.list_tasks(user_id).limit(100).all()
for task in tasks:
self._refresh_task_progress(task)
self.db.commit()
return [self.serialize_task_summary(task) for task in tasks if task.user_visible]
def get_task_detail(self, user_id: int, task_no: str) -> dict:
task = self.repository.get_task(user_id, task_no)
if not task:
raise NotFoundAppError("task not found", code=50006)
self._refresh_task_progress(task)
self.db.commit()
return self.serialize_task_detail(task)
def retry_task(self, user_id: int, task_no: str) -> dict:
task = self.repository.get_task(user_id, task_no)
if not task:
raise NotFoundAppError("task not found", code=50006)
payload = task.request_payload
class RetryPayload:
videoModelId = payload["videoModelId"]
prompt = payload["prompt"]
durationSeconds = payload["durationSeconds"]
resolution = payload["resolution"]
ratio = payload["ratio"]
generateAudio = payload["generateAudio"]
referenceImageAssetIds = payload.get("referenceImageAssetIds", [])
referenceVideoAssetIds = payload.get("referenceVideoAssetIds", [])
referenceAudioAssetIds = payload.get("referenceAudioAssetIds", [])
return self.create_task(user_id, RetryPayload)
def cancel_task(self, user_id: int, task_no: str) -> dict:
task = self.repository.get_task(user_id, task_no)
if not task:
raise NotFoundAppError("task not found", code=50006)
if task.task_status in self.FINAL_STATUSES:
return {"taskNo": task.task_no, "taskStatus": task.task_status}
task.task_status = "cancelled"
task.finished_at = datetime.utcnow()
self.wallet_service.release_frozen_points(
user_id,
task.frozen_points,
related_type="video_task",
related_id=task.id,
remark=f"cancel {task.task_no}",
)
self._add_event(task.id, "cancelled", "task cancelled")
self.db.commit()
return {"taskNo": task.task_no, "taskStatus": task.task_status}
def delete_task(self, user_id: int, task_no: str) -> dict:
task = self.repository.get_task(user_id, task_no)
if not task:
raise NotFoundAppError("task not found", code=50006)
task.user_visible = False
task.user_deleted_at = datetime.utcnow()
self.db.commit()
return {"taskNo": task.task_no, "deleted": True}
def admin_list_tasks(self) -> list[dict]:
tasks = self.db.query(VideoGenerationTask).order_by(VideoGenerationTask.id.desc()).limit(200).all()
for task in tasks:
self._refresh_task_progress(task)
self.db.commit()
return [self.serialize_task_summary(task) for task in tasks]
def admin_get_task(self, task_id: int) -> dict:
task = self.repository.get_task_by_id(task_id)
if not task:
raise NotFoundAppError("task not found", code=50006)
self._refresh_task_progress(task)
self.db.commit()
return self.serialize_task_detail(task)
def admin_retry_task(self, task_id: int) -> dict:
task = self.repository.get_task_by_id(task_id)
if not task:
raise NotFoundAppError("task not found", code=50006)
return self.retry_task(task.user_id, task.task_no)
def admin_refund_task(self, task_id: int) -> dict:
task = self.repository.get_task_by_id(task_id)
if not task:
raise NotFoundAppError("task not found", code=50006)
tx = self.wallet_service.add_points(
task.user_id,
task.final_points or task.frozen_points,
biz_type="refund",
related_type="video_task",
related_id=task.id,
remark=f"manual refund for {task.task_no}",
operator_type="admin",
)
self._add_event(task.id, "refund", "manual refund")
self.db.commit()
return {"taskNo": task.task_no, "refundTransactionNo": tx.transaction_no}
def _select_binding(self, video_model_id: int):
bindings = self.repository.get_bindings(video_model_id)
if not bindings:
raise BusinessAppError("no available provider", code=50003)
return bindings[0]
def _build_normalized_payload(self, user_id: int, payload, provider_model_id: int, provider_account_id: int) -> dict:
image_assets = self.repository.list_assets(user_id, payload.referenceImageAssetIds)
video_assets = self.repository.list_assets(user_id, payload.referenceVideoAssetIds)
audio_assets = self.repository.list_assets(user_id, payload.referenceAudioAssetIds)
return {
"videoModelId": payload.videoModelId,
"providerModelId": provider_model_id,
"providerAccountId": provider_account_id,
"prompt": payload.prompt,
"durationSeconds": payload.durationSeconds,
"resolution": payload.resolution,
"ratio": payload.ratio,
"generateAudio": payload.generateAudio,
"referenceImageAssetIds": payload.referenceImageAssetIds,
"referenceVideoAssetIds": payload.referenceVideoAssetIds,
"referenceAudioAssetIds": payload.referenceAudioAssetIds,
"referenceImages": [{"assetId": item.id, "url": item.public_url} for item in image_assets],
"referenceVideos": [{"assetId": item.id, "url": item.public_url} for item in video_assets],
"referenceAudios": [{"assetId": item.id, "url": item.public_url} for item in audio_assets],
}
@staticmethod
def _infer_generation_mode(payload) -> str:
if payload.referenceImageAssetIds or payload.referenceVideoAssetIds or payload.referenceAudioAssetIds:
return "multimodal"
return "text_to_video"
def _submit_task(self, task: VideoGenerationTask, provider_account, provider_model) -> None:
adapter = build_adapter(provider_account, provider_model)
result = adapter.submit_task(task.request_payload)
task.external_task_id = result["externalTaskId"]
task.task_status = result["normalizedStatus"]
task.submitted_at = datetime.utcnow()
task.response_payload = result
self._add_event(task.id, "submitted", "task submitted to provider")
def _refresh_task_progress(self, task: VideoGenerationTask) -> None:
if task.task_status in self.FINAL_STATUSES:
return
provider_model = self.repository.get_provider_model(task.provider_model_id)
provider_account = self.repository.get_provider_account(task.provider_account_id)
adapter = build_adapter(provider_account, provider_model)
result = adapter.query_task(task)
task.response_payload = result
status = result["normalizedStatus"]
if status == "succeeded" and task.result_asset_id is None:
content = adapter.download_result(task)
stored = storage_service.save_bytes(
content,
filename=f"{task.task_no}.mp4",
folder="generated/videos",
)
asset = MediaAsset(
asset_no=new_order_no("asset"),
user_id=task.user_id,
media_type="video",
source_type="generated",
original_filename=f"{task.task_no}.mp4",
mime_type="video/mp4",
file_ext=".mp4",
file_size=stored["file_size"],
storage_provider="local",
storage_bucket="local",
storage_key=stored["storage_key"],
public_url=stored["public_url"],
sha256=stored["sha256"],
status="active",
)
self.db.add(asset)
self.db.flush()
task.result_asset_id = asset.id
task.final_points = task.estimated_points
task.task_status = "succeeded"
task.finished_at = datetime.utcnow()
self.wallet_service.consume_frozen_points(
task.user_id,
task.frozen_points,
related_type="video_task",
related_id=task.id,
remark=f"consume for {task.task_no}",
)
self.wallet_service.try_issue_invite_reward(task.user_id, task.id, task.final_points)
self._add_event(task.id, "succeeded", "task succeeded")
elif status == "failed":
task.task_status = "failed"
task.fail_reason = result.get("rawResponse", {}).get("message", "provider failed")
task.finished_at = datetime.utcnow()
self.wallet_service.release_frozen_points(
task.user_id,
task.frozen_points,
related_type="video_task",
related_id=task.id,
remark=f"release for {task.task_no}",
)
self._add_event(task.id, "failed", "task failed")
else:
task.task_status = status
def _add_event(self, task_id: int, event_type: str, message: str) -> None:
self.db.add(
VideoTaskEvent(
video_task_id=task_id,
event_type=event_type,
event_message=message,
payload=None,
created_at=datetime.utcnow(),
)
)
def serialize_task_summary(self, task: VideoGenerationTask) -> dict:
return {
"id": task.id,
"taskNo": task.task_no,
"taskStatus": task.task_status,
"durationSeconds": task.duration_seconds,
"estimatedPoints": task.estimated_points,
"finalPoints": task.final_points,
"resultVideoUrl": self._result_url(task),
"failReason": task.fail_reason,
"createdAt": task.created_at.isoformat(),
"finishedAt": task.finished_at.isoformat() if task.finished_at else None,
}
def serialize_task_detail(self, task: VideoGenerationTask) -> dict:
provider_account = self.repository.get_provider_account(task.provider_account_id)
provider_model = self.repository.get_provider_model(task.provider_model_id)
video_model = self.repository.get_video_model(task.video_model_id)
events = self.repository.task_events(task.id).all()
return {
**self.serialize_task_summary(task),
"videoModel": {
"id": task.video_model_id,
"name": video_model.model_name if video_model else "",
},
"provider": {
"providerCode": provider_account.provider_code if provider_account else "",
"providerName": provider_account.provider_name if provider_account else "",
"modelCode": provider_model.model_code if provider_model else "",
"modelName": provider_model.model_name if provider_model else "",
},
"ratio": task.ratio,
"resolution": task.resolution,
"prompt": task.prompt_text or "",
"requestPayload": task.request_payload,
"responsePayload": task.response_payload,
"events": [
{
"eventType": item.event_type,
"eventMessage": item.event_message,
"createdAt": item.created_at.isoformat(),
}
for item in events
],
}
def _result_url(self, task: VideoGenerationTask) -> str:
if not task.result_asset_id:
return ""
asset = self.db.scalar(select(MediaAsset).where(MediaAsset.id == task.result_asset_id))
return asset.public_url if asset else ""