200 lines
7.5 KiB
Python
200 lines
7.5 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
from typing import Dict, List, Optional
|
|
|
|
import requests
|
|
|
|
|
|
class LLMClient:
|
|
def __init__(self, config: Dict):
|
|
self.config = config or {}
|
|
self.provider = self.config.get("provider", "openai_compatible")
|
|
self.base_url = str(self.config.get("api_base_url", "")).rstrip("/")
|
|
self.endpoint = str(self.config.get("endpoint", "chat/completions")).lstrip("/")
|
|
self.api_key = self.config.get("api_key", "")
|
|
self.model = self.config.get("model", "")
|
|
self.timeout_seconds = int(self.config.get("timeout_seconds", 45))
|
|
self.temperature = float(self.config.get("temperature", 0.7))
|
|
self.max_tokens = int(self.config.get("max_tokens", 500))
|
|
self.stream = bool(self.config.get("stream", True))
|
|
self.max_retries = max(int(self.config.get("max_retries", 3) or 3), 1)
|
|
self.retry_delay_seconds = float(self.config.get("retry_delay_seconds", 1.0) or 1.0)
|
|
self.last_error = ""
|
|
|
|
def chat(
|
|
self,
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
user_id: str,
|
|
image_urls: Optional[List[str]] = None,
|
|
) -> str:
|
|
self.last_error = ""
|
|
if not self.base_url:
|
|
self.last_error = "empty_base_url"
|
|
return ""
|
|
if self.provider == "openai_compatible":
|
|
return self._chat_openai_compatible(system_prompt, user_prompt, user_id, image_urls or [])
|
|
self.last_error = f"unsupported_provider:{self.provider}"
|
|
return ""
|
|
|
|
def _chat_openai_compatible(
|
|
self,
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
user_id: str,
|
|
image_urls: List[str],
|
|
) -> str:
|
|
if not self.model:
|
|
return ""
|
|
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": self._build_messages(system_prompt, user_prompt, image_urls),
|
|
"temperature": self.temperature,
|
|
"max_tokens": self.max_tokens,
|
|
"user": user_id,
|
|
}
|
|
if self.stream:
|
|
payload["stream"] = True
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
for attempt in range(1, self.max_retries + 1):
|
|
try:
|
|
if self.stream:
|
|
text = self._chat_streaming(payload, headers)
|
|
else:
|
|
text = self._chat_non_streaming(payload, headers)
|
|
if text:
|
|
return text
|
|
except Exception as exc:
|
|
self.last_error = f"request_failed:attempt_{attempt}:{exc}"
|
|
if attempt < self.max_retries:
|
|
time.sleep(self.retry_delay_seconds * attempt)
|
|
return ""
|
|
|
|
def _chat_non_streaming(self, payload: Dict, headers: Dict[str, str]) -> str:
|
|
response = requests.post(
|
|
f"{self.base_url}/{self.endpoint}",
|
|
json=payload,
|
|
headers=headers,
|
|
timeout=self.timeout_seconds,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
text = self._extract_text(data)
|
|
if text:
|
|
return text
|
|
self.last_error = f"empty_model_output:{self.model}"
|
|
return ""
|
|
|
|
def _chat_streaming(self, payload: Dict, headers: Dict[str, str]) -> str:
|
|
chunks: List[str] = []
|
|
with requests.post(
|
|
f"{self.base_url}/{self.endpoint}",
|
|
json=payload,
|
|
headers=headers,
|
|
timeout=self.timeout_seconds,
|
|
stream=True,
|
|
) as response:
|
|
response.raise_for_status()
|
|
buffer = b""
|
|
for part in response.iter_content(chunk_size=None):
|
|
if not part:
|
|
continue
|
|
buffer += part
|
|
while b"\n\n" in buffer:
|
|
event, buffer = buffer.split(b"\n\n", 1)
|
|
try:
|
|
event_text = event.decode("utf-8")
|
|
except UnicodeDecodeError:
|
|
buffer = event + b"\n\n" + buffer
|
|
break
|
|
text_piece, done = self._parse_sse_event(event_text)
|
|
if text_piece:
|
|
chunks.append(text_piece)
|
|
if done:
|
|
final_text = "".join(chunks).strip()
|
|
if final_text:
|
|
return final_text
|
|
self.last_error = f"empty_stream_output:{self.model}"
|
|
return ""
|
|
final_text = "".join(chunks).strip()
|
|
if final_text:
|
|
return final_text
|
|
self.last_error = f"empty_stream_output:{self.model}"
|
|
return ""
|
|
|
|
@staticmethod
|
|
def _build_messages(system_prompt: str, user_prompt: str, image_urls: List[str]) -> List[Dict]:
|
|
user_content: str | List[Dict[str, object]]
|
|
if image_urls:
|
|
content_parts: List[Dict[str, object]] = [{"type": "text", "text": user_prompt}]
|
|
for image_url in image_urls:
|
|
if image_url:
|
|
content_parts.append({"type": "image_url", "image_url": {"url": image_url}})
|
|
user_content = content_parts
|
|
else:
|
|
user_content = user_prompt
|
|
return [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_content},
|
|
]
|
|
|
|
@staticmethod
|
|
def _extract_text(data: Dict) -> str:
|
|
choices = data.get("choices") or []
|
|
if choices:
|
|
message = choices[0].get("message", {}) or {}
|
|
content = message.get("content")
|
|
if isinstance(content, str) and content.strip():
|
|
return content.strip()
|
|
if isinstance(content, list):
|
|
parts = []
|
|
for item in content:
|
|
if isinstance(item, dict):
|
|
text = item.get("text") or item.get("content")
|
|
if isinstance(text, str) and text.strip():
|
|
parts.append(text.strip())
|
|
if parts:
|
|
return "\n".join(parts).strip()
|
|
for key in ("reasoning_content", "text", "output_text"):
|
|
value = message.get(key)
|
|
if isinstance(value, str) and value.strip():
|
|
return value.strip()
|
|
for key in ("output_text", "text", "answer", "response"):
|
|
value = data.get(key)
|
|
if isinstance(value, str) and value.strip():
|
|
return value.strip()
|
|
return ""
|
|
|
|
@classmethod
|
|
def _parse_sse_event(cls, event_text: str) -> tuple[str, bool]:
|
|
lines = [line.strip() for line in event_text.splitlines() if line.strip()]
|
|
data_lines = [line[5:].strip() for line in lines if line.startswith("data:")]
|
|
if not data_lines:
|
|
return "", False
|
|
data = "\n".join(data_lines)
|
|
if data == "[DONE]":
|
|
return "", True
|
|
obj = json.loads(data)
|
|
choice = (obj.get("choices") or [{}])[0]
|
|
delta = choice.get("delta") or {}
|
|
content = delta.get("content")
|
|
if isinstance(content, str):
|
|
return content, False
|
|
if isinstance(content, list):
|
|
parts = []
|
|
for item in content:
|
|
if isinstance(item, dict):
|
|
text = item.get("text") or item.get("content")
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
return "".join(parts), False
|
|
return "", False
|