Files
abot/plugins/ai_auto_response/llm_client.py
2026-04-07 13:51:15 +08:00

188 lines
7.0 KiB
Python

from __future__ import annotations
import json
from typing import Dict, List, Optional
import requests
class LLMClient:
def __init__(self, config: Dict):
self.config = config or {}
self.provider = self.config.get("provider", "openai_compatible")
self.base_url = str(self.config.get("api_base_url", "")).rstrip("/")
self.endpoint = str(self.config.get("endpoint", "chat/completions")).lstrip("/")
self.api_key = self.config.get("api_key", "")
self.model = self.config.get("model", "")
self.timeout_seconds = int(self.config.get("timeout_seconds", 45))
self.temperature = float(self.config.get("temperature", 0.7))
self.max_tokens = int(self.config.get("max_tokens", 500))
self.stream = bool(self.config.get("stream", True))
self.last_error = ""
def chat(
self,
system_prompt: str,
user_prompt: str,
user_id: str,
image_urls: Optional[List[str]] = None,
) -> str:
self.last_error = ""
if not self.base_url:
self.last_error = "empty_base_url"
return ""
if self.provider == "openai_compatible":
return self._chat_openai_compatible(system_prompt, user_prompt, user_id, image_urls or [])
self.last_error = f"unsupported_provider:{self.provider}"
return ""
def _chat_openai_compatible(
self,
system_prompt: str,
user_prompt: str,
user_id: str,
image_urls: List[str],
) -> str:
if not self.model:
return ""
payload = {
"model": self.model,
"messages": self._build_messages(system_prompt, user_prompt, image_urls),
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"user": user_id,
}
if self.stream:
payload["stream"] = True
headers = {
"Content-Type": "application/json",
}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
try:
if self.stream:
return self._chat_streaming(payload, headers)
response = requests.post(
f"{self.base_url}/{self.endpoint}",
json=payload,
headers=headers,
timeout=self.timeout_seconds,
)
response.raise_for_status()
data = response.json()
text = self._extract_text(data)
if text:
return text
self.last_error = f"empty_model_output:{self.model}"
return ""
except Exception as exc:
self.last_error = f"request_failed:{exc}"
return ""
def _chat_streaming(self, payload: Dict, headers: Dict[str, str]) -> str:
chunks: List[str] = []
with requests.post(
f"{self.base_url}/{self.endpoint}",
json=payload,
headers=headers,
timeout=self.timeout_seconds,
stream=True,
) as response:
response.raise_for_status()
buffer = b""
for part in response.iter_content(chunk_size=None):
if not part:
continue
buffer += part
while b"\n\n" in buffer:
event, buffer = buffer.split(b"\n\n", 1)
try:
event_text = event.decode("utf-8")
except UnicodeDecodeError:
buffer = event + b"\n\n" + buffer
break
text_piece, done = self._parse_sse_event(event_text)
if text_piece:
chunks.append(text_piece)
if done:
final_text = "".join(chunks).strip()
if final_text:
return final_text
self.last_error = f"empty_stream_output:{self.model}"
return ""
final_text = "".join(chunks).strip()
if final_text:
return final_text
self.last_error = f"empty_stream_output:{self.model}"
return ""
@staticmethod
def _build_messages(system_prompt: str, user_prompt: str, image_urls: List[str]) -> List[Dict]:
user_content: str | List[Dict[str, object]]
if image_urls:
content_parts: List[Dict[str, object]] = [{"type": "text", "text": user_prompt}]
for image_url in image_urls:
if image_url:
content_parts.append({"type": "image_url", "image_url": {"url": image_url}})
user_content = content_parts
else:
user_content = user_prompt
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
]
@staticmethod
def _extract_text(data: Dict) -> str:
choices = data.get("choices") or []
if choices:
message = choices[0].get("message", {}) or {}
content = message.get("content")
if isinstance(content, str) and content.strip():
return content.strip()
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
text = item.get("text") or item.get("content")
if isinstance(text, str) and text.strip():
parts.append(text.strip())
if parts:
return "\n".join(parts).strip()
for key in ("reasoning_content", "text", "output_text"):
value = message.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
for key in ("output_text", "text", "answer", "response"):
value = data.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return ""
@classmethod
def _parse_sse_event(cls, event_text: str) -> tuple[str, bool]:
lines = [line.strip() for line in event_text.splitlines() if line.strip()]
data_lines = [line[5:].strip() for line in lines if line.startswith("data:")]
if not data_lines:
return "", False
data = "\n".join(data_lines)
if data == "[DONE]":
return "", True
obj = json.loads(data)
choice = (obj.get("choices") or [{}])[0]
delta = choice.get("delta") or {}
content = delta.get("content")
if isinstance(content, str):
return content, False
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
text = item.get("text") or item.get("content")
if isinstance(text, str):
parts.append(text)
return "".join(parts), False
return "", False