feat:精简
Some checks failed
Create and publish Docker images with specific build args / build-main-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda126-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda126-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-ollama-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-ollama-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-slim-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-slim-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Python CI / Format Backend (3.11.x) (push) Has been cancelled
Python CI / Format Backend (3.12.x) (push) Has been cancelled
Frontend Build / Format & Build Frontend (push) Has been cancelled
Frontend Build / Frontend Unit Tests (push) Has been cancelled
Create and publish Docker images with specific build args / merge-main-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-cuda-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-cuda126-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-ollama-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-slim-images (push) Has been cancelled
Close inactive issues / close-issues (push) Has been cancelled
Some checks failed
Create and publish Docker images with specific build args / build-main-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda126-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda126-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-ollama-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-ollama-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-slim-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-slim-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Python CI / Format Backend (3.11.x) (push) Has been cancelled
Python CI / Format Backend (3.12.x) (push) Has been cancelled
Frontend Build / Format & Build Frontend (push) Has been cancelled
Frontend Build / Frontend Unit Tests (push) Has been cancelled
Create and publish Docker images with specific build args / merge-main-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-cuda-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-cuda126-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-ollama-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-slim-images (push) Has been cancelled
Close inactive issues / close-issues (push) Has been cancelled
This commit is contained in:
@@ -22,7 +22,7 @@ from typing import Optional, Union, List, Dict
|
||||
|
||||
from opentelemetry import trace
|
||||
|
||||
from open_webui.config import WEBUI_URL
|
||||
from open_webui.config import WEBUI_URL, EMAIL_VERIFY_TEMPLATE, EMAIL_CODE_TEMPLATE
|
||||
from open_webui.utils.smtp import send_email
|
||||
|
||||
from open_webui.utils.access_control import has_permission
|
||||
@@ -47,6 +47,7 @@ from open_webui.env import (
|
||||
REDIS_SENTINEL_PORT,
|
||||
WEBUI_NAME,
|
||||
REDIS_CLUSTER,
|
||||
BASE_DIR,
|
||||
)
|
||||
|
||||
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
|
||||
@@ -81,76 +82,176 @@ def verify_signature(payload: str, signature: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def override_static(path: str, content: str):
|
||||
def override_static(path: str, content: str) -> bool:
|
||||
"""Download content from URL and save to path.
|
||||
|
||||
Returns True if successful, False otherwise.
|
||||
Does not raise exceptions to prevent app startup failure.
|
||||
"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
|
||||
r = requests.get(content, stream=True)
|
||||
with open(path, "wb") as f:
|
||||
r.raw.decode_content = True
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
try:
|
||||
# try with SSL verification first, then without if it fails
|
||||
try:
|
||||
r = requests.get(content, stream=True, timeout=30)
|
||||
except requests.exceptions.SSLError:
|
||||
log.warning(f"SSL error downloading {content}, retrying without verification")
|
||||
r = requests.get(content, stream=True, timeout=30, verify=False)
|
||||
|
||||
r.raise_for_status()
|
||||
with open(path, "wb") as f:
|
||||
r.raw.decode_content = True
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
log.info(f"Successfully downloaded {content} to {path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Failed to download {content}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def apply_branding(app):
|
||||
"""Apply branding settings from database config or environment variables."""
|
||||
custom_png = ""
|
||||
custom_svg = ""
|
||||
custom_ico = ""
|
||||
custom_dark_png = ""
|
||||
organization_name = "OpenWebui"
|
||||
custom_name = ""
|
||||
|
||||
if hasattr(app, "state") and hasattr(app.state, "config"):
|
||||
# Get config values - must access _state directly to get fresh PersistentConfig values
|
||||
def get_config_value(attr_name, default=""):
|
||||
# Access _state directly to get the PersistentConfig object
|
||||
config_obj = app.state.config._state.get(attr_name)
|
||||
if config_obj is None:
|
||||
return default
|
||||
# If it's a PersistentConfig object, get its value
|
||||
if hasattr(config_obj, "value"):
|
||||
result = config_obj.value or default
|
||||
log.info(f"Branding: {attr_name}.value = {result}")
|
||||
return result
|
||||
return config_obj or default
|
||||
|
||||
custom_png = get_config_value("BRANDING_FAVICON_PNG") or os.getenv("CUSTOM_PNG", "")
|
||||
custom_svg = get_config_value("BRANDING_FAVICON_SVG") or os.getenv("CUSTOM_SVG", "")
|
||||
custom_ico = get_config_value("BRANDING_FAVICON_ICO") or os.getenv("CUSTOM_ICO", "")
|
||||
custom_dark_png = get_config_value("BRANDING_FAVICON_DARK_PNG") or os.getenv("CUSTOM_DARK_PNG", "")
|
||||
organization_name = get_config_value("BRANDING_ORGANIZATION_NAME") or os.getenv("ORGANIZATION_NAME", "OpenWebui")
|
||||
custom_name = get_config_value("BRANDING_CUSTOM_NAME") or os.getenv("CUSTOM_NAME", "")
|
||||
else:
|
||||
custom_png = os.getenv("CUSTOM_PNG", "")
|
||||
custom_svg = os.getenv("CUSTOM_SVG", "")
|
||||
custom_ico = os.getenv("CUSTOM_ICO", "")
|
||||
custom_dark_png = os.getenv("CUSTOM_DARK_PNG", "")
|
||||
organization_name = os.getenv("ORGANIZATION_NAME", "OpenWebui")
|
||||
custom_name = os.getenv("CUSTOM_NAME", "")
|
||||
|
||||
log.info(f"Branding: custom_name={custom_name}, custom_png={custom_png}")
|
||||
|
||||
# Frontend static/static directory (for dev mode)
|
||||
# SvelteKit maps static/ folder to root, so /static/favicon.png = static/static/favicon.png
|
||||
frontend_static_dir = BASE_DIR / "static" / "static"
|
||||
|
||||
# Build resources mapping for both backend and frontend static dirs
|
||||
resources = []
|
||||
|
||||
# files to override
|
||||
branding_files = [
|
||||
("logo.png", custom_png),
|
||||
("favicon.png", custom_png),
|
||||
("favicon.svg", custom_svg),
|
||||
("favicon-96x96.png", custom_png),
|
||||
("apple-touch-icon.png", custom_png),
|
||||
("web-app-manifest-192x192.png", custom_png),
|
||||
("web-app-manifest-512x512.png", custom_png),
|
||||
("splash.png", custom_png),
|
||||
("favicon.ico", custom_ico),
|
||||
("favicon-dark.png", custom_dark_png),
|
||||
("splash-dark.png", custom_dark_png),
|
||||
]
|
||||
|
||||
for filename, url in branding_files:
|
||||
if url:
|
||||
# Add backend static dir path
|
||||
resources.append((os.path.join(STATIC_DIR, filename), url))
|
||||
# Add frontend static/static dir path (for dev mode)
|
||||
resources.append((os.path.join(frontend_static_dir, filename), url))
|
||||
|
||||
try:
|
||||
for path, url in resources:
|
||||
if url:
|
||||
log.info(f"Branding: Downloading {url} to {path}")
|
||||
override_static(path, url)
|
||||
|
||||
# set metadata
|
||||
setattr(app.state, "LICENSE_METADATA", {
|
||||
"type": "enterprise",
|
||||
"organization_name": organization_name,
|
||||
})
|
||||
log.info(f"Branding: Set LICENSE_METADATA organization_name={organization_name}")
|
||||
|
||||
# set custom name if provided
|
||||
if custom_name:
|
||||
setattr(app.state, "WEBUI_NAME", custom_name)
|
||||
log.info(f"Branding: Set WEBUI_NAME={custom_name}")
|
||||
else:
|
||||
log.info(f"Branding: custom_name is empty, WEBUI_NAME not changed")
|
||||
|
||||
return True
|
||||
except Exception as ex:
|
||||
log.exception(f"Branding: Uncaught Exception: {ex}")
|
||||
return False
|
||||
|
||||
|
||||
def get_license_data(app, key):
|
||||
# get branding config from app.state.config if available, fallback to env vars
|
||||
custom_png = ""
|
||||
custom_svg = ""
|
||||
custom_ico = ""
|
||||
custom_dark_png = ""
|
||||
organization_name = "OpenWebui"
|
||||
|
||||
if hasattr(app, "state") and hasattr(app.state, "config"):
|
||||
custom_png = getattr(app.state.config, "BRANDING_FAVICON_PNG", "") or os.getenv("CUSTOM_PNG", "")
|
||||
custom_svg = getattr(app.state.config, "BRANDING_FAVICON_SVG", "") or os.getenv("CUSTOM_SVG", "")
|
||||
custom_ico = getattr(app.state.config, "BRANDING_FAVICON_ICO", "") or os.getenv("CUSTOM_ICO", "")
|
||||
custom_dark_png = getattr(app.state.config, "BRANDING_FAVICON_DARK_PNG", "") or os.getenv("CUSTOM_DARK_PNG", "")
|
||||
organization_name = getattr(app.state.config, "BRANDING_ORGANIZATION_NAME", "") or os.getenv("ORGANIZATION_NAME", "OpenWebui")
|
||||
else:
|
||||
custom_png = os.getenv("CUSTOM_PNG", "")
|
||||
custom_svg = os.getenv("CUSTOM_SVG", "")
|
||||
custom_ico = os.getenv("CUSTOM_ICO", "")
|
||||
custom_dark_png = os.getenv("CUSTOM_DARK_PNG", "")
|
||||
organization_name = os.getenv("ORGANIZATION_NAME", "OpenWebui")
|
||||
|
||||
payload = {
|
||||
"resources": {
|
||||
os.path.join(STATIC_DIR, "logo.png"): os.getenv("CUSTOM_PNG", ""),
|
||||
os.path.join(STATIC_DIR, "favicon.png"): os.getenv("CUSTOM_PNG", ""),
|
||||
os.path.join(STATIC_DIR, "favicon.svg"): os.getenv("CUSTOM_SVG", ""),
|
||||
os.path.join(STATIC_DIR, "favicon-96x96.png"): os.getenv("CUSTOM_PNG", ""),
|
||||
os.path.join(STATIC_DIR, "apple-touch-icon.png"): os.getenv(
|
||||
"CUSTOM_PNG", ""
|
||||
),
|
||||
os.path.join(STATIC_DIR, "web-app-manifest-192x192.png"): os.getenv(
|
||||
"CUSTOM_PNG", ""
|
||||
),
|
||||
os.path.join(STATIC_DIR, "web-app-manifest-512x512.png"): os.getenv(
|
||||
"CUSTOM_PNG", ""
|
||||
),
|
||||
os.path.join(STATIC_DIR, "splash.png"): os.getenv("CUSTOM_PNG", ""),
|
||||
os.path.join(STATIC_DIR, "favicon.ico"): os.getenv("CUSTOM_ICO", ""),
|
||||
os.path.join(STATIC_DIR, "favicon-dark.png"): os.getenv(
|
||||
"CUSTOM_DARK_PNG", ""
|
||||
),
|
||||
os.path.join(STATIC_DIR, "splash-dark.png"): os.getenv(
|
||||
"CUSTOM_DARK_PNG", ""
|
||||
),
|
||||
os.path.join(FRONTEND_BUILD_DIR, "favicon.png"): os.getenv(
|
||||
"CUSTOM_PNG", ""
|
||||
),
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/favicon.png"): os.getenv(
|
||||
"CUSTOM_PNG", ""
|
||||
),
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/favicon.svg"): os.getenv(
|
||||
"CUSTOM_SVG", ""
|
||||
),
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/favicon-96x96.png"): os.getenv(
|
||||
"CUSTOM_PNG", ""
|
||||
),
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/apple-touch-icon.png"): os.getenv(
|
||||
"CUSTOM_PNG", ""
|
||||
),
|
||||
os.path.join(
|
||||
FRONTEND_BUILD_DIR, "static/web-app-manifest-192x192.png"
|
||||
): os.getenv("CUSTOM_PNG", ""),
|
||||
os.path.join(
|
||||
FRONTEND_BUILD_DIR, "static/web-app-manifest-512x512.png"
|
||||
): os.getenv("CUSTOM_PNG", ""),
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/splash.png"): os.getenv(
|
||||
"CUSTOM_PNG", ""
|
||||
),
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/favicon.ico"): os.getenv(
|
||||
"CUSTOM_ICO", ""
|
||||
),
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/favicon-dark.png"): os.getenv(
|
||||
"CUSTOM_DARK_PNG", ""
|
||||
),
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/splash-dark.png"): os.getenv(
|
||||
"CUSTOM_DARK_PNG", ""
|
||||
),
|
||||
os.path.join(STATIC_DIR, "logo.png"): custom_png,
|
||||
os.path.join(STATIC_DIR, "favicon.png"): custom_png,
|
||||
os.path.join(STATIC_DIR, "favicon.svg"): custom_svg,
|
||||
os.path.join(STATIC_DIR, "favicon-96x96.png"): custom_png,
|
||||
os.path.join(STATIC_DIR, "apple-touch-icon.png"): custom_png,
|
||||
os.path.join(STATIC_DIR, "web-app-manifest-192x192.png"): custom_png,
|
||||
os.path.join(STATIC_DIR, "web-app-manifest-512x512.png"): custom_png,
|
||||
os.path.join(STATIC_DIR, "splash.png"): custom_png,
|
||||
os.path.join(STATIC_DIR, "favicon.ico"): custom_ico,
|
||||
os.path.join(STATIC_DIR, "favicon-dark.png"): custom_dark_png,
|
||||
os.path.join(STATIC_DIR, "splash-dark.png"): custom_dark_png,
|
||||
os.path.join(FRONTEND_BUILD_DIR, "favicon.png"): custom_png,
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/favicon.png"): custom_png,
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/favicon.svg"): custom_svg,
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/favicon-96x96.png"): custom_png,
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/apple-touch-icon.png"): custom_png,
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/web-app-manifest-192x192.png"): custom_png,
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/web-app-manifest-512x512.png"): custom_png,
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/splash.png"): custom_png,
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/favicon.ico"): custom_ico,
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/favicon-dark.png"): custom_dark_png,
|
||||
os.path.join(FRONTEND_BUILD_DIR, "static/splash-dark.png"): custom_dark_png,
|
||||
},
|
||||
"metadata": {
|
||||
"type": "enterprise",
|
||||
"organization_name": os.getenv("ORGANIZATION_NAME", "OpenWebui"),
|
||||
"organization_name": organization_name,
|
||||
},
|
||||
}
|
||||
try:
|
||||
@@ -611,11 +712,17 @@ def send_verify_email(email: str):
|
||||
code = f"{uuid.uuid4().hex}{uuid.uuid1().hex}"
|
||||
redis.set(name=get_email_code_key(code=code), value=email, ex=timedelta(days=1))
|
||||
link = f"{WEBUI_URL.value.rstrip('/')}/api/v1/auths/signup_verify/{code}"
|
||||
|
||||
# use template from config
|
||||
template = EMAIL_VERIFY_TEMPLATE.value or verify_email_template
|
||||
|
||||
# use replace instead of % formatting to avoid issues with % in CSS
|
||||
body = template.replace("%(title)s", f"{WEBUI_NAME} Email Verify").replace("%(link)s", link)
|
||||
|
||||
send_email(
|
||||
receiver=email,
|
||||
subject=f"{WEBUI_NAME} Email Verify",
|
||||
body=verify_email_template
|
||||
% {"title": f"{WEBUI_NAME} Email Verify", "link": link},
|
||||
body=body,
|
||||
)
|
||||
|
||||
|
||||
@@ -628,3 +735,78 @@ def verify_email_by_code(code: str) -> str:
|
||||
redis_cluster=REDIS_CLUSTER,
|
||||
)
|
||||
return redis.get(name=get_email_code_key(code=code))
|
||||
|
||||
|
||||
# email verification code template
|
||||
email_code_template = """<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
|
||||
.container { max-width: 500px; margin: 0 auto; background: white; border-radius: 10px; padding: 30px; }
|
||||
.code { font-size: 32px; font-weight: bold; color: #3b82f6; letter-spacing: 8px; text-align: center; padding: 20px; background: #f0f9ff; border-radius: 8px; margin: 20px 0; }
|
||||
.note { color: #666; font-size: 14px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h2>%(title)s</h2>
|
||||
<p>您的验证码是:</p>
|
||||
<div class="code">%(code)s</div>
|
||||
<p class="note">验证码有效期为10分钟,请勿泄露给他人。</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
|
||||
def get_signup_code_key(email: str) -> str:
|
||||
return f"signup_code:{email}"
|
||||
|
||||
|
||||
def send_signup_email_code(email: str) -> str:
|
||||
"""Send a 6-digit verification code to email for signup."""
|
||||
import random
|
||||
|
||||
redis = get_redis_connection(
|
||||
redis_url=REDIS_URL,
|
||||
redis_sentinels=get_sentinels_from_env(
|
||||
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
|
||||
),
|
||||
redis_cluster=REDIS_CLUSTER,
|
||||
)
|
||||
|
||||
# generate 6-digit code
|
||||
code = str(random.randint(100000, 999999))
|
||||
|
||||
# store in redis with 10 min expiration
|
||||
redis.set(name=get_signup_code_key(email.lower()), value=code, ex=timedelta(minutes=10))
|
||||
|
||||
# use template from config
|
||||
template = EMAIL_CODE_TEMPLATE.value or email_code_template
|
||||
|
||||
send_email(
|
||||
receiver=email,
|
||||
subject=f"{WEBUI_NAME} 注册验证码",
|
||||
body=template % {"title": f"{WEBUI_NAME} 注册验证码", "code": code},
|
||||
)
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def verify_signup_email_code(email: str, code: str) -> bool:
|
||||
"""Verify the signup email code."""
|
||||
redis = get_redis_connection(
|
||||
redis_url=REDIS_URL,
|
||||
redis_sentinels=get_sentinels_from_env(
|
||||
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
|
||||
),
|
||||
redis_cluster=REDIS_CLUSTER,
|
||||
)
|
||||
|
||||
stored_code = redis.get(name=get_signup_code_key(email.lower()))
|
||||
if stored_code and stored_code == code:
|
||||
# delete code after successful verification
|
||||
redis.delete(get_signup_code_key(email.lower()))
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -39,8 +39,6 @@ from open_webui.routers.pipelines import (
|
||||
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
from open_webui.utils.credit.usage import CreditDeduct
|
||||
from open_webui.utils.credit.utils import check_credit_by_user_id
|
||||
|
||||
from open_webui.utils.plugin import (
|
||||
load_function_module_by_id,
|
||||
@@ -169,8 +167,6 @@ async def generate_chat_completion(
|
||||
user: Any,
|
||||
bypass_filter: bool = False,
|
||||
):
|
||||
check_credit_by_user_id(user_id=user.id, form_data=form_data)
|
||||
|
||||
log.debug(f"generate_chat_completion: {form_data}")
|
||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
bypass_filter = True
|
||||
@@ -283,15 +279,8 @@ async def generate_chat_completion(
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
with CreditDeduct(
|
||||
user=user,
|
||||
model_id=model_id,
|
||||
body=payload,
|
||||
is_stream=False,
|
||||
) as credit_deduct:
|
||||
response = convert_response_ollama_to_openai(response)
|
||||
credit_deduct.run(response)
|
||||
return credit_deduct.add_usage_to_resp(response)
|
||||
response = convert_response_ollama_to_openai(response)
|
||||
return response
|
||||
else:
|
||||
return await generate_openai_chat_completion(
|
||||
request=request,
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
# credit module - deprecated, kept for backward compatibility
|
||||
# this module is replaced by the subscription system
|
||||
@@ -1,25 +0,0 @@
|
||||
# credit usage module - deprecated, kept for backward compatibility
|
||||
# this module is replaced by the subscription system
|
||||
|
||||
|
||||
class CreditDeduct:
|
||||
"""
|
||||
Deprecated credit deduction class.
|
||||
Kept for backward compatibility - does nothing.
|
||||
The subscription system now handles usage tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
@@ -1,18 +0,0 @@
|
||||
# credit utils module - deprecated, kept for backward compatibility
|
||||
# this module is replaced by the subscription system
|
||||
|
||||
|
||||
def is_free_request(*args, **kwargs) -> bool:
|
||||
"""
|
||||
Deprecated function.
|
||||
Returns True to allow all requests (subscription system handles access control).
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
def check_credit_by_user_id(*args, **kwargs):
|
||||
"""
|
||||
Deprecated function.
|
||||
Does nothing - subscription system now handles access control.
|
||||
"""
|
||||
pass
|
||||
@@ -57,6 +57,7 @@ from open_webui.routers.pipelines import (
|
||||
from open_webui.routers.memories import query_memory, QueryMemoryForm
|
||||
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
from open_webui.utils.subscription.check import record_usage
|
||||
from open_webui.utils.files import (
|
||||
convert_markdown_base64_images,
|
||||
get_file_url_from_base64,
|
||||
@@ -1964,6 +1965,14 @@ async def process_chat_response(
|
||||
},
|
||||
)
|
||||
|
||||
# record subscription usage on success
|
||||
record_usage(
|
||||
user_id=user.id,
|
||||
model_id=form_data.get("model"),
|
||||
chat_id=metadata["chat_id"],
|
||||
message_id=metadata["message_id"],
|
||||
)
|
||||
|
||||
# Send a webhook notification if the user is not active
|
||||
if not Users.is_user_active(user.id):
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
@@ -3269,6 +3278,14 @@ async def process_chat_response(
|
||||
},
|
||||
)
|
||||
|
||||
# record subscription usage on success
|
||||
record_usage(
|
||||
user_id=user.id,
|
||||
model_id=model_id,
|
||||
chat_id=metadata["chat_id"],
|
||||
message_id=metadata["message_id"],
|
||||
)
|
||||
|
||||
# Send a webhook notification if the user is not active
|
||||
if not Users.is_user_active(user.id):
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
|
||||
@@ -665,24 +665,12 @@ def stream_chunks_handler(
|
||||
:return: An async generator that yields the stream data.
|
||||
"""
|
||||
|
||||
from open_webui.utils.credit.usage import CreditDeduct
|
||||
|
||||
max_buffer_size = CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE
|
||||
if max_buffer_size is None or max_buffer_size <= 0:
|
||||
|
||||
async def consumer_content(stream: aiohttp.StreamReader):
|
||||
with CreditDeduct(
|
||||
user=user,
|
||||
model_id=model_id,
|
||||
body=form_data,
|
||||
is_stream=True,
|
||||
) as credit_deduct:
|
||||
# change to avoid multi \n\n cause message lose
|
||||
async for chunk in stream:
|
||||
credit_deduct.run(response=chunk)
|
||||
yield chunk
|
||||
|
||||
yield credit_deduct.usage_message
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
return consumer_content(stream)
|
||||
|
||||
@@ -690,65 +678,53 @@ def stream_chunks_handler(
|
||||
buffer = b""
|
||||
skip_mode = False
|
||||
|
||||
with CreditDeduct(
|
||||
user=user,
|
||||
model_id=model_id,
|
||||
body=form_data,
|
||||
is_stream=True,
|
||||
) as credit_deduct:
|
||||
# change to avoid multi \n\n cause message lose
|
||||
async for data in stream:
|
||||
async for data in stream:
|
||||
|
||||
if not data:
|
||||
continue
|
||||
if not data:
|
||||
continue
|
||||
|
||||
credit_deduct.run(response=data)
|
||||
# In skip_mode, if buffer already exceeds the limit, clear it (it's part of an oversized line)
|
||||
if skip_mode and len(buffer) > max_buffer_size:
|
||||
buffer = b""
|
||||
|
||||
# In skip_mode, if buffer already exceeds the limit, clear it (it's part of an oversized line)
|
||||
if skip_mode and len(buffer) > max_buffer_size:
|
||||
buffer = b""
|
||||
lines = (buffer + data).split(b"\n")
|
||||
|
||||
lines = (buffer + data).split(b"\n")
|
||||
# Process complete lines (except the last possibly incomplete fragment)
|
||||
for i in range(len(lines) - 1):
|
||||
line = lines[i]
|
||||
|
||||
# Process complete lines (except the last possibly incomplete fragment)
|
||||
for i in range(len(lines) - 1):
|
||||
line = lines[i]
|
||||
|
||||
if skip_mode:
|
||||
# Skip mode: check if current line is small enough to exit skip mode
|
||||
if len(line) <= max_buffer_size:
|
||||
skip_mode = False
|
||||
yield line
|
||||
else:
|
||||
yield b"data: {}"
|
||||
yield b"\n"
|
||||
if skip_mode:
|
||||
# Skip mode: check if current line is small enough to exit skip mode
|
||||
if len(line) <= max_buffer_size:
|
||||
skip_mode = False
|
||||
yield line
|
||||
else:
|
||||
# Normal mode: check if line exceeds limit
|
||||
if len(line) > max_buffer_size:
|
||||
skip_mode = True
|
||||
yield b"data: {}"
|
||||
yield b"\n"
|
||||
log.info(f"Skip mode triggered, line size: {len(line)}")
|
||||
else:
|
||||
yield line
|
||||
yield b"\n"
|
||||
yield b"data: {}"
|
||||
yield b"\n"
|
||||
else:
|
||||
# Normal mode: check if line exceeds limit
|
||||
if len(line) > max_buffer_size:
|
||||
skip_mode = True
|
||||
yield b"data: {}"
|
||||
yield b"\n"
|
||||
log.info(f"Skip mode triggered, line size: {len(line)}")
|
||||
else:
|
||||
yield line
|
||||
yield b"\n"
|
||||
|
||||
# Save the last incomplete fragment
|
||||
buffer = lines[-1]
|
||||
# Save the last incomplete fragment
|
||||
buffer = lines[-1]
|
||||
|
||||
# Check if buffer exceeds limit
|
||||
if not skip_mode and len(buffer) > max_buffer_size:
|
||||
skip_mode = True
|
||||
log.info(f"Skip mode triggered, buffer size: {len(buffer)}")
|
||||
# Clear oversized buffer to prevent unlimited growth
|
||||
buffer = b""
|
||||
# Check if buffer exceeds limit
|
||||
if not skip_mode and len(buffer) > max_buffer_size:
|
||||
skip_mode = True
|
||||
log.info(f"Skip mode triggered, buffer size: {len(buffer)}")
|
||||
# Clear oversized buffer to prevent unlimited growth
|
||||
buffer = b""
|
||||
|
||||
# Process remaining buffer data
|
||||
if buffer and not skip_mode:
|
||||
credit_deduct.run(response=buffer)
|
||||
yield buffer
|
||||
yield b"\n"
|
||||
|
||||
yield credit_deduct.usage_message
|
||||
# Process remaining buffer data
|
||||
if buffer and not skip_mode:
|
||||
yield buffer
|
||||
yield b"\n"
|
||||
|
||||
return yield_safe_stream_chunks()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
from uuid import uuid4
|
||||
|
||||
from open_webui.utils.credit.usage import CreditDeduct
|
||||
from open_webui.utils.misc import (
|
||||
openai_chat_chunk_message_template,
|
||||
openai_chat_completion_message_template,
|
||||
@@ -104,13 +103,7 @@ def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||
async def convert_streaming_response_ollama_to_openai(
|
||||
user, model_id, form_data, ollama_streaming_response
|
||||
):
|
||||
with CreditDeduct(
|
||||
user=user,
|
||||
model_id=model_id,
|
||||
body=form_data,
|
||||
is_stream=True,
|
||||
) as credit_deduct:
|
||||
async for data in ollama_streaming_response.body_iterator:
|
||||
async for data in ollama_streaming_response.body_iterator:
|
||||
data = json.loads(data)
|
||||
|
||||
model = data.get("model", "ollama")
|
||||
@@ -133,11 +126,8 @@ async def convert_streaming_response_ollama_to_openai(
|
||||
)
|
||||
|
||||
line = f"data: {json.dumps(data)}\n\n"
|
||||
credit_deduct.run(line)
|
||||
yield line
|
||||
|
||||
yield credit_deduct.usage_message
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from open_webui.models.subscriptions import (
|
||||
)
|
||||
|
||||
|
||||
async def check_subscription_access(
|
||||
def check_subscription_access(
|
||||
user_id: str,
|
||||
model_id: str,
|
||||
chat_id: Optional[str] = None,
|
||||
@@ -21,8 +21,8 @@ async def check_subscription_access(
|
||||
|
||||
Raises HTTPException if:
|
||||
1. User's subscription has expired (auto-downgrades to Free)
|
||||
2. Monthly message limit reached
|
||||
3. Model is not in allowed list
|
||||
2. Model usage limit reached for this billing period
|
||||
3. Model is not allowed (limit = 0)
|
||||
"""
|
||||
now = int(time.time())
|
||||
|
||||
@@ -47,19 +47,23 @@ async def check_subscription_access(
|
||||
detail="Subscription plan not found"
|
||||
)
|
||||
|
||||
# check model access
|
||||
if plan.allowed_models and model_id not in plan.allowed_models:
|
||||
# get model limit for this plan
|
||||
model_limit = plan.get_model_limit(model_id)
|
||||
|
||||
# check if model is not allowed (limit = 0)
|
||||
if model_limit == 0:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Your {plan.name} plan does not include access to this model. Please upgrade your subscription."
|
||||
detail=f"您的 {plan.name} 套餐不支持使用此模型。请升级套餐以获取访问权限。"
|
||||
)
|
||||
|
||||
# check usage limit
|
||||
if plan.monthly_message_limit is not None:
|
||||
if subscription.messages_used >= plan.monthly_message_limit:
|
||||
# check usage limit (skip if unlimited = -1)
|
||||
if model_limit > 0:
|
||||
current_usage = subscription.get_model_usage(model_id)
|
||||
if current_usage >= model_limit:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"You have reached your monthly message limit ({plan.monthly_message_limit}). Please wait until next month or upgrade your subscription."
|
||||
detail=f"您本月已使用 {current_usage}/{model_limit} 次此模型。请等待下月刷新或升级套餐。"
|
||||
)
|
||||
|
||||
|
||||
@@ -77,8 +81,11 @@ def record_usage(
|
||||
if not subscription:
|
||||
return
|
||||
|
||||
# increment usage counter
|
||||
UserSubscriptions.increment_usage(subscription.id)
|
||||
# increment per-model usage counter
|
||||
if model_id:
|
||||
UserSubscriptions.increment_model_usage(subscription.id, model_id)
|
||||
else:
|
||||
UserSubscriptions.increment_usage(subscription.id)
|
||||
|
||||
# log usage
|
||||
SubscriptionUsageLogs.insert(
|
||||
@@ -91,27 +98,100 @@ def record_usage(
|
||||
)
|
||||
|
||||
|
||||
def get_user_model_limit(user_id: str, model_id: str) -> int:
|
||||
"""
|
||||
Get the usage limit for a specific model for a user.
|
||||
Returns: -1 = unlimited, 0 = not allowed, positive = monthly limit
|
||||
"""
|
||||
subscription = UserSubscriptions.get_by_user_id(user_id)
|
||||
if not subscription:
|
||||
default_plan = SubscriptionPlans.get_default_plan()
|
||||
if default_plan:
|
||||
return default_plan.get_model_limit(model_id)
|
||||
return -1 # allow by default if no plan
|
||||
|
||||
plan = SubscriptionPlans.get_plan_by_id(subscription.plan_id)
|
||||
if not plan:
|
||||
return -1
|
||||
|
||||
return plan.get_model_limit(model_id)
|
||||
|
||||
|
||||
def get_user_model_remaining(user_id: str, model_id: str) -> Optional[int]:
|
||||
"""
|
||||
Get the remaining usage count for a specific model for a user.
|
||||
Returns: None if unlimited, 0 if not allowed or exhausted, positive if remaining
|
||||
"""
|
||||
subscription = UserSubscriptions.get_by_user_id(user_id)
|
||||
if not subscription:
|
||||
default_plan = SubscriptionPlans.get_default_plan()
|
||||
if default_plan:
|
||||
limit = default_plan.get_model_limit(model_id)
|
||||
return None if limit == -1 else limit
|
||||
return None
|
||||
|
||||
plan = SubscriptionPlans.get_plan_by_id(subscription.plan_id)
|
||||
if not plan:
|
||||
return None
|
||||
|
||||
limit = plan.get_model_limit(model_id)
|
||||
if limit == -1:
|
||||
return None
|
||||
if limit == 0:
|
||||
return 0
|
||||
|
||||
used = subscription.get_model_usage(model_id)
|
||||
return max(0, limit - used)
|
||||
|
||||
|
||||
def get_user_allowed_models(user_id: str) -> Optional[list[str]]:
|
||||
"""
|
||||
Deprecated: Use get_user_model_limit instead.
|
||||
Get the list of models allowed for a user based on their subscription.
|
||||
Returns None if all models are allowed.
|
||||
"""
|
||||
subscription = UserSubscriptions.get_by_user_id(user_id)
|
||||
if not subscription:
|
||||
default_plan = SubscriptionPlans.get_default_plan()
|
||||
return default_plan.allowed_models if default_plan else None
|
||||
if default_plan and default_plan.model_limits:
|
||||
# return models with limit != 0
|
||||
return [m for m, l in default_plan.model_limits.items() if l != 0]
|
||||
return None
|
||||
|
||||
plan = SubscriptionPlans.get_plan_by_id(subscription.plan_id)
|
||||
return plan.allowed_models if plan else None
|
||||
if plan and plan.model_limits:
|
||||
return [m for m, l in plan.model_limits.items() if l != 0]
|
||||
return None
|
||||
|
||||
|
||||
def filter_models_by_subscription(user_id: str, models: list) -> list:
|
||||
"""
|
||||
Filter a list of models based on user's subscription.
|
||||
Models with limit = 0 are filtered out.
|
||||
"""
|
||||
allowed_models = get_user_allowed_models(user_id)
|
||||
subscription = UserSubscriptions.get_by_user_id(user_id)
|
||||
plan = None
|
||||
|
||||
if allowed_models is None:
|
||||
if subscription:
|
||||
plan = SubscriptionPlans.get_plan_by_id(subscription.plan_id)
|
||||
else:
|
||||
plan = SubscriptionPlans.get_default_plan()
|
||||
|
||||
if not plan:
|
||||
return models
|
||||
|
||||
return [m for m in models if m.get("id") in allowed_models]
|
||||
# if no model_limits defined, use default_model_limit
|
||||
if not plan.model_limits:
|
||||
if plan.default_model_limit == 0:
|
||||
return [] # no models allowed
|
||||
return models # all models allowed
|
||||
|
||||
# filter out models with limit = 0
|
||||
result = []
|
||||
for m in models:
|
||||
model_id = m.get("id", "")
|
||||
limit = plan.get_model_limit(model_id)
|
||||
if limit != 0:
|
||||
result.append(m)
|
||||
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user