from __future__ import annotations from datetime import datetime, timedelta from decimal import Decimal from sqlalchemy import select from sqlalchemy.orm import Session from app.common.db.session import SessionLocal, engine from app.common.security.password import hash_password from app.common.utils.id_gen import new_invite_code, new_public_id from app.models import Base from app.models.entities import ( AdminUser, GrowthRewardRule, PaymentChannel, PricingRule, ProviderAccount, ProviderModel, RechargePlan, RedeemCode, SystemConfig, VideoModel, VideoModelSupplierBinding, ) def init_database() -> None: Base.metadata.create_all(bind=engine) with SessionLocal() as db: seed_defaults(db) def seed_defaults(db: Session) -> None: from app.common.config.settings import get_settings settings = get_settings() admin = db.scalar(select(AdminUser).where(AdminUser.username == settings.admin_username)) if not admin: admin = AdminUser( username=settings.admin_username, password_hash=hash_password(settings.admin_password), nickname=settings.admin_nickname, is_super_admin=True, status=1, ) db.add(admin) for rule_type, trigger, points in [ ("signup_reward", "on_register", 300), ("invite_reward", "on_first_consume", 500), ]: rule = db.scalar(select(GrowthRewardRule).where(GrowthRewardRule.rule_type == rule_type)) if not rule: db.add( GrowthRewardRule( rule_type=rule_type, enabled=True, reward_points=points, trigger_condition=trigger, min_consume_points=settings.invite_reward_min_consume_points, remark=rule_type, ) ) for channel_code, channel_name, provider_type, sort_order in [ ("alipay", "支付宝", "manual", 10), ("wechat_pay", "微信支付", "manual", 20), ]: channel = db.scalar( select(PaymentChannel).where(PaymentChannel.channel_code == channel_code) ) if not channel: db.add( PaymentChannel( channel_code=channel_code, channel_name=channel_name, provider_type=provider_type, status=1, sort_order=sort_order, ) ) if not db.scalar(select(RechargePlan.id)): db.add_all( [ RechargePlan( name="体验包", pay_amount=Decimal("29.90"), point_ratio=100, give_points=2990, bonus_points=200, sort_order=10, status=1, ), RechargePlan( name="标准包", pay_amount=Decimal("99.00"), point_ratio=100, give_points=9900, bonus_points=1200, sort_order=20, status=1, ), RechargePlan( name="专业包", pay_amount=Decimal("299.00"), point_ratio=100, give_points=29900, bonus_points=4500, sort_order=30, status=1, ), ] ) if not db.scalar(select(ProviderAccount.id)): openai_account = ProviderAccount( provider_code="openai-mock", provider_name="OpenAI Mock", api_format="openai_official_video", base_url="mock://openai", api_key_encrypted="mock", timeout_seconds=60, max_retries=3, status=1, ) seedance_account = ProviderAccount( provider_code="seedance-mock", provider_name="Seedance Mock", api_format="seedance_video_generation", base_url="mock://seedance", api_key_encrypted="mock", timeout_seconds=60, max_retries=3, status=1, ) db.add_all([openai_account, seedance_account]) db.flush() openai_model = ProviderModel( provider_account_id=openai_account.id, model_code="sora-2", model_name="Sora 2", request_content_type="multipart/form-data", supports_text_to_video=True, supports_image_to_video=True, supports_generate_audio=True, supports_webhook=True, min_duration=4, max_duration=12, default_ratio="16:9", default_resolution="1280x720", status=1, ) seedance_model = ProviderModel( provider_account_id=seedance_account.id, model_code="seedance", model_name="Seedance", request_content_type="application/json", supports_text_to_video=True, supports_image_to_video=True, supports_generate_audio=True, supports_webhook=False, min_duration=4, max_duration=12, default_ratio="16:9", default_resolution="1280x720", status=1, ) db.add_all([openai_model, seedance_model]) db.flush() standard_model = VideoModel( model_key="standard-video", model_name="标准视频", frontend_title="标准视频", frontend_description="平衡质量与速度,适合大多数日常创作。", default_duration_seconds=8, default_ratio="16:9", default_resolution="1280x720", sort_order=10, status=1, ) fast_model = VideoModel( model_key="fast-video", model_name="高速视频", frontend_title="高速视频", frontend_description="更快返回结果,适合灵感验证与批量尝试。", default_duration_seconds=6, default_ratio="16:9", default_resolution="1280x720", sort_order=20, status=1, ) db.add_all([standard_model, fast_model]) db.flush() db.add_all( [ VideoModelSupplierBinding( video_model_id=standard_model.id, provider_model_id=openai_model.id, routing_priority=10, is_primary=True, status=1, ), VideoModelSupplierBinding( video_model_id=fast_model.id, provider_model_id=seedance_model.id, routing_priority=10, is_primary=True, status=1, ), ] ) db.add_all( [ PricingRule( rule_name="标准视频默认价格", video_model_id=standard_model.id, points_per_second=120, minimum_points=500, effective_at=datetime.utcnow() - timedelta(days=1), version_no=1, status=1, ), PricingRule( rule_name="高速视频默认价格", video_model_id=fast_model.id, points_per_second=90, minimum_points=400, effective_at=datetime.utcnow() - timedelta(days=1), version_no=1, status=1, ), ] ) default_configs = { "site.title": ("AIVideo", "site"), "site.notice": ("欢迎体验 AIVideo 本地开发版。", "site"), "reward.signup.enabled": ("1", "reward"), "reward.signup.points": ("300", "reward"), "reward.invite.enabled": ("1", "reward"), "reward.invite.points": ("500", "reward"), "invite.code.enabled": ("1", "invite"), "task.default_poll_interval_seconds": ("5", "task"), } for key, (value, group_name) in default_configs.items(): if not db.scalar(select(SystemConfig).where(SystemConfig.config_key == key)): db.add( SystemConfig( config_key=key, config_value=value, value_type="string", group_name=group_name, is_public=1 if key.startswith("site.") else 0, ) ) if not db.scalar(select(RedeemCode.id)): db.add_all( [ RedeemCode( batch_no="WELCOME", redeem_code="SPRING-2026-ABCD-1234", points=1000, status="unused", ), RedeemCode( batch_no="WELCOME", redeem_code="SPRING-2026-EFGH-5678", points=1500, status="unused", ), ] ) db.commit()