""" LLM 客户端抽象层 提供统一的 LLM API 调用接口: - 支持 OpenAI 兼容 API - 自动重试和错误处理 - 流式/非流式响应 - 代理支持 - Token 估算 使用示例: from utils.llm_client import LLMClient, LLMConfig config = LLMConfig( api_base="https://api.openai.com/v1", api_key="sk-xxx", model="gpt-4", ) client = LLMClient(config) response = await client.chat_completion( messages=[{"role": "user", "content": "Hello"}], tools=[...], ) """ from __future__ import annotations import asyncio import json import time from dataclasses import dataclass, field from typing import Any, AsyncGenerator, Dict, List, Optional, Union import aiohttp from loguru import logger # 可选代理支持 try: from aiohttp_socks import ProxyConnector PROXY_SUPPORT = True except ImportError: PROXY_SUPPORT = False @dataclass class LLMConfig: """LLM 配置""" api_base: str = "https://api.openai.com/v1" api_key: str = "" model: str = "gpt-4" temperature: float = 0.7 max_tokens: int = 4096 timeout: int = 120 max_retries: int = 3 retry_delay: float = 1.0 # 代理配置 proxy_enabled: bool = False proxy_type: str = "socks5" proxy_host: str = "127.0.0.1" proxy_port: int = 7890 # 额外参数 extra_params: Dict[str, Any] = field(default_factory=dict) @classmethod def from_dict(cls, config: Dict[str, Any]) -> "LLMConfig": """从配置字典创建""" api_config = config.get("api", {}) proxy_config = config.get("proxy", {}) return cls( api_base=api_config.get("base_url", "https://api.openai.com/v1"), api_key=api_config.get("api_key", ""), model=api_config.get("model", "gpt-4"), temperature=api_config.get("temperature", 0.7), max_tokens=api_config.get("max_tokens", 4096), timeout=api_config.get("timeout", 120), max_retries=api_config.get("max_retries", 3), retry_delay=api_config.get("retry_delay", 1.0), proxy_enabled=proxy_config.get("enabled", False), proxy_type=proxy_config.get("type", "socks5"), proxy_host=proxy_config.get("host", "127.0.0.1"), proxy_port=proxy_config.get("port", 7890), ) @dataclass class LLMResponse: """LLM 响应""" content: str = "" tool_calls: List[Dict[str, Any]] = field(default_factory=list) finish_reason: str = "" usage: Dict[str, int] = field(default_factory=dict) raw_response: Dict[str, Any] = field(default_factory=dict) error: Optional[str] = None @property def has_tool_calls(self) -> bool: return len(self.tool_calls) > 0 @property def success(self) -> bool: return self.error is None class LLMClient: """ LLM 客户端 提供统一的 API 调用接口,支持: - OpenAI 兼容 API - 自动重试 - 代理 - 流式响应 """ def __init__(self, config: LLMConfig): self.config = config self._session: Optional[aiohttp.ClientSession] = None async def _get_session(self) -> aiohttp.ClientSession: """获取或创建 HTTP 会话""" if self._session is None or self._session.closed: connector = None # 配置代理 if self.config.proxy_enabled and PROXY_SUPPORT: proxy_url = ( f"{self.config.proxy_type}://" f"{self.config.proxy_host}:{self.config.proxy_port}" ) connector = ProxyConnector.from_url(proxy_url) logger.debug(f"[LLMClient] 使用代理: {proxy_url}") timeout = aiohttp.ClientTimeout(total=self.config.timeout) self._session = aiohttp.ClientSession( connector=connector, timeout=timeout, ) return self._session async def close(self): """关闭会话""" if self._session and not self._session.closed: await self._session.close() self._session = None def _build_headers(self) -> Dict[str, str]: """构建请求头""" return { "Content-Type": "application/json", "Authorization": f"Bearer {self.config.api_key}", } def _build_payload( self, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None, stream: bool = False, **kwargs, ) -> Dict[str, Any]: """构建请求体""" payload = { "model": self.config.model, "messages": messages, "temperature": self.config.temperature, "max_tokens": self.config.max_tokens, "stream": stream, } if tools: payload["tools"] = tools payload["tool_choice"] = kwargs.get("tool_choice", "auto") # 合并额外参数 payload.update(self.config.extra_params) payload.update(kwargs) return payload async def chat_completion( self, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None, **kwargs, ) -> LLMResponse: """ 非流式聊天补全 Args: messages: 消息列表 tools: 工具列表(可选) **kwargs: 额外参数 Returns: LLMResponse 对象 """ session = await self._get_session() url = f"{self.config.api_base}/chat/completions" headers = self._build_headers() payload = self._build_payload(messages, tools, stream=False, **kwargs) last_error = None for attempt in range(self.config.max_retries): try: async with session.post(url, headers=headers, json=payload) as resp: if resp.status == 200: data = await resp.json() return self._parse_response(data) error_text = await resp.text() last_error = f"HTTP {resp.status}: {error_text[:200]}" logger.warning(f"[LLMClient] 请求失败 (尝试 {attempt + 1}): {last_error}") # 不可重试的错误 if resp.status in [400, 401, 403]: break except asyncio.TimeoutError: last_error = f"请求超时 ({self.config.timeout}s)" logger.warning(f"[LLMClient] {last_error} (尝试 {attempt + 1})") except Exception as e: last_error = str(e) logger.warning(f"[LLMClient] 请求异常 (尝试 {attempt + 1}): {e}") # 重试延迟 if attempt < self.config.max_retries - 1: await asyncio.sleep(self.config.retry_delay * (attempt + 1)) return LLMResponse(error=last_error) async def chat_completion_stream( self, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None, **kwargs, ) -> AsyncGenerator[str, None]: """ 流式聊天补全 Args: messages: 消息列表 tools: 工具列表(可选) **kwargs: 额外参数 Yields: 文本片段 """ session = await self._get_session() url = f"{self.config.api_base}/chat/completions" headers = self._build_headers() payload = self._build_payload(messages, tools, stream=True, **kwargs) try: async with session.post(url, headers=headers, json=payload) as resp: if resp.status != 200: error_text = await resp.text() logger.error(f"[LLMClient] 流式请求失败: HTTP {resp.status}") return async for line in resp.content: line = line.decode("utf-8").strip() if not line or not line.startswith("data: "): continue data_str = line[6:] if data_str == "[DONE]": break try: data = json.loads(data_str) delta = data.get("choices", [{}])[0].get("delta", {}) content = delta.get("content", "") if content: yield content except json.JSONDecodeError: continue except Exception as e: logger.error(f"[LLMClient] 流式请求异常: {e}") def _parse_response(self, data: Dict[str, Any]) -> LLMResponse: """解析 API 响应""" try: choice = data.get("choices", [{}])[0] message = choice.get("message", {}) content = message.get("content", "") or "" tool_calls = message.get("tool_calls", []) finish_reason = choice.get("finish_reason", "") usage = data.get("usage", {}) # 标准化 tool_calls parsed_tool_calls = [] for tc in tool_calls: parsed_tool_calls.append({ "id": tc.get("id", ""), "type": tc.get("type", "function"), "function": { "name": tc.get("function", {}).get("name", ""), "arguments": tc.get("function", {}).get("arguments", "{}"), } }) return LLMResponse( content=content, tool_calls=parsed_tool_calls, finish_reason=finish_reason, usage=usage, raw_response=data, ) except Exception as e: logger.error(f"[LLMClient] 解析响应失败: {e}") return LLMResponse(error=f"解析响应失败: {e}") # ==================== Token 估算 ==================== @staticmethod def estimate_tokens(text: str) -> int: """ 估算文本的 token 数量 使用简化规则: - 英文约 4 字符 = 1 token - 中文约 1.5 字符 = 1 token """ if not text: return 0 chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') other_chars = len(text) - chinese_chars chinese_tokens = chinese_chars / 1.5 other_tokens = other_chars / 4 return int(chinese_tokens + other_tokens) @staticmethod def estimate_message_tokens(message: Dict[str, Any]) -> int: """估算单条消息的 token 数量""" content = message.get("content", "") if isinstance(content, str): return LLMClient.estimate_tokens(content) + 4 # role 等开销 if isinstance(content, list): total = 4 for item in content: if isinstance(item, dict): if item.get("type") == "text": total += LLMClient.estimate_tokens(item.get("text", "")) elif item.get("type") == "image_url": total += 85 # 图片固定开销 return total return 4 @staticmethod def estimate_messages_tokens(messages: List[Dict[str, Any]]) -> int: """估算消息列表的总 token 数量""" return sum(LLMClient.estimate_message_tokens(m) for m in messages) # ==================== 便捷函数 ==================== _default_client: Optional[LLMClient] = None def get_llm_client(config: Optional[LLMConfig] = None) -> LLMClient: """获取默认 LLM 客户端""" global _default_client if config: _default_client = LLMClient(config) if _default_client is None: raise ValueError("LLM 客户端未初始化,请先传入配置") return _default_client # ==================== 导出 ==================== __all__ = [ 'LLMConfig', 'LLMResponse', 'LLMClient', 'get_llm_client', ]