""" MCP 客户端封装 实现 MCP (Model Context Protocol) 客户端,支持两种传输方式: - stdio: 通过子进程 stdin/stdout 通信 - http: 通过 HTTP/SSE 通信 支持: - 服务器生命周期管理(启动、停止、重启) - 工具发现和执行 - JSON-RPC 2.0 通信协议 """ import asyncio import json import os import sys from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Literal from abc import ABC, abstractmethod from loguru import logger # 可选导入 aiohttp try: import aiohttp AIOHTTP_AVAILABLE = True except ImportError: AIOHTTP_AVAILABLE = False @dataclass class MCPServerConfig: """MCP 服务器配置""" name: str # 传输类型: stdio 或 http transport: Literal["stdio", "http"] = "stdio" # stdio 类型配置 command: str = "" args: List[str] = field(default_factory=list) working_dir: str = "" # http 类型配置 url: str = "" headers: Dict[str, str] = field(default_factory=dict) # 通用配置 env: Dict[str, str] = field(default_factory=dict) enabled: bool = True tool_prefix: str = "" @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MCPServerConfig": # 自动检测传输类型 transport = data.get("type", data.get("transport", "stdio")) if data.get("url"): transport = "http" return cls( name=data.get("name", ""), transport=transport, command=data.get("command", ""), args=data.get("args", []), working_dir=data.get("working_dir", ""), url=data.get("url", ""), headers=data.get("headers", {}), env=data.get("env", {}), enabled=data.get("enabled", True), tool_prefix=data.get("tool_prefix", ""), ) @dataclass class MCPTool: """MCP 工具定义""" name: str description: str input_schema: Dict[str, Any] server_name: str # 所属服务器名称 def to_openai_schema(self, prefix: str = "") -> Dict[str, Any]: """转换为 OpenAI 兼容的工具 schema""" tool_name = f"{prefix}_{self.name}" if prefix else self.name # 清理工具名中的非法字符 tool_name = tool_name.replace("-", "_").replace(".", "_") return { "type": "function", "function": { "name": tool_name, "description": self.description or f"MCP tool: {self.name}", "parameters": self.input_schema or {"type": "object", "properties": {}} } } class MCPClient: """ MCP 客户端(stdio 类型) 管理与单个 MCP 服务器的连接和通信(通过子进程 stdio) """ def __init__(self, config: MCPServerConfig, start_timeout: float = 30.0): self.config = config self.start_timeout = start_timeout self.process: Optional[asyncio.subprocess.Process] = None self.tools: Dict[str, MCPTool] = {} self._request_id = 0 self._pending_requests: Dict[int, asyncio.Future] = {} self._read_task: Optional[asyncio.Task] = None self._initialized = False self._lock = asyncio.Lock() @property def name(self) -> str: return self.config.name @property def is_connected(self) -> bool: return self.process is not None and self.process.returncode is None async def connect(self) -> bool: """ 连接到 MCP 服务器 Returns: 是否连接成功 """ if self.is_connected: return True async with self._lock: try: # 准备环境变量 env = os.environ.copy() env.update(self.config.env) # 准备工作目录 cwd = self.config.working_dir if self.config.working_dir else None # 构建完整命令 cmd = [self.config.command] + self.config.args logger.info(f"[MCP] 启动服务器 {self.name}: {' '.join(cmd)}") # 启动子进程 self.process = await asyncio.create_subprocess_exec( *cmd, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, env=env, cwd=cwd, # Windows 特殊处理 creationflags=subprocess_creation_flags() if sys.platform == "win32" else 0, ) # 启动读取协程 self._read_task = asyncio.create_task(self._read_loop()) # 发送 initialize 请求 init_result = await asyncio.wait_for( self._initialize(), timeout=self.start_timeout ) if not init_result: await self.disconnect() return False # 获取工具列表 await self._list_tools() self._initialized = True logger.success(f"[MCP] 服务器 {self.name} 已连接,发现 {len(self.tools)} 个工具") return True except asyncio.TimeoutError: logger.error(f"[MCP] 服务器 {self.name} 启动超时") await self.disconnect() return False except Exception as e: logger.error(f"[MCP] 服务器 {self.name} 连接失败: {e}") await self.disconnect() return False async def disconnect(self): """断开与 MCP 服务器的连接""" if self._read_task: self._read_task.cancel() try: await self._read_task except asyncio.CancelledError: pass self._read_task = None if self.process: try: self.process.terminate() await asyncio.wait_for(self.process.wait(), timeout=5.0) except asyncio.TimeoutError: self.process.kill() except Exception: pass self.process = None self._initialized = False self._pending_requests.clear() self.tools.clear() logger.info(f"[MCP] 服务器 {self.name} 已断开") async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: """ 调用 MCP 工具 Args: tool_name: 工具名(不含前缀) arguments: 工具参数 Returns: 工具执行结果 """ if not self.is_connected: return {"success": False, "error": "MCP 服务器未连接"} # 发送 tools/call 请求 result = await self._send_request("tools/call", { "name": tool_name, "arguments": arguments }) if "error" in result: return {"success": False, "error": result["error"]} # 解析结果 content = result.get("result", {}).get("content", []) if content: # 提取文本内容 texts = [] for item in content: if item.get("type") == "text": texts.append(item.get("text", "")) elif item.get("type") == "image": texts.append(f"[图片: {item.get('mimeType', 'image')}]") elif item.get("type") == "resource": texts.append(f"[资源: {item.get('uri', '')}]") return { "success": True, "message": "\n".join(texts) if texts else "执行成功", "data": content } return {"success": True, "message": "执行成功", "data": result.get("result")} async def _initialize(self) -> bool: """发送 initialize 请求""" result = await self._send_request("initialize", { "protocolVersion": "2024-11-05", "capabilities": { "tools": {} }, "clientInfo": { "name": "WechatHookBot-MCPManager", "version": "1.0.0" } }) if "error" in result: logger.error(f"[MCP] {self.name} initialize 失败: {result['error']}") return False # 发送 initialized 通知 await self._send_notification("notifications/initialized", {}) return True async def _list_tools(self): """获取工具列表""" result = await self._send_request("tools/list", {}) if "error" in result: logger.warning(f"[MCP] {self.name} 获取工具列表失败: {result['error']}") return tools_data = result.get("result", {}).get("tools", []) self.tools.clear() for tool_data in tools_data: tool = MCPTool( name=tool_data.get("name", ""), description=tool_data.get("description", ""), input_schema=tool_data.get("inputSchema", {}), server_name=self.name ) self.tools[tool.name] = tool logger.debug(f"[MCP] {self.name} 发现工具: {tool.name}") async def _send_request(self, method: str, params: Dict[str, Any]) -> Dict[str, Any]: """发送 JSON-RPC 请求并等待响应""" if not self.process or not self.process.stdin: return {"error": "进程未启动"} self._request_id += 1 request_id = self._request_id request = { "jsonrpc": "2.0", "id": request_id, "method": method, "params": params } # 创建 Future 等待响应 future: asyncio.Future = asyncio.get_event_loop().create_future() self._pending_requests[request_id] = future try: # 发送请求 request_json = json.dumps(request) + "\n" self.process.stdin.write(request_json.encode()) await self.process.stdin.drain() # 等待响应 result = await asyncio.wait_for(future, timeout=60.0) return result except asyncio.TimeoutError: self._pending_requests.pop(request_id, None) return {"error": f"请求超时: {method}"} except Exception as e: self._pending_requests.pop(request_id, None) return {"error": str(e)} async def _send_notification(self, method: str, params: Dict[str, Any]): """发送 JSON-RPC 通知(无需响应)""" if not self.process or not self.process.stdin: return notification = { "jsonrpc": "2.0", "method": method, "params": params } try: notification_json = json.dumps(notification) + "\n" self.process.stdin.write(notification_json.encode()) await self.process.stdin.drain() except Exception as e: logger.warning(f"[MCP] {self.name} 发送通知失败: {e}") async def _read_loop(self): """读取服务器响应的协程""" if not self.process or not self.process.stdout: return try: while True: line = await self.process.stdout.readline() if not line: break try: message = json.loads(line.decode().strip()) await self._handle_message(message) except json.JSONDecodeError: # 可能是服务器的日志输出,忽略 pass except asyncio.CancelledError: raise except Exception as e: logger.error(f"[MCP] {self.name} 读取错误: {e}") async def _handle_message(self, message: Dict[str, Any]): """处理收到的消息""" # 检查是否是响应 if "id" in message: request_id = message["id"] future = self._pending_requests.pop(request_id, None) if future and not future.done(): if "error" in message: future.set_result({"error": message["error"].get("message", "未知错误")}) else: future.set_result({"result": message.get("result")}) # 处理通知(如 log, progress 等) elif "method" in message: method = message["method"] if method == "notifications/message": # 服务器日志 params = message.get("params", {}) level = params.get("level", "info") data = params.get("data", "") logger.debug(f"[MCP] {self.name} [{level}]: {data}") def subprocess_creation_flags() -> int: """获取 Windows 子进程创建标志""" if sys.platform == "win32": import subprocess return subprocess.CREATE_NO_WINDOW return 0 class MCPManager: """ MCP 管理器 管理多个 MCP 服务器的连接和工具注册 """ def __init__(self, tool_timeout: float = 60.0, server_start_timeout: float = 30.0): self.tool_timeout = tool_timeout self.server_start_timeout = server_start_timeout self.clients: Dict[str, MCPClient] = {} self._tool_to_server: Dict[str, str] = {} # 工具名 -> 服务器名 self._tool_original_name: Dict[str, str] = {} # 带前缀工具名 -> 原始工具名 async def add_server(self, config: MCPServerConfig) -> bool: """ 添加并连接 MCP 服务器 Args: config: 服务器配置 Returns: 是否成功 """ if not config.enabled: logger.info(f"[MCP] 服务器 {config.name} 已禁用,跳过") return False if config.name in self.clients: logger.warning(f"[MCP] 服务器 {config.name} 已存在") return False # 使用工厂函数创建合适的客户端 client = create_mcp_client(config, self.server_start_timeout) success = await client.connect() if success: self.clients[config.name] = client # 记录工具映射 for tool_name, tool in client.tools.items(): prefixed_name = f"{config.tool_prefix}_{tool_name}" if config.tool_prefix else tool_name prefixed_name = prefixed_name.replace("-", "_").replace(".", "_") self._tool_to_server[prefixed_name] = config.name self._tool_original_name[prefixed_name] = tool_name return True return False async def remove_server(self, name: str): """移除 MCP 服务器""" client = self.clients.pop(name, None) if client: await client.disconnect() # 清理工具映射 to_remove = [k for k, v in self._tool_to_server.items() if v == name] for k in to_remove: self._tool_to_server.pop(k, None) self._tool_original_name.pop(k, None) async def shutdown(self): """关闭所有服务器""" for name in list(self.clients.keys()): await self.remove_server(name) def get_all_tools(self) -> List[Dict[str, Any]]: """获取所有工具的 OpenAI schema""" tools = [] for client in self.clients.values(): prefix = client.config.tool_prefix for tool in client.tools.values(): tools.append(tool.to_openai_schema(prefix)) return tools async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: """ 调用工具 Args: tool_name: 工具名(可能带前缀) arguments: 工具参数 Returns: 执行结果 """ server_name = self._tool_to_server.get(tool_name) if not server_name: return {"success": False, "error": f"工具 {tool_name} 不存在"} client = self.clients.get(server_name) if not client: return {"success": False, "error": f"服务器 {server_name} 未连接"} # 获取原始工具名(去掉前缀) original_name = self._tool_original_name.get(tool_name, tool_name) return await client.call_tool(original_name, arguments) def list_servers(self) -> List[Dict[str, Any]]: """列出所有服务器状态""" return [ { "name": name, "connected": client.is_connected, "tools_count": len(client.tools), "tools": list(client.tools.keys()) } for name, client in self.clients.items() ] class MCPHttpClient: """ MCP HTTP 客户端 通过 HTTP 与 MCP 服务器通信(如智谱 AI 的 MCP 服务) """ def __init__(self, config: MCPServerConfig, start_timeout: float = 30.0): self.config = config self.start_timeout = start_timeout self.tools: Dict[str, MCPTool] = {} self._request_id = 0 self._initialized = False self._session: Optional[aiohttp.ClientSession] = None self._lock = asyncio.Lock() @property def name(self) -> str: return self.config.name @property def is_connected(self) -> bool: return self._initialized and self._session is not None and not self._session.closed async def connect(self) -> bool: """连接到 HTTP MCP 服务器""" if not AIOHTTP_AVAILABLE: logger.error(f"[MCP] HTTP 客户端需要 aiohttp,请安装: pip install aiohttp") return False if self.is_connected: return True async with self._lock: try: # 创建 HTTP 会话 timeout = aiohttp.ClientTimeout(total=self.start_timeout) self._session = aiohttp.ClientSession(timeout=timeout) logger.info(f"[MCP] 连接 HTTP 服务器 {self.name}: {self.config.url}") # 尝试 initialize,如果失败则跳过 init_result = await self._initialize() if not init_result: logger.warning(f"[MCP] {self.name} initialize 失败,尝试直接获取工具列表") # 获取工具列表 await self._list_tools() # 只要有工具就算连接成功 if self.tools: self._initialized = True logger.success(f"[MCP] HTTP 服务器 {self.name} 已连接,发现 {len(self.tools)} 个工具") return True else: logger.error(f"[MCP] HTTP 服务器 {self.name} 未发现任何工具") await self.disconnect() return False except asyncio.TimeoutError: logger.error(f"[MCP] HTTP 服务器 {self.name} 连接超时") await self.disconnect() return False except Exception as e: logger.error(f"[MCP] HTTP 服务器 {self.name} 连接失败: {e}") await self.disconnect() return False async def disconnect(self): """断开连接""" if self._session and not self._session.closed: await self._session.close() self._session = None self._initialized = False self.tools.clear() logger.info(f"[MCP] HTTP 服务器 {self.name} 已断开") async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: """调用 MCP 工具""" if not self.is_connected: return {"success": False, "error": "MCP 服务器未连接"} logger.info(f"[MCP] {self.name} 调用工具: {tool_name}, 参数: {arguments}") result = await self._send_request("tools/call", { "name": tool_name, "arguments": arguments }) logger.debug(f"[MCP] {self.name} 工具原始结果: {str(result)[:500]}") if "error" in result: return {"success": False, "error": result["error"]} # 解析结果 tool_result = result.get("result", {}) content = tool_result.get("content", []) if isinstance(tool_result, dict) else [] # 检查是否有 isError 标志 is_error = tool_result.get("isError", False) if isinstance(tool_result, dict) else False # 如果 result 本身就是内容列表 if isinstance(tool_result, list): content = tool_result if content: texts = [] for item in content: if isinstance(item, str): texts.append(item) elif isinstance(item, dict): if item.get("type") == "text": texts.append(item.get("text", "")) elif item.get("type") == "image": texts.append(f"[图片: {item.get('mimeType', 'image')}]") elif item.get("type") == "resource": texts.append(f"[资源: {item.get('uri', '')}]") elif "text" in item: texts.append(item.get("text", "")) message = "\n".join(texts) if texts else "执行成功" # 检查消息内容是否包含错误信息 if is_error or "error" in message.lower()[:50]: logger.warning(f"[MCP] {self.name} 工具返回错误: {message[:200]}") return {"success": False, "error": message} logger.info(f"[MCP] {self.name} 工具返回: {message[:200]}...") return { "success": True, "message": message, "data": content } # 如果没有 content,尝试直接使用 result if tool_result: message = str(tool_result) if not isinstance(tool_result, dict) else json.dumps(tool_result, ensure_ascii=False) logger.info(f"[MCP] {self.name} 工具返回(原始): {message[:200]}...") return {"success": True, "message": message, "data": tool_result} return {"success": True, "message": "执行成功", "data": result.get("result")} async def _initialize(self) -> bool: """发送 initialize 请求""" result = await self._send_request("initialize", { "protocolVersion": "2024-11-05", "capabilities": {"tools": {}}, "clientInfo": { "name": "WechatHookBot-MCPManager", "version": "1.0.0" } }) if "error" in result: logger.error(f"[MCP] {self.name} initialize 失败: {result['error']}") return False # 发送 initialized 通知 await self._send_notification("notifications/initialized", {}) return True async def _list_tools(self): """获取工具列表""" result = await self._send_request("tools/list", {}) if "error" in result: logger.warning(f"[MCP] {self.name} 获取工具列表失败: {result['error']}") return tools_data = result.get("result", {}).get("tools", []) self.tools.clear() for tool_data in tools_data: tool = MCPTool( name=tool_data.get("name", ""), description=tool_data.get("description", ""), input_schema=tool_data.get("inputSchema", {}), server_name=self.name ) self.tools[tool.name] = tool logger.debug(f"[MCP] {self.name} 发现工具: {tool.name}") async def _send_request(self, method: str, params: Dict[str, Any]) -> Dict[str, Any]: """发送 JSON-RPC HTTP 请求""" if not self._session: return {"error": "会话未创建"} self._request_id += 1 request_id = self._request_id request = { "jsonrpc": "2.0", "id": request_id, "method": method, "params": params } # 构建请求头 headers = { "Content-Type": "application/json", "Accept": "application/json, text/event-stream", } # 处理 Authorization header for key, value in self.config.headers.items(): if key.lower() == "authorization": # 如果没有 Bearer 前缀,添加它 if not value.startswith("Bearer "): headers["Authorization"] = f"Bearer {value}" else: headers["Authorization"] = value else: headers[key] = value logger.debug(f"[MCP] {self.name} 发送请求: {method}, params: {params}") try: async with self._session.post( self.config.url, json=request, headers=headers ) as response: if response.status != 200: text = await response.text() logger.error(f"[MCP] {self.name} HTTP {response.status}: {text[:500]}") return {"error": f"HTTP {response.status}: {text[:200]}"} # 检查响应类型 content_type = response.headers.get("Content-Type", "") if "text/event-stream" in content_type: # SSE 响应,读取第一个事件 text = await response.text() # 解析 SSE 格式 for line in text.split("\n"): if line.startswith("data:"): data_str = line[5:].strip() if data_str: try: data = json.loads(data_str) if "error" in data: return {"error": data["error"].get("message", "未知错误")} return {"result": data.get("result")} except json.JSONDecodeError: pass return {"error": "无法解析 SSE 响应"} else: # JSON 响应 data = await response.json() if "error" in data: return {"error": data["error"].get("message", "未知错误")} return {"result": data.get("result")} except asyncio.TimeoutError: return {"error": f"请求超时: {method}"} except Exception as e: logger.error(f"[MCP] {self.name} 请求异常: {e}") return {"error": str(e)} async def _send_notification(self, method: str, params: Dict[str, Any]): """发送 JSON-RPC 通知""" if not self._session: return notification = { "jsonrpc": "2.0", "method": method, "params": params } headers = { "Content-Type": "application/json", **self.config.headers } try: async with self._session.post( self.config.url, json=notification, headers=headers ) as response: pass # 通知不需要响应 except Exception as e: logger.warning(f"[MCP] {self.name} 发送通知失败: {e}") def create_mcp_client(config: MCPServerConfig, start_timeout: float = 30.0): """ 工厂函数:根据配置创建合适的 MCP 客户端 Args: config: 服务器配置 start_timeout: 启动超时 Returns: MCPClient 或 MCPHttpClient 实例 """ if config.transport == "http" or config.url: return MCPHttpClient(config, start_timeout) else: return MCPClient(config, start_timeout)