541 lines
20 KiB
Python
541 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
from urllib.parse import urlparse
|
|
|
|
import requests
|
|
from loguru import logger
|
|
|
|
from utils.ai.llm_registry import LLMRegistry
|
|
|
|
|
|
class UnifiedLLMClient:
|
|
"""统一的 LLM 调用客户端,兼容 OpenAI-compatible 与 Dify。"""
|
|
|
|
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
|
self.LOG = logger
|
|
self.raw_config = config or {}
|
|
self.config = self._normalize_config(self.raw_config)
|
|
self.enabled = bool(self.config.get("enabled", True))
|
|
self.provider = str(self.config.get("provider", "openai_compatible")).strip().lower()
|
|
self.base_url = str(self.config.get("base_url", "")).rstrip("/")
|
|
self.endpoint = str(self.config.get("endpoint", "")).lstrip("/")
|
|
self.api_key = str(self.config.get("api_key", "")).strip()
|
|
self.model = str(self.config.get("model", "")).strip()
|
|
self.timeout_seconds = int(self.config.get("timeout_seconds", 60))
|
|
self.timeout = self.timeout_seconds
|
|
self.temperature = float(self.config.get("temperature", 0.7))
|
|
self.max_tokens = int(self.config.get("max_tokens", 1024))
|
|
self.stream = bool(self.config.get("stream", False))
|
|
self.max_retries = max(int(self.config.get("max_retries", 3) or 3), 1)
|
|
self.retry_delay_seconds = float(self.config.get("retry_delay_seconds", 1.0) or 1.0)
|
|
self.mode = str(self.config.get("mode", "chat")).strip().lower()
|
|
self.response_mode = str(self.config.get("response_mode", "blocking")).strip().lower()
|
|
self.workflow_output_key = str(self.config.get("workflow_output_key", "text")).strip()
|
|
self.default_system_prompt = str(self.config.get("system_prompt", "")).strip()
|
|
self.last_error = ""
|
|
|
|
def is_available(self) -> bool:
|
|
if not self.enabled:
|
|
return False
|
|
if self.provider == "openai_compatible":
|
|
return bool(self.base_url and self.endpoint and self.model)
|
|
if self.provider == "dify":
|
|
return bool(self.base_url and self.endpoint and self.api_key)
|
|
return False
|
|
|
|
def chat(
|
|
self,
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
user_id: str,
|
|
image_urls: Optional[List[str]] = None,
|
|
) -> str:
|
|
result = self.generate(
|
|
system_prompt=system_prompt,
|
|
user_prompt=user_prompt,
|
|
user=user_id,
|
|
image_urls=image_urls or [],
|
|
)
|
|
return (result or {}).get("text", "") or ""
|
|
|
|
def run(
|
|
self,
|
|
prompt: str,
|
|
user: str,
|
|
inputs: Optional[Dict[str, Any]] = None,
|
|
tag: str = "",
|
|
) -> Optional[Dict[str, Any]]:
|
|
if self.provider == "dify":
|
|
return self.generate(prompt=prompt, user=user, inputs=inputs or {}, tag=tag)
|
|
|
|
effective_prompt = prompt or self._stringify_inputs(inputs or {})
|
|
return self.generate(
|
|
system_prompt=self.default_system_prompt,
|
|
user_prompt=effective_prompt,
|
|
user=user,
|
|
inputs=inputs or {},
|
|
tag=tag,
|
|
)
|
|
|
|
def generate(
|
|
self,
|
|
prompt: str = "",
|
|
user: str = "",
|
|
inputs: Optional[Dict[str, Any]] = None,
|
|
tag: str = "",
|
|
system_prompt: str = "",
|
|
user_prompt: str = "",
|
|
image_urls: Optional[List[str]] = None,
|
|
files: Optional[List[Dict[str, Any]]] = None,
|
|
) -> Optional[Dict[str, Any]]:
|
|
self.last_error = ""
|
|
if not self.is_available():
|
|
self.last_error = "client_unavailable"
|
|
return None
|
|
|
|
if self.provider == "dify":
|
|
return self._generate_dify(
|
|
prompt=prompt,
|
|
user=user,
|
|
inputs=inputs or {},
|
|
tag=tag,
|
|
files=files or [],
|
|
)
|
|
if self.provider == "openai_compatible":
|
|
return self._generate_openai(
|
|
system_prompt=system_prompt,
|
|
user_prompt=user_prompt or prompt,
|
|
user=user,
|
|
image_urls=image_urls or [],
|
|
)
|
|
|
|
self.last_error = f"unsupported_provider:{self.provider}"
|
|
return None
|
|
|
|
def _generate_openai(
|
|
self,
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
user: str,
|
|
image_urls: List[str],
|
|
) -> Optional[Dict[str, Any]]:
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": self._build_messages(system_prompt or self.default_system_prompt, user_prompt, image_urls),
|
|
"temperature": self.temperature,
|
|
"max_tokens": self.max_tokens,
|
|
"user": user,
|
|
"stream": self.stream,
|
|
}
|
|
headers = {"Content-Type": "application/json"}
|
|
if self.api_key:
|
|
headers["Authorization"] = self._build_auth_header(self.api_key)
|
|
|
|
url = f"{self.base_url}/{self.endpoint}"
|
|
for attempt in range(1, self.max_retries + 1):
|
|
try:
|
|
if self.stream:
|
|
text, raw = self._request_openai_stream(url, payload, headers)
|
|
else:
|
|
text, raw = self._request_openai_json(url, payload, headers)
|
|
if text:
|
|
return {
|
|
"text": text,
|
|
"usage": self._extract_openai_usage(raw),
|
|
"raw": raw,
|
|
}
|
|
self.last_error = f"empty_model_output:{self.model}"
|
|
except Exception as exc:
|
|
self.last_error = f"request_failed:attempt_{attempt}:{exc}"
|
|
if attempt < self.max_retries:
|
|
time.sleep(self.retry_delay_seconds * attempt)
|
|
return None
|
|
|
|
def _generate_dify(
|
|
self,
|
|
prompt: str,
|
|
user: str,
|
|
inputs: Dict[str, Any],
|
|
tag: str,
|
|
files: List[Dict[str, Any]],
|
|
) -> Optional[Dict[str, Any]]:
|
|
headers = {
|
|
"Authorization": self._build_auth_header(self.api_key),
|
|
"Content-Type": "application/json",
|
|
}
|
|
payload_inputs = dict(inputs or {})
|
|
if self.mode == "workflow":
|
|
if prompt and "query" not in payload_inputs:
|
|
payload_inputs["query"] = prompt
|
|
payload = {
|
|
"inputs": payload_inputs,
|
|
"response_mode": self.response_mode,
|
|
"user": user,
|
|
"files": files,
|
|
}
|
|
elif self.mode == "completion":
|
|
payload = {
|
|
"inputs": payload_inputs,
|
|
"query": prompt,
|
|
"response_mode": self.response_mode,
|
|
"user": user,
|
|
"files": files,
|
|
}
|
|
else:
|
|
payload = {
|
|
"inputs": payload_inputs,
|
|
"query": prompt,
|
|
"response_mode": self.response_mode,
|
|
"conversation_id": "",
|
|
"user": user,
|
|
"files": files,
|
|
}
|
|
|
|
url = f"{self.base_url}/{self.endpoint}"
|
|
for attempt in range(1, self.max_retries + 1):
|
|
try:
|
|
if self.response_mode == "streaming":
|
|
parsed = self._request_dify_stream(url, payload, headers, tag)
|
|
else:
|
|
response = requests.post(url, headers=headers, json=payload, timeout=self.timeout_seconds)
|
|
response.raise_for_status()
|
|
parsed = self._parse_dify_response(response.json())
|
|
if parsed and parsed.get("text"):
|
|
return parsed
|
|
self.last_error = f"empty_model_output:{self.mode}"
|
|
except Exception as exc:
|
|
self.last_error = f"request_failed:attempt_{attempt}:{exc}"
|
|
self.LOG.warning(f"[UnifiedLLMClient] Dify 请求失败: tag={tag}, attempt={attempt}, error={exc}")
|
|
if attempt < self.max_retries:
|
|
time.sleep(self.retry_delay_seconds * attempt)
|
|
return None
|
|
|
|
def _request_openai_json(self, url: str, payload: Dict[str, Any], headers: Dict[str, str]) -> Tuple[str, Dict[str, Any]]:
|
|
response = requests.post(url, json=payload, headers=headers, timeout=self.timeout_seconds)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return self._extract_openai_text(data), data
|
|
|
|
def _request_openai_stream(
|
|
self,
|
|
url: str,
|
|
payload: Dict[str, Any],
|
|
headers: Dict[str, str],
|
|
) -> Tuple[str, Dict[str, Any]]:
|
|
chunks: List[str] = []
|
|
with requests.post(url, json=payload, headers=headers, timeout=self.timeout_seconds, stream=True) as response:
|
|
response.raise_for_status()
|
|
buffer = b""
|
|
for part in response.iter_content(chunk_size=None):
|
|
if not part:
|
|
continue
|
|
buffer += part
|
|
while b"\n\n" in buffer:
|
|
event, buffer = buffer.split(b"\n\n", 1)
|
|
try:
|
|
text_piece, done = self._parse_openai_sse_event(event.decode("utf-8"))
|
|
except UnicodeDecodeError:
|
|
buffer = event + b"\n\n" + buffer
|
|
break
|
|
if text_piece:
|
|
chunks.append(text_piece)
|
|
if done:
|
|
break
|
|
return "".join(chunks).strip(), {"stream_text": "".join(chunks).strip()}
|
|
|
|
def _request_dify_stream(
|
|
self,
|
|
url: str,
|
|
payload: Dict[str, Any],
|
|
headers: Dict[str, str],
|
|
tag: str,
|
|
) -> Optional[Dict[str, Any]]:
|
|
with requests.post(url, headers=headers, json=payload, timeout=self.timeout_seconds, stream=True) as response:
|
|
response.raise_for_status()
|
|
event_name = ""
|
|
text_fragments: List[str] = []
|
|
final_payload = None
|
|
|
|
for raw_line in response.iter_lines(decode_unicode=True):
|
|
if raw_line is None:
|
|
continue
|
|
line = str(raw_line).strip()
|
|
if not line:
|
|
continue
|
|
if line.startswith("event:"):
|
|
event_name = line[6:].strip()
|
|
continue
|
|
if not line.startswith("data:"):
|
|
continue
|
|
|
|
data_text = line[5:].strip()
|
|
if not data_text or data_text == "[DONE]":
|
|
continue
|
|
try:
|
|
chunk = json.loads(data_text)
|
|
except Exception:
|
|
continue
|
|
|
|
candidate_text = self._extract_dify_stream_text(chunk)
|
|
if candidate_text:
|
|
text_fragments.append(candidate_text)
|
|
|
|
chunk_event = str(chunk.get("event") or event_name or "").strip()
|
|
if chunk_event in {"workflow_finished", "message_end"}:
|
|
final_payload = chunk
|
|
|
|
if final_payload:
|
|
parsed = self._parse_dify_response(final_payload)
|
|
if parsed and parsed.get("text"):
|
|
return parsed
|
|
|
|
text = "".join(fragment for fragment in text_fragments if fragment).strip()
|
|
if text:
|
|
return {"text": text, "usage": {}, "raw": final_payload or {}}
|
|
|
|
self.LOG.warning(f"[UnifiedLLMClient] Dify 流式响应未产出有效内容: tag={tag}")
|
|
return None
|
|
|
|
@staticmethod
|
|
def _build_messages(system_prompt: str, user_prompt: str, image_urls: List[str]) -> List[Dict[str, Any]]:
|
|
user_content: str | List[Dict[str, Any]]
|
|
if image_urls:
|
|
content_parts: List[Dict[str, Any]] = [{"type": "text", "text": user_prompt}]
|
|
for image_url in image_urls:
|
|
if image_url:
|
|
content_parts.append({"type": "image_url", "image_url": {"url": image_url}})
|
|
user_content = content_parts
|
|
else:
|
|
user_content = user_prompt
|
|
messages: List[Dict[str, Any]] = []
|
|
if system_prompt:
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
messages.append({"role": "user", "content": user_content})
|
|
return messages
|
|
|
|
@staticmethod
|
|
def _extract_openai_text(data: Dict[str, Any]) -> str:
|
|
choices = data.get("choices") or []
|
|
if choices:
|
|
message = choices[0].get("message", {}) or {}
|
|
content = message.get("content")
|
|
if isinstance(content, str) and content.strip():
|
|
return content.strip()
|
|
if isinstance(content, list):
|
|
parts = []
|
|
for item in content:
|
|
if isinstance(item, dict):
|
|
text = item.get("text") or item.get("content")
|
|
if isinstance(text, str) and text.strip():
|
|
parts.append(text.strip())
|
|
if parts:
|
|
return "\n".join(parts).strip()
|
|
for key in ("reasoning_content", "text", "output_text"):
|
|
value = message.get(key)
|
|
if isinstance(value, str) and value.strip():
|
|
return value.strip()
|
|
for key in ("output_text", "text", "answer", "response"):
|
|
value = data.get(key)
|
|
if isinstance(value, str) and value.strip():
|
|
return value.strip()
|
|
return ""
|
|
|
|
@classmethod
|
|
def _parse_openai_sse_event(cls, event_text: str) -> Tuple[str, bool]:
|
|
lines = [line.strip() for line in event_text.splitlines() if line.strip()]
|
|
data_lines = [line[5:].strip() for line in lines if line.startswith("data:")]
|
|
if not data_lines:
|
|
return "", False
|
|
data = "\n".join(data_lines)
|
|
if data == "[DONE]":
|
|
return "", True
|
|
obj = json.loads(data)
|
|
choice = (obj.get("choices") or [{}])[0]
|
|
delta = choice.get("delta") or {}
|
|
content = delta.get("content")
|
|
if isinstance(content, str):
|
|
return content, False
|
|
if isinstance(content, list):
|
|
parts = []
|
|
for item in content:
|
|
if isinstance(item, dict):
|
|
text = item.get("text") or item.get("content")
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
return "".join(parts), False
|
|
return "", False
|
|
|
|
def _parse_dify_response(self, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
if self.mode == "workflow":
|
|
return self._parse_dify_workflow_response(data)
|
|
answer = str(data.get("answer", "") or "").strip()
|
|
usage = (data.get("metadata") or {}).get("usage", {}) or {}
|
|
return {"text": answer, "usage": usage, "raw": data}
|
|
|
|
def _parse_dify_workflow_response(self, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
payload = (data or {}).get("data", {}) or {}
|
|
outputs = payload.get("outputs", {}) or {}
|
|
text = ""
|
|
|
|
for key in filter(None, [self.workflow_output_key, "text", "answer", "result_json", "result"]):
|
|
if outputs.get(key) is not None:
|
|
text = self._stringify_output(outputs.get(key))
|
|
if text:
|
|
break
|
|
|
|
if not text:
|
|
for value in outputs.values():
|
|
text = self._stringify_output(value)
|
|
if text:
|
|
break
|
|
|
|
usage = {
|
|
"total_tokens": payload.get("total_tokens"),
|
|
"latency": payload.get("elapsed_time"),
|
|
}
|
|
return {"text": text.strip(), "usage": usage, "raw": data}
|
|
|
|
def _extract_dify_stream_text(self, chunk: Dict[str, Any]) -> str:
|
|
if not isinstance(chunk, dict):
|
|
return ""
|
|
payload = (chunk.get("data") or {}) if isinstance(chunk.get("data"), dict) else {}
|
|
outputs = payload.get("outputs", {}) if isinstance(payload.get("outputs"), dict) else {}
|
|
|
|
for key in filter(None, [self.workflow_output_key, "text", "answer", "result_json", "result"]):
|
|
if outputs.get(key) is not None:
|
|
return self._stringify_output(outputs.get(key))
|
|
|
|
for key in ("text", "answer"):
|
|
if chunk.get(key) is not None:
|
|
return self._stringify_output(chunk.get(key))
|
|
|
|
return ""
|
|
|
|
@staticmethod
|
|
def _extract_openai_usage(data: Dict[str, Any]) -> Dict[str, Any]:
|
|
usage = data.get("usage", {}) or {}
|
|
if usage:
|
|
return usage
|
|
return {}
|
|
|
|
@staticmethod
|
|
def _stringify_output(value: Any) -> str:
|
|
if value is None:
|
|
return ""
|
|
if isinstance(value, str):
|
|
return value.strip()
|
|
if isinstance(value, (dict, list)):
|
|
return json.dumps(value, ensure_ascii=False)
|
|
return str(value).strip()
|
|
|
|
@classmethod
|
|
def _normalize_config(cls, config: Dict[str, Any]) -> Dict[str, Any]:
|
|
normalized = LLMRegistry.resolve(config or {})
|
|
normalized["enabled"] = bool(
|
|
normalized.get("enabled", normalized.get("enable", True))
|
|
)
|
|
|
|
if not normalized.get("provider"):
|
|
normalized["provider"] = cls._guess_provider(normalized)
|
|
|
|
parsed_url = cls._split_url(
|
|
normalized.get("api_url")
|
|
or normalized.get("url")
|
|
)
|
|
base_url = (
|
|
normalized.get("base_url")
|
|
or normalized.get("api_base_url")
|
|
or parsed_url[0]
|
|
or ""
|
|
)
|
|
endpoint = (
|
|
normalized.get("endpoint")
|
|
or parsed_url[1]
|
|
or ""
|
|
)
|
|
|
|
normalized["base_url"] = str(base_url).rstrip("/")
|
|
normalized["endpoint"] = str(endpoint).lstrip("/")
|
|
normalized["api_key"] = (
|
|
normalized.get("api_key")
|
|
or normalized.get("api-key")
|
|
or normalized.get("authorization")
|
|
or ""
|
|
)
|
|
normalized["timeout_seconds"] = int(
|
|
normalized.get("timeout_seconds")
|
|
or normalized.get("request_timeout_seconds")
|
|
or normalized.get("request_timeout")
|
|
or 60
|
|
)
|
|
normalized["max_retries"] = int(normalized.get("max_retries", len(normalized.get("retry_delays_seconds", [])) + 1 or 3))
|
|
normalized["retry_delay_seconds"] = float(normalized.get("retry_delay_seconds", 1.0))
|
|
normalized["response_mode"] = normalized.get("response_mode", "blocking")
|
|
normalized["workflow_output_key"] = normalized.get("workflow_output_key", "text")
|
|
|
|
if normalized["provider"] == "dify":
|
|
default_endpoint = cls._guess_dify_endpoint(normalized)
|
|
if not normalized["endpoint"]:
|
|
normalized["endpoint"] = default_endpoint
|
|
else:
|
|
if not normalized["endpoint"]:
|
|
normalized["endpoint"] = "chat/completions"
|
|
|
|
return normalized
|
|
|
|
@staticmethod
|
|
def _guess_provider(config: Dict[str, Any]) -> str:
|
|
api_key = str(
|
|
config.get("api_key")
|
|
or config.get("api-key")
|
|
or config.get("authorization")
|
|
or ""
|
|
).strip()
|
|
url = str(config.get("api_url") or config.get("url") or config.get("endpoint") or "").lower()
|
|
mode = str(config.get("mode", "")).lower()
|
|
if "workflows/run" in url or "chat-messages" in url or "completion-messages" in url:
|
|
return "dify"
|
|
if api_key.startswith("app-") or mode in {"workflow", "completion"}:
|
|
return "dify"
|
|
return "openai_compatible"
|
|
|
|
@staticmethod
|
|
def _guess_dify_endpoint(config: Dict[str, Any]) -> str:
|
|
mode = str(config.get("mode", "chat")).strip().lower()
|
|
if mode == "workflow":
|
|
return "workflows/run"
|
|
if mode == "completion":
|
|
return "completion-messages"
|
|
return "chat-messages"
|
|
|
|
@staticmethod
|
|
def _split_url(url: Optional[str]) -> Tuple[str, str]:
|
|
if not url:
|
|
return "", ""
|
|
parsed = urlparse(str(url))
|
|
if not parsed.scheme or not parsed.netloc:
|
|
return "", str(url)
|
|
base = f"{parsed.scheme}://{parsed.netloc}"
|
|
return base, parsed.path.lstrip("/")
|
|
|
|
@staticmethod
|
|
def _build_auth_header(value: str) -> str:
|
|
token = str(value or "").strip()
|
|
if not token:
|
|
return ""
|
|
if token.lower().startswith("bearer "):
|
|
return token
|
|
return f"Bearer {token}"
|
|
|
|
@staticmethod
|
|
def _stringify_inputs(inputs: Dict[str, Any]) -> str:
|
|
if not inputs:
|
|
return ""
|
|
try:
|
|
return json.dumps(inputs, ensure_ascii=False)
|
|
except Exception:
|
|
return str(inputs)
|