feat:mcp
This commit is contained in:
@@ -1548,7 +1548,14 @@ class AIChat(PluginBase):
|
||||
if content == clear_command:
|
||||
chat_id = self._get_chat_id(from_wxid, user_wxid, is_group)
|
||||
self._clear_memory(chat_id)
|
||||
await bot.send_text(from_wxid, "✅ 已清空当前会话的记忆")
|
||||
|
||||
# 如果是群聊,还需要清空群聊历史
|
||||
if is_group and self.store:
|
||||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||||
await self.store.clear_group_history(history_chat_id)
|
||||
await bot.send_text(from_wxid, "✅ 已清空当前群聊的记忆和历史记录")
|
||||
else:
|
||||
await bot.send_text(from_wxid, "✅ 已清空当前会话的记忆")
|
||||
return False
|
||||
|
||||
# 检查是否是上下文统计指令
|
||||
|
||||
10
plugins/MCPManager/__init__.py
Normal file
10
plugins/MCPManager/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
MCPManager 插件
|
||||
|
||||
管理 MCP (Model Context Protocol) 服务器,
|
||||
自动将 MCP 工具注册到 ToolRegistry 供 AI 调用。
|
||||
"""
|
||||
|
||||
from .main import MCPManagerPlugin
|
||||
|
||||
__all__ = ["MCPManagerPlugin"]
|
||||
269
plugins/MCPManager/main.py
Normal file
269
plugins/MCPManager/main.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
MCP 管理插件
|
||||
|
||||
管理 MCP (Model Context Protocol) 服务器,将 MCP 工具自动注册到 ToolRegistry,
|
||||
使 AI 可以调用各种 MCP 服务器提供的工具。
|
||||
|
||||
功能:
|
||||
- 自动连接配置的 MCP 服务器
|
||||
- 将 MCP 工具转换为 OpenAI 格式并注册
|
||||
- 支持热重载(禁用/启用插件时自动管理连接)
|
||||
- 提供管理命令(查看状态、重连等)
|
||||
"""
|
||||
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from loguru import logger
|
||||
|
||||
from utils.plugin_base import PluginBase
|
||||
from utils.decorators import on_text_message
|
||||
from utils.tool_registry import get_tool_registry
|
||||
|
||||
from .mcp_client import MCPManager, MCPServerConfig
|
||||
|
||||
|
||||
class MCPManagerPlugin(PluginBase):
|
||||
"""MCP 管理插件"""
|
||||
|
||||
description = "MCP 服务器管理,自动注册 MCP 工具到 AI"
|
||||
author = "ShiHao"
|
||||
version = "1.0.0"
|
||||
|
||||
# 高优先级加载,确保在其他插件之前注册工具
|
||||
load_priority = 90
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.config = None
|
||||
self.mcp_manager: MCPManager = None
|
||||
self._registered_tools: List[str] = [] # 已注册的工具名列表
|
||||
|
||||
async def async_init(self):
|
||||
"""插件异步初始化"""
|
||||
# 读取配置
|
||||
config_path = Path(__file__).parent / "config.toml"
|
||||
with open(config_path, "rb") as f:
|
||||
self.config = tomllib.load(f)
|
||||
|
||||
mcp_config = self.config.get("mcp", {})
|
||||
|
||||
if not mcp_config.get("enabled", True):
|
||||
logger.info("MCPManager: MCP 功能已禁用")
|
||||
return
|
||||
|
||||
# 初始化 MCP 管理器
|
||||
self.mcp_manager = MCPManager(
|
||||
tool_timeout=mcp_config.get("tool_timeout", 60),
|
||||
server_start_timeout=mcp_config.get("server_start_timeout", 30)
|
||||
)
|
||||
|
||||
# 连接所有配置的服务器
|
||||
servers = mcp_config.get("servers", [])
|
||||
if not servers:
|
||||
logger.info("MCPManager: 未配置任何 MCP 服务器")
|
||||
return
|
||||
|
||||
if mcp_config.get("auto_connect", True):
|
||||
await self._connect_all_servers(servers)
|
||||
|
||||
logger.success(f"MCPManager 插件已加载,已连接 {len(self.mcp_manager.clients)} 个 MCP 服务器")
|
||||
|
||||
async def _connect_all_servers(self, servers: List[Dict]):
|
||||
"""连接所有配置的服务器"""
|
||||
for server_data in servers:
|
||||
try:
|
||||
config = MCPServerConfig.from_dict(server_data)
|
||||
success = await self.mcp_manager.add_server(config)
|
||||
|
||||
if success:
|
||||
# 注册工具到 ToolRegistry
|
||||
await self._register_server_tools(config.name)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCPManager: 连接服务器 {server_data.get('name', '未知')} 失败: {e}")
|
||||
|
||||
async def _register_server_tools(self, server_name: str):
|
||||
"""将服务器的工具注册到 ToolRegistry"""
|
||||
client = self.mcp_manager.clients.get(server_name)
|
||||
if not client:
|
||||
return
|
||||
|
||||
registry = get_tool_registry()
|
||||
prefix = client.config.tool_prefix
|
||||
|
||||
for tool in client.tools.values():
|
||||
schema = tool.to_openai_schema(prefix)
|
||||
tool_name = schema["function"]["name"]
|
||||
|
||||
# 创建工具执行器
|
||||
async def executor(name: str, arguments: Dict, bot, from_wxid: str, _tn=tool_name) -> Dict:
|
||||
return await self.mcp_manager.call_tool(_tn, arguments)
|
||||
|
||||
# 注册到 ToolRegistry
|
||||
success = registry.register(
|
||||
name=tool_name,
|
||||
plugin_name=f"MCP:{server_name}",
|
||||
schema=schema,
|
||||
executor=executor,
|
||||
timeout=self.mcp_manager.tool_timeout,
|
||||
priority=40 # 比普通插件优先级稍低
|
||||
)
|
||||
|
||||
if success:
|
||||
self._registered_tools.append(tool_name)
|
||||
logger.debug(f"MCPManager: 注册工具 {tool_name}")
|
||||
|
||||
logger.info(f"MCPManager: 从 {server_name} 注册了 {len(client.tools)} 个工具")
|
||||
|
||||
async def on_disable(self):
|
||||
"""插件禁用时清理"""
|
||||
await super().on_disable()
|
||||
|
||||
# 注销所有已注册的工具
|
||||
registry = get_tool_registry()
|
||||
for tool_name in self._registered_tools:
|
||||
registry.unregister(tool_name)
|
||||
self._registered_tools.clear()
|
||||
|
||||
# 关闭所有 MCP 服务器连接
|
||||
if self.mcp_manager:
|
||||
await self.mcp_manager.shutdown()
|
||||
self.mcp_manager = None
|
||||
|
||||
logger.info("MCPManager: 已清理所有 MCP 连接和工具注册")
|
||||
|
||||
# ==================== 管理命令 ====================
|
||||
|
||||
@on_text_message(priority=80)
|
||||
async def handle_admin_commands(self, bot, message: dict):
|
||||
"""处理管理命令"""
|
||||
content = message.get("Content", "").strip()
|
||||
|
||||
# 只响应管理员的命令
|
||||
# TODO: 从主配置读取管理员列表
|
||||
if not content.startswith("/mcp"):
|
||||
return True
|
||||
|
||||
parts = content.split()
|
||||
if len(parts) < 2:
|
||||
return True
|
||||
|
||||
cmd = parts[1].lower()
|
||||
reply_to = message.get("FromWxid", "")
|
||||
|
||||
if cmd == "status":
|
||||
await self._cmd_status(bot, reply_to)
|
||||
return False
|
||||
|
||||
elif cmd == "list":
|
||||
await self._cmd_list_tools(bot, reply_to)
|
||||
return False
|
||||
|
||||
elif cmd == "reload":
|
||||
await self._cmd_reload(bot, reply_to)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _cmd_status(self, bot, reply_to: str):
|
||||
"""查看 MCP 服务器状态"""
|
||||
if not self.mcp_manager:
|
||||
await bot.send_text(reply_to, "MCP 功能未启用")
|
||||
return
|
||||
|
||||
servers = self.mcp_manager.list_servers()
|
||||
if not servers:
|
||||
await bot.send_text(reply_to, "没有已连接的 MCP 服务器")
|
||||
return
|
||||
|
||||
lines = ["📡 MCP 服务器状态:"]
|
||||
for s in servers:
|
||||
status = "✅" if s["connected"] else "❌"
|
||||
lines.append(f"{status} {s['name']}: {s['tools_count']} 个工具")
|
||||
|
||||
await bot.send_text(reply_to, "\n".join(lines))
|
||||
|
||||
async def _cmd_list_tools(self, bot, reply_to: str):
|
||||
"""列出所有 MCP 工具"""
|
||||
if not self.mcp_manager:
|
||||
await bot.send_text(reply_to, "MCP 功能未启用")
|
||||
return
|
||||
|
||||
servers = self.mcp_manager.list_servers()
|
||||
if not servers:
|
||||
await bot.send_text(reply_to, "没有已连接的 MCP 服务器")
|
||||
return
|
||||
|
||||
lines = ["🔧 MCP 工具列表:"]
|
||||
for s in servers:
|
||||
if s["tools"]:
|
||||
lines.append(f"\n【{s['name']}】")
|
||||
for tool in s["tools"][:10]: # 限制显示数量
|
||||
lines.append(f" • {tool}")
|
||||
if len(s["tools"]) > 10:
|
||||
lines.append(f" ... 还有 {len(s['tools']) - 10} 个")
|
||||
|
||||
await bot.send_text(reply_to, "\n".join(lines))
|
||||
|
||||
async def _cmd_reload(self, bot, reply_to: str):
|
||||
"""重新加载 MCP 服务器"""
|
||||
await bot.send_text(reply_to, "正在重新加载 MCP 服务器...")
|
||||
|
||||
# 清理现有连接
|
||||
if self.mcp_manager:
|
||||
registry = get_tool_registry()
|
||||
for tool_name in self._registered_tools:
|
||||
registry.unregister(tool_name)
|
||||
self._registered_tools.clear()
|
||||
await self.mcp_manager.shutdown()
|
||||
|
||||
# 重新读取配置
|
||||
config_path = Path(__file__).parent / "config.toml"
|
||||
with open(config_path, "rb") as f:
|
||||
self.config = tomllib.load(f)
|
||||
|
||||
mcp_config = self.config.get("mcp", {})
|
||||
|
||||
# 重新初始化
|
||||
self.mcp_manager = MCPManager(
|
||||
tool_timeout=mcp_config.get("tool_timeout", 60),
|
||||
server_start_timeout=mcp_config.get("server_start_timeout", 30)
|
||||
)
|
||||
|
||||
servers = mcp_config.get("servers", [])
|
||||
await self._connect_all_servers(servers)
|
||||
|
||||
await bot.send_text(
|
||||
reply_to,
|
||||
f"MCP 重新加载完成,已连接 {len(self.mcp_manager.clients)} 个服务器"
|
||||
)
|
||||
|
||||
# ==================== LLM 工具接口(备用) ====================
|
||||
|
||||
def get_llm_tools(self) -> List[Dict]:
|
||||
"""
|
||||
返回 MCP 工具列表(备用接口)
|
||||
|
||||
注意:工具已通过 ToolRegistry 注册,此方法仅供参考
|
||||
"""
|
||||
if not self.mcp_manager:
|
||||
return []
|
||||
return self.mcp_manager.get_all_tools()
|
||||
|
||||
async def execute_llm_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
bot,
|
||||
from_wxid: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行 MCP 工具(备用接口)
|
||||
|
||||
注意:工具已通过 ToolRegistry 注册,此方法仅供备用
|
||||
"""
|
||||
if not self.mcp_manager:
|
||||
return {"success": False, "error": "MCP 功能未启用"}
|
||||
|
||||
return await self.mcp_manager.call_tool(tool_name, arguments)
|
||||
821
plugins/MCPManager/mcp_client.py
Normal file
821
plugins/MCPManager/mcp_client.py
Normal file
@@ -0,0 +1,821 @@
|
||||
"""
|
||||
MCP 客户端封装
|
||||
|
||||
实现 MCP (Model Context Protocol) 客户端,支持两种传输方式:
|
||||
- stdio: 通过子进程 stdin/stdout 通信
|
||||
- http: 通过 HTTP/SSE 通信
|
||||
|
||||
支持:
|
||||
- 服务器生命周期管理(启动、停止、重启)
|
||||
- 工具发现和执行
|
||||
- JSON-RPC 2.0 通信协议
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Literal
|
||||
from abc import ABC, abstractmethod
|
||||
from loguru import logger
|
||||
|
||||
# 可选导入 aiohttp
|
||||
try:
|
||||
import aiohttp
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOHTTP_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPServerConfig:
|
||||
"""MCP 服务器配置"""
|
||||
name: str
|
||||
# 传输类型: stdio 或 http
|
||||
transport: Literal["stdio", "http"] = "stdio"
|
||||
|
||||
# stdio 类型配置
|
||||
command: str = ""
|
||||
args: List[str] = field(default_factory=list)
|
||||
working_dir: str = ""
|
||||
|
||||
# http 类型配置
|
||||
url: str = ""
|
||||
headers: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# 通用配置
|
||||
env: Dict[str, str] = field(default_factory=dict)
|
||||
enabled: bool = True
|
||||
tool_prefix: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MCPServerConfig":
|
||||
# 自动检测传输类型
|
||||
transport = data.get("type", data.get("transport", "stdio"))
|
||||
if data.get("url"):
|
||||
transport = "http"
|
||||
|
||||
return cls(
|
||||
name=data.get("name", ""),
|
||||
transport=transport,
|
||||
command=data.get("command", ""),
|
||||
args=data.get("args", []),
|
||||
working_dir=data.get("working_dir", ""),
|
||||
url=data.get("url", ""),
|
||||
headers=data.get("headers", {}),
|
||||
env=data.get("env", {}),
|
||||
enabled=data.get("enabled", True),
|
||||
tool_prefix=data.get("tool_prefix", ""),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPTool:
|
||||
"""MCP 工具定义"""
|
||||
name: str
|
||||
description: str
|
||||
input_schema: Dict[str, Any]
|
||||
server_name: str # 所属服务器名称
|
||||
|
||||
def to_openai_schema(self, prefix: str = "") -> Dict[str, Any]:
|
||||
"""转换为 OpenAI 兼容的工具 schema"""
|
||||
tool_name = f"{prefix}_{self.name}" if prefix else self.name
|
||||
# 清理工具名中的非法字符
|
||||
tool_name = tool_name.replace("-", "_").replace(".", "_")
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"description": self.description or f"MCP tool: {self.name}",
|
||||
"parameters": self.input_schema or {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""
|
||||
MCP 客户端(stdio 类型)
|
||||
|
||||
管理与单个 MCP 服务器的连接和通信(通过子进程 stdio)
|
||||
"""
|
||||
|
||||
def __init__(self, config: MCPServerConfig, start_timeout: float = 30.0):
|
||||
self.config = config
|
||||
self.start_timeout = start_timeout
|
||||
self.process: Optional[asyncio.subprocess.Process] = None
|
||||
self.tools: Dict[str, MCPTool] = {}
|
||||
self._request_id = 0
|
||||
self._pending_requests: Dict[int, asyncio.Future] = {}
|
||||
self._read_task: Optional[asyncio.Task] = None
|
||||
self._initialized = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.config.name
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.process is not None and self.process.returncode is None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
连接到 MCP 服务器
|
||||
|
||||
Returns:
|
||||
是否连接成功
|
||||
"""
|
||||
if self.is_connected:
|
||||
return True
|
||||
|
||||
async with self._lock:
|
||||
try:
|
||||
# 准备环境变量
|
||||
env = os.environ.copy()
|
||||
env.update(self.config.env)
|
||||
|
||||
# 准备工作目录
|
||||
cwd = self.config.working_dir if self.config.working_dir else None
|
||||
|
||||
# 构建完整命令
|
||||
cmd = [self.config.command] + self.config.args
|
||||
logger.info(f"[MCP] 启动服务器 {self.name}: {' '.join(cmd)}")
|
||||
|
||||
# 启动子进程
|
||||
self.process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
cwd=cwd,
|
||||
# Windows 特殊处理
|
||||
creationflags=subprocess_creation_flags() if sys.platform == "win32" else 0,
|
||||
)
|
||||
|
||||
# 启动读取协程
|
||||
self._read_task = asyncio.create_task(self._read_loop())
|
||||
|
||||
# 发送 initialize 请求
|
||||
init_result = await asyncio.wait_for(
|
||||
self._initialize(),
|
||||
timeout=self.start_timeout
|
||||
)
|
||||
|
||||
if not init_result:
|
||||
await self.disconnect()
|
||||
return False
|
||||
|
||||
# 获取工具列表
|
||||
await self._list_tools()
|
||||
|
||||
self._initialized = True
|
||||
logger.success(f"[MCP] 服务器 {self.name} 已连接,发现 {len(self.tools)} 个工具")
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"[MCP] 服务器 {self.name} 启动超时")
|
||||
await self.disconnect()
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[MCP] 服务器 {self.name} 连接失败: {e}")
|
||||
await self.disconnect()
|
||||
return False
|
||||
|
||||
async def disconnect(self):
|
||||
"""断开与 MCP 服务器的连接"""
|
||||
if self._read_task:
|
||||
self._read_task.cancel()
|
||||
try:
|
||||
await self._read_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._read_task = None
|
||||
|
||||
if self.process:
|
||||
try:
|
||||
self.process.terminate()
|
||||
await asyncio.wait_for(self.process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
self.process.kill()
|
||||
except Exception:
|
||||
pass
|
||||
self.process = None
|
||||
|
||||
self._initialized = False
|
||||
self._pending_requests.clear()
|
||||
self.tools.clear()
|
||||
logger.info(f"[MCP] 服务器 {self.name} 已断开")
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
调用 MCP 工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名(不含前缀)
|
||||
arguments: 工具参数
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
if not self.is_connected:
|
||||
return {"success": False, "error": "MCP 服务器未连接"}
|
||||
|
||||
# 发送 tools/call 请求
|
||||
result = await self._send_request("tools/call", {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
})
|
||||
|
||||
if "error" in result:
|
||||
return {"success": False, "error": result["error"]}
|
||||
|
||||
# 解析结果
|
||||
content = result.get("result", {}).get("content", [])
|
||||
if content:
|
||||
# 提取文本内容
|
||||
texts = []
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
texts.append(item.get("text", ""))
|
||||
elif item.get("type") == "image":
|
||||
texts.append(f"[图片: {item.get('mimeType', 'image')}]")
|
||||
elif item.get("type") == "resource":
|
||||
texts.append(f"[资源: {item.get('uri', '')}]")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "\n".join(texts) if texts else "执行成功",
|
||||
"data": content
|
||||
}
|
||||
|
||||
return {"success": True, "message": "执行成功", "data": result.get("result")}
|
||||
|
||||
async def _initialize(self) -> bool:
|
||||
"""发送 initialize 请求"""
|
||||
result = await self._send_request("initialize", {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"tools": {}
|
||||
},
|
||||
"clientInfo": {
|
||||
"name": "WechatHookBot-MCPManager",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
})
|
||||
|
||||
if "error" in result:
|
||||
logger.error(f"[MCP] {self.name} initialize 失败: {result['error']}")
|
||||
return False
|
||||
|
||||
# 发送 initialized 通知
|
||||
await self._send_notification("notifications/initialized", {})
|
||||
return True
|
||||
|
||||
async def _list_tools(self):
|
||||
"""获取工具列表"""
|
||||
result = await self._send_request("tools/list", {})
|
||||
|
||||
if "error" in result:
|
||||
logger.warning(f"[MCP] {self.name} 获取工具列表失败: {result['error']}")
|
||||
return
|
||||
|
||||
tools_data = result.get("result", {}).get("tools", [])
|
||||
self.tools.clear()
|
||||
|
||||
for tool_data in tools_data:
|
||||
tool = MCPTool(
|
||||
name=tool_data.get("name", ""),
|
||||
description=tool_data.get("description", ""),
|
||||
input_schema=tool_data.get("inputSchema", {}),
|
||||
server_name=self.name
|
||||
)
|
||||
self.tools[tool.name] = tool
|
||||
logger.debug(f"[MCP] {self.name} 发现工具: {tool.name}")
|
||||
|
||||
async def _send_request(self, method: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送 JSON-RPC 请求并等待响应"""
|
||||
if not self.process or not self.process.stdin:
|
||||
return {"error": "进程未启动"}
|
||||
|
||||
self._request_id += 1
|
||||
request_id = self._request_id
|
||||
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
"params": params
|
||||
}
|
||||
|
||||
# 创建 Future 等待响应
|
||||
future: asyncio.Future = asyncio.get_event_loop().create_future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
request_json = json.dumps(request) + "\n"
|
||||
self.process.stdin.write(request_json.encode())
|
||||
await self.process.stdin.drain()
|
||||
|
||||
# 等待响应
|
||||
result = await asyncio.wait_for(future, timeout=60.0)
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
return {"error": f"请求超时: {method}"}
|
||||
except Exception as e:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _send_notification(self, method: str, params: Dict[str, Any]):
|
||||
"""发送 JSON-RPC 通知(无需响应)"""
|
||||
if not self.process or not self.process.stdin:
|
||||
return
|
||||
|
||||
notification = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params
|
||||
}
|
||||
|
||||
try:
|
||||
notification_json = json.dumps(notification) + "\n"
|
||||
self.process.stdin.write(notification_json.encode())
|
||||
await self.process.stdin.drain()
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP] {self.name} 发送通知失败: {e}")
|
||||
|
||||
async def _read_loop(self):
|
||||
"""读取服务器响应的协程"""
|
||||
if not self.process or not self.process.stdout:
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
line = await self.process.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
|
||||
try:
|
||||
message = json.loads(line.decode().strip())
|
||||
await self._handle_message(message)
|
||||
except json.JSONDecodeError:
|
||||
# 可能是服务器的日志输出,忽略
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[MCP] {self.name} 读取错误: {e}")
|
||||
|
||||
async def _handle_message(self, message: Dict[str, Any]):
|
||||
"""处理收到的消息"""
|
||||
# 检查是否是响应
|
||||
if "id" in message:
|
||||
request_id = message["id"]
|
||||
future = self._pending_requests.pop(request_id, None)
|
||||
if future and not future.done():
|
||||
if "error" in message:
|
||||
future.set_result({"error": message["error"].get("message", "未知错误")})
|
||||
else:
|
||||
future.set_result({"result": message.get("result")})
|
||||
|
||||
# 处理通知(如 log, progress 等)
|
||||
elif "method" in message:
|
||||
method = message["method"]
|
||||
if method == "notifications/message":
|
||||
# 服务器日志
|
||||
params = message.get("params", {})
|
||||
level = params.get("level", "info")
|
||||
data = params.get("data", "")
|
||||
logger.debug(f"[MCP] {self.name} [{level}]: {data}")
|
||||
|
||||
|
||||
def subprocess_creation_flags() -> int:
|
||||
"""获取 Windows 子进程创建标志"""
|
||||
if sys.platform == "win32":
|
||||
import subprocess
|
||||
return subprocess.CREATE_NO_WINDOW
|
||||
return 0
|
||||
|
||||
|
||||
class MCPManager:
|
||||
"""
|
||||
MCP 管理器
|
||||
|
||||
管理多个 MCP 服务器的连接和工具注册
|
||||
"""
|
||||
|
||||
def __init__(self, tool_timeout: float = 60.0, server_start_timeout: float = 30.0):
|
||||
self.tool_timeout = tool_timeout
|
||||
self.server_start_timeout = server_start_timeout
|
||||
self.clients: Dict[str, MCPClient] = {}
|
||||
self._tool_to_server: Dict[str, str] = {} # 工具名 -> 服务器名
|
||||
self._tool_original_name: Dict[str, str] = {} # 带前缀工具名 -> 原始工具名
|
||||
|
||||
async def add_server(self, config: MCPServerConfig) -> bool:
|
||||
"""
|
||||
添加并连接 MCP 服务器
|
||||
|
||||
Args:
|
||||
config: 服务器配置
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not config.enabled:
|
||||
logger.info(f"[MCP] 服务器 {config.name} 已禁用,跳过")
|
||||
return False
|
||||
|
||||
if config.name in self.clients:
|
||||
logger.warning(f"[MCP] 服务器 {config.name} 已存在")
|
||||
return False
|
||||
|
||||
# 使用工厂函数创建合适的客户端
|
||||
client = create_mcp_client(config, self.server_start_timeout)
|
||||
success = await client.connect()
|
||||
|
||||
if success:
|
||||
self.clients[config.name] = client
|
||||
|
||||
# 记录工具映射
|
||||
for tool_name, tool in client.tools.items():
|
||||
prefixed_name = f"{config.tool_prefix}_{tool_name}" if config.tool_prefix else tool_name
|
||||
prefixed_name = prefixed_name.replace("-", "_").replace(".", "_")
|
||||
self._tool_to_server[prefixed_name] = config.name
|
||||
self._tool_original_name[prefixed_name] = tool_name
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def remove_server(self, name: str):
|
||||
"""移除 MCP 服务器"""
|
||||
client = self.clients.pop(name, None)
|
||||
if client:
|
||||
await client.disconnect()
|
||||
|
||||
# 清理工具映射
|
||||
to_remove = [k for k, v in self._tool_to_server.items() if v == name]
|
||||
for k in to_remove:
|
||||
self._tool_to_server.pop(k, None)
|
||||
self._tool_original_name.pop(k, None)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭所有服务器"""
|
||||
for name in list(self.clients.keys()):
|
||||
await self.remove_server(name)
|
||||
|
||||
def get_all_tools(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有工具的 OpenAI schema"""
|
||||
tools = []
|
||||
for client in self.clients.values():
|
||||
prefix = client.config.tool_prefix
|
||||
for tool in client.tools.values():
|
||||
tools.append(tool.to_openai_schema(prefix))
|
||||
return tools
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
调用工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名(可能带前缀)
|
||||
arguments: 工具参数
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
server_name = self._tool_to_server.get(tool_name)
|
||||
if not server_name:
|
||||
return {"success": False, "error": f"工具 {tool_name} 不存在"}
|
||||
|
||||
client = self.clients.get(server_name)
|
||||
if not client:
|
||||
return {"success": False, "error": f"服务器 {server_name} 未连接"}
|
||||
|
||||
# 获取原始工具名(去掉前缀)
|
||||
original_name = self._tool_original_name.get(tool_name, tool_name)
|
||||
|
||||
return await client.call_tool(original_name, arguments)
|
||||
|
||||
def list_servers(self) -> List[Dict[str, Any]]:
|
||||
"""列出所有服务器状态"""
|
||||
return [
|
||||
{
|
||||
"name": name,
|
||||
"connected": client.is_connected,
|
||||
"tools_count": len(client.tools),
|
||||
"tools": list(client.tools.keys())
|
||||
}
|
||||
for name, client in self.clients.items()
|
||||
]
|
||||
|
||||
|
||||
class MCPHttpClient:
|
||||
"""
|
||||
MCP HTTP 客户端
|
||||
|
||||
通过 HTTP 与 MCP 服务器通信(如智谱 AI 的 MCP 服务)
|
||||
"""
|
||||
|
||||
def __init__(self, config: MCPServerConfig, start_timeout: float = 30.0):
|
||||
self.config = config
|
||||
self.start_timeout = start_timeout
|
||||
self.tools: Dict[str, MCPTool] = {}
|
||||
self._request_id = 0
|
||||
self._initialized = False
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.config.name
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._initialized and self._session is not None and not self._session.closed
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""连接到 HTTP MCP 服务器"""
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
logger.error(f"[MCP] HTTP 客户端需要 aiohttp,请安装: pip install aiohttp")
|
||||
return False
|
||||
|
||||
if self.is_connected:
|
||||
return True
|
||||
|
||||
async with self._lock:
|
||||
try:
|
||||
# 创建 HTTP 会话
|
||||
timeout = aiohttp.ClientTimeout(total=self.start_timeout)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
|
||||
logger.info(f"[MCP] 连接 HTTP 服务器 {self.name}: {self.config.url}")
|
||||
|
||||
# 尝试 initialize,如果失败则跳过
|
||||
init_result = await self._initialize()
|
||||
if not init_result:
|
||||
logger.warning(f"[MCP] {self.name} initialize 失败,尝试直接获取工具列表")
|
||||
|
||||
# 获取工具列表
|
||||
await self._list_tools()
|
||||
|
||||
# 只要有工具就算连接成功
|
||||
if self.tools:
|
||||
self._initialized = True
|
||||
logger.success(f"[MCP] HTTP 服务器 {self.name} 已连接,发现 {len(self.tools)} 个工具")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[MCP] HTTP 服务器 {self.name} 未发现任何工具")
|
||||
await self.disconnect()
|
||||
return False
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"[MCP] HTTP 服务器 {self.name} 连接超时")
|
||||
await self.disconnect()
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[MCP] HTTP 服务器 {self.name} 连接失败: {e}")
|
||||
await self.disconnect()
|
||||
return False
|
||||
|
||||
async def disconnect(self):
|
||||
"""断开连接"""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self._initialized = False
|
||||
self.tools.clear()
|
||||
logger.info(f"[MCP] HTTP 服务器 {self.name} 已断开")
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""调用 MCP 工具"""
|
||||
if not self.is_connected:
|
||||
return {"success": False, "error": "MCP 服务器未连接"}
|
||||
|
||||
logger.info(f"[MCP] {self.name} 调用工具: {tool_name}, 参数: {arguments}")
|
||||
|
||||
result = await self._send_request("tools/call", {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
})
|
||||
|
||||
logger.debug(f"[MCP] {self.name} 工具原始结果: {str(result)[:500]}")
|
||||
|
||||
if "error" in result:
|
||||
return {"success": False, "error": result["error"]}
|
||||
|
||||
# 解析结果
|
||||
tool_result = result.get("result", {})
|
||||
content = tool_result.get("content", []) if isinstance(tool_result, dict) else []
|
||||
|
||||
# 检查是否有 isError 标志
|
||||
is_error = tool_result.get("isError", False) if isinstance(tool_result, dict) else False
|
||||
|
||||
# 如果 result 本身就是内容列表
|
||||
if isinstance(tool_result, list):
|
||||
content = tool_result
|
||||
|
||||
if content:
|
||||
texts = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
texts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
texts.append(item.get("text", ""))
|
||||
elif item.get("type") == "image":
|
||||
texts.append(f"[图片: {item.get('mimeType', 'image')}]")
|
||||
elif item.get("type") == "resource":
|
||||
texts.append(f"[资源: {item.get('uri', '')}]")
|
||||
elif "text" in item:
|
||||
texts.append(item.get("text", ""))
|
||||
|
||||
message = "\n".join(texts) if texts else "执行成功"
|
||||
|
||||
# 检查消息内容是否包含错误信息
|
||||
if is_error or "error" in message.lower()[:50]:
|
||||
logger.warning(f"[MCP] {self.name} 工具返回错误: {message[:200]}")
|
||||
return {"success": False, "error": message}
|
||||
|
||||
logger.info(f"[MCP] {self.name} 工具返回: {message[:200]}...")
|
||||
return {
|
||||
"success": True,
|
||||
"message": message,
|
||||
"data": content
|
||||
}
|
||||
|
||||
# 如果没有 content,尝试直接使用 result
|
||||
if tool_result:
|
||||
message = str(tool_result) if not isinstance(tool_result, dict) else json.dumps(tool_result, ensure_ascii=False)
|
||||
logger.info(f"[MCP] {self.name} 工具返回(原始): {message[:200]}...")
|
||||
return {"success": True, "message": message, "data": tool_result}
|
||||
|
||||
return {"success": True, "message": "执行成功", "data": result.get("result")}
|
||||
|
||||
async def _initialize(self) -> bool:
|
||||
"""发送 initialize 请求"""
|
||||
result = await self._send_request("initialize", {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {
|
||||
"name": "WechatHookBot-MCPManager",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
})
|
||||
|
||||
if "error" in result:
|
||||
logger.error(f"[MCP] {self.name} initialize 失败: {result['error']}")
|
||||
return False
|
||||
|
||||
# 发送 initialized 通知
|
||||
await self._send_notification("notifications/initialized", {})
|
||||
return True
|
||||
|
||||
async def _list_tools(self):
|
||||
"""获取工具列表"""
|
||||
result = await self._send_request("tools/list", {})
|
||||
|
||||
if "error" in result:
|
||||
logger.warning(f"[MCP] {self.name} 获取工具列表失败: {result['error']}")
|
||||
return
|
||||
|
||||
tools_data = result.get("result", {}).get("tools", [])
|
||||
self.tools.clear()
|
||||
|
||||
for tool_data in tools_data:
|
||||
tool = MCPTool(
|
||||
name=tool_data.get("name", ""),
|
||||
description=tool_data.get("description", ""),
|
||||
input_schema=tool_data.get("inputSchema", {}),
|
||||
server_name=self.name
|
||||
)
|
||||
self.tools[tool.name] = tool
|
||||
logger.debug(f"[MCP] {self.name} 发现工具: {tool.name}")
|
||||
|
||||
async def _send_request(self, method: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送 JSON-RPC HTTP 请求"""
|
||||
if not self._session:
|
||||
return {"error": "会话未创建"}
|
||||
|
||||
self._request_id += 1
|
||||
request_id = self._request_id
|
||||
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
"params": params
|
||||
}
|
||||
|
||||
# 构建请求头
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
}
|
||||
|
||||
# 处理 Authorization header
|
||||
for key, value in self.config.headers.items():
|
||||
if key.lower() == "authorization":
|
||||
# 如果没有 Bearer 前缀,添加它
|
||||
if not value.startswith("Bearer "):
|
||||
headers["Authorization"] = f"Bearer {value}"
|
||||
else:
|
||||
headers["Authorization"] = value
|
||||
else:
|
||||
headers[key] = value
|
||||
|
||||
logger.debug(f"[MCP] {self.name} 发送请求: {method}, params: {params}")
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
self.config.url,
|
||||
json=request,
|
||||
headers=headers
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
text = await response.text()
|
||||
logger.error(f"[MCP] {self.name} HTTP {response.status}: {text[:500]}")
|
||||
return {"error": f"HTTP {response.status}: {text[:200]}"}
|
||||
|
||||
# 检查响应类型
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
|
||||
if "text/event-stream" in content_type:
|
||||
# SSE 响应,读取第一个事件
|
||||
text = await response.text()
|
||||
# 解析 SSE 格式
|
||||
for line in text.split("\n"):
|
||||
if line.startswith("data:"):
|
||||
data_str = line[5:].strip()
|
||||
if data_str:
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
if "error" in data:
|
||||
return {"error": data["error"].get("message", "未知错误")}
|
||||
return {"result": data.get("result")}
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return {"error": "无法解析 SSE 响应"}
|
||||
else:
|
||||
# JSON 响应
|
||||
data = await response.json()
|
||||
|
||||
if "error" in data:
|
||||
return {"error": data["error"].get("message", "未知错误")}
|
||||
|
||||
return {"result": data.get("result")}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return {"error": f"请求超时: {method}"}
|
||||
except Exception as e:
|
||||
logger.error(f"[MCP] {self.name} 请求异常: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _send_notification(self, method: str, params: Dict[str, Any]):
|
||||
"""发送 JSON-RPC 通知"""
|
||||
if not self._session:
|
||||
return
|
||||
|
||||
notification = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
**self.config.headers
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
self.config.url,
|
||||
json=notification,
|
||||
headers=headers
|
||||
) as response:
|
||||
pass # 通知不需要响应
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP] {self.name} 发送通知失败: {e}")
|
||||
|
||||
|
||||
def create_mcp_client(config: MCPServerConfig, start_timeout: float = 30.0):
|
||||
"""
|
||||
工厂函数:根据配置创建合适的 MCP 客户端
|
||||
|
||||
Args:
|
||||
config: 服务器配置
|
||||
start_timeout: 启动超时
|
||||
|
||||
Returns:
|
||||
MCPClient 或 MCPHttpClient 实例
|
||||
"""
|
||||
if config.transport == "http" or config.url:
|
||||
return MCPHttpClient(config, start_timeout)
|
||||
else:
|
||||
return MCPClient(config, start_timeout)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -180,6 +180,33 @@ class ContextStore:
|
||||
pass
|
||||
self.memory_fallback.pop(chat_id, None)
|
||||
|
||||
async def clear_group_history(self, chat_id: str) -> None:
|
||||
"""
|
||||
清空群聊历史记录
|
||||
|
||||
同时清除 Redis 和本地文件中的群聊历史
|
||||
"""
|
||||
# 清除 Redis 中的群聊历史
|
||||
if self._use_redis_for_group_history():
|
||||
redis_cache = get_cache()
|
||||
try:
|
||||
key = f"group_history:{_safe_chat_id(chat_id)}"
|
||||
redis_cache.delete(key)
|
||||
logger.debug(f"[ContextStore] 已清除 Redis 群聊历史: {chat_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"[ContextStore] 清除 Redis 群聊历史失败: {e}")
|
||||
|
||||
# 清除本地文件中的群聊历史
|
||||
history_file = self._get_history_file(chat_id)
|
||||
if history_file and history_file.exists():
|
||||
lock = self._get_history_lock(chat_id)
|
||||
async with lock:
|
||||
try:
|
||||
history_file.unlink()
|
||||
logger.debug(f"[ContextStore] 已清除本地群聊历史文件: {history_file}")
|
||||
except Exception as e:
|
||||
logger.debug(f"[ContextStore] 清除本地群聊历史文件失败: {e}")
|
||||
|
||||
# ------------------ 群聊 history ------------------
|
||||
|
||||
def _use_redis_for_group_history(self) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user