Files

113 lines
4.6 KiB
Python

from sqlalchemy.orm import Session
from app.common.errors.app_error import NotFoundAppError
from app.models.entities import ProviderAccount, ProviderModel
from app.modules.providers.repository import ProvidersRepository
class ProvidersService:
def __init__(self, db: Session) -> None:
self.db = db
self.repository = ProvidersRepository(db)
def list_accounts(self) -> list[dict]:
return [self.serialize_account(item) for item in self.repository.list_accounts().all()]
def create_account(self, payload) -> dict:
item = ProviderAccount(
provider_code=payload.provider_code,
provider_name=payload.provider_name,
api_format=payload.api_format,
base_url=payload.base_url,
api_key_encrypted=payload.api_key,
api_secret_encrypted=payload.api_secret,
webhook_secret_encrypted=payload.webhook_secret,
timeout_seconds=payload.timeout_seconds,
max_retries=payload.max_retries,
status=payload.status,
remark=payload.remark,
)
self.db.add(item)
self.db.commit()
self.db.refresh(item)
return self.serialize_account(item)
def update_account(self, account_id: int, payload) -> dict:
item = self.repository.get_account(account_id)
if not item:
raise NotFoundAppError("provider account not found", code=60001)
item.provider_code = payload.provider_code
item.provider_name = payload.provider_name
item.api_format = payload.api_format
item.base_url = payload.base_url
item.api_key_encrypted = payload.api_key
item.api_secret_encrypted = payload.api_secret
item.webhook_secret_encrypted = payload.webhook_secret
item.timeout_seconds = payload.timeout_seconds
item.max_retries = payload.max_retries
item.status = payload.status
item.remark = payload.remark
self.db.commit()
return self.serialize_account(item)
def list_models(self) -> list[dict]:
accounts = {item.id: item for item in self.repository.list_accounts().all()}
return [self.serialize_model(item, accounts) for item in self.repository.list_models().all()]
def create_model(self, payload) -> dict:
item = ProviderModel(**payload.model_dump())
self.db.add(item)
self.db.commit()
self.db.refresh(item)
account = self.repository.get_account(item.provider_account_id)
return self.serialize_model(item, {account.id: account} if account else {})
def update_model(self, model_id: int, payload) -> dict:
item = self.repository.get_model(model_id)
if not item:
raise NotFoundAppError("provider model not found", code=60002)
for key, value in payload.model_dump().items():
setattr(item, key, value)
self.db.commit()
account = self.repository.get_account(item.provider_account_id)
return self.serialize_model(item, {account.id: account} if account else {})
@staticmethod
def serialize_account(item: ProviderAccount) -> dict:
return {
"id": item.id,
"providerCode": item.provider_code,
"providerName": item.provider_name,
"apiFormat": item.api_format,
"baseUrl": item.base_url,
"timeoutSeconds": item.timeout_seconds,
"maxRetries": item.max_retries,
"status": item.status,
"remark": item.remark,
"updatedAt": item.updated_at.isoformat(),
}
@staticmethod
def serialize_model(item: ProviderModel, accounts: dict[int, ProviderAccount]) -> dict:
account = accounts.get(item.provider_account_id)
return {
"id": item.id,
"providerAccountId": item.provider_account_id,
"providerName": account.provider_name if account else "",
"modelCode": item.model_code,
"modelName": item.model_name,
"requestContentType": item.request_content_type,
"supportsTextToVideo": item.supports_text_to_video,
"supportsImageToVideo": item.supports_image_to_video,
"supportsVideoReference": item.supports_video_reference,
"supportsAudioReference": item.supports_audio_reference,
"supportsGenerateAudio": item.supports_generate_audio,
"supportsWebhook": item.supports_webhook,
"minDuration": item.min_duration,
"maxDuration": item.max_duration,
"defaultRatio": item.default_ratio,
"defaultResolution": item.default_resolution,
"status": item.status,
}