113 lines
4.6 KiB
Python
113 lines
4.6 KiB
Python
from sqlalchemy.orm import Session
|
|
|
|
from app.common.errors.app_error import NotFoundAppError
|
|
from app.models.entities import ProviderAccount, ProviderModel
|
|
from app.modules.providers.repository import ProvidersRepository
|
|
|
|
|
|
class ProvidersService:
|
|
def __init__(self, db: Session) -> None:
|
|
self.db = db
|
|
self.repository = ProvidersRepository(db)
|
|
|
|
def list_accounts(self) -> list[dict]:
|
|
return [self.serialize_account(item) for item in self.repository.list_accounts().all()]
|
|
|
|
def create_account(self, payload) -> dict:
|
|
item = ProviderAccount(
|
|
provider_code=payload.provider_code,
|
|
provider_name=payload.provider_name,
|
|
api_format=payload.api_format,
|
|
base_url=payload.base_url,
|
|
api_key_encrypted=payload.api_key,
|
|
api_secret_encrypted=payload.api_secret,
|
|
webhook_secret_encrypted=payload.webhook_secret,
|
|
timeout_seconds=payload.timeout_seconds,
|
|
max_retries=payload.max_retries,
|
|
status=payload.status,
|
|
remark=payload.remark,
|
|
)
|
|
self.db.add(item)
|
|
self.db.commit()
|
|
self.db.refresh(item)
|
|
return self.serialize_account(item)
|
|
|
|
def update_account(self, account_id: int, payload) -> dict:
|
|
item = self.repository.get_account(account_id)
|
|
if not item:
|
|
raise NotFoundAppError("provider account not found", code=60001)
|
|
item.provider_code = payload.provider_code
|
|
item.provider_name = payload.provider_name
|
|
item.api_format = payload.api_format
|
|
item.base_url = payload.base_url
|
|
item.api_key_encrypted = payload.api_key
|
|
item.api_secret_encrypted = payload.api_secret
|
|
item.webhook_secret_encrypted = payload.webhook_secret
|
|
item.timeout_seconds = payload.timeout_seconds
|
|
item.max_retries = payload.max_retries
|
|
item.status = payload.status
|
|
item.remark = payload.remark
|
|
self.db.commit()
|
|
return self.serialize_account(item)
|
|
|
|
def list_models(self) -> list[dict]:
|
|
accounts = {item.id: item for item in self.repository.list_accounts().all()}
|
|
return [self.serialize_model(item, accounts) for item in self.repository.list_models().all()]
|
|
|
|
def create_model(self, payload) -> dict:
|
|
item = ProviderModel(**payload.model_dump())
|
|
self.db.add(item)
|
|
self.db.commit()
|
|
self.db.refresh(item)
|
|
account = self.repository.get_account(item.provider_account_id)
|
|
return self.serialize_model(item, {account.id: account} if account else {})
|
|
|
|
def update_model(self, model_id: int, payload) -> dict:
|
|
item = self.repository.get_model(model_id)
|
|
if not item:
|
|
raise NotFoundAppError("provider model not found", code=60002)
|
|
for key, value in payload.model_dump().items():
|
|
setattr(item, key, value)
|
|
self.db.commit()
|
|
account = self.repository.get_account(item.provider_account_id)
|
|
return self.serialize_model(item, {account.id: account} if account else {})
|
|
|
|
@staticmethod
|
|
def serialize_account(item: ProviderAccount) -> dict:
|
|
return {
|
|
"id": item.id,
|
|
"providerCode": item.provider_code,
|
|
"providerName": item.provider_name,
|
|
"apiFormat": item.api_format,
|
|
"baseUrl": item.base_url,
|
|
"timeoutSeconds": item.timeout_seconds,
|
|
"maxRetries": item.max_retries,
|
|
"status": item.status,
|
|
"remark": item.remark,
|
|
"updatedAt": item.updated_at.isoformat(),
|
|
}
|
|
|
|
@staticmethod
|
|
def serialize_model(item: ProviderModel, accounts: dict[int, ProviderAccount]) -> dict:
|
|
account = accounts.get(item.provider_account_id)
|
|
return {
|
|
"id": item.id,
|
|
"providerAccountId": item.provider_account_id,
|
|
"providerName": account.provider_name if account else "",
|
|
"modelCode": item.model_code,
|
|
"modelName": item.model_name,
|
|
"requestContentType": item.request_content_type,
|
|
"supportsTextToVideo": item.supports_text_to_video,
|
|
"supportsImageToVideo": item.supports_image_to_video,
|
|
"supportsVideoReference": item.supports_video_reference,
|
|
"supportsAudioReference": item.supports_audio_reference,
|
|
"supportsGenerateAudio": item.supports_generate_audio,
|
|
"supportsWebhook": item.supports_webhook,
|
|
"minDuration": item.min_duration,
|
|
"maxDuration": item.max_duration,
|
|
"defaultRatio": item.default_ratio,
|
|
"defaultResolution": item.default_resolution,
|
|
"status": item.status,
|
|
}
|
|
|