Files
abot/utils/ai/unified_llm.py

617 lines
23 KiB
Python

from __future__ import annotations
import base64
import binascii
import json
import mimetypes
import time
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import requests
from loguru import logger
from utils.ai.llm_registry import LLMRegistry
class UnifiedLLMClient:
"""统一的 LLM 调用客户端,兼容 OpenAI-compatible 与 Dify。"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
self.LOG = logger
self.raw_config = config or {}
self.config = self._normalize_config(self.raw_config)
self.enabled = bool(self.config.get("enabled", True))
self.provider = str(self.config.get("provider", "openai_compatible")).strip().lower()
self.base_url = str(self.config.get("base_url", "")).rstrip("/")
self.endpoint = str(self.config.get("endpoint", "")).lstrip("/")
self.api_key = str(self.config.get("api_key", "")).strip()
self.model = str(self.config.get("model", "")).strip()
self.timeout_seconds = int(self.config.get("timeout_seconds", 60))
self.timeout = self.timeout_seconds
self.temperature = float(self.config.get("temperature", 0.7))
self.max_tokens = int(self.config.get("max_tokens", 1024))
self.stream = bool(self.config.get("stream", False))
self.max_retries = max(int(self.config.get("max_retries", 3) or 3), 1)
self.retry_delay_seconds = float(self.config.get("retry_delay_seconds", 1.0) or 1.0)
self.mode = str(self.config.get("mode", "chat")).strip().lower()
self.response_mode = str(self.config.get("response_mode", "blocking")).strip().lower()
self.workflow_output_key = str(self.config.get("workflow_output_key", "text")).strip()
self.default_system_prompt = str(self.config.get("system_prompt", "")).strip()
self.last_error = ""
def is_available(self) -> bool:
if not self.enabled:
return False
if self.provider == "openai_compatible":
return bool(self.base_url and self.endpoint and self.model)
if self.provider == "dify":
return bool(self.base_url and self.endpoint and self.api_key)
return False
def chat(
self,
system_prompt: str,
user_prompt: str,
user_id: str,
image_urls: Optional[List[str]] = None,
) -> str:
result = self.generate(
system_prompt=system_prompt,
user_prompt=user_prompt,
user=user_id,
image_urls=image_urls or [],
)
return (result or {}).get("text", "") or ""
def run(
self,
prompt: str,
user: str,
inputs: Optional[Dict[str, Any]] = None,
tag: str = "",
files: Optional[List[Dict[str, Any]]] = None,
) -> Optional[Dict[str, Any]]:
if self.provider == "dify":
return self.generate(prompt=prompt, user=user, inputs=inputs or {}, tag=tag, files=files or [])
effective_prompt = prompt or self._stringify_inputs(inputs or {})
return self.generate(
system_prompt=self.default_system_prompt,
user_prompt=effective_prompt,
user=user,
inputs=inputs or {},
tag=tag,
files=files or [],
)
def upload_dify_file(
self,
*,
user: str,
file_bytes: bytes,
filename: str,
mime_type: str = "",
) -> Optional[Dict[str, Any]]:
self.last_error = ""
if self.provider != "dify":
self.last_error = "upload_not_supported_for_provider"
return None
if not self.base_url or not self.api_key or not user or not file_bytes or not filename:
self.last_error = "upload_missing_required_fields"
return None
upload_url = f"{self.base_url}/files/upload"
headers = {"Authorization": self._build_auth_header(self.api_key)}
detected_mime = mime_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
files = {
"file": (filename, file_bytes, detected_mime),
}
data = {"user": user}
for attempt in range(1, self.max_retries + 1):
try:
response = requests.post(upload_url, headers=headers, files=files, data=data, timeout=self.timeout_seconds)
response.raise_for_status()
payload = response.json() or {}
if payload.get("id"):
return payload
self.last_error = "upload_missing_file_id"
except Exception as exc:
self.last_error = f"upload_failed:attempt_{attempt}:{exc}"
if attempt < self.max_retries:
time.sleep(self.retry_delay_seconds * attempt)
return None
@staticmethod
def build_dify_file_ref(
*,
file_type: str = "image",
remote_url: str = "",
upload_file_id: str = "",
) -> Dict[str, Any]:
if upload_file_id:
return {
"type": file_type,
"transfer_method": "local_file",
"upload_file_id": upload_file_id,
}
if remote_url:
return {
"type": file_type,
"transfer_method": "remote_url",
"url": remote_url,
}
return {}
@staticmethod
def decode_data_url(data_url: str) -> Tuple[bytes, str]:
raw = str(data_url or "").strip()
if not raw.startswith("data:") or "," not in raw:
return b"", ""
header, encoded = raw.split(",", 1)
mime_type = header[5:].split(";", 1)[0].strip()
try:
return base64.b64decode(encoded), mime_type
except (ValueError, binascii.Error):
return b"", mime_type
def generate(
self,
prompt: str = "",
user: str = "",
inputs: Optional[Dict[str, Any]] = None,
tag: str = "",
system_prompt: str = "",
user_prompt: str = "",
image_urls: Optional[List[str]] = None,
files: Optional[List[Dict[str, Any]]] = None,
) -> Optional[Dict[str, Any]]:
self.last_error = ""
if not self.is_available():
self.last_error = "client_unavailable"
return None
if self.provider == "dify":
return self._generate_dify(
prompt=prompt,
user=user,
inputs=inputs or {},
tag=tag,
files=files or [],
)
if self.provider == "openai_compatible":
return self._generate_openai(
system_prompt=system_prompt,
user_prompt=user_prompt or prompt,
user=user,
image_urls=image_urls or [],
)
self.last_error = f"unsupported_provider:{self.provider}"
return None
def _generate_openai(
self,
system_prompt: str,
user_prompt: str,
user: str,
image_urls: List[str],
) -> Optional[Dict[str, Any]]:
payload = {
"model": self.model,
"messages": self._build_messages(system_prompt or self.default_system_prompt, user_prompt, image_urls),
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"user": user,
"stream": self.stream,
}
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = self._build_auth_header(self.api_key)
url = f"{self.base_url}/{self.endpoint}"
for attempt in range(1, self.max_retries + 1):
try:
if self.stream:
text, raw = self._request_openai_stream(url, payload, headers)
else:
text, raw = self._request_openai_json(url, payload, headers)
if text:
return {
"text": text,
"usage": self._extract_openai_usage(raw),
"raw": raw,
}
self.last_error = f"empty_model_output:{self.model}"
except Exception as exc:
self.last_error = f"request_failed:attempt_{attempt}:{exc}"
if attempt < self.max_retries:
time.sleep(self.retry_delay_seconds * attempt)
return None
def _generate_dify(
self,
prompt: str,
user: str,
inputs: Dict[str, Any],
tag: str,
files: List[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
headers = {
"Authorization": self._build_auth_header(self.api_key),
"Content-Type": "application/json",
}
payload_inputs = dict(inputs or {})
if self.mode == "workflow":
if prompt and "query" not in payload_inputs:
payload_inputs["query"] = prompt
payload = {
"inputs": payload_inputs,
"response_mode": self.response_mode,
"user": user,
"files": files,
}
elif self.mode == "completion":
payload = {
"inputs": payload_inputs,
"query": prompt,
"response_mode": self.response_mode,
"user": user,
"files": files,
}
else:
payload = {
"inputs": payload_inputs,
"query": prompt,
"response_mode": self.response_mode,
"conversation_id": "",
"user": user,
"files": files,
}
url = f"{self.base_url}/{self.endpoint}"
for attempt in range(1, self.max_retries + 1):
try:
if self.response_mode == "streaming":
parsed = self._request_dify_stream(url, payload, headers, tag)
else:
response = requests.post(url, headers=headers, json=payload, timeout=self.timeout_seconds)
response.raise_for_status()
parsed = self._parse_dify_response(response.json())
if parsed and parsed.get("text"):
return parsed
self.last_error = f"empty_model_output:{self.mode}"
except Exception as exc:
self.last_error = f"request_failed:attempt_{attempt}:{exc}"
self.LOG.warning(f"[UnifiedLLMClient] Dify 请求失败: tag={tag}, attempt={attempt}, error={exc}")
if attempt < self.max_retries:
time.sleep(self.retry_delay_seconds * attempt)
return None
def _request_openai_json(self, url: str, payload: Dict[str, Any], headers: Dict[str, str]) -> Tuple[str, Dict[str, Any]]:
response = requests.post(url, json=payload, headers=headers, timeout=self.timeout_seconds)
response.raise_for_status()
data = response.json()
return self._extract_openai_text(data), data
def _request_openai_stream(
self,
url: str,
payload: Dict[str, Any],
headers: Dict[str, str],
) -> Tuple[str, Dict[str, Any]]:
chunks: List[str] = []
with requests.post(url, json=payload, headers=headers, timeout=self.timeout_seconds, stream=True) as response:
response.raise_for_status()
buffer = b""
for part in response.iter_content(chunk_size=None):
if not part:
continue
buffer += part
while b"\n\n" in buffer:
event, buffer = buffer.split(b"\n\n", 1)
try:
text_piece, done = self._parse_openai_sse_event(event.decode("utf-8"))
except UnicodeDecodeError:
buffer = event + b"\n\n" + buffer
break
if text_piece:
chunks.append(text_piece)
if done:
break
return "".join(chunks).strip(), {"stream_text": "".join(chunks).strip()}
def _request_dify_stream(
self,
url: str,
payload: Dict[str, Any],
headers: Dict[str, str],
tag: str,
) -> Optional[Dict[str, Any]]:
with requests.post(url, headers=headers, json=payload, timeout=self.timeout_seconds, stream=True) as response:
response.raise_for_status()
event_name = ""
text_fragments: List[str] = []
final_payload = None
for raw_line in response.iter_lines(decode_unicode=True):
if raw_line is None:
continue
line = str(raw_line).strip()
if not line:
continue
if line.startswith("event:"):
event_name = line[6:].strip()
continue
if not line.startswith("data:"):
continue
data_text = line[5:].strip()
if not data_text or data_text == "[DONE]":
continue
try:
chunk = json.loads(data_text)
except Exception:
continue
candidate_text = self._extract_dify_stream_text(chunk)
if candidate_text:
text_fragments.append(candidate_text)
chunk_event = str(chunk.get("event") or event_name or "").strip()
if chunk_event in {"workflow_finished", "message_end"}:
final_payload = chunk
if final_payload:
parsed = self._parse_dify_response(final_payload)
if parsed and parsed.get("text"):
return parsed
text = "".join(fragment for fragment in text_fragments if fragment).strip()
if text:
return {"text": text, "usage": {}, "raw": final_payload or {}}
self.LOG.warning(f"[UnifiedLLMClient] Dify 流式响应未产出有效内容: tag={tag}")
return None
@staticmethod
def _build_messages(system_prompt: str, user_prompt: str, image_urls: List[str]) -> List[Dict[str, Any]]:
user_content: str | List[Dict[str, Any]]
if image_urls:
content_parts: List[Dict[str, Any]] = [{"type": "text", "text": user_prompt}]
for image_url in image_urls:
if image_url:
content_parts.append({"type": "image_url", "image_url": {"url": image_url}})
user_content = content_parts
else:
user_content = user_prompt
messages: List[Dict[str, Any]] = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_content})
return messages
@staticmethod
def _extract_openai_text(data: Dict[str, Any]) -> str:
choices = data.get("choices") or []
if choices:
message = choices[0].get("message", {}) or {}
content = message.get("content")
if isinstance(content, str) and content.strip():
return content.strip()
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
text = item.get("text") or item.get("content")
if isinstance(text, str) and text.strip():
parts.append(text.strip())
if parts:
return "\n".join(parts).strip()
for key in ("reasoning_content", "text", "output_text"):
value = message.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
for key in ("output_text", "text", "answer", "response"):
value = data.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return ""
@classmethod
def _parse_openai_sse_event(cls, event_text: str) -> Tuple[str, bool]:
lines = [line.strip() for line in event_text.splitlines() if line.strip()]
data_lines = [line[5:].strip() for line in lines if line.startswith("data:")]
if not data_lines:
return "", False
data = "\n".join(data_lines)
if data == "[DONE]":
return "", True
obj = json.loads(data)
choice = (obj.get("choices") or [{}])[0]
delta = choice.get("delta") or {}
content = delta.get("content")
if isinstance(content, str):
return content, False
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
text = item.get("text") or item.get("content")
if isinstance(text, str):
parts.append(text)
return "".join(parts), False
return "", False
def _parse_dify_response(self, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if self.mode == "workflow":
return self._parse_dify_workflow_response(data)
answer = str(data.get("answer", "") or "").strip()
usage = (data.get("metadata") or {}).get("usage", {}) or {}
return {"text": answer, "usage": usage, "raw": data}
def _parse_dify_workflow_response(self, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
payload = (data or {}).get("data", {}) or {}
outputs = payload.get("outputs", {}) or {}
text = ""
for key in filter(None, [self.workflow_output_key, "text", "answer", "result_json", "result"]):
if outputs.get(key) is not None:
text = self._stringify_output(outputs.get(key))
if text:
break
if not text:
for value in outputs.values():
text = self._stringify_output(value)
if text:
break
usage = {
"total_tokens": payload.get("total_tokens"),
"latency": payload.get("elapsed_time"),
}
return {"text": text.strip(), "usage": usage, "raw": data}
def _extract_dify_stream_text(self, chunk: Dict[str, Any]) -> str:
if not isinstance(chunk, dict):
return ""
payload = (chunk.get("data") or {}) if isinstance(chunk.get("data"), dict) else {}
outputs = payload.get("outputs", {}) if isinstance(payload.get("outputs"), dict) else {}
for key in filter(None, [self.workflow_output_key, "text", "answer", "result_json", "result"]):
if outputs.get(key) is not None:
return self._stringify_output(outputs.get(key))
for key in ("text", "answer"):
if chunk.get(key) is not None:
return self._stringify_output(chunk.get(key))
return ""
@staticmethod
def _extract_openai_usage(data: Dict[str, Any]) -> Dict[str, Any]:
usage = data.get("usage", {}) or {}
if usage:
return usage
return {}
@staticmethod
def _stringify_output(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value.strip()
if isinstance(value, (dict, list)):
return json.dumps(value, ensure_ascii=False)
return str(value).strip()
@classmethod
def _normalize_config(cls, config: Dict[str, Any]) -> Dict[str, Any]:
normalized = LLMRegistry.resolve(config or {})
normalized["enabled"] = bool(
normalized.get("enabled", normalized.get("enable", True))
)
if not normalized.get("provider"):
normalized["provider"] = cls._guess_provider(normalized)
parsed_url = cls._split_url(
normalized.get("api_url")
or normalized.get("url")
)
base_url = (
normalized.get("base_url")
or normalized.get("api_base_url")
or parsed_url[0]
or ""
)
endpoint = (
normalized.get("endpoint")
or parsed_url[1]
or ""
)
normalized["base_url"] = str(base_url).rstrip("/")
normalized["endpoint"] = str(endpoint).lstrip("/")
normalized["api_key"] = (
normalized.get("api_key")
or normalized.get("api-key")
or normalized.get("authorization")
or ""
)
normalized["timeout_seconds"] = int(
normalized.get("timeout_seconds")
or normalized.get("request_timeout_seconds")
or normalized.get("request_timeout")
or 60
)
normalized["max_retries"] = int(normalized.get("max_retries", len(normalized.get("retry_delays_seconds", [])) + 1 or 3))
normalized["retry_delay_seconds"] = float(normalized.get("retry_delay_seconds", 1.0))
normalized["response_mode"] = normalized.get("response_mode", "blocking")
normalized["workflow_output_key"] = normalized.get("workflow_output_key", "text")
if normalized["provider"] == "dify":
default_endpoint = cls._guess_dify_endpoint(normalized)
if not normalized["endpoint"]:
normalized["endpoint"] = default_endpoint
else:
if not normalized["endpoint"]:
normalized["endpoint"] = "chat/completions"
return normalized
@staticmethod
def _guess_provider(config: Dict[str, Any]) -> str:
api_key = str(
config.get("api_key")
or config.get("api-key")
or config.get("authorization")
or ""
).strip()
url = str(config.get("api_url") or config.get("url") or config.get("endpoint") or "").lower()
mode = str(config.get("mode", "")).lower()
if "workflows/run" in url or "chat-messages" in url or "completion-messages" in url:
return "dify"
if api_key.startswith("app-") or mode in {"workflow", "completion"}:
return "dify"
return "openai_compatible"
@staticmethod
def _guess_dify_endpoint(config: Dict[str, Any]) -> str:
mode = str(config.get("mode", "chat")).strip().lower()
if mode == "workflow":
return "workflows/run"
if mode == "completion":
return "completion-messages"
return "chat-messages"
@staticmethod
def _split_url(url: Optional[str]) -> Tuple[str, str]:
if not url:
return "", ""
parsed = urlparse(str(url))
if not parsed.scheme or not parsed.netloc:
return "", str(url)
base = f"{parsed.scheme}://{parsed.netloc}"
return base, parsed.path.lstrip("/")
@staticmethod
def _build_auth_header(value: str) -> str:
token = str(value or "").strip()
if not token:
return ""
if token.lower().startswith("bearer "):
return token
return f"Bearer {token}"
@staticmethod
def _stringify_inputs(inputs: Dict[str, Any]) -> str:
if not inputs:
return ""
try:
return json.dumps(inputs, ensure_ascii=False)
except Exception:
return str(inputs)