Merge branch 'main' of https://gitea.functen.cn/shihao/WechatHookBot
This commit is contained in:
392
utils/llm_client.py
Normal file
392
utils/llm_client.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""
|
||||
LLM 客户端抽象层
|
||||
|
||||
提供统一的 LLM API 调用接口:
|
||||
- 支持 OpenAI 兼容 API
|
||||
- 自动重试和错误处理
|
||||
- 流式/非流式响应
|
||||
- 代理支持
|
||||
- Token 估算
|
||||
|
||||
使用示例:
|
||||
from utils.llm_client import LLMClient, LLMConfig
|
||||
|
||||
config = LLMConfig(
|
||||
api_base="https://api.openai.com/v1",
|
||||
api_key="sk-xxx",
|
||||
model="gpt-4",
|
||||
)
|
||||
client = LLMClient(config)
|
||||
|
||||
response = await client.chat_completion(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
tools=[...],
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
# 可选代理支持
|
||||
try:
|
||||
from aiohttp_socks import ProxyConnector
|
||||
PROXY_SUPPORT = True
|
||||
except ImportError:
|
||||
PROXY_SUPPORT = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""LLM 配置"""
|
||||
api_base: str = "https://api.openai.com/v1"
|
||||
api_key: str = ""
|
||||
model: str = "gpt-4"
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 4096
|
||||
timeout: int = 120
|
||||
max_retries: int = 3
|
||||
retry_delay: float = 1.0
|
||||
|
||||
# 代理配置
|
||||
proxy_enabled: bool = False
|
||||
proxy_type: str = "socks5"
|
||||
proxy_host: str = "127.0.0.1"
|
||||
proxy_port: int = 7890
|
||||
|
||||
# 额外参数
|
||||
extra_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config: Dict[str, Any]) -> "LLMConfig":
|
||||
"""从配置字典创建"""
|
||||
api_config = config.get("api", {})
|
||||
proxy_config = config.get("proxy", {})
|
||||
|
||||
return cls(
|
||||
api_base=api_config.get("base_url", "https://api.openai.com/v1"),
|
||||
api_key=api_config.get("api_key", ""),
|
||||
model=api_config.get("model", "gpt-4"),
|
||||
temperature=api_config.get("temperature", 0.7),
|
||||
max_tokens=api_config.get("max_tokens", 4096),
|
||||
timeout=api_config.get("timeout", 120),
|
||||
max_retries=api_config.get("max_retries", 3),
|
||||
retry_delay=api_config.get("retry_delay", 1.0),
|
||||
proxy_enabled=proxy_config.get("enabled", False),
|
||||
proxy_type=proxy_config.get("type", "socks5"),
|
||||
proxy_host=proxy_config.get("host", "127.0.0.1"),
|
||||
proxy_port=proxy_config.get("port", 7890),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""LLM 响应"""
|
||||
content: str = ""
|
||||
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
|
||||
finish_reason: str = ""
|
||||
usage: Dict[str, int] = field(default_factory=dict)
|
||||
raw_response: Dict[str, Any] = field(default_factory=dict)
|
||||
error: Optional[str] = None
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return self.error is None
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""
|
||||
LLM 客户端
|
||||
|
||||
提供统一的 API 调用接口,支持:
|
||||
- OpenAI 兼容 API
|
||||
- 自动重试
|
||||
- 代理
|
||||
- 流式响应
|
||||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""获取或创建 HTTP 会话"""
|
||||
if self._session is None or self._session.closed:
|
||||
connector = None
|
||||
|
||||
# 配置代理
|
||||
if self.config.proxy_enabled and PROXY_SUPPORT:
|
||||
proxy_url = (
|
||||
f"{self.config.proxy_type}://"
|
||||
f"{self.config.proxy_host}:{self.config.proxy_port}"
|
||||
)
|
||||
connector = ProxyConnector.from_url(proxy_url)
|
||||
logger.debug(f"[LLMClient] 使用代理: {proxy_url}")
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""关闭会话"""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
}
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建请求体"""
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||
|
||||
# 合并额外参数
|
||||
payload.update(self.config.extra_params)
|
||||
payload.update(kwargs)
|
||||
|
||||
return payload
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
非流式聊天补全
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
tools: 工具列表(可选)
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
LLMResponse 对象
|
||||
"""
|
||||
session = await self._get_session()
|
||||
url = f"{self.config.api_base}/chat/completions"
|
||||
headers = self._build_headers()
|
||||
payload = self._build_payload(messages, tools, stream=False, **kwargs)
|
||||
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
async with session.post(url, headers=headers, json=payload) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return self._parse_response(data)
|
||||
|
||||
error_text = await resp.text()
|
||||
last_error = f"HTTP {resp.status}: {error_text[:200]}"
|
||||
logger.warning(f"[LLMClient] 请求失败 (尝试 {attempt + 1}): {last_error}")
|
||||
|
||||
# 不可重试的错误
|
||||
if resp.status in [400, 401, 403]:
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
last_error = f"请求超时 ({self.config.timeout}s)"
|
||||
logger.warning(f"[LLMClient] {last_error} (尝试 {attempt + 1})")
|
||||
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
logger.warning(f"[LLMClient] 请求异常 (尝试 {attempt + 1}): {e}")
|
||||
|
||||
# 重试延迟
|
||||
if attempt < self.config.max_retries - 1:
|
||||
await asyncio.sleep(self.config.retry_delay * (attempt + 1))
|
||||
|
||||
return LLMResponse(error=last_error)
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
流式聊天补全
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
tools: 工具列表(可选)
|
||||
**kwargs: 额外参数
|
||||
|
||||
Yields:
|
||||
文本片段
|
||||
"""
|
||||
session = await self._get_session()
|
||||
url = f"{self.config.api_base}/chat/completions"
|
||||
headers = self._build_headers()
|
||||
payload = self._build_payload(messages, tools, stream=True, **kwargs)
|
||||
|
||||
try:
|
||||
async with session.post(url, headers=headers, json=payload) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
logger.error(f"[LLMClient] 流式请求失败: HTTP {resp.status}")
|
||||
return
|
||||
|
||||
async for line in resp.content:
|
||||
line = line.decode("utf-8").strip()
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMClient] 流式请求异常: {e}")
|
||||
|
||||
def _parse_response(self, data: Dict[str, Any]) -> LLMResponse:
|
||||
"""解析 API 响应"""
|
||||
try:
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
content = message.get("content", "") or ""
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
finish_reason = choice.get("finish_reason", "")
|
||||
usage = data.get("usage", {})
|
||||
|
||||
# 标准化 tool_calls
|
||||
parsed_tool_calls = []
|
||||
for tc in tool_calls:
|
||||
parsed_tool_calls.append({
|
||||
"id": tc.get("id", ""),
|
||||
"type": tc.get("type", "function"),
|
||||
"function": {
|
||||
"name": tc.get("function", {}).get("name", ""),
|
||||
"arguments": tc.get("function", {}).get("arguments", "{}"),
|
||||
}
|
||||
})
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=parsed_tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
raw_response=data,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMClient] 解析响应失败: {e}")
|
||||
return LLMResponse(error=f"解析响应失败: {e}")
|
||||
|
||||
# ==================== Token 估算 ====================
|
||||
|
||||
@staticmethod
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""
|
||||
估算文本的 token 数量
|
||||
|
||||
使用简化规则:
|
||||
- 英文约 4 字符 = 1 token
|
||||
- 中文约 1.5 字符 = 1 token
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||||
other_chars = len(text) - chinese_chars
|
||||
|
||||
chinese_tokens = chinese_chars / 1.5
|
||||
other_tokens = other_chars / 4
|
||||
|
||||
return int(chinese_tokens + other_tokens)
|
||||
|
||||
@staticmethod
|
||||
def estimate_message_tokens(message: Dict[str, Any]) -> int:
|
||||
"""估算单条消息的 token 数量"""
|
||||
content = message.get("content", "")
|
||||
|
||||
if isinstance(content, str):
|
||||
return LLMClient.estimate_tokens(content) + 4 # role 等开销
|
||||
|
||||
if isinstance(content, list):
|
||||
total = 4
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
total += LLMClient.estimate_tokens(item.get("text", ""))
|
||||
elif item.get("type") == "image_url":
|
||||
total += 85 # 图片固定开销
|
||||
return total
|
||||
|
||||
return 4
|
||||
|
||||
@staticmethod
|
||||
def estimate_messages_tokens(messages: List[Dict[str, Any]]) -> int:
|
||||
"""估算消息列表的总 token 数量"""
|
||||
return sum(LLMClient.estimate_message_tokens(m) for m in messages)
|
||||
|
||||
|
||||
# ==================== 便捷函数 ====================
|
||||
|
||||
_default_client: Optional[LLMClient] = None
|
||||
|
||||
|
||||
def get_llm_client(config: Optional[LLMConfig] = None) -> LLMClient:
|
||||
"""获取默认 LLM 客户端"""
|
||||
global _default_client
|
||||
if config:
|
||||
_default_client = LLMClient(config)
|
||||
if _default_client is None:
|
||||
raise ValueError("LLM 客户端未初始化,请先传入配置")
|
||||
return _default_client
|
||||
|
||||
|
||||
# ==================== 导出 ====================
|
||||
|
||||
__all__ = [
|
||||
'LLMConfig',
|
||||
'LLMResponse',
|
||||
'LLMClient',
|
||||
'get_llm_client',
|
||||
]
|
||||
Reference in New Issue
Block a user