feat:修复一些BUG

This commit is contained in:
2026-01-08 09:49:01 +08:00
parent 820861752b
commit 472b1a0d5e
33 changed files with 1643 additions and 742 deletions

View File

@@ -2831,10 +2831,64 @@ class AIChat(PluginBase):
refer_type = int(refer_type_elem.text) if refer_type_elem is not None and refer_type_elem.text else 0
logger.debug(f"被引用消息类型: {refer_type}")
# 纯文本消息不需要处理type=1
# 纯文本消息type=1:如果@了机器人,转发给 AI 处理
if refer_type == 1:
logger.debug("引用的是纯文本消息,跳过")
return True
if self._should_reply_quote(message, title_text):
# 获取被引用的文本内容
refer_content_elem = refermsg.find("content")
refer_text = refer_content_elem.text.strip() if refer_content_elem is not None and refer_content_elem.text else ""
# 获取被引用者昵称
refer_displayname = refermsg.find("displayname")
refer_nickname = refer_displayname.text if refer_displayname is not None and refer_displayname.text else "某人"
# 组合消息:引用内容 + 用户评论
# title_text 格式如 "@瑞依 评价下",需要去掉 @昵称 部分
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
user_comment = title_text
if bot_nickname:
# 移除 @机器人昵称(可能有空格分隔)
user_comment = user_comment.replace(f"@{bot_nickname}", "").strip()
# 构造给 AI 的消息
combined_message = f"[引用 {refer_nickname} 的消息:{refer_text}]\n{user_comment}"
logger.info(f"引用纯文本消息,转发给 AI: {combined_message[:80]}...")
# 调用 AI 处理
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
chat_id = from_wxid if is_group else user_wxid
# 保存用户消息到群组历史记录
history_enabled = bool(self.store) and self.config.get("history", {}).get("enabled", True)
sync_bot_messages = self.config.get("history", {}).get("sync_bot_messages", True)
if is_group and history_enabled:
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
await self._add_to_history(history_chat_id, nickname, combined_message, sender_wxid=user_wxid)
ai_response = await self._call_ai_api(
combined_message,
bot=bot,
from_wxid=from_wxid,
chat_id=chat_id,
nickname=nickname
)
if ai_response:
final_response = self._sanitize_llm_output(ai_response)
await bot.send_text(from_wxid, final_response)
# 保存 AI 回复到群组历史记录
if is_group and history_enabled and sync_bot_messages:
bot_nickname_display = main_config.get("Bot", {}).get("nickname", "AI")
await self._add_to_history(history_chat_id, bot_nickname_display, final_response, role="assistant")
return False
else:
logger.debug("引用的是纯文本消息且未@机器人,跳过")
return True
# 只处理图片(3)、视频(43)、应用消息(49含聊天记录)
if refer_type not in [3, 43, 49]:
@@ -3553,40 +3607,50 @@ class AIChat(PluginBase):
def _should_reply_quote(self, message: dict, title_text: str) -> bool:
"""判断是否应该回复引用消息"""
is_group = message.get("IsGroup", False)
# 检查群聊/私聊开关
if is_group and not self.config["behavior"]["reply_group"]:
return False
if not is_group and not self.config["behavior"]["reply_private"]:
return False
trigger_mode = self.config["behavior"]["trigger_mode"]
# all模式回复所有消息
if trigger_mode == "all":
return True
# mention模式检查是否@了机器人
if trigger_mode == "mention":
if is_group:
# 方式1检查 Ats 字段(普通消息格式)
ats = message.get("Ats", [])
if not ats:
return False
import tomllib
with open("main_config.toml", "rb") as f:
main_config = tomllib.load(f)
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
return bot_wxid and bot_wxid in ats
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
# 检查 Ats 列表
if bot_wxid and bot_wxid in ats:
return True
# 方式2检查标题中是否包含 @机器人昵称(引用消息格式)
# 引用消息的 @ 信息在 title 中,如 "@瑞依 评价下"
if bot_nickname and f"@{bot_nickname}" in title_text:
logger.debug(f"引用消息标题中检测到 @{bot_nickname}")
return True
return False
else:
return True
# keyword模式检查关键词
if trigger_mode == "keyword":
keywords = self.config["behavior"]["keywords"]
return any(kw in title_text for kw in keywords)
return False
async def _call_ai_api_with_image(

View File

@@ -70,7 +70,7 @@ class GrokVideo(PluginBase):
# 初始化MinIO客户端
self.minio_client = Minio(
"101.201.65.129:19000",
"115.190.113.141:19000",
access_key="admin",
secret_key="80012029Lz",
secure=False
@@ -173,7 +173,7 @@ class GrokVideo(PluginBase):
)
# 返回访问URL
url = f"http://101.201.65.129:19000/{self.minio_bucket}/{object_name}"
url = f"http://115.190.113.141:19000/{self.minio_bucket}/{object_name}"
logger.info(f"视频上传成功: {url}")
return url
@@ -293,7 +293,15 @@ class GrokVideo(PluginBase):
# 解析 XML 获取标题和引用消息
try:
root = ET.fromstring(content)
xml_content = content.lstrip("\ufeff")
if ":\n" in xml_content:
xml_start = xml_content.find("<?xml")
if xml_start == -1:
xml_start = xml_content.find("<msg")
if xml_start > 0:
xml_content = xml_content[xml_start:]
root = ET.fromstring(xml_content)
title = root.find(".//title")
if title is None or not title.text:
return
@@ -338,7 +346,13 @@ class GrokVideo(PluginBase):
# 解码 HTML 实体
import html
refer_xml = html.unescape(refer_content.text)
refer_xml = html.unescape(refer_content.text).lstrip("\ufeff")
if ":\n" in refer_xml:
xml_start = refer_xml.find("<?xml")
if xml_start == -1:
xml_start = refer_xml.find("<msg")
if xml_start > 0:
refer_xml = refer_xml[xml_start:]
refer_root = ET.fromstring(refer_xml)
# 提取图片信息

View File

@@ -1 +0,0 @@
# 即梦AI绘图插件

Binary file not shown.

Before

Width:  |  Height:  |  Size: 943 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 816 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 893 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 758 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 753 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 860 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 896 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.0 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 960 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 903 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 920 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 962 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 882 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 967 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 965 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 993 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 905 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 933 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 916 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 877 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 874 KiB

View File

@@ -1,372 +0,0 @@
"""
即梦AI绘图插件
支持命令触发和LLM工具调用
"""
import asyncio
import tomllib
import aiohttp
import uuid
from pathlib import Path
from datetime import datetime
from typing import List, Optional
from loguru import logger
from utils.plugin_base import PluginBase
from utils.decorators import on_text_message
from WechatHook import WechatHookClient
class TokenState:
"""Token轮询状态管理"""
def __init__(self):
self.token_index = 0
self._lock = asyncio.Lock()
async def get_next_token(self, tokens: List[str]) -> str:
"""获取下一个可用的token"""
async with self._lock:
if not tokens:
raise ValueError("Token列表为空")
return tokens[self.token_index % len(tokens)]
async def rotate(self, tokens: List[str]):
"""轮换到下一个token"""
async with self._lock:
if tokens:
self.token_index = (self.token_index + 1) % len(tokens)
class JimengAI(PluginBase):
"""即梦AI绘图插件"""
description = "即梦AI绘图插件 - 支持AI绘图和LLM工具调用"
author = "ShiHao"
version = "1.0.0"
def __init__(self):
super().__init__()
self.config = None
self.token_state = TokenState()
self.images_dir = None
async def async_init(self):
"""异步初始化"""
config_path = Path(__file__).parent / "config.toml"
with open(config_path, "rb") as f:
self.config = tomllib.load(f)
# 创建图片目录
self.images_dir = Path(__file__).parent / "images"
self.images_dir.mkdir(exist_ok=True)
logger.success(f"即梦AI插件初始化完成配置了 {len(self.config['api']['tokens'])} 个token")
async def generate_image(self, prompt: str, **kwargs) -> List[str]:
"""
生成图像
Args:
prompt: 提示词
**kwargs: 其他参数model, width, height, sample_strength, negative_prompt
Returns:
图片本地路径列表
"""
api_config = self.config["api"]
gen_config = self.config["generation"]
model = kwargs.get("model", gen_config["default_model"])
width = kwargs.get("width", gen_config["default_width"])
height = kwargs.get("height", gen_config["default_height"])
sample_strength = kwargs.get("sample_strength", gen_config["default_sample_strength"])
negative_prompt = kwargs.get("negative_prompt", gen_config["default_negative_prompt"])
# 参数验证
sample_strength = max(0.0, min(1.0, sample_strength))
width = max(64, min(2048, width))
height = max(64, min(2048, height))
tokens = api_config["tokens"]
max_retry = gen_config["max_retry_attempts"]
# 尝试每个token
for token_attempt in range(len(tokens)):
current_token = await self.token_state.get_next_token(tokens)
for attempt in range(max_retry):
if attempt > 0:
await asyncio.sleep(min(2 ** attempt, 10))
try:
url = f"{api_config['base_url'].rstrip('/')}/v1/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {current_token}"
}
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"prompt": prompt,
"negativePrompt": negative_prompt,
"width": width,
"height": height,
"sample_strength": sample_strength
}
logger.info(f"即梦AI请求: {model}, 尺寸: {width}x{height}, 提示词: {prompt[:50]}...")
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=api_config["timeout"])) as session:
async with session.post(url, headers=headers, json=payload) as response:
if response.status == 200:
data = await response.json()
logger.debug(f"API返回数据: {data}")
if "error" in data:
logger.error(f"API错误: {data['error']}")
continue
# 提取图片URL
image_paths = await self._extract_images(data)
if image_paths:
logger.success(f"成功生成 {len(image_paths)} 张图像")
return image_paths
else:
logger.warning(f"未找到图像数据API返回: {str(data)[:200]}")
continue
elif response.status == 401:
logger.warning("Token认证失败尝试下一个token")
break
elif response.status == 429:
logger.warning("请求频率限制,等待后重试")
await asyncio.sleep(5)
continue
else:
error_text = await response.text()
logger.error(f"API请求失败: {response.status}, {error_text[:200]}")
continue
except asyncio.TimeoutError:
logger.warning(f"请求超时,重试中... ({attempt + 1}/{max_retry})")
continue
except Exception as e:
logger.error(f"请求异常: {e}")
continue
# 当前token失败轮换
await self.token_state.rotate(tokens)
logger.error("所有token都失败了")
return []
async def _extract_images(self, data: dict) -> List[str]:
"""从API响应中提取图片"""
import re
image_paths = []
# 格式1: OpenAI格式的choices
if "choices" in data and data["choices"]:
for choice in data["choices"]:
if "message" in choice and "content" in choice["message"]:
content = choice["message"]["content"]
if "https://" in content:
urls = re.findall(r'https://[^\s\)]+', content)
for url in urls:
path = await self._download_image(url)
if path:
image_paths.append(path)
# 格式2: data数组
elif "data" in data:
data_list = data["data"] if isinstance(data["data"], list) else [data["data"]]
for item in data_list:
if isinstance(item, str) and item.startswith("http"):
path = await self._download_image(item)
if path:
image_paths.append(path)
elif isinstance(item, dict) and "url" in item:
path = await self._download_image(item["url"])
if path:
image_paths.append(path)
# 格式3: images数组
elif "images" in data:
images_list = data["images"] if isinstance(data["images"], list) else [data["images"]]
for item in images_list:
if isinstance(item, str) and item.startswith("http"):
path = await self._download_image(item)
if path:
image_paths.append(path)
elif isinstance(item, dict) and "url" in item:
path = await self._download_image(item["url"])
if path:
image_paths.append(path)
# 格式4: 单个URL
elif "url" in data:
path = await self._download_image(data["url"])
if path:
image_paths.append(path)
return image_paths
async def _download_image(self, url: str) -> Optional[str]:
"""下载图片到本地"""
try:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) as session:
async with session.get(url) as response:
if response.status == 200:
content = await response.read()
# 生成文件名
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
uid = uuid.uuid4().hex[:8]
file_path = self.images_dir / f"jimeng_{ts}_{uid}.jpg"
# 保存文件
with open(file_path, "wb") as f:
f.write(content)
logger.info(f"图片下载成功: {file_path}")
return str(file_path)
except Exception as e:
logger.error(f"下载图片失败: {e}")
return None
@on_text_message(priority=70)
async def handle_message(self, bot: WechatHookClient, message: dict):
"""处理文本消息"""
if not self.config["behavior"]["enable_command"]:
return True
content = message.get("Content", "").strip()
from_wxid = message.get("FromWxid", "")
is_group = message.get("IsGroup", False)
# 检查群聊/私聊开关
if is_group and not self.config["behavior"]["enable_group"]:
return True
if not is_group and not self.config["behavior"]["enable_private"]:
return True
# 检查是否是绘图命令(精确匹配命令+空格+提示词)
keywords = self.config["behavior"]["command_keywords"]
matched_keyword = None
for keyword in keywords:
if content.startswith(keyword + " "):
matched_keyword = keyword
break
if not matched_keyword:
return True
# 提取提示词
prompt = content[len(matched_keyword):].strip()
if not prompt:
await bot.send_text(from_wxid, "❌ 请提供绘图提示词\n用法: /绘图 <提示词>")
return False
logger.info(f"收到绘图请求: {prompt[:50]}...")
# 发送处理中提示
await bot.send_text(from_wxid, "🎨 正在为您生成图像,请稍候...")
try:
# 生成图像
image_paths = await self.generate_image(prompt)
if image_paths:
# 直接发送图片
await bot.send_image(from_wxid, image_paths[0])
logger.success(f"绘图成功,已发送图片")
else:
await bot.send_text(from_wxid, "❌ 图像生成失败,请稍后重试")
except Exception as e:
logger.error(f"绘图处理失败: {e}")
await bot.send_text(from_wxid, f"❌ 处理失败: {str(e)}")
return False
def get_llm_tools(self) -> List[dict]:
"""
返回LLM工具定义
供AIChat插件调用
"""
if not self.config["llm_tool"]["enabled"]:
return []
return [{
"type": "function",
"function": {
"name": self.config["llm_tool"]["tool_name"],
"description": self.config["llm_tool"]["tool_description"],
"parameters": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "图像生成提示词,描述想要生成的图像内容"
},
"width": {
"type": "integer",
"description": "图像宽度64-2048默认1024",
"default": 1024
},
"height": {
"type": "integer",
"description": "图像高度64-2048默认1024",
"default": 1024
}
},
"required": ["prompt"]
}
}
}]
async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict:
"""
执行LLM工具调用
供AIChat插件调用
Returns:
{"success": bool, "message": str, "images": List[str]}
"""
expected_tool_name = self.config["llm_tool"]["tool_name"]
logger.info(f"JimengAI工具检查: 收到={tool_name}, 期望={expected_tool_name}")
if tool_name != expected_tool_name:
return None # 不是本插件的工具返回None让其他插件处理
try:
prompt = arguments.get("prompt")
if not prompt:
return {"success": False, "message": "缺少提示词参数"}
logger.info(f"LLM工具调用绘图: {prompt[:50]}...")
# 生成图像(使用配置的默认尺寸)
gen_config = self.config["generation"]
image_paths = await self.generate_image(
prompt=prompt,
width=arguments.get("width", gen_config["default_width"]),
height=arguments.get("height", gen_config["default_height"])
)
if image_paths:
# 直接发送图片
await bot.send_image(from_wxid, image_paths[0])
return {
"success": True,
"message": "已生成并发送图像",
"images": [image_paths[0]]
}
else:
return {"success": False, "message": "图像生成失败"}
except Exception as e:
logger.error(f"LLM工具执行失败: {e}")
return {"success": False, "message": f"执行失败: {str(e)}"}

