import time from typing import Optional from fastapi import HTTPException from open_webui.models.subscriptions import ( SubscriptionPlans, UserSubscriptions, SubscriptionUsageLogs, ) async def check_subscription_access( user_id: str, model_id: str, chat_id: Optional[str] = None, message_id: Optional[str] = None, ) -> None: """ Check if user has access to use the specified model. Raises HTTPException if: 1. User's subscription has expired (auto-downgrades to Free) 2. Monthly message limit reached 3. Model is not in allowed list """ now = int(time.time()) # get or initialize subscription subscription = UserSubscriptions.get_by_user_id(user_id) if not subscription: subscription = UserSubscriptions.init_free_subscription(user_id) # check if expired and downgrade if subscription.expires_at and subscription.expires_at < now: subscription = UserSubscriptions.downgrade_to_free(user_id) # check/reset billing period if now >= subscription.current_period_end: subscription = UserSubscriptions.reset_period(subscription.id) # get plan info plan = SubscriptionPlans.get_plan_by_id(subscription.plan_id) if not plan: raise HTTPException( status_code=500, detail="Subscription plan not found" ) # check model access if plan.allowed_models and model_id not in plan.allowed_models: raise HTTPException( status_code=403, detail=f"Your {plan.name} plan does not include access to this model. Please upgrade your subscription." ) # check usage limit if plan.monthly_message_limit is not None: if subscription.messages_used >= plan.monthly_message_limit: raise HTTPException( status_code=403, detail=f"You have reached your monthly message limit ({plan.monthly_message_limit}). Please wait until next month or upgrade your subscription." ) def record_usage( user_id: str, model_id: Optional[str] = None, chat_id: Optional[str] = None, message_id: Optional[str] = None, ) -> None: """ Record a message usage for the user. Should be called after a successful chat completion. """ subscription = UserSubscriptions.get_by_user_id(user_id) if not subscription: return # increment usage counter UserSubscriptions.increment_usage(subscription.id) # log usage SubscriptionUsageLogs.insert( user_id=user_id, subscription_id=subscription.id, plan_id=subscription.plan_id, model_id=model_id, chat_id=chat_id, message_id=message_id, ) def get_user_allowed_models(user_id: str) -> Optional[list[str]]: """ Get the list of models allowed for a user based on their subscription. Returns None if all models are allowed. """ subscription = UserSubscriptions.get_by_user_id(user_id) if not subscription: default_plan = SubscriptionPlans.get_default_plan() return default_plan.allowed_models if default_plan else None plan = SubscriptionPlans.get_plan_by_id(subscription.plan_id) return plan.allowed_models if plan else None def filter_models_by_subscription(user_id: str, models: list) -> list: """ Filter a list of models based on user's subscription. """ allowed_models = get_user_allowed_models(user_id) if allowed_models is None: return models return [m for m in models if m.get("id") in allowed_models]