Files
openwebui/backend/open_webui/models/subscriptions.py
shihao 16263710d9
Some checks failed
Create and publish Docker images with specific build args / build-main-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda126-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda126-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-ollama-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-ollama-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-slim-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-slim-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / merge-main-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-cuda-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-cuda126-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-ollama-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-slim-images (push) Has been cancelled
Python CI / Format Backend (3.11.x) (push) Has been cancelled
Python CI / Format Backend (3.12.x) (push) Has been cancelled
Frontend Build / Format & Build Frontend (push) Has been cancelled
Frontend Build / Frontend Unit Tests (push) Has been cancelled
Close inactive issues / close-issues (push) Has been cancelled
feat:新增套餐系统,删除积分制
2026-01-09 17:30:15 +08:00

670 lines
23 KiB
Python

import time
import uuid
from typing import List, Optional, Tuple
from decimal import Decimal
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import JSON, BigInteger, Boolean, Column, Integer, Numeric, String, Text
from open_webui.internal.db import Base, get_db
####################
# Subscription Plan DB Schema
####################
class SubscriptionPlan(Base):
__tablename__ = "subscription_plan"
id = Column(String, primary_key=True)
name = Column(String, nullable=False)
description = Column(Text, nullable=True)
price = Column(Numeric(precision=10, scale=2), default=0)
monthly_message_limit = Column(Integer, nullable=True)
allowed_models = Column(JSON, nullable=True)
priority = Column(Integer, default=0)
is_default = Column(Boolean, default=False)
is_active = Column(Boolean, default=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class UserSubscription(Base):
__tablename__ = "user_subscription"
id = Column(String, primary_key=True)
user_id = Column(String, index=True, nullable=False, unique=True)
plan_id = Column(String, nullable=False)
started_at = Column(BigInteger, nullable=False)
expires_at = Column(BigInteger, nullable=True)
current_period_start = Column(BigInteger)
current_period_end = Column(BigInteger)
messages_used = Column(Integer, default=0)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class SubscriptionUsageLog(Base):
__tablename__ = "subscription_usage_log"
id = Column(String, primary_key=True)
user_id = Column(String, index=True, nullable=False)
subscription_id = Column(String, nullable=False)
plan_id = Column(String, nullable=False)
model_id = Column(String, nullable=True)
chat_id = Column(String, nullable=True)
message_id = Column(String, nullable=True)
created_at = Column(BigInteger, index=True)
class RedemptionCode(Base):
__tablename__ = "redemption_code"
code = Column(String, primary_key=True)
purpose = Column(String, index=True)
redemption_type = Column(String, default="duration")
plan_id = Column(String, nullable=True)
duration_days = Column(Integer, nullable=True)
upgrade_expires_at = Column(BigInteger, nullable=True)
user_id = Column(String, index=True, nullable=True)
created_at = Column(BigInteger, index=True)
expired_at = Column(BigInteger, index=True, nullable=True)
received_at = Column(BigInteger, index=True, nullable=True)
####################
# Pydantic Models
####################
class SubscriptionPlanModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
name: str
description: Optional[str] = None
price: Decimal = Field(default_factory=lambda: Decimal("0"))
monthly_message_limit: Optional[int] = None
allowed_models: Optional[List[str]] = None
priority: int = 0
is_default: bool = False
is_active: bool = True
created_at: int = Field(default_factory=lambda: int(time.time()))
updated_at: int = Field(default_factory=lambda: int(time.time()))
class SubscriptionPlanForm(BaseModel):
id: str
name: str
description: Optional[str] = None
price: float = 0
monthly_message_limit: Optional[int] = None
allowed_models: Optional[List[str]] = None
priority: int = 0
is_default: bool = False
is_active: bool = True
class UpdateSubscriptionPlanForm(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
price: Optional[float] = None
monthly_message_limit: Optional[int] = None
allowed_models: Optional[List[str]] = None
priority: Optional[int] = None
is_default: Optional[bool] = None
is_active: Optional[bool] = None
class UserSubscriptionModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=lambda: uuid.uuid4().hex)
user_id: str
plan_id: str
started_at: int
expires_at: Optional[int] = None
current_period_start: int
current_period_end: int
messages_used: int = 0
created_at: int = Field(default_factory=lambda: int(time.time()))
updated_at: int = Field(default_factory=lambda: int(time.time()))
class UserSubscriptionWithPlanModel(UserSubscriptionModel):
plan: Optional[SubscriptionPlanModel] = None
messages_remaining: Optional[int] = None
days_remaining: Optional[int] = None
class SubscriptionUsageLogModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=lambda: uuid.uuid4().hex)
user_id: str
subscription_id: str
plan_id: str
model_id: Optional[str] = None
chat_id: Optional[str] = None
message_id: Optional[str] = None
created_at: int = Field(default_factory=lambda: int(time.time()))
class RedemptionCodeModel(BaseModel):
model_config = ConfigDict(from_attributes=True, extra="allow")
code: str
purpose: str
redemption_type: str = "duration"
plan_id: Optional[str] = None
duration_days: Optional[int] = None
upgrade_expires_at: Optional[int] = None
user_id: Optional[str] = None
created_at: int
expired_at: Optional[int] = None
received_at: Optional[int] = None
class CreateRedemptionCodeForm(BaseModel):
purpose: str = Field(min_length=1, max_length=255)
count: int = Field(ge=1, le=1000)
redemption_type: str = "duration"
plan_id: str
duration_days: Optional[int] = Field(default=None, ge=1)
upgrade_expires_at: Optional[int] = Field(default=None, gt=0)
expired_at: Optional[int] = Field(default=None, gt=0)
class UpdateRedemptionCodeForm(BaseModel):
purpose: Optional[str] = Field(None, min_length=1, max_length=255)
plan_id: Optional[str] = None
duration_days: Optional[int] = Field(None, ge=1)
upgrade_expires_at: Optional[int] = Field(None, gt=0)
expired_at: Optional[int] = Field(None, gt=0)
class UpdateUserSubscriptionForm(BaseModel):
plan_id: Optional[str] = None
expires_at: Optional[int] = None
messages_used: Optional[int] = None
####################
# Table Classes
####################
class SubscriptionPlansTable:
def get_all_plans(self, include_inactive: bool = False) -> List[SubscriptionPlanModel]:
with get_db() as db:
query = db.query(SubscriptionPlan)
if not include_inactive:
query = query.filter(SubscriptionPlan.is_active == True)
plans = query.order_by(SubscriptionPlan.priority.desc()).all()
return [SubscriptionPlanModel.model_validate(plan) for plan in plans]
def get_plan_by_id(self, plan_id: str) -> Optional[SubscriptionPlanModel]:
with get_db() as db:
plan = db.query(SubscriptionPlan).filter(SubscriptionPlan.id == plan_id).first()
return SubscriptionPlanModel.model_validate(plan) if plan else None
def get_default_plan(self) -> Optional[SubscriptionPlanModel]:
with get_db() as db:
plan = db.query(SubscriptionPlan).filter(
SubscriptionPlan.is_default == True,
SubscriptionPlan.is_active == True
).first()
return SubscriptionPlanModel.model_validate(plan) if plan else None
def create_plan(self, form_data: SubscriptionPlanForm) -> SubscriptionPlanModel:
now = int(time.time())
plan = SubscriptionPlanModel(
id=form_data.id,
name=form_data.name,
description=form_data.description,
price=Decimal(str(form_data.price)),
monthly_message_limit=form_data.monthly_message_limit,
allowed_models=form_data.allowed_models,
priority=form_data.priority,
is_default=form_data.is_default,
is_active=form_data.is_active,
created_at=now,
updated_at=now,
)
# if this is set as default, unset other defaults
if form_data.is_default:
with get_db() as db:
db.query(SubscriptionPlan).filter(
SubscriptionPlan.is_default == True
).update({"is_default": False})
db.commit()
with get_db() as db:
db.add(SubscriptionPlan(**plan.model_dump()))
db.commit()
return plan
def update_plan(self, plan_id: str, form_data: UpdateSubscriptionPlanForm) -> Optional[SubscriptionPlanModel]:
with get_db() as db:
plan = db.query(SubscriptionPlan).filter(SubscriptionPlan.id == plan_id).first()
if not plan:
return None
update_data = form_data.model_dump(exclude_unset=True)
update_data["updated_at"] = int(time.time())
if update_data.get("price") is not None:
update_data["price"] = Decimal(str(update_data["price"]))
# if this is set as default, unset other defaults
if update_data.get("is_default"):
db.query(SubscriptionPlan).filter(
SubscriptionPlan.id != plan_id,
SubscriptionPlan.is_default == True
).update({"is_default": False})
db.query(SubscriptionPlan).filter(SubscriptionPlan.id == plan_id).update(update_data)
db.commit()
return self.get_plan_by_id(plan_id)
def delete_plan(self, plan_id: str) -> bool:
with get_db() as db:
result = db.query(SubscriptionPlan).filter(SubscriptionPlan.id == plan_id).delete()
db.commit()
return result > 0
SubscriptionPlans = SubscriptionPlansTable()
class UserSubscriptionsTable:
def get_by_user_id(self, user_id: str) -> Optional[UserSubscriptionModel]:
with get_db() as db:
sub = db.query(UserSubscription).filter(UserSubscription.user_id == user_id).first()
return UserSubscriptionModel.model_validate(sub) if sub else None
def get_with_plan(self, user_id: str) -> Optional[UserSubscriptionWithPlanModel]:
sub = self.get_by_user_id(user_id)
if not sub:
return None
plan = SubscriptionPlans.get_plan_by_id(sub.plan_id)
now = int(time.time())
messages_remaining = None
if plan and plan.monthly_message_limit:
messages_remaining = max(0, plan.monthly_message_limit - sub.messages_used)
days_remaining = None
if sub.expires_at:
days_remaining = max(0, (sub.expires_at - now) // 86400)
return UserSubscriptionWithPlanModel(
**sub.model_dump(),
plan=plan,
messages_remaining=messages_remaining,
days_remaining=days_remaining,
)
def get_all_subscriptions(
self,
query: Optional[str] = None,
offset: int = 0,
limit: int = 30
) -> Tuple[int, List[UserSubscriptionWithPlanModel]]:
with get_db() as db:
q = db.query(UserSubscription).order_by(UserSubscription.updated_at.desc())
if query:
q = q.filter(UserSubscription.user_id.contains(query))
total = q.count()
subs = q.offset(offset).limit(limit).all()
results = []
for sub in subs:
sub_model = UserSubscriptionModel.model_validate(sub)
plan = SubscriptionPlans.get_plan_by_id(sub_model.plan_id)
now = int(time.time())
messages_remaining = None
if plan and plan.monthly_message_limit:
messages_remaining = max(0, plan.monthly_message_limit - sub_model.messages_used)
days_remaining = None
if sub_model.expires_at:
days_remaining = max(0, (sub_model.expires_at - now) // 86400)
results.append(UserSubscriptionWithPlanModel(
**sub_model.model_dump(),
plan=plan,
messages_remaining=messages_remaining,
days_remaining=days_remaining,
))
return total, results
def init_free_subscription(self, user_id: str) -> UserSubscriptionModel:
default_plan = SubscriptionPlans.get_default_plan()
if not default_plan:
# create a default free plan if not exists
default_plan = SubscriptionPlans.create_plan(SubscriptionPlanForm(
id="free",
name="Free",
description="Default free plan",
price=0,
monthly_message_limit=100,
allowed_models=[],
priority=0,
is_default=True,
is_active=True,
))
now = int(time.time())
# calculate period end (end of current month)
import datetime
today = datetime.date.today()
if today.month == 12:
next_month = today.replace(year=today.year + 1, month=1, day=1)
else:
next_month = today.replace(month=today.month + 1, day=1)
period_end = int(datetime.datetime.combine(next_month, datetime.time.min).timestamp())
sub = UserSubscriptionModel(
user_id=user_id,
plan_id=default_plan.id,
started_at=now,
expires_at=None,
current_period_start=now,
current_period_end=period_end,
messages_used=0,
created_at=now,
updated_at=now,
)
with get_db() as db:
db.add(UserSubscription(**sub.model_dump()))
db.commit()
return sub
def get_or_create_subscription(self, user_id: str) -> UserSubscriptionModel:
"""Get existing subscription or create a new free subscription for user."""
existing = self.get_by_user_id(user_id)
if existing:
return existing
return self.init_free_subscription(user_id)
def update_subscription(
self, user_id: str, form_data: UpdateUserSubscriptionForm
) -> Optional[UserSubscriptionModel]:
with get_db() as db:
sub = db.query(UserSubscription).filter(UserSubscription.user_id == user_id).first()
if not sub:
return None
update_data = form_data.model_dump(exclude_unset=True)
update_data["updated_at"] = int(time.time())
db.query(UserSubscription).filter(UserSubscription.user_id == user_id).update(update_data)
db.commit()
return self.get_by_user_id(user_id)
def upgrade_subscription(
self, user_id: str, plan_id: str, duration_days: Optional[int] = None, expires_at: Optional[int] = None
) -> UserSubscriptionModel:
now = int(time.time())
# calculate new expiration
new_expires_at = expires_at
if duration_days:
existing = self.get_by_user_id(user_id)
if existing and existing.expires_at and existing.expires_at > now:
# extend from current expiration
new_expires_at = existing.expires_at + (duration_days * 86400)
else:
# start fresh
new_expires_at = now + (duration_days * 86400)
with get_db() as db:
existing = db.query(UserSubscription).filter(UserSubscription.user_id == user_id).first()
if existing:
db.query(UserSubscription).filter(UserSubscription.user_id == user_id).update({
"plan_id": plan_id,
"expires_at": new_expires_at,
"updated_at": now,
})
db.commit()
else:
# calculate period end
import datetime
today = datetime.date.today()
if today.month == 12:
next_month = today.replace(year=today.year + 1, month=1, day=1)
else:
next_month = today.replace(month=today.month + 1, day=1)
period_end = int(datetime.datetime.combine(next_month, datetime.time.min).timestamp())
sub = UserSubscriptionModel(
user_id=user_id,
plan_id=plan_id,
started_at=now,
expires_at=new_expires_at,
current_period_start=now,
current_period_end=period_end,
messages_used=0,
created_at=now,
updated_at=now,
)
db.add(UserSubscription(**sub.model_dump()))
db.commit()
return self.get_by_user_id(user_id)
def downgrade_to_free(self, user_id: str) -> UserSubscriptionModel:
default_plan = SubscriptionPlans.get_default_plan()
if not default_plan:
raise ValueError("No default plan configured")
now = int(time.time())
with get_db() as db:
db.query(UserSubscription).filter(UserSubscription.user_id == user_id).update({
"plan_id": default_plan.id,
"expires_at": None,
"updated_at": now,
})
db.commit()
return self.get_by_user_id(user_id)
def reset_period(self, subscription_id: str) -> UserSubscriptionModel:
import datetime
now = int(time.time())
today = datetime.date.today()
if today.month == 12:
next_month = today.replace(year=today.year + 1, month=1, day=1)
else:
next_month = today.replace(month=today.month + 1, day=1)
period_end = int(datetime.datetime.combine(next_month, datetime.time.min).timestamp())
with get_db() as db:
db.query(UserSubscription).filter(UserSubscription.id == subscription_id).update({
"current_period_start": now,
"current_period_end": period_end,
"messages_used": 0,
"updated_at": now,
})
db.commit()
sub = db.query(UserSubscription).filter(UserSubscription.id == subscription_id).first()
return UserSubscriptionModel.model_validate(sub)
def increment_usage(self, subscription_id: str) -> None:
with get_db() as db:
db.query(UserSubscription).filter(UserSubscription.id == subscription_id).update({
"messages_used": UserSubscription.messages_used + 1,
"updated_at": int(time.time()),
})
db.commit()
UserSubscriptions = UserSubscriptionsTable()
class SubscriptionUsageLogsTable:
def insert(
self,
user_id: str,
subscription_id: str,
plan_id: str,
model_id: Optional[str] = None,
chat_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> SubscriptionUsageLogModel:
log = SubscriptionUsageLogModel(
user_id=user_id,
subscription_id=subscription_id,
plan_id=plan_id,
model_id=model_id,
chat_id=chat_id,
message_id=message_id,
)
with get_db() as db:
db.add(SubscriptionUsageLog(**log.model_dump()))
db.commit()
return log
def get_by_user_id(
self, user_id: str, offset: int = 0, limit: int = 30
) -> Tuple[int, List[SubscriptionUsageLogModel]]:
with get_db() as db:
query = db.query(SubscriptionUsageLog).filter(
SubscriptionUsageLog.user_id == user_id
).order_by(SubscriptionUsageLog.created_at.desc())
total = query.count()
logs = query.offset(offset).limit(limit).all()
return total, [SubscriptionUsageLogModel.model_validate(log) for log in logs]
def get_by_time_range(
self, start_time: int, end_time: int, user_ids: Optional[List[str]] = None
) -> List[SubscriptionUsageLogModel]:
with get_db() as db:
query = db.query(SubscriptionUsageLog).filter(
SubscriptionUsageLog.created_at >= start_time,
SubscriptionUsageLog.created_at < end_time
)
if user_ids:
query = query.filter(SubscriptionUsageLog.user_id.in_(user_ids))
logs = query.order_by(SubscriptionUsageLog.created_at.asc()).all()
return [SubscriptionUsageLogModel.model_validate(log) for log in logs]
SubscriptionUsageLogs = SubscriptionUsageLogsTable()
class RedemptionCodesTable:
def get_code(self, code: str) -> Optional[RedemptionCodeModel]:
with get_db() as db:
redemption_code = db.query(RedemptionCode).filter(RedemptionCode.code == code).first()
return RedemptionCodeModel.model_validate(redemption_code) if redemption_code else None
def get_codes(
self, keyword: Optional[str] = None, offset: int = 0, limit: int = 30
) -> Tuple[int, List[RedemptionCodeModel]]:
with get_db() as db:
query = db.query(RedemptionCode).order_by(RedemptionCode.created_at.desc())
if keyword:
query = query.filter(
(RedemptionCode.code == keyword) |
(RedemptionCode.purpose.contains(keyword)) |
(RedemptionCode.plan_id == keyword)
)
total = query.count()
codes = query.offset(offset).limit(limit).all()
return total, [RedemptionCodeModel.model_validate(code) for code in codes]
def insert_codes(self, codes: List[RedemptionCodeModel]) -> None:
with get_db() as db:
db.add_all([RedemptionCode(**code.model_dump()) for code in codes])
db.commit()
def update_code(self, code: str, form_data: UpdateRedemptionCodeForm) -> Optional[RedemptionCodeModel]:
with get_db() as db:
existing = db.query(RedemptionCode).filter(RedemptionCode.code == code).first()
if not existing:
return None
update_data = form_data.model_dump(exclude_unset=True)
db.query(RedemptionCode).filter(RedemptionCode.code == code).update(update_data)
db.commit()
return self.get_code(code)
def delete_code(self, code: str) -> bool:
with get_db() as db:
result = db.query(RedemptionCode).filter(RedemptionCode.code == code).delete()
db.commit()
return result > 0
def redeem_code(self, code: str, user_id: str) -> UserSubscriptionModel:
from fastapi import HTTPException
redemption_code = self.get_code(code)
if not redemption_code:
raise HTTPException(status_code=404, detail="Code not found")
if redemption_code.user_id is not None:
raise HTTPException(status_code=400, detail="Code already used")
now = int(time.time())
if redemption_code.expired_at and redemption_code.expired_at < now:
raise HTTPException(status_code=400, detail="Code expired")
# mark code as used
with get_db() as db:
db.query(RedemptionCode).filter(RedemptionCode.code == code).update({
"user_id": user_id,
"received_at": now,
})
db.commit()
# upgrade user subscription
if redemption_code.redemption_type == "duration":
return UserSubscriptions.upgrade_subscription(
user_id=user_id,
plan_id=redemption_code.plan_id,
duration_days=redemption_code.duration_days,
)
else:
return UserSubscriptions.upgrade_subscription(
user_id=user_id,
plan_id=redemption_code.plan_id,
expires_at=redemption_code.upgrade_expires_at,
)
RedemptionCodes = RedemptionCodesTable()