393 lines
12 KiB
Python
393 lines
12 KiB
Python
"""
|
|
LLM 客户端抽象层
|
|
|
|
提供统一的 LLM API 调用接口:
|
|
- 支持 OpenAI 兼容 API
|
|
- 自动重试和错误处理
|
|
- 流式/非流式响应
|
|
- 代理支持
|
|
- Token 估算
|
|
|
|
使用示例:
|
|
from utils.llm_client import LLMClient, LLMConfig
|
|
|
|
config = LLMConfig(
|
|
api_base="https://api.openai.com/v1",
|
|
api_key="sk-xxx",
|
|
model="gpt-4",
|
|
)
|
|
client = LLMClient(config)
|
|
|
|
response = await client.chat_completion(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
tools=[...],
|
|
)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
|
|
import aiohttp
|
|
from loguru import logger
|
|
|
|
# 可选代理支持
|
|
try:
|
|
from aiohttp_socks import ProxyConnector
|
|
PROXY_SUPPORT = True
|
|
except ImportError:
|
|
PROXY_SUPPORT = False
|
|
|
|
|
|
@dataclass
|
|
class LLMConfig:
|
|
"""LLM 配置"""
|
|
api_base: str = "https://api.openai.com/v1"
|
|
api_key: str = ""
|
|
model: str = "gpt-4"
|
|
temperature: float = 0.7
|
|
max_tokens: int = 4096
|
|
timeout: int = 120
|
|
max_retries: int = 3
|
|
retry_delay: float = 1.0
|
|
|
|
# 代理配置
|
|
proxy_enabled: bool = False
|
|
proxy_type: str = "socks5"
|
|
proxy_host: str = "127.0.0.1"
|
|
proxy_port: int = 7890
|
|
|
|
# 额外参数
|
|
extra_params: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
@classmethod
|
|
def from_dict(cls, config: Dict[str, Any]) -> "LLMConfig":
|
|
"""从配置字典创建"""
|
|
api_config = config.get("api", {})
|
|
proxy_config = config.get("proxy", {})
|
|
|
|
return cls(
|
|
api_base=api_config.get("base_url", "https://api.openai.com/v1"),
|
|
api_key=api_config.get("api_key", ""),
|
|
model=api_config.get("model", "gpt-4"),
|
|
temperature=api_config.get("temperature", 0.7),
|
|
max_tokens=api_config.get("max_tokens", 4096),
|
|
timeout=api_config.get("timeout", 120),
|
|
max_retries=api_config.get("max_retries", 3),
|
|
retry_delay=api_config.get("retry_delay", 1.0),
|
|
proxy_enabled=proxy_config.get("enabled", False),
|
|
proxy_type=proxy_config.get("type", "socks5"),
|
|
proxy_host=proxy_config.get("host", "127.0.0.1"),
|
|
proxy_port=proxy_config.get("port", 7890),
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class LLMResponse:
|
|
"""LLM 响应"""
|
|
content: str = ""
|
|
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
|
|
finish_reason: str = ""
|
|
usage: Dict[str, int] = field(default_factory=dict)
|
|
raw_response: Dict[str, Any] = field(default_factory=dict)
|
|
error: Optional[str] = None
|
|
|
|
@property
|
|
def has_tool_calls(self) -> bool:
|
|
return len(self.tool_calls) > 0
|
|
|
|
@property
|
|
def success(self) -> bool:
|
|
return self.error is None
|
|
|
|
|
|
class LLMClient:
|
|
"""
|
|
LLM 客户端
|
|
|
|
提供统一的 API 调用接口,支持:
|
|
- OpenAI 兼容 API
|
|
- 自动重试
|
|
- 代理
|
|
- 流式响应
|
|
"""
|
|
|
|
def __init__(self, config: LLMConfig):
|
|
self.config = config
|
|
self._session: Optional[aiohttp.ClientSession] = None
|
|
|
|
async def _get_session(self) -> aiohttp.ClientSession:
|
|
"""获取或创建 HTTP 会话"""
|
|
if self._session is None or self._session.closed:
|
|
connector = None
|
|
|
|
# 配置代理
|
|
if self.config.proxy_enabled and PROXY_SUPPORT:
|
|
proxy_url = (
|
|
f"{self.config.proxy_type}://"
|
|
f"{self.config.proxy_host}:{self.config.proxy_port}"
|
|
)
|
|
connector = ProxyConnector.from_url(proxy_url)
|
|
logger.debug(f"[LLMClient] 使用代理: {proxy_url}")
|
|
|
|
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
|
|
self._session = aiohttp.ClientSession(
|
|
connector=connector,
|
|
timeout=timeout,
|
|
)
|
|
|
|
return self._session
|
|
|
|
async def close(self):
|
|
"""关闭会话"""
|
|
if self._session and not self._session.closed:
|
|
await self._session.close()
|
|
self._session = None
|
|
|
|
def _build_headers(self) -> Dict[str, str]:
|
|
"""构建请求头"""
|
|
return {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.config.api_key}",
|
|
}
|
|
|
|
def _build_payload(
|
|
self,
|
|
messages: List[Dict[str, Any]],
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
stream: bool = False,
|
|
**kwargs,
|
|
) -> Dict[str, Any]:
|
|
"""构建请求体"""
|
|
payload = {
|
|
"model": self.config.model,
|
|
"messages": messages,
|
|
"temperature": self.config.temperature,
|
|
"max_tokens": self.config.max_tokens,
|
|
"stream": stream,
|
|
}
|
|
|
|
if tools:
|
|
payload["tools"] = tools
|
|
payload["tool_choice"] = kwargs.get("tool_choice", "auto")
|
|
|
|
# 合并额外参数
|
|
payload.update(self.config.extra_params)
|
|
payload.update(kwargs)
|
|
|
|
return payload
|
|
|
|
async def chat_completion(
|
|
self,
|
|
messages: List[Dict[str, Any]],
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
**kwargs,
|
|
) -> LLMResponse:
|
|
"""
|
|
非流式聊天补全
|
|
|
|
Args:
|
|
messages: 消息列表
|
|
tools: 工具列表(可选)
|
|
**kwargs: 额外参数
|
|
|
|
Returns:
|
|
LLMResponse 对象
|
|
"""
|
|
session = await self._get_session()
|
|
url = f"{self.config.api_base}/chat/completions"
|
|
headers = self._build_headers()
|
|
payload = self._build_payload(messages, tools, stream=False, **kwargs)
|
|
|
|
last_error = None
|
|
|
|
for attempt in range(self.config.max_retries):
|
|
try:
|
|
async with session.post(url, headers=headers, json=payload) as resp:
|
|
if resp.status == 200:
|
|
data = await resp.json()
|
|
return self._parse_response(data)
|
|
|
|
error_text = await resp.text()
|
|
last_error = f"HTTP {resp.status}: {error_text[:200]}"
|
|
logger.warning(f"[LLMClient] 请求失败 (尝试 {attempt + 1}): {last_error}")
|
|
|
|
# 不可重试的错误
|
|
if resp.status in [400, 401, 403]:
|
|
break
|
|
|
|
except asyncio.TimeoutError:
|
|
last_error = f"请求超时 ({self.config.timeout}s)"
|
|
logger.warning(f"[LLMClient] {last_error} (尝试 {attempt + 1})")
|
|
|
|
except Exception as e:
|
|
last_error = str(e)
|
|
logger.warning(f"[LLMClient] 请求异常 (尝试 {attempt + 1}): {e}")
|
|
|
|
# 重试延迟
|
|
if attempt < self.config.max_retries - 1:
|
|
await asyncio.sleep(self.config.retry_delay * (attempt + 1))
|
|
|
|
return LLMResponse(error=last_error)
|
|
|
|
async def chat_completion_stream(
|
|
self,
|
|
messages: List[Dict[str, Any]],
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
**kwargs,
|
|
) -> AsyncGenerator[str, None]:
|
|
"""
|
|
流式聊天补全
|
|
|
|
Args:
|
|
messages: 消息列表
|
|
tools: 工具列表(可选)
|
|
**kwargs: 额外参数
|
|
|
|
Yields:
|
|
文本片段
|
|
"""
|
|
session = await self._get_session()
|
|
url = f"{self.config.api_base}/chat/completions"
|
|
headers = self._build_headers()
|
|
payload = self._build_payload(messages, tools, stream=True, **kwargs)
|
|
|
|
try:
|
|
async with session.post(url, headers=headers, json=payload) as resp:
|
|
if resp.status != 200:
|
|
error_text = await resp.text()
|
|
logger.error(f"[LLMClient] 流式请求失败: HTTP {resp.status}")
|
|
return
|
|
|
|
async for line in resp.content:
|
|
line = line.decode("utf-8").strip()
|
|
if not line or not line.startswith("data: "):
|
|
continue
|
|
|
|
data_str = line[6:]
|
|
if data_str == "[DONE]":
|
|
break
|
|
|
|
try:
|
|
data = json.loads(data_str)
|
|
delta = data.get("choices", [{}])[0].get("delta", {})
|
|
content = delta.get("content", "")
|
|
if content:
|
|
yield content
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
except Exception as e:
|
|
logger.error(f"[LLMClient] 流式请求异常: {e}")
|
|
|
|
def _parse_response(self, data: Dict[str, Any]) -> LLMResponse:
|
|
"""解析 API 响应"""
|
|
try:
|
|
choice = data.get("choices", [{}])[0]
|
|
message = choice.get("message", {})
|
|
|
|
content = message.get("content", "") or ""
|
|
tool_calls = message.get("tool_calls", [])
|
|
finish_reason = choice.get("finish_reason", "")
|
|
usage = data.get("usage", {})
|
|
|
|
# 标准化 tool_calls
|
|
parsed_tool_calls = []
|
|
for tc in tool_calls:
|
|
parsed_tool_calls.append({
|
|
"id": tc.get("id", ""),
|
|
"type": tc.get("type", "function"),
|
|
"function": {
|
|
"name": tc.get("function", {}).get("name", ""),
|
|
"arguments": tc.get("function", {}).get("arguments", "{}"),
|
|
}
|
|
})
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
tool_calls=parsed_tool_calls,
|
|
finish_reason=finish_reason,
|
|
usage=usage,
|
|
raw_response=data,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"[LLMClient] 解析响应失败: {e}")
|
|
return LLMResponse(error=f"解析响应失败: {e}")
|
|
|
|
# ==================== Token 估算 ====================
|
|
|
|
@staticmethod
|
|
def estimate_tokens(text: str) -> int:
|
|
"""
|
|
估算文本的 token 数量
|
|
|
|
使用简化规则:
|
|
- 英文约 4 字符 = 1 token
|
|
- 中文约 1.5 字符 = 1 token
|
|
"""
|
|
if not text:
|
|
return 0
|
|
|
|
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
|
other_chars = len(text) - chinese_chars
|
|
|
|
chinese_tokens = chinese_chars / 1.5
|
|
other_tokens = other_chars / 4
|
|
|
|
return int(chinese_tokens + other_tokens)
|
|
|
|
@staticmethod
|
|
def estimate_message_tokens(message: Dict[str, Any]) -> int:
|
|
"""估算单条消息的 token 数量"""
|
|
content = message.get("content", "")
|
|
|
|
if isinstance(content, str):
|
|
return LLMClient.estimate_tokens(content) + 4 # role 等开销
|
|
|
|
if isinstance(content, list):
|
|
total = 4
|
|
for item in content:
|
|
if isinstance(item, dict):
|
|
if item.get("type") == "text":
|
|
total += LLMClient.estimate_tokens(item.get("text", ""))
|
|
elif item.get("type") == "image_url":
|
|
total += 85 # 图片固定开销
|
|
return total
|
|
|
|
return 4
|
|
|
|
@staticmethod
|
|
def estimate_messages_tokens(messages: List[Dict[str, Any]]) -> int:
|
|
"""估算消息列表的总 token 数量"""
|
|
return sum(LLMClient.estimate_message_tokens(m) for m in messages)
|
|
|
|
|
|
# ==================== 便捷函数 ====================
|
|
|
|
_default_client: Optional[LLMClient] = None
|
|
|
|
|
|
def get_llm_client(config: Optional[LLMConfig] = None) -> LLMClient:
|
|
"""获取默认 LLM 客户端"""
|
|
global _default_client
|
|
if config:
|
|
_default_client = LLMClient(config)
|
|
if _default_client is None:
|
|
raise ValueError("LLM 客户端未初始化,请先传入配置")
|
|
return _default_client
|
|
|
|
|
|
# ==================== 导出 ====================
|
|
|
|
__all__ = [
|
|
'LLMConfig',
|
|
'LLMResponse',
|
|
'LLMClient',
|
|
'get_llm_client',
|
|
]
|