from sqlalchemy import delete, func, or_, select from sqlalchemy.orm import Session from app.common.errors.app_error import BusinessAppError, NotFoundAppError from app.models.entities import ( ProviderAccount, ProviderModel, VideoGenerationTask, VideoModelSupplierBinding, ) from app.modules.providers.repository import ProvidersRepository class ProvidersService: def __init__(self, db: Session) -> None: self.db = db self.repository = ProvidersRepository(db) def list_accounts(self) -> list[dict]: return [self.serialize_account(item) for item in self.repository.list_accounts().all()] def create_account(self, payload) -> dict: item = ProviderAccount( provider_code=payload.provider_code, provider_name=payload.provider_name, api_format=payload.api_format, base_url=payload.base_url, api_key_encrypted=payload.api_key, api_secret_encrypted=payload.api_secret, webhook_secret_encrypted=payload.webhook_secret, timeout_seconds=payload.timeout_seconds, max_retries=payload.max_retries, status=payload.status, remark=payload.remark, ) self.db.add(item) self.db.commit() self.db.refresh(item) return self.serialize_account(item) def update_account(self, account_id: int, payload) -> dict: item = self.repository.get_account(account_id) if not item: raise NotFoundAppError("provider account not found", code=60001) item.provider_code = payload.provider_code item.provider_name = payload.provider_name item.api_format = payload.api_format item.base_url = payload.base_url item.api_key_encrypted = payload.api_key item.api_secret_encrypted = payload.api_secret item.webhook_secret_encrypted = payload.webhook_secret item.timeout_seconds = payload.timeout_seconds item.max_retries = payload.max_retries item.status = payload.status item.remark = payload.remark self.db.commit() return self.serialize_account(item) def delete_account(self, account_id: int) -> dict: item = self.repository.get_account(account_id) if not item: raise NotFoundAppError("provider account not found", code=60001) model_ids = list( self.db.scalars( select(ProviderModel.id).where(ProviderModel.provider_account_id == account_id) ) ) task_conditions = [VideoGenerationTask.provider_account_id == account_id] if model_ids: task_conditions.append(VideoGenerationTask.provider_model_id.in_(model_ids)) task_count = self.db.scalar( select(func.count()) .select_from(VideoGenerationTask) .where(or_(*task_conditions)) ) if task_count: raise BusinessAppError( "该供应商已有任务记录,不能直接删除,请先停用。", code=60011, ) if model_ids: self.db.execute( delete(VideoModelSupplierBinding).where( VideoModelSupplierBinding.provider_model_id.in_(model_ids) ) ) self.db.execute( delete(ProviderModel).where(ProviderModel.provider_account_id == account_id) ) self.db.delete(item) self.db.commit() return {"id": account_id, "deleted": True} def list_models(self) -> list[dict]: accounts = {item.id: item for item in self.repository.list_accounts().all()} return [self.serialize_model(item, accounts) for item in self.repository.list_models().all()] def create_model(self, payload) -> dict: item = ProviderModel(**payload.model_dump()) self.db.add(item) self.db.commit() self.db.refresh(item) account = self.repository.get_account(item.provider_account_id) return self.serialize_model(item, {account.id: account} if account else {}) def update_model(self, model_id: int, payload) -> dict: item = self.repository.get_model(model_id) if not item: raise NotFoundAppError("provider model not found", code=60002) for key, value in payload.model_dump().items(): setattr(item, key, value) self.db.commit() account = self.repository.get_account(item.provider_account_id) return self.serialize_model(item, {account.id: account} if account else {}) def delete_model(self, model_id: int) -> dict: item = self.repository.get_model(model_id) if not item: raise NotFoundAppError("provider model not found", code=60002) task_count = self.db.scalar( select(func.count()) .select_from(VideoGenerationTask) .where(VideoGenerationTask.provider_model_id == model_id) ) if task_count: raise BusinessAppError( "该模型已有任务记录,不能直接删除,请先停用。", code=60012, ) self.db.execute( delete(VideoModelSupplierBinding).where( VideoModelSupplierBinding.provider_model_id == model_id ) ) self.db.delete(item) self.db.commit() return {"id": model_id, "deleted": True} @staticmethod def serialize_account(item: ProviderAccount) -> dict: return { "id": item.id, "providerCode": item.provider_code, "providerName": item.provider_name, "apiFormat": item.api_format, "baseUrl": item.base_url, "timeoutSeconds": item.timeout_seconds, "maxRetries": item.max_retries, "status": item.status, "remark": item.remark, "updatedAt": item.updated_at.isoformat(), } @staticmethod def serialize_model(item: ProviderModel, accounts: dict[int, ProviderAccount]) -> dict: account = accounts.get(item.provider_account_id) return { "id": item.id, "providerAccountId": item.provider_account_id, "providerName": account.provider_name if account else "", "modelCode": item.model_code, "modelName": item.model_name, "requestContentType": item.request_content_type, "supportsTextToVideo": item.supports_text_to_video, "supportsImageToVideo": item.supports_image_to_video, "supportsVideoReference": item.supports_video_reference, "supportsAudioReference": item.supports_audio_reference, "supportsGenerateAudio": item.supports_generate_audio, "supportsWebhook": item.supports_webhook, "minDuration": item.min_duration, "maxDuration": item.max_duration, "defaultRatio": item.default_ratio, "defaultResolution": item.default_resolution, "status": item.status, }