feat:新增套餐系统,删除积分制
Some checks failed
Create and publish Docker images with specific build args / build-main-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda126-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda126-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-ollama-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-ollama-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-slim-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-slim-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / merge-main-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-cuda-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-cuda126-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-ollama-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-slim-images (push) Has been cancelled
Python CI / Format Backend (3.11.x) (push) Has been cancelled
Python CI / Format Backend (3.12.x) (push) Has been cancelled
Frontend Build / Format & Build Frontend (push) Has been cancelled
Frontend Build / Frontend Unit Tests (push) Has been cancelled
Close inactive issues / close-issues (push) Has been cancelled
Some checks failed
Create and publish Docker images with specific build args / build-main-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-main-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda126-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-cuda126-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-ollama-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-ollama-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / build-slim-image (linux/amd64, ubuntu-latest) (push) Has been cancelled
Create and publish Docker images with specific build args / build-slim-image (linux/arm64, ubuntu-24.04-arm) (push) Has been cancelled
Create and publish Docker images with specific build args / merge-main-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-cuda-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-cuda126-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-ollama-images (push) Has been cancelled
Create and publish Docker images with specific build args / merge-slim-images (push) Has been cancelled
Python CI / Format Backend (3.11.x) (push) Has been cancelled
Python CI / Format Backend (3.12.x) (push) Has been cancelled
Frontend Build / Format & Build Frontend (push) Has been cancelled
Frontend Build / Frontend Unit Tests (push) Has been cancelled
Close inactive issues / close-issues (push) Has been cancelled
This commit is contained in:
278
backend/open_webui/retrieval/loaders/datalab_marker.py
Normal file
278
backend/open_webui/retrieval/loaders/datalab_marker.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from langchain_core.documents import Document
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatalabMarkerLoader:
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_key: str,
|
||||
api_base_url: str,
|
||||
additional_config: Optional[str] = None,
|
||||
use_llm: bool = False,
|
||||
skip_cache: bool = False,
|
||||
force_ocr: bool = False,
|
||||
paginate: bool = False,
|
||||
strip_existing_ocr: bool = False,
|
||||
disable_image_extraction: bool = False,
|
||||
format_lines: bool = False,
|
||||
output_format: str = None,
|
||||
):
|
||||
self.file_path = file_path
|
||||
self.api_key = api_key
|
||||
self.api_base_url = api_base_url
|
||||
self.additional_config = additional_config
|
||||
self.use_llm = use_llm
|
||||
self.skip_cache = skip_cache
|
||||
self.force_ocr = force_ocr
|
||||
self.paginate = paginate
|
||||
self.strip_existing_ocr = strip_existing_ocr
|
||||
self.disable_image_extraction = disable_image_extraction
|
||||
self.format_lines = format_lines
|
||||
self.output_format = output_format
|
||||
|
||||
def _get_mime_type(self, filename: str) -> str:
|
||||
ext = filename.rsplit(".", 1)[-1].lower()
|
||||
mime_map = {
|
||||
"pdf": "application/pdf",
|
||||
"xls": "application/vnd.ms-excel",
|
||||
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"ods": "application/vnd.oasis.opendocument.spreadsheet",
|
||||
"doc": "application/msword",
|
||||
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"odt": "application/vnd.oasis.opendocument.text",
|
||||
"ppt": "application/vnd.ms-powerpoint",
|
||||
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"odp": "application/vnd.oasis.opendocument.presentation",
|
||||
"html": "text/html",
|
||||
"epub": "application/epub+zip",
|
||||
"png": "image/png",
|
||||
"jpeg": "image/jpeg",
|
||||
"jpg": "image/jpeg",
|
||||
"webp": "image/webp",
|
||||
"gif": "image/gif",
|
||||
"tiff": "image/tiff",
|
||||
}
|
||||
return mime_map.get(ext, "application/octet-stream")
|
||||
|
||||
def check_marker_request_status(self, request_id: str) -> dict:
|
||||
url = f"{self.api_base_url}/{request_id}"
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
log.info(f"Marker API status check for request {request_id}: {result}")
|
||||
return result
|
||||
except requests.HTTPError as e:
|
||||
log.error(f"Error checking Marker request status: {e}")
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Failed to check Marker request: {e}",
|
||||
)
|
||||
except ValueError as e:
|
||||
log.error(f"Invalid JSON checking Marker request: {e}")
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON: {e}"
|
||||
)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
filename = os.path.basename(self.file_path)
|
||||
mime_type = self._get_mime_type(filename)
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
|
||||
form_data = {
|
||||
"use_llm": str(self.use_llm).lower(),
|
||||
"skip_cache": str(self.skip_cache).lower(),
|
||||
"force_ocr": str(self.force_ocr).lower(),
|
||||
"paginate": str(self.paginate).lower(),
|
||||
"strip_existing_ocr": str(self.strip_existing_ocr).lower(),
|
||||
"disable_image_extraction": str(self.disable_image_extraction).lower(),
|
||||
"format_lines": str(self.format_lines).lower(),
|
||||
"output_format": self.output_format,
|
||||
}
|
||||
|
||||
if self.additional_config and self.additional_config.strip():
|
||||
form_data["additional_config"] = self.additional_config
|
||||
|
||||
log.info(
|
||||
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
|
||||
)
|
||||
|
||||
try:
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {"file": (filename, f, mime_type)}
|
||||
response = requests.post(
|
||||
f"{self.api_base_url}",
|
||||
data=form_data,
|
||||
files=files,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(
|
||||
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
|
||||
)
|
||||
except requests.HTTPError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Datalab Marker request failed: {e}",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON response: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
||||
|
||||
if not result.get("success"):
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Datalab Marker request failed: {result.get('error', 'Unknown error')}",
|
||||
)
|
||||
|
||||
check_url = result.get("request_check_url")
|
||||
request_id = result.get("request_id")
|
||||
|
||||
# Check if this is a direct response (self-hosted) or polling response (DataLab)
|
||||
if check_url:
|
||||
# DataLab polling pattern
|
||||
for _ in range(300): # Up to 10 minutes
|
||||
time.sleep(2)
|
||||
try:
|
||||
poll_response = requests.get(check_url, headers=headers)
|
||||
poll_response.raise_for_status()
|
||||
poll_result = poll_response.json()
|
||||
except (requests.HTTPError, ValueError) as e:
|
||||
raw_body = poll_response.text
|
||||
log.error(f"Polling error: {e}, response body: {raw_body}")
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
|
||||
)
|
||||
|
||||
status_val = poll_result.get("status")
|
||||
success_val = poll_result.get("success")
|
||||
|
||||
if status_val == "complete":
|
||||
summary = {
|
||||
k: poll_result.get(k)
|
||||
for k in (
|
||||
"status",
|
||||
"output_format",
|
||||
"success",
|
||||
"error",
|
||||
"page_count",
|
||||
"total_cost",
|
||||
)
|
||||
}
|
||||
log.info(
|
||||
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
|
||||
)
|
||||
break
|
||||
|
||||
if status_val == "failed" or success_val is False:
|
||||
log.error(
|
||||
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
|
||||
)
|
||||
error_msg = (
|
||||
poll_result.get("error")
|
||||
or "Marker returned failure without error message"
|
||||
)
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Marker processing failed: {error_msg}",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="Marker processing timed out",
|
||||
)
|
||||
|
||||
if not poll_result.get("success", False):
|
||||
error_msg = poll_result.get("error") or "Unknown processing error"
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Final processing failed: {error_msg}",
|
||||
)
|
||||
|
||||
# DataLab format - content in format-specific fields
|
||||
content_key = self.output_format.lower()
|
||||
raw_content = poll_result.get(content_key)
|
||||
final_result = poll_result
|
||||
else:
|
||||
# Self-hosted direct response - content in "output" field
|
||||
if "output" in result:
|
||||
log.info("Self-hosted Marker returned direct response without polling")
|
||||
raw_content = result.get("output")
|
||||
final_result = result
|
||||
else:
|
||||
available_fields = (
|
||||
list(result.keys())
|
||||
if isinstance(result, dict)
|
||||
else "non-dict response"
|
||||
)
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Custom Marker endpoint returned success but no 'output' field found. Available fields: {available_fields}. Expected either 'request_check_url' for polling or 'output' field for direct response.",
|
||||
)
|
||||
|
||||
if self.output_format.lower() == "json":
|
||||
full_text = json.dumps(raw_content, indent=2)
|
||||
elif self.output_format.lower() in {"markdown", "html"}:
|
||||
full_text = str(raw_content).strip()
|
||||
else:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported output format: {self.output_format}",
|
||||
)
|
||||
|
||||
if not full_text:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail="Marker returned empty content",
|
||||
)
|
||||
|
||||
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output")
|
||||
os.makedirs(marker_output_dir, exist_ok=True)
|
||||
|
||||
file_ext_map = {"markdown": "md", "json": "json", "html": "html"}
|
||||
file_ext = file_ext_map.get(self.output_format.lower(), "txt")
|
||||
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}"
|
||||
output_path = os.path.join(marker_output_dir, output_filename)
|
||||
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(full_text)
|
||||
log.info(f"Saved Marker output to: {output_path}")
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to write marker output to disk: {e}")
|
||||
|
||||
metadata = {
|
||||
"source": filename,
|
||||
"output_format": final_result.get("output_format", self.output_format),
|
||||
"page_count": final_result.get("page_count", 0),
|
||||
"processed_with_llm": self.use_llm,
|
||||
"request_id": request_id or "",
|
||||
}
|
||||
|
||||
images = final_result.get("images", {})
|
||||
if images:
|
||||
metadata["image_count"] = len(images)
|
||||
metadata["images"] = json.dumps(list(images.keys()))
|
||||
|
||||
for k, v in metadata.items():
|
||||
if isinstance(v, (dict, list)):
|
||||
metadata[k] = json.dumps(v)
|
||||
elif v is None:
|
||||
metadata[k] = ""
|
||||
|
||||
return [Document(page_content=full_text, metadata=metadata)]
|
||||
89
backend/open_webui/retrieval/loaders/external_document.py
Normal file
89
backend/open_webui/retrieval/loaders/external_document.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import requests
|
||||
import logging, os
|
||||
from typing import Iterator, List, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from langchain_core.document_loaders import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.utils.headers import include_user_info_headers
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExternalDocumentLoader(BaseLoader):
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
url: str,
|
||||
api_key: str,
|
||||
mime_type=None,
|
||||
user=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.url = url
|
||||
self.api_key = api_key
|
||||
|
||||
self.file_path = file_path
|
||||
self.mime_type = mime_type
|
||||
|
||||
self.user = user
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
with open(self.file_path, "rb") as f:
|
||||
data = f.read()
|
||||
|
||||
headers = {}
|
||||
if self.mime_type is not None:
|
||||
headers["Content-Type"] = self.mime_type
|
||||
|
||||
if self.api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
try:
|
||||
headers["X-Filename"] = quote(os.path.basename(self.file_path))
|
||||
except:
|
||||
pass
|
||||
|
||||
if self.user is not None:
|
||||
headers = include_user_info_headers(headers, self.user)
|
||||
|
||||
url = self.url
|
||||
if url.endswith("/"):
|
||||
url = url[:-1]
|
||||
|
||||
try:
|
||||
response = requests.put(f"{url}/process", data=data, headers=headers)
|
||||
except Exception as e:
|
||||
log.error(f"Error connecting to endpoint: {e}")
|
||||
raise Exception(f"Error connecting to endpoint: {e}")
|
||||
|
||||
if response.ok:
|
||||
|
||||
response_data = response.json()
|
||||
if response_data:
|
||||
if isinstance(response_data, dict):
|
||||
return [
|
||||
Document(
|
||||
page_content=response_data.get("page_content"),
|
||||
metadata=response_data.get("metadata"),
|
||||
)
|
||||
]
|
||||
elif isinstance(response_data, list):
|
||||
documents = []
|
||||
for document in response_data:
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=document.get("page_content"),
|
||||
metadata=document.get("metadata"),
|
||||
)
|
||||
)
|
||||
return documents
|
||||
else:
|
||||
raise Exception("Error loading document: Unable to parse content")
|
||||
|
||||
else:
|
||||
raise Exception("Error loading document: No content returned")
|
||||
else:
|
||||
raise Exception(
|
||||
f"Error loading document: {response.status_code} {response.text}"
|
||||
)
|
||||
51
backend/open_webui/retrieval/loaders/external_web.py
Normal file
51
backend/open_webui/retrieval/loaders/external_web.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import requests
|
||||
import logging
|
||||
from typing import Iterator, List, Union
|
||||
|
||||
from langchain_core.document_loaders import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExternalWebLoader(BaseLoader):
|
||||
def __init__(
|
||||
self,
|
||||
web_paths: Union[str, List[str]],
|
||||
external_url: str,
|
||||
external_api_key: str,
|
||||
continue_on_failure: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.external_url = external_url
|
||||
self.external_api_key = external_api_key
|
||||
self.urls = web_paths if isinstance(web_paths, list) else [web_paths]
|
||||
self.continue_on_failure = continue_on_failure
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
batch_size = 20
|
||||
for i in range(0, len(self.urls), batch_size):
|
||||
urls = self.urls[i : i + batch_size]
|
||||
try:
|
||||
response = requests.post(
|
||||
self.external_url,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) External Web Loader",
|
||||
"Authorization": f"Bearer {self.external_api_key}",
|
||||
},
|
||||
json={
|
||||
"urls": urls,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
for result in results:
|
||||
yield Document(
|
||||
page_content=result.get("page_content", ""),
|
||||
metadata=result.get("metadata", {}),
|
||||
)
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.error(f"Error extracting content from batch {urls}: {e}")
|
||||
else:
|
||||
raise e
|
||||
403
backend/open_webui/retrieval/loaders/main.py
Normal file
403
backend/open_webui/retrieval/loaders/main.py
Normal file
@@ -0,0 +1,403 @@
|
||||
import requests
|
||||
import logging
|
||||
import ftfy
|
||||
import sys
|
||||
import json
|
||||
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from langchain_community.document_loaders import (
|
||||
AzureAIDocumentIntelligenceLoader,
|
||||
BSHTMLLoader,
|
||||
CSVLoader,
|
||||
Docx2txtLoader,
|
||||
OutlookMessageLoader,
|
||||
PyPDFLoader,
|
||||
TextLoader,
|
||||
UnstructuredEPubLoader,
|
||||
UnstructuredExcelLoader,
|
||||
UnstructuredODTLoader,
|
||||
UnstructuredPowerPointLoader,
|
||||
UnstructuredRSTLoader,
|
||||
UnstructuredXMLLoader,
|
||||
YoutubeLoader,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from open_webui.retrieval.loaders.external_document import ExternalDocumentLoader
|
||||
|
||||
from open_webui.retrieval.loaders.mistral import MistralLoader
|
||||
from open_webui.retrieval.loaders.datalab_marker import DatalabMarkerLoader
|
||||
from open_webui.retrieval.loaders.mineru import MinerULoader
|
||||
|
||||
|
||||
from open_webui.env import GLOBAL_LOG_LEVEL
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
known_source_ext = [
|
||||
"go",
|
||||
"py",
|
||||
"java",
|
||||
"sh",
|
||||
"bat",
|
||||
"ps1",
|
||||
"cmd",
|
||||
"js",
|
||||
"ts",
|
||||
"css",
|
||||
"cpp",
|
||||
"hpp",
|
||||
"h",
|
||||
"c",
|
||||
"cs",
|
||||
"sql",
|
||||
"log",
|
||||
"ini",
|
||||
"pl",
|
||||
"pm",
|
||||
"r",
|
||||
"dart",
|
||||
"dockerfile",
|
||||
"env",
|
||||
"php",
|
||||
"hs",
|
||||
"hsc",
|
||||
"lua",
|
||||
"nginxconf",
|
||||
"conf",
|
||||
"m",
|
||||
"mm",
|
||||
"plsql",
|
||||
"perl",
|
||||
"rb",
|
||||
"rs",
|
||||
"db2",
|
||||
"scala",
|
||||
"bash",
|
||||
"swift",
|
||||
"vue",
|
||||
"svelte",
|
||||
"ex",
|
||||
"exs",
|
||||
"erl",
|
||||
"tsx",
|
||||
"jsx",
|
||||
"hs",
|
||||
"lhs",
|
||||
"json",
|
||||
]
|
||||
|
||||
|
||||
class TikaLoader:
|
||||
def __init__(self, url, file_path, mime_type=None, extract_images=None):
|
||||
self.url = url
|
||||
self.file_path = file_path
|
||||
self.mime_type = mime_type
|
||||
|
||||
self.extract_images = extract_images
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
with open(self.file_path, "rb") as f:
|
||||
data = f.read()
|
||||
|
||||
if self.mime_type is not None:
|
||||
headers = {"Content-Type": self.mime_type}
|
||||
else:
|
||||
headers = {}
|
||||
|
||||
if self.extract_images == True:
|
||||
headers["X-Tika-PDFextractInlineImages"] = "true"
|
||||
|
||||
endpoint = self.url
|
||||
if not endpoint.endswith("/"):
|
||||
endpoint += "/"
|
||||
endpoint += "tika/text"
|
||||
|
||||
r = requests.put(endpoint, data=data, headers=headers)
|
||||
|
||||
if r.ok:
|
||||
raw_metadata = r.json()
|
||||
text = raw_metadata.get("X-TIKA:content", "<No text content found>").strip()
|
||||
|
||||
if "Content-Type" in raw_metadata:
|
||||
headers["Content-Type"] = raw_metadata["Content-Type"]
|
||||
|
||||
log.debug("Tika extracted text: %s", text)
|
||||
|
||||
return [Document(page_content=text, metadata=headers)]
|
||||
else:
|
||||
raise Exception(f"Error calling Tika: {r.reason}")
|
||||
|
||||
|
||||
class DoclingLoader:
|
||||
def __init__(self, url, api_key=None, file_path=None, mime_type=None, params=None):
|
||||
self.url = url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.file_path = file_path
|
||||
self.mime_type = mime_type
|
||||
|
||||
self.params = params or {}
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
with open(self.file_path, "rb") as f:
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["X-Api-Key"] = f"Bearer {self.api_key}"
|
||||
|
||||
r = requests.post(
|
||||
f"{self.url}/v1/convert/file",
|
||||
files={
|
||||
"files": (
|
||||
self.file_path,
|
||||
f,
|
||||
self.mime_type or "application/octet-stream",
|
||||
)
|
||||
},
|
||||
data={
|
||||
"image_export_mode": "placeholder",
|
||||
**self.params,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
if r.ok:
|
||||
result = r.json()
|
||||
document_data = result.get("document", {})
|
||||
text = document_data.get("md_content", "<No text content found>")
|
||||
|
||||
metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
|
||||
|
||||
log.debug("Docling extracted text: %s", text)
|
||||
return [Document(page_content=text, metadata=metadata)]
|
||||
else:
|
||||
error_msg = f"Error calling Docling API: {r.reason}"
|
||||
if r.text:
|
||||
try:
|
||||
error_data = r.json()
|
||||
if "detail" in error_data:
|
||||
error_msg += f" - {error_data['detail']}"
|
||||
except Exception:
|
||||
error_msg += f" - {r.text}"
|
||||
raise Exception(f"Error calling Docling: {error_msg}")
|
||||
|
||||
|
||||
class Loader:
|
||||
def __init__(self, engine: str = "", **kwargs):
|
||||
self.engine = engine
|
||||
self.user = kwargs.get("user", None)
|
||||
self.kwargs = kwargs
|
||||
|
||||
def load(
|
||||
self, filename: str, file_content_type: str, file_path: str
|
||||
) -> list[Document]:
|
||||
loader = self._get_loader(filename, file_content_type, file_path)
|
||||
docs = loader.load()
|
||||
|
||||
return [
|
||||
Document(
|
||||
page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
|
||||
def _is_text_file(self, file_ext: str, file_content_type: str) -> bool:
|
||||
return file_ext in known_source_ext or (
|
||||
file_content_type
|
||||
and file_content_type.find("text/") >= 0
|
||||
# Avoid text/html files being detected as text
|
||||
and not file_content_type.find("html") >= 0
|
||||
)
|
||||
|
||||
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
|
||||
file_ext = filename.split(".")[-1].lower()
|
||||
|
||||
if (
|
||||
self.engine == "external"
|
||||
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL")
|
||||
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY")
|
||||
):
|
||||
loader = ExternalDocumentLoader(
|
||||
file_path=file_path,
|
||||
url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"),
|
||||
api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"),
|
||||
mime_type=file_content_type,
|
||||
user=self.user,
|
||||
)
|
||||
elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
|
||||
if self._is_text_file(file_ext, file_content_type):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
loader = TikaLoader(
|
||||
url=self.kwargs.get("TIKA_SERVER_URL"),
|
||||
file_path=file_path,
|
||||
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
|
||||
)
|
||||
elif (
|
||||
self.engine == "datalab_marker"
|
||||
and self.kwargs.get("DATALAB_MARKER_API_KEY")
|
||||
and file_ext
|
||||
in [
|
||||
"pdf",
|
||||
"xls",
|
||||
"xlsx",
|
||||
"ods",
|
||||
"doc",
|
||||
"docx",
|
||||
"odt",
|
||||
"ppt",
|
||||
"pptx",
|
||||
"odp",
|
||||
"html",
|
||||
"epub",
|
||||
"png",
|
||||
"jpeg",
|
||||
"jpg",
|
||||
"webp",
|
||||
"gif",
|
||||
"tiff",
|
||||
]
|
||||
):
|
||||
api_base_url = self.kwargs.get("DATALAB_MARKER_API_BASE_URL", "")
|
||||
if not api_base_url or api_base_url.strip() == "":
|
||||
api_base_url = "https://www.datalab.to/api/v1/marker" # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349
|
||||
|
||||
loader = DatalabMarkerLoader(
|
||||
file_path=file_path,
|
||||
api_key=self.kwargs["DATALAB_MARKER_API_KEY"],
|
||||
api_base_url=api_base_url,
|
||||
additional_config=self.kwargs.get("DATALAB_MARKER_ADDITIONAL_CONFIG"),
|
||||
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False),
|
||||
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False),
|
||||
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False),
|
||||
paginate=self.kwargs.get("DATALAB_MARKER_PAGINATE", False),
|
||||
strip_existing_ocr=self.kwargs.get(
|
||||
"DATALAB_MARKER_STRIP_EXISTING_OCR", False
|
||||
),
|
||||
disable_image_extraction=self.kwargs.get(
|
||||
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
|
||||
),
|
||||
format_lines=self.kwargs.get("DATALAB_MARKER_FORMAT_LINES", False),
|
||||
output_format=self.kwargs.get(
|
||||
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
|
||||
),
|
||||
)
|
||||
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"):
|
||||
if self._is_text_file(file_ext, file_content_type):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# Build params for DoclingLoader
|
||||
params = self.kwargs.get("DOCLING_PARAMS", {})
|
||||
if not isinstance(params, dict):
|
||||
try:
|
||||
params = json.loads(params)
|
||||
except json.JSONDecodeError:
|
||||
log.error("Invalid DOCLING_PARAMS format, expected JSON object")
|
||||
params = {}
|
||||
|
||||
loader = DoclingLoader(
|
||||
url=self.kwargs.get("DOCLING_SERVER_URL"),
|
||||
api_key=self.kwargs.get("DOCLING_API_KEY", None),
|
||||
file_path=file_path,
|
||||
mime_type=file_content_type,
|
||||
params=params,
|
||||
)
|
||||
elif (
|
||||
self.engine == "document_intelligence"
|
||||
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
|
||||
and (
|
||||
file_ext in ["pdf", "docx", "ppt", "pptx"]
|
||||
or file_content_type
|
||||
in [
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
]
|
||||
)
|
||||
):
|
||||
if self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != "":
|
||||
loader = AzureAIDocumentIntelligenceLoader(
|
||||
file_path=file_path,
|
||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
||||
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
|
||||
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
|
||||
)
|
||||
else:
|
||||
loader = AzureAIDocumentIntelligenceLoader(
|
||||
file_path=file_path,
|
||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
||||
azure_credential=DefaultAzureCredential(),
|
||||
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
|
||||
)
|
||||
elif self.engine == "mineru" and file_ext in [
|
||||
"pdf"
|
||||
]: # MinerU currently only supports PDF
|
||||
|
||||
mineru_timeout = self.kwargs.get("MINERU_API_TIMEOUT", 300)
|
||||
if mineru_timeout:
|
||||
try:
|
||||
mineru_timeout = int(mineru_timeout)
|
||||
except ValueError:
|
||||
mineru_timeout = 300
|
||||
|
||||
loader = MinerULoader(
|
||||
file_path=file_path,
|
||||
api_mode=self.kwargs.get("MINERU_API_MODE", "local"),
|
||||
api_url=self.kwargs.get("MINERU_API_URL", "http://localhost:8000"),
|
||||
api_key=self.kwargs.get("MINERU_API_KEY", ""),
|
||||
params=self.kwargs.get("MINERU_PARAMS", {}),
|
||||
timeout=mineru_timeout,
|
||||
)
|
||||
elif (
|
||||
self.engine == "mistral_ocr"
|
||||
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
|
||||
and file_ext
|
||||
in ["pdf"] # Mistral OCR currently only supports PDF and images
|
||||
):
|
||||
loader = MistralLoader(
|
||||
base_url=self.kwargs.get("MISTRAL_OCR_API_BASE_URL"),
|
||||
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"),
|
||||
file_path=file_path,
|
||||
)
|
||||
else:
|
||||
if file_ext == "pdf":
|
||||
loader = PyPDFLoader(
|
||||
file_path, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES")
|
||||
)
|
||||
elif file_ext == "csv":
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
elif file_ext == "rst":
|
||||
loader = UnstructuredRSTLoader(file_path, mode="elements")
|
||||
elif file_ext == "xml":
|
||||
loader = UnstructuredXMLLoader(file_path)
|
||||
elif file_ext in ["htm", "html"]:
|
||||
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
|
||||
elif file_ext == "md":
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
elif file_content_type == "application/epub+zip":
|
||||
loader = UnstructuredEPubLoader(file_path)
|
||||
elif (
|
||||
file_content_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
or file_ext == "docx"
|
||||
):
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_content_type in [
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
] or file_ext in ["xls", "xlsx"]:
|
||||
loader = UnstructuredExcelLoader(file_path)
|
||||
elif file_content_type in [
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
] or file_ext in ["ppt", "pptx"]:
|
||||
loader = UnstructuredPowerPointLoader(file_path)
|
||||
elif file_ext == "msg":
|
||||
loader = OutlookMessageLoader(file_path)
|
||||
elif file_ext == "odt":
|
||||
loader = UnstructuredODTLoader(file_path)
|
||||
elif self._is_text_file(file_ext, file_content_type):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
|
||||
return loader
|
||||
524
backend/open_webui/retrieval/loaders/mineru.py
Normal file
524
backend/open_webui/retrieval/loaders/mineru.py
Normal file
@@ -0,0 +1,524 @@
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
import logging
|
||||
import tempfile
|
||||
import zipfile
|
||||
from typing import List, Optional
|
||||
from langchain_core.documents import Document
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MinerULoader:
|
||||
"""
|
||||
MinerU document parser loader supporting both Cloud API and Local API modes.
|
||||
|
||||
Cloud API: Uses MinerU managed service with async task-based processing
|
||||
Local API: Uses self-hosted MinerU API with synchronous processing
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_mode: str = "local",
|
||||
api_url: str = "http://localhost:8000",
|
||||
api_key: str = "",
|
||||
params: dict = None,
|
||||
timeout: Optional[int] = 300,
|
||||
):
|
||||
self.file_path = file_path
|
||||
self.api_mode = api_mode.lower()
|
||||
self.api_url = api_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.timeout = timeout
|
||||
|
||||
# Parse params dict with defaults
|
||||
self.params = params or {}
|
||||
self.enable_ocr = params.get("enable_ocr", False)
|
||||
self.enable_formula = params.get("enable_formula", True)
|
||||
self.enable_table = params.get("enable_table", True)
|
||||
self.language = params.get("language", "en")
|
||||
self.model_version = params.get("model_version", "pipeline")
|
||||
|
||||
self.page_ranges = self.params.pop("page_ranges", "")
|
||||
|
||||
# Validate API mode
|
||||
if self.api_mode not in ["local", "cloud"]:
|
||||
raise ValueError(
|
||||
f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'"
|
||||
)
|
||||
|
||||
# Validate Cloud API requirements
|
||||
if self.api_mode == "cloud" and not self.api_key:
|
||||
raise ValueError("API key is required for Cloud API mode")
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""
|
||||
Main entry point for loading and parsing the document.
|
||||
Routes to Cloud or Local API based on api_mode.
|
||||
"""
|
||||
try:
|
||||
if self.api_mode == "cloud":
|
||||
return self._load_cloud_api()
|
||||
else:
|
||||
return self._load_local_api()
|
||||
except Exception as e:
|
||||
log.error(f"Error loading document with MinerU: {e}")
|
||||
raise
|
||||
|
||||
def _load_local_api(self) -> List[Document]:
|
||||
"""
|
||||
Load document using Local API (synchronous).
|
||||
Posts file to /file_parse endpoint and gets immediate response.
|
||||
"""
|
||||
log.info(f"Using MinerU Local API at {self.api_url}")
|
||||
|
||||
filename = os.path.basename(self.file_path)
|
||||
|
||||
# Build form data for Local API
|
||||
form_data = {
|
||||
**self.params,
|
||||
"return_md": "true",
|
||||
}
|
||||
|
||||
# Page ranges (Local API uses start_page_id and end_page_id)
|
||||
if self.page_ranges:
|
||||
# For simplicity, if page_ranges is specified, log a warning
|
||||
# Full page range parsing would require parsing the string
|
||||
log.warning(
|
||||
f"Page ranges '{self.page_ranges}' specified but Local API uses different format. "
|
||||
"Consider using start_page_id/end_page_id parameters if needed."
|
||||
)
|
||||
|
||||
try:
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {"files": (filename, f, "application/octet-stream")}
|
||||
|
||||
log.info(f"Sending file to MinerU Local API: {filename}")
|
||||
log.debug(f"Local API parameters: {form_data}")
|
||||
|
||||
response = requests.post(
|
||||
f"{self.api_url}/file_parse",
|
||||
data=form_data,
|
||||
files=files,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(
|
||||
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
|
||||
)
|
||||
except requests.Timeout:
|
||||
raise HTTPException(
|
||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="MinerU Local API request timed out",
|
||||
)
|
||||
except requests.HTTPError as e:
|
||||
error_detail = f"MinerU Local API request failed: {e}"
|
||||
if e.response is not None:
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
error_detail += f" - {error_data}"
|
||||
except:
|
||||
error_detail += f" - {e.response.text}"
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error calling MinerU Local API: {str(e)}",
|
||||
)
|
||||
|
||||
# Parse response
|
||||
try:
|
||||
result = response.json()
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Invalid JSON response from MinerU Local API: {e}",
|
||||
)
|
||||
|
||||
# Extract markdown content from response
|
||||
if "results" not in result:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail="MinerU Local API response missing 'results' field",
|
||||
)
|
||||
|
||||
results = result["results"]
|
||||
if not results:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail="MinerU returned empty results",
|
||||
)
|
||||
|
||||
# Get the first (and typically only) result
|
||||
file_result = list(results.values())[0]
|
||||
markdown_content = file_result.get("md_content", "")
|
||||
|
||||
if not markdown_content:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail="MinerU returned empty markdown content",
|
||||
)
|
||||
|
||||
log.info(f"Successfully parsed document with MinerU Local API: {filename}")
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"source": filename,
|
||||
"api_mode": "local",
|
||||
"backend": result.get("backend", "unknown"),
|
||||
"version": result.get("version", "unknown"),
|
||||
}
|
||||
|
||||
return [Document(page_content=markdown_content, metadata=metadata)]
|
||||
|
||||
def _load_cloud_api(self) -> List[Document]:
|
||||
"""
|
||||
Load document using Cloud API (asynchronous).
|
||||
Uses batch upload endpoint to avoid need for public file URLs.
|
||||
"""
|
||||
log.info(f"Using MinerU Cloud API at {self.api_url}")
|
||||
|
||||
filename = os.path.basename(self.file_path)
|
||||
|
||||
# Step 1: Request presigned upload URL
|
||||
batch_id, upload_url = self._request_upload_url(filename)
|
||||
|
||||
# Step 2: Upload file to presigned URL
|
||||
self._upload_to_presigned_url(upload_url)
|
||||
|
||||
# Step 3: Poll for results
|
||||
result = self._poll_batch_status(batch_id, filename)
|
||||
|
||||
# Step 4: Download and extract markdown from ZIP
|
||||
markdown_content = self._download_and_extract_zip(
|
||||
result["full_zip_url"], filename
|
||||
)
|
||||
|
||||
log.info(f"Successfully parsed document with MinerU Cloud API: {filename}")
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"source": filename,
|
||||
"api_mode": "cloud",
|
||||
"batch_id": batch_id,
|
||||
}
|
||||
|
||||
return [Document(page_content=markdown_content, metadata=metadata)]
|
||||
|
||||
def _request_upload_url(self, filename: str) -> tuple:
|
||||
"""
|
||||
Request presigned upload URL from Cloud API.
|
||||
Returns (batch_id, upload_url).
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Build request body
|
||||
request_body = {
|
||||
**self.params,
|
||||
"files": [
|
||||
{
|
||||
"name": filename,
|
||||
"is_ocr": self.enable_ocr,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Add page ranges if specified
|
||||
if self.page_ranges:
|
||||
request_body["files"][0]["page_ranges"] = self.page_ranges
|
||||
|
||||
log.info(f"Requesting upload URL for: {filename}")
|
||||
log.debug(f"Cloud API request body: {request_body}")
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.api_url}/file-urls/batch",
|
||||
headers=headers,
|
||||
json=request_body,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = f"Failed to request upload URL: {e}"
|
||||
if e.response is not None:
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
error_detail += f" - {error_data.get('msg', error_data)}"
|
||||
except:
|
||||
error_detail += f" - {e.response.text}"
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error requesting upload URL: {str(e)}",
|
||||
)
|
||||
|
||||
try:
|
||||
result = response.json()
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Invalid JSON response: {e}",
|
||||
)
|
||||
|
||||
# Check for API error response
|
||||
if result.get("code") != 0:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
|
||||
)
|
||||
|
||||
data = result.get("data", {})
|
||||
batch_id = data.get("batch_id")
|
||||
file_urls = data.get("file_urls", [])
|
||||
|
||||
if not batch_id or not file_urls:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail="MinerU Cloud API response missing batch_id or file_urls",
|
||||
)
|
||||
|
||||
upload_url = file_urls[0]
|
||||
log.info(f"Received upload URL for batch: {batch_id}")
|
||||
|
||||
return batch_id, upload_url
|
||||
|
||||
def _upload_to_presigned_url(self, upload_url: str) -> None:
|
||||
"""
|
||||
Upload file to presigned URL (no authentication needed).
|
||||
"""
|
||||
log.info(f"Uploading file to presigned URL")
|
||||
|
||||
try:
|
||||
with open(self.file_path, "rb") as f:
|
||||
response = requests.put(
|
||||
upload_url,
|
||||
data=f,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(
|
||||
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
|
||||
)
|
||||
except requests.Timeout:
|
||||
raise HTTPException(
|
||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="File upload to presigned URL timed out",
|
||||
)
|
||||
except requests.HTTPError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Failed to upload file to presigned URL: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error uploading file: {str(e)}",
|
||||
)
|
||||
|
||||
log.info("File uploaded successfully")
|
||||
|
||||
def _poll_batch_status(self, batch_id: str, filename: str) -> dict:
|
||||
"""
|
||||
Poll batch status until completion.
|
||||
Returns the result dict for the file.
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
max_iterations = 300 # 10 minutes max (2 seconds per iteration)
|
||||
poll_interval = 2 # seconds
|
||||
|
||||
log.info(f"Polling batch status: {batch_id}")
|
||||
|
||||
for iteration in range(max_iterations):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.api_url}/extract-results/batch/{batch_id}",
|
||||
headers=headers,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = f"Failed to poll batch status: {e}"
|
||||
if e.response is not None:
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
error_detail += f" - {error_data.get('msg', error_data)}"
|
||||
except:
|
||||
error_detail += f" - {e.response.text}"
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error polling batch status: {str(e)}",
|
||||
)
|
||||
|
||||
try:
|
||||
result = response.json()
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Invalid JSON response while polling: {e}",
|
||||
)
|
||||
|
||||
# Check for API error response
|
||||
if result.get("code") != 0:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
|
||||
)
|
||||
|
||||
data = result.get("data", {})
|
||||
extract_result = data.get("extract_result", [])
|
||||
|
||||
# Find our file in the batch results
|
||||
file_result = None
|
||||
for item in extract_result:
|
||||
if item.get("file_name") == filename:
|
||||
file_result = item
|
||||
break
|
||||
|
||||
if not file_result:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"File {filename} not found in batch results",
|
||||
)
|
||||
|
||||
state = file_result.get("state")
|
||||
|
||||
if state == "done":
|
||||
log.info(f"Processing complete for {filename}")
|
||||
return file_result
|
||||
elif state == "failed":
|
||||
error_msg = file_result.get("err_msg", "Unknown error")
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"MinerU processing failed: {error_msg}",
|
||||
)
|
||||
elif state in ["waiting-file", "pending", "running", "converting"]:
|
||||
# Still processing
|
||||
if iteration % 10 == 0: # Log every 20 seconds
|
||||
log.info(
|
||||
f"Processing status: {state} (iteration {iteration + 1}/{max_iterations})"
|
||||
)
|
||||
time.sleep(poll_interval)
|
||||
else:
|
||||
log.warning(f"Unknown state: {state}")
|
||||
time.sleep(poll_interval)
|
||||
|
||||
# Timeout
|
||||
raise HTTPException(
|
||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="MinerU processing timed out after 10 minutes",
|
||||
)
|
||||
|
||||
def _download_and_extract_zip(self, zip_url: str, filename: str) -> str:
|
||||
"""
|
||||
Download ZIP file from CDN and extract markdown content.
|
||||
Returns the markdown content as a string.
|
||||
"""
|
||||
log.info(f"Downloading results from: {zip_url}")
|
||||
|
||||
try:
|
||||
response = requests.get(zip_url, timeout=60)
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Failed to download results ZIP: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error downloading results: {str(e)}",
|
||||
)
|
||||
|
||||
# Save ZIP to temporary file and extract
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
|
||||
tmp_zip.write(response.content)
|
||||
tmp_zip_path = tmp_zip.name
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Extract ZIP
|
||||
with zipfile.ZipFile(tmp_zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(tmp_dir)
|
||||
|
||||
# Find markdown file - search recursively for any .md file
|
||||
markdown_content = None
|
||||
found_md_path = None
|
||||
|
||||
# First, list all files in the ZIP for debugging
|
||||
all_files = []
|
||||
for root, dirs, files in os.walk(tmp_dir):
|
||||
for file in files:
|
||||
full_path = os.path.join(root, file)
|
||||
all_files.append(full_path)
|
||||
# Look for any .md file
|
||||
if file.endswith(".md"):
|
||||
found_md_path = full_path
|
||||
log.info(f"Found markdown file at: {full_path}")
|
||||
try:
|
||||
with open(full_path, "r", encoding="utf-8") as f:
|
||||
markdown_content = f.read()
|
||||
if (
|
||||
markdown_content
|
||||
): # Use the first non-empty markdown file
|
||||
break
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to read {full_path}: {e}")
|
||||
if markdown_content:
|
||||
break
|
||||
|
||||
if markdown_content is None:
|
||||
log.error(f"Available files in ZIP: {all_files}")
|
||||
# Try to provide more helpful error message
|
||||
md_files = [f for f in all_files if f.endswith(".md")]
|
||||
if md_files:
|
||||
error_msg = (
|
||||
f"Found .md files but couldn't read them: {md_files}"
|
||||
)
|
||||
else:
|
||||
error_msg = (
|
||||
f"No .md files found in ZIP. Available files: {all_files}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=error_msg,
|
||||
)
|
||||
|
||||
# Clean up temporary ZIP file
|
||||
os.unlink(tmp_zip_path)
|
||||
|
||||
except zipfile.BadZipFile as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Invalid ZIP file received: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error extracting ZIP: {str(e)}",
|
||||
)
|
||||
|
||||
if not markdown_content:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail="Extracted markdown content is empty",
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Successfully extracted markdown content ({len(markdown_content)} characters)"
|
||||
)
|
||||
return markdown_content
|
||||
769
backend/open_webui/retrieval/loaders/mistral.py
Normal file
769
backend/open_webui/retrieval/loaders/mistral.py
Normal file
@@ -0,0 +1,769 @@
|
||||
import requests
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.env import GLOBAL_LOG_LEVEL
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MistralLoader:
|
||||
"""
|
||||
Enhanced Mistral OCR loader with both sync and async support.
|
||||
Loads documents by processing them through the Mistral OCR API.
|
||||
|
||||
Performance Optimizations:
|
||||
- Differentiated timeouts for different operations
|
||||
- Intelligent retry logic with exponential backoff
|
||||
- Memory-efficient file streaming for large files
|
||||
- Connection pooling and keepalive optimization
|
||||
- Semaphore-based concurrency control for batch processing
|
||||
- Enhanced error handling with retryable error classification
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
file_path: str,
|
||||
timeout: int = 300, # 5 minutes default
|
||||
max_retries: int = 3,
|
||||
enable_debug_logging: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes the loader with enhanced features.
|
||||
|
||||
Args:
|
||||
api_key: Your Mistral API key.
|
||||
file_path: The local path to the PDF file to process.
|
||||
timeout: Request timeout in seconds.
|
||||
max_retries: Maximum number of retry attempts.
|
||||
enable_debug_logging: Enable detailed debug logs.
|
||||
"""
|
||||
if not api_key:
|
||||
raise ValueError("API key cannot be empty.")
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found at {file_path}")
|
||||
|
||||
self.base_url = (
|
||||
base_url.rstrip("/") if base_url else "https://api.mistral.ai/v1"
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.file_path = file_path
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.debug = enable_debug_logging
|
||||
|
||||
# PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations
|
||||
# This prevents long-running OCR operations from affecting quick operations
|
||||
# and improves user experience by failing fast on operations that should be quick
|
||||
self.upload_timeout = min(
|
||||
timeout, 120
|
||||
) # Cap upload at 2 minutes - prevents hanging on large files
|
||||
self.url_timeout = (
|
||||
30 # URL requests should be fast - fail quickly if API is slow
|
||||
)
|
||||
self.ocr_timeout = (
|
||||
timeout # OCR can take the full timeout - this is the heavy operation
|
||||
)
|
||||
self.cleanup_timeout = (
|
||||
30 # Cleanup should be quick - don't hang on file deletion
|
||||
)
|
||||
|
||||
# PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls
|
||||
# This avoids multiple os.path.basename() and os.path.getsize() calls during processing
|
||||
self.file_name = os.path.basename(file_path)
|
||||
self.file_size = os.path.getsize(file_path)
|
||||
|
||||
# ENHANCEMENT: Added User-Agent for better API tracking and debugging
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage
|
||||
}
|
||||
|
||||
def _debug_log(self, message: str, *args) -> None:
|
||||
"""
|
||||
PERFORMANCE OPTIMIZATION: Conditional debug logging for performance.
|
||||
|
||||
Only processes debug messages when debug mode is enabled, avoiding
|
||||
string formatting overhead in production environments.
|
||||
"""
|
||||
if self.debug:
|
||||
log.debug(message, *args)
|
||||
|
||||
def _handle_response(self, response: requests.Response) -> Dict[str, Any]:
|
||||
"""Checks response status and returns JSON content."""
|
||||
try:
|
||||
response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx)
|
||||
# Handle potential empty responses for certain successful requests (e.g., DELETE)
|
||||
if response.status_code == 204 or not response.content:
|
||||
return {} # Return empty dict if no content
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
log.error(f"HTTP error occurred: {http_err} - Response: {response.text}")
|
||||
raise
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
log.error(f"Request exception occurred: {req_err}")
|
||||
raise
|
||||
except ValueError as json_err: # Includes JSONDecodeError
|
||||
log.error(f"JSON decode error: {json_err} - Response: {response.text}")
|
||||
raise # Re-raise after logging
|
||||
|
||||
async def _handle_response_async(
|
||||
self, response: aiohttp.ClientResponse
|
||||
) -> Dict[str, Any]:
|
||||
"""Async version of response handling with better error info."""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
|
||||
# Check content type
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "application/json" not in content_type:
|
||||
if response.status == 204:
|
||||
return {}
|
||||
text = await response.text()
|
||||
raise ValueError(
|
||||
f"Unexpected content type: {content_type}, body: {text[:200]}..."
|
||||
)
|
||||
|
||||
return await response.json()
|
||||
|
||||
except aiohttp.ClientResponseError as e:
|
||||
error_text = await response.text() if response else "No response"
|
||||
log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}")
|
||||
raise
|
||||
except aiohttp.ClientError as e:
|
||||
log.error(f"Client error: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
log.error(f"Unexpected error processing response: {e}")
|
||||
raise
|
||||
|
||||
def _is_retryable_error(self, error: Exception) -> bool:
|
||||
"""
|
||||
ENHANCEMENT: Intelligent error classification for retry logic.
|
||||
|
||||
Determines if an error is retryable based on its type and status code.
|
||||
This prevents wasting time retrying errors that will never succeed
|
||||
(like authentication errors) while ensuring transient errors are retried.
|
||||
|
||||
Retryable errors:
|
||||
- Network connection errors (temporary network issues)
|
||||
- Timeouts (server might be temporarily overloaded)
|
||||
- Server errors (5xx status codes - server-side issues)
|
||||
- Rate limiting (429 status - temporary throttling)
|
||||
|
||||
Non-retryable errors:
|
||||
- Authentication errors (401, 403 - won't fix with retry)
|
||||
- Bad request errors (400 - malformed request)
|
||||
- Not found errors (404 - resource doesn't exist)
|
||||
"""
|
||||
if isinstance(error, requests.exceptions.ConnectionError):
|
||||
return True # Network issues are usually temporary
|
||||
if isinstance(error, requests.exceptions.Timeout):
|
||||
return True # Timeouts might resolve on retry
|
||||
if isinstance(error, requests.exceptions.HTTPError):
|
||||
# Only retry on server errors (5xx) or rate limits (429)
|
||||
if hasattr(error, "response") and error.response is not None:
|
||||
status_code = error.response.status_code
|
||||
return status_code >= 500 or status_code == 429
|
||||
return False
|
||||
if isinstance(
|
||||
error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)
|
||||
):
|
||||
return True # Async network/timeout errors are retryable
|
||||
if isinstance(error, aiohttp.ClientResponseError):
|
||||
return error.status >= 500 or error.status == 429
|
||||
return False # All other errors are non-retryable
|
||||
|
||||
def _retry_request_sync(self, request_func, *args, **kwargs):
|
||||
"""
|
||||
ENHANCEMENT: Synchronous retry logic with intelligent error classification.
|
||||
|
||||
Uses exponential backoff with jitter to avoid thundering herd problems.
|
||||
The wait time increases exponentially but is capped at 30 seconds to
|
||||
prevent excessive delays. Only retries errors that are likely to succeed
|
||||
on subsequent attempts.
|
||||
"""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return request_func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if attempt == self.max_retries - 1 or not self._is_retryable_error(e):
|
||||
raise
|
||||
|
||||
# PERFORMANCE OPTIMIZATION: Exponential backoff with cap
|
||||
# Prevents overwhelming the server while ensuring reasonable retry delays
|
||||
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
|
||||
log.warning(
|
||||
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
|
||||
f"Retrying in {wait_time}s..."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
|
||||
async def _retry_request_async(self, request_func, *args, **kwargs):
|
||||
"""
|
||||
ENHANCEMENT: Async retry logic with intelligent error classification.
|
||||
|
||||
Async version of retry logic that doesn't block the event loop during
|
||||
wait periods. Uses the same exponential backoff strategy as sync version.
|
||||
"""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return await request_func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if attempt == self.max_retries - 1 or not self._is_retryable_error(e):
|
||||
raise
|
||||
|
||||
# PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff
|
||||
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
|
||||
log.warning(
|
||||
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
|
||||
f"Retrying in {wait_time}s..."
|
||||
)
|
||||
await asyncio.sleep(wait_time) # Non-blocking wait
|
||||
|
||||
def _upload_file(self) -> str:
|
||||
"""
|
||||
PERFORMANCE OPTIMIZATION: Enhanced file upload with streaming consideration.
|
||||
|
||||
Uploads the file to Mistral for OCR processing (sync version).
|
||||
Uses context manager for file handling to ensure proper resource cleanup.
|
||||
Although streaming is not enabled for this endpoint, the file is opened
|
||||
in a context manager to minimize memory usage duration.
|
||||
"""
|
||||
log.info("Uploading file to Mistral API")
|
||||
url = f"{self.base_url}/files"
|
||||
|
||||
def upload_request():
|
||||
# MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
|
||||
# This ensures the file is closed immediately after reading, reducing memory usage
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {"file": (self.file_name, f, "application/pdf")}
|
||||
data = {"purpose": "ocr"}
|
||||
|
||||
# NOTE: stream=False is required for this endpoint
|
||||
# The Mistral API doesn't support chunked uploads for this endpoint
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.headers,
|
||||
files=files,
|
||||
data=data,
|
||||
timeout=self.upload_timeout, # Use specialized upload timeout
|
||||
stream=False, # Keep as False for this endpoint
|
||||
)
|
||||
|
||||
return self._handle_response(response)
|
||||
|
||||
try:
|
||||
response_data = self._retry_request_sync(upload_request)
|
||||
file_id = response_data.get("id")
|
||||
if not file_id:
|
||||
raise ValueError("File ID not found in upload response.")
|
||||
log.info(f"File uploaded successfully. File ID: {file_id}")
|
||||
return file_id
|
||||
except Exception as e:
|
||||
log.error(f"Failed to upload file: {e}")
|
||||
raise
|
||||
|
||||
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
|
||||
"""Async file upload with streaming for better memory efficiency."""
|
||||
url = f"{self.base_url}/files"
|
||||
|
||||
async def upload_request():
|
||||
# Create multipart writer for streaming upload
|
||||
writer = aiohttp.MultipartWriter("form-data")
|
||||
|
||||
# Add purpose field
|
||||
purpose_part = writer.append("ocr")
|
||||
purpose_part.set_content_disposition("form-data", name="purpose")
|
||||
|
||||
# Add file part with streaming
|
||||
file_part = writer.append_payload(
|
||||
aiohttp.streams.FilePayload(
|
||||
self.file_path,
|
||||
filename=self.file_name,
|
||||
content_type="application/pdf",
|
||||
)
|
||||
)
|
||||
file_part.set_content_disposition(
|
||||
"form-data", name="file", filename=self.file_name
|
||||
)
|
||||
|
||||
self._debug_log(
|
||||
f"Uploading file: {self.file_name} ({self.file_size:,} bytes)"
|
||||
)
|
||||
|
||||
async with session.post(
|
||||
url,
|
||||
data=writer,
|
||||
headers=self.headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.upload_timeout),
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
|
||||
response_data = await self._retry_request_async(upload_request)
|
||||
|
||||
file_id = response_data.get("id")
|
||||
if not file_id:
|
||||
raise ValueError("File ID not found in upload response.")
|
||||
|
||||
log.info(f"File uploaded successfully. File ID: {file_id}")
|
||||
return file_id
|
||||
|
||||
def _get_signed_url(self, file_id: str) -> str:
|
||||
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
|
||||
log.info(f"Getting signed URL for file ID: {file_id}")
|
||||
url = f"{self.base_url}/files/{file_id}/url"
|
||||
params = {"expiry": 1}
|
||||
signed_url_headers = {**self.headers, "Accept": "application/json"}
|
||||
|
||||
def url_request():
|
||||
response = requests.get(
|
||||
url, headers=signed_url_headers, params=params, timeout=self.url_timeout
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
try:
|
||||
response_data = self._retry_request_sync(url_request)
|
||||
signed_url = response_data.get("url")
|
||||
if not signed_url:
|
||||
raise ValueError("Signed URL not found in response.")
|
||||
log.info("Signed URL received.")
|
||||
return signed_url
|
||||
except Exception as e:
|
||||
log.error(f"Failed to get signed URL: {e}")
|
||||
raise
|
||||
|
||||
async def _get_signed_url_async(
|
||||
self, session: aiohttp.ClientSession, file_id: str
|
||||
) -> str:
|
||||
"""Async signed URL retrieval."""
|
||||
url = f"{self.base_url}/files/{file_id}/url"
|
||||
params = {"expiry": 1}
|
||||
|
||||
headers = {**self.headers, "Accept": "application/json"}
|
||||
|
||||
async def url_request():
|
||||
self._debug_log(f"Getting signed URL for file ID: {file_id}")
|
||||
async with session.get(
|
||||
url,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=aiohttp.ClientTimeout(total=self.url_timeout),
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
|
||||
response_data = await self._retry_request_async(url_request)
|
||||
|
||||
signed_url = response_data.get("url")
|
||||
if not signed_url:
|
||||
raise ValueError("Signed URL not found in response.")
|
||||
|
||||
self._debug_log("Signed URL received successfully")
|
||||
return signed_url
|
||||
|
||||
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
|
||||
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
|
||||
log.info("Processing OCR via Mistral API")
|
||||
url = f"{self.base_url}/ocr"
|
||||
ocr_headers = {
|
||||
**self.headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": "mistral-ocr-latest",
|
||||
"document": {
|
||||
"type": "document_url",
|
||||
"document_url": signed_url,
|
||||
},
|
||||
"include_image_base64": False,
|
||||
}
|
||||
|
||||
def ocr_request():
|
||||
response = requests.post(
|
||||
url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
try:
|
||||
ocr_response = self._retry_request_sync(ocr_request)
|
||||
log.info("OCR processing done.")
|
||||
self._debug_log("OCR response: %s", ocr_response)
|
||||
return ocr_response
|
||||
except Exception as e:
|
||||
log.error(f"Failed during OCR processing: {e}")
|
||||
raise
|
||||
|
||||
async def _process_ocr_async(
|
||||
self, session: aiohttp.ClientSession, signed_url: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Async OCR processing with timing metrics."""
|
||||
url = f"{self.base_url}/ocr"
|
||||
|
||||
headers = {
|
||||
**self.headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "mistral-ocr-latest",
|
||||
"document": {
|
||||
"type": "document_url",
|
||||
"document_url": signed_url,
|
||||
},
|
||||
"include_image_base64": False,
|
||||
}
|
||||
|
||||
async def ocr_request():
|
||||
log.info("Starting OCR processing via Mistral API")
|
||||
start_time = time.time()
|
||||
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.ocr_timeout),
|
||||
) as response:
|
||||
ocr_response = await self._handle_response_async(response)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
log.info(f"OCR processing completed in {processing_time:.2f}s")
|
||||
|
||||
return ocr_response
|
||||
|
||||
return await self._retry_request_async(ocr_request)
|
||||
|
||||
def _delete_file(self, file_id: str) -> None:
|
||||
"""Deletes the file from Mistral storage (sync version)."""
|
||||
log.info(f"Deleting uploaded file ID: {file_id}")
|
||||
url = f"{self.base_url}/files/{file_id}"
|
||||
|
||||
try:
|
||||
response = requests.delete(
|
||||
url, headers=self.headers, timeout=self.cleanup_timeout
|
||||
)
|
||||
delete_response = self._handle_response(response)
|
||||
log.info(f"File deleted successfully: {delete_response}")
|
||||
except Exception as e:
|
||||
# Log error but don't necessarily halt execution if deletion fails
|
||||
log.error(f"Failed to delete file ID {file_id}: {e}")
|
||||
|
||||
async def _delete_file_async(
|
||||
self, session: aiohttp.ClientSession, file_id: str
|
||||
) -> None:
|
||||
"""Async file deletion with error tolerance."""
|
||||
try:
|
||||
|
||||
async def delete_request():
|
||||
self._debug_log(f"Deleting file ID: {file_id}")
|
||||
async with session.delete(
|
||||
url=f"{self.base_url}/files/{file_id}",
|
||||
headers=self.headers,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.cleanup_timeout
|
||||
), # Shorter timeout for cleanup
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
|
||||
await self._retry_request_async(delete_request)
|
||||
self._debug_log(f"File {file_id} deleted successfully")
|
||||
|
||||
except Exception as e:
|
||||
# Don't fail the entire process if cleanup fails
|
||||
log.warning(f"Failed to delete file ID {file_id}: {e}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_session(self):
|
||||
"""Context manager for HTTP session with optimized settings."""
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=20, # Increased total connection limit for better throughput
|
||||
limit_per_host=10, # Increased per-host limit for API endpoints
|
||||
ttl_dns_cache=600, # Longer DNS cache TTL (10 minutes)
|
||||
use_dns_cache=True,
|
||||
keepalive_timeout=60, # Increased keepalive for connection reuse
|
||||
enable_cleanup_closed=True,
|
||||
force_close=False, # Allow connection reuse
|
||||
resolver=aiohttp.AsyncResolver(), # Use async DNS resolver
|
||||
)
|
||||
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=self.timeout,
|
||||
connect=30, # Connection timeout
|
||||
sock_read=60, # Socket read timeout
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=timeout,
|
||||
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
|
||||
raise_for_status=False, # We handle status codes manually
|
||||
trust_env=True,
|
||||
) as session:
|
||||
yield session
|
||||
|
||||
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
|
||||
"""Process OCR results into Document objects with enhanced metadata and memory efficiency."""
|
||||
pages_data = ocr_response.get("pages")
|
||||
if not pages_data:
|
||||
log.warning("No pages found in OCR response.")
|
||||
return [
|
||||
Document(
|
||||
page_content="No text content found",
|
||||
metadata={"error": "no_pages", "file_name": self.file_name},
|
||||
)
|
||||
]
|
||||
|
||||
documents = []
|
||||
total_pages = len(pages_data)
|
||||
skipped_pages = 0
|
||||
|
||||
# Process pages in a memory-efficient way
|
||||
for page_data in pages_data:
|
||||
page_content = page_data.get("markdown")
|
||||
page_index = page_data.get("index") # API uses 0-based index
|
||||
|
||||
if page_content is None or page_index is None:
|
||||
skipped_pages += 1
|
||||
self._debug_log(
|
||||
f"Skipping page due to missing 'markdown' or 'index'. Data keys: {list(page_data.keys())}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Clean up content efficiently with early exit for empty content
|
||||
if isinstance(page_content, str):
|
||||
cleaned_content = page_content.strip()
|
||||
else:
|
||||
cleaned_content = str(page_content).strip()
|
||||
|
||||
if not cleaned_content:
|
||||
skipped_pages += 1
|
||||
self._debug_log(f"Skipping empty page {page_index}")
|
||||
continue
|
||||
|
||||
# Create document with optimized metadata
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=cleaned_content,
|
||||
metadata={
|
||||
"page": page_index, # 0-based index from API
|
||||
"page_label": page_index + 1, # 1-based label for convenience
|
||||
"total_pages": total_pages,
|
||||
"file_name": self.file_name,
|
||||
"file_size": self.file_size,
|
||||
"processing_engine": "mistral-ocr",
|
||||
"content_length": len(cleaned_content),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if skipped_pages > 0:
|
||||
log.info(
|
||||
f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages"
|
||||
)
|
||||
|
||||
if not documents:
|
||||
# Case where pages existed but none had valid markdown/index
|
||||
log.warning(
|
||||
"OCR response contained pages, but none had valid content/index."
|
||||
)
|
||||
return [
|
||||
Document(
|
||||
page_content="No valid text content found in document",
|
||||
metadata={
|
||||
"error": "no_valid_pages",
|
||||
"total_pages": total_pages,
|
||||
"file_name": self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
return documents
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""
|
||||
Executes the full OCR workflow: upload, get URL, process OCR, delete file.
|
||||
Synchronous version for backward compatibility.
|
||||
|
||||
Returns:
|
||||
A list of Document objects, one for each page processed.
|
||||
"""
|
||||
file_id = None
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 1. Upload file
|
||||
file_id = self._upload_file()
|
||||
|
||||
# 2. Get Signed URL
|
||||
signed_url = self._get_signed_url(file_id)
|
||||
|
||||
# 3. Process OCR
|
||||
ocr_response = self._process_ocr(signed_url)
|
||||
|
||||
# 4. Process results
|
||||
documents = self._process_results(ocr_response)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
log.info(
|
||||
f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
total_time = time.time() - start_time
|
||||
log.error(
|
||||
f"An error occurred during the loading process after {total_time:.2f}s: {e}"
|
||||
)
|
||||
# Return an error document on failure
|
||||
return [
|
||||
Document(
|
||||
page_content=f"Error during processing: {e}",
|
||||
metadata={
|
||||
"error": "processing_failed",
|
||||
"file_name": self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
finally:
|
||||
# 5. Delete file (attempt even if prior steps failed after upload)
|
||||
if file_id:
|
||||
try:
|
||||
self._delete_file(file_id)
|
||||
except Exception as del_e:
|
||||
# Log deletion error, but don't overwrite original error if one occurred
|
||||
log.error(
|
||||
f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}"
|
||||
)
|
||||
|
||||
async def load_async(self) -> List[Document]:
|
||||
"""
|
||||
Asynchronous OCR workflow execution with optimized performance.
|
||||
|
||||
Returns:
|
||||
A list of Document objects, one for each page processed.
|
||||
"""
|
||||
file_id = None
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
async with self._get_session() as session:
|
||||
# 1. Upload file with streaming
|
||||
file_id = await self._upload_file_async(session)
|
||||
|
||||
# 2. Get signed URL
|
||||
signed_url = await self._get_signed_url_async(session, file_id)
|
||||
|
||||
# 3. Process OCR
|
||||
ocr_response = await self._process_ocr_async(session, signed_url)
|
||||
|
||||
# 4. Process results
|
||||
documents = self._process_results(ocr_response)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
log.info(
|
||||
f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
total_time = time.time() - start_time
|
||||
log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}")
|
||||
return [
|
||||
Document(
|
||||
page_content=f"Error during OCR processing: {e}",
|
||||
metadata={
|
||||
"error": "processing_failed",
|
||||
"file_name": self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
finally:
|
||||
# 5. Cleanup - always attempt file deletion
|
||||
if file_id:
|
||||
try:
|
||||
async with self._get_session() as session:
|
||||
await self._delete_file_async(session, file_id)
|
||||
except Exception as cleanup_error:
|
||||
log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}")
|
||||
|
||||
@staticmethod
|
||||
async def load_multiple_async(
|
||||
loaders: List["MistralLoader"],
|
||||
max_concurrent: int = 5, # Limit concurrent requests
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
Process multiple files concurrently with controlled concurrency.
|
||||
|
||||
Args:
|
||||
loaders: List of MistralLoader instances
|
||||
max_concurrent: Maximum number of concurrent requests
|
||||
|
||||
Returns:
|
||||
List of document lists, one for each loader
|
||||
"""
|
||||
if not loaders:
|
||||
return []
|
||||
|
||||
log.info(
|
||||
f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent"
|
||||
)
|
||||
start_time = time.time()
|
||||
|
||||
# Use semaphore to control concurrency
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def process_with_semaphore(loader: "MistralLoader") -> List[Document]:
|
||||
async with semaphore:
|
||||
return await loader.load_async()
|
||||
|
||||
# Process all files with controlled concurrency
|
||||
tasks = [process_with_semaphore(loader) for loader in loaders]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Handle any exceptions in results
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
log.error(f"File {i} failed: {result}")
|
||||
processed_results.append(
|
||||
[
|
||||
Document(
|
||||
page_content=f"Error processing file: {result}",
|
||||
metadata={
|
||||
"error": "batch_processing_failed",
|
||||
"file_index": i,
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
# MONITORING: Log comprehensive batch processing statistics
|
||||
total_time = time.time() - start_time
|
||||
total_docs = sum(len(docs) for docs in processed_results)
|
||||
success_count = sum(
|
||||
1 for result in results if not isinstance(result, Exception)
|
||||
)
|
||||
failure_count = len(results) - success_count
|
||||
|
||||
log.info(
|
||||
f"Batch processing completed in {total_time:.2f}s: "
|
||||
f"{success_count} files succeeded, {failure_count} files failed, "
|
||||
f"produced {total_docs} total documents"
|
||||
)
|
||||
|
||||
return processed_results
|
||||
91
backend/open_webui/retrieval/loaders/tavily.py
Normal file
91
backend/open_webui/retrieval/loaders/tavily.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import requests
|
||||
import logging
|
||||
from typing import Iterator, List, Literal, Union
|
||||
|
||||
from langchain_core.document_loaders import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TavilyLoader(BaseLoader):
|
||||
"""Extract web page content from URLs using Tavily Extract API.
|
||||
|
||||
This is a LangChain document loader that uses Tavily's Extract API to
|
||||
retrieve content from web pages and return it as Document objects.
|
||||
|
||||
Args:
|
||||
urls: URL or list of URLs to extract content from.
|
||||
api_key: The Tavily API key.
|
||||
extract_depth: Depth of extraction, either "basic" or "advanced".
|
||||
continue_on_failure: Whether to continue if extraction of a URL fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
urls: Union[str, List[str]],
|
||||
api_key: str,
|
||||
extract_depth: Literal["basic", "advanced"] = "basic",
|
||||
continue_on_failure: bool = True,
|
||||
) -> None:
|
||||
"""Initialize Tavily Extract client.
|
||||
|
||||
Args:
|
||||
urls: URL or list of URLs to extract content from.
|
||||
api_key: The Tavily API key.
|
||||
include_images: Whether to include images in the extraction.
|
||||
extract_depth: Depth of extraction, either "basic" or "advanced".
|
||||
advanced extraction retrieves more data, including tables and
|
||||
embedded content, with higher success but may increase latency.
|
||||
basic costs 1 credit per 5 successful URL extractions,
|
||||
advanced costs 2 credits per 5 successful URL extractions.
|
||||
continue_on_failure: Whether to continue if extraction of a URL fails.
|
||||
"""
|
||||
if not urls:
|
||||
raise ValueError("At least one URL must be provided.")
|
||||
|
||||
self.api_key = api_key
|
||||
self.urls = urls if isinstance(urls, list) else [urls]
|
||||
self.extract_depth = extract_depth
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.api_url = "https://api.tavily.com/extract"
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Extract and yield documents from the URLs using Tavily Extract API."""
|
||||
batch_size = 20
|
||||
for i in range(0, len(self.urls), batch_size):
|
||||
batch_urls = self.urls[i : i + batch_size]
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
# Use string for single URL, array for multiple URLs
|
||||
urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
|
||||
payload = {"urls": urls_param, "extract_depth": self.extract_depth}
|
||||
# Make the API call
|
||||
response = requests.post(self.api_url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
# Process successful results
|
||||
for result in response_data.get("results", []):
|
||||
url = result.get("url", "")
|
||||
content = result.get("raw_content", "")
|
||||
if not content:
|
||||
log.warning(f"No content extracted from {url}")
|
||||
continue
|
||||
# Add URLs as metadata
|
||||
metadata = {"source": url}
|
||||
yield Document(
|
||||
page_content=content,
|
||||
metadata=metadata,
|
||||
)
|
||||
for failed in response_data.get("failed_results", []):
|
||||
url = failed.get("url", "")
|
||||
error = failed.get("error", "Unknown error")
|
||||
log.error(f"Failed to extract content from {url}: {error}")
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.error(f"Error extracting content from batch {batch_urls}: {e}")
|
||||
else:
|
||||
raise e
|
||||
164
backend/open_webui/retrieval/loaders/youtube.py
Normal file
164
backend/open_webui/retrieval/loaders/youtube.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import logging
|
||||
from xml.etree.ElementTree import ParseError
|
||||
|
||||
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from langchain_core.documents import Document
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
ALLOWED_SCHEMES = {"http", "https"}
|
||||
ALLOWED_NETLOCS = {
|
||||
"youtu.be",
|
||||
"m.youtube.com",
|
||||
"youtube.com",
|
||||
"www.youtube.com",
|
||||
"www.youtube-nocookie.com",
|
||||
"vid.plus",
|
||||
}
|
||||
|
||||
|
||||
def _parse_video_id(url: str) -> Optional[str]:
|
||||
"""Parse a YouTube URL and return the video ID if valid, otherwise None."""
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
if parsed_url.scheme not in ALLOWED_SCHEMES:
|
||||
return None
|
||||
|
||||
if parsed_url.netloc not in ALLOWED_NETLOCS:
|
||||
return None
|
||||
|
||||
path = parsed_url.path
|
||||
|
||||
if path.endswith("/watch"):
|
||||
query = parsed_url.query
|
||||
parsed_query = parse_qs(query)
|
||||
if "v" in parsed_query:
|
||||
ids = parsed_query["v"]
|
||||
video_id = ids if isinstance(ids, str) else ids[0]
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
path = parsed_url.path.lstrip("/")
|
||||
video_id = path.split("/")[-1]
|
||||
|
||||
if len(video_id) != 11: # Video IDs are 11 characters long
|
||||
return None
|
||||
|
||||
return video_id
|
||||
|
||||
|
||||
class YoutubeLoader:
|
||||
"""Load `YouTube` video transcripts."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_id: str,
|
||||
language: Union[str, Sequence[str]] = "en",
|
||||
proxy_url: Optional[str] = None,
|
||||
):
|
||||
"""Initialize with YouTube video ID."""
|
||||
_video_id = _parse_video_id(video_id)
|
||||
self.video_id = _video_id if _video_id is not None else video_id
|
||||
self._metadata = {"source": video_id}
|
||||
self.proxy_url = proxy_url
|
||||
|
||||
# Ensure language is a list
|
||||
if isinstance(language, str):
|
||||
self.language = [language]
|
||||
else:
|
||||
self.language = list(language)
|
||||
|
||||
# Add English as fallback if not already in the list
|
||||
if "en" not in self.language:
|
||||
self.language.append("en")
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load YouTube transcripts into `Document` objects."""
|
||||
try:
|
||||
from youtube_transcript_api import (
|
||||
NoTranscriptFound,
|
||||
TranscriptsDisabled,
|
||||
YouTubeTranscriptApi,
|
||||
)
|
||||
from youtube_transcript_api.proxies import GenericProxyConfig
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Could not import "youtube_transcript_api" Python package. '
|
||||
"Please install it with `pip install youtube-transcript-api`."
|
||||
)
|
||||
|
||||
if self.proxy_url:
|
||||
youtube_proxies = GenericProxyConfig(
|
||||
http_url=self.proxy_url, https_url=self.proxy_url
|
||||
)
|
||||
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
|
||||
else:
|
||||
youtube_proxies = None
|
||||
|
||||
transcript_api = YouTubeTranscriptApi(proxy_config=youtube_proxies)
|
||||
try:
|
||||
transcript_list = transcript_api.list(self.video_id)
|
||||
except Exception as e:
|
||||
log.exception("Loading YouTube transcript failed")
|
||||
return []
|
||||
|
||||
# Try each language in order of priority
|
||||
for lang in self.language:
|
||||
try:
|
||||
transcript = transcript_list.find_transcript([lang])
|
||||
if transcript.is_generated:
|
||||
log.debug(f"Found generated transcript for language '{lang}'")
|
||||
try:
|
||||
transcript = transcript_list.find_manually_created_transcript(
|
||||
[lang]
|
||||
)
|
||||
log.debug(f"Found manual transcript for language '{lang}'")
|
||||
except NoTranscriptFound:
|
||||
log.debug(
|
||||
f"No manual transcript found for language '{lang}', using generated"
|
||||
)
|
||||
pass
|
||||
|
||||
log.debug(f"Found transcript for language '{lang}'")
|
||||
try:
|
||||
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
|
||||
except ParseError:
|
||||
log.debug(f"Empty or invalid transcript for language '{lang}'")
|
||||
continue
|
||||
|
||||
if not transcript_pieces:
|
||||
log.debug(f"Empty transcript for language '{lang}'")
|
||||
continue
|
||||
|
||||
transcript_text = " ".join(
|
||||
map(
|
||||
lambda transcript_piece: (
|
||||
transcript_piece.text.strip(" ")
|
||||
if hasattr(transcript_piece, "text")
|
||||
else ""
|
||||
),
|
||||
transcript_pieces,
|
||||
)
|
||||
)
|
||||
return [Document(page_content=transcript_text, metadata=self._metadata)]
|
||||
except NoTranscriptFound:
|
||||
log.debug(f"No transcript found for language '{lang}'")
|
||||
continue
|
||||
except Exception as e:
|
||||
log.info(f"Error finding transcript for language '{lang}'")
|
||||
raise e
|
||||
|
||||
# If we get here, all languages failed
|
||||
languages_tried = ", ".join(self.language)
|
||||
log.warning(
|
||||
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
|
||||
)
|
||||
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
|
||||
|
||||
async def aload(self) -> Generator[Document, None, None]:
|
||||
"""Asynchronously load YouTube transcripts into `Document` objects."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self.load)
|
||||
Reference in New Issue
Block a user