344 lines
15 KiB
Python
344 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.common.errors.app_error import BusinessAppError, NotFoundAppError
|
|
from app.common.utils.id_gen import new_order_no
|
|
from app.core.providers import build_adapter
|
|
from app.core.storage import storage_service
|
|
from app.models.entities import MediaAsset, VideoGenerationTask, VideoTaskEvent
|
|
from app.modules.video_tasks.repository import VideoTasksRepository
|
|
from app.modules.wallets.service import WalletService
|
|
|
|
|
|
class VideoTasksService:
|
|
FINAL_STATUSES = {"succeeded", "failed", "cancelled", "timed_out"}
|
|
|
|
def __init__(self, db: Session) -> None:
|
|
self.db = db
|
|
self.repository = VideoTasksRepository(db)
|
|
self.wallet_service = WalletService(db)
|
|
|
|
def create_task(self, user_id: int, payload) -> dict:
|
|
video_model = self.repository.get_video_model(payload.videoModelId)
|
|
if not video_model or video_model.status != 1:
|
|
raise BusinessAppError("video model unavailable", code=50002)
|
|
pricing = self.repository.get_active_pricing(video_model.id)
|
|
if not pricing:
|
|
raise BusinessAppError("pricing rule unavailable", code=50003)
|
|
binding = self._select_binding(video_model.id)
|
|
provider_model = self.repository.get_provider_model(binding.provider_model_id)
|
|
provider_account = self.repository.get_provider_account(provider_model.provider_account_id)
|
|
estimated_points = max(
|
|
pricing.minimum_points,
|
|
pricing.points_per_second * payload.durationSeconds,
|
|
)
|
|
normalized = self._build_normalized_payload(user_id, payload, provider_model.id, provider_account.id)
|
|
task = VideoGenerationTask(
|
|
task_no=new_order_no("vt"),
|
|
user_id=user_id,
|
|
video_model_id=video_model.id,
|
|
provider_account_id=provider_account.id,
|
|
provider_model_id=provider_model.id,
|
|
provider_binding_id=binding.id,
|
|
pricing_rule_id=pricing.id,
|
|
task_status="queued",
|
|
generation_mode=self._infer_generation_mode(payload),
|
|
prompt_text=payload.prompt,
|
|
request_payload=normalized,
|
|
duration_seconds=payload.durationSeconds,
|
|
ratio=payload.ratio,
|
|
resolution=payload.resolution,
|
|
generate_audio=payload.generateAudio,
|
|
estimated_points=estimated_points,
|
|
frozen_points=estimated_points,
|
|
)
|
|
self.db.add(task)
|
|
self.db.flush()
|
|
self.wallet_service.freeze_points(
|
|
user_id,
|
|
estimated_points,
|
|
related_type="video_task",
|
|
related_id=task.id,
|
|
remark=f"freeze for {task.task_no}",
|
|
)
|
|
self._add_event(task.id, "created", "task created")
|
|
self._submit_task(task, provider_account, provider_model)
|
|
self.db.commit()
|
|
return {
|
|
"taskNo": task.task_no,
|
|
"taskStatus": task.task_status,
|
|
"estimatedPoints": task.estimated_points,
|
|
"frozenPoints": task.frozen_points,
|
|
}
|
|
|
|
def list_tasks(self, user_id: int) -> list[dict]:
|
|
tasks = self.repository.list_tasks(user_id).limit(100).all()
|
|
for task in tasks:
|
|
self._refresh_task_progress(task)
|
|
self.db.commit()
|
|
return [self.serialize_task_summary(task) for task in tasks if task.user_visible]
|
|
|
|
def get_task_detail(self, user_id: int, task_no: str) -> dict:
|
|
task = self.repository.get_task(user_id, task_no)
|
|
if not task:
|
|
raise NotFoundAppError("task not found", code=50006)
|
|
self._refresh_task_progress(task)
|
|
self.db.commit()
|
|
return self.serialize_task_detail(task)
|
|
|
|
def retry_task(self, user_id: int, task_no: str) -> dict:
|
|
task = self.repository.get_task(user_id, task_no)
|
|
if not task:
|
|
raise NotFoundAppError("task not found", code=50006)
|
|
payload = task.request_payload
|
|
class RetryPayload:
|
|
videoModelId = payload["videoModelId"]
|
|
prompt = payload["prompt"]
|
|
durationSeconds = payload["durationSeconds"]
|
|
resolution = payload["resolution"]
|
|
ratio = payload["ratio"]
|
|
generateAudio = payload["generateAudio"]
|
|
referenceImageAssetIds = payload.get("referenceImageAssetIds", [])
|
|
referenceVideoAssetIds = payload.get("referenceVideoAssetIds", [])
|
|
referenceAudioAssetIds = payload.get("referenceAudioAssetIds", [])
|
|
|
|
return self.create_task(user_id, RetryPayload)
|
|
|
|
def cancel_task(self, user_id: int, task_no: str) -> dict:
|
|
task = self.repository.get_task(user_id, task_no)
|
|
if not task:
|
|
raise NotFoundAppError("task not found", code=50006)
|
|
if task.task_status in self.FINAL_STATUSES:
|
|
return {"taskNo": task.task_no, "taskStatus": task.task_status}
|
|
task.task_status = "cancelled"
|
|
task.finished_at = datetime.utcnow()
|
|
self.wallet_service.release_frozen_points(
|
|
user_id,
|
|
task.frozen_points,
|
|
related_type="video_task",
|
|
related_id=task.id,
|
|
remark=f"cancel {task.task_no}",
|
|
)
|
|
self._add_event(task.id, "cancelled", "task cancelled")
|
|
self.db.commit()
|
|
return {"taskNo": task.task_no, "taskStatus": task.task_status}
|
|
|
|
def delete_task(self, user_id: int, task_no: str) -> dict:
|
|
task = self.repository.get_task(user_id, task_no)
|
|
if not task:
|
|
raise NotFoundAppError("task not found", code=50006)
|
|
task.user_visible = False
|
|
task.user_deleted_at = datetime.utcnow()
|
|
self.db.commit()
|
|
return {"taskNo": task.task_no, "deleted": True}
|
|
|
|
def admin_list_tasks(self) -> list[dict]:
|
|
tasks = self.db.query(VideoGenerationTask).order_by(VideoGenerationTask.id.desc()).limit(200).all()
|
|
for task in tasks:
|
|
self._refresh_task_progress(task)
|
|
self.db.commit()
|
|
return [self.serialize_task_summary(task) for task in tasks]
|
|
|
|
def admin_get_task(self, task_id: int) -> dict:
|
|
task = self.repository.get_task_by_id(task_id)
|
|
if not task:
|
|
raise NotFoundAppError("task not found", code=50006)
|
|
self._refresh_task_progress(task)
|
|
self.db.commit()
|
|
return self.serialize_task_detail(task)
|
|
|
|
def admin_retry_task(self, task_id: int) -> dict:
|
|
task = self.repository.get_task_by_id(task_id)
|
|
if not task:
|
|
raise NotFoundAppError("task not found", code=50006)
|
|
return self.retry_task(task.user_id, task.task_no)
|
|
|
|
def admin_refund_task(self, task_id: int) -> dict:
|
|
task = self.repository.get_task_by_id(task_id)
|
|
if not task:
|
|
raise NotFoundAppError("task not found", code=50006)
|
|
tx = self.wallet_service.add_points(
|
|
task.user_id,
|
|
task.final_points or task.frozen_points,
|
|
biz_type="refund",
|
|
related_type="video_task",
|
|
related_id=task.id,
|
|
remark=f"manual refund for {task.task_no}",
|
|
operator_type="admin",
|
|
)
|
|
self._add_event(task.id, "refund", "manual refund")
|
|
self.db.commit()
|
|
return {"taskNo": task.task_no, "refundTransactionNo": tx.transaction_no}
|
|
|
|
def _select_binding(self, video_model_id: int):
|
|
bindings = self.repository.get_bindings(video_model_id)
|
|
if not bindings:
|
|
raise BusinessAppError("no available provider", code=50003)
|
|
return bindings[0]
|
|
|
|
def _build_normalized_payload(self, user_id: int, payload, provider_model_id: int, provider_account_id: int) -> dict:
|
|
image_assets = self.repository.list_assets(user_id, payload.referenceImageAssetIds)
|
|
video_assets = self.repository.list_assets(user_id, payload.referenceVideoAssetIds)
|
|
audio_assets = self.repository.list_assets(user_id, payload.referenceAudioAssetIds)
|
|
return {
|
|
"videoModelId": payload.videoModelId,
|
|
"providerModelId": provider_model_id,
|
|
"providerAccountId": provider_account_id,
|
|
"prompt": payload.prompt,
|
|
"durationSeconds": payload.durationSeconds,
|
|
"resolution": payload.resolution,
|
|
"ratio": payload.ratio,
|
|
"generateAudio": payload.generateAudio,
|
|
"referenceImageAssetIds": payload.referenceImageAssetIds,
|
|
"referenceVideoAssetIds": payload.referenceVideoAssetIds,
|
|
"referenceAudioAssetIds": payload.referenceAudioAssetIds,
|
|
"referenceImages": [{"assetId": item.id, "url": item.public_url} for item in image_assets],
|
|
"referenceVideos": [{"assetId": item.id, "url": item.public_url} for item in video_assets],
|
|
"referenceAudios": [{"assetId": item.id, "url": item.public_url} for item in audio_assets],
|
|
}
|
|
|
|
@staticmethod
|
|
def _infer_generation_mode(payload) -> str:
|
|
if payload.referenceImageAssetIds or payload.referenceVideoAssetIds or payload.referenceAudioAssetIds:
|
|
return "multimodal"
|
|
return "text_to_video"
|
|
|
|
def _submit_task(self, task: VideoGenerationTask, provider_account, provider_model) -> None:
|
|
adapter = build_adapter(provider_account, provider_model)
|
|
result = adapter.submit_task(task.request_payload)
|
|
task.external_task_id = result["externalTaskId"]
|
|
task.task_status = result["normalizedStatus"]
|
|
task.submitted_at = datetime.utcnow()
|
|
task.response_payload = result
|
|
self._add_event(task.id, "submitted", "task submitted to provider")
|
|
|
|
def _refresh_task_progress(self, task: VideoGenerationTask) -> None:
|
|
if task.task_status in self.FINAL_STATUSES:
|
|
return
|
|
provider_model = self.repository.get_provider_model(task.provider_model_id)
|
|
provider_account = self.repository.get_provider_account(task.provider_account_id)
|
|
adapter = build_adapter(provider_account, provider_model)
|
|
result = adapter.query_task(task)
|
|
task.response_payload = result
|
|
status = result["normalizedStatus"]
|
|
if status == "succeeded" and task.result_asset_id is None:
|
|
content = adapter.download_result(task)
|
|
stored = storage_service.save_bytes(
|
|
content,
|
|
filename=f"{task.task_no}.mp4",
|
|
folder="generated/videos",
|
|
)
|
|
asset = MediaAsset(
|
|
asset_no=new_order_no("asset"),
|
|
user_id=task.user_id,
|
|
media_type="video",
|
|
source_type="generated",
|
|
original_filename=f"{task.task_no}.mp4",
|
|
mime_type="video/mp4",
|
|
file_ext=".mp4",
|
|
file_size=stored["file_size"],
|
|
storage_provider="local",
|
|
storage_bucket="local",
|
|
storage_key=stored["storage_key"],
|
|
public_url=stored["public_url"],
|
|
sha256=stored["sha256"],
|
|
status="active",
|
|
)
|
|
self.db.add(asset)
|
|
self.db.flush()
|
|
task.result_asset_id = asset.id
|
|
task.final_points = task.estimated_points
|
|
task.task_status = "succeeded"
|
|
task.finished_at = datetime.utcnow()
|
|
self.wallet_service.consume_frozen_points(
|
|
task.user_id,
|
|
task.frozen_points,
|
|
related_type="video_task",
|
|
related_id=task.id,
|
|
remark=f"consume for {task.task_no}",
|
|
)
|
|
self.wallet_service.try_issue_invite_reward(task.user_id, task.id, task.final_points)
|
|
self._add_event(task.id, "succeeded", "task succeeded")
|
|
elif status == "failed":
|
|
task.task_status = "failed"
|
|
task.fail_reason = result.get("rawResponse", {}).get("message", "provider failed")
|
|
task.finished_at = datetime.utcnow()
|
|
self.wallet_service.release_frozen_points(
|
|
task.user_id,
|
|
task.frozen_points,
|
|
related_type="video_task",
|
|
related_id=task.id,
|
|
remark=f"release for {task.task_no}",
|
|
)
|
|
self._add_event(task.id, "failed", "task failed")
|
|
else:
|
|
task.task_status = status
|
|
|
|
def _add_event(self, task_id: int, event_type: str, message: str) -> None:
|
|
self.db.add(
|
|
VideoTaskEvent(
|
|
video_task_id=task_id,
|
|
event_type=event_type,
|
|
event_message=message,
|
|
payload=None,
|
|
created_at=datetime.utcnow(),
|
|
)
|
|
)
|
|
|
|
def serialize_task_summary(self, task: VideoGenerationTask) -> dict:
|
|
return {
|
|
"id": task.id,
|
|
"taskNo": task.task_no,
|
|
"taskStatus": task.task_status,
|
|
"durationSeconds": task.duration_seconds,
|
|
"estimatedPoints": task.estimated_points,
|
|
"finalPoints": task.final_points,
|
|
"resultVideoUrl": self._result_url(task),
|
|
"failReason": task.fail_reason,
|
|
"createdAt": task.created_at.isoformat(),
|
|
"finishedAt": task.finished_at.isoformat() if task.finished_at else None,
|
|
}
|
|
|
|
def serialize_task_detail(self, task: VideoGenerationTask) -> dict:
|
|
provider_account = self.repository.get_provider_account(task.provider_account_id)
|
|
provider_model = self.repository.get_provider_model(task.provider_model_id)
|
|
video_model = self.repository.get_video_model(task.video_model_id)
|
|
events = self.repository.task_events(task.id).all()
|
|
return {
|
|
**self.serialize_task_summary(task),
|
|
"videoModel": {
|
|
"id": task.video_model_id,
|
|
"name": video_model.model_name if video_model else "",
|
|
},
|
|
"provider": {
|
|
"providerCode": provider_account.provider_code if provider_account else "",
|
|
"providerName": provider_account.provider_name if provider_account else "",
|
|
"modelCode": provider_model.model_code if provider_model else "",
|
|
"modelName": provider_model.model_name if provider_model else "",
|
|
},
|
|
"ratio": task.ratio,
|
|
"resolution": task.resolution,
|
|
"prompt": task.prompt_text or "",
|
|
"requestPayload": task.request_payload,
|
|
"responsePayload": task.response_payload,
|
|
"events": [
|
|
{
|
|
"eventType": item.event_type,
|
|
"eventMessage": item.event_message,
|
|
"createdAt": item.created_at.isoformat(),
|
|
}
|
|
for item in events
|
|
],
|
|
}
|
|
|
|
def _result_url(self, task: VideoGenerationTask) -> str:
|
|
if not task.result_asset_id:
|
|
return ""
|
|
asset = self.db.scalar(select(MediaAsset).where(MediaAsset.id == task.result_asset_id))
|
|
return asset.public_url if asset else ""
|