From 4016c1e6eba30ac79084ef9a31f2caed07fa3ab1 Mon Sep 17 00:00:00 2001 From: shihao <3127647737@qq.com> Date: Thu, 8 Jan 2026 18:46:14 +0800 Subject: [PATCH] feat:mcp --- plugins/AIChat/main.py | 9 +- plugins/MCPManager/__init__.py | 10 + plugins/MCPManager/main.py | 269 ++++ plugins/MCPManager/mcp_client.py | 821 ++++++++++++ plugins/TravelPlanner/amap_client.py | 1735 +++++++++++++------------- plugins/TravelPlanner/main.py | 1414 ++++++++++++--------- utils/context_store.py | 27 + 7 files changed, 2815 insertions(+), 1470 deletions(-) create mode 100644 plugins/MCPManager/__init__.py create mode 100644 plugins/MCPManager/main.py create mode 100644 plugins/MCPManager/mcp_client.py diff --git a/plugins/AIChat/main.py b/plugins/AIChat/main.py index c16da5f..0fef153 100644 --- a/plugins/AIChat/main.py +++ b/plugins/AIChat/main.py @@ -1548,7 +1548,14 @@ class AIChat(PluginBase): if content == clear_command: chat_id = self._get_chat_id(from_wxid, user_wxid, is_group) self._clear_memory(chat_id) - await bot.send_text(from_wxid, "✅ 已清空当前会话的记忆") + + # 如果是群聊,还需要清空群聊历史 + if is_group and self.store: + history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid) + await self.store.clear_group_history(history_chat_id) + await bot.send_text(from_wxid, "✅ 已清空当前群聊的记忆和历史记录") + else: + await bot.send_text(from_wxid, "✅ 已清空当前会话的记忆") return False # 检查是否是上下文统计指令 diff --git a/plugins/MCPManager/__init__.py b/plugins/MCPManager/__init__.py new file mode 100644 index 0000000..185ebc1 --- /dev/null +++ b/plugins/MCPManager/__init__.py @@ -0,0 +1,10 @@ +""" +MCPManager 插件 + +管理 MCP (Model Context Protocol) 服务器, +自动将 MCP 工具注册到 ToolRegistry 供 AI 调用。 +""" + +from .main import MCPManagerPlugin + +__all__ = ["MCPManagerPlugin"] diff --git a/plugins/MCPManager/main.py b/plugins/MCPManager/main.py new file mode 100644 index 0000000..12fc503 --- /dev/null +++ b/plugins/MCPManager/main.py @@ -0,0 +1,269 @@ +""" +MCP 管理插件 + +管理 MCP (Model Context Protocol) 服务器,将 MCP 工具自动注册到 ToolRegistry, +使 AI 可以调用各种 MCP 服务器提供的工具。 + +功能: +- 自动连接配置的 MCP 服务器 +- 将 MCP 工具转换为 OpenAI 格式并注册 +- 支持热重载(禁用/启用插件时自动管理连接) +- 提供管理命令(查看状态、重连等) +""" + +import tomllib +from pathlib import Path +from typing import Any, Dict, List +from loguru import logger + +from utils.plugin_base import PluginBase +from utils.decorators import on_text_message +from utils.tool_registry import get_tool_registry + +from .mcp_client import MCPManager, MCPServerConfig + + +class MCPManagerPlugin(PluginBase): + """MCP 管理插件""" + + description = "MCP 服务器管理,自动注册 MCP 工具到 AI" + author = "ShiHao" + version = "1.0.0" + + # 高优先级加载,确保在其他插件之前注册工具 + load_priority = 90 + + def __init__(self): + super().__init__() + self.config = None + self.mcp_manager: MCPManager = None + self._registered_tools: List[str] = [] # 已注册的工具名列表 + + async def async_init(self): + """插件异步初始化""" + # 读取配置 + config_path = Path(__file__).parent / "config.toml" + with open(config_path, "rb") as f: + self.config = tomllib.load(f) + + mcp_config = self.config.get("mcp", {}) + + if not mcp_config.get("enabled", True): + logger.info("MCPManager: MCP 功能已禁用") + return + + # 初始化 MCP 管理器 + self.mcp_manager = MCPManager( + tool_timeout=mcp_config.get("tool_timeout", 60), + server_start_timeout=mcp_config.get("server_start_timeout", 30) + ) + + # 连接所有配置的服务器 + servers = mcp_config.get("servers", []) + if not servers: + logger.info("MCPManager: 未配置任何 MCP 服务器") + return + + if mcp_config.get("auto_connect", True): + await self._connect_all_servers(servers) + + logger.success(f"MCPManager 插件已加载,已连接 {len(self.mcp_manager.clients)} 个 MCP 服务器") + + async def _connect_all_servers(self, servers: List[Dict]): + """连接所有配置的服务器""" + for server_data in servers: + try: + config = MCPServerConfig.from_dict(server_data) + success = await self.mcp_manager.add_server(config) + + if success: + # 注册工具到 ToolRegistry + await self._register_server_tools(config.name) + + except Exception as e: + logger.error(f"MCPManager: 连接服务器 {server_data.get('name', '未知')} 失败: {e}") + + async def _register_server_tools(self, server_name: str): + """将服务器的工具注册到 ToolRegistry""" + client = self.mcp_manager.clients.get(server_name) + if not client: + return + + registry = get_tool_registry() + prefix = client.config.tool_prefix + + for tool in client.tools.values(): + schema = tool.to_openai_schema(prefix) + tool_name = schema["function"]["name"] + + # 创建工具执行器 + async def executor(name: str, arguments: Dict, bot, from_wxid: str, _tn=tool_name) -> Dict: + return await self.mcp_manager.call_tool(_tn, arguments) + + # 注册到 ToolRegistry + success = registry.register( + name=tool_name, + plugin_name=f"MCP:{server_name}", + schema=schema, + executor=executor, + timeout=self.mcp_manager.tool_timeout, + priority=40 # 比普通插件优先级稍低 + ) + + if success: + self._registered_tools.append(tool_name) + logger.debug(f"MCPManager: 注册工具 {tool_name}") + + logger.info(f"MCPManager: 从 {server_name} 注册了 {len(client.tools)} 个工具") + + async def on_disable(self): + """插件禁用时清理""" + await super().on_disable() + + # 注销所有已注册的工具 + registry = get_tool_registry() + for tool_name in self._registered_tools: + registry.unregister(tool_name) + self._registered_tools.clear() + + # 关闭所有 MCP 服务器连接 + if self.mcp_manager: + await self.mcp_manager.shutdown() + self.mcp_manager = None + + logger.info("MCPManager: 已清理所有 MCP 连接和工具注册") + + # ==================== 管理命令 ==================== + + @on_text_message(priority=80) + async def handle_admin_commands(self, bot, message: dict): + """处理管理命令""" + content = message.get("Content", "").strip() + + # 只响应管理员的命令 + # TODO: 从主配置读取管理员列表 + if not content.startswith("/mcp"): + return True + + parts = content.split() + if len(parts) < 2: + return True + + cmd = parts[1].lower() + reply_to = message.get("FromWxid", "") + + if cmd == "status": + await self._cmd_status(bot, reply_to) + return False + + elif cmd == "list": + await self._cmd_list_tools(bot, reply_to) + return False + + elif cmd == "reload": + await self._cmd_reload(bot, reply_to) + return False + + return True + + async def _cmd_status(self, bot, reply_to: str): + """查看 MCP 服务器状态""" + if not self.mcp_manager: + await bot.send_text(reply_to, "MCP 功能未启用") + return + + servers = self.mcp_manager.list_servers() + if not servers: + await bot.send_text(reply_to, "没有已连接的 MCP 服务器") + return + + lines = ["📡 MCP 服务器状态:"] + for s in servers: + status = "✅" if s["connected"] else "❌" + lines.append(f"{status} {s['name']}: {s['tools_count']} 个工具") + + await bot.send_text(reply_to, "\n".join(lines)) + + async def _cmd_list_tools(self, bot, reply_to: str): + """列出所有 MCP 工具""" + if not self.mcp_manager: + await bot.send_text(reply_to, "MCP 功能未启用") + return + + servers = self.mcp_manager.list_servers() + if not servers: + await bot.send_text(reply_to, "没有已连接的 MCP 服务器") + return + + lines = ["🔧 MCP 工具列表:"] + for s in servers: + if s["tools"]: + lines.append(f"\n【{s['name']}】") + for tool in s["tools"][:10]: # 限制显示数量 + lines.append(f" • {tool}") + if len(s["tools"]) > 10: + lines.append(f" ... 还有 {len(s['tools']) - 10} 个") + + await bot.send_text(reply_to, "\n".join(lines)) + + async def _cmd_reload(self, bot, reply_to: str): + """重新加载 MCP 服务器""" + await bot.send_text(reply_to, "正在重新加载 MCP 服务器...") + + # 清理现有连接 + if self.mcp_manager: + registry = get_tool_registry() + for tool_name in self._registered_tools: + registry.unregister(tool_name) + self._registered_tools.clear() + await self.mcp_manager.shutdown() + + # 重新读取配置 + config_path = Path(__file__).parent / "config.toml" + with open(config_path, "rb") as f: + self.config = tomllib.load(f) + + mcp_config = self.config.get("mcp", {}) + + # 重新初始化 + self.mcp_manager = MCPManager( + tool_timeout=mcp_config.get("tool_timeout", 60), + server_start_timeout=mcp_config.get("server_start_timeout", 30) + ) + + servers = mcp_config.get("servers", []) + await self._connect_all_servers(servers) + + await bot.send_text( + reply_to, + f"MCP 重新加载完成,已连接 {len(self.mcp_manager.clients)} 个服务器" + ) + + # ==================== LLM 工具接口(备用) ==================== + + def get_llm_tools(self) -> List[Dict]: + """ + 返回 MCP 工具列表(备用接口) + + 注意:工具已通过 ToolRegistry 注册,此方法仅供参考 + """ + if not self.mcp_manager: + return [] + return self.mcp_manager.get_all_tools() + + async def execute_llm_tool( + self, + tool_name: str, + arguments: Dict[str, Any], + bot, + from_wxid: str + ) -> Dict[str, Any]: + """ + 执行 MCP 工具(备用接口) + + 注意:工具已通过 ToolRegistry 注册,此方法仅供备用 + """ + if not self.mcp_manager: + return {"success": False, "error": "MCP 功能未启用"} + + return await self.mcp_manager.call_tool(tool_name, arguments) diff --git a/plugins/MCPManager/mcp_client.py b/plugins/MCPManager/mcp_client.py new file mode 100644 index 0000000..17b4d74 --- /dev/null +++ b/plugins/MCPManager/mcp_client.py @@ -0,0 +1,821 @@ +""" +MCP 客户端封装 + +实现 MCP (Model Context Protocol) 客户端,支持两种传输方式: +- stdio: 通过子进程 stdin/stdout 通信 +- http: 通过 HTTP/SSE 通信 + +支持: +- 服务器生命周期管理(启动、停止、重启) +- 工具发现和执行 +- JSON-RPC 2.0 通信协议 +""" + +import asyncio +import json +import os +import sys +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Literal +from abc import ABC, abstractmethod +from loguru import logger + +# 可选导入 aiohttp +try: + import aiohttp + AIOHTTP_AVAILABLE = True +except ImportError: + AIOHTTP_AVAILABLE = False + + +@dataclass +class MCPServerConfig: + """MCP 服务器配置""" + name: str + # 传输类型: stdio 或 http + transport: Literal["stdio", "http"] = "stdio" + + # stdio 类型配置 + command: str = "" + args: List[str] = field(default_factory=list) + working_dir: str = "" + + # http 类型配置 + url: str = "" + headers: Dict[str, str] = field(default_factory=dict) + + # 通用配置 + env: Dict[str, str] = field(default_factory=dict) + enabled: bool = True + tool_prefix: str = "" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MCPServerConfig": + # 自动检测传输类型 + transport = data.get("type", data.get("transport", "stdio")) + if data.get("url"): + transport = "http" + + return cls( + name=data.get("name", ""), + transport=transport, + command=data.get("command", ""), + args=data.get("args", []), + working_dir=data.get("working_dir", ""), + url=data.get("url", ""), + headers=data.get("headers", {}), + env=data.get("env", {}), + enabled=data.get("enabled", True), + tool_prefix=data.get("tool_prefix", ""), + ) + + +@dataclass +class MCPTool: + """MCP 工具定义""" + name: str + description: str + input_schema: Dict[str, Any] + server_name: str # 所属服务器名称 + + def to_openai_schema(self, prefix: str = "") -> Dict[str, Any]: + """转换为 OpenAI 兼容的工具 schema""" + tool_name = f"{prefix}_{self.name}" if prefix else self.name + # 清理工具名中的非法字符 + tool_name = tool_name.replace("-", "_").replace(".", "_") + + return { + "type": "function", + "function": { + "name": tool_name, + "description": self.description or f"MCP tool: {self.name}", + "parameters": self.input_schema or {"type": "object", "properties": {}} + } + } + + +class MCPClient: + """ + MCP 客户端(stdio 类型) + + 管理与单个 MCP 服务器的连接和通信(通过子进程 stdio) + """ + + def __init__(self, config: MCPServerConfig, start_timeout: float = 30.0): + self.config = config + self.start_timeout = start_timeout + self.process: Optional[asyncio.subprocess.Process] = None + self.tools: Dict[str, MCPTool] = {} + self._request_id = 0 + self._pending_requests: Dict[int, asyncio.Future] = {} + self._read_task: Optional[asyncio.Task] = None + self._initialized = False + self._lock = asyncio.Lock() + + @property + def name(self) -> str: + return self.config.name + + @property + def is_connected(self) -> bool: + return self.process is not None and self.process.returncode is None + + async def connect(self) -> bool: + """ + 连接到 MCP 服务器 + + Returns: + 是否连接成功 + """ + if self.is_connected: + return True + + async with self._lock: + try: + # 准备环境变量 + env = os.environ.copy() + env.update(self.config.env) + + # 准备工作目录 + cwd = self.config.working_dir if self.config.working_dir else None + + # 构建完整命令 + cmd = [self.config.command] + self.config.args + logger.info(f"[MCP] 启动服务器 {self.name}: {' '.join(cmd)}") + + # 启动子进程 + self.process = await asyncio.create_subprocess_exec( + *cmd, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + cwd=cwd, + # Windows 特殊处理 + creationflags=subprocess_creation_flags() if sys.platform == "win32" else 0, + ) + + # 启动读取协程 + self._read_task = asyncio.create_task(self._read_loop()) + + # 发送 initialize 请求 + init_result = await asyncio.wait_for( + self._initialize(), + timeout=self.start_timeout + ) + + if not init_result: + await self.disconnect() + return False + + # 获取工具列表 + await self._list_tools() + + self._initialized = True + logger.success(f"[MCP] 服务器 {self.name} 已连接,发现 {len(self.tools)} 个工具") + return True + + except asyncio.TimeoutError: + logger.error(f"[MCP] 服务器 {self.name} 启动超时") + await self.disconnect() + return False + except Exception as e: + logger.error(f"[MCP] 服务器 {self.name} 连接失败: {e}") + await self.disconnect() + return False + + async def disconnect(self): + """断开与 MCP 服务器的连接""" + if self._read_task: + self._read_task.cancel() + try: + await self._read_task + except asyncio.CancelledError: + pass + self._read_task = None + + if self.process: + try: + self.process.terminate() + await asyncio.wait_for(self.process.wait(), timeout=5.0) + except asyncio.TimeoutError: + self.process.kill() + except Exception: + pass + self.process = None + + self._initialized = False + self._pending_requests.clear() + self.tools.clear() + logger.info(f"[MCP] 服务器 {self.name} 已断开") + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """ + 调用 MCP 工具 + + Args: + tool_name: 工具名(不含前缀) + arguments: 工具参数 + + Returns: + 工具执行结果 + """ + if not self.is_connected: + return {"success": False, "error": "MCP 服务器未连接"} + + # 发送 tools/call 请求 + result = await self._send_request("tools/call", { + "name": tool_name, + "arguments": arguments + }) + + if "error" in result: + return {"success": False, "error": result["error"]} + + # 解析结果 + content = result.get("result", {}).get("content", []) + if content: + # 提取文本内容 + texts = [] + for item in content: + if item.get("type") == "text": + texts.append(item.get("text", "")) + elif item.get("type") == "image": + texts.append(f"[图片: {item.get('mimeType', 'image')}]") + elif item.get("type") == "resource": + texts.append(f"[资源: {item.get('uri', '')}]") + + return { + "success": True, + "message": "\n".join(texts) if texts else "执行成功", + "data": content + } + + return {"success": True, "message": "执行成功", "data": result.get("result")} + + async def _initialize(self) -> bool: + """发送 initialize 请求""" + result = await self._send_request("initialize", { + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {} + }, + "clientInfo": { + "name": "WechatHookBot-MCPManager", + "version": "1.0.0" + } + }) + + if "error" in result: + logger.error(f"[MCP] {self.name} initialize 失败: {result['error']}") + return False + + # 发送 initialized 通知 + await self._send_notification("notifications/initialized", {}) + return True + + async def _list_tools(self): + """获取工具列表""" + result = await self._send_request("tools/list", {}) + + if "error" in result: + logger.warning(f"[MCP] {self.name} 获取工具列表失败: {result['error']}") + return + + tools_data = result.get("result", {}).get("tools", []) + self.tools.clear() + + for tool_data in tools_data: + tool = MCPTool( + name=tool_data.get("name", ""), + description=tool_data.get("description", ""), + input_schema=tool_data.get("inputSchema", {}), + server_name=self.name + ) + self.tools[tool.name] = tool + logger.debug(f"[MCP] {self.name} 发现工具: {tool.name}") + + async def _send_request(self, method: str, params: Dict[str, Any]) -> Dict[str, Any]: + """发送 JSON-RPC 请求并等待响应""" + if not self.process or not self.process.stdin: + return {"error": "进程未启动"} + + self._request_id += 1 + request_id = self._request_id + + request = { + "jsonrpc": "2.0", + "id": request_id, + "method": method, + "params": params + } + + # 创建 Future 等待响应 + future: asyncio.Future = asyncio.get_event_loop().create_future() + self._pending_requests[request_id] = future + + try: + # 发送请求 + request_json = json.dumps(request) + "\n" + self.process.stdin.write(request_json.encode()) + await self.process.stdin.drain() + + # 等待响应 + result = await asyncio.wait_for(future, timeout=60.0) + return result + + except asyncio.TimeoutError: + self._pending_requests.pop(request_id, None) + return {"error": f"请求超时: {method}"} + except Exception as e: + self._pending_requests.pop(request_id, None) + return {"error": str(e)} + + async def _send_notification(self, method: str, params: Dict[str, Any]): + """发送 JSON-RPC 通知(无需响应)""" + if not self.process or not self.process.stdin: + return + + notification = { + "jsonrpc": "2.0", + "method": method, + "params": params + } + + try: + notification_json = json.dumps(notification) + "\n" + self.process.stdin.write(notification_json.encode()) + await self.process.stdin.drain() + except Exception as e: + logger.warning(f"[MCP] {self.name} 发送通知失败: {e}") + + async def _read_loop(self): + """读取服务器响应的协程""" + if not self.process or not self.process.stdout: + return + + try: + while True: + line = await self.process.stdout.readline() + if not line: + break + + try: + message = json.loads(line.decode().strip()) + await self._handle_message(message) + except json.JSONDecodeError: + # 可能是服务器的日志输出,忽略 + pass + + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"[MCP] {self.name} 读取错误: {e}") + + async def _handle_message(self, message: Dict[str, Any]): + """处理收到的消息""" + # 检查是否是响应 + if "id" in message: + request_id = message["id"] + future = self._pending_requests.pop(request_id, None) + if future and not future.done(): + if "error" in message: + future.set_result({"error": message["error"].get("message", "未知错误")}) + else: + future.set_result({"result": message.get("result")}) + + # 处理通知(如 log, progress 等) + elif "method" in message: + method = message["method"] + if method == "notifications/message": + # 服务器日志 + params = message.get("params", {}) + level = params.get("level", "info") + data = params.get("data", "") + logger.debug(f"[MCP] {self.name} [{level}]: {data}") + + +def subprocess_creation_flags() -> int: + """获取 Windows 子进程创建标志""" + if sys.platform == "win32": + import subprocess + return subprocess.CREATE_NO_WINDOW + return 0 + + +class MCPManager: + """ + MCP 管理器 + + 管理多个 MCP 服务器的连接和工具注册 + """ + + def __init__(self, tool_timeout: float = 60.0, server_start_timeout: float = 30.0): + self.tool_timeout = tool_timeout + self.server_start_timeout = server_start_timeout + self.clients: Dict[str, MCPClient] = {} + self._tool_to_server: Dict[str, str] = {} # 工具名 -> 服务器名 + self._tool_original_name: Dict[str, str] = {} # 带前缀工具名 -> 原始工具名 + + async def add_server(self, config: MCPServerConfig) -> bool: + """ + 添加并连接 MCP 服务器 + + Args: + config: 服务器配置 + + Returns: + 是否成功 + """ + if not config.enabled: + logger.info(f"[MCP] 服务器 {config.name} 已禁用,跳过") + return False + + if config.name in self.clients: + logger.warning(f"[MCP] 服务器 {config.name} 已存在") + return False + + # 使用工厂函数创建合适的客户端 + client = create_mcp_client(config, self.server_start_timeout) + success = await client.connect() + + if success: + self.clients[config.name] = client + + # 记录工具映射 + for tool_name, tool in client.tools.items(): + prefixed_name = f"{config.tool_prefix}_{tool_name}" if config.tool_prefix else tool_name + prefixed_name = prefixed_name.replace("-", "_").replace(".", "_") + self._tool_to_server[prefixed_name] = config.name + self._tool_original_name[prefixed_name] = tool_name + + return True + + return False + + async def remove_server(self, name: str): + """移除 MCP 服务器""" + client = self.clients.pop(name, None) + if client: + await client.disconnect() + + # 清理工具映射 + to_remove = [k for k, v in self._tool_to_server.items() if v == name] + for k in to_remove: + self._tool_to_server.pop(k, None) + self._tool_original_name.pop(k, None) + + async def shutdown(self): + """关闭所有服务器""" + for name in list(self.clients.keys()): + await self.remove_server(name) + + def get_all_tools(self) -> List[Dict[str, Any]]: + """获取所有工具的 OpenAI schema""" + tools = [] + for client in self.clients.values(): + prefix = client.config.tool_prefix + for tool in client.tools.values(): + tools.append(tool.to_openai_schema(prefix)) + return tools + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """ + 调用工具 + + Args: + tool_name: 工具名(可能带前缀) + arguments: 工具参数 + + Returns: + 执行结果 + """ + server_name = self._tool_to_server.get(tool_name) + if not server_name: + return {"success": False, "error": f"工具 {tool_name} 不存在"} + + client = self.clients.get(server_name) + if not client: + return {"success": False, "error": f"服务器 {server_name} 未连接"} + + # 获取原始工具名(去掉前缀) + original_name = self._tool_original_name.get(tool_name, tool_name) + + return await client.call_tool(original_name, arguments) + + def list_servers(self) -> List[Dict[str, Any]]: + """列出所有服务器状态""" + return [ + { + "name": name, + "connected": client.is_connected, + "tools_count": len(client.tools), + "tools": list(client.tools.keys()) + } + for name, client in self.clients.items() + ] + + +class MCPHttpClient: + """ + MCP HTTP 客户端 + + 通过 HTTP 与 MCP 服务器通信(如智谱 AI 的 MCP 服务) + """ + + def __init__(self, config: MCPServerConfig, start_timeout: float = 30.0): + self.config = config + self.start_timeout = start_timeout + self.tools: Dict[str, MCPTool] = {} + self._request_id = 0 + self._initialized = False + self._session: Optional[aiohttp.ClientSession] = None + self._lock = asyncio.Lock() + + @property + def name(self) -> str: + return self.config.name + + @property + def is_connected(self) -> bool: + return self._initialized and self._session is not None and not self._session.closed + + async def connect(self) -> bool: + """连接到 HTTP MCP 服务器""" + if not AIOHTTP_AVAILABLE: + logger.error(f"[MCP] HTTP 客户端需要 aiohttp,请安装: pip install aiohttp") + return False + + if self.is_connected: + return True + + async with self._lock: + try: + # 创建 HTTP 会话 + timeout = aiohttp.ClientTimeout(total=self.start_timeout) + self._session = aiohttp.ClientSession(timeout=timeout) + + logger.info(f"[MCP] 连接 HTTP 服务器 {self.name}: {self.config.url}") + + # 尝试 initialize,如果失败则跳过 + init_result = await self._initialize() + if not init_result: + logger.warning(f"[MCP] {self.name} initialize 失败,尝试直接获取工具列表") + + # 获取工具列表 + await self._list_tools() + + # 只要有工具就算连接成功 + if self.tools: + self._initialized = True + logger.success(f"[MCP] HTTP 服务器 {self.name} 已连接,发现 {len(self.tools)} 个工具") + return True + else: + logger.error(f"[MCP] HTTP 服务器 {self.name} 未发现任何工具") + await self.disconnect() + return False + + except asyncio.TimeoutError: + logger.error(f"[MCP] HTTP 服务器 {self.name} 连接超时") + await self.disconnect() + return False + except Exception as e: + logger.error(f"[MCP] HTTP 服务器 {self.name} 连接失败: {e}") + await self.disconnect() + return False + + async def disconnect(self): + """断开连接""" + if self._session and not self._session.closed: + await self._session.close() + self._session = None + self._initialized = False + self.tools.clear() + logger.info(f"[MCP] HTTP 服务器 {self.name} 已断开") + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """调用 MCP 工具""" + if not self.is_connected: + return {"success": False, "error": "MCP 服务器未连接"} + + logger.info(f"[MCP] {self.name} 调用工具: {tool_name}, 参数: {arguments}") + + result = await self._send_request("tools/call", { + "name": tool_name, + "arguments": arguments + }) + + logger.debug(f"[MCP] {self.name} 工具原始结果: {str(result)[:500]}") + + if "error" in result: + return {"success": False, "error": result["error"]} + + # 解析结果 + tool_result = result.get("result", {}) + content = tool_result.get("content", []) if isinstance(tool_result, dict) else [] + + # 检查是否有 isError 标志 + is_error = tool_result.get("isError", False) if isinstance(tool_result, dict) else False + + # 如果 result 本身就是内容列表 + if isinstance(tool_result, list): + content = tool_result + + if content: + texts = [] + for item in content: + if isinstance(item, str): + texts.append(item) + elif isinstance(item, dict): + if item.get("type") == "text": + texts.append(item.get("text", "")) + elif item.get("type") == "image": + texts.append(f"[图片: {item.get('mimeType', 'image')}]") + elif item.get("type") == "resource": + texts.append(f"[资源: {item.get('uri', '')}]") + elif "text" in item: + texts.append(item.get("text", "")) + + message = "\n".join(texts) if texts else "执行成功" + + # 检查消息内容是否包含错误信息 + if is_error or "error" in message.lower()[:50]: + logger.warning(f"[MCP] {self.name} 工具返回错误: {message[:200]}") + return {"success": False, "error": message} + + logger.info(f"[MCP] {self.name} 工具返回: {message[:200]}...") + return { + "success": True, + "message": message, + "data": content + } + + # 如果没有 content,尝试直接使用 result + if tool_result: + message = str(tool_result) if not isinstance(tool_result, dict) else json.dumps(tool_result, ensure_ascii=False) + logger.info(f"[MCP] {self.name} 工具返回(原始): {message[:200]}...") + return {"success": True, "message": message, "data": tool_result} + + return {"success": True, "message": "执行成功", "data": result.get("result")} + + async def _initialize(self) -> bool: + """发送 initialize 请求""" + result = await self._send_request("initialize", { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {}}, + "clientInfo": { + "name": "WechatHookBot-MCPManager", + "version": "1.0.0" + } + }) + + if "error" in result: + logger.error(f"[MCP] {self.name} initialize 失败: {result['error']}") + return False + + # 发送 initialized 通知 + await self._send_notification("notifications/initialized", {}) + return True + + async def _list_tools(self): + """获取工具列表""" + result = await self._send_request("tools/list", {}) + + if "error" in result: + logger.warning(f"[MCP] {self.name} 获取工具列表失败: {result['error']}") + return + + tools_data = result.get("result", {}).get("tools", []) + self.tools.clear() + + for tool_data in tools_data: + tool = MCPTool( + name=tool_data.get("name", ""), + description=tool_data.get("description", ""), + input_schema=tool_data.get("inputSchema", {}), + server_name=self.name + ) + self.tools[tool.name] = tool + logger.debug(f"[MCP] {self.name} 发现工具: {tool.name}") + + async def _send_request(self, method: str, params: Dict[str, Any]) -> Dict[str, Any]: + """发送 JSON-RPC HTTP 请求""" + if not self._session: + return {"error": "会话未创建"} + + self._request_id += 1 + request_id = self._request_id + + request = { + "jsonrpc": "2.0", + "id": request_id, + "method": method, + "params": params + } + + # 构建请求头 + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + } + + # 处理 Authorization header + for key, value in self.config.headers.items(): + if key.lower() == "authorization": + # 如果没有 Bearer 前缀,添加它 + if not value.startswith("Bearer "): + headers["Authorization"] = f"Bearer {value}" + else: + headers["Authorization"] = value + else: + headers[key] = value + + logger.debug(f"[MCP] {self.name} 发送请求: {method}, params: {params}") + + try: + async with self._session.post( + self.config.url, + json=request, + headers=headers + ) as response: + if response.status != 200: + text = await response.text() + logger.error(f"[MCP] {self.name} HTTP {response.status}: {text[:500]}") + return {"error": f"HTTP {response.status}: {text[:200]}"} + + # 检查响应类型 + content_type = response.headers.get("Content-Type", "") + + if "text/event-stream" in content_type: + # SSE 响应,读取第一个事件 + text = await response.text() + # 解析 SSE 格式 + for line in text.split("\n"): + if line.startswith("data:"): + data_str = line[5:].strip() + if data_str: + try: + data = json.loads(data_str) + if "error" in data: + return {"error": data["error"].get("message", "未知错误")} + return {"result": data.get("result")} + except json.JSONDecodeError: + pass + return {"error": "无法解析 SSE 响应"} + else: + # JSON 响应 + data = await response.json() + + if "error" in data: + return {"error": data["error"].get("message", "未知错误")} + + return {"result": data.get("result")} + + except asyncio.TimeoutError: + return {"error": f"请求超时: {method}"} + except Exception as e: + logger.error(f"[MCP] {self.name} 请求异常: {e}") + return {"error": str(e)} + + async def _send_notification(self, method: str, params: Dict[str, Any]): + """发送 JSON-RPC 通知""" + if not self._session: + return + + notification = { + "jsonrpc": "2.0", + "method": method, + "params": params + } + + headers = { + "Content-Type": "application/json", + **self.config.headers + } + + try: + async with self._session.post( + self.config.url, + json=notification, + headers=headers + ) as response: + pass # 通知不需要响应 + except Exception as e: + logger.warning(f"[MCP] {self.name} 发送通知失败: {e}") + + +def create_mcp_client(config: MCPServerConfig, start_timeout: float = 30.0): + """ + 工厂函数:根据配置创建合适的 MCP 客户端 + + Args: + config: 服务器配置 + start_timeout: 启动超时 + + Returns: + MCPClient 或 MCPHttpClient 实例 + """ + if config.transport == "http" or config.url: + return MCPHttpClient(config, start_timeout) + else: + return MCPClient(config, start_timeout) diff --git a/plugins/TravelPlanner/amap_client.py b/plugins/TravelPlanner/amap_client.py index 6f8c515..5964341 100644 --- a/plugins/TravelPlanner/amap_client.py +++ b/plugins/TravelPlanner/amap_client.py @@ -1,860 +1,875 @@ -""" -高德地图 API 客户端封装 - -提供以下功能: -- 地理编码:地址 → 坐标 -- 逆地理编码:坐标 → 地址 -- 行政区域查询:获取城市 adcode -- 天气查询:实况/预报天气 -- POI 搜索:关键字搜索、周边搜索 -- 路径规划:驾车、公交、步行、骑行 -""" - -from __future__ import annotations - -import hashlib -import aiohttp -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Literal -from loguru import logger - - -@dataclass -class AmapConfig: - """高德 API 配置""" - api_key: str - secret: str = "" # 安全密钥,用于数字签名 - timeout: int = 30 - - -class AmapClient: - """高德地图 API 客户端""" - - BASE_URL = "https://restapi.amap.com" - - def __init__(self, config: AmapConfig): - self.config = config - self._session: Optional[aiohttp.ClientSession] = None - - @staticmethod - def _safe_int(value, default: int = 0) -> int: - """安全地将值转换为整数,处理列表、None、空字符串等情况""" - if value is None: - return default - if isinstance(value, list): - return default - if isinstance(value, (int, float)): - return int(value) - if isinstance(value, str): - if not value.strip(): - return default - try: - return int(float(value)) - except (ValueError, TypeError): - return default - return default - - @staticmethod - def _safe_float(value, default: float = 0.0) -> float: - """安全地将值转换为浮点数""" - if value is None: - return default - if isinstance(value, list): - return default - if isinstance(value, (int, float)): - return float(value) - if isinstance(value, str): - if not value.strip(): - return default - try: - return float(value) - except (ValueError, TypeError): - return default - return default - - @staticmethod - def _safe_str(value, default: str = "") -> str: - """安全地将值转换为字符串,处理列表等情况""" - if value is None: - return default - if isinstance(value, list): - return default - return str(value) - - async def _get_session(self) -> aiohttp.ClientSession: - """获取或创建 HTTP 会话""" - if self._session is None or self._session.closed: - timeout = aiohttp.ClientTimeout(total=self.config.timeout) - self._session = aiohttp.ClientSession(timeout=timeout) - return self._session - - async def close(self): - """关闭 HTTP 会话""" - if self._session and not self._session.closed: - await self._session.close() - - def _generate_signature(self, params: Dict[str, Any]) -> str: - """ - 生成数字签名 - - 算法: - 1. 将请求参数按参数名升序排序 - 2. 按 key=value 格式拼接,用 & 连接 - 3. 最后拼接上私钥(secret) - 4. 对整个字符串进行 MD5 加密 - - Args: - params: 请求参数(不含 sig) - - Returns: - MD5 签名字符串 - """ - # 按参数名升序排序 - sorted_params = sorted(params.items(), key=lambda x: x[0]) - # 拼接成 key=value&key=value 格式 - param_str = "&".join(f"{k}={v}" for k, v in sorted_params) - # 拼接私钥 - sign_str = param_str + self.config.secret - # MD5 加密 - return hashlib.md5(sign_str.encode('utf-8')).hexdigest() - - async def _request(self, endpoint: str, params: Dict[str, Any]) -> Dict[str, Any]: - """ - 发送 API 请求 - - Args: - endpoint: API 端点路径 - params: 请求参数 - - Returns: - API 响应数据 - """ - params["key"] = self.config.api_key - params["output"] = "JSON" - - # 如果配置了安全密钥,生成数字签名 - if self.config.secret: - params["sig"] = self._generate_signature(params) - - url = f"{self.BASE_URL}{endpoint}" - session = await self._get_session() - - try: - async with session.get(url, params=params) as response: - data = await response.json() - - # 检查 API 状态 - status = data.get("status", "0") - if status != "1": - info = data.get("info", "未知错误") - infocode = data.get("infocode", "") - logger.warning(f"高德 API 错误: {info} (code: {infocode})") - return {"success": False, "error": info, "code": infocode} - - return {"success": True, "data": data} - - except aiohttp.ClientError as e: - logger.error(f"高德 API 请求失败: {e}") - return {"success": False, "error": str(e)} - except Exception as e: - logger.error(f"高德 API 未知错误: {e}") - return {"success": False, "error": str(e)} - - # ==================== 地理编码 ==================== - - async def geocode(self, address: str, city: str = None) -> Dict[str, Any]: - """ - 地理编码:将地址转换为坐标 - - Args: - address: 结构化地址,如 "北京市朝阳区阜通东大街6号" - city: 指定城市(可选) - - Returns: - { - "success": True, - "location": "116.480881,39.989410", - "adcode": "110105", - "city": "北京市", - "district": "朝阳区", - "level": "门址" - } - """ - params = {"address": address} - if city: - params["city"] = city - - result = await self._request("/v3/geocode/geo", params) - - if not result["success"]: - return result - - geocodes = result["data"].get("geocodes", []) - if not geocodes: - return {"success": False, "error": "未找到该地址"} - - geo = geocodes[0] - return { - "success": True, - "location": geo.get("location", ""), - "adcode": geo.get("adcode", ""), - "province": geo.get("province", ""), - "city": geo.get("city", ""), - "district": geo.get("district", ""), - "level": geo.get("level", ""), - "formatted_address": geo.get("formatted_address", address) - } - - async def reverse_geocode( - self, - location: str, - radius: int = 1000, - extensions: str = "base" - ) -> Dict[str, Any]: - """ - 逆地理编码:将坐标转换为地址 - - Args: - location: 经纬度坐标,格式 "lng,lat" - radius: 搜索半径(米),0-3000 - extensions: base 或 all - - Returns: - 地址信息 - """ - params = { - "location": location, - "radius": min(radius, 3000), - "extensions": extensions - } - - result = await self._request("/v3/geocode/regeo", params) - - if not result["success"]: - return result - - regeocode = result["data"].get("regeocode", {}) - address_component = regeocode.get("addressComponent", {}) - - return { - "success": True, - "formatted_address": regeocode.get("formatted_address", ""), - "province": address_component.get("province", ""), - "city": address_component.get("city", ""), - "district": address_component.get("district", ""), - "adcode": address_component.get("adcode", ""), - "township": address_component.get("township", ""), - "pois": regeocode.get("pois", []) if extensions == "all" else [] - } - - # ==================== 行政区域查询 ==================== - - async def get_district( - self, - keywords: str = None, - subdistrict: int = 1 - ) -> Dict[str, Any]: - """ - 行政区域查询 - - Args: - keywords: 查询关键字(城市名、adcode 等) - subdistrict: 返回下级行政区级数(0-3) - - Returns: - 行政区域信息,包含 adcode、citycode 等 - """ - params = {"subdistrict": subdistrict} - if keywords: - params["keywords"] = keywords - - result = await self._request("/v3/config/district", params) - - if not result["success"]: - return result - - districts = result["data"].get("districts", []) - if not districts: - return {"success": False, "error": "未找到该行政区域"} - - district = districts[0] - return { - "success": True, - "name": district.get("name", ""), - "adcode": district.get("adcode", ""), - "citycode": district.get("citycode", ""), - "center": district.get("center", ""), - "level": district.get("level", ""), - "districts": district.get("districts", []) - } - - # ==================== 天气查询 ==================== - - async def get_weather( - self, - city: str, - extensions: Literal["base", "all"] = "all" - ) -> Dict[str, Any]: - """ - 天气查询 - - Args: - city: 城市 adcode(如 110000)或城市名 - extensions: base=实况天气,all=预报天气(未来4天) - - Returns: - 天气信息 - """ - # 如果传入的是城市名,先获取 adcode - if not city.isdigit(): - district_result = await self.get_district(city) - if not district_result["success"]: - return {"success": False, "error": f"无法识别城市: {city}"} - city = district_result["adcode"] - - params = { - "city": city, - "extensions": extensions - } - - result = await self._request("/v3/weather/weatherInfo", params) - - if not result["success"]: - return result - - data = result["data"] - - if extensions == "base": - # 实况天气 - lives = data.get("lives", []) - if not lives: - return {"success": False, "error": "未获取到天气数据"} - - live = lives[0] - return { - "success": True, - "type": "live", - "city": live.get("city", ""), - "weather": live.get("weather", ""), - "temperature": live.get("temperature", ""), - "winddirection": live.get("winddirection", ""), - "windpower": live.get("windpower", ""), - "humidity": live.get("humidity", ""), - "reporttime": live.get("reporttime", "") - } - else: - # 预报天气 - forecasts = data.get("forecasts", []) - if not forecasts: - return {"success": False, "error": "未获取到天气预报数据"} - - forecast = forecasts[0] - casts = forecast.get("casts", []) - - return { - "success": True, - "type": "forecast", - "city": forecast.get("city", ""), - "province": forecast.get("province", ""), - "reporttime": forecast.get("reporttime", ""), - "forecasts": [ - { - "date": cast.get("date", ""), - "week": cast.get("week", ""), - "dayweather": cast.get("dayweather", ""), - "nightweather": cast.get("nightweather", ""), - "daytemp": cast.get("daytemp", ""), - "nighttemp": cast.get("nighttemp", ""), - "daywind": cast.get("daywind", ""), - "nightwind": cast.get("nightwind", ""), - "daypower": cast.get("daypower", ""), - "nightpower": cast.get("nightpower", "") - } - for cast in casts - ] - } - - # ==================== POI 搜索 ==================== - - async def search_poi( - self, - keywords: str = None, - types: str = None, - city: str = None, - citylimit: bool = True, - offset: int = 20, - page: int = 1, - extensions: str = "all" - ) -> Dict[str, Any]: - """ - 关键字搜索 POI - - Args: - keywords: 查询关键字 - types: POI 类型代码,多个用 | 分隔 - city: 城市名或 adcode - citylimit: 是否仅返回指定城市 - offset: 每页数量(建议不超过25) - page: 页码 - extensions: base 或 all - - Returns: - POI 列表 - """ - params = { - "offset": min(offset, 25), - "page": page, - "extensions": extensions - } - - if keywords: - params["keywords"] = keywords - if types: - params["types"] = types - if city: - params["city"] = city - params["citylimit"] = "true" if citylimit else "false" - - result = await self._request("/v3/place/text", params) - - if not result["success"]: - return result - - pois = result["data"].get("pois", []) - count = self._safe_int(result["data"].get("count", 0)) - - return { - "success": True, - "count": count, - "pois": [self._format_poi(poi) for poi in pois] - } - - async def search_around( - self, - location: str, - keywords: str = None, - types: str = None, - radius: int = 3000, - offset: int = 20, - page: int = 1, - extensions: str = "all" - ) -> Dict[str, Any]: - """ - 周边搜索 POI - - Args: - location: 中心点坐标,格式 "lng,lat" - keywords: 查询关键字 - types: POI 类型代码 - radius: 搜索半径(米),0-50000 - offset: 每页数量 - page: 页码 - extensions: base 或 all - - Returns: - POI 列表 - """ - params = { - "location": location, - "radius": min(radius, 50000), - "offset": min(offset, 25), - "page": page, - "extensions": extensions, - "sortrule": "distance" - } - - if keywords: - params["keywords"] = keywords - if types: - params["types"] = types - - result = await self._request("/v3/place/around", params) - - if not result["success"]: - return result - - pois = result["data"].get("pois", []) - count = self._safe_int(result["data"].get("count", 0)) - - return { - "success": True, - "count": count, - "pois": [self._format_poi(poi) for poi in pois] - } - - def _format_poi(self, poi: Dict[str, Any]) -> Dict[str, Any]: - """格式化 POI 数据""" - biz_ext = poi.get("biz_ext", {}) or {} - return { - "id": poi.get("id", ""), - "name": poi.get("name", ""), - "type": poi.get("type", ""), - "address": poi.get("address", ""), - "location": poi.get("location", ""), - "tel": poi.get("tel", ""), - "distance": poi.get("distance", ""), - "pname": poi.get("pname", ""), - "cityname": poi.get("cityname", ""), - "adname": poi.get("adname", ""), - "rating": biz_ext.get("rating", ""), - "cost": biz_ext.get("cost", "") - } - - # ==================== 路径规划 ==================== - - async def route_driving( - self, - origin: str, - destination: str, - strategy: int = 10, - waypoints: str = None, - extensions: str = "base" - ) -> Dict[str, Any]: - """ - 驾车路径规划 - - Args: - origin: 起点坐标 "lng,lat" - destination: 终点坐标 "lng,lat" - strategy: 驾车策略(10=躲避拥堵,13=不走高速,14=避免收费) - waypoints: 途经点,多个用 ; 分隔 - extensions: base 或 all - - Returns: - 路径规划结果 - """ - params = { - "origin": origin, - "destination": destination, - "strategy": strategy, - "extensions": extensions - } - if waypoints: - params["waypoints"] = waypoints - - result = await self._request("/v3/direction/driving", params) - - if not result["success"]: - return result - - route = result["data"].get("route", {}) - paths = route.get("paths", []) - - if not paths: - return {"success": False, "error": "未找到驾车路线"} - - path = paths[0] - return { - "success": True, - "mode": "driving", - "origin": route.get("origin", ""), - "destination": route.get("destination", ""), - "distance": self._safe_int(path.get("distance", 0)), - "duration": self._safe_int(path.get("duration", 0)), - "tolls": self._safe_float(path.get("tolls", 0)), - "toll_distance": self._safe_int(path.get("toll_distance", 0)), - "traffic_lights": self._safe_int(path.get("traffic_lights", 0)), - "taxi_cost": self._safe_str(route.get("taxi_cost", "")), - "strategy": path.get("strategy", ""), - "steps": self._format_driving_steps(path.get("steps", [])) - } - - async def route_transit( - self, - origin: str, - destination: str, - city: str, - cityd: str = None, - strategy: int = 0, - extensions: str = "all" - ) -> Dict[str, Any]: - """ - 公交路径规划(含火车、地铁) - - Args: - origin: 起点坐标 "lng,lat" - destination: 终点坐标 "lng,lat" - city: 起点城市 - cityd: 终点城市(跨城时必填) - strategy: 0=最快,1=最省钱,2=最少换乘,3=最少步行 - extensions: base 或 all - - Returns: - 公交路径规划结果 - """ - params = { - "origin": origin, - "destination": destination, - "city": city, - "strategy": strategy, - "extensions": extensions - } - if cityd: - params["cityd"] = cityd - - result = await self._request("/v3/direction/transit/integrated", params) - - if not result["success"]: - return result - - route = result["data"].get("route", {}) - transits = route.get("transits", []) - - if not transits: - return {"success": False, "error": "未找到公交路线"} - - # 返回前3个方案 - formatted_transits = [] - for transit in transits[:3]: - segments = transit.get("segments", []) - formatted_segments = [] - - for seg in segments: - # 步行段 - walking = seg.get("walking", {}) - if walking and walking.get("distance"): - formatted_segments.append({ - "type": "walking", - "distance": self._safe_int(walking.get("distance", 0)), - "duration": self._safe_int(walking.get("duration", 0)) - }) - - # 公交/地铁段 - bus_info = seg.get("bus", {}) - buslines = bus_info.get("buslines", []) - if buslines: - line = buslines[0] - formatted_segments.append({ - "type": "bus", - "name": self._safe_str(line.get("name", "")), - "departure_stop": self._safe_str(line.get("departure_stop", {}).get("name", "")), - "arrival_stop": self._safe_str(line.get("arrival_stop", {}).get("name", "")), - "via_num": self._safe_int(line.get("via_num", 0)), - "distance": self._safe_int(line.get("distance", 0)), - "duration": self._safe_int(line.get("duration", 0)) - }) - - # 火车段 - railway = seg.get("railway", {}) - if railway and railway.get("name"): - formatted_segments.append({ - "type": "railway", - "name": self._safe_str(railway.get("name", "")), - "trip": self._safe_str(railway.get("trip", "")), - "departure_stop": self._safe_str(railway.get("departure_stop", {}).get("name", "")), - "arrival_stop": self._safe_str(railway.get("arrival_stop", {}).get("name", "")), - "departure_time": self._safe_str(railway.get("departure_stop", {}).get("time", "")), - "arrival_time": self._safe_str(railway.get("arrival_stop", {}).get("time", "")), - "distance": self._safe_int(railway.get("distance", 0)), - "time": self._safe_str(railway.get("time", "")) - }) - - formatted_transits.append({ - "cost": self._safe_str(transit.get("cost", "")), - "duration": self._safe_int(transit.get("duration", 0)), - "walking_distance": self._safe_int(transit.get("walking_distance", 0)), - "segments": formatted_segments - }) - - return { - "success": True, - "mode": "transit", - "origin": route.get("origin", ""), - "destination": route.get("destination", ""), - "distance": self._safe_int(route.get("distance", 0)), - "taxi_cost": self._safe_str(route.get("taxi_cost", "")), - "transits": formatted_transits - } - - async def route_walking( - self, - origin: str, - destination: str - ) -> Dict[str, Any]: - """ - 步行路径规划 - - Args: - origin: 起点坐标 "lng,lat" - destination: 终点坐标 "lng,lat" - - Returns: - 步行路径规划结果 - """ - params = { - "origin": origin, - "destination": destination - } - - result = await self._request("/v3/direction/walking", params) - - if not result["success"]: - return result - - route = result["data"].get("route", {}) - paths = route.get("paths", []) - - if not paths: - return {"success": False, "error": "未找到步行路线"} - - path = paths[0] - return { - "success": True, - "mode": "walking", - "origin": route.get("origin", ""), - "destination": route.get("destination", ""), - "distance": self._safe_int(path.get("distance", 0)), - "duration": self._safe_int(path.get("duration", 0)) - } - - async def route_bicycling( - self, - origin: str, - destination: str - ) -> Dict[str, Any]: - """ - 骑行路径规划 - - Args: - origin: 起点坐标 "lng,lat" - destination: 终点坐标 "lng,lat" - - Returns: - 骑行路径规划结果 - """ - params = { - "origin": origin, - "destination": destination - } - - # 骑行用 v4 接口 - result = await self._request("/v4/direction/bicycling", params) - - if not result["success"]: - return result - - data = result["data"].get("data", {}) - paths = data.get("paths", []) - - if not paths: - return {"success": False, "error": "未找到骑行路线"} - - path = paths[0] - return { - "success": True, - "mode": "bicycling", - "origin": data.get("origin", ""), - "destination": data.get("destination", ""), - "distance": self._safe_int(path.get("distance", 0)), - "duration": self._safe_int(path.get("duration", 0)) - } - - def _format_driving_steps(self, steps: List[Dict]) -> List[Dict]: - """格式化驾车步骤""" - return [ - { - "instruction": step.get("instruction", ""), - "road": step.get("road", ""), - "distance": self._safe_int(step.get("distance", 0)), - "duration": self._safe_int(step.get("duration", 0)), - "orientation": step.get("orientation", "") - } - for step in steps[:10] # 只返回前10步 - ] - - # ==================== 距离测量 ==================== - - async def get_distance( - self, - origins: str, - destination: str, - type: int = 1 - ) -> Dict[str, Any]: - """ - 距离测量 - - Args: - origins: 起点坐标,多个用 | 分隔 - destination: 终点坐标 - type: 0=直线距离,1=驾车距离,3=步行距离 - - Returns: - 距离信息 - """ - params = { - "origins": origins, - "destination": destination, - "type": type - } - - result = await self._request("/v3/distance", params) - - if not result["success"]: - return result - - results = result["data"].get("results", []) - if not results: - return {"success": False, "error": "无法计算距离"} - - return { - "success": True, - "results": [ - { - "origin_id": r.get("origin_id", ""), - "distance": self._safe_int(r.get("distance", 0)), - "duration": self._safe_int(r.get("duration", 0)) - } - for r in results - ] - } - - # ==================== 输入提示 ==================== - - async def input_tips( - self, - keywords: str, - city: str = None, - citylimit: bool = False, - datatype: str = "all" - ) -> Dict[str, Any]: - """ - 输入提示 - - Args: - keywords: 查询关键字 - city: 城市名或 adcode - citylimit: 是否仅返回指定城市 - datatype: all/poi/bus/busline - - Returns: - 提示列表 - """ - params = { - "keywords": keywords, - "datatype": datatype - } - if city: - params["city"] = city - params["citylimit"] = "true" if citylimit else "false" - - result = await self._request("/v3/assistant/inputtips", params) - - if not result["success"]: - return result - - tips = result["data"].get("tips", []) - return { - "success": True, - "tips": [ - { - "id": tip.get("id", ""), - "name": tip.get("name", ""), - "district": tip.get("district", ""), - "adcode": tip.get("adcode", ""), - "location": tip.get("location", ""), - "address": tip.get("address", "") - } - for tip in tips - if tip.get("location") # 过滤无坐标的结果 - ] - } +""" +高德地图 API 客户端封装 + +提供以下功能: +- 地理编码:地址 → 坐标 +- 逆地理编码:坐标 → 地址 +- 行政区域查询:获取城市 adcode +- 天气查询:实况/预报天气 +- POI 搜索:关键字搜索、周边搜索 +- 路径规划:驾车、公交、步行、骑行 +""" + +from __future__ import annotations + +import hashlib +import aiohttp +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Literal +from loguru import logger + + +@dataclass +class AmapConfig: + """高德 API 配置""" + api_key: str + secret: str = "" # 安全密钥,用于数字签名 + timeout: int = 30 + + +class AmapClient: + """高德地图 API 客户端""" + + BASE_URL = "https://restapi.amap.com" + + def __init__(self, config: AmapConfig): + self.config = config + self._session: Optional[aiohttp.ClientSession] = None + + @staticmethod + def _safe_int(value, default: int = 0) -> int: + """安全地将值转换为整数,处理列表、None、空字符串等情况""" + if value is None: + return default + if isinstance(value, list): + return default + if isinstance(value, (int, float)): + return int(value) + if isinstance(value, str): + if not value.strip(): + return default + try: + return int(float(value)) + except (ValueError, TypeError): + return default + return default + + @staticmethod + def _safe_float(value, default: float = 0.0) -> float: + """安全地将值转换为浮点数""" + if value is None: + return default + if isinstance(value, list): + return default + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + if not value.strip(): + return default + try: + return float(value) + except (ValueError, TypeError): + return default + return default + + @staticmethod + def _safe_str(value, default: str = "") -> str: + """安全地将值转换为字符串,处理列表等情况""" + if value is None: + return default + if isinstance(value, list): + return default + return str(value) + + async def _get_session(self) -> aiohttp.ClientSession: + """获取或创建 HTTP 会话""" + if self._session is None or self._session.closed: + timeout = aiohttp.ClientTimeout(total=self.config.timeout) + self._session = aiohttp.ClientSession(timeout=timeout) + return self._session + + async def close(self): + """关闭 HTTP 会话""" + if self._session and not self._session.closed: + await self._session.close() + + def _generate_signature(self, params: Dict[str, Any]) -> str: + """ + 生成数字签名 + + 算法: + 1. 将请求参数按参数名升序排序 + 2. 按 key=value 格式拼接,用 & 连接 + 3. 最后拼接上私钥(secret) + 4. 对整个字符串进行 MD5 加密 + + Args: + params: 请求参数(不含 sig) + + Returns: + MD5 签名字符串 + """ + # 按参数名升序排序 + sorted_params = sorted(params.items(), key=lambda x: x[0]) + # 拼接成 key=value&key=value 格式 + param_str = "&".join(f"{k}={v}" for k, v in sorted_params) + # 拼接私钥 + sign_str = param_str + self.config.secret + # MD5 加密 + return hashlib.md5(sign_str.encode('utf-8')).hexdigest() + + async def _request(self, endpoint: str, params: Dict[str, Any]) -> Dict[str, Any]: + """ + 发送 API 请求 + + Args: + endpoint: API 端点路径 + params: 请求参数 + + Returns: + API 响应数据 + """ + params["key"] = self.config.api_key + params["output"] = "JSON" + + # 如果配置了安全密钥,生成数字签名 + if self.config.secret: + params["sig"] = self._generate_signature(params) + + url = f"{self.BASE_URL}{endpoint}" + session = await self._get_session() + + try: + async with session.get(url, params=params) as response: + data = await response.json() + + # 检查 API 状态 + status = data.get("status", "0") + if status != "1": + info = data.get("info", "未知错误") + infocode = data.get("infocode", "") + logger.warning(f"高德 API 错误: {info} (code: {infocode})") + return {"success": False, "error": info, "code": infocode} + + return {"success": True, "data": data} + + except aiohttp.ClientError as e: + logger.error(f"高德 API 请求失败: {e}") + return {"success": False, "error": str(e)} + except Exception as e: + logger.error(f"高德 API 未知错误: {e}") + return {"success": False, "error": str(e)} + + # ==================== 地理编码 ==================== + + async def geocode(self, address: str, city: str = None) -> Dict[str, Any]: + """ + 地理编码:将地址转换为坐标 + + Args: + address: 结构化地址,如 "北京市朝阳区阜通东大街6号" + city: 指定城市(可选) + + Returns: + { + "success": True, + "location": "116.480881,39.989410", + "adcode": "110105", + "city": "北京市", + "district": "朝阳区", + "level": "门址" + } + """ + params = {"address": address} + if city: + params["city"] = city + + result = await self._request("/v3/geocode/geo", params) + + if not result["success"]: + return result + + geocodes = result["data"].get("geocodes", []) + if not geocodes: + return {"success": False, "error": "未找到该地址"} + + geo = geocodes[0] + # 处理高德 API 返回空列表的情况(如直辖市) + city_val = geo.get("city", "") + if isinstance(city_val, list): + city_val = "" + province_val = geo.get("province", "") + if isinstance(province_val, list): + province_val = "" + district_val = geo.get("district", "") + if isinstance(district_val, list): + district_val = "" + + # 如果城市为空,使用省份(直辖市情况) + if not city_val and province_val: + city_val = province_val + + return { + "success": True, + "location": geo.get("location", ""), + "adcode": self._safe_str(geo.get("adcode", "")), + "province": province_val, + "city": city_val, + "district": district_val, + "level": self._safe_str(geo.get("level", "")), + "formatted_address": self._safe_str(geo.get("formatted_address", address)) + } + + async def reverse_geocode( + self, + location: str, + radius: int = 1000, + extensions: str = "base" + ) -> Dict[str, Any]: + """ + 逆地理编码:将坐标转换为地址 + + Args: + location: 经纬度坐标,格式 "lng,lat" + radius: 搜索半径(米),0-3000 + extensions: base 或 all + + Returns: + 地址信息 + """ + params = { + "location": location, + "radius": min(radius, 3000), + "extensions": extensions + } + + result = await self._request("/v3/geocode/regeo", params) + + if not result["success"]: + return result + + regeocode = result["data"].get("regeocode", {}) + address_component = regeocode.get("addressComponent", {}) + + return { + "success": True, + "formatted_address": regeocode.get("formatted_address", ""), + "province": address_component.get("province", ""), + "city": address_component.get("city", ""), + "district": address_component.get("district", ""), + "adcode": address_component.get("adcode", ""), + "township": address_component.get("township", ""), + "pois": regeocode.get("pois", []) if extensions == "all" else [] + } + + # ==================== 行政区域查询 ==================== + + async def get_district( + self, + keywords: str = None, + subdistrict: int = 1 + ) -> Dict[str, Any]: + """ + 行政区域查询 + + Args: + keywords: 查询关键字(城市名、adcode 等) + subdistrict: 返回下级行政区级数(0-3) + + Returns: + 行政区域信息,包含 adcode、citycode 等 + """ + params = {"subdistrict": subdistrict} + if keywords: + params["keywords"] = keywords + + result = await self._request("/v3/config/district", params) + + if not result["success"]: + return result + + districts = result["data"].get("districts", []) + if not districts: + return {"success": False, "error": "未找到该行政区域"} + + district = districts[0] + return { + "success": True, + "name": district.get("name", ""), + "adcode": district.get("adcode", ""), + "citycode": district.get("citycode", ""), + "center": district.get("center", ""), + "level": district.get("level", ""), + "districts": district.get("districts", []) + } + + # ==================== 天气查询 ==================== + + async def get_weather( + self, + city: str, + extensions: Literal["base", "all"] = "all" + ) -> Dict[str, Any]: + """ + 天气查询 + + Args: + city: 城市 adcode(如 110000)或城市名 + extensions: base=实况天气,all=预报天气(未来4天) + + Returns: + 天气信息 + """ + # 如果传入的是城市名,先获取 adcode + if not city.isdigit(): + district_result = await self.get_district(city) + if not district_result["success"]: + return {"success": False, "error": f"无法识别城市: {city}"} + city = district_result["adcode"] + + params = { + "city": city, + "extensions": extensions + } + + result = await self._request("/v3/weather/weatherInfo", params) + + if not result["success"]: + return result + + data = result["data"] + + if extensions == "base": + # 实况天气 + lives = data.get("lives", []) + if not lives: + return {"success": False, "error": "未获取到天气数据"} + + live = lives[0] + return { + "success": True, + "type": "live", + "city": live.get("city", ""), + "weather": live.get("weather", ""), + "temperature": live.get("temperature", ""), + "winddirection": live.get("winddirection", ""), + "windpower": live.get("windpower", ""), + "humidity": live.get("humidity", ""), + "reporttime": live.get("reporttime", "") + } + else: + # 预报天气 + forecasts = data.get("forecasts", []) + if not forecasts: + return {"success": False, "error": "未获取到天气预报数据"} + + forecast = forecasts[0] + casts = forecast.get("casts", []) + + return { + "success": True, + "type": "forecast", + "city": forecast.get("city", ""), + "province": forecast.get("province", ""), + "reporttime": forecast.get("reporttime", ""), + "forecasts": [ + { + "date": cast.get("date", ""), + "week": cast.get("week", ""), + "dayweather": cast.get("dayweather", ""), + "nightweather": cast.get("nightweather", ""), + "daytemp": cast.get("daytemp", ""), + "nighttemp": cast.get("nighttemp", ""), + "daywind": cast.get("daywind", ""), + "nightwind": cast.get("nightwind", ""), + "daypower": cast.get("daypower", ""), + "nightpower": cast.get("nightpower", "") + } + for cast in casts + ] + } + + # ==================== POI 搜索 ==================== + + async def search_poi( + self, + keywords: str = None, + types: str = None, + city: str = None, + citylimit: bool = True, + offset: int = 20, + page: int = 1, + extensions: str = "all" + ) -> Dict[str, Any]: + """ + 关键字搜索 POI + + Args: + keywords: 查询关键字 + types: POI 类型代码,多个用 | 分隔 + city: 城市名或 adcode + citylimit: 是否仅返回指定城市 + offset: 每页数量(建议不超过25) + page: 页码 + extensions: base 或 all + + Returns: + POI 列表 + """ + params = { + "offset": min(offset, 25), + "page": page, + "extensions": extensions + } + + if keywords: + params["keywords"] = keywords + if types: + params["types"] = types + if city: + params["city"] = city + params["citylimit"] = "true" if citylimit else "false" + + result = await self._request("/v3/place/text", params) + + if not result["success"]: + return result + + pois = result["data"].get("pois", []) + count = self._safe_int(result["data"].get("count", 0)) + + return { + "success": True, + "count": count, + "pois": [self._format_poi(poi) for poi in pois] + } + + async def search_around( + self, + location: str, + keywords: str = None, + types: str = None, + radius: int = 3000, + offset: int = 20, + page: int = 1, + extensions: str = "all" + ) -> Dict[str, Any]: + """ + 周边搜索 POI + + Args: + location: 中心点坐标,格式 "lng,lat" + keywords: 查询关键字 + types: POI 类型代码 + radius: 搜索半径(米),0-50000 + offset: 每页数量 + page: 页码 + extensions: base 或 all + + Returns: + POI 列表 + """ + params = { + "location": location, + "radius": min(radius, 50000), + "offset": min(offset, 25), + "page": page, + "extensions": extensions, + "sortrule": "distance" + } + + if keywords: + params["keywords"] = keywords + if types: + params["types"] = types + + result = await self._request("/v3/place/around", params) + + if not result["success"]: + return result + + pois = result["data"].get("pois", []) + count = self._safe_int(result["data"].get("count", 0)) + + return { + "success": True, + "count": count, + "pois": [self._format_poi(poi) for poi in pois] + } + + def _format_poi(self, poi: Dict[str, Any]) -> Dict[str, Any]: + """格式化 POI 数据""" + biz_ext = poi.get("biz_ext", {}) or {} + return { + "id": poi.get("id", ""), + "name": poi.get("name", ""), + "type": poi.get("type", ""), + "address": poi.get("address", ""), + "location": poi.get("location", ""), + "tel": poi.get("tel", ""), + "distance": poi.get("distance", ""), + "pname": poi.get("pname", ""), + "cityname": poi.get("cityname", ""), + "adname": poi.get("adname", ""), + "rating": biz_ext.get("rating", ""), + "cost": biz_ext.get("cost", "") + } + + # ==================== 路径规划 ==================== + + async def route_driving( + self, + origin: str, + destination: str, + strategy: int = 10, + waypoints: str = None, + extensions: str = "base" + ) -> Dict[str, Any]: + """ + 驾车路径规划 + + Args: + origin: 起点坐标 "lng,lat" + destination: 终点坐标 "lng,lat" + strategy: 驾车策略(10=躲避拥堵,13=不走高速,14=避免收费) + waypoints: 途经点,多个用 ; 分隔 + extensions: base 或 all + + Returns: + 路径规划结果 + """ + params = { + "origin": origin, + "destination": destination, + "strategy": strategy, + "extensions": extensions + } + if waypoints: + params["waypoints"] = waypoints + + result = await self._request("/v3/direction/driving", params) + + if not result["success"]: + return result + + route = result["data"].get("route", {}) + paths = route.get("paths", []) + + if not paths: + return {"success": False, "error": "未找到驾车路线"} + + path = paths[0] + return { + "success": True, + "mode": "driving", + "origin": route.get("origin", ""), + "destination": route.get("destination", ""), + "distance": self._safe_int(path.get("distance", 0)), + "duration": self._safe_int(path.get("duration", 0)), + "tolls": self._safe_float(path.get("tolls", 0)), + "toll_distance": self._safe_int(path.get("toll_distance", 0)), + "traffic_lights": self._safe_int(path.get("traffic_lights", 0)), + "taxi_cost": self._safe_str(route.get("taxi_cost", "")), + "strategy": path.get("strategy", ""), + "steps": self._format_driving_steps(path.get("steps", [])) + } + + async def route_transit( + self, + origin: str, + destination: str, + city: str, + cityd: str = None, + strategy: int = 0, + extensions: str = "all" + ) -> Dict[str, Any]: + """ + 公交路径规划(含火车、地铁) + + Args: + origin: 起点坐标 "lng,lat" + destination: 终点坐标 "lng,lat" + city: 起点城市 + cityd: 终点城市(跨城时必填) + strategy: 0=最快,1=最省钱,2=最少换乘,3=最少步行 + extensions: base 或 all + + Returns: + 公交路径规划结果 + """ + params = { + "origin": origin, + "destination": destination, + "city": city, + "strategy": strategy, + "extensions": extensions + } + if cityd: + params["cityd"] = cityd + + result = await self._request("/v3/direction/transit/integrated", params) + + if not result["success"]: + return result + + route = result["data"].get("route", {}) + transits = route.get("transits", []) + + if not transits: + return {"success": False, "error": "未找到公交路线"} + + # 返回前3个方案 + formatted_transits = [] + for transit in transits[:3]: + segments = transit.get("segments", []) + formatted_segments = [] + + for seg in segments: + # 步行段 + walking = seg.get("walking", {}) + if walking and walking.get("distance"): + formatted_segments.append({ + "type": "walking", + "distance": self._safe_int(walking.get("distance", 0)), + "duration": self._safe_int(walking.get("duration", 0)) + }) + + # 公交/地铁段 + bus_info = seg.get("bus", {}) + buslines = bus_info.get("buslines", []) + if buslines: + line = buslines[0] + formatted_segments.append({ + "type": "bus", + "name": self._safe_str(line.get("name", "")), + "departure_stop": self._safe_str(line.get("departure_stop", {}).get("name", "")), + "arrival_stop": self._safe_str(line.get("arrival_stop", {}).get("name", "")), + "via_num": self._safe_int(line.get("via_num", 0)), + "distance": self._safe_int(line.get("distance", 0)), + "duration": self._safe_int(line.get("duration", 0)) + }) + + # 火车段 + railway = seg.get("railway", {}) + if railway and railway.get("name"): + formatted_segments.append({ + "type": "railway", + "name": self._safe_str(railway.get("name", "")), + "trip": self._safe_str(railway.get("trip", "")), + "departure_stop": self._safe_str(railway.get("departure_stop", {}).get("name", "")), + "arrival_stop": self._safe_str(railway.get("arrival_stop", {}).get("name", "")), + "departure_time": self._safe_str(railway.get("departure_stop", {}).get("time", "")), + "arrival_time": self._safe_str(railway.get("arrival_stop", {}).get("time", "")), + "distance": self._safe_int(railway.get("distance", 0)), + "time": self._safe_str(railway.get("time", "")) + }) + + formatted_transits.append({ + "cost": self._safe_str(transit.get("cost", "")), + "duration": self._safe_int(transit.get("duration", 0)), + "walking_distance": self._safe_int(transit.get("walking_distance", 0)), + "segments": formatted_segments + }) + + return { + "success": True, + "mode": "transit", + "origin": route.get("origin", ""), + "destination": route.get("destination", ""), + "distance": self._safe_int(route.get("distance", 0)), + "taxi_cost": self._safe_str(route.get("taxi_cost", "")), + "transits": formatted_transits + } + + async def route_walking( + self, + origin: str, + destination: str + ) -> Dict[str, Any]: + """ + 步行路径规划 + + Args: + origin: 起点坐标 "lng,lat" + destination: 终点坐标 "lng,lat" + + Returns: + 步行路径规划结果 + """ + params = { + "origin": origin, + "destination": destination + } + + result = await self._request("/v3/direction/walking", params) + + if not result["success"]: + return result + + route = result["data"].get("route", {}) + paths = route.get("paths", []) + + if not paths: + return {"success": False, "error": "未找到步行路线"} + + path = paths[0] + return { + "success": True, + "mode": "walking", + "origin": route.get("origin", ""), + "destination": route.get("destination", ""), + "distance": self._safe_int(path.get("distance", 0)), + "duration": self._safe_int(path.get("duration", 0)) + } + + async def route_bicycling( + self, + origin: str, + destination: str + ) -> Dict[str, Any]: + """ + 骑行路径规划 + + Args: + origin: 起点坐标 "lng,lat" + destination: 终点坐标 "lng,lat" + + Returns: + 骑行路径规划结果 + """ + params = { + "origin": origin, + "destination": destination + } + + # 骑行用 v4 接口 + result = await self._request("/v4/direction/bicycling", params) + + if not result["success"]: + return result + + data = result["data"].get("data", {}) + paths = data.get("paths", []) + + if not paths: + return {"success": False, "error": "未找到骑行路线"} + + path = paths[0] + return { + "success": True, + "mode": "bicycling", + "origin": data.get("origin", ""), + "destination": data.get("destination", ""), + "distance": self._safe_int(path.get("distance", 0)), + "duration": self._safe_int(path.get("duration", 0)) + } + + def _format_driving_steps(self, steps: List[Dict]) -> List[Dict]: + """格式化驾车步骤""" + return [ + { + "instruction": step.get("instruction", ""), + "road": step.get("road", ""), + "distance": self._safe_int(step.get("distance", 0)), + "duration": self._safe_int(step.get("duration", 0)), + "orientation": step.get("orientation", "") + } + for step in steps[:10] # 只返回前10步 + ] + + # ==================== 距离测量 ==================== + + async def get_distance( + self, + origins: str, + destination: str, + type: int = 1 + ) -> Dict[str, Any]: + """ + 距离测量 + + Args: + origins: 起点坐标,多个用 | 分隔 + destination: 终点坐标 + type: 0=直线距离,1=驾车距离,3=步行距离 + + Returns: + 距离信息 + """ + params = { + "origins": origins, + "destination": destination, + "type": type + } + + result = await self._request("/v3/distance", params) + + if not result["success"]: + return result + + results = result["data"].get("results", []) + if not results: + return {"success": False, "error": "无法计算距离"} + + return { + "success": True, + "results": [ + { + "origin_id": r.get("origin_id", ""), + "distance": self._safe_int(r.get("distance", 0)), + "duration": self._safe_int(r.get("duration", 0)) + } + for r in results + ] + } + + # ==================== 输入提示 ==================== + + async def input_tips( + self, + keywords: str, + city: str = None, + citylimit: bool = False, + datatype: str = "all" + ) -> Dict[str, Any]: + """ + 输入提示 + + Args: + keywords: 查询关键字 + city: 城市名或 adcode + citylimit: 是否仅返回指定城市 + datatype: all/poi/bus/busline + + Returns: + 提示列表 + """ + params = { + "keywords": keywords, + "datatype": datatype + } + if city: + params["city"] = city + params["citylimit"] = "true" if citylimit else "false" + + result = await self._request("/v3/assistant/inputtips", params) + + if not result["success"]: + return result + + tips = result["data"].get("tips", []) + return { + "success": True, + "tips": [ + { + "id": tip.get("id", ""), + "name": tip.get("name", ""), + "district": tip.get("district", ""), + "adcode": tip.get("adcode", ""), + "location": tip.get("location", ""), + "address": tip.get("address", "") + } + for tip in tips + if tip.get("location") # 过滤无坐标的结果 + ] + } diff --git a/plugins/TravelPlanner/main.py b/plugins/TravelPlanner/main.py index a6af8fb..4becf73 100644 --- a/plugins/TravelPlanner/main.py +++ b/plugins/TravelPlanner/main.py @@ -1,609 +1,805 @@ -""" -旅行规划插件 - -基于高德地图 API,提供以下功能: -- 地点搜索与地理编码 -- 天气查询(实况 + 4天预报) -- 景点/酒店/餐厅搜索 -- 路径规划(驾车/公交/步行) -- 周边搜索 - -支持 LLM 函数调用,可与 AIChat 插件配合使用。 -""" - -import tomllib -from pathlib import Path -from typing import Any, Dict, List -from loguru import logger - -from utils.plugin_base import PluginBase -from .amap_client import AmapClient, AmapConfig - - -class TravelPlanner(PluginBase): - """旅行规划插件""" - - description = "旅行规划助手,支持天气查询、景点搜索、路线规划" - author = "ShiHao" - version = "1.0.0" - - def __init__(self): - super().__init__() - self.config = None - self.amap: AmapClient = None - - async def async_init(self): - """插件异步初始化""" - # 读取配置 - config_path = Path(__file__).parent / "config.toml" - with open(config_path, "rb") as f: - self.config = tomllib.load(f) - - # 初始化高德 API 客户端 - amap_config = self.config.get("amap", {}) - api_key = amap_config.get("api_key", "") - secret = amap_config.get("secret", "") - - if not api_key: - logger.warning("TravelPlanner: 未配置高德 API Key,请在 config.toml 中设置") - else: - self.amap = AmapClient(AmapConfig( - api_key=api_key, - secret=secret, - timeout=amap_config.get("timeout", 30) - )) - if secret: - logger.success(f"TravelPlanner 插件已加载,API Key: {api_key[:8]}...(已启用数字签名)") - else: - logger.success(f"TravelPlanner 插件已加载,API Key: {api_key[:8]}...(未配置安全密钥)") - - async def on_disable(self): - """插件禁用时关闭连接""" - await super().on_disable() - if self.amap: - await self.amap.close() - logger.info("TravelPlanner: 已关闭高德 API 连接") - - # ==================== LLM 工具定义 ==================== - - def get_llm_tools(self) -> List[Dict]: - """返回 LLM 可调用的工具列表""" - return [ - { - "type": "function", - "function": { - "name": "search_location", - "description": "【旅行工具】将地名转换为坐标和行政区划信息。仅当用户明确询问某个地点的位置信息时使用。", - "parameters": { - "type": "object", - "properties": { - "address": { - "type": "string", - "description": "地址或地名,如:北京市、西湖、故宫" - }, - "city": { - "type": "string", - "description": "所在城市,可选。填写可提高搜索精度" - } - }, - "required": ["address"] - } - } - }, - { - "type": "function", - "function": { - "name": "query_weather", - "description": "【旅行工具】查询城市天气预报。仅当用户明确询问某城市的天气情况时使用,如'北京天气怎么样'、'杭州明天会下雨吗'。", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "城市名称,如:北京、杭州、上海" - }, - "forecast": { - "type": "boolean", - "description": "是否查询预报天气。true=未来4天预报,false=当前实况" - } - }, - "required": ["city"] - } - } - }, - { - "type": "function", - "function": { - "name": "search_poi", - "description": "【旅行工具】搜索地点(景点、酒店、餐厅等)。仅当用户明确要求查找某城市的景点、酒店、餐厅等时使用。", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "搜索城市,如:杭州、北京" - }, - "keyword": { - "type": "string", - "description": "搜索关键词,如:西湖、希尔顿酒店、火锅" - }, - "category": { - "type": "string", - "enum": ["景点", "酒店", "餐厅", "购物", "交通"], - "description": "POI 类别。不填则搜索所有类别" - }, - "limit": { - "type": "integer", - "description": "返回结果数量,默认10,最大20" - } - }, - "required": ["city"] - } - } - }, - { - "type": "function", - "function": { - "name": "search_nearby", - "description": "【旅行工具】搜索某地点周边的设施。仅当用户明确要求查找某地点附近的餐厅、酒店等时使用,如'西湖附近有什么好吃的'。", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "中心地点名称,如:西湖、故宫" - }, - "city": { - "type": "string", - "description": "所在城市" - }, - "keyword": { - "type": "string", - "description": "搜索关键词" - }, - "category": { - "type": "string", - "enum": ["景点", "酒店", "餐厅", "购物", "交通"], - "description": "POI 类别" - }, - "radius": { - "type": "integer", - "description": "搜索半径(米),默认3000,最大50000" - } - }, - "required": ["location", "city"] - } - } - }, - { - "type": "function", - "function": { - "name": "plan_route", - "description": "【旅行工具】规划两地之间的出行路线。仅当用户明确要求规划从A到B的路线时使用,如'从北京到杭州怎么走'、'上海到苏州的高铁'。", - "parameters": { - "type": "object", - "properties": { - "origin": { - "type": "string", - "description": "起点地名,如:北京、上海虹桥站" - }, - "destination": { - "type": "string", - "description": "终点地名,如:杭州、西湖" - }, - "origin_city": { - "type": "string", - "description": "起点所在城市" - }, - "destination_city": { - "type": "string", - "description": "终点所在城市(跨城时必填)" - }, - "mode": { - "type": "string", - "enum": ["driving", "transit", "walking"], - "description": "出行方式:driving=驾车,transit=公交/高铁,walking=步行。默认 transit" - } - }, - "required": ["origin", "destination", "origin_city"] - } - } - }, - { - "type": "function", - "function": { - "name": "get_travel_info", - "description": "【旅行工具】获取目的地城市的旅行信息(天气、景点、交通)。仅当用户明确表示要去某城市旅游并询问相关信息时使用,如'我想去杭州玩,帮我看看'、'北京旅游攻略'。", - "parameters": { - "type": "object", - "properties": { - "destination": { - "type": "string", - "description": "目的地城市,如:杭州、成都" - }, - "origin": { - "type": "string", - "description": "出发城市,如:北京、上海。填写后会规划交通路线" - } - }, - "required": ["destination"] - } - } - } - ] - - async def execute_llm_tool( - self, - tool_name: str, - arguments: Dict[str, Any], - bot, - from_wxid: str - ) -> Dict[str, Any]: - """执行 LLM 工具调用""" - - if not self.amap: - return {"success": False, "message": "高德 API 未配置,请联系管理员设置 API Key"} - - try: - if tool_name == "search_location": - return await self._tool_search_location(arguments) - elif tool_name == "query_weather": - return await self._tool_query_weather(arguments) - elif tool_name == "search_poi": - return await self._tool_search_poi(arguments) - elif tool_name == "search_nearby": - return await self._tool_search_nearby(arguments) - elif tool_name == "plan_route": - return await self._tool_plan_route(arguments) - elif tool_name == "get_travel_info": - return await self._tool_get_travel_info(arguments) - else: - return {"success": False, "message": f"未知工具: {tool_name}"} - - except Exception as e: - logger.error(f"TravelPlanner 工具执行失败: {tool_name}, 错误: {e}") - return {"success": False, "message": f"工具执行失败: {str(e)}"} - - # ==================== 工具实现 ==================== - - async def _tool_search_location(self, args: Dict) -> Dict: - """地点搜索工具""" - address = args.get("address", "") - city = args.get("city") - - result = await self.amap.geocode(address, city) - - if not result["success"]: - return {"success": False, "message": result.get("error", "地点搜索失败")} - - return { - "success": True, - "message": f"已找到地点:{result['formatted_address']}", - "data": { - "name": address, - "formatted_address": result["formatted_address"], - "location": result["location"], - "province": result["province"], - "city": result["city"], - "district": result["district"], - "adcode": result["adcode"] - } - } - - async def _tool_query_weather(self, args: Dict) -> Dict: - """天气查询工具""" - city = args.get("city", "") - forecast = args.get("forecast", True) - - extensions = "all" if forecast else "base" - result = await self.amap.get_weather(city, extensions) - - if not result["success"]: - return {"success": False, "message": result.get("error", "天气查询失败")} - - if result["type"] == "live": - return { - "success": True, - "message": f"{result['city']}当前天气:{result['weather']},{result['temperature']}℃", - "data": { - "city": result["city"], - "weather": result["weather"], - "temperature": result["temperature"], - "humidity": result["humidity"], - "wind": f"{result['winddirection']}风 {result['windpower']}级", - "reporttime": result["reporttime"] - } - } - else: - forecasts = result["forecasts"] - weather_text = "\n".join([ - f"- {f['date']} 星期{self._weekday_cn(f['week'])}:白天{f['dayweather']} {f['daytemp']}℃,夜间{f['nightweather']} {f['nighttemp']}℃" - for f in forecasts - ]) - - return { - "success": True, - "message": f"{result['city']}未来天气预报:\n{weather_text}", - "data": { - "city": result["city"], - "province": result["province"], - "forecasts": forecasts, - "reporttime": result["reporttime"] - } - } - - async def _tool_search_poi(self, args: Dict) -> Dict: - """POI 搜索工具""" - city = args.get("city", "") - keyword = args.get("keyword") - category = args.get("category") - limit = min(args.get("limit", 10), 20) - - # 获取 POI 类型代码 - types = None - if category: - poi_types = self.config.get("poi_types", {}) - types = poi_types.get(category) - - result = await self.amap.search_poi( - keywords=keyword, - types=types, - city=city, - citylimit=True, - offset=limit - ) - - if not result["success"]: - return {"success": False, "message": result.get("error", "搜索失败")} - - pois = result["pois"] - if not pois: - return {"success": False, "message": f"在{city}未找到相关地点"} - - # 格式化输出 - poi_list = [] - for i, poi in enumerate(pois, 1): - info = f"{i}. {poi['name']}" - if poi.get("address"): - info += f" - {poi['address']}" - if poi.get("rating"): - info += f" ⭐{poi['rating']}" - if poi.get("cost"): - info += f" 人均¥{poi['cost']}" - poi_list.append(info) - - return { - "success": True, - "message": f"在{city}找到{len(pois)}个结果:\n" + "\n".join(poi_list), - "data": { - "city": city, - "category": category or "全部", - "count": len(pois), - "pois": pois - } - } - - async def _tool_search_nearby(self, args: Dict) -> Dict: - """周边搜索工具""" - location_name = args.get("location", "") - city = args.get("city", "") - keyword = args.get("keyword") - category = args.get("category") - radius = min(args.get("radius", 3000), 50000) - - # 先获取中心点坐标 - geo_result = await self.amap.geocode(location_name, city) - if not geo_result["success"]: - return {"success": False, "message": f"无法定位 {location_name}"} - - location = geo_result["location"] - - # 获取 POI 类型代码 - types = None - if category: - poi_types = self.config.get("poi_types", {}) - types = poi_types.get(category) - - result = await self.amap.search_around( - location=location, - keywords=keyword, - types=types, - radius=radius, - offset=10 - ) - - if not result["success"]: - return {"success": False, "message": result.get("error", "周边搜索失败")} - - pois = result["pois"] - if not pois: - return {"success": False, "message": f"在{location_name}周边未找到相关地点"} - - # 格式化输出 - poi_list = [] - for i, poi in enumerate(pois, 1): - info = f"{i}. {poi['name']}" - if poi.get("distance"): - info += f" ({poi['distance']}米)" - if poi.get("rating"): - info += f" ⭐{poi['rating']}" - poi_list.append(info) - - return { - "success": True, - "message": f"{location_name}周边{radius}米内找到{len(pois)}个结果:\n" + "\n".join(poi_list), - "data": { - "center": location_name, - "radius": radius, - "category": category or "全部", - "count": len(pois), - "pois": pois - } - } - - async def _tool_plan_route(self, args: Dict) -> Dict: - """路线规划工具""" - origin = args.get("origin", "") - destination = args.get("destination", "") - origin_city = args.get("origin_city", "") - destination_city = args.get("destination_city", origin_city) - mode = args.get("mode", "transit") - - # 获取起终点坐标 - origin_geo = await self.amap.geocode(origin, origin_city) - if not origin_geo["success"]: - return {"success": False, "message": f"无法定位起点:{origin}"} - - dest_geo = await self.amap.geocode(destination, destination_city) - if not dest_geo["success"]: - return {"success": False, "message": f"无法定位终点:{destination}"} - - origin_loc = origin_geo["location"] - dest_loc = dest_geo["location"] - - # 根据模式规划路线 - if mode == "driving": - result = await self.amap.route_driving(origin_loc, dest_loc) - if not result["success"]: - return {"success": False, "message": result.get("error", "驾车路线规划失败")} - - distance_km = result["distance"] / 1000 - duration_h = result["duration"] / 3600 - - msg = f"🚗 驾车路线:{origin} → {destination}\n" - msg += f"距离:{distance_km:.1f}公里,预计{self._format_duration(result['duration'])}\n" - if result["tolls"]: - msg += f"收费:约{result['tolls']}元\n" - if result["taxi_cost"]: - msg += f"打车费用:约{result['taxi_cost']}元" - - return { - "success": True, - "message": msg, - "data": result - } - - elif mode == "transit": - result = await self.amap.route_transit( - origin_loc, dest_loc, - city=origin_city, - cityd=destination_city if destination_city != origin_city else None - ) - if not result["success"]: - return {"success": False, "message": result.get("error", "公交路线规划失败")} - - msg = f"🚄 公交/高铁路线:{origin} → {destination}\n" - - for i, transit in enumerate(result["transits"][:2], 1): - msg += f"\n方案{i}:{self._format_duration(transit['duration'])}" - if transit.get("cost"): - msg += f",约{transit['cost']}元" - msg += "\n" - - for seg in transit["segments"]: - if seg["type"] == "walking" and seg["distance"] > 100: - msg += f" 🚶 步行{seg['distance']}米\n" - elif seg["type"] == "bus": - msg += f" 🚌 {seg['name']}:{seg['departure_stop']} → {seg['arrival_stop']}({seg['via_num']}站)\n" - elif seg["type"] == "railway": - msg += f" 🚄 {seg['trip']} {seg['name']}:{seg['departure_stop']} {seg.get('departure_time', '')} → {seg['arrival_stop']} {seg.get('arrival_time', '')}\n" - - return { - "success": True, - "message": msg.strip(), - "data": result - } - - elif mode == "walking": - result = await self.amap.route_walking(origin_loc, dest_loc) - if not result["success"]: - return {"success": False, "message": result.get("error", "步行路线规划失败")} - - return { - "success": True, - "message": f"🚶 步行路线:{origin} → {destination}\n距离:{result['distance']}米,预计{self._format_duration(result['duration'])}", - "data": result - } - - return {"success": False, "message": f"不支持的出行方式:{mode}"} - - async def _tool_get_travel_info(self, args: Dict) -> Dict: - """一键获取旅行信息""" - destination = args.get("destination", "") - origin = args.get("origin") - - info = {"destination": destination} - msg_parts = [f"📍 {destination} 旅行信息\n"] - - # 1. 查询天气 - weather_result = await self.amap.get_weather(destination, "all") - if weather_result["success"]: - info["weather"] = weather_result - msg_parts.append("🌤️ 天气预报:") - for f in weather_result["forecasts"][:3]: - msg_parts.append(f" {f['date']} {f['dayweather']} {f['nighttemp']}~{f['daytemp']}℃") - - # 2. 搜索热门景点 - poi_result = await self.amap.search_poi( - types="110000", # 景点 - city=destination, - citylimit=True, - offset=5 - ) - if poi_result["success"] and poi_result["pois"]: - info["attractions"] = poi_result["pois"] - msg_parts.append("\n🏞️ 热门景点:") - for poi in poi_result["pois"][:5]: - rating = f" ⭐{poi['rating']}" if poi.get("rating") else "" - msg_parts.append(f" • {poi['name']}{rating}") - - # 3. 规划交通路线(如果提供了出发地) - if origin: - origin_geo = await self.amap.geocode(origin) - dest_geo = await self.amap.geocode(destination) - - if origin_geo["success"] and dest_geo["success"]: - route_result = await self.amap.route_transit( - origin_geo["location"], - dest_geo["location"], - city=origin_geo.get("city", origin), - cityd=dest_geo.get("city", destination) - ) - - if route_result["success"] and route_result["transits"]: - info["route"] = route_result - transit = route_result["transits"][0] - msg_parts.append(f"\n🚄 从{origin}出发:") - msg_parts.append(f" 预计{self._format_duration(transit['duration'])}") - - # 显示主要交通工具 - for seg in transit["segments"]: - if seg["type"] == "railway": - msg_parts.append(f" {seg['trip']}:{seg['departure_stop']} → {seg['arrival_stop']}") - break - - return { - "success": True, - "message": "\n".join(msg_parts), - "data": info - } - - # ==================== 辅助方法 ==================== - - def _weekday_cn(self, week: str) -> str: - """星期数字转中文""" - mapping = {"1": "一", "2": "二", "3": "三", "4": "四", "5": "五", "6": "六", "7": "日"} - return mapping.get(str(week), week) - - def _format_duration(self, seconds: int) -> str: - """格式化时长""" - if seconds < 60: - return f"{seconds}秒" - elif seconds < 3600: - return f"{seconds // 60}分钟" - else: - hours = seconds // 3600 - minutes = (seconds % 3600) // 60 - if minutes: - return f"{hours}小时{minutes}分钟" - return f"{hours}小时" +""" +旅行规划插件 + +基于高德地图 API,提供以下功能: +- 地点搜索与地理编码 +- 天气查询(实况 + 4天预报) +- 景点/酒店/餐厅搜索 +- 路径规划(驾车/公交/步行) +- 周边搜索 + +支持 LLM 函数调用,可与 AIChat 插件配合使用。 +""" + +import asyncio +import tomllib +from pathlib import Path +from typing import Any, Dict, List +from loguru import logger + +from utils.plugin_base import PluginBase +from .amap_client import AmapClient, AmapConfig + + +class TravelPlanner(PluginBase): + """旅行规划插件""" + + description = "旅行规划助手,支持天气查询、景点搜索、路线规划" + author = "ShiHao" + version = "1.0.0" + + def __init__(self): + super().__init__() + self.config = None + self.amap: AmapClient = None + + async def async_init(self): + """插件异步初始化""" + # 读取配置 + config_path = Path(__file__).parent / "config.toml" + with open(config_path, "rb") as f: + self.config = tomllib.load(f) + + # 初始化高德 API 客户端 + amap_config = self.config.get("amap", {}) + api_key = amap_config.get("api_key", "") + secret = amap_config.get("secret", "") + + if not api_key: + logger.warning("TravelPlanner: 未配置高德 API Key,请在 config.toml 中设置") + else: + self.amap = AmapClient(AmapConfig( + api_key=api_key, + secret=secret, + timeout=amap_config.get("timeout", 30) + )) + if secret: + logger.success(f"TravelPlanner 插件已加载,API Key: {api_key[:8]}...(已启用数字签名)") + else: + logger.success(f"TravelPlanner 插件已加载,API Key: {api_key[:8]}...(未配置安全密钥)") + + async def on_disable(self): + """插件禁用时关闭连接""" + await super().on_disable() + if self.amap: + await self.amap.close() + logger.info("TravelPlanner: 已关闭高德 API 连接") + + # ==================== LLM 工具定义 ==================== + + def get_llm_tools(self) -> List[Dict]: + """返回 LLM 可调用的工具列表""" + return [ + { + "type": "function", + "function": { + "name": "search_location", + "description": "【旅行工具】将地名转换为坐标和行政区划信息。仅当用户明确询问某个地点的位置信息时使用。", + "parameters": { + "type": "object", + "properties": { + "address": { + "type": "string", + "description": "地址或地名,如:北京市、西湖、故宫" + }, + "city": { + "type": "string", + "description": "所在城市,可选。填写可提高搜索精度" + } + }, + "required": ["address"] + } + } + }, + { + "type": "function", + "function": { + "name": "query_weather", + "description": "【旅行工具】查询城市天气预报。仅当用户明确询问某城市的天气情况时使用,如'北京天气怎么样'、'杭州明天会下雨吗'。", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "城市名称,如:北京、杭州、上海" + }, + "forecast": { + "type": "boolean", + "description": "是否查询预报天气。true=未来4天预报,false=当前实况" + } + }, + "required": ["city"] + } + } + }, + { + "type": "function", + "function": { + "name": "search_poi", + "description": "【旅行工具】搜索地点(景点、酒店、餐厅等)。仅当用户明确要求查找某城市的景点、酒店、餐厅等时使用。", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "搜索城市,如:杭州、北京" + }, + "keyword": { + "type": "string", + "description": "搜索关键词,如:西湖、希尔顿酒店、火锅" + }, + "category": { + "type": "string", + "enum": ["景点", "酒店", "餐厅", "购物", "交通"], + "description": "POI 类别。不填则搜索所有类别" + }, + "limit": { + "type": "integer", + "description": "返回结果数量,默认10,最大20" + } + }, + "required": ["city"] + } + } + }, + { + "type": "function", + "function": { + "name": "search_nearby", + "description": "【旅行工具】搜索某地点周边的设施。仅当用户明确要求查找某地点附近的餐厅、酒店等时使用,如'西湖附近有什么好吃的'。", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "中心地点名称,如:西湖、故宫" + }, + "city": { + "type": "string", + "description": "所在城市" + }, + "keyword": { + "type": "string", + "description": "搜索关键词" + }, + "category": { + "type": "string", + "enum": ["景点", "酒店", "餐厅", "购物", "交通"], + "description": "POI 类别" + }, + "radius": { + "type": "integer", + "description": "搜索半径(米),默认3000,最大50000" + } + }, + "required": ["location", "city"] + } + } + }, + { + "type": "function", + "function": { + "name": "plan_route", + "description": "【旅行工具】规划两地之间的出行路线。仅当用户明确要求规划从A到B的路线时使用,如'从北京到杭州怎么走'、'上海到苏州的高铁'。", + "parameters": { + "type": "object", + "properties": { + "origin": { + "type": "string", + "description": "起点地名,如:北京、上海虹桥站" + }, + "destination": { + "type": "string", + "description": "终点地名,如:杭州、西湖" + }, + "origin_city": { + "type": "string", + "description": "起点所在城市" + }, + "destination_city": { + "type": "string", + "description": "终点所在城市(跨城时必填)" + }, + "mode": { + "type": "string", + "enum": ["driving", "transit", "walking"], + "description": "出行方式:driving=驾车,transit=公交/高铁,walking=步行。默认 transit" + } + }, + "required": ["origin", "destination", "origin_city"] + } + } + }, + { + "type": "function", + "function": { + "name": "get_travel_info", + "description": "【旅行工具】获取目的地城市的旅行信息(天气、景点、交通)。仅当用户明确表示要去某城市旅游并询问相关信息时使用,如'我想去杭州玩,帮我看看'、'北京旅游攻略'。", + "parameters": { + "type": "object", + "properties": { + "destination": { + "type": "string", + "description": "目的地城市,如:杭州、成都" + }, + "origin": { + "type": "string", + "description": "出发城市,如:北京、上海。填写后会规划交通路线" + } + }, + "required": ["destination"] + } + } + }, + { + "type": "function", + "function": { + "name": "plan_detailed_trip", + "description": "【必须调用】详细行程规划工具。当用户提到'规划行程'、'安排旅行'、'去XX旅游'、'帮我规划'、'我想去XX玩'时,必须调用此工具获取实时的交通、酒店、景点信息。此工具会返回:1.从用户家到火车站的详细路线(地铁几号线、哪站上哪站下)2.高铁车次和时刻 3.酒店推荐 4.景点推荐 5.餐厅推荐 6.天气预报。", + "parameters": { + "type": "object", + "properties": { + "origin_city": { + "type": "string", + "description": "出发城市,如:合肥、上海、北京" + }, + "origin_address": { + "type": "string", + "description": "用户的具体出发地址,如:合肥市蜀山区xxx小区、上海市浦东新区xxx路。如果用户没提供具体地址,填写城市名即可" + }, + "destination": { + "type": "string", + "description": "目的地城市,如:北京、杭州、成都" + }, + "days": { + "type": "integer", + "description": "旅行天数,默认2天" + }, + "departure_time": { + "type": "string", + "description": "出发时间偏好,如:周六早上、明天下午" + }, + "preferences": { + "type": "string", + "description": "旅行偏好,如:喜欢历史文化、想吃美食、带小孩" + } + }, + "required": ["origin_city", "destination"] + } + } + } + ] + + async def execute_llm_tool( + self, + tool_name: str, + arguments: Dict[str, Any], + bot, + from_wxid: str + ) -> Dict[str, Any]: + """执行 LLM 工具调用""" + + if not self.amap: + return {"success": False, "message": "高德 API 未配置,请联系管理员设置 API Key"} + + try: + if tool_name == "search_location": + return await self._tool_search_location(arguments) + elif tool_name == "query_weather": + return await self._tool_query_weather(arguments) + elif tool_name == "search_poi": + return await self._tool_search_poi(arguments) + elif tool_name == "search_nearby": + return await self._tool_search_nearby(arguments) + elif tool_name == "plan_route": + return await self._tool_plan_route(arguments) + elif tool_name == "get_travel_info": + return await self._tool_get_travel_info(arguments) + elif tool_name == "plan_detailed_trip": + return await self._tool_plan_detailed_trip(arguments) + else: + return {"success": False, "message": f"未知工具: {tool_name}"} + + except Exception as e: + logger.error(f"TravelPlanner 工具执行失败: {tool_name}, 错误: {e}") + return {"success": False, "message": f"工具执行失败: {str(e)}"} + + # ==================== 工具实现 ==================== + + async def _tool_search_location(self, args: Dict) -> Dict: + """地点搜索工具""" + address = args.get("address", "") + city = args.get("city") + + result = await self.amap.geocode(address, city) + + if not result["success"]: + return {"success": False, "message": result.get("error", "地点搜索失败")} + + return { + "success": True, + "message": f"已找到地点:{result['formatted_address']}", + "data": { + "name": address, + "formatted_address": result["formatted_address"], + "location": result["location"], + "province": result["province"], + "city": result["city"], + "district": result["district"], + "adcode": result["adcode"] + } + } + + async def _tool_query_weather(self, args: Dict) -> Dict: + """天气查询工具""" + city = args.get("city", "") + forecast = args.get("forecast", True) + + extensions = "all" if forecast else "base" + result = await self.amap.get_weather(city, extensions) + + if not result["success"]: + return {"success": False, "message": result.get("error", "天气查询失败")} + + if result["type"] == "live": + return { + "success": True, + "message": f"{result['city']}当前天气:{result['weather']},{result['temperature']}℃", + "data": { + "city": result["city"], + "weather": result["weather"], + "temperature": result["temperature"], + "humidity": result["humidity"], + "wind": f"{result['winddirection']}风 {result['windpower']}级", + "reporttime": result["reporttime"] + } + } + else: + forecasts = result["forecasts"] + weather_text = "\n".join([ + f"- {f['date']} 星期{self._weekday_cn(f['week'])}:白天{f['dayweather']} {f['daytemp']}℃,夜间{f['nightweather']} {f['nighttemp']}℃" + for f in forecasts + ]) + + return { + "success": True, + "message": f"{result['city']}未来天气预报:\n{weather_text}", + "data": { + "city": result["city"], + "province": result["province"], + "forecasts": forecasts, + "reporttime": result["reporttime"] + } + } + + async def _tool_search_poi(self, args: Dict) -> Dict: + """POI 搜索工具""" + city = args.get("city", "") + keyword = args.get("keyword") + category = args.get("category") + limit = min(args.get("limit", 10), 20) + + # 获取 POI 类型代码 + types = None + if category: + poi_types = self.config.get("poi_types", {}) + types = poi_types.get(category) + + result = await self.amap.search_poi( + keywords=keyword, + types=types, + city=city, + citylimit=True, + offset=limit + ) + + if not result["success"]: + return {"success": False, "message": result.get("error", "搜索失败")} + + pois = result["pois"] + if not pois: + return {"success": False, "message": f"在{city}未找到相关地点"} + + # 格式化输出 + poi_list = [] + for i, poi in enumerate(pois, 1): + info = f"{i}. {poi['name']}" + if poi.get("address"): + info += f" - {poi['address']}" + if poi.get("rating"): + info += f" ⭐{poi['rating']}" + if poi.get("cost"): + info += f" 人均¥{poi['cost']}" + poi_list.append(info) + + return { + "success": True, + "message": f"在{city}找到{len(pois)}个结果:\n" + "\n".join(poi_list), + "data": { + "city": city, + "category": category or "全部", + "count": len(pois), + "pois": pois + } + } + + async def _tool_search_nearby(self, args: Dict) -> Dict: + """周边搜索工具""" + location_name = args.get("location", "") + city = args.get("city", "") + keyword = args.get("keyword") + category = args.get("category") + radius = min(args.get("radius", 3000), 50000) + + # 先获取中心点坐标 + geo_result = await self.amap.geocode(location_name, city) + if not geo_result["success"]: + return {"success": False, "message": f"无法定位 {location_name}"} + + location = geo_result["location"] + + # 获取 POI 类型代码 + types = None + if category: + poi_types = self.config.get("poi_types", {}) + types = poi_types.get(category) + + result = await self.amap.search_around( + location=location, + keywords=keyword, + types=types, + radius=radius, + offset=10 + ) + + if not result["success"]: + return {"success": False, "message": result.get("error", "周边搜索失败")} + + pois = result["pois"] + if not pois: + return {"success": False, "message": f"在{location_name}周边未找到相关地点"} + + # 格式化输出 + poi_list = [] + for i, poi in enumerate(pois, 1): + info = f"{i}. {poi['name']}" + if poi.get("distance"): + info += f" ({poi['distance']}米)" + if poi.get("rating"): + info += f" ⭐{poi['rating']}" + poi_list.append(info) + + return { + "success": True, + "message": f"{location_name}周边{radius}米内找到{len(pois)}个结果:\n" + "\n".join(poi_list), + "data": { + "center": location_name, + "radius": radius, + "category": category or "全部", + "count": len(pois), + "pois": pois + } + } + + async def _tool_plan_route(self, args: Dict) -> Dict: + """路线规划工具""" + origin = args.get("origin", "") + destination = args.get("destination", "") + origin_city = args.get("origin_city", "") + destination_city = args.get("destination_city", origin_city) + mode = args.get("mode", "transit") + + # 获取起终点坐标 + origin_geo = await self.amap.geocode(origin, origin_city) + if not origin_geo["success"]: + return {"success": False, "message": f"无法定位起点:{origin}"} + + dest_geo = await self.amap.geocode(destination, destination_city) + if not dest_geo["success"]: + return {"success": False, "message": f"无法定位终点:{destination}"} + + origin_loc = origin_geo["location"] + dest_loc = dest_geo["location"] + + # 根据模式规划路线 + if mode == "driving": + result = await self.amap.route_driving(origin_loc, dest_loc) + if not result["success"]: + return {"success": False, "message": result.get("error", "驾车路线规划失败")} + + distance_km = result["distance"] / 1000 + duration_h = result["duration"] / 3600 + + msg = f"🚗 驾车路线:{origin} → {destination}\n" + msg += f"距离:{distance_km:.1f}公里,预计{self._format_duration(result['duration'])}\n" + if result["tolls"]: + msg += f"收费:约{result['tolls']}元\n" + if result["taxi_cost"]: + msg += f"打车费用:约{result['taxi_cost']}元" + + return { + "success": True, + "message": msg, + "data": result + } + + elif mode == "transit": + result = await self.amap.route_transit( + origin_loc, dest_loc, + city=origin_city, + cityd=destination_city if destination_city != origin_city else None + ) + if not result["success"]: + return {"success": False, "message": result.get("error", "公交路线规划失败")} + + msg = f"🚄 公交/高铁路线:{origin} → {destination}\n" + + for i, transit in enumerate(result["transits"][:2], 1): + msg += f"\n方案{i}:{self._format_duration(transit['duration'])}" + if transit.get("cost"): + msg += f",约{transit['cost']}元" + msg += "\n" + + for seg in transit["segments"]: + if seg["type"] == "walking" and seg["distance"] > 100: + msg += f" 🚶 步行{seg['distance']}米\n" + elif seg["type"] == "bus": + msg += f" 🚌 {seg['name']}:{seg['departure_stop']} → {seg['arrival_stop']}({seg['via_num']}站)\n" + elif seg["type"] == "railway": + msg += f" 🚄 {seg['trip']} {seg['name']}:{seg['departure_stop']} {seg.get('departure_time', '')} → {seg['arrival_stop']} {seg.get('arrival_time', '')}\n" + + return { + "success": True, + "message": msg.strip(), + "data": result + } + + elif mode == "walking": + result = await self.amap.route_walking(origin_loc, dest_loc) + if not result["success"]: + return {"success": False, "message": result.get("error", "步行路线规划失败")} + + return { + "success": True, + "message": f"🚶 步行路线:{origin} → {destination}\n距离:{result['distance']}米,预计{self._format_duration(result['duration'])}", + "data": result + } + + return {"success": False, "message": f"不支持的出行方式:{mode}"} + + async def _tool_get_travel_info(self, args: Dict) -> Dict: + """一键获取旅行信息""" + destination = args.get("destination", "") + origin = args.get("origin") + + info = {"destination": destination} + msg_parts = [f"📍 {destination} 旅行信息\n"] + + # 1. 查询天气 + weather_result = await self.amap.get_weather(destination, "all") + if weather_result["success"]: + info["weather"] = weather_result + msg_parts.append("🌤️ 天气预报:") + for f in weather_result["forecasts"][:3]: + msg_parts.append(f" {f['date']} {f['dayweather']} {f['nighttemp']}~{f['daytemp']}℃") + + # 2. 搜索热门景点 + poi_result = await self.amap.search_poi( + types="110000", # 景点 + city=destination, + citylimit=True, + offset=5 + ) + if poi_result["success"] and poi_result["pois"]: + info["attractions"] = poi_result["pois"] + msg_parts.append("\n🏞️ 热门景点:") + for poi in poi_result["pois"][:5]: + rating = f" ⭐{poi['rating']}" if poi.get("rating") else "" + msg_parts.append(f" • {poi['name']}{rating}") + + # 3. 规划交通路线(如果提供了出发地) + if origin: + origin_geo = await self.amap.geocode(origin) + dest_geo = await self.amap.geocode(destination) + + if origin_geo["success"] and dest_geo["success"]: + route_result = await self.amap.route_transit( + origin_geo["location"], + dest_geo["location"], + city=origin_geo.get("city", origin), + cityd=dest_geo.get("city", destination) + ) + + if route_result["success"] and route_result["transits"]: + info["route"] = route_result + transit = route_result["transits"][0] + msg_parts.append(f"\n🚄 从{origin}出发:") + msg_parts.append(f" 预计{self._format_duration(transit['duration'])}") + + # 显示主要交通工具 + for seg in transit["segments"]: + if seg["type"] == "railway": + msg_parts.append(f" {seg['trip']}:{seg['departure_stop']} → {seg['arrival_stop']}") + break + + return { + "success": True, + "message": "\n".join(msg_parts), + "data": info, + "need_ai_reply": True # 让 AI 根据这些信息生成详细的行程规划 + } + + async def _tool_plan_detailed_trip(self, args: Dict) -> Dict: + """ + 详细行程规划工具(优化版:并行 API 调用) + """ + origin_city = args.get("origin_city", "") + origin_address = args.get("origin_address", "") or origin_city + destination = args.get("destination", "") + days = args.get("days", 2) + departure_time = args.get("departure_time", "") + preferences = args.get("preferences", "") + + info = { + "origin_city": origin_city, + "origin_address": origin_address, + "destination": destination, + "days": days + } + + # ========== 第1步:并行获取基础信息 ========== + user_geo_task = self.amap.geocode(origin_address, origin_city) + dest_geo_task = self.amap.geocode(destination) + weather_task = self.amap.get_weather(destination, "all") + + user_geo, dest_geo, weather_result = await asyncio.gather( + user_geo_task, dest_geo_task, weather_task, + return_exceptions=True + ) + + # 处理地理编码结果 + if isinstance(user_geo, Exception) or not user_geo.get("success"): + user_geo = await self.amap.geocode(origin_city) + if isinstance(dest_geo, Exception) or not dest_geo.get("success"): + return {"success": False, "message": f"无法定位目的地:{destination}"} + if not user_geo.get("success"): + return {"success": False, "message": f"无法定位出发地:{origin_address}"} + + user_loc = user_geo["location"] + dest_loc = dest_geo["location"] + origin_city_name = user_geo.get("city") or origin_city + dest_city_name = dest_geo.get("city") or destination + + # ========== 第2步:并行搜索火车站和目的地信息 ========== + origin_stations_task = self.amap.search_poi(types="150200", city=origin_city, citylimit=True, offset=3) + dest_stations_task = self.amap.search_poi(types="150200", city=destination, citylimit=True, offset=2) + hotels_task = self.amap.search_poi(types="100100|100101", city=destination, citylimit=True, offset=5) + attractions_task = self.amap.search_poi(types="110000", city=destination, citylimit=True, offset=6) + food_task = self.amap.search_poi(types="050000", city=destination, citylimit=True, offset=5) + + results = await asyncio.gather( + origin_stations_task, dest_stations_task, hotels_task, attractions_task, food_task, + return_exceptions=True + ) + origin_stations, dest_stations, hotels, attractions, food = results + + # ========== 第3步:规划到火车站的路线(只规划1个最近的) ========== + best_station = None + best_route = None + + if not isinstance(origin_stations, Exception) and origin_stations.get("success") and origin_stations.get("pois"): + station = origin_stations["pois"][0] # 只取第一个火车站 + try: + route = await self.amap.route_transit(user_loc, station["location"], city=origin_city_name) + if route.get("success") and route.get("transits"): + best_station = station + best_route = route["transits"][0] + except Exception as e: + logger.warning(f"规划到火车站路线失败: {e}") + + # ========== 第4步:规划城际交通 ========== + transit_info = None + if best_station and not isinstance(dest_stations, Exception) and dest_stations.get("success") and dest_stations.get("pois"): + try: + dest_station = dest_stations["pois"][0] + transit = await self.amap.route_transit( + best_station["location"], dest_station["location"], + city=origin_city_name, cityd=dest_city_name + ) + if transit.get("success") and transit.get("transits"): + transit_info = transit["transits"][0] + except Exception as e: + logger.warning(f"城际交通规划失败: {e}") + + # ========== 组装输出 ========== + sections = [] + sections.append(f"📋 {origin_address} → {destination} {days}天行程\n") + sections.append(f"📍 出发地:{user_geo.get('formatted_address', origin_address)}\n") + + # 天气 + if not isinstance(weather_result, Exception) and weather_result.get("success"): + sections.append("【天气预报】") + for f in weather_result.get("forecasts", [])[:3]: + sections.append(f" {f['date']}:{f['dayweather']} {f['nighttemp']}~{f['daytemp']}℃") + sections.append("") + + # 到火车站 + if best_station and best_route: + sections.append("【从您家到火车站】") + sections.append(f" 🚉 {best_station['name']}") + sections.append(f" ⏱️ 预计:{self._format_duration(best_route['duration'])}") + for seg in best_route.get("segments", []): + if seg["type"] == "bus": + line = seg["name"] + icon = "🚇" if "地铁" in line or "号线" in line else "🚌" + sections.append(f" {icon} {line}:{seg['departure_stop']} → {seg['arrival_stop']}({seg['via_num']}站)") + sections.append("") + + # 城际交通 + if transit_info: + sections.append("【城际高铁/火车】") + sections.append(f" ⏱️ 全程约{self._format_duration(transit_info['duration'])},费用约{transit_info.get('cost', '未知')}元") + for seg in transit_info.get("segments", []): + if seg["type"] == "railway": + sections.append(f" 🚄 {seg['trip']} {seg['name']}") + sections.append(f" {seg['departure_stop']} → {seg['arrival_stop']}") + sections.append("") + + # 酒店 + if not isinstance(hotels, Exception) and hotels.get("success") and hotels.get("pois"): + sections.append("【酒店推荐】") + for i, h in enumerate(hotels["pois"][:4], 1): + rating = f"⭐{h['rating']}" if h.get("rating") else "" + cost = f"¥{h['cost']}/晚" if h.get("cost") else "" + sections.append(f" {i}. {h['name']} {rating} {cost}") + sections.append("") + + # 景点 + if not isinstance(attractions, Exception) and attractions.get("success") and attractions.get("pois"): + sections.append("【热门景点】") + for i, p in enumerate(attractions["pois"][:5], 1): + rating = f"⭐{p['rating']}" if p.get("rating") else "" + sections.append(f" {i}. {p['name']} {rating}") + sections.append("") + + # 美食 + if not isinstance(food, Exception) and food.get("success") and food.get("pois"): + sections.append("【美食推荐】") + for i, p in enumerate(food["pois"][:4], 1): + cost = f"人均¥{p['cost']}" if p.get("cost") else "" + sections.append(f" {i}. {p['name']} {cost}") + sections.append("") + + # 提示 + sections.append(f"📌 请根据以上信息为用户安排{days}天行程") + if departure_time: + sections.append(f" 出发时间偏好:{departure_time}") + if preferences: + sections.append(f" 用户偏好:{preferences}") + + return { + "success": True, + "message": "\n".join(sections), + "data": info, + "need_ai_reply": True + } + + # ==================== 辅助方法 ==================== + + def _weekday_cn(self, week: str) -> str: + """星期数字转中文""" + mapping = {"1": "一", "2": "二", "3": "三", "4": "四", "5": "五", "6": "六", "7": "日"} + return mapping.get(str(week), week) + + def _format_duration(self, seconds: int) -> str: + """格式化时长""" + if seconds < 60: + return f"{seconds}秒" + elif seconds < 3600: + return f"{seconds // 60}分钟" + else: + hours = seconds // 3600 + minutes = (seconds % 3600) // 60 + if minutes: + return f"{hours}小时{minutes}分钟" + return f"{hours}小时" diff --git a/utils/context_store.py b/utils/context_store.py index f00e484..857b94c 100644 --- a/utils/context_store.py +++ b/utils/context_store.py @@ -180,6 +180,33 @@ class ContextStore: pass self.memory_fallback.pop(chat_id, None) + async def clear_group_history(self, chat_id: str) -> None: + """ + 清空群聊历史记录 + + 同时清除 Redis 和本地文件中的群聊历史 + """ + # 清除 Redis 中的群聊历史 + if self._use_redis_for_group_history(): + redis_cache = get_cache() + try: + key = f"group_history:{_safe_chat_id(chat_id)}" + redis_cache.delete(key) + logger.debug(f"[ContextStore] 已清除 Redis 群聊历史: {chat_id}") + except Exception as e: + logger.debug(f"[ContextStore] 清除 Redis 群聊历史失败: {e}") + + # 清除本地文件中的群聊历史 + history_file = self._get_history_file(chat_id) + if history_file and history_file.exists(): + lock = self._get_history_lock(chat_id) + async with lock: + try: + history_file.unlink() + logger.debug(f"[ContextStore] 已清除本地群聊历史文件: {history_file}") + except Exception as e: + logger.debug(f"[ContextStore] 清除本地群聊历史文件失败: {e}") + # ------------------ 群聊 history ------------------ def _use_redis_for_group_history(self) -> bool: