feat: initialize aivideo project
This commit is contained in:
50
backend/app/modules/admins/repository.py
Normal file
50
backend/app/modules/admins/repository.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.entities import (
|
||||
AdminUser,
|
||||
InviteRelation,
|
||||
RechargeOrder,
|
||||
User,
|
||||
VideoGenerationTask,
|
||||
)
|
||||
|
||||
|
||||
class AdminsRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_admin_by_username(self, username: str) -> AdminUser | None:
|
||||
return self.db.scalar(select(AdminUser).where(AdminUser.username == username))
|
||||
|
||||
def list_users(self):
|
||||
return self.db.query(User).order_by(User.id.desc())
|
||||
|
||||
def get_user(self, user_id: int) -> User | None:
|
||||
return self.db.scalar(select(User).where(User.id == user_id))
|
||||
|
||||
def count_users(self) -> int:
|
||||
return self.db.query(func.count(User.id)).scalar() or 0
|
||||
|
||||
def count_paid_orders(self) -> int:
|
||||
return (
|
||||
self.db.query(func.count(RechargeOrder.id))
|
||||
.filter(RechargeOrder.status == "paid")
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
def count_tasks(self) -> int:
|
||||
return self.db.query(func.count(VideoGenerationTask.id)).scalar() or 0
|
||||
|
||||
def count_success_tasks(self) -> int:
|
||||
return (
|
||||
self.db.query(func.count(VideoGenerationTask.id))
|
||||
.filter(VideoGenerationTask.task_status == "succeeded")
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
def invite_relations(self):
|
||||
return self.db.query(InviteRelation).order_by(InviteRelation.id.desc())
|
||||
|
||||
86
backend/app/modules/admins/router.py
Normal file
86
backend/app/modules/admins/router.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from fastapi import APIRouter, Depends, Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.db.session import get_db
|
||||
from app.common.responses.api_response import success_response
|
||||
from app.common.security.deps import get_current_admin, require_admin_permission
|
||||
from app.modules.admins.schema import AdminLoginRequest, ManualAdjustRequest
|
||||
from app.modules.admins.service import AdminsService
|
||||
|
||||
|
||||
auth_router = APIRouter(prefix="/api/v1/admin/auth", tags=["admin-auth"])
|
||||
router = APIRouter(prefix="/api/v1/admin", tags=["admin"])
|
||||
|
||||
|
||||
@auth_router.post("/login")
|
||||
def admin_login(
|
||||
payload: AdminLoginRequest,
|
||||
response: Response,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return {"code": 0, "message": "ok", "data": AdminsService(db).login(payload, response)}
|
||||
|
||||
|
||||
@auth_router.post("/logout")
|
||||
def admin_logout(response: Response, db: Session = Depends(get_db)):
|
||||
AdminsService(db).logout(response)
|
||||
return {"code": 0, "message": "ok", "data": {"success": True}}
|
||||
|
||||
|
||||
@auth_router.get("/me")
|
||||
def admin_me(admin=Depends(get_current_admin)):
|
||||
return success_response(AdminsService.serialize_admin(admin))
|
||||
|
||||
|
||||
@router.get("/dashboard")
|
||||
def dashboard(
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(AdminsService(db).dashboard())
|
||||
|
||||
|
||||
@router.get("/users")
|
||||
def list_users(
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(AdminsService(db).list_users())
|
||||
|
||||
|
||||
@router.get("/users/{user_id}")
|
||||
def get_user_detail(
|
||||
user_id: int,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(AdminsService(db).get_user_detail(user_id))
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/wallet-adjust")
|
||||
def manual_adjust_wallet(
|
||||
user_id: int,
|
||||
payload: ManualAdjustRequest,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(
|
||||
AdminsService(db).manual_adjust_wallet(user_id, payload.amount_points, payload.reason)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/users/{user_id}/invite-relations")
|
||||
def user_invite_relations(
|
||||
user_id: int,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(AdminsService(db).user_invite_relations(user_id))
|
||||
|
||||
|
||||
@router.get("/invite-relations")
|
||||
def list_invite_relations(
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(AdminsService(db).list_invite_relations())
|
||||
12
backend/app/modules/admins/schema.py
Normal file
12
backend/app/modules/admins/schema.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AdminLoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class ManualAdjustRequest(BaseModel):
|
||||
amount_points: int = Field(alias="amountPoints")
|
||||
reason: str
|
||||
|
||||
134
backend/app/modules/admins/service.py
Normal file
134
backend/app/modules/admins/service.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.errors.app_error import AuthenticationError, NotFoundAppError
|
||||
from app.common.security.jwt import (
|
||||
clear_auth_cookies,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
set_auth_cookies,
|
||||
)
|
||||
from app.common.security.password import verify_password
|
||||
from app.models.entities import InviteRelation
|
||||
from app.modules.admins.repository import AdminsRepository
|
||||
from app.modules.wallets.service import WalletService
|
||||
|
||||
|
||||
class AdminsService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.repository = AdminsRepository(db)
|
||||
self.wallet_service = WalletService(db)
|
||||
|
||||
def login(self, payload, response: Response) -> dict:
|
||||
admin = self.repository.get_admin_by_username(payload.username)
|
||||
if not admin or not verify_password(payload.password, admin.password_hash):
|
||||
raise AuthenticationError("invalid admin credentials")
|
||||
admin.last_login_at = datetime.utcnow()
|
||||
access_token = create_access_token(admin.username, scope="admin")
|
||||
refresh_token = create_refresh_token(admin.username, scope="admin")
|
||||
set_auth_cookies(response, access_token, refresh_token, prefix="admin")
|
||||
self.db.commit()
|
||||
return self.serialize_admin(admin)
|
||||
|
||||
def logout(self, response: Response) -> None:
|
||||
clear_auth_cookies(response, prefix="admin")
|
||||
|
||||
@staticmethod
|
||||
def serialize_admin(admin) -> dict:
|
||||
return {
|
||||
"id": admin.id,
|
||||
"username": admin.username,
|
||||
"nickname": admin.nickname,
|
||||
"isSuperAdmin": admin.is_super_admin,
|
||||
}
|
||||
|
||||
def dashboard(self) -> dict:
|
||||
total_tasks = self.repository.count_tasks()
|
||||
success_tasks = self.repository.count_success_tasks()
|
||||
return {
|
||||
"users": self.repository.count_users(),
|
||||
"paidOrders": self.repository.count_paid_orders(),
|
||||
"tasks": total_tasks,
|
||||
"successRate": round(success_tasks / total_tasks * 100, 2) if total_tasks else 0,
|
||||
}
|
||||
|
||||
def list_users(self) -> list[dict]:
|
||||
rows = self.repository.list_users().limit(200).all()
|
||||
return [
|
||||
{
|
||||
"id": item.id,
|
||||
"publicId": item.public_id,
|
||||
"username": item.username or "",
|
||||
"nickname": item.nickname,
|
||||
"email": item.email or "",
|
||||
"status": item.status,
|
||||
"createdAt": item.created_at.isoformat(),
|
||||
}
|
||||
for item in rows
|
||||
]
|
||||
|
||||
def get_user_detail(self, user_id: int) -> dict:
|
||||
user = self.repository.get_user(user_id)
|
||||
if not user:
|
||||
raise NotFoundAppError("user not found", code=10020)
|
||||
wallet = self.wallet_service.get_wallet_summary(user.id)
|
||||
return {
|
||||
"id": user.id,
|
||||
"publicId": user.public_id,
|
||||
"username": user.username or "",
|
||||
"nickname": user.nickname,
|
||||
"email": user.email or "",
|
||||
"status": user.status,
|
||||
"wallet": wallet,
|
||||
}
|
||||
|
||||
def manual_adjust_wallet(self, user_id: int, amount_points: int, reason: str) -> dict:
|
||||
user = self.repository.get_user(user_id)
|
||||
if not user:
|
||||
raise NotFoundAppError("user not found", code=10020)
|
||||
tx = self.wallet_service.add_points(
|
||||
user.id,
|
||||
amount_points,
|
||||
biz_type="manual_adjust",
|
||||
related_type="user",
|
||||
related_id=user.id,
|
||||
remark=reason,
|
||||
operator_type="admin",
|
||||
)
|
||||
self.db.commit()
|
||||
return {"transactionNo": tx.transaction_no, "amountPoints": amount_points}
|
||||
|
||||
def user_invite_relations(self, user_id: int) -> list[dict]:
|
||||
rows = self.repository.invite_relations().filter(
|
||||
(InviteRelation.inviter_user_id == user_id) | (InviteRelation.invitee_user_id == user_id)
|
||||
).all()
|
||||
return [
|
||||
{
|
||||
"id": item.id,
|
||||
"inviterUserId": item.inviter_user_id,
|
||||
"inviteeUserId": item.invitee_user_id,
|
||||
"rewardStatus": item.reward_status,
|
||||
"rewardPoints": item.reward_points,
|
||||
"createdAt": item.created_at.isoformat(),
|
||||
}
|
||||
for item in rows
|
||||
]
|
||||
|
||||
def list_invite_relations(self) -> list[dict]:
|
||||
rows = self.repository.invite_relations().limit(200).all()
|
||||
return [
|
||||
{
|
||||
"id": item.id,
|
||||
"inviterUserId": item.inviter_user_id,
|
||||
"inviteeUserId": item.invitee_user_id,
|
||||
"rewardStatus": item.reward_status,
|
||||
"rewardPoints": item.reward_points,
|
||||
"rewardedAt": item.rewarded_at.isoformat() if item.rewarded_at else None,
|
||||
"createdAt": item.created_at.isoformat(),
|
||||
}
|
||||
for item in rows
|
||||
]
|
||||
|
||||
26
backend/app/modules/assets/repository.py
Normal file
26
backend/app/modules/assets/repository.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.entities import MediaAsset
|
||||
|
||||
|
||||
class AssetsRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def list_assets(self, user_id: int):
|
||||
return (
|
||||
self.db.query(MediaAsset)
|
||||
.filter(MediaAsset.user_id == user_id, MediaAsset.status == "active")
|
||||
.order_by(MediaAsset.id.desc())
|
||||
)
|
||||
|
||||
def get_asset(self, user_id: int, asset_id: int) -> MediaAsset | None:
|
||||
return self.db.scalar(
|
||||
select(MediaAsset).where(
|
||||
MediaAsset.id == asset_id,
|
||||
MediaAsset.user_id == user_id,
|
||||
MediaAsset.status == "active",
|
||||
)
|
||||
)
|
||||
|
||||
44
backend/app/modules/assets/router.py
Normal file
44
backend/app/modules/assets/router.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from fastapi import APIRouter, Depends, File, Form, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.db.session import get_db
|
||||
from app.common.responses.api_response import success_response
|
||||
from app.common.security.deps import get_current_user
|
||||
from app.models.entities import User
|
||||
from app.modules.assets.service import AssetsService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/assets", tags=["assets"])
|
||||
|
||||
|
||||
@router.post("/upload-token")
|
||||
def create_upload_token(mediaType: str = "image", db: Session = Depends(get_db)):
|
||||
return success_response(AssetsService(db).create_upload_token(mediaType))
|
||||
|
||||
|
||||
@router.post("")
|
||||
def upload_asset(
|
||||
file: UploadFile = File(...),
|
||||
mediaType: str = Form("image"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(AssetsService(db).save_asset(current_user.id, file, mediaType))
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_assets(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(AssetsService(db).list_assets(current_user.id))
|
||||
|
||||
|
||||
@router.delete("/{asset_id}")
|
||||
def delete_asset(
|
||||
asset_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(AssetsService(db).delete_asset(current_user.id, asset_id))
|
||||
|
||||
6
backend/app/modules/assets/schema.py
Normal file
6
backend/app/modules/assets/schema.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UploadTokenRequest(BaseModel):
|
||||
media_type: str = "image"
|
||||
|
||||
72
backend/app/modules/assets/service.py
Normal file
72
backend/app/modules/assets/service.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.errors.app_error import NotFoundAppError
|
||||
from app.common.utils.id_gen import new_order_no
|
||||
from app.core.storage import storage_service
|
||||
from app.models.entities import MediaAsset
|
||||
from app.modules.assets.repository import AssetsRepository
|
||||
|
||||
|
||||
class AssetsService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.repository = AssetsRepository(db)
|
||||
|
||||
def create_upload_token(self, media_type: str) -> dict:
|
||||
return {
|
||||
"uploadToken": new_order_no("upload"),
|
||||
"mediaType": media_type,
|
||||
"uploadMode": "multipart",
|
||||
}
|
||||
|
||||
def save_asset(self, user_id: int, file: UploadFile, media_type: str) -> dict:
|
||||
stored = storage_service.save_upload(file, folder=f"uploads/{media_type}")
|
||||
asset = MediaAsset(
|
||||
asset_no=new_order_no("asset"),
|
||||
user_id=user_id,
|
||||
media_type=media_type,
|
||||
source_type="upload",
|
||||
original_filename=file.filename or "upload.bin",
|
||||
mime_type=file.content_type or "application/octet-stream",
|
||||
file_ext=Path(file.filename or "").suffix,
|
||||
file_size=stored["file_size"],
|
||||
storage_provider="local",
|
||||
storage_bucket="local",
|
||||
storage_key=stored["storage_key"],
|
||||
public_url=stored["public_url"],
|
||||
sha256=stored["sha256"],
|
||||
status="active",
|
||||
)
|
||||
self.db.add(asset)
|
||||
self.db.commit()
|
||||
self.db.refresh(asset)
|
||||
return self.serialize(asset)
|
||||
|
||||
def list_assets(self, user_id: int) -> list[dict]:
|
||||
return [self.serialize(item) for item in self.repository.list_assets(user_id).all()]
|
||||
|
||||
def delete_asset(self, user_id: int, asset_id: int) -> dict:
|
||||
asset = self.repository.get_asset(user_id, asset_id)
|
||||
if not asset:
|
||||
raise NotFoundAppError("asset not found", code=40003)
|
||||
asset.status = "deleted"
|
||||
asset.deleted_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
return {"assetId": asset_id, "deleted": True}
|
||||
|
||||
@staticmethod
|
||||
def serialize(asset: MediaAsset) -> dict:
|
||||
return {
|
||||
"id": asset.id,
|
||||
"assetNo": asset.asset_no,
|
||||
"mediaType": asset.media_type,
|
||||
"originalFilename": asset.original_filename,
|
||||
"fileSize": asset.file_size,
|
||||
"publicUrl": asset.public_url,
|
||||
"createdAt": asset.created_at.isoformat(),
|
||||
}
|
||||
|
||||
24
backend/app/modules/auth/repository.py
Normal file
24
backend/app/modules/auth/repository.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.entities import User
|
||||
|
||||
|
||||
class AuthRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_user_by_account(self, account: str) -> User | None:
|
||||
return self.db.scalar(
|
||||
select(User).where(
|
||||
or_(
|
||||
User.email == account,
|
||||
User.mobile == account,
|
||||
User.username == account,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def get_user_by_public_id(self, public_id: str) -> User | None:
|
||||
return self.db.scalar(select(User).where(User.public_id == public_id))
|
||||
|
||||
55
backend/app/modules/auth/router.py
Normal file
55
backend/app/modules/auth/router.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from fastapi import APIRouter, Cookie, Depends, Request, Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.db.session import get_db
|
||||
from app.common.responses.api_response import success_response
|
||||
from app.common.security.deps import get_current_user
|
||||
from app.models.entities import User
|
||||
from app.modules.auth.schema import LoginRequest, RegisterRequest
|
||||
from app.modules.auth.service import AuthService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register")
|
||||
def register(
|
||||
payload: RegisterRequest,
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
data = AuthService(db).register(payload, request, response)
|
||||
return {"code": 0, "message": "ok", "data": data}
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
def login(
|
||||
payload: LoginRequest,
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
data = AuthService(db).login(payload, request, response)
|
||||
return {"code": 0, "message": "ok", "data": data}
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
def refresh(
|
||||
response: Response,
|
||||
db: Session = Depends(get_db),
|
||||
user_refresh_token: str | None = Cookie(default=None),
|
||||
):
|
||||
data = AuthService(db).refresh(user_refresh_token, response)
|
||||
return {"code": 0, "message": "ok", "data": data}
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
def logout(response: Response, db: Session = Depends(get_db)):
|
||||
AuthService(db).logout(response)
|
||||
return {"code": 0, "message": "ok", "data": {"success": True}}
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
def me(current_user: User = Depends(get_current_user)):
|
||||
return success_response(AuthService.serialize_user(current_user))
|
||||
13
backend/app/modules/auth/schema.py
Normal file
13
backend/app/modules/auth/schema.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
account: EmailStr
|
||||
password: str = Field(min_length=8, max_length=64)
|
||||
invite_code: str | None = None
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
account: str
|
||||
password: str = Field(min_length=8, max_length=64)
|
||||
|
||||
123
backend/app/modules/auth/service.py
Normal file
123
backend/app/modules/auth/service.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Request, Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.errors.app_error import AuthenticationError, ConflictAppError
|
||||
from app.common.security.jwt import (
|
||||
clear_auth_cookies,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_refresh_token,
|
||||
set_auth_cookies,
|
||||
)
|
||||
from app.common.security.password import hash_password, verify_password
|
||||
from app.common.utils.id_gen import new_public_id
|
||||
from app.models.entities import InviteCode, InviteRelation, User, Wallet
|
||||
from app.modules.auth.repository import AuthRepository
|
||||
from app.modules.wallets.service import WalletService
|
||||
|
||||
|
||||
class AuthService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.repository = AuthRepository(db)
|
||||
self.wallet_service = WalletService(db)
|
||||
|
||||
def register(self, payload, request: Request, response: Response) -> dict:
|
||||
if self.repository.get_user_by_account(payload.account):
|
||||
raise ConflictAppError("account already exists", code=10010)
|
||||
|
||||
user = User(
|
||||
public_id=new_public_id("usr"),
|
||||
email=payload.account,
|
||||
password_hash=hash_password(payload.password),
|
||||
nickname=payload.account.split("@")[0],
|
||||
status=1,
|
||||
register_ip=request.client.host if request.client else "",
|
||||
last_login_ip=request.client.host if request.client else "",
|
||||
last_login_at=datetime.utcnow(),
|
||||
)
|
||||
self.db.add(user)
|
||||
self.db.flush()
|
||||
self.db.add(Wallet(user_id=user.id))
|
||||
self.db.flush()
|
||||
self._bind_invite_relation(user.id, payload.invite_code, request)
|
||||
self.wallet_service.try_issue_signup_reward(user.id)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
self._issue_tokens(user.public_id, response)
|
||||
return self.serialize_user(user)
|
||||
|
||||
def login(self, payload, request: Request, response: Response) -> dict:
|
||||
user = self.repository.get_user_by_account(payload.account)
|
||||
if not user or not verify_password(payload.password, user.password_hash):
|
||||
raise AuthenticationError("invalid credentials")
|
||||
user.last_login_at = datetime.utcnow()
|
||||
user.last_login_ip = request.client.host if request.client else ""
|
||||
self.db.commit()
|
||||
self._issue_tokens(user.public_id, response)
|
||||
return self.serialize_user(user)
|
||||
|
||||
def refresh(self, refresh_token: str | None, response: Response) -> dict:
|
||||
if not refresh_token:
|
||||
raise AuthenticationError()
|
||||
try:
|
||||
payload = decode_refresh_token(refresh_token)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise AuthenticationError() from exc
|
||||
if payload.get("scope") != "user":
|
||||
raise AuthenticationError()
|
||||
user = self.repository.get_user_by_public_id(payload["sub"])
|
||||
if not user:
|
||||
raise AuthenticationError()
|
||||
self._issue_tokens(user.public_id, response)
|
||||
return self.serialize_user(user)
|
||||
|
||||
def logout(self, response: Response) -> None:
|
||||
clear_auth_cookies(response, prefix="user")
|
||||
|
||||
@staticmethod
|
||||
def serialize_user(user: User) -> dict:
|
||||
return {
|
||||
"publicId": user.public_id,
|
||||
"username": user.username or "",
|
||||
"nickname": user.nickname,
|
||||
"avatarUrl": user.avatar_url,
|
||||
"email": user.email or "",
|
||||
"mobile": user.mobile or "",
|
||||
"status": user.status,
|
||||
}
|
||||
|
||||
def _issue_tokens(self, public_id: str, response: Response) -> None:
|
||||
access_token = create_access_token(public_id, scope="user")
|
||||
refresh_token = create_refresh_token(public_id, scope="user")
|
||||
set_auth_cookies(response, access_token, refresh_token, prefix="user")
|
||||
|
||||
def _bind_invite_relation(
|
||||
self,
|
||||
invitee_user_id: int,
|
||||
invite_code_value: str | None,
|
||||
request: Request,
|
||||
) -> None:
|
||||
if not invite_code_value:
|
||||
return
|
||||
invite_code = self.db.query(InviteCode).filter(
|
||||
InviteCode.invite_code == invite_code_value,
|
||||
InviteCode.status == 1,
|
||||
).first()
|
||||
if not invite_code:
|
||||
return
|
||||
self.db.add(
|
||||
InviteRelation(
|
||||
inviter_user_id=invite_code.user_id,
|
||||
invitee_user_id=invitee_user_id,
|
||||
invite_code_id=invite_code.id,
|
||||
reward_status="pending",
|
||||
reward_points=0,
|
||||
register_ip=request.client.host if request.client else "",
|
||||
)
|
||||
)
|
||||
invite_code.used_count += 1
|
||||
15
backend/app/modules/growth_rules/repository.py
Normal file
15
backend/app/modules/growth_rules/repository.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.entities import GrowthRewardRule
|
||||
|
||||
|
||||
class GrowthRulesRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_rule(self, rule_type: str) -> GrowthRewardRule | None:
|
||||
return self.db.scalar(
|
||||
select(GrowthRewardRule).where(GrowthRewardRule.rule_type == rule_type)
|
||||
)
|
||||
|
||||
38
backend/app/modules/growth_rules/router.py
Normal file
38
backend/app/modules/growth_rules/router.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.db.session import get_db
|
||||
from app.common.responses.api_response import success_response
|
||||
from app.common.security.deps import require_admin_permission
|
||||
from app.modules.growth_rules.schema import GrowthRulePayload
|
||||
from app.modules.growth_rules.service import GrowthRulesService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/growth-rules", tags=["admin-growth-rules"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
def get_growth_rules(
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(GrowthRulesService(db).get_rules())
|
||||
|
||||
|
||||
@router.put("/signup")
|
||||
def update_signup_rule(
|
||||
payload: GrowthRulePayload,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(GrowthRulesService(db).update_signup_rule(payload))
|
||||
|
||||
|
||||
@router.put("/invite")
|
||||
def update_invite_rule(
|
||||
payload: GrowthRulePayload,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(GrowthRulesService(db).update_invite_rule(payload))
|
||||
|
||||
9
backend/app/modules/growth_rules/schema.py
Normal file
9
backend/app/modules/growth_rules/schema.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GrowthRulePayload(BaseModel):
|
||||
enabled: bool
|
||||
reward_points: int
|
||||
min_consume_points: int = 0
|
||||
remark: str = ""
|
||||
|
||||
44
backend/app/modules/growth_rules/service.py
Normal file
44
backend/app/modules/growth_rules/service.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.errors.app_error import NotFoundAppError
|
||||
from app.modules.growth_rules.repository import GrowthRulesRepository
|
||||
|
||||
|
||||
class GrowthRulesService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.repository = GrowthRulesRepository(db)
|
||||
|
||||
def get_rules(self) -> dict:
|
||||
signup = self.repository.get_rule("signup_reward")
|
||||
invite = self.repository.get_rule("invite_reward")
|
||||
return {
|
||||
"signupRewardEnabled": bool(signup.enabled) if signup else False,
|
||||
"signupRewardPoints": signup.reward_points if signup else 0,
|
||||
"inviteRewardEnabled": bool(invite.enabled) if invite else False,
|
||||
"inviteRewardPoints": invite.reward_points if invite else 0,
|
||||
"inviteRewardTrigger": invite.trigger_condition if invite else "on_first_consume",
|
||||
"inviteRewardMinConsumePoints": invite.min_consume_points if invite else 0,
|
||||
}
|
||||
|
||||
def update_signup_rule(self, payload) -> dict:
|
||||
rule = self.repository.get_rule("signup_reward")
|
||||
if not rule:
|
||||
raise NotFoundAppError("signup rule not found", code=70010)
|
||||
rule.enabled = payload.enabled
|
||||
rule.reward_points = payload.reward_points
|
||||
rule.remark = payload.remark
|
||||
self.db.commit()
|
||||
return self.get_rules()
|
||||
|
||||
def update_invite_rule(self, payload) -> dict:
|
||||
rule = self.repository.get_rule("invite_reward")
|
||||
if not rule:
|
||||
raise NotFoundAppError("invite rule not found", code=70011)
|
||||
rule.enabled = payload.enabled
|
||||
rule.reward_points = payload.reward_points
|
||||
rule.min_consume_points = payload.min_consume_points
|
||||
rule.remark = payload.remark
|
||||
self.db.commit()
|
||||
return self.get_rules()
|
||||
|
||||
35
backend/app/modules/invites/repository.py
Normal file
35
backend/app/modules/invites/repository.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.entities import InviteCode, InviteRelation, User
|
||||
|
||||
|
||||
class InviteRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_default_code(self, user_id: int) -> InviteCode | None:
|
||||
return self.db.scalar(
|
||||
select(InviteCode).where(
|
||||
InviteCode.user_id == user_id,
|
||||
InviteCode.is_default.is_(True),
|
||||
)
|
||||
)
|
||||
|
||||
def get_code(self, code_value: str) -> InviteCode | None:
|
||||
return self.db.scalar(select(InviteCode).where(InviteCode.invite_code == code_value))
|
||||
|
||||
def inviter_relations(self, user_id: int) -> list[InviteRelation]:
|
||||
return (
|
||||
self.db.query(InviteRelation)
|
||||
.filter(InviteRelation.inviter_user_id == user_id)
|
||||
.order_by(InviteRelation.id.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
def users_by_ids(self, user_ids: list[int]) -> dict[int, User]:
|
||||
if not user_ids:
|
||||
return {}
|
||||
rows = self.db.scalars(select(User).where(User.id.in_(user_ids))).all()
|
||||
return {row.id: row for row in rows}
|
||||
|
||||
43
backend/app/modules/invites/router.py
Normal file
43
backend/app/modules/invites/router.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.db.session import get_db
|
||||
from app.common.responses.api_response import success_response
|
||||
from app.common.security.deps import get_current_user
|
||||
from app.models.entities import User
|
||||
from app.modules.invites.service import InviteService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/invite", tags=["invite"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
def get_invite_summary(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(InviteService(db).get_invite_summary(current_user.id))
|
||||
|
||||
|
||||
@router.post("/codes")
|
||||
def create_invite_code(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(InviteService(db).create_invite_code(current_user.id))
|
||||
|
||||
|
||||
@router.get("/relations")
|
||||
def list_relations(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(InviteService(db).list_relations(current_user.id))
|
||||
|
||||
|
||||
@router.get("/rewards")
|
||||
def list_rewards(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(InviteService(db).list_rewards(current_user.id))
|
||||
6
backend/app/modules/invites/schema.py
Normal file
6
backend/app/modules/invites/schema.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CreateInviteCodeRequest(BaseModel):
|
||||
regenerate: bool = False
|
||||
|
||||
75
backend/app/modules/invites/service.py
Normal file
75
backend/app/modules/invites/service.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.utils.id_gen import new_invite_code
|
||||
from app.models.entities import InviteCode, InviteRelation
|
||||
from app.modules.invites.repository import InviteRepository
|
||||
|
||||
|
||||
class InviteService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.repository = InviteRepository(db)
|
||||
|
||||
def get_invite_summary(self, user_id: int) -> dict:
|
||||
code = self.repository.get_default_code(user_id)
|
||||
if not code:
|
||||
code = self._create_default_code(user_id)
|
||||
relations = self.repository.inviter_relations(user_id)
|
||||
rewarded = [item for item in relations if item.reward_status == "rewarded"]
|
||||
return {
|
||||
"inviteCode": code.invite_code,
|
||||
"inviteLink": code.invite_link,
|
||||
"invitedUsers": len(relations),
|
||||
"rewardedUsers": len(rewarded),
|
||||
"rewardedPoints": sum(item.reward_points for item in rewarded),
|
||||
}
|
||||
|
||||
def create_invite_code(self, user_id: int) -> dict:
|
||||
code = self.repository.get_default_code(user_id)
|
||||
if code:
|
||||
return {"inviteCode": code.invite_code, "inviteLink": code.invite_link}
|
||||
code = self._create_default_code(user_id)
|
||||
return {"inviteCode": code.invite_code, "inviteLink": code.invite_link}
|
||||
|
||||
def list_relations(self, user_id: int) -> list[dict]:
|
||||
relations = self.repository.inviter_relations(user_id)
|
||||
users = self.repository.users_by_ids([item.invitee_user_id for item in relations])
|
||||
return [
|
||||
{
|
||||
"inviteeUserId": item.invitee_user_id,
|
||||
"inviteeNickname": users.get(item.invitee_user_id).nickname if users.get(item.invitee_user_id) else "",
|
||||
"rewardStatus": item.reward_status,
|
||||
"rewardPoints": item.reward_points,
|
||||
"createdAt": item.created_at.isoformat(),
|
||||
"rewardedAt": item.rewarded_at.isoformat() if item.rewarded_at else None,
|
||||
}
|
||||
for item in relations
|
||||
]
|
||||
|
||||
def list_rewards(self, user_id: int) -> list[dict]:
|
||||
relations = self.repository.inviter_relations(user_id)
|
||||
return [
|
||||
{
|
||||
"inviteeUserId": item.invitee_user_id,
|
||||
"rewardStatus": item.reward_status,
|
||||
"rewardPoints": item.reward_points,
|
||||
"rewardedAt": item.rewarded_at.isoformat() if item.rewarded_at else None,
|
||||
}
|
||||
for item in relations
|
||||
if item.reward_points > 0
|
||||
]
|
||||
|
||||
def _create_default_code(self, user_id: int) -> InviteCode:
|
||||
code_value = new_invite_code()
|
||||
code = InviteCode(
|
||||
user_id=user_id,
|
||||
invite_code=code_value,
|
||||
invite_link=f"http://localhost:3000/register?inviteCode={code_value}",
|
||||
status=1,
|
||||
is_default=True,
|
||||
)
|
||||
self.db.add(code)
|
||||
self.db.commit()
|
||||
self.db.refresh(code)
|
||||
return code
|
||||
|
||||
15
backend/app/modules/payments/repository.py
Normal file
15
backend/app/modules/payments/repository.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.entities import RechargeOrder
|
||||
|
||||
|
||||
class PaymentsRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_order_by_id(self, order_id: int) -> RechargeOrder | None:
|
||||
return self.db.scalar(select(RechargeOrder).where(RechargeOrder.id == order_id))
|
||||
|
||||
def list_orders(self):
|
||||
return self.db.query(RechargeOrder).order_by(RechargeOrder.id.desc())
|
||||
50
backend/app/modules/payments/router.py
Normal file
50
backend/app/modules/payments/router.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.db.session import get_db
|
||||
from app.common.responses.api_response import success_response
|
||||
from app.common.security.deps import require_admin_permission
|
||||
from app.modules.payments.service import PaymentsService
|
||||
from app.modules.wallets.service import WalletService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/payments", tags=["payments"])
|
||||
|
||||
|
||||
@router.post("/mock-notify")
|
||||
def mock_notify(order_no: str, db: Session = Depends(get_db)):
|
||||
return success_response(WalletService(db).handle_mock_payment(order_no))
|
||||
|
||||
|
||||
@router.get("/mock-pay")
|
||||
def mock_pay(orderNo: str, db: Session = Depends(get_db)):
|
||||
return success_response(WalletService(db).handle_mock_payment(orderNo))
|
||||
|
||||
|
||||
admin_router = APIRouter(prefix="/api/v1/admin/recharge-orders", tags=["admin-payments"])
|
||||
|
||||
|
||||
@admin_router.get("")
|
||||
def list_orders(
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(PaymentsService(db).list_orders())
|
||||
|
||||
|
||||
@admin_router.get("/{order_id}")
|
||||
def get_order_detail(
|
||||
order_id: int,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(PaymentsService(db).get_order_detail(order_id))
|
||||
|
||||
|
||||
@admin_router.post("/{order_id}/repair")
|
||||
def repair_order(
|
||||
order_id: int,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(PaymentsService(db).repair_order(order_id))
|
||||
10
backend/app/modules/payments/schema.py
Normal file
10
backend/app/modules/payments/schema.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RepairOrderRequest(BaseModel):
|
||||
remark: str = "manual repair"
|
||||
|
||||
|
||||
class MockPaymentNotifyRequest(BaseModel):
|
||||
order_no: str = Field(alias="orderNo")
|
||||
|
||||
42
backend/app/modules/payments/service.py
Normal file
42
backend/app/modules/payments/service.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.errors.app_error import NotFoundAppError
|
||||
from app.modules.payments.repository import PaymentsRepository
|
||||
from app.modules.wallets.service import WalletService
|
||||
|
||||
|
||||
class PaymentsService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.repository = PaymentsRepository(db)
|
||||
self.wallet_service = WalletService(db)
|
||||
|
||||
def repair_order(self, order_id: int) -> dict:
|
||||
order = self.repository.get_order_by_id(order_id)
|
||||
if not order:
|
||||
raise NotFoundAppError("order not found", code=30001)
|
||||
return self.wallet_service.handle_mock_payment(order.order_no)
|
||||
|
||||
def list_orders(self) -> list[dict]:
|
||||
rows = self.repository.list_orders().limit(200).all()
|
||||
return [self.serialize_order(item) for item in rows]
|
||||
|
||||
def get_order_detail(self, order_id: int) -> dict:
|
||||
order = self.repository.get_order_by_id(order_id)
|
||||
if not order:
|
||||
raise NotFoundAppError("order not found", code=30001)
|
||||
return self.serialize_order(order)
|
||||
|
||||
@staticmethod
|
||||
def serialize_order(order) -> dict:
|
||||
return {
|
||||
"id": order.id,
|
||||
"orderNo": order.order_no,
|
||||
"userId": order.user_id,
|
||||
"payAmount": f"{order.pay_amount:.2f}",
|
||||
"arrivalPoints": order.arrival_points,
|
||||
"status": order.status,
|
||||
"paymentChannelCode": order.payment_channel_code,
|
||||
"paidAt": order.paid_at.isoformat() if order.paid_at else None,
|
||||
"createdAt": order.created_at.isoformat(),
|
||||
}
|
||||
16
backend/app/modules/pricing/repository.py
Normal file
16
backend/app/modules/pricing/repository.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.entities import PricingRule
|
||||
|
||||
|
||||
class PricingRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def list_rules(self):
|
||||
return self.db.query(PricingRule).order_by(PricingRule.id.desc())
|
||||
|
||||
def get_rule(self, rule_id: int) -> PricingRule | None:
|
||||
return self.db.scalar(select(PricingRule).where(PricingRule.id == rule_id))
|
||||
|
||||
48
backend/app/modules/pricing/router.py
Normal file
48
backend/app/modules/pricing/router.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.db.session import get_db
|
||||
from app.common.responses.api_response import success_response
|
||||
from app.common.security.deps import require_admin_permission
|
||||
from app.modules.pricing.schema import PricingRulePayload
|
||||
from app.modules.pricing.service import PricingService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/pricing-rules", tags=["admin-pricing"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_pricing_rules(
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(PricingService(db).list_rules())
|
||||
|
||||
|
||||
@router.post("")
|
||||
def create_pricing_rule(
|
||||
payload: PricingRulePayload,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(PricingService(db).create_rule(payload))
|
||||
|
||||
|
||||
@router.put("/{rule_id}")
|
||||
def update_pricing_rule(
|
||||
rule_id: int,
|
||||
payload: PricingRulePayload,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(PricingService(db).update_rule(rule_id, payload))
|
||||
|
||||
|
||||
@router.post("/{rule_id}/publish")
|
||||
def publish_pricing_rule(
|
||||
rule_id: int,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(PricingService(db).publish_rule(rule_id))
|
||||
|
||||
15
backend/app/modules/pricing/schema.py
Normal file
15
backend/app/modules/pricing/schema.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PricingRulePayload(BaseModel):
|
||||
rule_name: str
|
||||
video_model_id: int
|
||||
points_per_second: int
|
||||
minimum_points: int
|
||||
effective_at: datetime
|
||||
expired_at: datetime | None = None
|
||||
version_no: int = 1
|
||||
status: int = 1
|
||||
|
||||
53
backend/app/modules/pricing/service.py
Normal file
53
backend/app/modules/pricing/service.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.errors.app_error import NotFoundAppError
|
||||
from app.models.entities import PricingRule
|
||||
from app.modules.pricing.repository import PricingRepository
|
||||
|
||||
|
||||
class PricingService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.repository = PricingRepository(db)
|
||||
|
||||
def list_rules(self) -> list[dict]:
|
||||
return [self.serialize(item) for item in self.repository.list_rules().all()]
|
||||
|
||||
def create_rule(self, payload) -> dict:
|
||||
item = PricingRule(**payload.model_dump())
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return self.serialize(item)
|
||||
|
||||
def update_rule(self, rule_id: int, payload) -> dict:
|
||||
item = self.repository.get_rule(rule_id)
|
||||
if not item:
|
||||
raise NotFoundAppError("pricing rule not found", code=60004)
|
||||
for key, value in payload.model_dump().items():
|
||||
setattr(item, key, value)
|
||||
self.db.commit()
|
||||
return self.serialize(item)
|
||||
|
||||
def publish_rule(self, rule_id: int) -> dict:
|
||||
item = self.repository.get_rule(rule_id)
|
||||
if not item:
|
||||
raise NotFoundAppError("pricing rule not found", code=60004)
|
||||
item.status = 1
|
||||
self.db.commit()
|
||||
return self.serialize(item)
|
||||
|
||||
@staticmethod
|
||||
def serialize(item: PricingRule) -> dict:
|
||||
return {
|
||||
"id": item.id,
|
||||
"ruleName": item.rule_name,
|
||||
"videoModelId": item.video_model_id,
|
||||
"pointsPerSecond": item.points_per_second,
|
||||
"minimumPoints": item.minimum_points,
|
||||
"effectiveAt": item.effective_at.isoformat(),
|
||||
"expiredAt": item.expired_at.isoformat() if item.expired_at else None,
|
||||
"versionNo": item.version_no,
|
||||
"status": item.status,
|
||||
}
|
||||
|
||||
22
backend/app/modules/providers/repository.py
Normal file
22
backend/app/modules/providers/repository.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.entities import ProviderAccount, ProviderModel
|
||||
|
||||
|
||||
class ProvidersRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def list_accounts(self):
|
||||
return self.db.query(ProviderAccount).order_by(ProviderAccount.id.desc())
|
||||
|
||||
def get_account(self, account_id: int) -> ProviderAccount | None:
|
||||
return self.db.scalar(select(ProviderAccount).where(ProviderAccount.id == account_id))
|
||||
|
||||
def list_models(self):
|
||||
return self.db.query(ProviderModel).order_by(ProviderModel.id.desc())
|
||||
|
||||
def get_model(self, model_id: int) -> ProviderModel | None:
|
||||
return self.db.scalar(select(ProviderModel).where(ProviderModel.id == model_id))
|
||||
|
||||
66
backend/app/modules/providers/router.py
Normal file
66
backend/app/modules/providers/router.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.db.session import get_db
|
||||
from app.common.responses.api_response import success_response
|
||||
from app.common.security.deps import require_admin_permission
|
||||
from app.modules.providers.schema import ProviderAccountPayload, ProviderModelPayload
|
||||
from app.modules.providers.service import ProvidersService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin", tags=["admin-providers"])
|
||||
|
||||
|
||||
@router.get("/provider-accounts")
|
||||
def list_provider_accounts(
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(ProvidersService(db).list_accounts())
|
||||
|
||||
|
||||
@router.post("/provider-accounts")
|
||||
def create_provider_account(
|
||||
payload: ProviderAccountPayload,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(ProvidersService(db).create_account(payload))
|
||||
|
||||
|
||||
@router.put("/provider-accounts/{account_id}")
|
||||
def update_provider_account(
|
||||
account_id: int,
|
||||
payload: ProviderAccountPayload,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(ProvidersService(db).update_account(account_id, payload))
|
||||
|
||||
|
||||
@router.get("/provider-models")
|
||||
def list_provider_models(
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(ProvidersService(db).list_models())
|
||||
|
||||
|
||||
@router.post("/provider-models")
|
||||
def create_provider_model(
|
||||
payload: ProviderModelPayload,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(ProvidersService(db).create_model(payload))
|
||||
|
||||
|
||||
@router.put("/provider-models/{model_id}")
|
||||
def update_provider_model(
|
||||
model_id: int,
|
||||
payload: ProviderModelPayload,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(ProvidersService(db).update_model(model_id, payload))
|
||||
|
||||
35
backend/app/modules/providers/schema.py
Normal file
35
backend/app/modules/providers/schema.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ProviderAccountPayload(BaseModel):
|
||||
provider_code: str
|
||||
provider_name: str
|
||||
api_format: str
|
||||
base_url: str
|
||||
api_key: str
|
||||
api_secret: str | None = ""
|
||||
webhook_secret: str | None = ""
|
||||
timeout_seconds: int = 120
|
||||
max_retries: int = 3
|
||||
status: int = 1
|
||||
remark: str = ""
|
||||
|
||||
|
||||
class ProviderModelPayload(BaseModel):
|
||||
provider_account_id: int
|
||||
model_code: str
|
||||
model_name: str
|
||||
request_content_type: str = "application/json"
|
||||
supports_text_to_video: bool = True
|
||||
supports_image_to_video: bool = False
|
||||
supports_video_reference: bool = False
|
||||
supports_audio_reference: bool = False
|
||||
supports_generate_audio: bool = False
|
||||
supports_remix: bool = False
|
||||
supports_webhook: bool = False
|
||||
min_duration: int = 4
|
||||
max_duration: int = 12
|
||||
default_ratio: str = "16:9"
|
||||
default_resolution: str = "1280x720"
|
||||
status: int = 1
|
||||
|
||||
112
backend/app/modules/providers/service.py
Normal file
112
backend/app/modules/providers/service.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.errors.app_error import NotFoundAppError
|
||||
from app.models.entities import ProviderAccount, ProviderModel
|
||||
from app.modules.providers.repository import ProvidersRepository
|
||||
|
||||
|
||||
class ProvidersService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.repository = ProvidersRepository(db)
|
||||
|
||||
def list_accounts(self) -> list[dict]:
|
||||
return [self.serialize_account(item) for item in self.repository.list_accounts().all()]
|
||||
|
||||
def create_account(self, payload) -> dict:
|
||||
item = ProviderAccount(
|
||||
provider_code=payload.provider_code,
|
||||
provider_name=payload.provider_name,
|
||||
api_format=payload.api_format,
|
||||
base_url=payload.base_url,
|
||||
api_key_encrypted=payload.api_key,
|
||||
api_secret_encrypted=payload.api_secret,
|
||||
webhook_secret_encrypted=payload.webhook_secret,
|
||||
timeout_seconds=payload.timeout_seconds,
|
||||
max_retries=payload.max_retries,
|
||||
status=payload.status,
|
||||
remark=payload.remark,
|
||||
)
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return self.serialize_account(item)
|
||||
|
||||
def update_account(self, account_id: int, payload) -> dict:
|
||||
item = self.repository.get_account(account_id)
|
||||
if not item:
|
||||
raise NotFoundAppError("provider account not found", code=60001)
|
||||
item.provider_code = payload.provider_code
|
||||
item.provider_name = payload.provider_name
|
||||
item.api_format = payload.api_format
|
||||
item.base_url = payload.base_url
|
||||
item.api_key_encrypted = payload.api_key
|
||||
item.api_secret_encrypted = payload.api_secret
|
||||
item.webhook_secret_encrypted = payload.webhook_secret
|
||||
item.timeout_seconds = payload.timeout_seconds
|
||||
item.max_retries = payload.max_retries
|
||||
item.status = payload.status
|
||||
item.remark = payload.remark
|
||||
self.db.commit()
|
||||
return self.serialize_account(item)
|
||||
|
||||
def list_models(self) -> list[dict]:
|
||||
accounts = {item.id: item for item in self.repository.list_accounts().all()}
|
||||
return [self.serialize_model(item, accounts) for item in self.repository.list_models().all()]
|
||||
|
||||
def create_model(self, payload) -> dict:
|
||||
item = ProviderModel(**payload.model_dump())
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
account = self.repository.get_account(item.provider_account_id)
|
||||
return self.serialize_model(item, {account.id: account} if account else {})
|
||||
|
||||
def update_model(self, model_id: int, payload) -> dict:
|
||||
item = self.repository.get_model(model_id)
|
||||
if not item:
|
||||
raise NotFoundAppError("provider model not found", code=60002)
|
||||
for key, value in payload.model_dump().items():
|
||||
setattr(item, key, value)
|
||||
self.db.commit()
|
||||
account = self.repository.get_account(item.provider_account_id)
|
||||
return self.serialize_model(item, {account.id: account} if account else {})
|
||||
|
||||
@staticmethod
|
||||
def serialize_account(item: ProviderAccount) -> dict:
|
||||
return {
|
||||
"id": item.id,
|
||||
"providerCode": item.provider_code,
|
||||
"providerName": item.provider_name,
|
||||
"apiFormat": item.api_format,
|
||||
"baseUrl": item.base_url,
|
||||
"timeoutSeconds": item.timeout_seconds,
|
||||
"maxRetries": item.max_retries,
|
||||
"status": item.status,
|
||||
"remark": item.remark,
|
||||
"updatedAt": item.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def serialize_model(item: ProviderModel, accounts: dict[int, ProviderAccount]) -> dict:
|
||||
account = accounts.get(item.provider_account_id)
|
||||
return {
|
||||
"id": item.id,
|
||||
"providerAccountId": item.provider_account_id,
|
||||
"providerName": account.provider_name if account else "",
|
||||
"modelCode": item.model_code,
|
||||
"modelName": item.model_name,
|
||||
"requestContentType": item.request_content_type,
|
||||
"supportsTextToVideo": item.supports_text_to_video,
|
||||
"supportsImageToVideo": item.supports_image_to_video,
|
||||
"supportsVideoReference": item.supports_video_reference,
|
||||
"supportsAudioReference": item.supports_audio_reference,
|
||||
"supportsGenerateAudio": item.supports_generate_audio,
|
||||
"supportsWebhook": item.supports_webhook,
|
||||
"minDuration": item.min_duration,
|
||||
"maxDuration": item.max_duration,
|
||||
"defaultRatio": item.default_ratio,
|
||||
"defaultResolution": item.default_resolution,
|
||||
"status": item.status,
|
||||
}
|
||||
|
||||
25
backend/app/modules/system/repository.py
Normal file
25
backend/app/modules/system/repository.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.entities import CallbackLog, RedeemCode, SystemConfig
|
||||
|
||||
|
||||
class SystemRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def list_configs(self):
|
||||
return self.db.query(SystemConfig).order_by(SystemConfig.group_name.asc(), SystemConfig.id.asc())
|
||||
|
||||
def get_config(self, config_key: str) -> SystemConfig | None:
|
||||
return self.db.scalar(select(SystemConfig).where(SystemConfig.config_key == config_key))
|
||||
|
||||
def list_redeem_codes(self):
|
||||
return self.db.query(RedeemCode).order_by(RedeemCode.id.desc())
|
||||
|
||||
def get_redeem_code(self, redeem_id: int) -> RedeemCode | None:
|
||||
return self.db.scalar(select(RedeemCode).where(RedeemCode.id == redeem_id))
|
||||
|
||||
def list_callback_logs(self):
|
||||
return self.db.query(CallbackLog).order_by(CallbackLog.id.desc())
|
||||
|
||||
72
backend/app/modules/system/router.py
Normal file
72
backend/app/modules/system/router.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.db.session import get_db
|
||||
from app.common.responses.api_response import success_response
|
||||
from app.common.security.deps import require_admin_permission
|
||||
from app.modules.system.schema import RedeemBatchCreatePayload, SystemConfigItemPayload
|
||||
from app.modules.system.service import SystemService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin", tags=["admin-system"])
|
||||
|
||||
|
||||
@router.get("/system-config")
|
||||
def list_system_configs(
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(SystemService(db).list_configs())
|
||||
|
||||
|
||||
@router.put("/system-config")
|
||||
def upsert_system_config(
|
||||
payload: SystemConfigItemPayload,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(SystemService(db).upsert_config(payload))
|
||||
|
||||
|
||||
@router.get("/redeem-codes")
|
||||
def list_redeem_codes(
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(SystemService(db).list_redeem_codes())
|
||||
|
||||
|
||||
@router.post("/redeem-codes/batch-create")
|
||||
def batch_create_redeem_codes(
|
||||
payload: RedeemBatchCreatePayload,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(SystemService(db).batch_create_redeem_codes(payload))
|
||||
|
||||
|
||||
@router.post("/redeem-codes/import")
|
||||
def import_redeem_codes(
|
||||
payload: RedeemBatchCreatePayload,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(SystemService(db).batch_create_redeem_codes(payload))
|
||||
|
||||
|
||||
@router.put("/redeem-codes/{redeem_id}/disable")
|
||||
def disable_redeem_code(
|
||||
redeem_id: int,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(SystemService(db).disable_redeem_code(redeem_id))
|
||||
|
||||
|
||||
@router.get("/callback-logs")
|
||||
def list_callback_logs(
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(SystemService(db).list_callback_logs())
|
||||
|
||||
18
backend/app/modules/system/schema.py
Normal file
18
backend/app/modules/system/schema.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SystemConfigItemPayload(BaseModel):
|
||||
config_key: str
|
||||
config_value: str
|
||||
value_type: str = "string"
|
||||
group_name: str = "default"
|
||||
description: str = ""
|
||||
is_public: bool = False
|
||||
|
||||
|
||||
class RedeemBatchCreatePayload(BaseModel):
|
||||
batch_no: str
|
||||
points: int
|
||||
quantity: int
|
||||
remark: str = ""
|
||||
|
||||
94
backend/app/modules/system/service.py
Normal file
94
backend/app/modules/system/service.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from app.common.errors.app_error import NotFoundAppError
|
||||
from app.common.utils.id_gen import new_invite_code
|
||||
from app.models.entities import RedeemCode, SystemConfig
|
||||
from app.modules.system.repository import SystemRepository
|
||||
|
||||
|
||||
class SystemService:
|
||||
def __init__(self, db) -> None:
|
||||
self.db = db
|
||||
self.repository = SystemRepository(db)
|
||||
|
||||
def list_configs(self) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"configKey": item.config_key,
|
||||
"configValue": item.config_value,
|
||||
"valueType": item.value_type,
|
||||
"groupName": item.group_name,
|
||||
"description": item.description,
|
||||
"isPublic": item.is_public,
|
||||
}
|
||||
for item in self.repository.list_configs().all()
|
||||
]
|
||||
|
||||
def upsert_config(self, payload) -> dict:
|
||||
item = self.repository.get_config(payload.config_key)
|
||||
if not item:
|
||||
item = SystemConfig(**payload.model_dump())
|
||||
self.db.add(item)
|
||||
else:
|
||||
for key, value in payload.model_dump().items():
|
||||
setattr(item, key, value)
|
||||
self.db.commit()
|
||||
return {
|
||||
"configKey": item.config_key,
|
||||
"configValue": item.config_value,
|
||||
"groupName": item.group_name,
|
||||
}
|
||||
|
||||
def list_redeem_codes(self) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"id": item.id,
|
||||
"batchNo": item.batch_no,
|
||||
"redeemCode": item.redeem_code,
|
||||
"points": item.points,
|
||||
"status": item.status,
|
||||
"usedByUserId": item.used_by_user_id,
|
||||
"usedAt": item.used_at.isoformat() if item.used_at else None,
|
||||
}
|
||||
for item in self.repository.list_redeem_codes().all()
|
||||
]
|
||||
|
||||
def batch_create_redeem_codes(self, payload) -> list[dict]:
|
||||
created = []
|
||||
for _ in range(payload.quantity):
|
||||
item = RedeemCode(
|
||||
batch_no=payload.batch_no,
|
||||
redeem_code=f"{payload.batch_no}-{new_invite_code(4)}-{new_invite_code(4)}",
|
||||
points=payload.points,
|
||||
status="unused",
|
||||
remark=payload.remark,
|
||||
)
|
||||
self.db.add(item)
|
||||
created.append(item)
|
||||
self.db.commit()
|
||||
return self.list_redeem_codes()[: payload.quantity]
|
||||
|
||||
def disable_redeem_code(self, redeem_id: int) -> dict:
|
||||
item = self.repository.get_redeem_code(redeem_id)
|
||||
if not item:
|
||||
raise NotFoundAppError("redeem code not found", code=70020)
|
||||
item.status = "disabled"
|
||||
self.db.commit()
|
||||
return {
|
||||
"id": item.id,
|
||||
"status": item.status,
|
||||
}
|
||||
|
||||
def list_callback_logs(self) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"id": item.id,
|
||||
"sourceType": item.source_type,
|
||||
"sourceCode": item.source_code,
|
||||
"relatedNo": item.related_no,
|
||||
"verifyStatus": item.verify_status,
|
||||
"processStatus": item.process_status,
|
||||
"errorMessage": item.error_message,
|
||||
"createdAt": item.created_at.isoformat(),
|
||||
}
|
||||
for item in self.repository.list_callback_logs().all()
|
||||
]
|
||||
|
||||
16
backend/app/modules/users/repository.py
Normal file
16
backend/app/modules/users/repository.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.entities import User
|
||||
|
||||
|
||||
class UsersRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_by_id(self, user_id: int) -> User | None:
|
||||
return self.db.scalar(select(User).where(User.id == user_id))
|
||||
|
||||
def get_by_username(self, username: str) -> User | None:
|
||||
return self.db.scalar(select(User).where(User.username == username))
|
||||
|
||||
37
backend/app/modules/users/router.py
Normal file
37
backend/app/modules/users/router.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.db.session import get_db
|
||||
from app.common.responses.api_response import success_response
|
||||
from app.common.security.deps import get_current_user
|
||||
from app.models.entities import User
|
||||
from app.modules.users.schema import UpdateProfileRequest
|
||||
from app.modules.users.service import UsersService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/profile", tags=["profile"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
def get_profile(current_user: User = Depends(get_current_user)):
|
||||
return success_response(
|
||||
{
|
||||
"publicId": current_user.public_id,
|
||||
"username": current_user.username or "",
|
||||
"nickname": current_user.nickname,
|
||||
"avatarUrl": current_user.avatar_url,
|
||||
"email": current_user.email or "",
|
||||
"mobile": current_user.mobile or "",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.put("")
|
||||
def update_profile(
|
||||
payload: UpdateProfileRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
data = UsersService(db).update_profile(current_user, payload)
|
||||
return success_response(data)
|
||||
|
||||
8
backend/app/modules/users/schema.py
Normal file
8
backend/app/modules/users/schema.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class UpdateProfileRequest(BaseModel):
|
||||
username: str | None = Field(default=None, min_length=3, max_length=32)
|
||||
nickname: str | None = Field(default=None, min_length=1, max_length=32)
|
||||
avatar_url: str | None = Field(default=None, max_length=500)
|
||||
|
||||
31
backend/app/modules/users/service.py
Normal file
31
backend/app/modules/users/service.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.errors.app_error import ConflictAppError
|
||||
from app.models.entities import User
|
||||
from app.modules.users.repository import UsersRepository
|
||||
|
||||
|
||||
class UsersService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.repository = UsersRepository(db)
|
||||
|
||||
def update_profile(self, user: User, payload) -> dict:
|
||||
if payload.username and payload.username != user.username:
|
||||
existing = self.repository.get_by_username(payload.username)
|
||||
if existing and existing.id != user.id:
|
||||
raise ConflictAppError("username already exists", code=10011)
|
||||
user.username = payload.username
|
||||
if payload.nickname is not None:
|
||||
user.nickname = payload.nickname
|
||||
if payload.avatar_url is not None:
|
||||
user.avatar_url = payload.avatar_url
|
||||
self.db.commit()
|
||||
return {
|
||||
"publicId": user.public_id,
|
||||
"username": user.username or "",
|
||||
"nickname": user.nickname,
|
||||
"avatarUrl": user.avatar_url,
|
||||
"email": user.email or "",
|
||||
}
|
||||
|
||||
49
backend/app/modules/video_models/repository.py
Normal file
49
backend/app/modules/video_models/repository.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.entities import PricingRule, ProviderModel, VideoModel, VideoModelSupplierBinding
|
||||
|
||||
|
||||
class VideoModelsRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def list_video_models(self):
|
||||
return self.db.query(VideoModel).order_by(VideoModel.sort_order.asc(), VideoModel.id.asc())
|
||||
|
||||
def get_video_model(self, model_id: int) -> VideoModel | None:
|
||||
return self.db.scalar(select(VideoModel).where(VideoModel.id == model_id))
|
||||
|
||||
def list_bindings(self):
|
||||
return (
|
||||
self.db.query(VideoModelSupplierBinding)
|
||||
.order_by(
|
||||
VideoModelSupplierBinding.video_model_id.asc(),
|
||||
VideoModelSupplierBinding.routing_priority.asc(),
|
||||
)
|
||||
)
|
||||
|
||||
def get_binding(self, binding_id: int) -> VideoModelSupplierBinding | None:
|
||||
return self.db.scalar(
|
||||
select(VideoModelSupplierBinding).where(VideoModelSupplierBinding.id == binding_id)
|
||||
)
|
||||
|
||||
def active_pricing_rule(self, video_model_id: int) -> PricingRule | None:
|
||||
now = datetime.utcnow()
|
||||
return self.db.scalar(
|
||||
select(PricingRule)
|
||||
.where(
|
||||
PricingRule.video_model_id == video_model_id,
|
||||
PricingRule.status == 1,
|
||||
PricingRule.effective_at <= now,
|
||||
or_(PricingRule.expired_at.is_(None), PricingRule.expired_at > now),
|
||||
)
|
||||
.order_by(PricingRule.version_no.desc(), PricingRule.id.desc())
|
||||
)
|
||||
|
||||
def provider_models(self) -> dict[int, ProviderModel]:
|
||||
rows = self.db.scalars(select(ProviderModel)).all()
|
||||
return {row.id: row for row in rows}
|
||||
|
||||
71
backend/app/modules/video_models/router.py
Normal file
71
backend/app/modules/video_models/router.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.db.session import get_db
|
||||
from app.common.responses.api_response import success_response
|
||||
from app.common.security.deps import require_admin_permission
|
||||
from app.modules.video_models.schema import BindingPayload, VideoModelPayload
|
||||
from app.modules.video_models.service import VideoModelsService
|
||||
|
||||
|
||||
router = APIRouter(tags=["video-models"])
|
||||
|
||||
|
||||
@router.get("/api/v1/video-models")
|
||||
def list_public_video_models(db: Session = Depends(get_db)):
|
||||
return success_response(VideoModelsService(db).list_public_models())
|
||||
|
||||
|
||||
@router.get("/api/v1/admin/video-models")
|
||||
def list_admin_video_models(
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(VideoModelsService(db).list_admin_models())
|
||||
|
||||
|
||||
@router.post("/api/v1/admin/video-models")
|
||||
def create_video_model(
|
||||
payload: VideoModelPayload,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(VideoModelsService(db).create_model(payload))
|
||||
|
||||
|
||||
@router.put("/api/v1/admin/video-models/{model_id}")
|
||||
def update_video_model(
|
||||
model_id: int,
|
||||
payload: VideoModelPayload,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(VideoModelsService(db).update_model(model_id, payload))
|
||||
|
||||
|
||||
@router.get("/api/v1/admin/video-model-bindings")
|
||||
def list_bindings(
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(VideoModelsService(db).list_bindings())
|
||||
|
||||
|
||||
@router.post("/api/v1/admin/video-model-bindings")
|
||||
def create_binding(
|
||||
payload: BindingPayload,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(VideoModelsService(db).create_binding(payload))
|
||||
|
||||
|
||||
@router.put("/api/v1/admin/video-model-bindings/{binding_id}")
|
||||
def update_binding(
|
||||
binding_id: int,
|
||||
payload: BindingPayload,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(VideoModelsService(db).update_binding(binding_id, payload))
|
||||
|
||||
23
backend/app/modules/video_models/schema.py
Normal file
23
backend/app/modules/video_models/schema.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class VideoModelPayload(BaseModel):
|
||||
model_key: str
|
||||
model_name: str
|
||||
frontend_title: str
|
||||
frontend_description: str = ""
|
||||
default_duration_seconds: int = 8
|
||||
default_ratio: str = "16:9"
|
||||
default_resolution: str = "1280x720"
|
||||
status: int = 1
|
||||
sort_order: int = 0
|
||||
|
||||
|
||||
class BindingPayload(BaseModel):
|
||||
video_model_id: int
|
||||
provider_model_id: int
|
||||
routing_priority: int = 10
|
||||
is_primary: bool = False
|
||||
status: int = 1
|
||||
timeout_seconds_override: int | None = None
|
||||
|
||||
113
backend/app/modules/video_models/service.py
Normal file
113
backend/app/modules/video_models/service.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.errors.app_error import NotFoundAppError
|
||||
from app.models.entities import ProviderModel, VideoModel, VideoModelSupplierBinding
|
||||
from app.modules.video_models.repository import VideoModelsRepository
|
||||
|
||||
|
||||
class VideoModelsService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.repository = VideoModelsRepository(db)
|
||||
|
||||
def list_public_models(self) -> list[dict]:
|
||||
items = []
|
||||
for item in self.repository.list_video_models().filter(VideoModel.status == 1).all():
|
||||
pricing = self.repository.active_pricing_rule(item.id)
|
||||
items.append(self.serialize_model(item, pricing))
|
||||
return items
|
||||
|
||||
def list_admin_models(self) -> list[dict]:
|
||||
return [
|
||||
self.serialize_model(item, self.repository.active_pricing_rule(item.id))
|
||||
for item in self.repository.list_video_models().all()
|
||||
]
|
||||
|
||||
def create_model(self, payload) -> dict:
|
||||
item = VideoModel(**payload.model_dump())
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return self.serialize_model(item, None)
|
||||
|
||||
def update_model(self, model_id: int, payload) -> dict:
|
||||
item = self.repository.get_video_model(model_id)
|
||||
if not item:
|
||||
raise NotFoundAppError("video model not found", code=50001)
|
||||
for key, value in payload.model_dump().items():
|
||||
setattr(item, key, value)
|
||||
self.db.commit()
|
||||
return self.serialize_model(item, self.repository.active_pricing_rule(item.id))
|
||||
|
||||
def list_bindings(self) -> list[dict]:
|
||||
provider_models = self.repository.provider_models()
|
||||
video_models = {
|
||||
item.id: item for item in self.repository.list_video_models().all()
|
||||
}
|
||||
return [
|
||||
self.serialize_binding(item, provider_models, video_models)
|
||||
for item in self.repository.list_bindings().all()
|
||||
]
|
||||
|
||||
def create_binding(self, payload) -> dict:
|
||||
item = VideoModelSupplierBinding(**payload.model_dump())
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return self._serialize_binding_single(item)
|
||||
|
||||
def update_binding(self, binding_id: int, payload) -> dict:
|
||||
item = self.repository.get_binding(binding_id)
|
||||
if not item:
|
||||
raise NotFoundAppError("binding not found", code=60003)
|
||||
for key, value in payload.model_dump().items():
|
||||
setattr(item, key, value)
|
||||
self.db.commit()
|
||||
return self._serialize_binding_single(item)
|
||||
|
||||
@staticmethod
|
||||
def serialize_model(item: VideoModel, pricing) -> dict:
|
||||
return {
|
||||
"id": item.id,
|
||||
"modelKey": item.model_key,
|
||||
"modelName": item.model_name,
|
||||
"frontendTitle": item.frontend_title,
|
||||
"frontendDescription": item.frontend_description,
|
||||
"defaultDurationSeconds": item.default_duration_seconds,
|
||||
"defaultRatio": item.default_ratio,
|
||||
"defaultResolution": item.default_resolution,
|
||||
"status": item.status,
|
||||
"sortOrder": item.sort_order,
|
||||
"pricing": {
|
||||
"pointsPerSecond": pricing.points_per_second if pricing else 0,
|
||||
"minimumPoints": pricing.minimum_points if pricing else 0,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def serialize_binding(
|
||||
item: VideoModelSupplierBinding,
|
||||
provider_models: dict[int, ProviderModel],
|
||||
video_models: dict[int, VideoModel],
|
||||
) -> dict:
|
||||
provider_model = provider_models.get(item.provider_model_id)
|
||||
video_model = video_models.get(item.video_model_id)
|
||||
return {
|
||||
"id": item.id,
|
||||
"videoModelId": item.video_model_id,
|
||||
"videoModelName": video_model.model_name if video_model else "",
|
||||
"providerModelId": item.provider_model_id,
|
||||
"providerModelName": provider_model.model_name if provider_model else "",
|
||||
"routingPriority": item.routing_priority,
|
||||
"isPrimary": item.is_primary,
|
||||
"status": item.status,
|
||||
"timeoutSecondsOverride": item.timeout_seconds_override,
|
||||
}
|
||||
|
||||
def _serialize_binding_single(self, item: VideoModelSupplierBinding) -> dict:
|
||||
provider_models = self.repository.provider_models()
|
||||
video_models = {
|
||||
row.id: row for row in self.repository.list_video_models().all()
|
||||
}
|
||||
return self.serialize_binding(item, provider_models, video_models)
|
||||
|
||||
95
backend/app/modules/video_tasks/repository.py
Normal file
95
backend/app/modules/video_tasks/repository.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.entities import (
|
||||
MediaAsset,
|
||||
PricingRule,
|
||||
ProviderAccount,
|
||||
ProviderModel,
|
||||
VideoGenerationTask,
|
||||
VideoModel,
|
||||
VideoModelSupplierBinding,
|
||||
VideoTaskEvent,
|
||||
)
|
||||
|
||||
|
||||
class VideoTasksRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_video_model(self, model_id: int) -> VideoModel | None:
|
||||
return self.db.scalar(select(VideoModel).where(VideoModel.id == model_id))
|
||||
|
||||
def get_active_pricing(self, video_model_id: int) -> PricingRule | None:
|
||||
now = datetime.utcnow()
|
||||
return self.db.scalar(
|
||||
select(PricingRule)
|
||||
.where(
|
||||
PricingRule.video_model_id == video_model_id,
|
||||
PricingRule.status == 1,
|
||||
PricingRule.effective_at <= now,
|
||||
or_(PricingRule.expired_at.is_(None), PricingRule.expired_at > now),
|
||||
)
|
||||
.order_by(PricingRule.version_no.desc(), PricingRule.id.desc())
|
||||
)
|
||||
|
||||
def get_bindings(self, video_model_id: int) -> list[VideoModelSupplierBinding]:
|
||||
return (
|
||||
self.db.query(VideoModelSupplierBinding)
|
||||
.filter(
|
||||
VideoModelSupplierBinding.video_model_id == video_model_id,
|
||||
VideoModelSupplierBinding.status == 1,
|
||||
)
|
||||
.order_by(
|
||||
VideoModelSupplierBinding.is_primary.desc(),
|
||||
VideoModelSupplierBinding.routing_priority.asc(),
|
||||
VideoModelSupplierBinding.id.asc(),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_provider_model(self, provider_model_id: int) -> ProviderModel | None:
|
||||
return self.db.scalar(select(ProviderModel).where(ProviderModel.id == provider_model_id))
|
||||
|
||||
def get_provider_account(self, provider_account_id: int) -> ProviderAccount | None:
|
||||
return self.db.scalar(select(ProviderAccount).where(ProviderAccount.id == provider_account_id))
|
||||
|
||||
def list_assets(self, user_id: int, asset_ids: list[int]) -> list[MediaAsset]:
|
||||
if not asset_ids:
|
||||
return []
|
||||
return (
|
||||
self.db.query(MediaAsset)
|
||||
.filter(
|
||||
MediaAsset.user_id == user_id,
|
||||
MediaAsset.id.in_(asset_ids),
|
||||
MediaAsset.status == "active",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
def list_tasks(self, user_id: int):
|
||||
return (
|
||||
self.db.query(VideoGenerationTask)
|
||||
.filter(VideoGenerationTask.user_id == user_id, VideoGenerationTask.user_visible == 1)
|
||||
.order_by(VideoGenerationTask.id.desc())
|
||||
)
|
||||
|
||||
def get_task(self, user_id: int, task_no: str) -> VideoGenerationTask | None:
|
||||
return self.db.scalar(
|
||||
select(VideoGenerationTask).where(
|
||||
VideoGenerationTask.user_id == user_id,
|
||||
VideoGenerationTask.task_no == task_no,
|
||||
)
|
||||
)
|
||||
|
||||
def get_task_by_id(self, task_id: int) -> VideoGenerationTask | None:
|
||||
return self.db.scalar(select(VideoGenerationTask).where(VideoGenerationTask.id == task_id))
|
||||
|
||||
def task_events(self, task_id: int):
|
||||
return (
|
||||
self.db.query(VideoTaskEvent)
|
||||
.filter(VideoTaskEvent.video_task_id == task_id)
|
||||
.order_by(VideoTaskEvent.id.asc())
|
||||
)
|
||||
104
backend/app/modules/video_tasks/router.py
Normal file
104
backend/app/modules/video_tasks/router.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.db.session import get_db
|
||||
from app.common.responses.api_response import success_response
|
||||
from app.common.security.deps import get_current_user, require_admin_permission
|
||||
from app.models.entities import User
|
||||
from app.modules.video_tasks.schema import CreateVideoTaskRequest
|
||||
from app.modules.video_tasks.service import VideoTasksService
|
||||
|
||||
|
||||
router = APIRouter(tags=["video-tasks"])
|
||||
|
||||
|
||||
@router.post("/api/v1/video-tasks")
|
||||
def create_task(
|
||||
payload: CreateVideoTaskRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(VideoTasksService(db).create_task(current_user.id, payload))
|
||||
|
||||
|
||||
@router.get("/api/v1/video-tasks")
|
||||
def list_tasks(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(VideoTasksService(db).list_tasks(current_user.id))
|
||||
|
||||
|
||||
@router.get("/api/v1/video-tasks/{task_no}")
|
||||
def get_task_detail(
|
||||
task_no: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(VideoTasksService(db).get_task_detail(current_user.id, task_no))
|
||||
|
||||
|
||||
@router.post("/api/v1/video-tasks/{task_no}/retry")
|
||||
def retry_task(
|
||||
task_no: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(VideoTasksService(db).retry_task(current_user.id, task_no))
|
||||
|
||||
|
||||
@router.post("/api/v1/video-tasks/{task_no}/cancel")
|
||||
def cancel_task(
|
||||
task_no: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(VideoTasksService(db).cancel_task(current_user.id, task_no))
|
||||
|
||||
|
||||
@router.delete("/api/v1/video-tasks/{task_no}")
|
||||
def delete_task(
|
||||
task_no: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(VideoTasksService(db).delete_task(current_user.id, task_no))
|
||||
|
||||
|
||||
admin_router = APIRouter(prefix="/api/v1/admin/video-tasks", tags=["admin-video-tasks"])
|
||||
|
||||
|
||||
@admin_router.get("")
|
||||
def admin_list_tasks(
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(VideoTasksService(db).admin_list_tasks())
|
||||
|
||||
|
||||
@admin_router.get("/{task_id}")
|
||||
def admin_get_task(
|
||||
task_id: int,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(VideoTasksService(db).admin_get_task(task_id))
|
||||
|
||||
|
||||
@admin_router.post("/{task_id}/retry")
|
||||
def admin_retry_task(
|
||||
task_id: int,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(VideoTasksService(db).admin_retry_task(task_id))
|
||||
|
||||
|
||||
@admin_router.post("/{task_id}/refund")
|
||||
def admin_refund_task(
|
||||
task_id: int,
|
||||
_=Depends(require_admin_permission()),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(VideoTasksService(db).admin_refund_task(task_id))
|
||||
|
||||
13
backend/app/modules/video_tasks/schema.py
Normal file
13
backend/app/modules/video_tasks/schema.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CreateVideoTaskRequest(BaseModel):
|
||||
videoModelId: int
|
||||
prompt: str = Field(min_length=1, max_length=4000)
|
||||
durationSeconds: int = Field(ge=4, le=15)
|
||||
resolution: str = "1280x720"
|
||||
ratio: str = "16:9"
|
||||
generateAudio: bool = False
|
||||
referenceImageAssetIds: list[int] = Field(default_factory=list)
|
||||
referenceVideoAssetIds: list[int] = Field(default_factory=list)
|
||||
referenceAudioAssetIds: list[int] = Field(default_factory=list)
|
||||
343
backend/app/modules/video_tasks/service.py
Normal file
343
backend/app/modules/video_tasks/service.py
Normal file
@@ -0,0 +1,343 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.errors.app_error import BusinessAppError, NotFoundAppError
|
||||
from app.common.utils.id_gen import new_order_no
|
||||
from app.core.providers import build_adapter
|
||||
from app.core.storage import storage_service
|
||||
from app.models.entities import MediaAsset, VideoGenerationTask, VideoTaskEvent
|
||||
from app.modules.video_tasks.repository import VideoTasksRepository
|
||||
from app.modules.wallets.service import WalletService
|
||||
|
||||
|
||||
class VideoTasksService:
|
||||
FINAL_STATUSES = {"succeeded", "failed", "cancelled", "timed_out"}
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.repository = VideoTasksRepository(db)
|
||||
self.wallet_service = WalletService(db)
|
||||
|
||||
def create_task(self, user_id: int, payload) -> dict:
|
||||
video_model = self.repository.get_video_model(payload.videoModelId)
|
||||
if not video_model or video_model.status != 1:
|
||||
raise BusinessAppError("video model unavailable", code=50002)
|
||||
pricing = self.repository.get_active_pricing(video_model.id)
|
||||
if not pricing:
|
||||
raise BusinessAppError("pricing rule unavailable", code=50003)
|
||||
binding = self._select_binding(video_model.id)
|
||||
provider_model = self.repository.get_provider_model(binding.provider_model_id)
|
||||
provider_account = self.repository.get_provider_account(provider_model.provider_account_id)
|
||||
estimated_points = max(
|
||||
pricing.minimum_points,
|
||||
pricing.points_per_second * payload.durationSeconds,
|
||||
)
|
||||
normalized = self._build_normalized_payload(user_id, payload, provider_model.id, provider_account.id)
|
||||
task = VideoGenerationTask(
|
||||
task_no=new_order_no("vt"),
|
||||
user_id=user_id,
|
||||
video_model_id=video_model.id,
|
||||
provider_account_id=provider_account.id,
|
||||
provider_model_id=provider_model.id,
|
||||
provider_binding_id=binding.id,
|
||||
pricing_rule_id=pricing.id,
|
||||
task_status="queued",
|
||||
generation_mode=self._infer_generation_mode(payload),
|
||||
prompt_text=payload.prompt,
|
||||
request_payload=normalized,
|
||||
duration_seconds=payload.durationSeconds,
|
||||
ratio=payload.ratio,
|
||||
resolution=payload.resolution,
|
||||
generate_audio=payload.generateAudio,
|
||||
estimated_points=estimated_points,
|
||||
frozen_points=estimated_points,
|
||||
)
|
||||
self.db.add(task)
|
||||
self.db.flush()
|
||||
self.wallet_service.freeze_points(
|
||||
user_id,
|
||||
estimated_points,
|
||||
related_type="video_task",
|
||||
related_id=task.id,
|
||||
remark=f"freeze for {task.task_no}",
|
||||
)
|
||||
self._add_event(task.id, "created", "task created")
|
||||
self._submit_task(task, provider_account, provider_model)
|
||||
self.db.commit()
|
||||
return {
|
||||
"taskNo": task.task_no,
|
||||
"taskStatus": task.task_status,
|
||||
"estimatedPoints": task.estimated_points,
|
||||
"frozenPoints": task.frozen_points,
|
||||
}
|
||||
|
||||
def list_tasks(self, user_id: int) -> list[dict]:
|
||||
tasks = self.repository.list_tasks(user_id).limit(100).all()
|
||||
for task in tasks:
|
||||
self._refresh_task_progress(task)
|
||||
self.db.commit()
|
||||
return [self.serialize_task_summary(task) for task in tasks if task.user_visible]
|
||||
|
||||
def get_task_detail(self, user_id: int, task_no: str) -> dict:
|
||||
task = self.repository.get_task(user_id, task_no)
|
||||
if not task:
|
||||
raise NotFoundAppError("task not found", code=50006)
|
||||
self._refresh_task_progress(task)
|
||||
self.db.commit()
|
||||
return self.serialize_task_detail(task)
|
||||
|
||||
def retry_task(self, user_id: int, task_no: str) -> dict:
|
||||
task = self.repository.get_task(user_id, task_no)
|
||||
if not task:
|
||||
raise NotFoundAppError("task not found", code=50006)
|
||||
payload = task.request_payload
|
||||
class RetryPayload:
|
||||
videoModelId = payload["videoModelId"]
|
||||
prompt = payload["prompt"]
|
||||
durationSeconds = payload["durationSeconds"]
|
||||
resolution = payload["resolution"]
|
||||
ratio = payload["ratio"]
|
||||
generateAudio = payload["generateAudio"]
|
||||
referenceImageAssetIds = payload.get("referenceImageAssetIds", [])
|
||||
referenceVideoAssetIds = payload.get("referenceVideoAssetIds", [])
|
||||
referenceAudioAssetIds = payload.get("referenceAudioAssetIds", [])
|
||||
|
||||
return self.create_task(user_id, RetryPayload)
|
||||
|
||||
def cancel_task(self, user_id: int, task_no: str) -> dict:
|
||||
task = self.repository.get_task(user_id, task_no)
|
||||
if not task:
|
||||
raise NotFoundAppError("task not found", code=50006)
|
||||
if task.task_status in self.FINAL_STATUSES:
|
||||
return {"taskNo": task.task_no, "taskStatus": task.task_status}
|
||||
task.task_status = "cancelled"
|
||||
task.finished_at = datetime.utcnow()
|
||||
self.wallet_service.release_frozen_points(
|
||||
user_id,
|
||||
task.frozen_points,
|
||||
related_type="video_task",
|
||||
related_id=task.id,
|
||||
remark=f"cancel {task.task_no}",
|
||||
)
|
||||
self._add_event(task.id, "cancelled", "task cancelled")
|
||||
self.db.commit()
|
||||
return {"taskNo": task.task_no, "taskStatus": task.task_status}
|
||||
|
||||
def delete_task(self, user_id: int, task_no: str) -> dict:
|
||||
task = self.repository.get_task(user_id, task_no)
|
||||
if not task:
|
||||
raise NotFoundAppError("task not found", code=50006)
|
||||
task.user_visible = False
|
||||
task.user_deleted_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
return {"taskNo": task.task_no, "deleted": True}
|
||||
|
||||
def admin_list_tasks(self) -> list[dict]:
|
||||
tasks = self.db.query(VideoGenerationTask).order_by(VideoGenerationTask.id.desc()).limit(200).all()
|
||||
for task in tasks:
|
||||
self._refresh_task_progress(task)
|
||||
self.db.commit()
|
||||
return [self.serialize_task_summary(task) for task in tasks]
|
||||
|
||||
def admin_get_task(self, task_id: int) -> dict:
|
||||
task = self.repository.get_task_by_id(task_id)
|
||||
if not task:
|
||||
raise NotFoundAppError("task not found", code=50006)
|
||||
self._refresh_task_progress(task)
|
||||
self.db.commit()
|
||||
return self.serialize_task_detail(task)
|
||||
|
||||
def admin_retry_task(self, task_id: int) -> dict:
|
||||
task = self.repository.get_task_by_id(task_id)
|
||||
if not task:
|
||||
raise NotFoundAppError("task not found", code=50006)
|
||||
return self.retry_task(task.user_id, task.task_no)
|
||||
|
||||
def admin_refund_task(self, task_id: int) -> dict:
|
||||
task = self.repository.get_task_by_id(task_id)
|
||||
if not task:
|
||||
raise NotFoundAppError("task not found", code=50006)
|
||||
tx = self.wallet_service.add_points(
|
||||
task.user_id,
|
||||
task.final_points or task.frozen_points,
|
||||
biz_type="refund",
|
||||
related_type="video_task",
|
||||
related_id=task.id,
|
||||
remark=f"manual refund for {task.task_no}",
|
||||
operator_type="admin",
|
||||
)
|
||||
self._add_event(task.id, "refund", "manual refund")
|
||||
self.db.commit()
|
||||
return {"taskNo": task.task_no, "refundTransactionNo": tx.transaction_no}
|
||||
|
||||
def _select_binding(self, video_model_id: int):
|
||||
bindings = self.repository.get_bindings(video_model_id)
|
||||
if not bindings:
|
||||
raise BusinessAppError("no available provider", code=50003)
|
||||
return bindings[0]
|
||||
|
||||
def _build_normalized_payload(self, user_id: int, payload, provider_model_id: int, provider_account_id: int) -> dict:
|
||||
image_assets = self.repository.list_assets(user_id, payload.referenceImageAssetIds)
|
||||
video_assets = self.repository.list_assets(user_id, payload.referenceVideoAssetIds)
|
||||
audio_assets = self.repository.list_assets(user_id, payload.referenceAudioAssetIds)
|
||||
return {
|
||||
"videoModelId": payload.videoModelId,
|
||||
"providerModelId": provider_model_id,
|
||||
"providerAccountId": provider_account_id,
|
||||
"prompt": payload.prompt,
|
||||
"durationSeconds": payload.durationSeconds,
|
||||
"resolution": payload.resolution,
|
||||
"ratio": payload.ratio,
|
||||
"generateAudio": payload.generateAudio,
|
||||
"referenceImageAssetIds": payload.referenceImageAssetIds,
|
||||
"referenceVideoAssetIds": payload.referenceVideoAssetIds,
|
||||
"referenceAudioAssetIds": payload.referenceAudioAssetIds,
|
||||
"referenceImages": [{"assetId": item.id, "url": item.public_url} for item in image_assets],
|
||||
"referenceVideos": [{"assetId": item.id, "url": item.public_url} for item in video_assets],
|
||||
"referenceAudios": [{"assetId": item.id, "url": item.public_url} for item in audio_assets],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _infer_generation_mode(payload) -> str:
|
||||
if payload.referenceImageAssetIds or payload.referenceVideoAssetIds or payload.referenceAudioAssetIds:
|
||||
return "multimodal"
|
||||
return "text_to_video"
|
||||
|
||||
def _submit_task(self, task: VideoGenerationTask, provider_account, provider_model) -> None:
|
||||
adapter = build_adapter(provider_account, provider_model)
|
||||
result = adapter.submit_task(task.request_payload)
|
||||
task.external_task_id = result["externalTaskId"]
|
||||
task.task_status = result["normalizedStatus"]
|
||||
task.submitted_at = datetime.utcnow()
|
||||
task.response_payload = result
|
||||
self._add_event(task.id, "submitted", "task submitted to provider")
|
||||
|
||||
def _refresh_task_progress(self, task: VideoGenerationTask) -> None:
|
||||
if task.task_status in self.FINAL_STATUSES:
|
||||
return
|
||||
provider_model = self.repository.get_provider_model(task.provider_model_id)
|
||||
provider_account = self.repository.get_provider_account(task.provider_account_id)
|
||||
adapter = build_adapter(provider_account, provider_model)
|
||||
result = adapter.query_task(task)
|
||||
task.response_payload = result
|
||||
status = result["normalizedStatus"]
|
||||
if status == "succeeded" and task.result_asset_id is None:
|
||||
content = adapter.download_result(task)
|
||||
stored = storage_service.save_bytes(
|
||||
content,
|
||||
filename=f"{task.task_no}.mp4",
|
||||
folder="generated/videos",
|
||||
)
|
||||
asset = MediaAsset(
|
||||
asset_no=new_order_no("asset"),
|
||||
user_id=task.user_id,
|
||||
media_type="video",
|
||||
source_type="generated",
|
||||
original_filename=f"{task.task_no}.mp4",
|
||||
mime_type="video/mp4",
|
||||
file_ext=".mp4",
|
||||
file_size=stored["file_size"],
|
||||
storage_provider="local",
|
||||
storage_bucket="local",
|
||||
storage_key=stored["storage_key"],
|
||||
public_url=stored["public_url"],
|
||||
sha256=stored["sha256"],
|
||||
status="active",
|
||||
)
|
||||
self.db.add(asset)
|
||||
self.db.flush()
|
||||
task.result_asset_id = asset.id
|
||||
task.final_points = task.estimated_points
|
||||
task.task_status = "succeeded"
|
||||
task.finished_at = datetime.utcnow()
|
||||
self.wallet_service.consume_frozen_points(
|
||||
task.user_id,
|
||||
task.frozen_points,
|
||||
related_type="video_task",
|
||||
related_id=task.id,
|
||||
remark=f"consume for {task.task_no}",
|
||||
)
|
||||
self.wallet_service.try_issue_invite_reward(task.user_id, task.id, task.final_points)
|
||||
self._add_event(task.id, "succeeded", "task succeeded")
|
||||
elif status == "failed":
|
||||
task.task_status = "failed"
|
||||
task.fail_reason = result.get("rawResponse", {}).get("message", "provider failed")
|
||||
task.finished_at = datetime.utcnow()
|
||||
self.wallet_service.release_frozen_points(
|
||||
task.user_id,
|
||||
task.frozen_points,
|
||||
related_type="video_task",
|
||||
related_id=task.id,
|
||||
remark=f"release for {task.task_no}",
|
||||
)
|
||||
self._add_event(task.id, "failed", "task failed")
|
||||
else:
|
||||
task.task_status = status
|
||||
|
||||
def _add_event(self, task_id: int, event_type: str, message: str) -> None:
|
||||
self.db.add(
|
||||
VideoTaskEvent(
|
||||
video_task_id=task_id,
|
||||
event_type=event_type,
|
||||
event_message=message,
|
||||
payload=None,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
|
||||
def serialize_task_summary(self, task: VideoGenerationTask) -> dict:
|
||||
return {
|
||||
"id": task.id,
|
||||
"taskNo": task.task_no,
|
||||
"taskStatus": task.task_status,
|
||||
"durationSeconds": task.duration_seconds,
|
||||
"estimatedPoints": task.estimated_points,
|
||||
"finalPoints": task.final_points,
|
||||
"resultVideoUrl": self._result_url(task),
|
||||
"failReason": task.fail_reason,
|
||||
"createdAt": task.created_at.isoformat(),
|
||||
"finishedAt": task.finished_at.isoformat() if task.finished_at else None,
|
||||
}
|
||||
|
||||
def serialize_task_detail(self, task: VideoGenerationTask) -> dict:
|
||||
provider_account = self.repository.get_provider_account(task.provider_account_id)
|
||||
provider_model = self.repository.get_provider_model(task.provider_model_id)
|
||||
video_model = self.repository.get_video_model(task.video_model_id)
|
||||
events = self.repository.task_events(task.id).all()
|
||||
return {
|
||||
**self.serialize_task_summary(task),
|
||||
"videoModel": {
|
||||
"id": task.video_model_id,
|
||||
"name": video_model.model_name if video_model else "",
|
||||
},
|
||||
"provider": {
|
||||
"providerCode": provider_account.provider_code if provider_account else "",
|
||||
"providerName": provider_account.provider_name if provider_account else "",
|
||||
"modelCode": provider_model.model_code if provider_model else "",
|
||||
"modelName": provider_model.model_name if provider_model else "",
|
||||
},
|
||||
"ratio": task.ratio,
|
||||
"resolution": task.resolution,
|
||||
"prompt": task.prompt_text or "",
|
||||
"requestPayload": task.request_payload,
|
||||
"responsePayload": task.response_payload,
|
||||
"events": [
|
||||
{
|
||||
"eventType": item.event_type,
|
||||
"eventMessage": item.event_message,
|
||||
"createdAt": item.created_at.isoformat(),
|
||||
}
|
||||
for item in events
|
||||
],
|
||||
}
|
||||
|
||||
def _result_url(self, task: VideoGenerationTask) -> str:
|
||||
if not task.result_asset_id:
|
||||
return ""
|
||||
asset = self.db.scalar(select(MediaAsset).where(MediaAsset.id == task.result_asset_id))
|
||||
return asset.public_url if asset else ""
|
||||
43
backend/app/modules/wallets/repository.py
Normal file
43
backend/app/modules/wallets/repository.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.entities import RechargeOrder, RechargePlan, RedeemCode, Wallet, WalletTransaction
|
||||
|
||||
|
||||
class WalletRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def lock_wallet(self, user_id: int) -> Wallet:
|
||||
return self.db.execute(
|
||||
select(Wallet).where(Wallet.user_id == user_id).with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
def wallet_transactions(self, user_id: int):
|
||||
return self.db.query(WalletTransaction).filter(WalletTransaction.user_id == user_id)
|
||||
|
||||
def recharge_orders(self, user_id: int):
|
||||
return self.db.query(RechargeOrder).filter(RechargeOrder.user_id == user_id)
|
||||
|
||||
def get_recharge_plan(self, plan_id: int) -> RechargePlan | None:
|
||||
return self.db.scalar(select(RechargePlan).where(RechargePlan.id == plan_id))
|
||||
|
||||
def get_order_by_no(self, order_no: str) -> RechargeOrder | None:
|
||||
return self.db.scalar(select(RechargeOrder).where(RechargeOrder.order_no == order_no))
|
||||
|
||||
def lock_redeem_code(self, redeem_code: str) -> RedeemCode | None:
|
||||
return self.db.execute(
|
||||
select(RedeemCode)
|
||||
.where(RedeemCode.redeem_code == redeem_code)
|
||||
.with_for_update()
|
||||
).scalar_one_or_none()
|
||||
|
||||
def latest_transactions_count(self, user_id: int) -> int:
|
||||
return (
|
||||
self.db.query(func.count(WalletTransaction.id))
|
||||
.filter(WalletTransaction.user_id == user_id)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
70
backend/app/modules/wallets/router.py
Normal file
70
backend/app/modules/wallets/router.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.db.session import get_db
|
||||
from app.common.responses.api_response import success_response
|
||||
from app.common.security.deps import get_current_user
|
||||
from app.models.entities import User
|
||||
from app.modules.wallets.schema import CreateRechargeOrderRequest, ExchangeRedeemCodeRequest
|
||||
from app.modules.wallets.service import WalletService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/wallet", tags=["wallet"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
def get_wallet(current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
return success_response(WalletService(db).get_wallet_summary(current_user.id))
|
||||
|
||||
|
||||
@router.get("/transactions")
|
||||
def list_transactions(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(WalletService(db).list_transactions(current_user.id))
|
||||
|
||||
|
||||
@router.post("/recharge-orders")
|
||||
def create_recharge_order(
|
||||
payload: CreateRechargeOrderRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(WalletService(db).create_recharge_order(current_user.id, payload))
|
||||
|
||||
|
||||
@router.get("/recharge-options")
|
||||
def get_recharge_options(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(WalletService(db).recharge_options())
|
||||
|
||||
|
||||
@router.get("/recharge-orders")
|
||||
def list_recharge_orders(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(WalletService(db).list_recharge_orders(current_user.id))
|
||||
|
||||
|
||||
@router.post("/redeem-codes/exchange")
|
||||
def exchange_redeem_code(
|
||||
payload: ExchangeRedeemCodeRequest,
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(
|
||||
WalletService(db).exchange_redeem_code(current_user.id, payload, request)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/redeem-records")
|
||||
def list_redeem_records(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return success_response(WalletService(db).list_redeem_records(current_user.id))
|
||||
22
backend/app/modules/wallets/schema.py
Normal file
22
backend/app/modules/wallets/schema.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from decimal import Decimal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CreateRechargeOrderRequest(BaseModel):
|
||||
rechargePlanId: int
|
||||
paymentChannelCode: str
|
||||
|
||||
|
||||
class ExchangeRedeemCodeRequest(BaseModel):
|
||||
redeemCode: str
|
||||
|
||||
|
||||
class WalletAdjustRequest(BaseModel):
|
||||
amount_points: int = Field(alias="amountPoints")
|
||||
reason: str
|
||||
|
||||
|
||||
class MockPayRequest(BaseModel):
|
||||
orderNo: str
|
||||
paidAmount: Decimal | None = Field(default=None)
|
||||
430
backend/app/modules/wallets/service.py
Normal file
430
backend/app/modules/wallets/service.py
Normal file
@@ -0,0 +1,430 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.common.config.settings import get_settings
|
||||
from app.common.errors.app_error import BusinessAppError, NotFoundAppError
|
||||
from app.common.utils.id_gen import new_order_no
|
||||
from app.models.entities import (
|
||||
GrowthRewardRule,
|
||||
InviteRelation,
|
||||
PaymentChannel,
|
||||
RechargeOrder,
|
||||
RechargePlan,
|
||||
RedeemCode,
|
||||
Wallet,
|
||||
WalletTransaction,
|
||||
)
|
||||
from app.modules.wallets.repository import WalletRepository
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class WalletService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.repository = WalletRepository(db)
|
||||
|
||||
def get_wallet_summary(self, user_id: int) -> dict:
|
||||
wallet = self.repository.lock_wallet(user_id)
|
||||
return {
|
||||
"balancePoints": wallet.balance_points,
|
||||
"frozenPoints": wallet.frozen_points,
|
||||
"availablePoints": wallet.balance_points - wallet.frozen_points,
|
||||
"pointExchangeRatio": settings.point_exchange_ratio,
|
||||
}
|
||||
|
||||
def list_transactions(self, user_id: int) -> list[dict]:
|
||||
records = (
|
||||
self.repository.wallet_transactions(user_id)
|
||||
.order_by(WalletTransaction.id.desc())
|
||||
.limit(100)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
{
|
||||
"transactionNo": item.transaction_no,
|
||||
"bizType": item.biz_type,
|
||||
"direction": item.direction,
|
||||
"amountPoints": item.amount_points,
|
||||
"remark": item.remark,
|
||||
"createdAt": item.created_at.isoformat(),
|
||||
}
|
||||
for item in records
|
||||
]
|
||||
|
||||
def list_recharge_orders(self, user_id: int) -> list[dict]:
|
||||
records = (
|
||||
self.repository.recharge_orders(user_id)
|
||||
.order_by(RechargeOrder.id.desc())
|
||||
.limit(100)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
{
|
||||
"orderNo": item.order_no,
|
||||
"payAmount": f"{item.pay_amount:.2f}",
|
||||
"arrivalPoints": item.arrival_points,
|
||||
"status": item.status,
|
||||
"paymentChannelCode": item.payment_channel_code,
|
||||
"paidAt": item.paid_at.isoformat() if item.paid_at else None,
|
||||
"createdAt": item.created_at.isoformat(),
|
||||
}
|
||||
for item in records
|
||||
]
|
||||
|
||||
def recharge_options(self) -> dict:
|
||||
plans = (
|
||||
self.db.query(RechargePlan)
|
||||
.filter(RechargePlan.status == 1)
|
||||
.order_by(RechargePlan.sort_order.asc(), RechargePlan.id.asc())
|
||||
.all()
|
||||
)
|
||||
channels = (
|
||||
self.db.query(PaymentChannel)
|
||||
.filter(PaymentChannel.status == 1)
|
||||
.order_by(PaymentChannel.sort_order.asc(), PaymentChannel.id.asc())
|
||||
.all()
|
||||
)
|
||||
return {
|
||||
"plans": [
|
||||
{
|
||||
"id": item.id,
|
||||
"name": item.name,
|
||||
"payAmount": f"{item.pay_amount:.2f}",
|
||||
"arrivalPoints": item.give_points + item.bonus_points,
|
||||
"bonusPoints": item.bonus_points,
|
||||
}
|
||||
for item in plans
|
||||
],
|
||||
"channels": [
|
||||
{
|
||||
"id": item.id,
|
||||
"channelCode": item.channel_code,
|
||||
"channelName": item.channel_name,
|
||||
}
|
||||
for item in channels
|
||||
],
|
||||
}
|
||||
|
||||
def list_redeem_records(self, user_id: int) -> list[dict]:
|
||||
records = (
|
||||
self.db.query(RedeemCode)
|
||||
.filter(RedeemCode.used_by_user_id == user_id)
|
||||
.order_by(RedeemCode.id.desc())
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
{
|
||||
"redeemCode": item.redeem_code,
|
||||
"points": item.points,
|
||||
"usedAt": item.used_at.isoformat() if item.used_at else None,
|
||||
}
|
||||
for item in records
|
||||
]
|
||||
|
||||
def create_recharge_order(self, user_id: int, payload) -> dict:
|
||||
plan = self.repository.get_recharge_plan(payload.rechargePlanId)
|
||||
if not plan or plan.status != 1:
|
||||
raise NotFoundAppError("recharge plan not found", code=30001)
|
||||
channel = self.db.scalar(
|
||||
select(PaymentChannel).where(
|
||||
PaymentChannel.channel_code == payload.paymentChannelCode
|
||||
)
|
||||
)
|
||||
if not channel or channel.status != 1:
|
||||
raise NotFoundAppError("payment channel not found", code=30002)
|
||||
arrival_points = plan.give_points + plan.bonus_points
|
||||
order = RechargeOrder(
|
||||
order_no=new_order_no("rc"),
|
||||
user_id=user_id,
|
||||
recharge_plan_id=plan.id,
|
||||
payment_channel_id=channel.id,
|
||||
payment_channel_code=channel.channel_code,
|
||||
pay_amount=plan.pay_amount,
|
||||
point_ratio_snapshot=plan.point_ratio,
|
||||
give_points=plan.give_points,
|
||||
bonus_points=plan.bonus_points,
|
||||
arrival_points=arrival_points,
|
||||
status="pending",
|
||||
)
|
||||
self.db.add(order)
|
||||
self.db.commit()
|
||||
return {
|
||||
"orderNo": order.order_no,
|
||||
"payAmount": f"{order.pay_amount:.2f}",
|
||||
"arrivalPoints": order.arrival_points,
|
||||
"payUrl": f"/api/v1/payments/mock-pay?orderNo={order.order_no}",
|
||||
}
|
||||
|
||||
def handle_mock_payment(self, order_no: str) -> dict:
|
||||
order = self.repository.get_order_by_no(order_no)
|
||||
if not order:
|
||||
raise NotFoundAppError("order not found", code=30001)
|
||||
if order.status == "paid":
|
||||
return {"orderNo": order.order_no, "status": order.status, "idempotent": True}
|
||||
|
||||
wallet = self.repository.lock_wallet(order.user_id)
|
||||
before_balance = wallet.balance_points
|
||||
wallet.balance_points += order.arrival_points
|
||||
wallet.total_recharged_points += order.arrival_points
|
||||
order.status = "paid"
|
||||
order.paid_at = datetime.utcnow()
|
||||
self.db.add(
|
||||
WalletTransaction(
|
||||
transaction_no=new_order_no("wt"),
|
||||
user_id=order.user_id,
|
||||
wallet_id=wallet.id,
|
||||
biz_type="recharge",
|
||||
direction="in",
|
||||
amount_points=order.arrival_points,
|
||||
balance_before_points=before_balance,
|
||||
balance_after_points=wallet.balance_points,
|
||||
frozen_before_points=wallet.frozen_points,
|
||||
frozen_after_points=wallet.frozen_points,
|
||||
related_type="recharge_order",
|
||||
related_id=order.id,
|
||||
remark=f"recharge order {order.order_no}",
|
||||
operator_type="system",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
self.db.commit()
|
||||
return {
|
||||
"orderNo": order.order_no,
|
||||
"status": order.status,
|
||||
"arrivalPoints": order.arrival_points,
|
||||
"idempotent": False,
|
||||
}
|
||||
|
||||
def exchange_redeem_code(self, user_id: int, payload, request) -> dict:
|
||||
redeem_code = self.repository.lock_redeem_code(payload.redeemCode)
|
||||
if not redeem_code:
|
||||
raise BusinessAppError("redeem code not found", code=20004)
|
||||
if redeem_code.status == "used":
|
||||
raise BusinessAppError("redeem code already used", code=20005)
|
||||
if redeem_code.status in {"expired", "disabled"}:
|
||||
raise BusinessAppError("redeem code unavailable", code=20006)
|
||||
if redeem_code.expired_at and redeem_code.expired_at < datetime.utcnow():
|
||||
redeem_code.status = "expired"
|
||||
self.db.commit()
|
||||
raise BusinessAppError("redeem code expired", code=20006)
|
||||
|
||||
wallet = self.repository.lock_wallet(user_id)
|
||||
before_balance = wallet.balance_points
|
||||
wallet.balance_points += redeem_code.points
|
||||
tx = WalletTransaction(
|
||||
transaction_no=new_order_no("wt"),
|
||||
user_id=user_id,
|
||||
wallet_id=wallet.id,
|
||||
biz_type="redeem_code",
|
||||
direction="in",
|
||||
amount_points=redeem_code.points,
|
||||
balance_before_points=before_balance,
|
||||
balance_after_points=wallet.balance_points,
|
||||
frozen_before_points=wallet.frozen_points,
|
||||
frozen_after_points=wallet.frozen_points,
|
||||
related_type="redeem_code",
|
||||
related_id=redeem_code.id,
|
||||
remark=f"redeem {redeem_code.redeem_code}",
|
||||
operator_type="user",
|
||||
operator_id=user_id,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
self.db.add(tx)
|
||||
self.db.flush()
|
||||
redeem_code.status = "used"
|
||||
redeem_code.used_by_user_id = user_id
|
||||
redeem_code.wallet_transaction_id = tx.id
|
||||
redeem_code.used_at = datetime.utcnow()
|
||||
redeem_code.used_ip = request.client.host if request.client else ""
|
||||
redeem_code.used_user_agent = request.headers.get("user-agent", "")
|
||||
self.db.commit()
|
||||
return {
|
||||
"redeemCode": redeem_code.redeem_code,
|
||||
"points": redeem_code.points,
|
||||
"walletBalance": wallet.balance_points,
|
||||
}
|
||||
|
||||
def add_points(
|
||||
self,
|
||||
user_id: int,
|
||||
amount_points: int,
|
||||
*,
|
||||
biz_type: str,
|
||||
related_type: str,
|
||||
related_id: int | None,
|
||||
remark: str,
|
||||
operator_type: str = "system",
|
||||
operator_id: int | None = None,
|
||||
) -> WalletTransaction:
|
||||
wallet = self.repository.lock_wallet(user_id)
|
||||
before_balance = wallet.balance_points
|
||||
wallet.balance_points += amount_points
|
||||
if biz_type == "recharge":
|
||||
wallet.total_recharged_points += amount_points
|
||||
if biz_type in {"refund", "unfreeze"}:
|
||||
wallet.total_refunded_points += amount_points
|
||||
tx = WalletTransaction(
|
||||
transaction_no=new_order_no("wt"),
|
||||
user_id=user_id,
|
||||
wallet_id=wallet.id,
|
||||
biz_type=biz_type,
|
||||
direction="in",
|
||||
amount_points=amount_points,
|
||||
balance_before_points=before_balance,
|
||||
balance_after_points=wallet.balance_points,
|
||||
frozen_before_points=wallet.frozen_points,
|
||||
frozen_after_points=wallet.frozen_points,
|
||||
related_type=related_type,
|
||||
related_id=related_id,
|
||||
remark=remark,
|
||||
operator_type=operator_type,
|
||||
operator_id=operator_id,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
self.db.add(tx)
|
||||
self.db.flush()
|
||||
return tx
|
||||
|
||||
def freeze_points(self, user_id: int, amount_points: int, *, related_type: str, related_id: int | None, remark: str) -> None:
|
||||
wallet = self.repository.lock_wallet(user_id)
|
||||
available_points = wallet.balance_points - wallet.frozen_points
|
||||
if available_points < amount_points:
|
||||
raise BusinessAppError("insufficient balance", code=20001)
|
||||
balance_before = wallet.balance_points
|
||||
frozen_before = wallet.frozen_points
|
||||
wallet.frozen_points += amount_points
|
||||
self.db.add(
|
||||
WalletTransaction(
|
||||
transaction_no=new_order_no("wt"),
|
||||
user_id=user_id,
|
||||
wallet_id=wallet.id,
|
||||
biz_type="freeze",
|
||||
direction="freeze",
|
||||
amount_points=amount_points,
|
||||
balance_before_points=balance_before,
|
||||
balance_after_points=wallet.balance_points,
|
||||
frozen_before_points=frozen_before,
|
||||
frozen_after_points=wallet.frozen_points,
|
||||
related_type=related_type,
|
||||
related_id=related_id,
|
||||
remark=remark,
|
||||
operator_type="system",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
|
||||
def consume_frozen_points(self, user_id: int, amount_points: int, *, related_type: str, related_id: int | None, remark: str) -> None:
|
||||
wallet = self.repository.lock_wallet(user_id)
|
||||
if wallet.frozen_points < amount_points:
|
||||
raise BusinessAppError("frozen points not enough", code=20003)
|
||||
balance_before = wallet.balance_points
|
||||
frozen_before = wallet.frozen_points
|
||||
wallet.balance_points -= amount_points
|
||||
wallet.frozen_points -= amount_points
|
||||
wallet.total_consumed_points += amount_points
|
||||
self.db.add(
|
||||
WalletTransaction(
|
||||
transaction_no=new_order_no("wt"),
|
||||
user_id=user_id,
|
||||
wallet_id=wallet.id,
|
||||
biz_type="consume",
|
||||
direction="out",
|
||||
amount_points=amount_points,
|
||||
balance_before_points=balance_before,
|
||||
balance_after_points=wallet.balance_points,
|
||||
frozen_before_points=frozen_before,
|
||||
frozen_after_points=wallet.frozen_points,
|
||||
related_type=related_type,
|
||||
related_id=related_id,
|
||||
remark=remark,
|
||||
operator_type="system",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
|
||||
def release_frozen_points(self, user_id: int, amount_points: int, *, related_type: str, related_id: int | None, remark: str) -> None:
|
||||
wallet = self.repository.lock_wallet(user_id)
|
||||
if wallet.frozen_points < amount_points:
|
||||
amount_points = wallet.frozen_points
|
||||
balance_before = wallet.balance_points
|
||||
frozen_before = wallet.frozen_points
|
||||
wallet.frozen_points -= amount_points
|
||||
self.db.add(
|
||||
WalletTransaction(
|
||||
transaction_no=new_order_no("wt"),
|
||||
user_id=user_id,
|
||||
wallet_id=wallet.id,
|
||||
biz_type="unfreeze",
|
||||
direction="unfreeze",
|
||||
amount_points=amount_points,
|
||||
balance_before_points=balance_before,
|
||||
balance_after_points=wallet.balance_points,
|
||||
frozen_before_points=frozen_before,
|
||||
frozen_after_points=wallet.frozen_points,
|
||||
related_type=related_type,
|
||||
related_id=related_id,
|
||||
remark=remark,
|
||||
operator_type="system",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
|
||||
def try_issue_signup_reward(self, user_id: int) -> None:
|
||||
rule = self.db.scalar(
|
||||
select(GrowthRewardRule).where(GrowthRewardRule.rule_type == "signup_reward")
|
||||
)
|
||||
if not rule or not rule.enabled or rule.reward_points <= 0:
|
||||
return
|
||||
exists = self.db.scalar(
|
||||
select(WalletTransaction).where(
|
||||
WalletTransaction.user_id == user_id,
|
||||
WalletTransaction.biz_type == "signup_reward",
|
||||
)
|
||||
)
|
||||
if exists:
|
||||
return
|
||||
self.add_points(
|
||||
user_id,
|
||||
rule.reward_points,
|
||||
biz_type="signup_reward",
|
||||
related_type="growth_rule",
|
||||
related_id=rule.id,
|
||||
remark="signup reward",
|
||||
)
|
||||
|
||||
def try_issue_invite_reward(self, user_id: int, task_id: int, final_points: int) -> None:
|
||||
relation = self.db.scalar(
|
||||
select(InviteRelation).where(InviteRelation.invitee_user_id == user_id)
|
||||
)
|
||||
if not relation or relation.reward_status == "rewarded":
|
||||
return
|
||||
rule = self.db.scalar(
|
||||
select(GrowthRewardRule).where(GrowthRewardRule.rule_type == "invite_reward")
|
||||
)
|
||||
if (
|
||||
not rule
|
||||
or not rule.enabled
|
||||
or final_points <= 0
|
||||
or final_points < rule.min_consume_points
|
||||
):
|
||||
return
|
||||
tx = self.add_points(
|
||||
relation.inviter_user_id,
|
||||
rule.reward_points,
|
||||
biz_type="invite_reward",
|
||||
related_type="invite_relation",
|
||||
related_id=relation.id,
|
||||
remark="invite reward",
|
||||
)
|
||||
relation.reward_status = "rewarded"
|
||||
relation.reward_points = rule.reward_points
|
||||
relation.first_consumed_task_id = task_id
|
||||
relation.first_consumed_at = datetime.utcnow()
|
||||
relation.rewarded_at = datetime.utcnow()
|
||||
relation.reward_wallet_transaction_id = tx.id
|
||||
Reference in New Issue
Block a user