from datetime import datetime from sqlalchemy import or_, select from sqlalchemy.orm import Session from app.models.entities import ( MediaAsset, PricingRule, ProviderAccount, ProviderModel, VideoGenerationTask, VideoModel, VideoModelSupplierBinding, VideoTaskEvent, ) class VideoTasksRepository: def __init__(self, db: Session) -> None: self.db = db def get_video_model(self, model_id: int) -> VideoModel | None: return self.db.scalar(select(VideoModel).where(VideoModel.id == model_id)) def get_active_pricing(self, video_model_id: int) -> PricingRule | None: now = datetime.utcnow() return self.db.scalar( select(PricingRule) .where( PricingRule.video_model_id == video_model_id, PricingRule.status == 1, PricingRule.effective_at <= now, or_(PricingRule.expired_at.is_(None), PricingRule.expired_at > now), ) .order_by(PricingRule.version_no.desc(), PricingRule.id.desc()) ) def get_bindings(self, video_model_id: int) -> list[VideoModelSupplierBinding]: return ( self.db.query(VideoModelSupplierBinding) .filter( VideoModelSupplierBinding.video_model_id == video_model_id, VideoModelSupplierBinding.status == 1, ) .order_by( VideoModelSupplierBinding.is_primary.desc(), VideoModelSupplierBinding.routing_priority.asc(), VideoModelSupplierBinding.id.asc(), ) .all() ) def get_provider_model(self, provider_model_id: int) -> ProviderModel | None: return self.db.scalar(select(ProviderModel).where(ProviderModel.id == provider_model_id)) def get_provider_account(self, provider_account_id: int) -> ProviderAccount | None: return self.db.scalar(select(ProviderAccount).where(ProviderAccount.id == provider_account_id)) def list_assets(self, user_id: int, asset_ids: list[int]) -> list[MediaAsset]: if not asset_ids: return [] return ( self.db.query(MediaAsset) .filter( MediaAsset.user_id == user_id, MediaAsset.id.in_(asset_ids), MediaAsset.status == "active", ) .all() ) def list_tasks(self, user_id: int): return ( self.db.query(VideoGenerationTask) .filter(VideoGenerationTask.user_id == user_id, VideoGenerationTask.user_visible == 1) .order_by(VideoGenerationTask.id.desc()) ) def get_task(self, user_id: int, task_no: str) -> VideoGenerationTask | None: return self.db.scalar( select(VideoGenerationTask).where( VideoGenerationTask.user_id == user_id, VideoGenerationTask.task_no == task_no, ) ) def get_task_by_id(self, task_id: int) -> VideoGenerationTask | None: return self.db.scalar(select(VideoGenerationTask).where(VideoGenerationTask.id == task_id)) def task_events(self, task_id: int): return ( self.db.query(VideoTaskEvent) .filter(VideoTaskEvent.video_task_id == task_id) .order_by(VideoTaskEvent.id.asc()) )