View File

@@ -1 +0,0 @@
"""Kiira2 AI绘图插件"""

View File

@@ -1,350 +0,0 @@
"""
Kiira2 AI绘图插件
支持命令触发和LLM工具调用
"""
import asyncio
import tomllib
import httpx
import uuid
from pathlib import Path
from datetime import datetime
from typing import List, Optional
from loguru import logger
from utils.plugin_base import PluginBase
from utils.decorators import on_text_message
from WechatHook import WechatHookClient
class TokenState:
"""Token轮询状态管理"""
def __init__(self):
self.token_index = 0
self._lock = asyncio.Lock()
async def get_next_token(self, tokens: List[str]) -> str:
"""获取下一个可用的token"""
async with self._lock:
if not tokens:
raise ValueError("Token列表为空")
return tokens[self.token_index % len(tokens)]
async def rotate(self, tokens: List[str]):
"""轮换到下一个token"""
async with self._lock:
if tokens:
self.token_index = (self.token_index + 1) % len(tokens)
class Kiira2AI(PluginBase):
"""Kiira2 AI绘图插件"""
description = "Kiira2 AI绘图插件 - 支持AI绘图和LLM工具调用"
author = "ShiHao"
version = "1.0.0"
def __init__(self):
super().__init__()
self.config = None
self.token_state = TokenState()
self.images_dir = None
async def async_init(self):
"""异步初始化"""
config_path = Path(__file__).parent / "config.toml"
with open(config_path, "rb") as f:
self.config = tomllib.load(f)
# 创建图片目录
self.images_dir = Path(__file__).parent / "images"
self.images_dir.mkdir(exist_ok=True)
logger.success(f"Kiira2 AI插件初始化完成配置了 {len(self.config['api']['tokens'])} 个token")
async def generate_image(self, prompt: str, **kwargs) -> List[str]:
"""
生成图像
Args:
prompt: 提示词
**kwargs: 其他参数model
Returns:
图片本地路径列表
"""
api_config = self.config["api"]
gen_config = self.config["generation"]
model = kwargs.get("model", gen_config["default_model"])
tokens = api_config["tokens"]
max_retry = gen_config["max_retry_attempts"]
# 尝试每个token
for token_attempt in range(len(tokens)):
current_token = await self.token_state.get_next_token(tokens)
for attempt in range(max_retry):
if attempt > 0:
await asyncio.sleep(min(2 ** attempt, 10))
try:
url = f"{api_config['base_url'].rstrip('/')}/v1/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {current_token}"
}
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"stream": False
}
logger.info(f"Kiira2 AI请求: {model}, 提示词: {prompt[:50]}...")
timeout = httpx.Timeout(connect=10.0, read=api_config["timeout"], write=10.0, pool=10.0)
# 配置代理
proxy = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5")
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
logger.info(f"使用代理: {proxy}")
async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client:
response = await client.post(url, json=payload, headers=headers)
if response.status_code == 200:
data = response.json()
logger.debug(f"API返回数据: {data}")
if "error" in data:
logger.error(f"API错误: {data['error']}")
continue
# 检查是否返回空content图片还在生成中
if "choices" in data and data["choices"]:
message = data["choices"][0].get("message", {})
content = message.get("content", "")
video_url = message.get("video_url")
# 如果content为空且没有video_url说明还在生成等待后重试
if not content and not video_url:
wait_time = min(10 + attempt * 5, 30)
logger.info(f"图片生成中,等待 {wait_time} 秒后重试...")
await asyncio.sleep(wait_time)
continue
# 提取图片URL
image_paths = await self._extract_images(data)
if image_paths:
logger.success(f"成功生成 {len(image_paths)} 张图像")
return image_paths
else:
logger.warning(f"未找到图像数据API返回: {str(data)[:500]}")
continue
elif response.status_code == 401:
logger.warning("Token认证失败尝试下一个token")
break
elif response.status_code == 429:
logger.warning("请求频率限制,等待后重试")
await asyncio.sleep(5)
continue
else:
error_text = response.text
logger.error(f"API请求失败: {response.status_code}, {error_text[:200]}")
continue
except asyncio.TimeoutError:
logger.warning(f"请求超时,重试中... ({attempt + 1}/{max_retry})")
continue
except Exception as e:
logger.error(f"请求异常: {e}")
continue
# 当前token失败轮换
await self.token_state.rotate(tokens)
logger.error("所有token都失败了")
return []
async def _extract_images(self, data: dict) -> List[str]:
"""从API响应中提取图片只提取图片忽略文字"""
import re
image_paths = []
# OpenAI格式的choices
if "choices" in data and data["choices"]:
for choice in data["choices"]:
message = choice.get("message", {})
# 检查video_url字段实际包含图片URL
if "video_url" in message:
video_url = message["video_url"]
if isinstance(video_url, list) and video_url:
url = video_url[0]
if isinstance(url, str) and url.startswith("http"):
path = await self._download_image(url)
if path:
image_paths.append(path)
# 检查content字段
if "content" in message and not image_paths:
content = message["content"]
if content and "http" in content:
urls = re.findall(r'https?://[^\s\)\]"]+', content)
for url in urls:
path = await self._download_image(url)
if path:
image_paths.append(path)
return image_paths
async def _download_image(self, url: str) -> Optional[str]:
"""下载图片到本地"""
try:
timeout = httpx.Timeout(connect=10.0, read=30.0, write=10.0, pool=10.0)
# 配置代理
proxy = None
proxy_config = self.config.get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5")
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client:
response = await client.get(url)
response.raise_for_status()
# 生成文件名
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
uid = uuid.uuid4().hex[:8]
file_path = self.images_dir / f"kiira2_{ts}_{uid}.jpg"
# 保存文件
with open(file_path, "wb") as f:
f.write(response.content)
logger.info(f"图片下载成功: {file_path}")
return str(file_path)
except Exception as e:
logger.error(f"下载图片失败: {e}")
return None
@on_text_message(priority=70)
async def handle_message(self, bot: WechatHookClient, message: dict):
"""处理文本消息"""
if not self.config["behavior"]["enable_command"]:
return True
content = message.get("Content", "").strip()
from_wxid = message.get("FromWxid", "")
is_group = message.get("IsGroup", False)
# 检查群聊/私聊开关
if is_group and not self.config["behavior"]["enable_group"]:
return True
if not is_group and not self.config["behavior"]["enable_private"]:
return True
# 检查是否是绘图命令
keywords = self.config["behavior"]["command_keywords"]
matched_keyword = None
for keyword in keywords:
if content.startswith(keyword + " "):
matched_keyword = keyword
break
if not matched_keyword:
return True
# 提取提示词
prompt = content[len(matched_keyword):].strip()
if not prompt:
await bot.send_text(from_wxid, "❌ 请提供绘图提示词\n用法: /画画 <提示词>")
return False
logger.info(f"收到绘图请求: {prompt[:50]}...")
# 发送处理中提示
await bot.send_text(from_wxid, "🎨 正在为您生成图像,请稍候...")
try:
# 生成图像
image_paths = await self.generate_image(prompt)
if image_paths:
# 直接发送图片
await bot.send_image(from_wxid, image_paths[0])
logger.success(f"绘图成功,已发送图片")
else:
await bot.send_text(from_wxid, "❌ 图像生成失败,请稍后重试")
except Exception as e:
logger.error(f"绘图处理失败: {e}")
await bot.send_text(from_wxid, f"❌ 处理失败: {str(e)}")
return False
def get_llm_tools(self) -> List[dict]:
"""返回LLM工具定义"""
if not self.config["llm_tool"]["enabled"]:
return []
return [{
"type": "function",
"function": {
"name": self.config["llm_tool"]["tool_name"],
"description": self.config["llm_tool"]["tool_description"],
"parameters": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "图像生成提示词,描述想要生成的图像内容"
}
},
"required": ["prompt"]
}
}
}]
async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict:
"""执行LLM工具调用"""
expected_tool_name = self.config["llm_tool"]["tool_name"]
if tool_name != expected_tool_name:
return None
try:
prompt = arguments.get("prompt")
if not prompt:
return {"success": False, "message": "缺少提示词参数"}
logger.info(f"LLM工具调用绘图: {prompt[:50]}...")
# 生成图像
image_paths = await self.generate_image(prompt=prompt)
if image_paths:
# 直接发送图片
await bot.send_image(from_wxid, image_paths[0])
return {
"success": True,
"message": "已生成并发送图像",
"images": [image_paths[0]]
}
else:
return {"success": False, "message": "图像生成失败"}
except Exception as e:
logger.error(f"LLM工具执行失败: {e}")
return {"success": False, "message": f"执行失败: {str(e)}"}

