feat: initialize aivideo project
This commit is contained in:
343
backend/app/modules/video_tasks/service.py
Normal file
343
backend/app/modules/video_tasks/service.py
Normal file
@@ -0,0 +1,343 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.errors.app_error import BusinessAppError, NotFoundAppError
|
||||
from app.common.utils.id_gen import new_order_no
|
||||
from app.core.providers import build_adapter
|
||||
from app.core.storage import storage_service
|
||||
from app.models.entities import MediaAsset, VideoGenerationTask, VideoTaskEvent
|
||||
from app.modules.video_tasks.repository import VideoTasksRepository
|
||||
from app.modules.wallets.service import WalletService
|
||||
|
||||
|
||||
class VideoTasksService:
|
||||
FINAL_STATUSES = {"succeeded", "failed", "cancelled", "timed_out"}
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.repository = VideoTasksRepository(db)
|
||||
self.wallet_service = WalletService(db)
|
||||
|
||||
def create_task(self, user_id: int, payload) -> dict:
|
||||
video_model = self.repository.get_video_model(payload.videoModelId)
|
||||
if not video_model or video_model.status != 1:
|
||||
raise BusinessAppError("video model unavailable", code=50002)
|
||||
pricing = self.repository.get_active_pricing(video_model.id)
|
||||
if not pricing:
|
||||
raise BusinessAppError("pricing rule unavailable", code=50003)
|
||||
binding = self._select_binding(video_model.id)
|
||||
provider_model = self.repository.get_provider_model(binding.provider_model_id)
|
||||
provider_account = self.repository.get_provider_account(provider_model.provider_account_id)
|
||||
estimated_points = max(
|
||||
pricing.minimum_points,
|
||||
pricing.points_per_second * payload.durationSeconds,
|
||||
)
|
||||
normalized = self._build_normalized_payload(user_id, payload, provider_model.id, provider_account.id)
|
||||
task = VideoGenerationTask(
|
||||
task_no=new_order_no("vt"),
|
||||
user_id=user_id,
|
||||
video_model_id=video_model.id,
|
||||
provider_account_id=provider_account.id,
|
||||
provider_model_id=provider_model.id,
|
||||
provider_binding_id=binding.id,
|
||||
pricing_rule_id=pricing.id,
|
||||
task_status="queued",
|
||||
generation_mode=self._infer_generation_mode(payload),
|
||||
prompt_text=payload.prompt,
|
||||
request_payload=normalized,
|
||||
duration_seconds=payload.durationSeconds,
|
||||
ratio=payload.ratio,
|
||||
resolution=payload.resolution,
|
||||
generate_audio=payload.generateAudio,
|
||||
estimated_points=estimated_points,
|
||||
frozen_points=estimated_points,
|
||||
)
|
||||
self.db.add(task)
|
||||
self.db.flush()
|
||||
self.wallet_service.freeze_points(
|
||||
user_id,
|
||||
estimated_points,
|
||||
related_type="video_task",
|
||||
related_id=task.id,
|
||||
remark=f"freeze for {task.task_no}",
|
||||
)
|
||||
self._add_event(task.id, "created", "task created")
|
||||
self._submit_task(task, provider_account, provider_model)
|
||||
self.db.commit()
|
||||
return {
|
||||
"taskNo": task.task_no,
|
||||
"taskStatus": task.task_status,
|
||||
"estimatedPoints": task.estimated_points,
|
||||
"frozenPoints": task.frozen_points,
|
||||
}
|
||||
|
||||
def list_tasks(self, user_id: int) -> list[dict]:
|
||||
tasks = self.repository.list_tasks(user_id).limit(100).all()
|
||||
for task in tasks:
|
||||
self._refresh_task_progress(task)
|
||||
self.db.commit()
|
||||
return [self.serialize_task_summary(task) for task in tasks if task.user_visible]
|
||||
|
||||
def get_task_detail(self, user_id: int, task_no: str) -> dict:
|
||||
task = self.repository.get_task(user_id, task_no)
|
||||
if not task:
|
||||
raise NotFoundAppError("task not found", code=50006)
|
||||
self._refresh_task_progress(task)
|
||||
self.db.commit()
|
||||
return self.serialize_task_detail(task)
|
||||
|
||||
def retry_task(self, user_id: int, task_no: str) -> dict:
|
||||
task = self.repository.get_task(user_id, task_no)
|
||||
if not task:
|
||||
raise NotFoundAppError("task not found", code=50006)
|
||||
payload = task.request_payload
|
||||
class RetryPayload:
|
||||
videoModelId = payload["videoModelId"]
|
||||
prompt = payload["prompt"]
|
||||
durationSeconds = payload["durationSeconds"]
|
||||
resolution = payload["resolution"]
|
||||
ratio = payload["ratio"]
|
||||
generateAudio = payload["generateAudio"]
|
||||
referenceImageAssetIds = payload.get("referenceImageAssetIds", [])
|
||||
referenceVideoAssetIds = payload.get("referenceVideoAssetIds", [])
|
||||
referenceAudioAssetIds = payload.get("referenceAudioAssetIds", [])
|
||||
|
||||
return self.create_task(user_id, RetryPayload)
|
||||
|
||||
def cancel_task(self, user_id: int, task_no: str) -> dict:
|
||||
task = self.repository.get_task(user_id, task_no)
|
||||
if not task:
|
||||
raise NotFoundAppError("task not found", code=50006)
|
||||
if task.task_status in self.FINAL_STATUSES:
|
||||
return {"taskNo": task.task_no, "taskStatus": task.task_status}
|
||||
task.task_status = "cancelled"
|
||||
task.finished_at = datetime.utcnow()
|
||||
self.wallet_service.release_frozen_points(
|
||||
user_id,
|
||||
task.frozen_points,
|
||||
related_type="video_task",
|
||||
related_id=task.id,
|
||||
remark=f"cancel {task.task_no}",
|
||||
)
|
||||
self._add_event(task.id, "cancelled", "task cancelled")
|
||||
self.db.commit()
|
||||
return {"taskNo": task.task_no, "taskStatus": task.task_status}
|
||||
|
||||
def delete_task(self, user_id: int, task_no: str) -> dict:
|
||||
task = self.repository.get_task(user_id, task_no)
|
||||
if not task:
|
||||
raise NotFoundAppError("task not found", code=50006)
|
||||
task.user_visible = False
|
||||
task.user_deleted_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
return {"taskNo": task.task_no, "deleted": True}
|
||||
|
||||
def admin_list_tasks(self) -> list[dict]:
|
||||
tasks = self.db.query(VideoGenerationTask).order_by(VideoGenerationTask.id.desc()).limit(200).all()
|
||||
for task in tasks:
|
||||
self._refresh_task_progress(task)
|
||||
self.db.commit()
|
||||
return [self.serialize_task_summary(task) for task in tasks]
|
||||
|
||||
def admin_get_task(self, task_id: int) -> dict:
|
||||
task = self.repository.get_task_by_id(task_id)
|
||||
if not task:
|
||||
raise NotFoundAppError("task not found", code=50006)
|
||||
self._refresh_task_progress(task)
|
||||
self.db.commit()
|
||||
return self.serialize_task_detail(task)
|
||||
|
||||
def admin_retry_task(self, task_id: int) -> dict:
|
||||
task = self.repository.get_task_by_id(task_id)
|
||||
if not task:
|
||||
raise NotFoundAppError("task not found", code=50006)
|
||||
return self.retry_task(task.user_id, task.task_no)
|
||||
|
||||
def admin_refund_task(self, task_id: int) -> dict:
|
||||
task = self.repository.get_task_by_id(task_id)
|
||||
if not task:
|
||||
raise NotFoundAppError("task not found", code=50006)
|
||||
tx = self.wallet_service.add_points(
|
||||
task.user_id,
|
||||
task.final_points or task.frozen_points,
|
||||
biz_type="refund",
|
||||
related_type="video_task",
|
||||
related_id=task.id,
|
||||
remark=f"manual refund for {task.task_no}",
|
||||
operator_type="admin",
|
||||
)
|
||||
self._add_event(task.id, "refund", "manual refund")
|
||||
self.db.commit()
|
||||
return {"taskNo": task.task_no, "refundTransactionNo": tx.transaction_no}
|
||||
|
||||
def _select_binding(self, video_model_id: int):
|
||||
bindings = self.repository.get_bindings(video_model_id)
|
||||
if not bindings:
|
||||
raise BusinessAppError("no available provider", code=50003)
|
||||
return bindings[0]
|
||||
|
||||
def _build_normalized_payload(self, user_id: int, payload, provider_model_id: int, provider_account_id: int) -> dict:
|
||||
image_assets = self.repository.list_assets(user_id, payload.referenceImageAssetIds)
|
||||
video_assets = self.repository.list_assets(user_id, payload.referenceVideoAssetIds)
|
||||
audio_assets = self.repository.list_assets(user_id, payload.referenceAudioAssetIds)
|
||||
return {
|
||||
"videoModelId": payload.videoModelId,
|
||||
"providerModelId": provider_model_id,
|
||||
"providerAccountId": provider_account_id,
|
||||
"prompt": payload.prompt,
|
||||
"durationSeconds": payload.durationSeconds,
|
||||
"resolution": payload.resolution,
|
||||
"ratio": payload.ratio,
|
||||
"generateAudio": payload.generateAudio,
|
||||
"referenceImageAssetIds": payload.referenceImageAssetIds,
|
||||
"referenceVideoAssetIds": payload.referenceVideoAssetIds,
|
||||
"referenceAudioAssetIds": payload.referenceAudioAssetIds,
|
||||
"referenceImages": [{"assetId": item.id, "url": item.public_url} for item in image_assets],
|
||||
"referenceVideos": [{"assetId": item.id, "url": item.public_url} for item in video_assets],
|
||||
"referenceAudios": [{"assetId": item.id, "url": item.public_url} for item in audio_assets],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _infer_generation_mode(payload) -> str:
|
||||
if payload.referenceImageAssetIds or payload.referenceVideoAssetIds or payload.referenceAudioAssetIds:
|
||||
return "multimodal"
|
||||
return "text_to_video"
|
||||
|
||||
def _submit_task(self, task: VideoGenerationTask, provider_account, provider_model) -> None:
|
||||
adapter = build_adapter(provider_account, provider_model)
|
||||
result = adapter.submit_task(task.request_payload)
|
||||
task.external_task_id = result["externalTaskId"]
|
||||
task.task_status = result["normalizedStatus"]
|
||||
task.submitted_at = datetime.utcnow()
|
||||
task.response_payload = result
|
||||
self._add_event(task.id, "submitted", "task submitted to provider")
|
||||
|
||||
def _refresh_task_progress(self, task: VideoGenerationTask) -> None:
|
||||
if task.task_status in self.FINAL_STATUSES:
|
||||
return
|
||||
provider_model = self.repository.get_provider_model(task.provider_model_id)
|
||||
provider_account = self.repository.get_provider_account(task.provider_account_id)
|
||||
adapter = build_adapter(provider_account, provider_model)
|
||||
result = adapter.query_task(task)
|
||||
task.response_payload = result
|
||||
status = result["normalizedStatus"]
|
||||
if status == "succeeded" and task.result_asset_id is None:
|
||||
content = adapter.download_result(task)
|
||||
stored = storage_service.save_bytes(
|
||||
content,
|
||||
filename=f"{task.task_no}.mp4",
|
||||
folder="generated/videos",
|
||||
)
|
||||
asset = MediaAsset(
|
||||
asset_no=new_order_no("asset"),
|
||||
user_id=task.user_id,
|
||||
media_type="video",
|
||||
source_type="generated",
|
||||
original_filename=f"{task.task_no}.mp4",
|
||||
mime_type="video/mp4",
|
||||
file_ext=".mp4",
|
||||
file_size=stored["file_size"],
|
||||
storage_provider="local",
|
||||
storage_bucket="local",
|
||||
storage_key=stored["storage_key"],
|
||||
public_url=stored["public_url"],
|
||||
sha256=stored["sha256"],
|
||||
status="active",
|
||||
)
|
||||
self.db.add(asset)
|
||||
self.db.flush()
|
||||
task.result_asset_id = asset.id
|
||||
task.final_points = task.estimated_points
|
||||
task.task_status = "succeeded"
|
||||
task.finished_at = datetime.utcnow()
|
||||
self.wallet_service.consume_frozen_points(
|
||||
task.user_id,
|
||||
task.frozen_points,
|
||||
related_type="video_task",
|
||||
related_id=task.id,
|
||||
remark=f"consume for {task.task_no}",
|
||||
)
|
||||
self.wallet_service.try_issue_invite_reward(task.user_id, task.id, task.final_points)
|
||||
self._add_event(task.id, "succeeded", "task succeeded")
|
||||
elif status == "failed":
|
||||
task.task_status = "failed"
|
||||
task.fail_reason = result.get("rawResponse", {}).get("message", "provider failed")
|
||||
task.finished_at = datetime.utcnow()
|
||||
self.wallet_service.release_frozen_points(
|
||||
task.user_id,
|
||||
task.frozen_points,
|
||||
related_type="video_task",
|
||||
related_id=task.id,
|
||||
remark=f"release for {task.task_no}",
|
||||
)
|
||||
self._add_event(task.id, "failed", "task failed")
|
||||
else:
|
||||
task.task_status = status
|
||||
|
||||
def _add_event(self, task_id: int, event_type: str, message: str) -> None:
|
||||
self.db.add(
|
||||
VideoTaskEvent(
|
||||
video_task_id=task_id,
|
||||
event_type=event_type,
|
||||
event_message=message,
|
||||
payload=None,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
|
||||
def serialize_task_summary(self, task: VideoGenerationTask) -> dict:
|
||||
return {
|
||||
"id": task.id,
|
||||
"taskNo": task.task_no,
|
||||
"taskStatus": task.task_status,
|
||||
"durationSeconds": task.duration_seconds,
|
||||
"estimatedPoints": task.estimated_points,
|
||||
"finalPoints": task.final_points,
|
||||
"resultVideoUrl": self._result_url(task),
|
||||
"failReason": task.fail_reason,
|
||||
"createdAt": task.created_at.isoformat(),
|
||||
"finishedAt": task.finished_at.isoformat() if task.finished_at else None,
|
||||
}
|
||||
|
||||
def serialize_task_detail(self, task: VideoGenerationTask) -> dict:
|
||||
provider_account = self.repository.get_provider_account(task.provider_account_id)
|
||||
provider_model = self.repository.get_provider_model(task.provider_model_id)
|
||||
video_model = self.repository.get_video_model(task.video_model_id)
|
||||
events = self.repository.task_events(task.id).all()
|
||||
return {
|
||||
**self.serialize_task_summary(task),
|
||||
"videoModel": {
|
||||
"id": task.video_model_id,
|
||||
"name": video_model.model_name if video_model else "",
|
||||
},
|
||||
"provider": {
|
||||
"providerCode": provider_account.provider_code if provider_account else "",
|
||||
"providerName": provider_account.provider_name if provider_account else "",
|
||||
"modelCode": provider_model.model_code if provider_model else "",
|
||||
"modelName": provider_model.model_name if provider_model else "",
|
||||
},
|
||||
"ratio": task.ratio,
|
||||
"resolution": task.resolution,
|
||||
"prompt": task.prompt_text or "",
|
||||
"requestPayload": task.request_payload,
|
||||
"responsePayload": task.response_payload,
|
||||
"events": [
|
||||
{
|
||||
"eventType": item.event_type,
|
||||
"eventMessage": item.event_message,
|
||||
"createdAt": item.created_at.isoformat(),
|
||||
}
|
||||
for item in events
|
||||
],
|
||||
}
|
||||
|
||||
def _result_url(self, task: VideoGenerationTask) -> str:
|
||||
if not task.result_asset_id:
|
||||
return ""
|
||||
asset = self.db.scalar(select(MediaAsset).where(MediaAsset.id == task.result_asset_id))
|
||||
return asset.public_url if asset else ""
|
||||
Reference in New Issue
Block a user