import json import time import uuid from typing import Dict, List, Optional, Tuple from decimal import Decimal from pydantic import BaseModel, ConfigDict, Field, field_validator from sqlalchemy import JSON, BigInteger, Boolean, Column, Integer, Numeric, String, Text from open_webui.internal.db import Base, get_db #################### # Subscription Plan DB Schema #################### class SubscriptionPlan(Base): __tablename__ = "subscription_plan" id = Column(String, primary_key=True) name = Column(String, nullable=False) description = Column(Text, nullable=True) price = Column(Numeric(precision=10, scale=2), default=0) # per-model monthly limits: {"model_id": limit, ...} # -1 = unlimited, 0 = not allowed, positive = monthly limit model_limits = Column(JSON, nullable=True) # default limit for models not in model_limits # -1 = unlimited, 0 = not allowed, positive = monthly limit default_model_limit = Column(Integer, default=0) # deprecated fields (kept for backward compatibility) monthly_message_limit = Column(Integer, nullable=True) allowed_models = Column(JSON, nullable=True) priority = Column(Integer, default=0) is_default = Column(Boolean, default=False) is_active = Column(Boolean, default=True) created_at = Column(BigInteger) updated_at = Column(BigInteger) class UserSubscription(Base): __tablename__ = "user_subscription" id = Column(String, primary_key=True) user_id = Column(String, index=True, nullable=False, unique=True) plan_id = Column(String, nullable=False) started_at = Column(BigInteger, nullable=False) expires_at = Column(BigInteger, nullable=True) current_period_start = Column(BigInteger) current_period_end = Column(BigInteger) # per-model usage tracking: {"model_id": count, ...} model_usage = Column(JSON, default={}) # deprecated field (kept for backward compatibility) messages_used = Column(Integer, default=0) created_at = Column(BigInteger) updated_at = Column(BigInteger) class SubscriptionUsageLog(Base): __tablename__ = "subscription_usage_log" id = Column(String, primary_key=True) user_id = Column(String, index=True, nullable=False) subscription_id = Column(String, nullable=False) plan_id = Column(String, nullable=False) model_id = Column(String, nullable=True) chat_id = Column(String, nullable=True) message_id = Column(String, nullable=True) created_at = Column(BigInteger, index=True) class RedemptionCode(Base): __tablename__ = "redemption_code" code = Column(String, primary_key=True) purpose = Column(String, index=True) redemption_type = Column(String, default="duration") plan_id = Column(String, nullable=True) duration_days = Column(Integer, nullable=True) upgrade_days = Column(Integer, nullable=True) user_id = Column(String, index=True, nullable=True) created_at = Column(BigInteger, index=True) expired_at = Column(BigInteger, index=True, nullable=True) received_at = Column(BigInteger, index=True, nullable=True) #################### # Pydantic Models #################### def _parse_json_field(v): """Parse JSON field from string if needed.""" if isinstance(v, str): try: return json.loads(v) except json.JSONDecodeError: return None return v class SubscriptionPlanModel(BaseModel): model_config = ConfigDict(from_attributes=True) id: str name: str description: Optional[str] = None price: Decimal = Field(default_factory=lambda: Decimal("0")) # per-model monthly limits model_limits: Optional[Dict[str, int]] = None default_model_limit: int = 0 # deprecated fields monthly_message_limit: Optional[int] = None allowed_models: Optional[List[str]] = None priority: int = 0 is_default: bool = False is_active: bool = True created_at: int = Field(default_factory=lambda: int(time.time())) updated_at: int = Field(default_factory=lambda: int(time.time())) @field_validator('model_limits', 'allowed_models', mode='before') @classmethod def parse_json_fields(cls, v): return _parse_json_field(v) def get_model_limit(self, model_id: str) -> int: """ Get the usage limit for a specific model. Returns: -1 = unlimited, 0 = not allowed, positive = monthly limit """ if self.model_limits and model_id in self.model_limits: return self.model_limits[model_id] return self.default_model_limit class SubscriptionPlanForm(BaseModel): id: str name: str description: Optional[str] = None price: float = 0 model_limits: Optional[Dict[str, int]] = None default_model_limit: int = 0 priority: int = 0 is_default: bool = False is_active: bool = True class UpdateSubscriptionPlanForm(BaseModel): name: Optional[str] = None description: Optional[str] = None price: Optional[float] = None model_limits: Optional[Dict[str, int]] = None default_model_limit: Optional[int] = None priority: Optional[int] = None is_default: Optional[bool] = None is_active: Optional[bool] = None class UserSubscriptionModel(BaseModel): model_config = ConfigDict(from_attributes=True) id: str = Field(default_factory=lambda: uuid.uuid4().hex) user_id: str plan_id: str started_at: int expires_at: Optional[int] = None current_period_start: int current_period_end: int # per-model usage tracking model_usage: Dict[str, int] = Field(default_factory=dict) # deprecated field messages_used: int = 0 created_at: int = Field(default_factory=lambda: int(time.time())) updated_at: int = Field(default_factory=lambda: int(time.time())) @field_validator('model_usage', mode='before') @classmethod def parse_model_usage(cls, v): parsed = _parse_json_field(v) return parsed if parsed else {} def get_model_usage(self, model_id: str) -> int: """Get usage count for a specific model.""" return self.model_usage.get(model_id, 0) class ModelUsageInfo(BaseModel): """Usage info for a single model.""" model_id: str limit: int # -1 = unlimited, 0 = not allowed used: int remaining: Optional[int] = None # None if unlimited class UserSubscriptionWithPlanModel(UserSubscriptionModel): plan: Optional[SubscriptionPlanModel] = None days_remaining: Optional[int] = None # per-model usage summary model_usage_info: Optional[List[ModelUsageInfo]] = None # deprecated messages_remaining: Optional[int] = None # admin view fields username: Optional[str] = None class SubscriptionUsageLogModel(BaseModel): model_config = ConfigDict(from_attributes=True) id: str = Field(default_factory=lambda: uuid.uuid4().hex) user_id: str subscription_id: str plan_id: str model_id: Optional[str] = None chat_id: Optional[str] = None message_id: Optional[str] = None created_at: int = Field(default_factory=lambda: int(time.time())) class RedemptionCodeModel(BaseModel): model_config = ConfigDict(from_attributes=True, extra="allow") code: str purpose: str redemption_type: str = "duration" plan_id: Optional[str] = None duration_days: Optional[int] = None upgrade_days: Optional[int] = None user_id: Optional[str] = None created_at: int expired_at: Optional[int] = None received_at: Optional[int] = None class CreateRedemptionCodeForm(BaseModel): purpose: str = Field(min_length=1, max_length=255) count: int = Field(ge=1, le=1000) redemption_type: str = "duration" plan_id: str duration_days: Optional[int] = Field(default=None, ge=1) upgrade_days: Optional[int] = Field(default=None, ge=1) expired_at: Optional[int] = Field(default=None, gt=0) class UpdateRedemptionCodeForm(BaseModel): purpose: Optional[str] = Field(None, min_length=1, max_length=255) plan_id: Optional[str] = None duration_days: Optional[int] = Field(None, ge=1) upgrade_days: Optional[int] = Field(None, ge=1) expired_at: Optional[int] = Field(None, gt=0) class UpdateUserSubscriptionForm(BaseModel): plan_id: Optional[str] = None expires_at: Optional[int] = None messages_used: Optional[int] = None #################### # Table Classes #################### class SubscriptionPlansTable: def get_all_plans(self, include_inactive: bool = False) -> List[SubscriptionPlanModel]: with get_db() as db: query = db.query(SubscriptionPlan) if not include_inactive: query = query.filter(SubscriptionPlan.is_active == True) plans = query.order_by(SubscriptionPlan.priority.desc()).all() return [SubscriptionPlanModel.model_validate(plan) for plan in plans] def get_plan_by_id(self, plan_id: str) -> Optional[SubscriptionPlanModel]: with get_db() as db: plan = db.query(SubscriptionPlan).filter(SubscriptionPlan.id == plan_id).first() return SubscriptionPlanModel.model_validate(plan) if plan else None def get_default_plan(self) -> Optional[SubscriptionPlanModel]: with get_db() as db: plan = db.query(SubscriptionPlan).filter( SubscriptionPlan.is_default == True, SubscriptionPlan.is_active == True ).first() return SubscriptionPlanModel.model_validate(plan) if plan else None def create_plan(self, form_data: SubscriptionPlanForm) -> SubscriptionPlanModel: now = int(time.time()) plan = SubscriptionPlanModel( id=form_data.id, name=form_data.name, description=form_data.description, price=Decimal(str(form_data.price)), model_limits=form_data.model_limits, default_model_limit=form_data.default_model_limit, priority=form_data.priority, is_default=form_data.is_default, is_active=form_data.is_active, created_at=now, updated_at=now, ) # if this is set as default, unset other defaults if form_data.is_default: with get_db() as db: db.query(SubscriptionPlan).filter( SubscriptionPlan.is_default == True ).update({"is_default": False}) db.commit() with get_db() as db: db.add(SubscriptionPlan(**plan.model_dump())) db.commit() return plan def update_plan(self, plan_id: str, form_data: UpdateSubscriptionPlanForm) -> Optional[SubscriptionPlanModel]: with get_db() as db: plan = db.query(SubscriptionPlan).filter(SubscriptionPlan.id == plan_id).first() if not plan: return None update_data = form_data.model_dump(exclude_unset=True) update_data["updated_at"] = int(time.time()) if update_data.get("price") is not None: update_data["price"] = Decimal(str(update_data["price"])) # if this is set as default, unset other defaults if update_data.get("is_default"): db.query(SubscriptionPlan).filter( SubscriptionPlan.id != plan_id, SubscriptionPlan.is_default == True ).update({"is_default": False}) db.query(SubscriptionPlan).filter(SubscriptionPlan.id == plan_id).update(update_data) db.commit() return self.get_plan_by_id(plan_id) def delete_plan(self, plan_id: str) -> bool: with get_db() as db: result = db.query(SubscriptionPlan).filter(SubscriptionPlan.id == plan_id).delete() db.commit() return result > 0 SubscriptionPlans = SubscriptionPlansTable() class UserSubscriptionsTable: def get_by_user_id(self, user_id: str) -> Optional[UserSubscriptionModel]: with get_db() as db: sub = db.query(UserSubscription).filter(UserSubscription.user_id == user_id).first() return UserSubscriptionModel.model_validate(sub) if sub else None def get_with_plan(self, user_id: str) -> Optional[UserSubscriptionWithPlanModel]: sub = self.get_by_user_id(user_id) if not sub: return None plan = SubscriptionPlans.get_plan_by_id(sub.plan_id) now = int(time.time()) days_remaining = None if sub.expires_at: days_remaining = max(0, (sub.expires_at - now) // 86400) # build model usage info model_usage_info = [] if plan and plan.model_limits: for model_id, limit in plan.model_limits.items(): used = sub.get_model_usage(model_id) remaining = None if limit == -1 else max(0, limit - used) model_usage_info.append(ModelUsageInfo( model_id=model_id, limit=limit, used=used, remaining=remaining, )) return UserSubscriptionWithPlanModel( **sub.model_dump(), plan=plan, days_remaining=days_remaining, model_usage_info=model_usage_info if model_usage_info else None, ) def get_all_subscriptions( self, query: Optional[str] = None, offset: int = 0, limit: int = 30 ) -> Tuple[int, List[UserSubscriptionWithPlanModel]]: with get_db() as db: q = db.query(UserSubscription).order_by(UserSubscription.updated_at.desc()) if query: q = q.filter(UserSubscription.user_id.contains(query)) total = q.count() subs = q.offset(offset).limit(limit).all() results = [] for sub in subs: sub_model = UserSubscriptionModel.model_validate(sub) plan = SubscriptionPlans.get_plan_by_id(sub_model.plan_id) now = int(time.time()) days_remaining = None if sub_model.expires_at: days_remaining = max(0, (sub_model.expires_at - now) // 86400) results.append(UserSubscriptionWithPlanModel( **sub_model.model_dump(), plan=plan, days_remaining=days_remaining, )) return total, results def init_free_subscription(self, user_id: str) -> UserSubscriptionModel: default_plan = SubscriptionPlans.get_default_plan() if not default_plan: # create a default free plan if not exists default_plan = SubscriptionPlans.create_plan(SubscriptionPlanForm( id="free", name="免费版", description="免费套餐,基础模型无限使用", price=0, model_limits={}, default_model_limit=-1, # all models unlimited by default priority=0, is_default=True, is_active=True, )) now = int(time.time()) # calculate period end (end of current month) import datetime today = datetime.date.today() if today.month == 12: next_month = today.replace(year=today.year + 1, month=1, day=1) else: next_month = today.replace(month=today.month + 1, day=1) period_end = int(datetime.datetime.combine(next_month, datetime.time.min).timestamp()) sub = UserSubscriptionModel( user_id=user_id, plan_id=default_plan.id, started_at=now, expires_at=None, current_period_start=now, current_period_end=period_end, model_usage={}, messages_used=0, created_at=now, updated_at=now, ) with get_db() as db: db.add(UserSubscription(**sub.model_dump())) db.commit() return sub def get_or_create_subscription(self, user_id: str) -> UserSubscriptionModel: """Get existing subscription or create a new free subscription for user.""" existing = self.get_by_user_id(user_id) if existing: return existing return self.init_free_subscription(user_id) def update_subscription( self, user_id: str, form_data: UpdateUserSubscriptionForm ) -> Optional[UserSubscriptionModel]: with get_db() as db: sub = db.query(UserSubscription).filter(UserSubscription.user_id == user_id).first() if not sub: return None update_data = form_data.model_dump(exclude_unset=True) update_data["updated_at"] = int(time.time()) db.query(UserSubscription).filter(UserSubscription.user_id == user_id).update(update_data) db.commit() return self.get_by_user_id(user_id) def upgrade_subscription( self, user_id: str, plan_id: str, duration_days: Optional[int] = None, expires_at: Optional[int] = None ) -> UserSubscriptionModel: now = int(time.time()) # calculate new expiration new_expires_at = expires_at if duration_days: existing = self.get_by_user_id(user_id) if existing and existing.expires_at and existing.expires_at > now: # extend from current expiration new_expires_at = existing.expires_at + (duration_days * 86400) else: # start fresh new_expires_at = now + (duration_days * 86400) with get_db() as db: existing = db.query(UserSubscription).filter(UserSubscription.user_id == user_id).first() if existing: db.query(UserSubscription).filter(UserSubscription.user_id == user_id).update({ "plan_id": plan_id, "expires_at": new_expires_at, "updated_at": now, }) db.commit() else: # calculate period end import datetime today = datetime.date.today() if today.month == 12: next_month = today.replace(year=today.year + 1, month=1, day=1) else: next_month = today.replace(month=today.month + 1, day=1) period_end = int(datetime.datetime.combine(next_month, datetime.time.min).timestamp()) sub = UserSubscriptionModel( user_id=user_id, plan_id=plan_id, started_at=now, expires_at=new_expires_at, current_period_start=now, current_period_end=period_end, messages_used=0, created_at=now, updated_at=now, ) db.add(UserSubscription(**sub.model_dump())) db.commit() return self.get_by_user_id(user_id) def downgrade_to_free(self, user_id: str) -> UserSubscriptionModel: default_plan = SubscriptionPlans.get_default_plan() if not default_plan: raise ValueError("No default plan configured") now = int(time.time()) with get_db() as db: db.query(UserSubscription).filter(UserSubscription.user_id == user_id).update({ "plan_id": default_plan.id, "expires_at": None, "updated_at": now, }) db.commit() return self.get_by_user_id(user_id) def reset_period(self, subscription_id: str) -> UserSubscriptionModel: import datetime now = int(time.time()) today = datetime.date.today() if today.month == 12: next_month = today.replace(year=today.year + 1, month=1, day=1) else: next_month = today.replace(month=today.month + 1, day=1) period_end = int(datetime.datetime.combine(next_month, datetime.time.min).timestamp()) with get_db() as db: db.query(UserSubscription).filter(UserSubscription.id == subscription_id).update({ "current_period_start": now, "current_period_end": period_end, "model_usage": {}, "messages_used": 0, "updated_at": now, }) db.commit() sub = db.query(UserSubscription).filter(UserSubscription.id == subscription_id).first() return UserSubscriptionModel.model_validate(sub) def increment_model_usage(self, subscription_id: str, model_id: str) -> None: """Increment usage counter for a specific model.""" with get_db() as db: sub = db.query(UserSubscription).filter(UserSubscription.id == subscription_id).first() if not sub: return current_usage = sub.model_usage or {} if isinstance(current_usage, str): try: current_usage = json.loads(current_usage) except json.JSONDecodeError: current_usage = {} current_usage[model_id] = current_usage.get(model_id, 0) + 1 db.query(UserSubscription).filter(UserSubscription.id == subscription_id).update({ "model_usage": current_usage, "messages_used": UserSubscription.messages_used + 1, "updated_at": int(time.time()), }) db.commit() def increment_usage(self, subscription_id: str) -> None: """Deprecated: use increment_model_usage instead.""" with get_db() as db: db.query(UserSubscription).filter(UserSubscription.id == subscription_id).update({ "messages_used": UserSubscription.messages_used + 1, "updated_at": int(time.time()), }) db.commit() UserSubscriptions = UserSubscriptionsTable() class SubscriptionUsageLogsTable: def insert( self, user_id: str, subscription_id: str, plan_id: str, model_id: Optional[str] = None, chat_id: Optional[str] = None, message_id: Optional[str] = None, ) -> SubscriptionUsageLogModel: log = SubscriptionUsageLogModel( user_id=user_id, subscription_id=subscription_id, plan_id=plan_id, model_id=model_id, chat_id=chat_id, message_id=message_id, ) with get_db() as db: db.add(SubscriptionUsageLog(**log.model_dump())) db.commit() return log def get_by_user_id( self, user_id: str, offset: int = 0, limit: int = 30 ) -> Tuple[int, List[SubscriptionUsageLogModel]]: with get_db() as db: query = db.query(SubscriptionUsageLog).filter( SubscriptionUsageLog.user_id == user_id ).order_by(SubscriptionUsageLog.created_at.desc()) total = query.count() logs = query.offset(offset).limit(limit).all() return total, [SubscriptionUsageLogModel.model_validate(log) for log in logs] def get_by_time_range( self, start_time: int, end_time: int, user_ids: Optional[List[str]] = None ) -> List[SubscriptionUsageLogModel]: with get_db() as db: query = db.query(SubscriptionUsageLog).filter( SubscriptionUsageLog.created_at >= start_time, SubscriptionUsageLog.created_at < end_time ) if user_ids: query = query.filter(SubscriptionUsageLog.user_id.in_(user_ids)) logs = query.order_by(SubscriptionUsageLog.created_at.asc()).all() return [SubscriptionUsageLogModel.model_validate(log) for log in logs] SubscriptionUsageLogs = SubscriptionUsageLogsTable() class RedemptionCodesTable: def get_code(self, code: str) -> Optional[RedemptionCodeModel]: with get_db() as db: redemption_code = db.query(RedemptionCode).filter(RedemptionCode.code == code).first() return RedemptionCodeModel.model_validate(redemption_code) if redemption_code else None def get_codes( self, keyword: Optional[str] = None, offset: int = 0, limit: int = 30 ) -> Tuple[int, List[RedemptionCodeModel]]: with get_db() as db: query = db.query(RedemptionCode).order_by(RedemptionCode.created_at.desc()) if keyword: query = query.filter( (RedemptionCode.code == keyword) | (RedemptionCode.purpose.contains(keyword)) | (RedemptionCode.plan_id == keyword) ) total = query.count() codes = query.offset(offset).limit(limit).all() return total, [RedemptionCodeModel.model_validate(code) for code in codes] def insert_codes(self, codes: List[RedemptionCodeModel]) -> None: with get_db() as db: db.add_all([RedemptionCode(**code.model_dump()) for code in codes]) db.commit() def update_code(self, code: str, form_data: UpdateRedemptionCodeForm) -> Optional[RedemptionCodeModel]: with get_db() as db: existing = db.query(RedemptionCode).filter(RedemptionCode.code == code).first() if not existing: return None update_data = form_data.model_dump(exclude_unset=True) db.query(RedemptionCode).filter(RedemptionCode.code == code).update(update_data) db.commit() return self.get_code(code) def delete_code(self, code: str) -> bool: with get_db() as db: result = db.query(RedemptionCode).filter(RedemptionCode.code == code).delete() db.commit() return result > 0 def redeem_code(self, code: str, user_id: str) -> UserSubscriptionModel: from fastapi import HTTPException redemption_code = self.get_code(code) if not redemption_code: raise HTTPException(status_code=404, detail="Code not found") if redemption_code.user_id is not None: raise HTTPException(status_code=400, detail="Code already used") now = int(time.time()) if redemption_code.expired_at and redemption_code.expired_at < now: raise HTTPException(status_code=400, detail="Code expired") # mark code as used with get_db() as db: db.query(RedemptionCode).filter(RedemptionCode.code == code).update({ "user_id": user_id, "received_at": now, }) db.commit() # upgrade user subscription if redemption_code.redemption_type == "duration": return UserSubscriptions.upgrade_subscription( user_id=user_id, plan_id=redemption_code.plan_id, duration_days=redemption_code.duration_days, ) else: # upgrade type: calculate expires_at from upgrade_days return UserSubscriptions.upgrade_subscription( user_id=user_id, plan_id=redemption_code.plan_id, duration_days=redemption_code.upgrade_days, ) RedemptionCodes = RedemptionCodesTable()