feat:精简
Some checks failed
Create and publish Docker images with specific build args / build-main-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda126-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda126-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-ollama-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-ollama-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-slim-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-slim-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Python CI / Format Backend (3.11.x) (push) Has been cancelled
Python CI / Format Backend (3.12.x) (push) Has been cancelled
Frontend Build / Format & Build Frontend (push) Has been cancelled
Frontend Build / Frontend Unit Tests (push) Has been cancelled
Create and publish Docker images with specific build args / merge-main-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-cuda-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-cuda126-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-ollama-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-slim-images (push) Has been cancelled
Close inactive issues / close-issues (push) Has been cancelled

This commit is contained in:
2026-01-16 18:34:38 +08:00
parent 16263710d9
commit 11fcec9387
137 changed files with 68993 additions and 6435 deletions

View File

@@ -22,7 +22,7 @@ from typing import Optional, Union, List, Dict
from opentelemetry import trace
from open_webui.config import WEBUI_URL
from open_webui.config import WEBUI_URL, EMAIL_VERIFY_TEMPLATE, EMAIL_CODE_TEMPLATE
from open_webui.utils.smtp import send_email
from open_webui.utils.access_control import has_permission
@@ -47,6 +47,7 @@ from open_webui.env import (
REDIS_SENTINEL_PORT,
WEBUI_NAME,
REDIS_CLUSTER,
BASE_DIR,
)
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
@@ -81,76 +82,176 @@ def verify_signature(payload: str, signature: str) -> bool:
return False
def override_static(path: str, content: str):
def override_static(path: str, content: str) -> bool:
"""Download content from URL and save to path.
Returns True if successful, False otherwise.
Does not raise exceptions to prevent app startup failure.
"""
os.makedirs(os.path.dirname(path), exist_ok=True)
r = requests.get(content, stream=True)
with open(path, "wb") as f:
r.raw.decode_content = True
shutil.copyfileobj(r.raw, f)
try:
# try with SSL verification first, then without if it fails
try:
r = requests.get(content, stream=True, timeout=30)
except requests.exceptions.SSLError:
log.warning(f"SSL error downloading {content}, retrying without verification")
r = requests.get(content, stream=True, timeout=30, verify=False)
r.raise_for_status()
with open(path, "wb") as f:
r.raw.decode_content = True
shutil.copyfileobj(r.raw, f)
log.info(f"Successfully downloaded {content} to {path}")
return True
except Exception as e:
log.error(f"Failed to download {content}: {e}")
return False
def apply_branding(app):
"""Apply branding settings from database config or environment variables."""
custom_png = ""
custom_svg = ""
custom_ico = ""
custom_dark_png = ""
organization_name = "OpenWebui"
custom_name = ""
if hasattr(app, "state") and hasattr(app.state, "config"):
# Get config values - must access _state directly to get fresh PersistentConfig values
def get_config_value(attr_name, default=""):
# Access _state directly to get the PersistentConfig object
config_obj = app.state.config._state.get(attr_name)
if config_obj is None:
return default
# If it's a PersistentConfig object, get its value
if hasattr(config_obj, "value"):
result = config_obj.value or default
log.info(f"Branding: {attr_name}.value = {result}")
return result
return config_obj or default
custom_png = get_config_value("BRANDING_FAVICON_PNG") or os.getenv("CUSTOM_PNG", "")
custom_svg = get_config_value("BRANDING_FAVICON_SVG") or os.getenv("CUSTOM_SVG", "")
custom_ico = get_config_value("BRANDING_FAVICON_ICO") or os.getenv("CUSTOM_ICO", "")
custom_dark_png = get_config_value("BRANDING_FAVICON_DARK_PNG") or os.getenv("CUSTOM_DARK_PNG", "")
organization_name = get_config_value("BRANDING_ORGANIZATION_NAME") or os.getenv("ORGANIZATION_NAME", "OpenWebui")
custom_name = get_config_value("BRANDING_CUSTOM_NAME") or os.getenv("CUSTOM_NAME", "")
else:
custom_png = os.getenv("CUSTOM_PNG", "")
custom_svg = os.getenv("CUSTOM_SVG", "")
custom_ico = os.getenv("CUSTOM_ICO", "")
custom_dark_png = os.getenv("CUSTOM_DARK_PNG", "")
organization_name = os.getenv("ORGANIZATION_NAME", "OpenWebui")
custom_name = os.getenv("CUSTOM_NAME", "")
log.info(f"Branding: custom_name={custom_name}, custom_png={custom_png}")
# Frontend static/static directory (for dev mode)
# SvelteKit maps static/ folder to root, so /static/favicon.png = static/static/favicon.png
frontend_static_dir = BASE_DIR / "static" / "static"
# Build resources mapping for both backend and frontend static dirs
resources = []
# files to override
branding_files = [
("logo.png", custom_png),
("favicon.png", custom_png),
("favicon.svg", custom_svg),
("favicon-96x96.png", custom_png),
("apple-touch-icon.png", custom_png),
("web-app-manifest-192x192.png", custom_png),
("web-app-manifest-512x512.png", custom_png),
("splash.png", custom_png),
("favicon.ico", custom_ico),
("favicon-dark.png", custom_dark_png),
("splash-dark.png", custom_dark_png),
]
for filename, url in branding_files:
if url:
# Add backend static dir path
resources.append((os.path.join(STATIC_DIR, filename), url))
# Add frontend static/static dir path (for dev mode)
resources.append((os.path.join(frontend_static_dir, filename), url))
try:
for path, url in resources:
if url:
log.info(f"Branding: Downloading {url} to {path}")
override_static(path, url)
# set metadata
setattr(app.state, "LICENSE_METADATA", {
"type": "enterprise",
"organization_name": organization_name,
})
log.info(f"Branding: Set LICENSE_METADATA organization_name={organization_name}")
# set custom name if provided
if custom_name:
setattr(app.state, "WEBUI_NAME", custom_name)
log.info(f"Branding: Set WEBUI_NAME={custom_name}")
else:
log.info(f"Branding: custom_name is empty, WEBUI_NAME not changed")
return True
except Exception as ex:
log.exception(f"Branding: Uncaught Exception: {ex}")
return False
def get_license_data(app, key):
# get branding config from app.state.config if available, fallback to env vars
custom_png = ""
custom_svg = ""
custom_ico = ""
custom_dark_png = ""
organization_name = "OpenWebui"
if hasattr(app, "state") and hasattr(app.state, "config"):
custom_png = getattr(app.state.config, "BRANDING_FAVICON_PNG", "") or os.getenv("CUSTOM_PNG", "")
custom_svg = getattr(app.state.config, "BRANDING_FAVICON_SVG", "") or os.getenv("CUSTOM_SVG", "")
custom_ico = getattr(app.state.config, "BRANDING_FAVICON_ICO", "") or os.getenv("CUSTOM_ICO", "")
custom_dark_png = getattr(app.state.config, "BRANDING_FAVICON_DARK_PNG", "") or os.getenv("CUSTOM_DARK_PNG", "")
organization_name = getattr(app.state.config, "BRANDING_ORGANIZATION_NAME", "") or os.getenv("ORGANIZATION_NAME", "OpenWebui")
else:
custom_png = os.getenv("CUSTOM_PNG", "")
custom_svg = os.getenv("CUSTOM_SVG", "")
custom_ico = os.getenv("CUSTOM_ICO", "")
custom_dark_png = os.getenv("CUSTOM_DARK_PNG", "")
organization_name = os.getenv("ORGANIZATION_NAME", "OpenWebui")
payload = {
"resources": {
os.path.join(STATIC_DIR, "logo.png"): os.getenv("CUSTOM_PNG", ""),
os.path.join(STATIC_DIR, "favicon.png"): os.getenv("CUSTOM_PNG", ""),
os.path.join(STATIC_DIR, "favicon.svg"): os.getenv("CUSTOM_SVG", ""),
os.path.join(STATIC_DIR, "favicon-96x96.png"): os.getenv("CUSTOM_PNG", ""),
os.path.join(STATIC_DIR, "apple-touch-icon.png"): os.getenv(
"CUSTOM_PNG", ""
),
os.path.join(STATIC_DIR, "web-app-manifest-192x192.png"): os.getenv(
"CUSTOM_PNG", ""
),
os.path.join(STATIC_DIR, "web-app-manifest-512x512.png"): os.getenv(
"CUSTOM_PNG", ""
),
os.path.join(STATIC_DIR, "splash.png"): os.getenv("CUSTOM_PNG", ""),
os.path.join(STATIC_DIR, "favicon.ico"): os.getenv("CUSTOM_ICO", ""),
os.path.join(STATIC_DIR, "favicon-dark.png"): os.getenv(
"CUSTOM_DARK_PNG", ""
),
os.path.join(STATIC_DIR, "splash-dark.png"): os.getenv(
"CUSTOM_DARK_PNG", ""
),
os.path.join(FRONTEND_BUILD_DIR, "favicon.png"): os.getenv(
"CUSTOM_PNG", ""
),
os.path.join(FRONTEND_BUILD_DIR, "static/favicon.png"): os.getenv(
"CUSTOM_PNG", ""
),
os.path.join(FRONTEND_BUILD_DIR, "static/favicon.svg"): os.getenv(
"CUSTOM_SVG", ""
),
os.path.join(FRONTEND_BUILD_DIR, "static/favicon-96x96.png"): os.getenv(
"CUSTOM_PNG", ""
),
os.path.join(FRONTEND_BUILD_DIR, "static/apple-touch-icon.png"): os.getenv(
"CUSTOM_PNG", ""
),
os.path.join(
FRONTEND_BUILD_DIR, "static/web-app-manifest-192x192.png"
): os.getenv("CUSTOM_PNG", ""),
os.path.join(
FRONTEND_BUILD_DIR, "static/web-app-manifest-512x512.png"
): os.getenv("CUSTOM_PNG", ""),
os.path.join(FRONTEND_BUILD_DIR, "static/splash.png"): os.getenv(
"CUSTOM_PNG", ""
),
os.path.join(FRONTEND_BUILD_DIR, "static/favicon.ico"): os.getenv(
"CUSTOM_ICO", ""
),
os.path.join(FRONTEND_BUILD_DIR, "static/favicon-dark.png"): os.getenv(
"CUSTOM_DARK_PNG", ""
),
os.path.join(FRONTEND_BUILD_DIR, "static/splash-dark.png"): os.getenv(
"CUSTOM_DARK_PNG", ""
),
os.path.join(STATIC_DIR, "logo.png"): custom_png,
os.path.join(STATIC_DIR, "favicon.png"): custom_png,
os.path.join(STATIC_DIR, "favicon.svg"): custom_svg,
os.path.join(STATIC_DIR, "favicon-96x96.png"): custom_png,
os.path.join(STATIC_DIR, "apple-touch-icon.png"): custom_png,
os.path.join(STATIC_DIR, "web-app-manifest-192x192.png"): custom_png,
os.path.join(STATIC_DIR, "web-app-manifest-512x512.png"): custom_png,
os.path.join(STATIC_DIR, "splash.png"): custom_png,
os.path.join(STATIC_DIR, "favicon.ico"): custom_ico,
os.path.join(STATIC_DIR, "favicon-dark.png"): custom_dark_png,
os.path.join(STATIC_DIR, "splash-dark.png"): custom_dark_png,
os.path.join(FRONTEND_BUILD_DIR, "favicon.png"): custom_png,
os.path.join(FRONTEND_BUILD_DIR, "static/favicon.png"): custom_png,
os.path.join(FRONTEND_BUILD_DIR, "static/favicon.svg"): custom_svg,
os.path.join(FRONTEND_BUILD_DIR, "static/favicon-96x96.png"): custom_png,
os.path.join(FRONTEND_BUILD_DIR, "static/apple-touch-icon.png"): custom_png,
os.path.join(FRONTEND_BUILD_DIR, "static/web-app-manifest-192x192.png"): custom_png,
os.path.join(FRONTEND_BUILD_DIR, "static/web-app-manifest-512x512.png"): custom_png,
os.path.join(FRONTEND_BUILD_DIR, "static/splash.png"): custom_png,
os.path.join(FRONTEND_BUILD_DIR, "static/favicon.ico"): custom_ico,
os.path.join(FRONTEND_BUILD_DIR, "static/favicon-dark.png"): custom_dark_png,
os.path.join(FRONTEND_BUILD_DIR, "static/splash-dark.png"): custom_dark_png,
},
"metadata": {
"type": "enterprise",
"organization_name": os.getenv("ORGANIZATION_NAME", "OpenWebui"),
"organization_name": organization_name,
},
}
try:
@@ -611,11 +712,17 @@ def send_verify_email(email: str):
code = f"{uuid.uuid4().hex}{uuid.uuid1().hex}"
redis.set(name=get_email_code_key(code=code), value=email, ex=timedelta(days=1))
link = f"{WEBUI_URL.value.rstrip('/')}/api/v1/auths/signup_verify/{code}"
# use template from config
template = EMAIL_VERIFY_TEMPLATE.value or verify_email_template
# use replace instead of % formatting to avoid issues with % in CSS
body = template.replace("%(title)s", f"{WEBUI_NAME} Email Verify").replace("%(link)s", link)
send_email(
receiver=email,
subject=f"{WEBUI_NAME} Email Verify",
body=verify_email_template
% {"title": f"{WEBUI_NAME} Email Verify", "link": link},
body=body,
)
@@ -628,3 +735,78 @@ def verify_email_by_code(code: str) -> str:
redis_cluster=REDIS_CLUSTER,
)
return redis.get(name=get_email_code_key(code=code))
# email verification code template
email_code_template = """<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: Arial, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 500px; margin: 0 auto; background: white; border-radius: 10px; padding: 30px; }
.code { font-size: 32px; font-weight: bold; color: #3b82f6; letter-spacing: 8px; text-align: center; padding: 20px; background: #f0f9ff; border-radius: 8px; margin: 20px 0; }
.note { color: #666; font-size: 14px; }
</style>
</head>
<body>
<div class="container">
<h2>%(title)s</h2>
<p>您的验证码是:</p>
<div class="code">%(code)s</div>
<p class="note">验证码有效期为10分钟请勿泄露给他人。</p>
</div>
</body>
</html>"""
def get_signup_code_key(email: str) -> str:
return f"signup_code:{email}"
def send_signup_email_code(email: str) -> str:
"""Send a 6-digit verification code to email for signup."""
import random
redis = get_redis_connection(
redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
),
redis_cluster=REDIS_CLUSTER,
)
# generate 6-digit code
code = str(random.randint(100000, 999999))
# store in redis with 10 min expiration
redis.set(name=get_signup_code_key(email.lower()), value=code, ex=timedelta(minutes=10))
# use template from config
template = EMAIL_CODE_TEMPLATE.value or email_code_template
send_email(
receiver=email,
subject=f"{WEBUI_NAME} 注册验证码",
body=template % {"title": f"{WEBUI_NAME} 注册验证码", "code": code},
)
return code
def verify_signup_email_code(email: str, code: str) -> bool:
"""Verify the signup email code."""
redis = get_redis_connection(
redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
),
redis_cluster=REDIS_CLUSTER,
)
stored_code = redis.get(name=get_signup_code_key(email.lower()))
if stored_code and stored_code == code:
# delete code after successful verification
redis.delete(get_signup_code_key(email.lower()))
return True
return False