View File

@@ -0,0 +1,3 @@
from .main import TravelPlanner
__all__ = ["TravelPlanner"]

View File

@@ -0,0 +1,860 @@
"""
高德地图 API 客户端封装
提供以下功能:
- 地理编码:地址 → 坐标
- 逆地理编码:坐标 → 地址
- 行政区域查询:获取城市 adcode
- 天气查询:实况/预报天气
- POI 搜索:关键字搜索、周边搜索
- 路径规划:驾车、公交、步行、骑行
"""
from __future__ import annotations
import hashlib
import aiohttp
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Literal
from loguru import logger
@dataclass
class AmapConfig:
"""高德 API 配置"""
api_key: str
secret: str = "" # 安全密钥,用于数字签名
timeout: int = 30
class AmapClient:
"""高德地图 API 客户端"""
BASE_URL = "https://restapi.amap.com"
def __init__(self, config: AmapConfig):
self.config = config
self._session: Optional[aiohttp.ClientSession] = None
@staticmethod
def _safe_int(value, default: int = 0) -> int:
"""安全地将值转换为整数处理列表、None、空字符串等情况"""
if value is None:
return default
if isinstance(value, list):
return default
if isinstance(value, (int, float)):
return int(value)
if isinstance(value, str):
if not value.strip():
return default
try:
return int(float(value))
except (ValueError, TypeError):
return default
return default
@staticmethod
def _safe_float(value, default: float = 0.0) -> float:
"""安全地将值转换为浮点数"""
if value is None:
return default
if isinstance(value, list):
return default
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, str):
if not value.strip():
return default
try:
return float(value)
except (ValueError, TypeError):
return default
return default
@staticmethod
def _safe_str(value, default: str = "") -> str:
"""安全地将值转换为字符串,处理列表等情况"""
if value is None:
return default
if isinstance(value, list):
return default
return str(value)
async def _get_session(self) -> aiohttp.ClientSession:
"""获取或创建 HTTP 会话"""
if self._session is None or self._session.closed:
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
self._session = aiohttp.ClientSession(timeout=timeout)
return self._session
async def close(self):
"""关闭 HTTP 会话"""
if self._session and not self._session.closed:
await self._session.close()
def _generate_signature(self, params: Dict[str, Any]) -> str:
"""
生成数字签名
算法:
1. 将请求参数按参数名升序排序
2. 按 key=value 格式拼接,用 & 连接
3. 最后拼接上私钥secret
4. 对整个字符串进行 MD5 加密
Args:
params: 请求参数(不含 sig
Returns:
MD5 签名字符串
"""
# 按参数名升序排序
sorted_params = sorted(params.items(), key=lambda x: x[0])
# 拼接成 key=value&key=value 格式
param_str = "&".join(f"{k}={v}" for k, v in sorted_params)
# 拼接私钥
sign_str = param_str + self.config.secret
# MD5 加密
return hashlib.md5(sign_str.encode('utf-8')).hexdigest()
async def _request(self, endpoint: str, params: Dict[str, Any]) -> Dict[str, Any]:
"""
发送 API 请求
Args:
endpoint: API 端点路径
params: 请求参数
Returns:
API 响应数据
"""
params["key"] = self.config.api_key
params["output"] = "JSON"
# 如果配置了安全密钥,生成数字签名
if self.config.secret:
params["sig"] = self._generate_signature(params)
url = f"{self.BASE_URL}{endpoint}"
session = await self._get_session()
try:
async with session.get(url, params=params) as response:
data = await response.json()
# 检查 API 状态
status = data.get("status", "0")
if status != "1":
info = data.get("info", "未知错误")
infocode = data.get("infocode", "")
logger.warning(f"高德 API 错误: {info} (code: {infocode})")
return {"success": False, "error": info, "code": infocode}
return {"success": True, "data": data}
except aiohttp.ClientError as e:
logger.error(f"高德 API 请求失败: {e}")
return {"success": False, "error": str(e)}
except Exception as e:
logger.error(f"高德 API 未知错误: {e}")
return {"success": False, "error": str(e)}
# ==================== 地理编码 ====================
async def geocode(self, address: str, city: str = None) -> Dict[str, Any]:
"""
地理编码:将地址转换为坐标
Args:
address: 结构化地址,如 "北京市朝阳区阜通东大街6号"
city: 指定城市(可选)
Returns:
{
"success": True,
"location": "116.480881,39.989410",
"adcode": "110105",
"city": "北京市",
"district": "朝阳区",
"level": "门址"
}
"""
params = {"address": address}
if city:
params["city"] = city
result = await self._request("/v3/geocode/geo", params)
if not result["success"]:
return result
geocodes = result["data"].get("geocodes", [])
if not geocodes:
return {"success": False, "error": "未找到该地址"}
geo = geocodes[0]
return {
"success": True,
"location": geo.get("location", ""),
"adcode": geo.get("adcode", ""),
"province": geo.get("province", ""),
"city": geo.get("city", ""),
"district": geo.get("district", ""),
"level": geo.get("level", ""),
"formatted_address": geo.get("formatted_address", address)
}
async def reverse_geocode(
self,
location: str,
radius: int = 1000,
extensions: str = "base"
) -> Dict[str, Any]:
"""
逆地理编码:将坐标转换为地址
Args:
location: 经纬度坐标,格式 "lng,lat"
radius: 搜索半径0-3000
extensions: base 或 all
Returns:
地址信息
"""
params = {
"location": location,
"radius": min(radius, 3000),
"extensions": extensions
}
result = await self._request("/v3/geocode/regeo", params)
if not result["success"]:
return result
regeocode = result["data"].get("regeocode", {})
address_component = regeocode.get("addressComponent", {})
return {
"success": True,
"formatted_address": regeocode.get("formatted_address", ""),
"province": address_component.get("province", ""),
"city": address_component.get("city", ""),
"district": address_component.get("district", ""),
"adcode": address_component.get("adcode", ""),
"township": address_component.get("township", ""),
"pois": regeocode.get("pois", []) if extensions == "all" else []
}
# ==================== 行政区域查询 ====================
async def get_district(
self,
keywords: str = None,
subdistrict: int = 1
) -> Dict[str, Any]:
"""
行政区域查询
Args:
keywords: 查询关键字城市名、adcode 等)
subdistrict: 返回下级行政区级数0-3
Returns:
行政区域信息,包含 adcode、citycode 等
"""
params = {"subdistrict": subdistrict}
if keywords:
params["keywords"] = keywords
result = await self._request("/v3/config/district", params)
if not result["success"]:
return result
districts = result["data"].get("districts", [])
if not districts:
return {"success": False, "error": "未找到该行政区域"}
district = districts[0]
return {
"success": True,
"name": district.get("name", ""),
"adcode": district.get("adcode", ""),
"citycode": district.get("citycode", ""),
"center": district.get("center", ""),
"level": district.get("level", ""),
"districts": district.get("districts", [])
}
# ==================== 天气查询 ====================
async def get_weather(
self,
city: str,
extensions: Literal["base", "all"] = "all"
) -> Dict[str, Any]:
"""
天气查询
Args:
city: 城市 adcode如 110000或城市名
extensions: base=实况天气all=预报天气未来4天
Returns:
天气信息
"""
# 如果传入的是城市名,先获取 adcode
if not city.isdigit():
district_result = await self.get_district(city)
if not district_result["success"]:
return {"success": False, "error": f"无法识别城市: {city}"}
city = district_result["adcode"]
params = {
"city": city,
"extensions": extensions
}
result = await self._request("/v3/weather/weatherInfo", params)
if not result["success"]:
return result
data = result["data"]
if extensions == "base":
# 实况天气
lives = data.get("lives", [])
if not lives:
return {"success": False, "error": "未获取到天气数据"}
live = lives[0]
return {
"success": True,
"type": "live",
"city": live.get("city", ""),
"weather": live.get("weather", ""),
"temperature": live.get("temperature", ""),
"winddirection": live.get("winddirection", ""),
"windpower": live.get("windpower", ""),
"humidity": live.get("humidity", ""),
"reporttime": live.get("reporttime", "")
}
else:
# 预报天气
forecasts = data.get("forecasts", [])
if not forecasts:
return {"success": False, "error": "未获取到天气预报数据"}
forecast = forecasts[0]
casts = forecast.get("casts", [])
return {
"success": True,
"type": "forecast",
"city": forecast.get("city", ""),
"province": forecast.get("province", ""),
"reporttime": forecast.get("reporttime", ""),
"forecasts": [
{
"date": cast.get("date", ""),
"week": cast.get("week", ""),
"dayweather": cast.get("dayweather", ""),
"nightweather": cast.get("nightweather", ""),
"daytemp": cast.get("daytemp", ""),
"nighttemp": cast.get("nighttemp", ""),
"daywind": cast.get("daywind", ""),
"nightwind": cast.get("nightwind", ""),
"daypower": cast.get("daypower", ""),
"nightpower": cast.get("nightpower", "")
}
for cast in casts
]
}
# ==================== POI 搜索 ====================
async def search_poi(
self,
keywords: str = None,
types: str = None,
city: str = None,
citylimit: bool = True,
offset: int = 20,
page: int = 1,
extensions: str = "all"
) -> Dict[str, Any]:
"""
关键字搜索 POI
Args:
keywords: 查询关键字
types: POI 类型代码,多个用 | 分隔
city: 城市名或 adcode
citylimit: 是否仅返回指定城市
offset: 每页数量建议不超过25
page: 页码
extensions: base 或 all
Returns:
POI 列表
"""
params = {
"offset": min(offset, 25),
"page": page,
"extensions": extensions
}
if keywords:
params["keywords"] = keywords
if types:
params["types"] = types
if city:
params["city"] = city
params["citylimit"] = "true" if citylimit else "false"
result = await self._request("/v3/place/text", params)
if not result["success"]:
return result
pois = result["data"].get("pois", [])
count = self._safe_int(result["data"].get("count", 0))
return {
"success": True,
"count": count,
"pois": [self._format_poi(poi) for poi in pois]
}
async def search_around(
self,
location: str,
keywords: str = None,
types: str = None,
radius: int = 3000,
offset: int = 20,
page: int = 1,
extensions: str = "all"
) -> Dict[str, Any]:
"""
周边搜索 POI
Args:
location: 中心点坐标,格式 "lng,lat"
keywords: 查询关键字
types: POI 类型代码
radius: 搜索半径0-50000
offset: 每页数量
page: 页码
extensions: base 或 all
Returns:
POI 列表
"""
params = {
"location": location,
"radius": min(radius, 50000),
"offset": min(offset, 25),
"page": page,
"extensions": extensions,
"sortrule": "distance"
}
if keywords:
params["keywords"] = keywords
if types:
params["types"] = types
result = await self._request("/v3/place/around", params)
if not result["success"]:
return result
pois = result["data"].get("pois", [])
count = self._safe_int(result["data"].get("count", 0))
return {
"success": True,
"count": count,
"pois": [self._format_poi(poi) for poi in pois]
}
def _format_poi(self, poi: Dict[str, Any]) -> Dict[str, Any]:
"""格式化 POI 数据"""
biz_ext = poi.get("biz_ext", {}) or {}
return {
"id": poi.get("id", ""),
"name": poi.get("name", ""),
"type": poi.get("type", ""),
"address": poi.get("address", ""),
"location": poi.get("location", ""),
"tel": poi.get("tel", ""),
"distance": poi.get("distance", ""),
"pname": poi.get("pname", ""),
"cityname": poi.get("cityname", ""),
"adname": poi.get("adname", ""),
"rating": biz_ext.get("rating", ""),
"cost": biz_ext.get("cost", "")
}
# ==================== 路径规划 ====================
async def route_driving(
self,
origin: str,
destination: str,
strategy: int = 10,
waypoints: str = None,
extensions: str = "base"
) -> Dict[str, Any]:
"""
驾车路径规划
Args:
origin: 起点坐标 "lng,lat"
destination: 终点坐标 "lng,lat"
strategy: 驾车策略10=躲避拥堵13=不走高速14=避免收费)
waypoints: 途经点,多个用 ; 分隔
extensions: base 或 all
Returns:
路径规划结果
"""
params = {
"origin": origin,
"destination": destination,
"strategy": strategy,
"extensions": extensions
}
if waypoints:
params["waypoints"] = waypoints
result = await self._request("/v3/direction/driving", params)
if not result["success"]:
return result
route = result["data"].get("route", {})
paths = route.get("paths", [])
if not paths:
return {"success": False, "error": "未找到驾车路线"}
path = paths[0]
return {
"success": True,
"mode": "driving",
"origin": route.get("origin", ""),
"destination": route.get("destination", ""),
"distance": self._safe_int(path.get("distance", 0)),
"duration": self._safe_int(path.get("duration", 0)),
"tolls": self._safe_float(path.get("tolls", 0)),
"toll_distance": self._safe_int(path.get("toll_distance", 0)),
"traffic_lights": self._safe_int(path.get("traffic_lights", 0)),
"taxi_cost": self._safe_str(route.get("taxi_cost", "")),
"strategy": path.get("strategy", ""),
"steps": self._format_driving_steps(path.get("steps", []))
}
async def route_transit(
self,
origin: str,
destination: str,
city: str,
cityd: str = None,
strategy: int = 0,
extensions: str = "all"
) -> Dict[str, Any]:
"""
公交路径规划(含火车、地铁)
Args:
origin: 起点坐标 "lng,lat"
destination: 终点坐标 "lng,lat"
city: 起点城市
cityd: 终点城市(跨城时必填)
strategy: 0=最快1=最省钱2=最少换乘3=最少步行
extensions: base 或 all
Returns:
公交路径规划结果
"""
params = {
"origin": origin,
"destination": destination,
"city": city,
"strategy": strategy,
"extensions": extensions
}
if cityd:
params["cityd"] = cityd
result = await self._request("/v3/direction/transit/integrated", params)
if not result["success"]:
return result
route = result["data"].get("route", {})
transits = route.get("transits", [])
if not transits:
return {"success": False, "error": "未找到公交路线"}
# 返回前3个方案
formatted_transits = []
for transit in transits[:3]:
segments = transit.get("segments", [])
formatted_segments = []
for seg in segments:
# 步行段
walking = seg.get("walking", {})
if walking and walking.get("distance"):
formatted_segments.append({
"type": "walking",
"distance": self._safe_int(walking.get("distance", 0)),
"duration": self._safe_int(walking.get("duration", 0))
})
# 公交/地铁段
bus_info = seg.get("bus", {})
buslines = bus_info.get("buslines", [])
if buslines:
line = buslines[0]
formatted_segments.append({
"type": "bus",
"name": self._safe_str(line.get("name", "")),
"departure_stop": self._safe_str(line.get("departure_stop", {}).get("name", "")),
"arrival_stop": self._safe_str(line.get("arrival_stop", {}).get("name", "")),
"via_num": self._safe_int(line.get("via_num", 0)),
"distance": self._safe_int(line.get("distance", 0)),
"duration": self._safe_int(line.get("duration", 0))
})
# 火车段
railway = seg.get("railway", {})
if railway and railway.get("name"):
formatted_segments.append({
"type": "railway",
"name": self._safe_str(railway.get("name", "")),
"trip": self._safe_str(railway.get("trip", "")),
"departure_stop": self._safe_str(railway.get("departure_stop", {}).get("name", "")),
"arrival_stop": self._safe_str(railway.get("arrival_stop", {}).get("name", "")),
"departure_time": self._safe_str(railway.get("departure_stop", {}).get("time", "")),
"arrival_time": self._safe_str(railway.get("arrival_stop", {}).get("time", "")),
"distance": self._safe_int(railway.get("distance", 0)),
"time": self._safe_str(railway.get("time", ""))
})
formatted_transits.append({
"cost": self._safe_str(transit.get("cost", "")),
"duration": self._safe_int(transit.get("duration", 0)),
"walking_distance": self._safe_int(transit.get("walking_distance", 0)),
"segments": formatted_segments
})
return {
"success": True,
"mode": "transit",
"origin": route.get("origin", ""),
"destination": route.get("destination", ""),
"distance": self._safe_int(route.get("distance", 0)),
"taxi_cost": self._safe_str(route.get("taxi_cost", "")),
"transits": formatted_transits
}
async def route_walking(
self,
origin: str,
destination: str
) -> Dict[str, Any]:
"""
步行路径规划
Args:
origin: 起点坐标 "lng,lat"
destination: 终点坐标 "lng,lat"
Returns:
步行路径规划结果
"""
params = {
"origin": origin,
"destination": destination
}
result = await self._request("/v3/direction/walking", params)
if not result["success"]:
return result
route = result["data"].get("route", {})
paths = route.get("paths", [])
if not paths:
return {"success": False, "error": "未找到步行路线"}
path = paths[0]
return {
"success": True,
"mode": "walking",
"origin": route.get("origin", ""),
"destination": route.get("destination", ""),
"distance": self._safe_int(path.get("distance", 0)),
"duration": self._safe_int(path.get("duration", 0))
}
async def route_bicycling(
self,
origin: str,
destination: str
) -> Dict[str, Any]:
"""
骑行路径规划
Args:
origin: 起点坐标 "lng,lat"
destination: 终点坐标 "lng,lat"
Returns:
骑行路径规划结果
"""
params = {
"origin": origin,
"destination": destination
}
# 骑行用 v4 接口
result = await self._request("/v4/direction/bicycling", params)
if not result["success"]:
return result
data = result["data"].get("data", {})
paths = data.get("paths", [])
if not paths:
return {"success": False, "error": "未找到骑行路线"}
path = paths[0]
return {
"success": True,
"mode": "bicycling",
"origin": data.get("origin", ""),
"destination": data.get("destination", ""),
"distance": self._safe_int(path.get("distance", 0)),
"duration": self._safe_int(path.get("duration", 0))
}
def _format_driving_steps(self, steps: List[Dict]) -> List[Dict]:
"""格式化驾车步骤"""
return [
{
"instruction": step.get("instruction", ""),
"road": step.get("road", ""),
"distance": self._safe_int(step.get("distance", 0)),
"duration": self._safe_int(step.get("duration", 0)),
"orientation": step.get("orientation", "")
}
for step in steps[:10] # 只返回前10步
]
# ==================== 距离测量 ====================
async def get_distance(
self,
origins: str,
destination: str,
type: int = 1
) -> Dict[str, Any]:
"""
距离测量
Args:
origins: 起点坐标,多个用 | 分隔
destination: 终点坐标
type: 0=直线距离1=驾车距离3=步行距离
Returns:
距离信息
"""
params = {
"origins": origins,
"destination": destination,
"type": type
}
result = await self._request("/v3/distance", params)
if not result["success"]:
return result
results = result["data"].get("results", [])
if not results:
return {"success": False, "error": "无法计算距离"}
return {
"success": True,
"results": [
{
"origin_id": r.get("origin_id", ""),
"distance": self._safe_int(r.get("distance", 0)),
"duration": self._safe_int(r.get("duration", 0))
}
for r in results
]
}
# ==================== 输入提示 ====================
async def input_tips(
self,
keywords: str,
city: str = None,
citylimit: bool = False,
datatype: str = "all"
) -> Dict[str, Any]:
"""
输入提示
Args:
keywords: 查询关键字
city: 城市名或 adcode
citylimit: 是否仅返回指定城市
datatype: all/poi/bus/busline
Returns:
提示列表
"""
params = {
"keywords": keywords,
"datatype": datatype
}
if city:
params["city"] = city
params["citylimit"] = "true" if citylimit else "false"
result = await self._request("/v3/assistant/inputtips", params)
if not result["success"]:
return result
tips = result["data"].get("tips", [])
return {
"success": True,
"tips": [
{
"id": tip.get("id", ""),
"name": tip.get("name", ""),
"district": tip.get("district", ""),
"adcode": tip.get("adcode", ""),
"location": tip.get("location", ""),
"address": tip.get("address", "")
}
for tip in tips
if tip.get("location") # 过滤无坐标的结果
]
}

