Files
abot/utils/ai/unified_llm.py

734 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import base64
import binascii
import json
import mimetypes
import time
from collections import deque
from threading import Lock
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import requests
from requests import HTTPError
from loguru import logger
from utils.ai.llm_registry import LLMRegistry
from utils.trace_context import get_current_trace_id, format_trace_prefix
class UnifiedLLMClient:
"""统一的 LLM 调用客户端,兼容 OpenAI-compatible 与 Dify。"""
# 运行时观测快照:
# 1. 只保留最近一小段调用窗口,避免无限增长;
# 2. 放在统一客户端层,所有复用该客户端的插件天然受益;
# 3. 这里存的不是业务明细,而是运维看板需要的轻量健康指标。
_runtime_metrics = deque(maxlen=50)
_runtime_lock = Lock()
def __init__(self, config: Optional[Dict[str, Any]] = None):
self.LOG = logger
self.raw_config = config or {}
self.config = self._normalize_config(self.raw_config)
self.enabled = bool(self.config.get("enabled", True))
self.provider = str(self.config.get("provider", "openai_compatible")).strip().lower()
self.base_url = str(self.config.get("base_url", "")).rstrip("/")
self.endpoint = str(self.config.get("endpoint", "")).lstrip("/")
self.api_key = str(self.config.get("api_key", "")).strip()
self.model = str(self.config.get("model", "")).strip()
self.timeout_seconds = int(self.config.get("timeout_seconds", 60))
self.timeout = self.timeout_seconds
self.temperature = float(self.config.get("temperature", 0.7))
self.max_tokens = int(self.config.get("max_tokens", 1024))
self.stream = bool(self.config.get("stream", False))
self.max_retries = max(int(self.config.get("max_retries", 3) or 3), 1)
self.retry_delay_seconds = float(self.config.get("retry_delay_seconds", 1.0) or 1.0)
self.mode = str(self.config.get("mode", "chat")).strip().lower()
self.response_mode = str(self.config.get("response_mode", "blocking")).strip().lower()
self.workflow_output_key = str(self.config.get("workflow_output_key", "text")).strip()
self.default_system_prompt = str(self.config.get("system_prompt", "")).strip()
self.last_error = ""
@classmethod
def _record_runtime_metric(
cls,
*,
provider: str,
backend: str,
scene: str,
model: str,
trace_id: str,
success: bool,
latency_ms: float,
error: str = "",
) -> None:
"""记录最近一次 LLM 调用结果,供后台健康面板聚合展示。"""
with cls._runtime_lock:
cls._runtime_metrics.append({
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"provider": str(provider or "").strip(),
"backend": str(backend or "").strip(),
"scene": str(scene or "").strip(),
"model": str(model or "").strip(),
"trace_id": str(trace_id or "").strip(),
"success": bool(success),
"latency_ms": round(float(latency_ms or 0.0), 2),
"error": str(error or "").strip()[:300],
})
@classmethod
def get_runtime_snapshot(cls) -> Dict[str, Any]:
"""返回最近调用窗口的聚合快照,供后台可观测性接口直接复用。"""
with cls._runtime_lock:
rows = list(cls._runtime_metrics)
total_calls = len(rows)
success_calls = sum(1 for item in rows if item.get("success"))
failed_calls = total_calls - success_calls
avg_latency_ms = round(
sum(float(item.get("latency_ms") or 0.0) for item in rows) / total_calls,
2
) if total_calls else 0.0
last_call = rows[-1] if rows else {}
last_error = ""
for item in reversed(rows):
if not item.get("success") and item.get("error"):
last_error = str(item.get("error") or "").strip()
break
return {
"window_size": cls._runtime_metrics.maxlen,
"total_calls": total_calls,
"success_calls": success_calls,
"failed_calls": failed_calls,
"success_rate": round((success_calls / total_calls) * 100, 2) if total_calls else 0.0,
"avg_latency_ms": avg_latency_ms,
"last_call": last_call,
"last_error": last_error,
}
def is_available(self) -> bool:
if not self.enabled:
return False
if self.provider == "openai_compatible":
return bool(self.base_url and self.endpoint and self.model)
if self.provider == "dify":
return bool(self.base_url and self.endpoint and self.api_key)
return False
def chat(
self,
system_prompt: str,
user_prompt: str,
user_id: str,
image_urls: Optional[List[str]] = None,
) -> str:
result = self.generate(
system_prompt=system_prompt,
user_prompt=user_prompt,
user=user_id,
image_urls=image_urls or [],
)
return (result or {}).get("text", "") or ""
def run(
self,
prompt: str,
user: str,
inputs: Optional[Dict[str, Any]] = None,
tag: str = "",
files: Optional[List[Dict[str, Any]]] = None,
) -> Optional[Dict[str, Any]]:
if self.provider == "dify":
return self.generate(prompt=prompt, user=user, inputs=inputs or {}, tag=tag, files=files or [])
effective_prompt = prompt or self._stringify_inputs(inputs or {})
return self.generate(
system_prompt=self.default_system_prompt,
user_prompt=effective_prompt,
user=user,
inputs=inputs or {},
tag=tag,
files=files or [],
)
def upload_dify_file(
self,
*,
user: str,
file_bytes: bytes,
filename: str,
mime_type: str = "",
) -> Optional[Dict[str, Any]]:
self.last_error = ""
if self.provider != "dify":
self.last_error = "upload_not_supported_for_provider"
return None
if not self.base_url or not self.api_key or not user or not file_bytes or not filename:
self.last_error = "upload_missing_required_fields"
return None
upload_url = f"{self.base_url}/files/upload"
headers = {"Authorization": self._build_auth_header(self.api_key)}
detected_mime = mime_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
files = {
"file": (filename, file_bytes, detected_mime),
}
data = {"user": user}
for attempt in range(1, self.max_retries + 1):
try:
response = requests.post(upload_url, headers=headers, files=files, data=data, timeout=self.timeout_seconds)
response.raise_for_status()
payload = response.json() or {}
if payload.get("id"):
return payload
self.last_error = "upload_missing_file_id"
except Exception as exc:
self.last_error = f"upload_failed:attempt_{attempt}:{exc}"
if attempt < self.max_retries:
time.sleep(self.retry_delay_seconds * attempt)
return None
@staticmethod
def build_dify_file_ref(
*,
file_type: str = "image",
remote_url: str = "",
upload_file_id: str = "",
) -> Dict[str, Any]:
if upload_file_id:
return {
"type": file_type,
"transfer_method": "local_file",
"upload_file_id": upload_file_id,
}
if remote_url:
return {
"type": file_type,
"transfer_method": "remote_url",
"url": remote_url,
}
return {}
@staticmethod
def decode_data_url(data_url: str) -> Tuple[bytes, str]:
raw = str(data_url or "").strip()
if not raw.startswith("data:") or "," not in raw:
return b"", ""
header, encoded = raw.split(",", 1)
mime_type = header[5:].split(";", 1)[0].strip()
try:
return base64.b64decode(encoded), mime_type
except (ValueError, binascii.Error):
return b"", mime_type
def generate(
self,
prompt: str = "",
user: str = "",
inputs: Optional[Dict[str, Any]] = None,
tag: str = "",
system_prompt: str = "",
user_prompt: str = "",
image_urls: Optional[List[str]] = None,
files: Optional[List[Dict[str, Any]]] = None,
) -> Optional[Dict[str, Any]]:
started_at = time.monotonic()
self.last_error = ""
result: Optional[Dict[str, Any]] = None
current_trace_id = get_current_trace_id()
if not self.is_available():
self.last_error = "client_unavailable"
elif self.provider == "dify":
result = self._generate_dify(
prompt=prompt,
user=user,
inputs=inputs or {},
tag=tag,
files=files or [],
)
elif self.provider == "openai_compatible":
result = self._generate_openai(
system_prompt=system_prompt,
user_prompt=user_prompt or prompt,
user=user,
image_urls=image_urls or [],
)
else:
self.last_error = f"unsupported_provider:{self.provider}"
# 统一在出口记录运行时快照,避免每种 provider 都重复埋点逻辑。
usage = (result or {}).get("usage", {}) if isinstance(result, dict) else {}
latency_ms = 0.0
if isinstance(usage, dict) and usage.get("latency") not in (None, ""):
try:
latency_ms = float(usage.get("latency")) * 1000
except Exception:
latency_ms = 0.0
if latency_ms <= 0:
latency_ms = (time.monotonic() - started_at) * 1000
self._record_runtime_metric(
provider=self.provider,
backend=str(self.config.get("backend", "") or ""),
scene=str(self.config.get("scene", "") or ""),
model=self.model or str(self.mode or ""),
trace_id=current_trace_id,
success=bool(result and result.get("text")),
latency_ms=latency_ms,
error=self.last_error,
)
# 在统一出口补一条轻量 trace 日志,方便把“消息 -> AI 调用”快速串起来。
self.LOG.debug(
f"{format_trace_prefix(current_trace_id)}LLM调用结束 "
f"provider={self.provider} backend={self.config.get('backend', '') or '-'} "
f"scene={self.config.get('scene', '') or '-'} "
f"success={bool(result and result.get('text'))} latency_ms={round(latency_ms, 2)} "
f"error={self.last_error or '-'}"
)
return result
def _generate_openai(
self,
system_prompt: str,
user_prompt: str,
user: str,
image_urls: List[str],
) -> Optional[Dict[str, Any]]:
payload = {
"model": self.model,
"messages": self._build_messages(system_prompt or self.default_system_prompt, user_prompt, image_urls),
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"user": user,
"stream": self.stream,
}
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = self._build_auth_header(self.api_key)
url = f"{self.base_url}/{self.endpoint}"
for attempt in range(1, self.max_retries + 1):
try:
if self.stream:
text, raw = self._request_openai_stream(url, payload, headers)
else:
text, raw = self._request_openai_json(url, payload, headers)
if text:
return {
"text": text,
"usage": self._extract_openai_usage(raw),
"raw": raw,
}
self.last_error = f"empty_model_output:{self.model}"
except Exception as exc:
self.last_error = f"request_failed:attempt_{attempt}:{exc}"
if attempt < self.max_retries:
time.sleep(self.retry_delay_seconds * attempt)
return None
def _generate_dify(
self,
prompt: str,
user: str,
inputs: Dict[str, Any],
tag: str,
files: List[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
headers = {
"Authorization": self._build_auth_header(self.api_key),
"Content-Type": "application/json",
}
payload_inputs = dict(inputs or {})
if self.mode == "workflow":
if prompt and "query" not in payload_inputs:
payload_inputs["query"] = prompt
payload = {
"inputs": payload_inputs,
"response_mode": self.response_mode,
"user": user,
"files": files,
}
elif self.mode == "completion":
payload = {
"inputs": payload_inputs,
"query": prompt,
"response_mode": self.response_mode,
"user": user,
"files": files,
}
else:
payload = {
"inputs": payload_inputs,
"query": prompt,
"response_mode": self.response_mode,
"conversation_id": "",
"user": user,
"files": files,
}
url = f"{self.base_url}/{self.endpoint}"
for attempt in range(1, self.max_retries + 1):
try:
if self.response_mode == "streaming":
parsed = self._request_dify_stream(url, payload, headers, tag)
else:
response = requests.post(url, headers=headers, json=payload, timeout=self.timeout_seconds)
response.raise_for_status()
parsed = self._parse_dify_response(response.json())
if parsed and parsed.get("text"):
return parsed
self.last_error = f"empty_model_output:{self.mode}"
except HTTPError as exc:
# 诊断增强:
# 1. Dify 返回 400 时,异常信息默认只包含状态码和 URL不含具体原因
# 2. 这里把响应体片段追加到 last_error便于快速定位“入参字段/类型”问题。
response_text = ""
response_obj = getattr(exc, "response", None)
if response_obj is not None:
try:
response_text = str(response_obj.text or "").strip()
except Exception:
response_text = ""
response_text = response_text[:500] if response_text else ""
self.last_error = (
f"request_failed:attempt_{attempt}:{exc}"
+ (f" | response_body={response_text}" if response_text else "")
)
self.LOG.warning(f"[UnifiedLLMClient] Dify 请求失败: tag={tag}, attempt={attempt}, error={self.last_error}")
except Exception as exc:
self.last_error = f"request_failed:attempt_{attempt}:{exc}"
self.LOG.warning(f"[UnifiedLLMClient] Dify 请求失败: tag={tag}, attempt={attempt}, error={exc}")
if attempt < self.max_retries:
time.sleep(self.retry_delay_seconds * attempt)
return None
def _request_openai_json(self, url: str, payload: Dict[str, Any], headers: Dict[str, str]) -> Tuple[str, Dict[str, Any]]:
response = requests.post(url, json=payload, headers=headers, timeout=self.timeout_seconds)
response.raise_for_status()
data = response.json()
return self._extract_openai_text(data), data
def _request_openai_stream(
self,
url: str,
payload: Dict[str, Any],
headers: Dict[str, str],
) -> Tuple[str, Dict[str, Any]]:
chunks: List[str] = []
with requests.post(url, json=payload, headers=headers, timeout=self.timeout_seconds, stream=True) as response:
response.raise_for_status()
buffer = b""
for part in response.iter_content(chunk_size=None):
if not part:
continue
buffer += part
while b"\n\n" in buffer:
event, buffer = buffer.split(b"\n\n", 1)
try:
text_piece, done = self._parse_openai_sse_event(event.decode("utf-8"))
except UnicodeDecodeError:
buffer = event + b"\n\n" + buffer
break
if text_piece:
chunks.append(text_piece)
if done:
break
return "".join(chunks).strip(), {"stream_text": "".join(chunks).strip()}
def _request_dify_stream(
self,
url: str,
payload: Dict[str, Any],
headers: Dict[str, str],
tag: str,
) -> Optional[Dict[str, Any]]:
with requests.post(url, headers=headers, json=payload, timeout=self.timeout_seconds, stream=True) as response:
response.raise_for_status()
event_name = ""
text_fragments: List[str] = []
final_payload = None
for raw_line in response.iter_lines(decode_unicode=True):
if raw_line is None:
continue
line = str(raw_line).strip()
if not line:
continue
if line.startswith("event:"):
event_name = line[6:].strip()
continue
if not line.startswith("data:"):
continue
data_text = line[5:].strip()
if not data_text or data_text == "[DONE]":
continue
try:
chunk = json.loads(data_text)
except Exception:
continue
candidate_text = self._extract_dify_stream_text(chunk)
if candidate_text:
text_fragments.append(candidate_text)
chunk_event = str(chunk.get("event") or event_name or "").strip()
if chunk_event in {"workflow_finished", "message_end"}:
final_payload = chunk
if final_payload:
parsed = self._parse_dify_response(final_payload)
if parsed and parsed.get("text"):
return parsed
text = "".join(fragment for fragment in text_fragments if fragment).strip()
if text:
return {"text": text, "usage": {}, "raw": final_payload or {}}
self.LOG.warning(f"[UnifiedLLMClient] Dify 流式响应未产出有效内容: tag={tag}")
return None
@staticmethod
def _build_messages(system_prompt: str, user_prompt: str, image_urls: List[str]) -> List[Dict[str, Any]]:
user_content: str | List[Dict[str, Any]]
if image_urls:
content_parts: List[Dict[str, Any]] = [{"type": "text", "text": user_prompt}]
for image_url in image_urls:
if image_url:
content_parts.append({"type": "image_url", "image_url": {"url": image_url}})
user_content = content_parts
else:
user_content = user_prompt
messages: List[Dict[str, Any]] = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_content})
return messages
@staticmethod
def _extract_openai_text(data: Dict[str, Any]) -> str:
choices = data.get("choices") or []
if choices:
message = choices[0].get("message", {}) or {}
content = message.get("content")
if isinstance(content, str) and content.strip():
return content.strip()
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
text = item.get("text") or item.get("content")
if isinstance(text, str) and text.strip():
parts.append(text.strip())
if parts:
return "\n".join(parts).strip()
for key in ("reasoning_content", "text", "output_text"):
value = message.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
for key in ("output_text", "text", "answer", "response"):
value = data.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return ""
@classmethod
def _parse_openai_sse_event(cls, event_text: str) -> Tuple[str, bool]:
lines = [line.strip() for line in event_text.splitlines() if line.strip()]
data_lines = [line[5:].strip() for line in lines if line.startswith("data:")]
if not data_lines:
return "", False
data = "\n".join(data_lines)
if data == "[DONE]":
return "", True
obj = json.loads(data)
choice = (obj.get("choices") or [{}])[0]
delta = choice.get("delta") or {}
content = delta.get("content")
if isinstance(content, str):
return content, False
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
text = item.get("text") or item.get("content")
if isinstance(text, str):
parts.append(text)
return "".join(parts), False
return "", False
def _parse_dify_response(self, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if self.mode == "workflow":
return self._parse_dify_workflow_response(data)
answer = str(data.get("answer", "") or "").strip()
usage = (data.get("metadata") or {}).get("usage", {}) or {}
return {"text": answer, "usage": usage, "raw": data}
def _parse_dify_workflow_response(self, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
payload = (data or {}).get("data", {}) or {}
outputs = payload.get("outputs", {}) or {}
text = ""
for key in filter(None, [self.workflow_output_key, "text", "answer", "result_json", "result"]):
if outputs.get(key) is not None:
text = self._stringify_output(outputs.get(key))
if text:
break
if not text:
for value in outputs.values():
text = self._stringify_output(value)
if text:
break
usage = {
"total_tokens": payload.get("total_tokens"),
"latency": payload.get("elapsed_time"),
}
return {"text": text.strip(), "usage": usage, "raw": data}
def _extract_dify_stream_text(self, chunk: Dict[str, Any]) -> str:
if not isinstance(chunk, dict):
return ""
payload = (chunk.get("data") or {}) if isinstance(chunk.get("data"), dict) else {}
outputs = payload.get("outputs", {}) if isinstance(payload.get("outputs"), dict) else {}
for key in filter(None, [self.workflow_output_key, "text", "answer", "result_json", "result"]):
if outputs.get(key) is not None:
return self._stringify_output(outputs.get(key))
for key in ("text", "answer"):
if chunk.get(key) is not None:
return self._stringify_output(chunk.get(key))
return ""
@staticmethod
def _extract_openai_usage(data: Dict[str, Any]) -> Dict[str, Any]:
usage = data.get("usage", {}) or {}
if usage:
return usage
return {}
@staticmethod
def _stringify_output(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value.strip()
if isinstance(value, (dict, list)):
return json.dumps(value, ensure_ascii=False)
return str(value).strip()
@classmethod
def _normalize_config(cls, config: Dict[str, Any]) -> Dict[str, Any]:
normalized = LLMRegistry.resolve(config or {})
normalized["enabled"] = bool(
normalized.get("enabled", normalized.get("enable", True))
)
if not normalized.get("provider"):
normalized["provider"] = cls._guess_provider(normalized)
parsed_url = cls._split_url(
normalized.get("api_url")
or normalized.get("url")
)
base_url = (
normalized.get("base_url")
or normalized.get("api_base_url")
or parsed_url[0]
or ""
)
endpoint = (
normalized.get("endpoint")
or parsed_url[1]
or ""
)
normalized["base_url"] = str(base_url).rstrip("/")
normalized["endpoint"] = str(endpoint).lstrip("/")
normalized["api_key"] = (
normalized.get("api_key")
or normalized.get("api-key")
or normalized.get("authorization")
or ""
)
normalized["timeout_seconds"] = int(
normalized.get("timeout_seconds")
or normalized.get("request_timeout_seconds")
or normalized.get("request_timeout")
or 60
)
normalized["max_retries"] = int(normalized.get("max_retries", len(normalized.get("retry_delays_seconds", [])) + 1 or 3))
normalized["retry_delay_seconds"] = float(normalized.get("retry_delay_seconds", 1.0))
normalized["response_mode"] = normalized.get("response_mode", "blocking")
normalized["workflow_output_key"] = normalized.get("workflow_output_key", "text")
if normalized["provider"] == "dify":
default_endpoint = cls._guess_dify_endpoint(normalized)
if not normalized["endpoint"]:
normalized["endpoint"] = default_endpoint
else:
if not normalized["endpoint"]:
normalized["endpoint"] = "chat/completions"
return normalized
@staticmethod
def _guess_provider(config: Dict[str, Any]) -> str:
api_key = str(
config.get("api_key")
or config.get("api-key")
or config.get("authorization")
or ""
).strip()
url = str(config.get("api_url") or config.get("url") or config.get("endpoint") or "").lower()
mode = str(config.get("mode", "")).lower()
if "workflows/run" in url or "chat-messages" in url or "completion-messages" in url:
return "dify"
if api_key.startswith("app-") or mode in {"workflow", "completion"}:
return "dify"
return "openai_compatible"
@staticmethod
def _guess_dify_endpoint(config: Dict[str, Any]) -> str:
mode = str(config.get("mode", "chat")).strip().lower()
if mode == "workflow":
return "workflows/run"
if mode == "completion":
return "completion-messages"
return "chat-messages"
@staticmethod
def _split_url(url: Optional[str]) -> Tuple[str, str]:
if not url:
return "", ""
parsed = urlparse(str(url))
if not parsed.scheme or not parsed.netloc:
return "", str(url)
base = f"{parsed.scheme}://{parsed.netloc}"
return base, parsed.path.lstrip("/")
@staticmethod
def _build_auth_header(value: str) -> str:
token = str(value or "").strip()
if not token:
return ""
if token.lower().startswith("bearer "):
return token
return f"Bearer {token}"
@staticmethod
def _stringify_inputs(inputs: Dict[str, Any]) -> str:
if not inputs:
return ""
try:
return json.dumps(inputs, ensure_ascii=False)
except Exception:
return str(inputs)