This commit is contained in:
2025-12-31 17:47:39 +08:00
38 changed files with 4435 additions and 1343 deletions

392
utils/llm_client.py Normal file
View File

@@ -0,0 +1,392 @@
"""
LLM 客户端抽象层
提供统一的 LLM API 调用接口:
- 支持 OpenAI 兼容 API
- 自动重试和错误处理
- 流式/非流式响应
- 代理支持
- Token 估算
使用示例:
from utils.llm_client import LLMClient, LLMConfig
config = LLMConfig(
api_base="https://api.openai.com/v1",
api_key="sk-xxx",
model="gpt-4",
)
client = LLMClient(config)
response = await client.chat_completion(
messages=[{"role": "user", "content": "Hello"}],
tools=[...],
)
"""
from __future__ import annotations
import asyncio
import json
import time
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
import aiohttp
from loguru import logger
# 可选代理支持
try:
from aiohttp_socks import ProxyConnector
PROXY_SUPPORT = True
except ImportError:
PROXY_SUPPORT = False
@dataclass
class LLMConfig:
"""LLM 配置"""
api_base: str = "https://api.openai.com/v1"
api_key: str = ""
model: str = "gpt-4"
temperature: float = 0.7
max_tokens: int = 4096
timeout: int = 120
max_retries: int = 3
retry_delay: float = 1.0
# 代理配置
proxy_enabled: bool = False
proxy_type: str = "socks5"
proxy_host: str = "127.0.0.1"
proxy_port: int = 7890
# 额外参数
extra_params: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, config: Dict[str, Any]) -> "LLMConfig":
"""从配置字典创建"""
api_config = config.get("api", {})
proxy_config = config.get("proxy", {})
return cls(
api_base=api_config.get("base_url", "https://api.openai.com/v1"),
api_key=api_config.get("api_key", ""),
model=api_config.get("model", "gpt-4"),
temperature=api_config.get("temperature", 0.7),
max_tokens=api_config.get("max_tokens", 4096),
timeout=api_config.get("timeout", 120),
max_retries=api_config.get("max_retries", 3),
retry_delay=api_config.get("retry_delay", 1.0),
proxy_enabled=proxy_config.get("enabled", False),
proxy_type=proxy_config.get("type", "socks5"),
proxy_host=proxy_config.get("host", "127.0.0.1"),
proxy_port=proxy_config.get("port", 7890),
)
@dataclass
class LLMResponse:
"""LLM 响应"""
content: str = ""
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
finish_reason: str = ""
usage: Dict[str, int] = field(default_factory=dict)
raw_response: Dict[str, Any] = field(default_factory=dict)
error: Optional[str] = None
@property
def has_tool_calls(self) -> bool:
return len(self.tool_calls) > 0
@property
def success(self) -> bool:
return self.error is None
class LLMClient:
"""
LLM 客户端
提供统一的 API 调用接口,支持:
- OpenAI 兼容 API
- 自动重试
- 代理
- 流式响应
"""
def __init__(self, config: LLMConfig):
self.config = config
self._session: Optional[aiohttp.ClientSession] = None
async def _get_session(self) -> aiohttp.ClientSession:
"""获取或创建 HTTP 会话"""
if self._session is None or self._session.closed:
connector = None
# 配置代理
if self.config.proxy_enabled and PROXY_SUPPORT:
proxy_url = (
f"{self.config.proxy_type}://"
f"{self.config.proxy_host}:{self.config.proxy_port}"
)
connector = ProxyConnector.from_url(proxy_url)
logger.debug(f"[LLMClient] 使用代理: {proxy_url}")
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
self._session = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
)
return self._session
async def close(self):
"""关闭会话"""
if self._session and not self._session.closed:
await self._session.close()
self._session = None
def _build_headers(self) -> Dict[str, str]:
"""构建请求头"""
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.config.api_key}",
}
def _build_payload(
self,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
stream: bool = False,
**kwargs,
) -> Dict[str, Any]:
"""构建请求体"""
payload = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"stream": stream,
}
if tools:
payload["tools"] = tools
payload["tool_choice"] = kwargs.get("tool_choice", "auto")
# 合并额外参数
payload.update(self.config.extra_params)
payload.update(kwargs)
return payload
async def chat_completion(
self,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs,
) -> LLMResponse:
"""
非流式聊天补全
Args:
messages: 消息列表
tools: 工具列表(可选)
**kwargs: 额外参数
Returns:
LLMResponse 对象
"""
session = await self._get_session()
url = f"{self.config.api_base}/chat/completions"
headers = self._build_headers()
payload = self._build_payload(messages, tools, stream=False, **kwargs)
last_error = None
for attempt in range(self.config.max_retries):
try:
async with session.post(url, headers=headers, json=payload) as resp:
if resp.status == 200:
data = await resp.json()
return self._parse_response(data)
error_text = await resp.text()
last_error = f"HTTP {resp.status}: {error_text[:200]}"
logger.warning(f"[LLMClient] 请求失败 (尝试 {attempt + 1}): {last_error}")
# 不可重试的错误
if resp.status in [400, 401, 403]:
break
except asyncio.TimeoutError:
last_error = f"请求超时 ({self.config.timeout}s)"
logger.warning(f"[LLMClient] {last_error} (尝试 {attempt + 1})")
except Exception as e:
last_error = str(e)
logger.warning(f"[LLMClient] 请求异常 (尝试 {attempt + 1}): {e}")
# 重试延迟
if attempt < self.config.max_retries - 1:
await asyncio.sleep(self.config.retry_delay * (attempt + 1))
return LLMResponse(error=last_error)
async def chat_completion_stream(
self,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs,
) -> AsyncGenerator[str, None]:
"""
流式聊天补全
Args:
messages: 消息列表
tools: 工具列表(可选)
**kwargs: 额外参数
Yields:
文本片段
"""
session = await self._get_session()
url = f"{self.config.api_base}/chat/completions"
headers = self._build_headers()
payload = self._build_payload(messages, tools, stream=True, **kwargs)
try:
async with session.post(url, headers=headers, json=payload) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"[LLMClient] 流式请求失败: HTTP {resp.status}")
return
async for line in resp.content:
line = line.decode("utf-8").strip()
if not line or not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"[LLMClient] 流式请求异常: {e}")
def _parse_response(self, data: Dict[str, Any]) -> LLMResponse:
"""解析 API 响应"""
try:
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
content = message.get("content", "") or ""
tool_calls = message.get("tool_calls", [])
finish_reason = choice.get("finish_reason", "")
usage = data.get("usage", {})
# 标准化 tool_calls
parsed_tool_calls = []
for tc in tool_calls:
parsed_tool_calls.append({
"id": tc.get("id", ""),
"type": tc.get("type", "function"),
"function": {
"name": tc.get("function", {}).get("name", ""),
"arguments": tc.get("function", {}).get("arguments", "{}"),
}
})
return LLMResponse(
content=content,
tool_calls=parsed_tool_calls,
finish_reason=finish_reason,
usage=usage,
raw_response=data,
)
except Exception as e:
logger.error(f"[LLMClient] 解析响应失败: {e}")
return LLMResponse(error=f"解析响应失败: {e}")
# ==================== Token 估算 ====================
@staticmethod
def estimate_tokens(text: str) -> int:
"""
估算文本的 token 数量
使用简化规则:
- 英文约 4 字符 = 1 token
- 中文约 1.5 字符 = 1 token
"""
if not text:
return 0
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
other_chars = len(text) - chinese_chars
chinese_tokens = chinese_chars / 1.5
other_tokens = other_chars / 4
return int(chinese_tokens + other_tokens)
@staticmethod
def estimate_message_tokens(message: Dict[str, Any]) -> int:
"""估算单条消息的 token 数量"""
content = message.get("content", "")
if isinstance(content, str):
return LLMClient.estimate_tokens(content) + 4 # role 等开销
if isinstance(content, list):
total = 4
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
total += LLMClient.estimate_tokens(item.get("text", ""))
elif item.get("type") == "image_url":
total += 85 # 图片固定开销
return total
return 4
@staticmethod
def estimate_messages_tokens(messages: List[Dict[str, Any]]) -> int:
"""估算消息列表的总 token 数量"""
return sum(LLMClient.estimate_message_tokens(m) for m in messages)
# ==================== 便捷函数 ====================
_default_client: Optional[LLMClient] = None
def get_llm_client(config: Optional[LLMConfig] = None) -> LLMClient:
"""获取默认 LLM 客户端"""
global _default_client
if config:
_default_client = LLMClient(config)
if _default_client is None:
raise ValueError("LLM 客户端未初始化,请先传入配置")
return _default_client
# ==================== 导出 ====================
__all__ = [
'LLMConfig',
'LLMResponse',
'LLMClient',
'get_llm_client',
]