View File

@@ -0,0 +1,609 @@
"""
旅行规划插件
基于高德地图 API提供以下功能
- 地点搜索与地理编码
- 天气查询(实况 + 4天预报
- 景点/酒店/餐厅搜索
- 路径规划(驾车/公交/步行)
- 周边搜索
支持 LLM 函数调用,可与 AIChat 插件配合使用。
"""
import tomllib
from pathlib import Path
from typing import Any, Dict, List
from loguru import logger
from utils.plugin_base import PluginBase
from .amap_client import AmapClient, AmapConfig
class TravelPlanner(PluginBase):
"""旅行规划插件"""
description = "旅行规划助手,支持天气查询、景点搜索、路线规划"
author = "ShiHao"
version = "1.0.0"
def __init__(self):
super().__init__()
self.config = None
self.amap: AmapClient = None
async def async_init(self):
"""插件异步初始化"""
# 读取配置
config_path = Path(__file__).parent / "config.toml"
with open(config_path, "rb") as f:
self.config = tomllib.load(f)
# 初始化高德 API 客户端
amap_config = self.config.get("amap", {})
api_key = amap_config.get("api_key", "")
secret = amap_config.get("secret", "")
if not api_key:
logger.warning("TravelPlanner: 未配置高德 API Key请在 config.toml 中设置")
else:
self.amap = AmapClient(AmapConfig(
api_key=api_key,
secret=secret,
timeout=amap_config.get("timeout", 30)
))
if secret:
logger.success(f"TravelPlanner 插件已加载API Key: {api_key[:8]}...(已启用数字签名)")
else:
logger.success(f"TravelPlanner 插件已加载API Key: {api_key[:8]}...(未配置安全密钥)")
async def on_disable(self):
"""插件禁用时关闭连接"""
await super().on_disable()
if self.amap:
await self.amap.close()
logger.info("TravelPlanner: 已关闭高德 API 连接")
# ==================== LLM 工具定义 ====================
def get_llm_tools(self) -> List[Dict]:
"""返回 LLM 可调用的工具列表"""
return [
{
"type": "function",
"function": {
"name": "search_location",
"description": "【旅行工具】将地名转换为坐标和行政区划信息。仅当用户明确询问某个地点的位置信息时使用。",
"parameters": {
"type": "object",
"properties": {
"address": {
"type": "string",
"description": "地址或地名,如:北京市、西湖、故宫"
},
"city": {
"type": "string",
"description": "所在城市,可选。填写可提高搜索精度"
}
},
"required": ["address"]
}
}
},
{
"type": "function",
"function": {
"name": "query_weather",
"description": "【旅行工具】查询城市天气预报。仅当用户明确询问某城市的天气情况时使用,如'北京天气怎么样''杭州明天会下雨吗'",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "城市名称,如:北京、杭州、上海"
},
"forecast": {
"type": "boolean",
"description": "是否查询预报天气。true=未来4天预报false=当前实况"
}
},
"required": ["city"]
}
}
},
{
"type": "function",
"function": {
"name": "search_poi",
"description": "【旅行工具】搜索地点(景点、酒店、餐厅等)。仅当用户明确要求查找某城市的景点、酒店、餐厅等时使用。",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "搜索城市,如:杭州、北京"
},
"keyword": {
"type": "string",
"description": "搜索关键词,如:西湖、希尔顿酒店、火锅"
},
"category": {
"type": "string",
"enum": ["景点", "酒店", "餐厅", "购物", "交通"],
"description": "POI 类别。不填则搜索所有类别"
},
"limit": {
"type": "integer",
"description": "返回结果数量默认10最大20"
}
},
"required": ["city"]
}
}
},
{
"type": "function",
"function": {
"name": "search_nearby",
"description": "【旅行工具】搜索某地点周边的设施。仅当用户明确要求查找某地点附近的餐厅、酒店等时使用,如'西湖附近有什么好吃的'",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "中心地点名称,如:西湖、故宫"
},
"city": {
"type": "string",
"description": "所在城市"
},
"keyword": {
"type": "string",
"description": "搜索关键词"
},
"category": {
"type": "string",
"enum": ["景点", "酒店", "餐厅", "购物", "交通"],
"description": "POI 类别"
},
"radius": {
"type": "integer",
"description": "搜索半径默认3000最大50000"
}
},
"required": ["location", "city"]
}
}
},
{
"type": "function",
"function": {
"name": "plan_route",
"description": "【旅行工具】规划两地之间的出行路线。仅当用户明确要求规划从A到B的路线时使用'从北京到杭州怎么走''上海到苏州的高铁'",
"parameters": {
"type": "object",
"properties": {
"origin": {
"type": "string",
"description": "起点地名,如:北京、上海虹桥站"
},
"destination": {
"type": "string",
"description": "终点地名,如:杭州、西湖"
},
"origin_city": {
"type": "string",
"description": "起点所在城市"
},
"destination_city": {
"type": "string",
"description": "终点所在城市(跨城时必填)"
},
"mode": {
"type": "string",
"enum": ["driving", "transit", "walking"],
"description": "出行方式driving=驾车transit=公交/高铁walking=步行。默认 transit"
}
},
"required": ["origin", "destination", "origin_city"]
}
}
},
{
"type": "function",
"function": {
"name": "get_travel_info",
"description": "【旅行工具】获取目的地城市的旅行信息(天气、景点、交通)。仅当用户明确表示要去某城市旅游并询问相关信息时使用,如'我想去杭州玩,帮我看看''北京旅游攻略'",
"parameters": {
"type": "object",
"properties": {
"destination": {
"type": "string",
"description": "目的地城市,如:杭州、成都"
},
"origin": {
"type": "string",
"description": "出发城市,如:北京、上海。填写后会规划交通路线"
}
},
"required": ["destination"]
}
}
}
]
async def execute_llm_tool(
self,
tool_name: str,
arguments: Dict[str, Any],
bot,
from_wxid: str
) -> Dict[str, Any]:
"""执行 LLM 工具调用"""
if not self.amap:
return {"success": False, "message": "高德 API 未配置,请联系管理员设置 API Key"}
try:
if tool_name == "search_location":
return await self._tool_search_location(arguments)
elif tool_name == "query_weather":
return await self._tool_query_weather(arguments)
elif tool_name == "search_poi":
return await self._tool_search_poi(arguments)
elif tool_name == "search_nearby":
return await self._tool_search_nearby(arguments)
elif tool_name == "plan_route":
return await self._tool_plan_route(arguments)
elif tool_name == "get_travel_info":
return await self._tool_get_travel_info(arguments)
else:
return {"success": False, "message": f"未知工具: {tool_name}"}
except Exception as e:
logger.error(f"TravelPlanner 工具执行失败: {tool_name}, 错误: {e}")
return {"success": False, "message": f"工具执行失败: {str(e)}"}
# ==================== 工具实现 ====================
async def _tool_search_location(self, args: Dict) -> Dict:
"""地点搜索工具"""
address = args.get("address", "")
city = args.get("city")
result = await self.amap.geocode(address, city)
if not result["success"]:
return {"success": False, "message": result.get("error", "地点搜索失败")}
return {
"success": True,
"message": f"已找到地点:{result['formatted_address']}",
"data": {
"name": address,
"formatted_address": result["formatted_address"],
"location": result["location"],
"province": result["province"],
"city": result["city"],
"district": result["district"],
"adcode": result["adcode"]
}
}
async def _tool_query_weather(self, args: Dict) -> Dict:
"""天气查询工具"""
city = args.get("city", "")
forecast = args.get("forecast", True)
extensions = "all" if forecast else "base"
result = await self.amap.get_weather(city, extensions)
if not result["success"]:
return {"success": False, "message": result.get("error", "天气查询失败")}
if result["type"] == "live":
return {
"success": True,
"message": f"{result['city']}当前天气:{result['weather']}{result['temperature']}",
"data": {
"city": result["city"],
"weather": result["weather"],
"temperature": result["temperature"],
"humidity": result["humidity"],
"wind": f"{result['winddirection']}{result['windpower']}",
"reporttime": result["reporttime"]
}
}
else:
forecasts = result["forecasts"]
weather_text = "\n".join([
f"- {f['date']} 星期{self._weekday_cn(f['week'])}:白天{f['dayweather']} {f['daytemp']}℃,夜间{f['nightweather']} {f['nighttemp']}"
for f in forecasts
])
return {
"success": True,
"message": f"{result['city']}未来天气预报:\n{weather_text}",
"data": {
"city": result["city"],
"province": result["province"],
"forecasts": forecasts,
"reporttime": result["reporttime"]
}
}
async def _tool_search_poi(self, args: Dict) -> Dict:
"""POI 搜索工具"""
city = args.get("city", "")
keyword = args.get("keyword")
category = args.get("category")
limit = min(args.get("limit", 10), 20)
# 获取 POI 类型代码
types = None
if category:
poi_types = self.config.get("poi_types", {})
types = poi_types.get(category)
result = await self.amap.search_poi(
keywords=keyword,
types=types,
city=city,
citylimit=True,
offset=limit
)
if not result["success"]:
return {"success": False, "message": result.get("error", "搜索失败")}
pois = result["pois"]
if not pois:
return {"success": False, "message": f"{city}未找到相关地点"}
# 格式化输出
poi_list = []
for i, poi in enumerate(pois, 1):
info = f"{i}. {poi['name']}"
if poi.get("address"):
info += f" - {poi['address']}"
if poi.get("rating"):
info += f"{poi['rating']}"
if poi.get("cost"):
info += f" 人均¥{poi['cost']}"
poi_list.append(info)
return {
"success": True,
"message": f"{city}找到{len(pois)}个结果:\n" + "\n".join(poi_list),
"data": {
"city": city,
"category": category or "全部",
"count": len(pois),
"pois": pois
}
}
async def _tool_search_nearby(self, args: Dict) -> Dict:
"""周边搜索工具"""
location_name = args.get("location", "")
city = args.get("city", "")
keyword = args.get("keyword")
category = args.get("category")
radius = min(args.get("radius", 3000), 50000)
# 先获取中心点坐标
geo_result = await self.amap.geocode(location_name, city)
if not geo_result["success"]:
return {"success": False, "message": f"无法定位 {location_name}"}
location = geo_result["location"]
# 获取 POI 类型代码
types = None
if category:
poi_types = self.config.get("poi_types", {})
types = poi_types.get(category)
result = await self.amap.search_around(
location=location,
keywords=keyword,
types=types,
radius=radius,
offset=10
)
if not result["success"]:
return {"success": False, "message": result.get("error", "周边搜索失败")}
pois = result["pois"]
if not pois:
return {"success": False, "message": f"{location_name}周边未找到相关地点"}
# 格式化输出
poi_list = []
for i, poi in enumerate(pois, 1):
info = f"{i}. {poi['name']}"
if poi.get("distance"):
info += f" ({poi['distance']}米)"
if poi.get("rating"):
info += f"{poi['rating']}"
poi_list.append(info)
return {
"success": True,
"message": f"{location_name}周边{radius}米内找到{len(pois)}个结果:\n" + "\n".join(poi_list),
"data": {
"center": location_name,
"radius": radius,
"category": category or "全部",
"count": len(pois),
"pois": pois
}
}
async def _tool_plan_route(self, args: Dict) -> Dict:
"""路线规划工具"""
origin = args.get("origin", "")
destination = args.get("destination", "")
origin_city = args.get("origin_city", "")
destination_city = args.get("destination_city", origin_city)
mode = args.get("mode", "transit")
# 获取起终点坐标
origin_geo = await self.amap.geocode(origin, origin_city)
if not origin_geo["success"]:
return {"success": False, "message": f"无法定位起点:{origin}"}
dest_geo = await self.amap.geocode(destination, destination_city)
if not dest_geo["success"]:
return {"success": False, "message": f"无法定位终点:{destination}"}
origin_loc = origin_geo["location"]
dest_loc = dest_geo["location"]
# 根据模式规划路线
if mode == "driving":
result = await self.amap.route_driving(origin_loc, dest_loc)
if not result["success"]:
return {"success": False, "message": result.get("error", "驾车路线规划失败")}
distance_km = result["distance"] / 1000
duration_h = result["duration"] / 3600
msg = f"🚗 驾车路线:{origin}{destination}\n"
msg += f"距离:{distance_km:.1f}公里,预计{self._format_duration(result['duration'])}\n"
if result["tolls"]:
msg += f"收费:约{result['tolls']}\n"
if result["taxi_cost"]:
msg += f"打车费用:约{result['taxi_cost']}"
return {
"success": True,
"message": msg,
"data": result
}
elif mode == "transit":
result = await self.amap.route_transit(
origin_loc, dest_loc,
city=origin_city,
cityd=destination_city if destination_city != origin_city else None
)
if not result["success"]:
return {"success": False, "message": result.get("error", "公交路线规划失败")}
msg = f"🚄 公交/高铁路线:{origin}{destination}\n"
for i, transit in enumerate(result["transits"][:2], 1):
msg += f"\n方案{i}{self._format_duration(transit['duration'])}"
if transit.get("cost"):
msg += f",约{transit['cost']}"
msg += "\n"
for seg in transit["segments"]:
if seg["type"] == "walking" and seg["distance"] > 100:
msg += f" 🚶 步行{seg['distance']}\n"
elif seg["type"] == "bus":
msg += f" 🚌 {seg['name']}{seg['departure_stop']}{seg['arrival_stop']}{seg['via_num']}站)\n"
elif seg["type"] == "railway":
msg += f" 🚄 {seg['trip']} {seg['name']}{seg['departure_stop']} {seg.get('departure_time', '')}{seg['arrival_stop']} {seg.get('arrival_time', '')}\n"
return {
"success": True,
"message": msg.strip(),
"data": result
}
elif mode == "walking":
result = await self.amap.route_walking(origin_loc, dest_loc)
if not result["success"]:
return {"success": False, "message": result.get("error", "步行路线规划失败")}
return {
"success": True,
"message": f"🚶 步行路线:{origin}{destination}\n距离:{result['distance']}米,预计{self._format_duration(result['duration'])}",
"data": result
}
return {"success": False, "message": f"不支持的出行方式:{mode}"}
async def _tool_get_travel_info(self, args: Dict) -> Dict:
"""一键获取旅行信息"""
destination = args.get("destination", "")
origin = args.get("origin")
info = {"destination": destination}
msg_parts = [f"📍 {destination} 旅行信息\n"]
# 1. 查询天气
weather_result = await self.amap.get_weather(destination, "all")
if weather_result["success"]:
info["weather"] = weather_result
msg_parts.append("🌤️ 天气预报:")
for f in weather_result["forecasts"][:3]:
msg_parts.append(f" {f['date']} {f['dayweather']} {f['nighttemp']}~{f['daytemp']}")
# 2. 搜索热门景点
poi_result = await self.amap.search_poi(
types="110000", # 景点
city=destination,
citylimit=True,
offset=5
)
if poi_result["success"] and poi_result["pois"]:
info["attractions"] = poi_result["pois"]
msg_parts.append("\n🏞️ 热门景点:")
for poi in poi_result["pois"][:5]:
rating = f"{poi['rating']}" if poi.get("rating") else ""
msg_parts.append(f"{poi['name']}{rating}")
# 3. 规划交通路线(如果提供了出发地)
if origin:
origin_geo = await self.amap.geocode(origin)
dest_geo = await self.amap.geocode(destination)
if origin_geo["success"] and dest_geo["success"]:
route_result = await self.amap.route_transit(
origin_geo["location"],
dest_geo["location"],
city=origin_geo.get("city", origin),
cityd=dest_geo.get("city", destination)
)
if route_result["success"] and route_result["transits"]:
info["route"] = route_result
transit = route_result["transits"][0]
msg_parts.append(f"\n🚄 从{origin}出发:")
msg_parts.append(f" 预计{self._format_duration(transit['duration'])}")
# 显示主要交通工具
for seg in transit["segments"]:
if seg["type"] == "railway":
msg_parts.append(f" {seg['trip']}{seg['departure_stop']}{seg['arrival_stop']}")
break
return {
"success": True,
"message": "\n".join(msg_parts),
"data": info
}
# ==================== 辅助方法 ====================
def _weekday_cn(self, week: str) -> str:
"""星期数字转中文"""
mapping = {"1": "", "2": "", "3": "", "4": "", "5": "", "6": "", "7": ""}
return mapping.get(str(week), week)
def _format_duration(self, seconds: int) -> str:
"""格式化时长"""
if seconds < 60:
return f"{seconds}"
elif seconds < 3600:
return f"{seconds // 60}分钟"
else:
hours = seconds // 3600
minutes = (seconds % 3600) // 60
if minutes:
return f"{hours}小时{minutes}分钟"
return f"{hours}小时"