View File

@@ -39,8 +39,6 @@ from open_webui.routers.pipelines import (
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.utils.credit.usage import CreditDeduct
from open_webui.utils.credit.utils import check_credit_by_user_id
from open_webui.utils.plugin import (
load_function_module_by_id,
@@ -169,8 +167,6 @@ async def generate_chat_completion(
user: Any,
bypass_filter: bool = False,
):
check_credit_by_user_id(user_id=user.id, form_data=form_data)
log.debug(f"generate_chat_completion: {form_data}")
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
@@ -283,15 +279,8 @@ async def generate_chat_completion(
background=response.background,
)
else:
with CreditDeduct(
user=user,
model_id=model_id,
body=payload,
is_stream=False,
) as credit_deduct:
response = convert_response_ollama_to_openai(response)
credit_deduct.run(response)
return credit_deduct.add_usage_to_resp(response)
response = convert_response_ollama_to_openai(response)
return response
else:
return await generate_openai_chat_completion(
request=request,

View File

@@ -1,2 +0,0 @@
# credit module - deprecated, kept for backward compatibility
# this module is replaced by the subscription system

View File

@@ -1,25 +0,0 @@
# credit usage module - deprecated, kept for backward compatibility
# this module is replaced by the subscription system
class CreditDeduct:
"""
Deprecated credit deduction class.
Kept for backward compatibility - does nothing.
The subscription system now handles usage tracking.
"""
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *args):
pass
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass

View File

@@ -1,18 +0,0 @@
# credit utils module - deprecated, kept for backward compatibility
# this module is replaced by the subscription system
def is_free_request(*args, **kwargs) -> bool:
"""
Deprecated function.
Returns True to allow all requests (subscription system handles access control).
"""
return True
def check_credit_by_user_id(*args, **kwargs):
"""
Deprecated function.
Does nothing - subscription system now handles access control.
"""
pass

View File

@@ -57,6 +57,7 @@ from open_webui.routers.pipelines import (
from open_webui.routers.memories import query_memory, QueryMemoryForm
from open_webui.utils.webhook import post_webhook
from open_webui.utils.subscription.check import record_usage
from open_webui.utils.files import (
convert_markdown_base64_images,
get_file_url_from_base64,
@@ -1964,6 +1965,14 @@ async def process_chat_response(
},
)
# record subscription usage on success
record_usage(
user_id=user.id,
model_id=form_data.get("model"),
chat_id=metadata["chat_id"],
message_id=metadata["message_id"],
)
# Send a webhook notification if the user is not active
if not Users.is_user_active(user.id):
webhook_url = Users.get_user_webhook_url_by_id(user.id)
@@ -3269,6 +3278,14 @@ async def process_chat_response(
},
)
# record subscription usage on success
record_usage(
user_id=user.id,
model_id=model_id,
chat_id=metadata["chat_id"],
message_id=metadata["message_id"],
)
# Send a webhook notification if the user is not active
if not Users.is_user_active(user.id):
webhook_url = Users.get_user_webhook_url_by_id(user.id)

View File

@@ -665,24 +665,12 @@ def stream_chunks_handler(
:return: An async generator that yields the stream data.
"""
from open_webui.utils.credit.usage import CreditDeduct
max_buffer_size = CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE
if max_buffer_size is None or max_buffer_size <= 0:
async def consumer_content(stream: aiohttp.StreamReader):
with CreditDeduct(
user=user,
model_id=model_id,
body=form_data,
is_stream=True,
) as credit_deduct:
# change to avoid multi \n\n cause message lose
async for chunk in stream:
credit_deduct.run(response=chunk)
yield chunk
yield credit_deduct.usage_message
async for chunk in stream:
yield chunk
return consumer_content(stream)
@@ -690,65 +678,53 @@ def stream_chunks_handler(
buffer = b""
skip_mode = False
with CreditDeduct(
user=user,
model_id=model_id,
body=form_data,
is_stream=True,
) as credit_deduct:
# change to avoid multi \n\n cause message lose
async for data in stream:
async for data in stream:
if not data:
continue
if not data:
continue
credit_deduct.run(response=data)
# In skip_mode, if buffer already exceeds the limit, clear it (it's part of an oversized line)
if skip_mode and len(buffer) > max_buffer_size:
buffer = b""
# In skip_mode, if buffer already exceeds the limit, clear it (it's part of an oversized line)
if skip_mode and len(buffer) > max_buffer_size:
buffer = b""
lines = (buffer + data).split(b"\n")
lines = (buffer + data).split(b"\n")
# Process complete lines (except the last possibly incomplete fragment)
for i in range(len(lines) - 1):
line = lines[i]
# Process complete lines (except the last possibly incomplete fragment)
for i in range(len(lines) - 1):
line = lines[i]
if skip_mode:
# Skip mode: check if current line is small enough to exit skip mode
if len(line) <= max_buffer_size:
skip_mode = False
yield line
else:
yield b"data: {}"
yield b"\n"
if skip_mode:
# Skip mode: check if current line is small enough to exit skip mode
if len(line) <= max_buffer_size:
skip_mode = False
yield line
else:
# Normal mode: check if line exceeds limit
if len(line) > max_buffer_size:
skip_mode = True
yield b"data: {}"
yield b"\n"
log.info(f"Skip mode triggered, line size: {len(line)}")
else:
yield line
yield b"\n"
yield b"data: {}"
yield b"\n"
else:
# Normal mode: check if line exceeds limit
if len(line) > max_buffer_size:
skip_mode = True
yield b"data: {}"
yield b"\n"
log.info(f"Skip mode triggered, line size: {len(line)}")
else:
yield line
yield b"\n"
# Save the last incomplete fragment
buffer = lines[-1]
# Save the last incomplete fragment
buffer = lines[-1]
# Check if buffer exceeds limit
if not skip_mode and len(buffer) > max_buffer_size:
skip_mode = True
log.info(f"Skip mode triggered, buffer size: {len(buffer)}")
# Clear oversized buffer to prevent unlimited growth
buffer = b""
# Check if buffer exceeds limit
if not skip_mode and len(buffer) > max_buffer_size:
skip_mode = True
log.info(f"Skip mode triggered, buffer size: {len(buffer)}")
# Clear oversized buffer to prevent unlimited growth
buffer = b""
# Process remaining buffer data
if buffer and not skip_mode:
credit_deduct.run(response=buffer)
yield buffer
yield b"\n"
yield credit_deduct.usage_message
# Process remaining buffer data
if buffer and not skip_mode:
yield buffer
yield b"\n"
return yield_safe_stream_chunks()

View File

@@ -1,7 +1,6 @@
import json
from uuid import uuid4
from open_webui.utils.credit.usage import CreditDeduct
from open_webui.utils.misc import (
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
@@ -104,13 +103,7 @@ def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
async def convert_streaming_response_ollama_to_openai(
user, model_id, form_data, ollama_streaming_response
):
with CreditDeduct(
user=user,
model_id=model_id,
body=form_data,
is_stream=True,
) as credit_deduct:
async for data in ollama_streaming_response.body_iterator:
async for data in ollama_streaming_response.body_iterator:
data = json.loads(data)
model = data.get("model", "ollama")
@@ -133,11 +126,8 @@ async def convert_streaming_response_ollama_to_openai(
)
line = f"data: {json.dumps(data)}\n\n"
credit_deduct.run(line)
yield line
yield credit_deduct.usage_message
yield "data: [DONE]\n\n"

View File

@@ -10,7 +10,7 @@ from open_webui.models.subscriptions import (
)
async def check_subscription_access(
def check_subscription_access(
user_id: str,
model_id: str,
chat_id: Optional[str] = None,
@@ -21,8 +21,8 @@ async def check_subscription_access(
Raises HTTPException if:
1. User's subscription has expired (auto-downgrades to Free)
2. Monthly message limit reached
3. Model is not in allowed list
2. Model usage limit reached for this billing period
3. Model is not allowed (limit = 0)
"""
now = int(time.time())
@@ -47,19 +47,23 @@ async def check_subscription_access(
detail="Subscription plan not found"
)
# check model access
if plan.allowed_models and model_id not in plan.allowed_models:
# get model limit for this plan
model_limit = plan.get_model_limit(model_id)
# check if model is not allowed (limit = 0)
if model_limit == 0:
raise HTTPException(
status_code=403,
detail=f"Your {plan.name} plan does not include access to this model. Please upgrade your subscription."
detail=f"您的 {plan.name} 套餐不支持使用此模型。请升级套餐以获取访问权限。"
)
# check usage limit
if plan.monthly_message_limit is not None:
if subscription.messages_used >= plan.monthly_message_limit:
# check usage limit (skip if unlimited = -1)
if model_limit > 0:
current_usage = subscription.get_model_usage(model_id)
if current_usage >= model_limit:
raise HTTPException(
status_code=403,
detail=f"You have reached your monthly message limit ({plan.monthly_message_limit}). Please wait until next month or upgrade your subscription."
detail=f"您本月已使用 {current_usage}/{model_limit} 次此模型。请等待下月刷新或升级套餐。"
)
@@ -77,8 +81,11 @@ def record_usage(
if not subscription:
return
# increment usage counter
UserSubscriptions.increment_usage(subscription.id)
# increment per-model usage counter
if model_id:
UserSubscriptions.increment_model_usage(subscription.id, model_id)
else:
UserSubscriptions.increment_usage(subscription.id)
# log usage
SubscriptionUsageLogs.insert(
@@ -91,27 +98,100 @@ def record_usage(
)
def get_user_model_limit(user_id: str, model_id: str) -> int:
"""
Get the usage limit for a specific model for a user.
Returns: -1 = unlimited, 0 = not allowed, positive = monthly limit
"""
subscription = UserSubscriptions.get_by_user_id(user_id)
if not subscription:
default_plan = SubscriptionPlans.get_default_plan()
if default_plan:
return default_plan.get_model_limit(model_id)
return -1 # allow by default if no plan
plan = SubscriptionPlans.get_plan_by_id(subscription.plan_id)
if not plan:
return -1
return plan.get_model_limit(model_id)
def get_user_model_remaining(user_id: str, model_id: str) -> Optional[int]:
"""
Get the remaining usage count for a specific model for a user.
Returns: None if unlimited, 0 if not allowed or exhausted, positive if remaining
"""
subscription = UserSubscriptions.get_by_user_id(user_id)
if not subscription:
default_plan = SubscriptionPlans.get_default_plan()
if default_plan:
limit = default_plan.get_model_limit(model_id)
return None if limit == -1 else limit
return None
plan = SubscriptionPlans.get_plan_by_id(subscription.plan_id)
if not plan:
return None
limit = plan.get_model_limit(model_id)
if limit == -1:
return None
if limit == 0:
return 0
used = subscription.get_model_usage(model_id)
return max(0, limit - used)
def get_user_allowed_models(user_id: str) -> Optional[list[str]]:
"""
Deprecated: Use get_user_model_limit instead.
Get the list of models allowed for a user based on their subscription.
Returns None if all models are allowed.
"""
subscription = UserSubscriptions.get_by_user_id(user_id)
if not subscription:
default_plan = SubscriptionPlans.get_default_plan()
return default_plan.allowed_models if default_plan else None
if default_plan and default_plan.model_limits:
# return models with limit != 0
return [m for m, l in default_plan.model_limits.items() if l != 0]
return None
plan = SubscriptionPlans.get_plan_by_id(subscription.plan_id)
return plan.allowed_models if plan else None
if plan and plan.model_limits:
return [m for m, l in plan.model_limits.items() if l != 0]
return None
def filter_models_by_subscription(user_id: str, models: list) -> list:
"""
Filter a list of models based on user's subscription.
Models with limit = 0 are filtered out.
"""
allowed_models = get_user_allowed_models(user_id)
subscription = UserSubscriptions.get_by_user_id(user_id)
plan = None
if allowed_models is None:
if subscription:
plan = SubscriptionPlans.get_plan_by_id(subscription.plan_id)
else:
plan = SubscriptionPlans.get_default_plan()
if not plan:
return models
return [m for m in models if m.get("id") in allowed_models]
# if no model_limits defined, use default_model_limit
if not plan.model_limits:
if plan.default_model_limit == 0:
return [] # no models allowed
return models # all models allowed
# filter out models with limit = 0
result = []
for m in models:
model_id = m.get("id", "")
limit = plan.get_model_limit(model_id)
if limit != 0:
result.append(m)
return result