refactor ai_auto_response into xiaoniu group bot
This commit is contained in:
166
plugins/ai_auto_response/llm_client.py
Normal file
166
plugins/ai_auto_response/llm_client.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class LLMClient:
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config or {}
|
||||
self.provider = self.config.get("provider", "openai_compatible")
|
||||
self.base_url = str(self.config.get("api_base_url", "")).rstrip("/")
|
||||
self.endpoint = str(self.config.get("endpoint", "chat/completions")).lstrip("/")
|
||||
self.api_key = self.config.get("api_key", "")
|
||||
self.model = self.config.get("model", "")
|
||||
self.timeout_seconds = int(self.config.get("timeout_seconds", 45))
|
||||
self.temperature = float(self.config.get("temperature", 0.7))
|
||||
self.max_tokens = int(self.config.get("max_tokens", 500))
|
||||
self.stream = bool(self.config.get("stream", True))
|
||||
self.last_error = ""
|
||||
|
||||
def chat(self, system_prompt: str, user_prompt: str, user_id: str) -> str:
|
||||
self.last_error = ""
|
||||
if not self.base_url:
|
||||
self.last_error = "empty_base_url"
|
||||
return ""
|
||||
if self.provider == "openai_compatible":
|
||||
return self._chat_openai_compatible(system_prompt, user_prompt, user_id)
|
||||
self.last_error = f"unsupported_provider:{self.provider}"
|
||||
return ""
|
||||
|
||||
def _chat_openai_compatible(self, system_prompt: str, user_prompt: str, user_id: str) -> str:
|
||||
if not self.model:
|
||||
return ""
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(system_prompt, user_prompt),
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"user": user_id,
|
||||
}
|
||||
if self.stream:
|
||||
payload["stream"] = True
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
try:
|
||||
if self.stream:
|
||||
return self._chat_streaming(payload, headers)
|
||||
response = requests.post(
|
||||
f"{self.base_url}/{self.endpoint}",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=self.timeout_seconds,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
text = self._extract_text(data)
|
||||
if text:
|
||||
return text
|
||||
self.last_error = f"empty_model_output:{self.model}"
|
||||
return ""
|
||||
except Exception as exc:
|
||||
self.last_error = f"request_failed:{exc}"
|
||||
return ""
|
||||
|
||||
def _chat_streaming(self, payload: Dict, headers: Dict[str, str]) -> str:
|
||||
chunks: List[str] = []
|
||||
with requests.post(
|
||||
f"{self.base_url}/{self.endpoint}",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=self.timeout_seconds,
|
||||
stream=True,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
buffer = b""
|
||||
for part in response.iter_content(chunk_size=None):
|
||||
if not part:
|
||||
continue
|
||||
buffer += part
|
||||
while b"\n\n" in buffer:
|
||||
event, buffer = buffer.split(b"\n\n", 1)
|
||||
try:
|
||||
event_text = event.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
buffer = event + b"\n\n" + buffer
|
||||
break
|
||||
text_piece, done = self._parse_sse_event(event_text)
|
||||
if text_piece:
|
||||
chunks.append(text_piece)
|
||||
if done:
|
||||
final_text = "".join(chunks).strip()
|
||||
if final_text:
|
||||
return final_text
|
||||
self.last_error = f"empty_stream_output:{self.model}"
|
||||
return ""
|
||||
final_text = "".join(chunks).strip()
|
||||
if final_text:
|
||||
return final_text
|
||||
self.last_error = f"empty_stream_output:{self.model}"
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _build_messages(system_prompt: str, user_prompt: str) -> List[Dict[str, str]]:
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(data: Dict) -> str:
|
||||
choices = data.get("choices") or []
|
||||
if choices:
|
||||
message = choices[0].get("message", {}) or {}
|
||||
content = message.get("content")
|
||||
if isinstance(content, str) and content.strip():
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
text = item.get("text") or item.get("content")
|
||||
if isinstance(text, str) and text.strip():
|
||||
parts.append(text.strip())
|
||||
if parts:
|
||||
return "\n".join(parts).strip()
|
||||
for key in ("reasoning_content", "text", "output_text"):
|
||||
value = message.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
for key in ("output_text", "text", "answer", "response"):
|
||||
value = data.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _parse_sse_event(cls, event_text: str) -> tuple[str, bool]:
|
||||
lines = [line.strip() for line in event_text.splitlines() if line.strip()]
|
||||
data_lines = [line[5:].strip() for line in lines if line.startswith("data:")]
|
||||
if not data_lines:
|
||||
return "", False
|
||||
data = "\n".join(data_lines)
|
||||
if data == "[DONE]":
|
||||
return "", True
|
||||
obj = json.loads(data)
|
||||
choice = (obj.get("choices") or [{}])[0]
|
||||
delta = choice.get("delta") or {}
|
||||
content = delta.get("content")
|
||||
if isinstance(content, str):
|
||||
return content, False
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
text = item.get("text") or item.get("content")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "".join(parts), False
|
||||
return "", False
|
||||
Reference in New Issue
Block a user