import time from typing import Optional from fastapi import HTTPException from open_webui.models.subscriptions import ( SubscriptionPlans, UserSubscriptions, SubscriptionUsageLogs, ) def check_subscription_access( user_id: str, model_id: str, chat_id: Optional[str] = None, message_id: Optional[str] = None, ) -> None: """ Check if user has access to use the specified model. Raises HTTPException if: 1. User's subscription has expired (auto-downgrades to Free) 2. Model usage limit reached for this billing period 3. Model is not allowed (limit = 0) """ now = int(time.time()) # get or initialize subscription subscription = UserSubscriptions.get_by_user_id(user_id) if not subscription: subscription = UserSubscriptions.init_free_subscription(user_id) # check if expired and downgrade if subscription.expires_at and subscription.expires_at < now: subscription = UserSubscriptions.downgrade_to_free(user_id) # check/reset billing period if now >= subscription.current_period_end: subscription = UserSubscriptions.reset_period(subscription.id) # get plan info plan = SubscriptionPlans.get_plan_by_id(subscription.plan_id) if not plan: raise HTTPException( status_code=500, detail="Subscription plan not found" ) # get model limit for this plan model_limit = plan.get_model_limit(model_id) # check if model is not allowed (limit = 0) if model_limit == 0: raise HTTPException( status_code=403, detail=f"您的 {plan.name} 套餐不支持使用此模型。请升级套餐以获取访问权限。" ) # check usage limit (skip if unlimited = -1) if model_limit > 0: current_usage = subscription.get_model_usage(model_id) if current_usage >= model_limit: raise HTTPException( status_code=403, detail=f"您本月已使用 {current_usage}/{model_limit} 次此模型。请等待下月刷新或升级套餐。" ) def record_usage( user_id: str, model_id: Optional[str] = None, chat_id: Optional[str] = None, message_id: Optional[str] = None, ) -> None: """ Record a message usage for the user. Should be called after a successful chat completion. """ subscription = UserSubscriptions.get_by_user_id(user_id) if not subscription: return # increment per-model usage counter if model_id: UserSubscriptions.increment_model_usage(subscription.id, model_id) else: UserSubscriptions.increment_usage(subscription.id) # log usage SubscriptionUsageLogs.insert( user_id=user_id, subscription_id=subscription.id, plan_id=subscription.plan_id, model_id=model_id, chat_id=chat_id, message_id=message_id, ) def get_user_model_limit(user_id: str, model_id: str) -> int: """ Get the usage limit for a specific model for a user. Returns: -1 = unlimited, 0 = not allowed, positive = monthly limit """ subscription = UserSubscriptions.get_by_user_id(user_id) if not subscription: default_plan = SubscriptionPlans.get_default_plan() if default_plan: return default_plan.get_model_limit(model_id) return -1 # allow by default if no plan plan = SubscriptionPlans.get_plan_by_id(subscription.plan_id) if not plan: return -1 return plan.get_model_limit(model_id) def get_user_model_remaining(user_id: str, model_id: str) -> Optional[int]: """ Get the remaining usage count for a specific model for a user. Returns: None if unlimited, 0 if not allowed or exhausted, positive if remaining """ subscription = UserSubscriptions.get_by_user_id(user_id) if not subscription: default_plan = SubscriptionPlans.get_default_plan() if default_plan: limit = default_plan.get_model_limit(model_id) return None if limit == -1 else limit return None plan = SubscriptionPlans.get_plan_by_id(subscription.plan_id) if not plan: return None limit = plan.get_model_limit(model_id) if limit == -1: return None if limit == 0: return 0 used = subscription.get_model_usage(model_id) return max(0, limit - used) def get_user_allowed_models(user_id: str) -> Optional[list[str]]: """ Deprecated: Use get_user_model_limit instead. Get the list of models allowed for a user based on their subscription. Returns None if all models are allowed. """ subscription = UserSubscriptions.get_by_user_id(user_id) if not subscription: default_plan = SubscriptionPlans.get_default_plan() if default_plan and default_plan.model_limits: # return models with limit != 0 return [m for m, l in default_plan.model_limits.items() if l != 0] return None plan = SubscriptionPlans.get_plan_by_id(subscription.plan_id) if plan and plan.model_limits: return [m for m, l in plan.model_limits.items() if l != 0] return None def filter_models_by_subscription(user_id: str, models: list) -> list: """ Filter a list of models based on user's subscription. Models with limit = 0 are filtered out. """ subscription = UserSubscriptions.get_by_user_id(user_id) plan = None if subscription: plan = SubscriptionPlans.get_plan_by_id(subscription.plan_id) else: plan = SubscriptionPlans.get_default_plan() if not plan: return models # if no model_limits defined, use default_model_limit if not plan.model_limits: if plan.default_model_limit == 0: return [] # no models allowed return models # all models allowed # filter out models with limit = 0 result = [] for m in models: model_id = m.get("id", "") limit = plan.get_model_limit(model_id) if limit != 0: result.append(m) return result