from __future__ import annotations from datetime import datetime from pathlib import Path from sqlalchemy import select from sqlalchemy.orm import Session from app.common.errors.app_error import BusinessAppError, NotFoundAppError from app.common.utils.id_gen import new_order_no from app.core.providers import build_adapter from app.core.storage import storage_service from app.models.entities import MediaAsset, VideoGenerationTask, VideoTaskEvent from app.modules.video_tasks.repository import VideoTasksRepository from app.modules.wallets.service import WalletService class VideoTasksService: FINAL_STATUSES = {"succeeded", "failed", "cancelled", "timed_out"} def __init__(self, db: Session) -> None: self.db = db self.repository = VideoTasksRepository(db) self.wallet_service = WalletService(db) def create_task(self, user_id: int, payload) -> dict: video_model = self.repository.get_video_model(payload.videoModelId) if not video_model or video_model.status != 1: raise BusinessAppError("video model unavailable", code=50002) pricing = self.repository.get_active_pricing(video_model.id) if not pricing: raise BusinessAppError("pricing rule unavailable", code=50003) binding = self._select_binding(video_model.id) provider_model = self.repository.get_provider_model(binding.provider_model_id) provider_account = self.repository.get_provider_account(provider_model.provider_account_id) estimated_points = max( pricing.minimum_points, pricing.points_per_second * payload.durationSeconds, ) normalized = self._build_normalized_payload(user_id, payload, provider_model.id, provider_account.id) task = VideoGenerationTask( task_no=new_order_no("vt"), user_id=user_id, video_model_id=video_model.id, provider_account_id=provider_account.id, provider_model_id=provider_model.id, provider_binding_id=binding.id, pricing_rule_id=pricing.id, task_status="queued", generation_mode=self._infer_generation_mode(payload), prompt_text=payload.prompt, request_payload=normalized, duration_seconds=payload.durationSeconds, ratio=payload.ratio, resolution=payload.resolution, generate_audio=payload.generateAudio, estimated_points=estimated_points, frozen_points=estimated_points, ) self.db.add(task) self.db.flush() self.wallet_service.freeze_points( user_id, estimated_points, related_type="video_task", related_id=task.id, remark=f"freeze for {task.task_no}", ) self._add_event(task.id, "created", "task created") self._submit_task(task, provider_account, provider_model) self.db.commit() return { "taskNo": task.task_no, "taskStatus": task.task_status, "estimatedPoints": task.estimated_points, "frozenPoints": task.frozen_points, } def list_tasks(self, user_id: int) -> list[dict]: tasks = self.repository.list_tasks(user_id).limit(100).all() for task in tasks: self._refresh_task_progress(task) self.db.commit() return [self.serialize_task_summary(task) for task in tasks if task.user_visible] def get_task_detail(self, user_id: int, task_no: str) -> dict: task = self.repository.get_task(user_id, task_no) if not task: raise NotFoundAppError("task not found", code=50006) self._refresh_task_progress(task) self.db.commit() return self.serialize_task_detail(task) def retry_task(self, user_id: int, task_no: str) -> dict: task = self.repository.get_task(user_id, task_no) if not task: raise NotFoundAppError("task not found", code=50006) payload = task.request_payload class RetryPayload: videoModelId = payload["videoModelId"] prompt = payload["prompt"] durationSeconds = payload["durationSeconds"] resolution = payload["resolution"] ratio = payload["ratio"] generateAudio = payload["generateAudio"] referenceImageAssetIds = payload.get("referenceImageAssetIds", []) referenceVideoAssetIds = payload.get("referenceVideoAssetIds", []) referenceAudioAssetIds = payload.get("referenceAudioAssetIds", []) return self.create_task(user_id, RetryPayload) def cancel_task(self, user_id: int, task_no: str) -> dict: task = self.repository.get_task(user_id, task_no) if not task: raise NotFoundAppError("task not found", code=50006) if task.task_status in self.FINAL_STATUSES: return {"taskNo": task.task_no, "taskStatus": task.task_status} task.task_status = "cancelled" task.finished_at = datetime.utcnow() self.wallet_service.release_frozen_points( user_id, task.frozen_points, related_type="video_task", related_id=task.id, remark=f"cancel {task.task_no}", ) self._add_event(task.id, "cancelled", "task cancelled") self.db.commit() return {"taskNo": task.task_no, "taskStatus": task.task_status} def delete_task(self, user_id: int, task_no: str) -> dict: task = self.repository.get_task(user_id, task_no) if not task: raise NotFoundAppError("task not found", code=50006) task.user_visible = False task.user_deleted_at = datetime.utcnow() self.db.commit() return {"taskNo": task.task_no, "deleted": True} def admin_list_tasks(self) -> list[dict]: tasks = self.db.query(VideoGenerationTask).order_by(VideoGenerationTask.id.desc()).limit(200).all() for task in tasks: self._refresh_task_progress(task) self.db.commit() return [self.serialize_task_summary(task) for task in tasks] def admin_get_task(self, task_id: int) -> dict: task = self.repository.get_task_by_id(task_id) if not task: raise NotFoundAppError("task not found", code=50006) self._refresh_task_progress(task) self.db.commit() return self.serialize_task_detail(task) def admin_retry_task(self, task_id: int) -> dict: task = self.repository.get_task_by_id(task_id) if not task: raise NotFoundAppError("task not found", code=50006) return self.retry_task(task.user_id, task.task_no) def admin_refund_task(self, task_id: int) -> dict: task = self.repository.get_task_by_id(task_id) if not task: raise NotFoundAppError("task not found", code=50006) tx = self.wallet_service.add_points( task.user_id, task.final_points or task.frozen_points, biz_type="refund", related_type="video_task", related_id=task.id, remark=f"manual refund for {task.task_no}", operator_type="admin", ) self._add_event(task.id, "refund", "manual refund") self.db.commit() return {"taskNo": task.task_no, "refundTransactionNo": tx.transaction_no} def _select_binding(self, video_model_id: int): bindings = self.repository.get_bindings(video_model_id) if not bindings: raise BusinessAppError("no available provider", code=50003) return bindings[0] def _build_normalized_payload(self, user_id: int, payload, provider_model_id: int, provider_account_id: int) -> dict: image_assets = self.repository.list_assets(user_id, payload.referenceImageAssetIds) video_assets = self.repository.list_assets(user_id, payload.referenceVideoAssetIds) audio_assets = self.repository.list_assets(user_id, payload.referenceAudioAssetIds) return { "videoModelId": payload.videoModelId, "providerModelId": provider_model_id, "providerAccountId": provider_account_id, "prompt": payload.prompt, "durationSeconds": payload.durationSeconds, "resolution": payload.resolution, "ratio": payload.ratio, "generateAudio": payload.generateAudio, "referenceImageAssetIds": payload.referenceImageAssetIds, "referenceVideoAssetIds": payload.referenceVideoAssetIds, "referenceAudioAssetIds": payload.referenceAudioAssetIds, "referenceImages": [{"assetId": item.id, "url": item.public_url} for item in image_assets], "referenceVideos": [{"assetId": item.id, "url": item.public_url} for item in video_assets], "referenceAudios": [{"assetId": item.id, "url": item.public_url} for item in audio_assets], } @staticmethod def _infer_generation_mode(payload) -> str: if payload.referenceImageAssetIds or payload.referenceVideoAssetIds or payload.referenceAudioAssetIds: return "multimodal" return "text_to_video" def _submit_task(self, task: VideoGenerationTask, provider_account, provider_model) -> None: adapter = build_adapter(provider_account, provider_model) result = adapter.submit_task(task.request_payload) task.external_task_id = result["externalTaskId"] task.task_status = result["normalizedStatus"] task.submitted_at = datetime.utcnow() task.response_payload = result self._add_event(task.id, "submitted", "task submitted to provider") def _refresh_task_progress(self, task: VideoGenerationTask) -> None: if task.task_status in self.FINAL_STATUSES: return provider_model = self.repository.get_provider_model(task.provider_model_id) provider_account = self.repository.get_provider_account(task.provider_account_id) adapter = build_adapter(provider_account, provider_model) result = adapter.query_task(task) task.response_payload = result status = result["normalizedStatus"] if status == "succeeded" and task.result_asset_id is None: content = adapter.download_result(task) stored = storage_service.save_bytes( content, filename=f"{task.task_no}.mp4", folder="generated/videos", ) asset = MediaAsset( asset_no=new_order_no("asset"), user_id=task.user_id, media_type="video", source_type="generated", original_filename=f"{task.task_no}.mp4", mime_type="video/mp4", file_ext=".mp4", file_size=stored["file_size"], storage_provider="local", storage_bucket="local", storage_key=stored["storage_key"], public_url=stored["public_url"], sha256=stored["sha256"], status="active", ) self.db.add(asset) self.db.flush() task.result_asset_id = asset.id task.final_points = task.estimated_points task.task_status = "succeeded" task.finished_at = datetime.utcnow() self.wallet_service.consume_frozen_points( task.user_id, task.frozen_points, related_type="video_task", related_id=task.id, remark=f"consume for {task.task_no}", ) self.wallet_service.try_issue_invite_reward(task.user_id, task.id, task.final_points) self._add_event(task.id, "succeeded", "task succeeded") elif status == "failed": task.task_status = "failed" task.fail_reason = result.get("rawResponse", {}).get("message", "provider failed") task.finished_at = datetime.utcnow() self.wallet_service.release_frozen_points( task.user_id, task.frozen_points, related_type="video_task", related_id=task.id, remark=f"release for {task.task_no}", ) self._add_event(task.id, "failed", "task failed") else: task.task_status = status def _add_event(self, task_id: int, event_type: str, message: str) -> None: self.db.add( VideoTaskEvent( video_task_id=task_id, event_type=event_type, event_message=message, payload=None, created_at=datetime.utcnow(), ) ) def serialize_task_summary(self, task: VideoGenerationTask) -> dict: return { "id": task.id, "taskNo": task.task_no, "taskStatus": task.task_status, "durationSeconds": task.duration_seconds, "estimatedPoints": task.estimated_points, "finalPoints": task.final_points, "resultVideoUrl": self._result_url(task), "failReason": task.fail_reason, "createdAt": task.created_at.isoformat(), "finishedAt": task.finished_at.isoformat() if task.finished_at else None, } def serialize_task_detail(self, task: VideoGenerationTask) -> dict: provider_account = self.repository.get_provider_account(task.provider_account_id) provider_model = self.repository.get_provider_model(task.provider_model_id) video_model = self.repository.get_video_model(task.video_model_id) events = self.repository.task_events(task.id).all() return { **self.serialize_task_summary(task), "videoModel": { "id": task.video_model_id, "name": video_model.model_name if video_model else "", }, "provider": { "providerCode": provider_account.provider_code if provider_account else "", "providerName": provider_account.provider_name if provider_account else "", "modelCode": provider_model.model_code if provider_model else "", "modelName": provider_model.model_name if provider_model else "", }, "ratio": task.ratio, "resolution": task.resolution, "prompt": task.prompt_text or "", "requestPayload": task.request_payload, "responsePayload": task.response_payload, "events": [ { "eventType": item.event_type, "eventMessage": item.event_message, "createdAt": item.created_at.isoformat(), } for item in events ], } def _result_url(self, task: VideoGenerationTask) -> str: if not task.result_asset_id: return "" asset = self.db.scalar(select(MediaAsset).where(MediaAsset.id == task.result_asset_id)) return asset.public_url if asset else ""