import time import uuid from typing import List, Optional, Tuple from decimal import Decimal from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import JSON, BigInteger, Boolean, Column, Integer, Numeric, String, Text from open_webui.internal.db import Base, get_db #################### # Subscription Plan DB Schema #################### class SubscriptionPlan(Base): __tablename__ = "subscription_plan" id = Column(String, primary_key=True) name = Column(String, nullable=False) description = Column(Text, nullable=True) price = Column(Numeric(precision=10, scale=2), default=0) monthly_message_limit = Column(Integer, nullable=True) allowed_models = Column(JSON, nullable=True) priority = Column(Integer, default=0) is_default = Column(Boolean, default=False) is_active = Column(Boolean, default=True) created_at = Column(BigInteger) updated_at = Column(BigInteger) class UserSubscription(Base): __tablename__ = "user_subscription" id = Column(String, primary_key=True) user_id = Column(String, index=True, nullable=False, unique=True) plan_id = Column(String, nullable=False) started_at = Column(BigInteger, nullable=False) expires_at = Column(BigInteger, nullable=True) current_period_start = Column(BigInteger) current_period_end = Column(BigInteger) messages_used = Column(Integer, default=0) created_at = Column(BigInteger) updated_at = Column(BigInteger) class SubscriptionUsageLog(Base): __tablename__ = "subscription_usage_log" id = Column(String, primary_key=True) user_id = Column(String, index=True, nullable=False) subscription_id = Column(String, nullable=False) plan_id = Column(String, nullable=False) model_id = Column(String, nullable=True) chat_id = Column(String, nullable=True) message_id = Column(String, nullable=True) created_at = Column(BigInteger, index=True) class RedemptionCode(Base): __tablename__ = "redemption_code" code = Column(String, primary_key=True) purpose = Column(String, index=True) redemption_type = Column(String, default="duration") plan_id = Column(String, nullable=True) duration_days = Column(Integer, nullable=True) upgrade_expires_at = Column(BigInteger, nullable=True) user_id = Column(String, index=True, nullable=True) created_at = Column(BigInteger, index=True) expired_at = Column(BigInteger, index=True, nullable=True) received_at = Column(BigInteger, index=True, nullable=True) #################### # Pydantic Models #################### class SubscriptionPlanModel(BaseModel): model_config = ConfigDict(from_attributes=True) id: str name: str description: Optional[str] = None price: Decimal = Field(default_factory=lambda: Decimal("0")) monthly_message_limit: Optional[int] = None allowed_models: Optional[List[str]] = None priority: int = 0 is_default: bool = False is_active: bool = True created_at: int = Field(default_factory=lambda: int(time.time())) updated_at: int = Field(default_factory=lambda: int(time.time())) class SubscriptionPlanForm(BaseModel): id: str name: str description: Optional[str] = None price: float = 0 monthly_message_limit: Optional[int] = None allowed_models: Optional[List[str]] = None priority: int = 0 is_default: bool = False is_active: bool = True class UpdateSubscriptionPlanForm(BaseModel): name: Optional[str] = None description: Optional[str] = None price: Optional[float] = None monthly_message_limit: Optional[int] = None allowed_models: Optional[List[str]] = None priority: Optional[int] = None is_default: Optional[bool] = None is_active: Optional[bool] = None class UserSubscriptionModel(BaseModel): model_config = ConfigDict(from_attributes=True) id: str = Field(default_factory=lambda: uuid.uuid4().hex) user_id: str plan_id: str started_at: int expires_at: Optional[int] = None current_period_start: int current_period_end: int messages_used: int = 0 created_at: int = Field(default_factory=lambda: int(time.time())) updated_at: int = Field(default_factory=lambda: int(time.time())) class UserSubscriptionWithPlanModel(UserSubscriptionModel): plan: Optional[SubscriptionPlanModel] = None messages_remaining: Optional[int] = None days_remaining: Optional[int] = None class SubscriptionUsageLogModel(BaseModel): model_config = ConfigDict(from_attributes=True) id: str = Field(default_factory=lambda: uuid.uuid4().hex) user_id: str subscription_id: str plan_id: str model_id: Optional[str] = None chat_id: Optional[str] = None message_id: Optional[str] = None created_at: int = Field(default_factory=lambda: int(time.time())) class RedemptionCodeModel(BaseModel): model_config = ConfigDict(from_attributes=True, extra="allow") code: str purpose: str redemption_type: str = "duration" plan_id: Optional[str] = None duration_days: Optional[int] = None upgrade_expires_at: Optional[int] = None user_id: Optional[str] = None created_at: int expired_at: Optional[int] = None received_at: Optional[int] = None class CreateRedemptionCodeForm(BaseModel): purpose: str = Field(min_length=1, max_length=255) count: int = Field(ge=1, le=1000) redemption_type: str = "duration" plan_id: str duration_days: Optional[int] = Field(default=None, ge=1) upgrade_expires_at: Optional[int] = Field(default=None, gt=0) expired_at: Optional[int] = Field(default=None, gt=0) class UpdateRedemptionCodeForm(BaseModel): purpose: Optional[str] = Field(None, min_length=1, max_length=255) plan_id: Optional[str] = None duration_days: Optional[int] = Field(None, ge=1) upgrade_expires_at: Optional[int] = Field(None, gt=0) expired_at: Optional[int] = Field(None, gt=0) class UpdateUserSubscriptionForm(BaseModel): plan_id: Optional[str] = None expires_at: Optional[int] = None messages_used: Optional[int] = None #################### # Table Classes #################### class SubscriptionPlansTable: def get_all_plans(self, include_inactive: bool = False) -> List[SubscriptionPlanModel]: with get_db() as db: query = db.query(SubscriptionPlan) if not include_inactive: query = query.filter(SubscriptionPlan.is_active == True) plans = query.order_by(SubscriptionPlan.priority.desc()).all() return [SubscriptionPlanModel.model_validate(plan) for plan in plans] def get_plan_by_id(self, plan_id: str) -> Optional[SubscriptionPlanModel]: with get_db() as db: plan = db.query(SubscriptionPlan).filter(SubscriptionPlan.id == plan_id).first() return SubscriptionPlanModel.model_validate(plan) if plan else None def get_default_plan(self) -> Optional[SubscriptionPlanModel]: with get_db() as db: plan = db.query(SubscriptionPlan).filter( SubscriptionPlan.is_default == True, SubscriptionPlan.is_active == True ).first() return SubscriptionPlanModel.model_validate(plan) if plan else None def create_plan(self, form_data: SubscriptionPlanForm) -> SubscriptionPlanModel: now = int(time.time()) plan = SubscriptionPlanModel( id=form_data.id, name=form_data.name, description=form_data.description, price=Decimal(str(form_data.price)), monthly_message_limit=form_data.monthly_message_limit, allowed_models=form_data.allowed_models, priority=form_data.priority, is_default=form_data.is_default, is_active=form_data.is_active, created_at=now, updated_at=now, ) # if this is set as default, unset other defaults if form_data.is_default: with get_db() as db: db.query(SubscriptionPlan).filter( SubscriptionPlan.is_default == True ).update({"is_default": False}) db.commit() with get_db() as db: db.add(SubscriptionPlan(**plan.model_dump())) db.commit() return plan def update_plan(self, plan_id: str, form_data: UpdateSubscriptionPlanForm) -> Optional[SubscriptionPlanModel]: with get_db() as db: plan = db.query(SubscriptionPlan).filter(SubscriptionPlan.id == plan_id).first() if not plan: return None update_data = form_data.model_dump(exclude_unset=True) update_data["updated_at"] = int(time.time()) if update_data.get("price") is not None: update_data["price"] = Decimal(str(update_data["price"])) # if this is set as default, unset other defaults if update_data.get("is_default"): db.query(SubscriptionPlan).filter( SubscriptionPlan.id != plan_id, SubscriptionPlan.is_default == True ).update({"is_default": False}) db.query(SubscriptionPlan).filter(SubscriptionPlan.id == plan_id).update(update_data) db.commit() return self.get_plan_by_id(plan_id) def delete_plan(self, plan_id: str) -> bool: with get_db() as db: result = db.query(SubscriptionPlan).filter(SubscriptionPlan.id == plan_id).delete() db.commit() return result > 0 SubscriptionPlans = SubscriptionPlansTable() class UserSubscriptionsTable: def get_by_user_id(self, user_id: str) -> Optional[UserSubscriptionModel]: with get_db() as db: sub = db.query(UserSubscription).filter(UserSubscription.user_id == user_id).first() return UserSubscriptionModel.model_validate(sub) if sub else None def get_with_plan(self, user_id: str) -> Optional[UserSubscriptionWithPlanModel]: sub = self.get_by_user_id(user_id) if not sub: return None plan = SubscriptionPlans.get_plan_by_id(sub.plan_id) now = int(time.time()) messages_remaining = None if plan and plan.monthly_message_limit: messages_remaining = max(0, plan.monthly_message_limit - sub.messages_used) days_remaining = None if sub.expires_at: days_remaining = max(0, (sub.expires_at - now) // 86400) return UserSubscriptionWithPlanModel( **sub.model_dump(), plan=plan, messages_remaining=messages_remaining, days_remaining=days_remaining, ) def get_all_subscriptions( self, query: Optional[str] = None, offset: int = 0, limit: int = 30 ) -> Tuple[int, List[UserSubscriptionWithPlanModel]]: with get_db() as db: q = db.query(UserSubscription).order_by(UserSubscription.updated_at.desc()) if query: q = q.filter(UserSubscription.user_id.contains(query)) total = q.count() subs = q.offset(offset).limit(limit).all() results = [] for sub in subs: sub_model = UserSubscriptionModel.model_validate(sub) plan = SubscriptionPlans.get_plan_by_id(sub_model.plan_id) now = int(time.time()) messages_remaining = None if plan and plan.monthly_message_limit: messages_remaining = max(0, plan.monthly_message_limit - sub_model.messages_used) days_remaining = None if sub_model.expires_at: days_remaining = max(0, (sub_model.expires_at - now) // 86400) results.append(UserSubscriptionWithPlanModel( **sub_model.model_dump(), plan=plan, messages_remaining=messages_remaining, days_remaining=days_remaining, )) return total, results def init_free_subscription(self, user_id: str) -> UserSubscriptionModel: default_plan = SubscriptionPlans.get_default_plan() if not default_plan: # create a default free plan if not exists default_plan = SubscriptionPlans.create_plan(SubscriptionPlanForm( id="free", name="Free", description="Default free plan", price=0, monthly_message_limit=100, allowed_models=[], priority=0, is_default=True, is_active=True, )) now = int(time.time()) # calculate period end (end of current month) import datetime today = datetime.date.today() if today.month == 12: next_month = today.replace(year=today.year + 1, month=1, day=1) else: next_month = today.replace(month=today.month + 1, day=1) period_end = int(datetime.datetime.combine(next_month, datetime.time.min).timestamp()) sub = UserSubscriptionModel( user_id=user_id, plan_id=default_plan.id, started_at=now, expires_at=None, current_period_start=now, current_period_end=period_end, messages_used=0, created_at=now, updated_at=now, ) with get_db() as db: db.add(UserSubscription(**sub.model_dump())) db.commit() return sub def get_or_create_subscription(self, user_id: str) -> UserSubscriptionModel: """Get existing subscription or create a new free subscription for user.""" existing = self.get_by_user_id(user_id) if existing: return existing return self.init_free_subscription(user_id) def update_subscription( self, user_id: str, form_data: UpdateUserSubscriptionForm ) -> Optional[UserSubscriptionModel]: with get_db() as db: sub = db.query(UserSubscription).filter(UserSubscription.user_id == user_id).first() if not sub: return None update_data = form_data.model_dump(exclude_unset=True) update_data["updated_at"] = int(time.time()) db.query(UserSubscription).filter(UserSubscription.user_id == user_id).update(update_data) db.commit() return self.get_by_user_id(user_id) def upgrade_subscription( self, user_id: str, plan_id: str, duration_days: Optional[int] = None, expires_at: Optional[int] = None ) -> UserSubscriptionModel: now = int(time.time()) # calculate new expiration new_expires_at = expires_at if duration_days: existing = self.get_by_user_id(user_id) if existing and existing.expires_at and existing.expires_at > now: # extend from current expiration new_expires_at = existing.expires_at + (duration_days * 86400) else: # start fresh new_expires_at = now + (duration_days * 86400) with get_db() as db: existing = db.query(UserSubscription).filter(UserSubscription.user_id == user_id).first() if existing: db.query(UserSubscription).filter(UserSubscription.user_id == user_id).update({ "plan_id": plan_id, "expires_at": new_expires_at, "updated_at": now, }) db.commit() else: # calculate period end import datetime today = datetime.date.today() if today.month == 12: next_month = today.replace(year=today.year + 1, month=1, day=1) else: next_month = today.replace(month=today.month + 1, day=1) period_end = int(datetime.datetime.combine(next_month, datetime.time.min).timestamp()) sub = UserSubscriptionModel( user_id=user_id, plan_id=plan_id, started_at=now, expires_at=new_expires_at, current_period_start=now, current_period_end=period_end, messages_used=0, created_at=now, updated_at=now, ) db.add(UserSubscription(**sub.model_dump())) db.commit() return self.get_by_user_id(user_id) def downgrade_to_free(self, user_id: str) -> UserSubscriptionModel: default_plan = SubscriptionPlans.get_default_plan() if not default_plan: raise ValueError("No default plan configured") now = int(time.time()) with get_db() as db: db.query(UserSubscription).filter(UserSubscription.user_id == user_id).update({ "plan_id": default_plan.id, "expires_at": None, "updated_at": now, }) db.commit() return self.get_by_user_id(user_id) def reset_period(self, subscription_id: str) -> UserSubscriptionModel: import datetime now = int(time.time()) today = datetime.date.today() if today.month == 12: next_month = today.replace(year=today.year + 1, month=1, day=1) else: next_month = today.replace(month=today.month + 1, day=1) period_end = int(datetime.datetime.combine(next_month, datetime.time.min).timestamp()) with get_db() as db: db.query(UserSubscription).filter(UserSubscription.id == subscription_id).update({ "current_period_start": now, "current_period_end": period_end, "messages_used": 0, "updated_at": now, }) db.commit() sub = db.query(UserSubscription).filter(UserSubscription.id == subscription_id).first() return UserSubscriptionModel.model_validate(sub) def increment_usage(self, subscription_id: str) -> None: with get_db() as db: db.query(UserSubscription).filter(UserSubscription.id == subscription_id).update({ "messages_used": UserSubscription.messages_used + 1, "updated_at": int(time.time()), }) db.commit() UserSubscriptions = UserSubscriptionsTable() class SubscriptionUsageLogsTable: def insert( self, user_id: str, subscription_id: str, plan_id: str, model_id: Optional[str] = None, chat_id: Optional[str] = None, message_id: Optional[str] = None, ) -> SubscriptionUsageLogModel: log = SubscriptionUsageLogModel( user_id=user_id, subscription_id=subscription_id, plan_id=plan_id, model_id=model_id, chat_id=chat_id, message_id=message_id, ) with get_db() as db: db.add(SubscriptionUsageLog(**log.model_dump())) db.commit() return log def get_by_user_id( self, user_id: str, offset: int = 0, limit: int = 30 ) -> Tuple[int, List[SubscriptionUsageLogModel]]: with get_db() as db: query = db.query(SubscriptionUsageLog).filter( SubscriptionUsageLog.user_id == user_id ).order_by(SubscriptionUsageLog.created_at.desc()) total = query.count() logs = query.offset(offset).limit(limit).all() return total, [SubscriptionUsageLogModel.model_validate(log) for log in logs] def get_by_time_range( self, start_time: int, end_time: int, user_ids: Optional[List[str]] = None ) -> List[SubscriptionUsageLogModel]: with get_db() as db: query = db.query(SubscriptionUsageLog).filter( SubscriptionUsageLog.created_at >= start_time, SubscriptionUsageLog.created_at < end_time ) if user_ids: query = query.filter(SubscriptionUsageLog.user_id.in_(user_ids)) logs = query.order_by(SubscriptionUsageLog.created_at.asc()).all() return [SubscriptionUsageLogModel.model_validate(log) for log in logs] SubscriptionUsageLogs = SubscriptionUsageLogsTable() class RedemptionCodesTable: def get_code(self, code: str) -> Optional[RedemptionCodeModel]: with get_db() as db: redemption_code = db.query(RedemptionCode).filter(RedemptionCode.code == code).first() return RedemptionCodeModel.model_validate(redemption_code) if redemption_code else None def get_codes( self, keyword: Optional[str] = None, offset: int = 0, limit: int = 30 ) -> Tuple[int, List[RedemptionCodeModel]]: with get_db() as db: query = db.query(RedemptionCode).order_by(RedemptionCode.created_at.desc()) if keyword: query = query.filter( (RedemptionCode.code == keyword) | (RedemptionCode.purpose.contains(keyword)) | (RedemptionCode.plan_id == keyword) ) total = query.count() codes = query.offset(offset).limit(limit).all() return total, [RedemptionCodeModel.model_validate(code) for code in codes] def insert_codes(self, codes: List[RedemptionCodeModel]) -> None: with get_db() as db: db.add_all([RedemptionCode(**code.model_dump()) for code in codes]) db.commit() def update_code(self, code: str, form_data: UpdateRedemptionCodeForm) -> Optional[RedemptionCodeModel]: with get_db() as db: existing = db.query(RedemptionCode).filter(RedemptionCode.code == code).first() if not existing: return None update_data = form_data.model_dump(exclude_unset=True) db.query(RedemptionCode).filter(RedemptionCode.code == code).update(update_data) db.commit() return self.get_code(code) def delete_code(self, code: str) -> bool: with get_db() as db: result = db.query(RedemptionCode).filter(RedemptionCode.code == code).delete() db.commit() return result > 0 def redeem_code(self, code: str, user_id: str) -> UserSubscriptionModel: from fastapi import HTTPException redemption_code = self.get_code(code) if not redemption_code: raise HTTPException(status_code=404, detail="Code not found") if redemption_code.user_id is not None: raise HTTPException(status_code=400, detail="Code already used") now = int(time.time()) if redemption_code.expired_at and redemption_code.expired_at < now: raise HTTPException(status_code=400, detail="Code expired") # mark code as used with get_db() as db: db.query(RedemptionCode).filter(RedemptionCode.code == code).update({ "user_id": user_id, "received_at": now, }) db.commit() # upgrade user subscription if redemption_code.redemption_type == "duration": return UserSubscriptions.upgrade_subscription( user_id=user_id, plan_id=redemption_code.plan_id, duration_days=redemption_code.duration_days, ) else: return UserSubscriptions.upgrade_subscription( user_id=user_id, plan_id=redemption_code.plan_id, expires_at=redemption_code.upgrade_expires_at, ) RedemptionCodes = RedemptionCodesTable()