Some checks failed
Create and publish Docker images with specific build args / build-main-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda126-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda126-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-ollama-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-ollama-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-slim-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-slim-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Python CI / Format Backend (3.11.x) (push) Has been cancelled
Python CI / Format Backend (3.12.x) (push) Has been cancelled
Frontend Build / Format & Build Frontend (push) Has been cancelled
Frontend Build / Frontend Unit Tests (push) Has been cancelled
Create and publish Docker images with specific build args / merge-main-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-cuda-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-cuda126-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-ollama-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-slim-images (push) Has been cancelled
Close inactive issues / close-issues (push) Has been cancelled
781 lines
27 KiB
Python
781 lines
27 KiB
Python
import json
|
|
import time
|
|
import uuid
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
from decimal import Decimal
|
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
from sqlalchemy import JSON, BigInteger, Boolean, Column, Integer, Numeric, String, Text
|
|
|
|
from open_webui.internal.db import Base, get_db
|
|
|
|
|
|
####################
|
|
# Subscription Plan DB Schema
|
|
####################
|
|
|
|
|
|
class SubscriptionPlan(Base):
|
|
__tablename__ = "subscription_plan"
|
|
|
|
id = Column(String, primary_key=True)
|
|
name = Column(String, nullable=False)
|
|
description = Column(Text, nullable=True)
|
|
price = Column(Numeric(precision=10, scale=2), default=0)
|
|
|
|
# per-model monthly limits: {"model_id": limit, ...}
|
|
# -1 = unlimited, 0 = not allowed, positive = monthly limit
|
|
model_limits = Column(JSON, nullable=True)
|
|
|
|
# default limit for models not in model_limits
|
|
# -1 = unlimited, 0 = not allowed, positive = monthly limit
|
|
default_model_limit = Column(Integer, default=0)
|
|
|
|
# deprecated fields (kept for backward compatibility)
|
|
monthly_message_limit = Column(Integer, nullable=True)
|
|
allowed_models = Column(JSON, nullable=True)
|
|
|
|
priority = Column(Integer, default=0)
|
|
is_default = Column(Boolean, default=False)
|
|
is_active = Column(Boolean, default=True)
|
|
|
|
created_at = Column(BigInteger)
|
|
updated_at = Column(BigInteger)
|
|
|
|
|
|
class UserSubscription(Base):
|
|
__tablename__ = "user_subscription"
|
|
|
|
id = Column(String, primary_key=True)
|
|
user_id = Column(String, index=True, nullable=False, unique=True)
|
|
plan_id = Column(String, nullable=False)
|
|
started_at = Column(BigInteger, nullable=False)
|
|
expires_at = Column(BigInteger, nullable=True)
|
|
|
|
current_period_start = Column(BigInteger)
|
|
current_period_end = Column(BigInteger)
|
|
|
|
# per-model usage tracking: {"model_id": count, ...}
|
|
model_usage = Column(JSON, default={})
|
|
|
|
# deprecated field (kept for backward compatibility)
|
|
messages_used = Column(Integer, default=0)
|
|
|
|
created_at = Column(BigInteger)
|
|
updated_at = Column(BigInteger)
|
|
|
|
|
|
class SubscriptionUsageLog(Base):
|
|
__tablename__ = "subscription_usage_log"
|
|
|
|
id = Column(String, primary_key=True)
|
|
user_id = Column(String, index=True, nullable=False)
|
|
subscription_id = Column(String, nullable=False)
|
|
plan_id = Column(String, nullable=False)
|
|
|
|
model_id = Column(String, nullable=True)
|
|
chat_id = Column(String, nullable=True)
|
|
message_id = Column(String, nullable=True)
|
|
|
|
created_at = Column(BigInteger, index=True)
|
|
|
|
|
|
class RedemptionCode(Base):
|
|
__tablename__ = "redemption_code"
|
|
|
|
code = Column(String, primary_key=True)
|
|
purpose = Column(String, index=True)
|
|
|
|
redemption_type = Column(String, default="duration")
|
|
plan_id = Column(String, nullable=True)
|
|
duration_days = Column(Integer, nullable=True)
|
|
upgrade_days = Column(Integer, nullable=True)
|
|
|
|
user_id = Column(String, index=True, nullable=True)
|
|
created_at = Column(BigInteger, index=True)
|
|
expired_at = Column(BigInteger, index=True, nullable=True)
|
|
received_at = Column(BigInteger, index=True, nullable=True)
|
|
|
|
|
|
####################
|
|
# Pydantic Models
|
|
####################
|
|
|
|
|
|
def _parse_json_field(v):
|
|
"""Parse JSON field from string if needed."""
|
|
if isinstance(v, str):
|
|
try:
|
|
return json.loads(v)
|
|
except json.JSONDecodeError:
|
|
return None
|
|
return v
|
|
|
|
|
|
class SubscriptionPlanModel(BaseModel):
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
id: str
|
|
name: str
|
|
description: Optional[str] = None
|
|
price: Decimal = Field(default_factory=lambda: Decimal("0"))
|
|
|
|
# per-model monthly limits
|
|
model_limits: Optional[Dict[str, int]] = None
|
|
default_model_limit: int = 0
|
|
|
|
# deprecated fields
|
|
monthly_message_limit: Optional[int] = None
|
|
allowed_models: Optional[List[str]] = None
|
|
|
|
priority: int = 0
|
|
is_default: bool = False
|
|
is_active: bool = True
|
|
created_at: int = Field(default_factory=lambda: int(time.time()))
|
|
updated_at: int = Field(default_factory=lambda: int(time.time()))
|
|
|
|
@field_validator('model_limits', 'allowed_models', mode='before')
|
|
@classmethod
|
|
def parse_json_fields(cls, v):
|
|
return _parse_json_field(v)
|
|
|
|
def get_model_limit(self, model_id: str) -> int:
|
|
"""
|
|
Get the usage limit for a specific model.
|
|
Returns: -1 = unlimited, 0 = not allowed, positive = monthly limit
|
|
"""
|
|
if self.model_limits and model_id in self.model_limits:
|
|
return self.model_limits[model_id]
|
|
return self.default_model_limit
|
|
|
|
|
|
class SubscriptionPlanForm(BaseModel):
|
|
id: str
|
|
name: str
|
|
description: Optional[str] = None
|
|
price: float = 0
|
|
model_limits: Optional[Dict[str, int]] = None
|
|
default_model_limit: int = 0
|
|
priority: int = 0
|
|
is_default: bool = False
|
|
is_active: bool = True
|
|
|
|
|
|
class UpdateSubscriptionPlanForm(BaseModel):
|
|
name: Optional[str] = None
|
|
description: Optional[str] = None
|
|
price: Optional[float] = None
|
|
model_limits: Optional[Dict[str, int]] = None
|
|
default_model_limit: Optional[int] = None
|
|
priority: Optional[int] = None
|
|
is_default: Optional[bool] = None
|
|
is_active: Optional[bool] = None
|
|
|
|
|
|
class UserSubscriptionModel(BaseModel):
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
id: str = Field(default_factory=lambda: uuid.uuid4().hex)
|
|
user_id: str
|
|
plan_id: str
|
|
started_at: int
|
|
expires_at: Optional[int] = None
|
|
current_period_start: int
|
|
current_period_end: int
|
|
|
|
# per-model usage tracking
|
|
model_usage: Dict[str, int] = Field(default_factory=dict)
|
|
|
|
# deprecated field
|
|
messages_used: int = 0
|
|
|
|
created_at: int = Field(default_factory=lambda: int(time.time()))
|
|
updated_at: int = Field(default_factory=lambda: int(time.time()))
|
|
|
|
@field_validator('model_usage', mode='before')
|
|
@classmethod
|
|
def parse_model_usage(cls, v):
|
|
parsed = _parse_json_field(v)
|
|
return parsed if parsed else {}
|
|
|
|
def get_model_usage(self, model_id: str) -> int:
|
|
"""Get usage count for a specific model."""
|
|
return self.model_usage.get(model_id, 0)
|
|
|
|
|
|
class ModelUsageInfo(BaseModel):
|
|
"""Usage info for a single model."""
|
|
model_id: str
|
|
limit: int # -1 = unlimited, 0 = not allowed
|
|
used: int
|
|
remaining: Optional[int] = None # None if unlimited
|
|
|
|
|
|
class UserSubscriptionWithPlanModel(UserSubscriptionModel):
|
|
plan: Optional[SubscriptionPlanModel] = None
|
|
days_remaining: Optional[int] = None
|
|
|
|
# per-model usage summary
|
|
model_usage_info: Optional[List[ModelUsageInfo]] = None
|
|
|
|
# deprecated
|
|
messages_remaining: Optional[int] = None
|
|
|
|
# admin view fields
|
|
username: Optional[str] = None
|
|
|
|
|
|
class SubscriptionUsageLogModel(BaseModel):
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
id: str = Field(default_factory=lambda: uuid.uuid4().hex)
|
|
user_id: str
|
|
subscription_id: str
|
|
plan_id: str
|
|
model_id: Optional[str] = None
|
|
chat_id: Optional[str] = None
|
|
message_id: Optional[str] = None
|
|
created_at: int = Field(default_factory=lambda: int(time.time()))
|
|
|
|
|
|
class RedemptionCodeModel(BaseModel):
|
|
model_config = ConfigDict(from_attributes=True, extra="allow")
|
|
|
|
code: str
|
|
purpose: str
|
|
redemption_type: str = "duration"
|
|
plan_id: Optional[str] = None
|
|
duration_days: Optional[int] = None
|
|
upgrade_days: Optional[int] = None
|
|
user_id: Optional[str] = None
|
|
created_at: int
|
|
expired_at: Optional[int] = None
|
|
received_at: Optional[int] = None
|
|
|
|
|
|
class CreateRedemptionCodeForm(BaseModel):
|
|
purpose: str = Field(min_length=1, max_length=255)
|
|
count: int = Field(ge=1, le=1000)
|
|
redemption_type: str = "duration"
|
|
plan_id: str
|
|
duration_days: Optional[int] = Field(default=None, ge=1)
|
|
upgrade_days: Optional[int] = Field(default=None, ge=1)
|
|
expired_at: Optional[int] = Field(default=None, gt=0)
|
|
|
|
|
|
class UpdateRedemptionCodeForm(BaseModel):
|
|
purpose: Optional[str] = Field(None, min_length=1, max_length=255)
|
|
plan_id: Optional[str] = None
|
|
duration_days: Optional[int] = Field(None, ge=1)
|
|
upgrade_days: Optional[int] = Field(None, ge=1)
|
|
expired_at: Optional[int] = Field(None, gt=0)
|
|
|
|
|
|
class UpdateUserSubscriptionForm(BaseModel):
|
|
plan_id: Optional[str] = None
|
|
expires_at: Optional[int] = None
|
|
messages_used: Optional[int] = None
|
|
|
|
|
|
####################
|
|
# Table Classes
|
|
####################
|
|
|
|
|
|
class SubscriptionPlansTable:
|
|
def get_all_plans(self, include_inactive: bool = False) -> List[SubscriptionPlanModel]:
|
|
with get_db() as db:
|
|
query = db.query(SubscriptionPlan)
|
|
if not include_inactive:
|
|
query = query.filter(SubscriptionPlan.is_active == True)
|
|
plans = query.order_by(SubscriptionPlan.priority.desc()).all()
|
|
return [SubscriptionPlanModel.model_validate(plan) for plan in plans]
|
|
|
|
def get_plan_by_id(self, plan_id: str) -> Optional[SubscriptionPlanModel]:
|
|
with get_db() as db:
|
|
plan = db.query(SubscriptionPlan).filter(SubscriptionPlan.id == plan_id).first()
|
|
return SubscriptionPlanModel.model_validate(plan) if plan else None
|
|
|
|
def get_default_plan(self) -> Optional[SubscriptionPlanModel]:
|
|
with get_db() as db:
|
|
plan = db.query(SubscriptionPlan).filter(
|
|
SubscriptionPlan.is_default == True,
|
|
SubscriptionPlan.is_active == True
|
|
).first()
|
|
return SubscriptionPlanModel.model_validate(plan) if plan else None
|
|
|
|
def create_plan(self, form_data: SubscriptionPlanForm) -> SubscriptionPlanModel:
|
|
now = int(time.time())
|
|
plan = SubscriptionPlanModel(
|
|
id=form_data.id,
|
|
name=form_data.name,
|
|
description=form_data.description,
|
|
price=Decimal(str(form_data.price)),
|
|
model_limits=form_data.model_limits,
|
|
default_model_limit=form_data.default_model_limit,
|
|
priority=form_data.priority,
|
|
is_default=form_data.is_default,
|
|
is_active=form_data.is_active,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
|
|
# if this is set as default, unset other defaults
|
|
if form_data.is_default:
|
|
with get_db() as db:
|
|
db.query(SubscriptionPlan).filter(
|
|
SubscriptionPlan.is_default == True
|
|
).update({"is_default": False})
|
|
db.commit()
|
|
|
|
with get_db() as db:
|
|
db.add(SubscriptionPlan(**plan.model_dump()))
|
|
db.commit()
|
|
|
|
return plan
|
|
|
|
def update_plan(self, plan_id: str, form_data: UpdateSubscriptionPlanForm) -> Optional[SubscriptionPlanModel]:
|
|
with get_db() as db:
|
|
plan = db.query(SubscriptionPlan).filter(SubscriptionPlan.id == plan_id).first()
|
|
if not plan:
|
|
return None
|
|
|
|
update_data = form_data.model_dump(exclude_unset=True)
|
|
update_data["updated_at"] = int(time.time())
|
|
|
|
if update_data.get("price") is not None:
|
|
update_data["price"] = Decimal(str(update_data["price"]))
|
|
|
|
# if this is set as default, unset other defaults
|
|
if update_data.get("is_default"):
|
|
db.query(SubscriptionPlan).filter(
|
|
SubscriptionPlan.id != plan_id,
|
|
SubscriptionPlan.is_default == True
|
|
).update({"is_default": False})
|
|
|
|
db.query(SubscriptionPlan).filter(SubscriptionPlan.id == plan_id).update(update_data)
|
|
db.commit()
|
|
|
|
return self.get_plan_by_id(plan_id)
|
|
|
|
def delete_plan(self, plan_id: str) -> bool:
|
|
with get_db() as db:
|
|
result = db.query(SubscriptionPlan).filter(SubscriptionPlan.id == plan_id).delete()
|
|
db.commit()
|
|
return result > 0
|
|
|
|
|
|
SubscriptionPlans = SubscriptionPlansTable()
|
|
|
|
|
|
class UserSubscriptionsTable:
|
|
def get_by_user_id(self, user_id: str) -> Optional[UserSubscriptionModel]:
|
|
with get_db() as db:
|
|
sub = db.query(UserSubscription).filter(UserSubscription.user_id == user_id).first()
|
|
return UserSubscriptionModel.model_validate(sub) if sub else None
|
|
|
|
def get_with_plan(self, user_id: str) -> Optional[UserSubscriptionWithPlanModel]:
|
|
sub = self.get_by_user_id(user_id)
|
|
if not sub:
|
|
return None
|
|
|
|
plan = SubscriptionPlans.get_plan_by_id(sub.plan_id)
|
|
now = int(time.time())
|
|
|
|
days_remaining = None
|
|
if sub.expires_at:
|
|
days_remaining = max(0, (sub.expires_at - now) // 86400)
|
|
|
|
# build model usage info
|
|
model_usage_info = []
|
|
if plan and plan.model_limits:
|
|
for model_id, limit in plan.model_limits.items():
|
|
used = sub.get_model_usage(model_id)
|
|
remaining = None if limit == -1 else max(0, limit - used)
|
|
model_usage_info.append(ModelUsageInfo(
|
|
model_id=model_id,
|
|
limit=limit,
|
|
used=used,
|
|
remaining=remaining,
|
|
))
|
|
|
|
return UserSubscriptionWithPlanModel(
|
|
**sub.model_dump(),
|
|
plan=plan,
|
|
days_remaining=days_remaining,
|
|
model_usage_info=model_usage_info if model_usage_info else None,
|
|
)
|
|
|
|
def get_all_subscriptions(
|
|
self,
|
|
query: Optional[str] = None,
|
|
offset: int = 0,
|
|
limit: int = 30
|
|
) -> Tuple[int, List[UserSubscriptionWithPlanModel]]:
|
|
with get_db() as db:
|
|
q = db.query(UserSubscription).order_by(UserSubscription.updated_at.desc())
|
|
|
|
if query:
|
|
q = q.filter(UserSubscription.user_id.contains(query))
|
|
|
|
total = q.count()
|
|
subs = q.offset(offset).limit(limit).all()
|
|
|
|
results = []
|
|
for sub in subs:
|
|
sub_model = UserSubscriptionModel.model_validate(sub)
|
|
plan = SubscriptionPlans.get_plan_by_id(sub_model.plan_id)
|
|
now = int(time.time())
|
|
|
|
days_remaining = None
|
|
if sub_model.expires_at:
|
|
days_remaining = max(0, (sub_model.expires_at - now) // 86400)
|
|
|
|
results.append(UserSubscriptionWithPlanModel(
|
|
**sub_model.model_dump(),
|
|
plan=plan,
|
|
days_remaining=days_remaining,
|
|
))
|
|
|
|
return total, results
|
|
|
|
def init_free_subscription(self, user_id: str) -> UserSubscriptionModel:
|
|
default_plan = SubscriptionPlans.get_default_plan()
|
|
if not default_plan:
|
|
# create a default free plan if not exists
|
|
default_plan = SubscriptionPlans.create_plan(SubscriptionPlanForm(
|
|
id="free",
|
|
name="免费版",
|
|
description="免费套餐,基础模型无限使用",
|
|
price=0,
|
|
model_limits={},
|
|
default_model_limit=-1, # all models unlimited by default
|
|
priority=0,
|
|
is_default=True,
|
|
is_active=True,
|
|
))
|
|
|
|
now = int(time.time())
|
|
|
|
# calculate period end (end of current month)
|
|
import datetime
|
|
today = datetime.date.today()
|
|
if today.month == 12:
|
|
next_month = today.replace(year=today.year + 1, month=1, day=1)
|
|
else:
|
|
next_month = today.replace(month=today.month + 1, day=1)
|
|
period_end = int(datetime.datetime.combine(next_month, datetime.time.min).timestamp())
|
|
|
|
sub = UserSubscriptionModel(
|
|
user_id=user_id,
|
|
plan_id=default_plan.id,
|
|
started_at=now,
|
|
expires_at=None,
|
|
current_period_start=now,
|
|
current_period_end=period_end,
|
|
model_usage={},
|
|
messages_used=0,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
|
|
with get_db() as db:
|
|
db.add(UserSubscription(**sub.model_dump()))
|
|
db.commit()
|
|
|
|
return sub
|
|
|
|
def get_or_create_subscription(self, user_id: str) -> UserSubscriptionModel:
|
|
"""Get existing subscription or create a new free subscription for user."""
|
|
existing = self.get_by_user_id(user_id)
|
|
if existing:
|
|
return existing
|
|
return self.init_free_subscription(user_id)
|
|
|
|
def update_subscription(
|
|
self, user_id: str, form_data: UpdateUserSubscriptionForm
|
|
) -> Optional[UserSubscriptionModel]:
|
|
with get_db() as db:
|
|
sub = db.query(UserSubscription).filter(UserSubscription.user_id == user_id).first()
|
|
if not sub:
|
|
return None
|
|
|
|
update_data = form_data.model_dump(exclude_unset=True)
|
|
update_data["updated_at"] = int(time.time())
|
|
|
|
db.query(UserSubscription).filter(UserSubscription.user_id == user_id).update(update_data)
|
|
db.commit()
|
|
|
|
return self.get_by_user_id(user_id)
|
|
|
|
def upgrade_subscription(
|
|
self, user_id: str, plan_id: str, duration_days: Optional[int] = None, expires_at: Optional[int] = None
|
|
) -> UserSubscriptionModel:
|
|
now = int(time.time())
|
|
|
|
# calculate new expiration
|
|
new_expires_at = expires_at
|
|
if duration_days:
|
|
existing = self.get_by_user_id(user_id)
|
|
if existing and existing.expires_at and existing.expires_at > now:
|
|
# extend from current expiration
|
|
new_expires_at = existing.expires_at + (duration_days * 86400)
|
|
else:
|
|
# start fresh
|
|
new_expires_at = now + (duration_days * 86400)
|
|
|
|
with get_db() as db:
|
|
existing = db.query(UserSubscription).filter(UserSubscription.user_id == user_id).first()
|
|
|
|
if existing:
|
|
db.query(UserSubscription).filter(UserSubscription.user_id == user_id).update({
|
|
"plan_id": plan_id,
|
|
"expires_at": new_expires_at,
|
|
"updated_at": now,
|
|
})
|
|
db.commit()
|
|
else:
|
|
# calculate period end
|
|
import datetime
|
|
today = datetime.date.today()
|
|
if today.month == 12:
|
|
next_month = today.replace(year=today.year + 1, month=1, day=1)
|
|
else:
|
|
next_month = today.replace(month=today.month + 1, day=1)
|
|
period_end = int(datetime.datetime.combine(next_month, datetime.time.min).timestamp())
|
|
|
|
sub = UserSubscriptionModel(
|
|
user_id=user_id,
|
|
plan_id=plan_id,
|
|
started_at=now,
|
|
expires_at=new_expires_at,
|
|
current_period_start=now,
|
|
current_period_end=period_end,
|
|
messages_used=0,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
db.add(UserSubscription(**sub.model_dump()))
|
|
db.commit()
|
|
|
|
return self.get_by_user_id(user_id)
|
|
|
|
def downgrade_to_free(self, user_id: str) -> UserSubscriptionModel:
|
|
default_plan = SubscriptionPlans.get_default_plan()
|
|
if not default_plan:
|
|
raise ValueError("No default plan configured")
|
|
|
|
now = int(time.time())
|
|
|
|
with get_db() as db:
|
|
db.query(UserSubscription).filter(UserSubscription.user_id == user_id).update({
|
|
"plan_id": default_plan.id,
|
|
"expires_at": None,
|
|
"updated_at": now,
|
|
})
|
|
db.commit()
|
|
|
|
return self.get_by_user_id(user_id)
|
|
|
|
def reset_period(self, subscription_id: str) -> UserSubscriptionModel:
|
|
import datetime
|
|
now = int(time.time())
|
|
|
|
today = datetime.date.today()
|
|
if today.month == 12:
|
|
next_month = today.replace(year=today.year + 1, month=1, day=1)
|
|
else:
|
|
next_month = today.replace(month=today.month + 1, day=1)
|
|
period_end = int(datetime.datetime.combine(next_month, datetime.time.min).timestamp())
|
|
|
|
with get_db() as db:
|
|
db.query(UserSubscription).filter(UserSubscription.id == subscription_id).update({
|
|
"current_period_start": now,
|
|
"current_period_end": period_end,
|
|
"model_usage": {},
|
|
"messages_used": 0,
|
|
"updated_at": now,
|
|
})
|
|
db.commit()
|
|
|
|
sub = db.query(UserSubscription).filter(UserSubscription.id == subscription_id).first()
|
|
return UserSubscriptionModel.model_validate(sub)
|
|
|
|
def increment_model_usage(self, subscription_id: str, model_id: str) -> None:
|
|
"""Increment usage counter for a specific model."""
|
|
with get_db() as db:
|
|
sub = db.query(UserSubscription).filter(UserSubscription.id == subscription_id).first()
|
|
if not sub:
|
|
return
|
|
|
|
current_usage = sub.model_usage or {}
|
|
if isinstance(current_usage, str):
|
|
try:
|
|
current_usage = json.loads(current_usage)
|
|
except json.JSONDecodeError:
|
|
current_usage = {}
|
|
|
|
current_usage[model_id] = current_usage.get(model_id, 0) + 1
|
|
|
|
db.query(UserSubscription).filter(UserSubscription.id == subscription_id).update({
|
|
"model_usage": current_usage,
|
|
"messages_used": UserSubscription.messages_used + 1,
|
|
"updated_at": int(time.time()),
|
|
})
|
|
db.commit()
|
|
|
|
def increment_usage(self, subscription_id: str) -> None:
|
|
"""Deprecated: use increment_model_usage instead."""
|
|
with get_db() as db:
|
|
db.query(UserSubscription).filter(UserSubscription.id == subscription_id).update({
|
|
"messages_used": UserSubscription.messages_used + 1,
|
|
"updated_at": int(time.time()),
|
|
})
|
|
db.commit()
|
|
|
|
|
|
UserSubscriptions = UserSubscriptionsTable()
|
|
|
|
|
|
class SubscriptionUsageLogsTable:
|
|
def insert(
|
|
self,
|
|
user_id: str,
|
|
subscription_id: str,
|
|
plan_id: str,
|
|
model_id: Optional[str] = None,
|
|
chat_id: Optional[str] = None,
|
|
message_id: Optional[str] = None,
|
|
) -> SubscriptionUsageLogModel:
|
|
log = SubscriptionUsageLogModel(
|
|
user_id=user_id,
|
|
subscription_id=subscription_id,
|
|
plan_id=plan_id,
|
|
model_id=model_id,
|
|
chat_id=chat_id,
|
|
message_id=message_id,
|
|
)
|
|
|
|
with get_db() as db:
|
|
db.add(SubscriptionUsageLog(**log.model_dump()))
|
|
db.commit()
|
|
|
|
return log
|
|
|
|
def get_by_user_id(
|
|
self, user_id: str, offset: int = 0, limit: int = 30
|
|
) -> Tuple[int, List[SubscriptionUsageLogModel]]:
|
|
with get_db() as db:
|
|
query = db.query(SubscriptionUsageLog).filter(
|
|
SubscriptionUsageLog.user_id == user_id
|
|
).order_by(SubscriptionUsageLog.created_at.desc())
|
|
|
|
total = query.count()
|
|
logs = query.offset(offset).limit(limit).all()
|
|
|
|
return total, [SubscriptionUsageLogModel.model_validate(log) for log in logs]
|
|
|
|
def get_by_time_range(
|
|
self, start_time: int, end_time: int, user_ids: Optional[List[str]] = None
|
|
) -> List[SubscriptionUsageLogModel]:
|
|
with get_db() as db:
|
|
query = db.query(SubscriptionUsageLog).filter(
|
|
SubscriptionUsageLog.created_at >= start_time,
|
|
SubscriptionUsageLog.created_at < end_time
|
|
)
|
|
if user_ids:
|
|
query = query.filter(SubscriptionUsageLog.user_id.in_(user_ids))
|
|
logs = query.order_by(SubscriptionUsageLog.created_at.asc()).all()
|
|
return [SubscriptionUsageLogModel.model_validate(log) for log in logs]
|
|
|
|
|
|
SubscriptionUsageLogs = SubscriptionUsageLogsTable()
|
|
|
|
|
|
class RedemptionCodesTable:
|
|
def get_code(self, code: str) -> Optional[RedemptionCodeModel]:
|
|
with get_db() as db:
|
|
redemption_code = db.query(RedemptionCode).filter(RedemptionCode.code == code).first()
|
|
return RedemptionCodeModel.model_validate(redemption_code) if redemption_code else None
|
|
|
|
def get_codes(
|
|
self, keyword: Optional[str] = None, offset: int = 0, limit: int = 30
|
|
) -> Tuple[int, List[RedemptionCodeModel]]:
|
|
with get_db() as db:
|
|
query = db.query(RedemptionCode).order_by(RedemptionCode.created_at.desc())
|
|
|
|
if keyword:
|
|
query = query.filter(
|
|
(RedemptionCode.code == keyword) |
|
|
(RedemptionCode.purpose.contains(keyword)) |
|
|
(RedemptionCode.plan_id == keyword)
|
|
)
|
|
|
|
total = query.count()
|
|
codes = query.offset(offset).limit(limit).all()
|
|
|
|
return total, [RedemptionCodeModel.model_validate(code) for code in codes]
|
|
|
|
def insert_codes(self, codes: List[RedemptionCodeModel]) -> None:
|
|
with get_db() as db:
|
|
db.add_all([RedemptionCode(**code.model_dump()) for code in codes])
|
|
db.commit()
|
|
|
|
def update_code(self, code: str, form_data: UpdateRedemptionCodeForm) -> Optional[RedemptionCodeModel]:
|
|
with get_db() as db:
|
|
existing = db.query(RedemptionCode).filter(RedemptionCode.code == code).first()
|
|
if not existing:
|
|
return None
|
|
|
|
update_data = form_data.model_dump(exclude_unset=True)
|
|
db.query(RedemptionCode).filter(RedemptionCode.code == code).update(update_data)
|
|
db.commit()
|
|
|
|
return self.get_code(code)
|
|
|
|
def delete_code(self, code: str) -> bool:
|
|
with get_db() as db:
|
|
result = db.query(RedemptionCode).filter(RedemptionCode.code == code).delete()
|
|
db.commit()
|
|
return result > 0
|
|
|
|
def redeem_code(self, code: str, user_id: str) -> UserSubscriptionModel:
|
|
from fastapi import HTTPException
|
|
|
|
redemption_code = self.get_code(code)
|
|
if not redemption_code:
|
|
raise HTTPException(status_code=404, detail="Code not found")
|
|
|
|
if redemption_code.user_id is not None:
|
|
raise HTTPException(status_code=400, detail="Code already used")
|
|
|
|
now = int(time.time())
|
|
if redemption_code.expired_at and redemption_code.expired_at < now:
|
|
raise HTTPException(status_code=400, detail="Code expired")
|
|
|
|
# mark code as used
|
|
with get_db() as db:
|
|
db.query(RedemptionCode).filter(RedemptionCode.code == code).update({
|
|
"user_id": user_id,
|
|
"received_at": now,
|
|
})
|
|
db.commit()
|
|
|
|
# upgrade user subscription
|
|
if redemption_code.redemption_type == "duration":
|
|
return UserSubscriptions.upgrade_subscription(
|
|
user_id=user_id,
|
|
plan_id=redemption_code.plan_id,
|
|
duration_days=redemption_code.duration_days,
|
|
)
|
|
else:
|
|
# upgrade type: calculate expires_at from upgrade_days
|
|
return UserSubscriptions.upgrade_subscription(
|
|
user_id=user_id,
|
|
plan_id=redemption_code.plan_id,
|
|
duration_days=redemption_code.upgrade_days,
|
|
)
|
|
|
|
|
|
RedemptionCodes = RedemptionCodesTable()
|