feat:修复一些BUG
@@ -2831,10 +2831,64 @@ class AIChat(PluginBase):
|
||||
refer_type = int(refer_type_elem.text) if refer_type_elem is not None and refer_type_elem.text else 0
|
||||
logger.debug(f"被引用消息类型: {refer_type}")
|
||||
|
||||
# 纯文本消息不需要处理(type=1)
|
||||
# 纯文本消息(type=1):如果@了机器人,转发给 AI 处理
|
||||
if refer_type == 1:
|
||||
logger.debug("引用的是纯文本消息,跳过")
|
||||
return True
|
||||
if self._should_reply_quote(message, title_text):
|
||||
# 获取被引用的文本内容
|
||||
refer_content_elem = refermsg.find("content")
|
||||
refer_text = refer_content_elem.text.strip() if refer_content_elem is not None and refer_content_elem.text else ""
|
||||
|
||||
# 获取被引用者昵称
|
||||
refer_displayname = refermsg.find("displayname")
|
||||
refer_nickname = refer_displayname.text if refer_displayname is not None and refer_displayname.text else "某人"
|
||||
|
||||
# 组合消息:引用内容 + 用户评论
|
||||
# title_text 格式如 "@瑞依 评价下",需要去掉 @昵称 部分
|
||||
import tomllib
|
||||
with open("main_config.toml", "rb") as f:
|
||||
main_config = tomllib.load(f)
|
||||
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
|
||||
|
||||
user_comment = title_text
|
||||
if bot_nickname:
|
||||
# 移除 @机器人昵称(可能有空格分隔)
|
||||
user_comment = user_comment.replace(f"@{bot_nickname}", "").strip()
|
||||
|
||||
# 构造给 AI 的消息
|
||||
combined_message = f"[引用 {refer_nickname} 的消息:{refer_text}]\n{user_comment}"
|
||||
logger.info(f"引用纯文本消息,转发给 AI: {combined_message[:80]}...")
|
||||
|
||||
# 调用 AI 处理
|
||||
nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group)
|
||||
chat_id = from_wxid if is_group else user_wxid
|
||||
|
||||
# 保存用户消息到群组历史记录
|
||||
history_enabled = bool(self.store) and self.config.get("history", {}).get("enabled", True)
|
||||
sync_bot_messages = self.config.get("history", {}).get("sync_bot_messages", True)
|
||||
if is_group and history_enabled:
|
||||
history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid)
|
||||
await self._add_to_history(history_chat_id, nickname, combined_message, sender_wxid=user_wxid)
|
||||
|
||||
ai_response = await self._call_ai_api(
|
||||
combined_message,
|
||||
bot=bot,
|
||||
from_wxid=from_wxid,
|
||||
chat_id=chat_id,
|
||||
nickname=nickname
|
||||
)
|
||||
|
||||
if ai_response:
|
||||
final_response = self._sanitize_llm_output(ai_response)
|
||||
await bot.send_text(from_wxid, final_response)
|
||||
|
||||
# 保存 AI 回复到群组历史记录
|
||||
if is_group and history_enabled and sync_bot_messages:
|
||||
bot_nickname_display = main_config.get("Bot", {}).get("nickname", "AI")
|
||||
await self._add_to_history(history_chat_id, bot_nickname_display, final_response, role="assistant")
|
||||
return False
|
||||
else:
|
||||
logger.debug("引用的是纯文本消息且未@机器人,跳过")
|
||||
return True
|
||||
|
||||
# 只处理图片(3)、视频(43)、应用消息(49,含聊天记录)
|
||||
if refer_type not in [3, 43, 49]:
|
||||
@@ -3553,40 +3607,50 @@ class AIChat(PluginBase):
|
||||
def _should_reply_quote(self, message: dict, title_text: str) -> bool:
|
||||
"""判断是否应该回复引用消息"""
|
||||
is_group = message.get("IsGroup", False)
|
||||
|
||||
|
||||
# 检查群聊/私聊开关
|
||||
if is_group and not self.config["behavior"]["reply_group"]:
|
||||
return False
|
||||
if not is_group and not self.config["behavior"]["reply_private"]:
|
||||
return False
|
||||
|
||||
|
||||
trigger_mode = self.config["behavior"]["trigger_mode"]
|
||||
|
||||
|
||||
# all模式:回复所有消息
|
||||
if trigger_mode == "all":
|
||||
return True
|
||||
|
||||
|
||||
# mention模式:检查是否@了机器人
|
||||
if trigger_mode == "mention":
|
||||
if is_group:
|
||||
# 方式1:检查 Ats 字段(普通消息格式)
|
||||
ats = message.get("Ats", [])
|
||||
if not ats:
|
||||
return False
|
||||
|
||||
|
||||
import tomllib
|
||||
with open("main_config.toml", "rb") as f:
|
||||
main_config = tomllib.load(f)
|
||||
bot_wxid = main_config.get("Bot", {}).get("wxid", "")
|
||||
|
||||
return bot_wxid and bot_wxid in ats
|
||||
bot_nickname = main_config.get("Bot", {}).get("nickname", "")
|
||||
|
||||
# 检查 Ats 列表
|
||||
if bot_wxid and bot_wxid in ats:
|
||||
return True
|
||||
|
||||
# 方式2:检查标题中是否包含 @机器人昵称(引用消息格式)
|
||||
# 引用消息的 @ 信息在 title 中,如 "@瑞依 评价下"
|
||||
if bot_nickname and f"@{bot_nickname}" in title_text:
|
||||
logger.debug(f"引用消息标题中检测到 @{bot_nickname}")
|
||||
return True
|
||||
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
# keyword模式:检查关键词
|
||||
if trigger_mode == "keyword":
|
||||
keywords = self.config["behavior"]["keywords"]
|
||||
return any(kw in title_text for kw in keywords)
|
||||
|
||||
|
||||
return False
|
||||
|
||||
async def _call_ai_api_with_image(
|
||||
|
||||
@@ -70,7 +70,7 @@ class GrokVideo(PluginBase):
|
||||
|
||||
# 初始化MinIO客户端
|
||||
self.minio_client = Minio(
|
||||
"101.201.65.129:19000",
|
||||
"115.190.113.141:19000",
|
||||
access_key="admin",
|
||||
secret_key="80012029Lz",
|
||||
secure=False
|
||||
@@ -173,7 +173,7 @@ class GrokVideo(PluginBase):
|
||||
)
|
||||
|
||||
# 返回访问URL
|
||||
url = f"http://101.201.65.129:19000/{self.minio_bucket}/{object_name}"
|
||||
url = f"http://115.190.113.141:19000/{self.minio_bucket}/{object_name}"
|
||||
logger.info(f"视频上传成功: {url}")
|
||||
return url
|
||||
|
||||
@@ -293,7 +293,15 @@ class GrokVideo(PluginBase):
|
||||
|
||||
# 解析 XML 获取标题和引用消息
|
||||
try:
|
||||
root = ET.fromstring(content)
|
||||
xml_content = content.lstrip("\ufeff")
|
||||
if ":\n" in xml_content:
|
||||
xml_start = xml_content.find("<?xml")
|
||||
if xml_start == -1:
|
||||
xml_start = xml_content.find("<msg")
|
||||
if xml_start > 0:
|
||||
xml_content = xml_content[xml_start:]
|
||||
|
||||
root = ET.fromstring(xml_content)
|
||||
title = root.find(".//title")
|
||||
if title is None or not title.text:
|
||||
return
|
||||
@@ -338,7 +346,13 @@ class GrokVideo(PluginBase):
|
||||
|
||||
# 解码 HTML 实体
|
||||
import html
|
||||
refer_xml = html.unescape(refer_content.text)
|
||||
refer_xml = html.unescape(refer_content.text).lstrip("\ufeff")
|
||||
if ":\n" in refer_xml:
|
||||
xml_start = refer_xml.find("<?xml")
|
||||
if xml_start == -1:
|
||||
xml_start = refer_xml.find("<msg")
|
||||
if xml_start > 0:
|
||||
refer_xml = refer_xml[xml_start:]
|
||||
refer_root = ET.fromstring(refer_xml)
|
||||
|
||||
# 提取图片信息
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# 即梦AI绘图插件
|
||||
|
Before Width: | Height: | Size: 943 KiB |
|
Before Width: | Height: | Size: 816 KiB |
|
Before Width: | Height: | Size: 893 KiB |
|
Before Width: | Height: | Size: 758 KiB |
|
Before Width: | Height: | Size: 753 KiB |
|
Before Width: | Height: | Size: 860 KiB |
|
Before Width: | Height: | Size: 896 KiB |
|
Before Width: | Height: | Size: 1.0 MiB |
|
Before Width: | Height: | Size: 1.1 MiB |
|
Before Width: | Height: | Size: 960 KiB |
|
Before Width: | Height: | Size: 1.1 MiB |
|
Before Width: | Height: | Size: 903 KiB |
|
Before Width: | Height: | Size: 920 KiB |
|
Before Width: | Height: | Size: 962 KiB |
|
Before Width: | Height: | Size: 882 KiB |
|
Before Width: | Height: | Size: 967 KiB |
|
Before Width: | Height: | Size: 965 KiB |
|
Before Width: | Height: | Size: 993 KiB |
|
Before Width: | Height: | Size: 905 KiB |
|
Before Width: | Height: | Size: 933 KiB |
|
Before Width: | Height: | Size: 916 KiB |
|
Before Width: | Height: | Size: 877 KiB |
|
Before Width: | Height: | Size: 874 KiB |
@@ -1,372 +0,0 @@
|
||||
"""
|
||||
即梦AI绘图插件
|
||||
|
||||
支持命令触发和LLM工具调用
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import tomllib
|
||||
import aiohttp
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from loguru import logger
|
||||
from utils.plugin_base import PluginBase
|
||||
from utils.decorators import on_text_message
|
||||
from WechatHook import WechatHookClient
|
||||
|
||||
|
||||
class TokenState:
|
||||
"""Token轮询状态管理"""
|
||||
def __init__(self):
|
||||
self.token_index = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_next_token(self, tokens: List[str]) -> str:
|
||||
"""获取下一个可用的token"""
|
||||
async with self._lock:
|
||||
if not tokens:
|
||||
raise ValueError("Token列表为空")
|
||||
return tokens[self.token_index % len(tokens)]
|
||||
|
||||
async def rotate(self, tokens: List[str]):
|
||||
"""轮换到下一个token"""
|
||||
async with self._lock:
|
||||
if tokens:
|
||||
self.token_index = (self.token_index + 1) % len(tokens)
|
||||
|
||||
|
||||
class JimengAI(PluginBase):
|
||||
"""即梦AI绘图插件"""
|
||||
|
||||
description = "即梦AI绘图插件 - 支持AI绘图和LLM工具调用"
|
||||
author = "ShiHao"
|
||||
version = "1.0.0"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.config = None
|
||||
self.token_state = TokenState()
|
||||
self.images_dir = None
|
||||
|
||||
async def async_init(self):
|
||||
"""异步初始化"""
|
||||
config_path = Path(__file__).parent / "config.toml"
|
||||
with open(config_path, "rb") as f:
|
||||
self.config = tomllib.load(f)
|
||||
|
||||
# 创建图片目录
|
||||
self.images_dir = Path(__file__).parent / "images"
|
||||
self.images_dir.mkdir(exist_ok=True)
|
||||
|
||||
logger.success(f"即梦AI插件初始化完成,配置了 {len(self.config['api']['tokens'])} 个token")
|
||||
|
||||
async def generate_image(self, prompt: str, **kwargs) -> List[str]:
|
||||
"""
|
||||
生成图像
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
**kwargs: 其他参数(model, width, height, sample_strength, negative_prompt)
|
||||
|
||||
Returns:
|
||||
图片本地路径列表
|
||||
"""
|
||||
api_config = self.config["api"]
|
||||
gen_config = self.config["generation"]
|
||||
|
||||
model = kwargs.get("model", gen_config["default_model"])
|
||||
width = kwargs.get("width", gen_config["default_width"])
|
||||
height = kwargs.get("height", gen_config["default_height"])
|
||||
sample_strength = kwargs.get("sample_strength", gen_config["default_sample_strength"])
|
||||
negative_prompt = kwargs.get("negative_prompt", gen_config["default_negative_prompt"])
|
||||
|
||||
# 参数验证
|
||||
sample_strength = max(0.0, min(1.0, sample_strength))
|
||||
width = max(64, min(2048, width))
|
||||
height = max(64, min(2048, height))
|
||||
|
||||
tokens = api_config["tokens"]
|
||||
max_retry = gen_config["max_retry_attempts"]
|
||||
|
||||
# 尝试每个token
|
||||
for token_attempt in range(len(tokens)):
|
||||
current_token = await self.token_state.get_next_token(tokens)
|
||||
|
||||
for attempt in range(max_retry):
|
||||
if attempt > 0:
|
||||
await asyncio.sleep(min(2 ** attempt, 10))
|
||||
|
||||
try:
|
||||
url = f"{api_config['base_url'].rstrip('/')}/v1/chat/completions"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {current_token}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"prompt": prompt,
|
||||
"negativePrompt": negative_prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"sample_strength": sample_strength
|
||||
}
|
||||
|
||||
logger.info(f"即梦AI请求: {model}, 尺寸: {width}x{height}, 提示词: {prompt[:50]}...")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=api_config["timeout"])) as session:
|
||||
async with session.post(url, headers=headers, json=payload) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
logger.debug(f"API返回数据: {data}")
|
||||
|
||||
if "error" in data:
|
||||
logger.error(f"API错误: {data['error']}")
|
||||
continue
|
||||
|
||||
# 提取图片URL
|
||||
image_paths = await self._extract_images(data)
|
||||
|
||||
if image_paths:
|
||||
logger.success(f"成功生成 {len(image_paths)} 张图像")
|
||||
return image_paths
|
||||
else:
|
||||
logger.warning(f"未找到图像数据,API返回: {str(data)[:200]}")
|
||||
continue
|
||||
|
||||
elif response.status == 401:
|
||||
logger.warning("Token认证失败,尝试下一个token")
|
||||
break
|
||||
elif response.status == 429:
|
||||
logger.warning("请求频率限制,等待后重试")
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"API请求失败: {response.status}, {error_text[:200]}")
|
||||
continue
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"请求超时,重试中... ({attempt + 1}/{max_retry})")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"请求异常: {e}")
|
||||
continue
|
||||
|
||||
# 当前token失败,轮换
|
||||
await self.token_state.rotate(tokens)
|
||||
|
||||
logger.error("所有token都失败了")
|
||||
return []
|
||||
|
||||
async def _extract_images(self, data: dict) -> List[str]:
|
||||
"""从API响应中提取图片"""
|
||||
import re
|
||||
image_paths = []
|
||||
|
||||
# 格式1: OpenAI格式的choices
|
||||
if "choices" in data and data["choices"]:
|
||||
for choice in data["choices"]:
|
||||
if "message" in choice and "content" in choice["message"]:
|
||||
content = choice["message"]["content"]
|
||||
if "https://" in content:
|
||||
urls = re.findall(r'https://[^\s\)]+', content)
|
||||
for url in urls:
|
||||
path = await self._download_image(url)
|
||||
if path:
|
||||
image_paths.append(path)
|
||||
|
||||
# 格式2: data数组
|
||||
elif "data" in data:
|
||||
data_list = data["data"] if isinstance(data["data"], list) else [data["data"]]
|
||||
for item in data_list:
|
||||
if isinstance(item, str) and item.startswith("http"):
|
||||
path = await self._download_image(item)
|
||||
if path:
|
||||
image_paths.append(path)
|
||||
elif isinstance(item, dict) and "url" in item:
|
||||
path = await self._download_image(item["url"])
|
||||
if path:
|
||||
image_paths.append(path)
|
||||
|
||||
# 格式3: images数组
|
||||
elif "images" in data:
|
||||
images_list = data["images"] if isinstance(data["images"], list) else [data["images"]]
|
||||
for item in images_list:
|
||||
if isinstance(item, str) and item.startswith("http"):
|
||||
path = await self._download_image(item)
|
||||
if path:
|
||||
image_paths.append(path)
|
||||
elif isinstance(item, dict) and "url" in item:
|
||||
path = await self._download_image(item["url"])
|
||||
if path:
|
||||
image_paths.append(path)
|
||||
|
||||
# 格式4: 单个URL
|
||||
elif "url" in data:
|
||||
path = await self._download_image(data["url"])
|
||||
if path:
|
||||
image_paths.append(path)
|
||||
|
||||
return image_paths
|
||||
|
||||
async def _download_image(self, url: str) -> Optional[str]:
|
||||
"""下载图片到本地"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
|
||||
# 生成文件名
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
uid = uuid.uuid4().hex[:8]
|
||||
file_path = self.images_dir / f"jimeng_{ts}_{uid}.jpg"
|
||||
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
logger.info(f"图片下载成功: {file_path}")
|
||||
return str(file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"下载图片失败: {e}")
|
||||
return None
|
||||
|
||||
@on_text_message(priority=70)
|
||||
async def handle_message(self, bot: WechatHookClient, message: dict):
|
||||
"""处理文本消息"""
|
||||
if not self.config["behavior"]["enable_command"]:
|
||||
return True
|
||||
|
||||
content = message.get("Content", "").strip()
|
||||
from_wxid = message.get("FromWxid", "")
|
||||
is_group = message.get("IsGroup", False)
|
||||
|
||||
# 检查群聊/私聊开关
|
||||
if is_group and not self.config["behavior"]["enable_group"]:
|
||||
return True
|
||||
if not is_group and not self.config["behavior"]["enable_private"]:
|
||||
return True
|
||||
|
||||
# 检查是否是绘图命令(精确匹配命令+空格+提示词)
|
||||
keywords = self.config["behavior"]["command_keywords"]
|
||||
matched_keyword = None
|
||||
for keyword in keywords:
|
||||
if content.startswith(keyword + " "):
|
||||
matched_keyword = keyword
|
||||
break
|
||||
|
||||
if not matched_keyword:
|
||||
return True
|
||||
|
||||
# 提取提示词
|
||||
prompt = content[len(matched_keyword):].strip()
|
||||
if not prompt:
|
||||
await bot.send_text(from_wxid, "❌ 请提供绘图提示词\n用法: /绘图 <提示词>")
|
||||
return False
|
||||
|
||||
logger.info(f"收到绘图请求: {prompt[:50]}...")
|
||||
|
||||
# 发送处理中提示
|
||||
await bot.send_text(from_wxid, "🎨 正在为您生成图像,请稍候...")
|
||||
|
||||
try:
|
||||
# 生成图像
|
||||
image_paths = await self.generate_image(prompt)
|
||||
|
||||
if image_paths:
|
||||
# 直接发送图片
|
||||
await bot.send_image(from_wxid, image_paths[0])
|
||||
logger.success(f"绘图成功,已发送图片")
|
||||
else:
|
||||
await bot.send_text(from_wxid, "❌ 图像生成失败,请稍后重试")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"绘图处理失败: {e}")
|
||||
await bot.send_text(from_wxid, f"❌ 处理失败: {str(e)}")
|
||||
|
||||
return False
|
||||
|
||||
def get_llm_tools(self) -> List[dict]:
|
||||
"""
|
||||
返回LLM工具定义
|
||||
供AIChat插件调用
|
||||
"""
|
||||
if not self.config["llm_tool"]["enabled"]:
|
||||
return []
|
||||
|
||||
return [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.config["llm_tool"]["tool_name"],
|
||||
"description": self.config["llm_tool"]["tool_description"],
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "图像生成提示词,描述想要生成的图像内容"
|
||||
},
|
||||
"width": {
|
||||
"type": "integer",
|
||||
"description": "图像宽度(64-2048),默认1024",
|
||||
"default": 1024
|
||||
},
|
||||
"height": {
|
||||
"type": "integer",
|
||||
"description": "图像高度(64-2048),默认1024",
|
||||
"default": 1024
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict:
|
||||
"""
|
||||
执行LLM工具调用
|
||||
供AIChat插件调用
|
||||
|
||||
Returns:
|
||||
{"success": bool, "message": str, "images": List[str]}
|
||||
"""
|
||||
expected_tool_name = self.config["llm_tool"]["tool_name"]
|
||||
logger.info(f"JimengAI工具检查: 收到={tool_name}, 期望={expected_tool_name}")
|
||||
|
||||
if tool_name != expected_tool_name:
|
||||
return None # 不是本插件的工具,返回None让其他插件处理
|
||||
|
||||
try:
|
||||
prompt = arguments.get("prompt")
|
||||
if not prompt:
|
||||
return {"success": False, "message": "缺少提示词参数"}
|
||||
|
||||
logger.info(f"LLM工具调用绘图: {prompt[:50]}...")
|
||||
|
||||
# 生成图像(使用配置的默认尺寸)
|
||||
gen_config = self.config["generation"]
|
||||
image_paths = await self.generate_image(
|
||||
prompt=prompt,
|
||||
width=arguments.get("width", gen_config["default_width"]),
|
||||
height=arguments.get("height", gen_config["default_height"])
|
||||
)
|
||||
|
||||
if image_paths:
|
||||
# 直接发送图片
|
||||
await bot.send_image(from_wxid, image_paths[0])
|
||||
return {
|
||||
"success": True,
|
||||
"message": "已生成并发送图像",
|
||||
"images": [image_paths[0]]
|
||||
}
|
||||
else:
|
||||
return {"success": False, "message": "图像生成失败"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM工具执行失败: {e}")
|
||||
return {"success": False, "message": f"执行失败: {str(e)}"}
|
||||
@@ -1 +0,0 @@
|
||||
"""Kiira2 AI绘图插件"""
|
||||
@@ -1,350 +0,0 @@
|
||||
"""
|
||||
Kiira2 AI绘图插件
|
||||
|
||||
支持命令触发和LLM工具调用
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import tomllib
|
||||
import httpx
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from loguru import logger
|
||||
from utils.plugin_base import PluginBase
|
||||
from utils.decorators import on_text_message
|
||||
from WechatHook import WechatHookClient
|
||||
|
||||
|
||||
class TokenState:
|
||||
"""Token轮询状态管理"""
|
||||
def __init__(self):
|
||||
self.token_index = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_next_token(self, tokens: List[str]) -> str:
|
||||
"""获取下一个可用的token"""
|
||||
async with self._lock:
|
||||
if not tokens:
|
||||
raise ValueError("Token列表为空")
|
||||
return tokens[self.token_index % len(tokens)]
|
||||
|
||||
async def rotate(self, tokens: List[str]):
|
||||
"""轮换到下一个token"""
|
||||
async with self._lock:
|
||||
if tokens:
|
||||
self.token_index = (self.token_index + 1) % len(tokens)
|
||||
|
||||
|
||||
class Kiira2AI(PluginBase):
|
||||
"""Kiira2 AI绘图插件"""
|
||||
|
||||
description = "Kiira2 AI绘图插件 - 支持AI绘图和LLM工具调用"
|
||||
author = "ShiHao"
|
||||
version = "1.0.0"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.config = None
|
||||
self.token_state = TokenState()
|
||||
self.images_dir = None
|
||||
|
||||
async def async_init(self):
|
||||
"""异步初始化"""
|
||||
config_path = Path(__file__).parent / "config.toml"
|
||||
with open(config_path, "rb") as f:
|
||||
self.config = tomllib.load(f)
|
||||
|
||||
# 创建图片目录
|
||||
self.images_dir = Path(__file__).parent / "images"
|
||||
self.images_dir.mkdir(exist_ok=True)
|
||||
|
||||
logger.success(f"Kiira2 AI插件初始化完成,配置了 {len(self.config['api']['tokens'])} 个token")
|
||||
|
||||
async def generate_image(self, prompt: str, **kwargs) -> List[str]:
|
||||
"""
|
||||
生成图像
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
**kwargs: 其他参数(model)
|
||||
|
||||
Returns:
|
||||
图片本地路径列表
|
||||
"""
|
||||
api_config = self.config["api"]
|
||||
gen_config = self.config["generation"]
|
||||
|
||||
model = kwargs.get("model", gen_config["default_model"])
|
||||
tokens = api_config["tokens"]
|
||||
max_retry = gen_config["max_retry_attempts"]
|
||||
|
||||
# 尝试每个token
|
||||
for token_attempt in range(len(tokens)):
|
||||
current_token = await self.token_state.get_next_token(tokens)
|
||||
|
||||
for attempt in range(max_retry):
|
||||
if attempt > 0:
|
||||
await asyncio.sleep(min(2 ** attempt, 10))
|
||||
|
||||
try:
|
||||
url = f"{api_config['base_url'].rstrip('/')}/v1/chat/completions"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {current_token}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"stream": False
|
||||
}
|
||||
|
||||
logger.info(f"Kiira2 AI请求: {model}, 提示词: {prompt[:50]}...")
|
||||
|
||||
timeout = httpx.Timeout(connect=10.0, read=api_config["timeout"], write=10.0, pool=10.0)
|
||||
|
||||
# 配置代理
|
||||
proxy = None
|
||||
proxy_config = self.config.get("proxy", {})
|
||||
if proxy_config.get("enabled", False):
|
||||
proxy_type = proxy_config.get("type", "socks5")
|
||||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||||
proxy_port = proxy_config.get("port", 7890)
|
||||
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||||
logger.info(f"使用代理: {proxy}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client:
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
logger.debug(f"API返回数据: {data}")
|
||||
|
||||
if "error" in data:
|
||||
logger.error(f"API错误: {data['error']}")
|
||||
continue
|
||||
|
||||
# 检查是否返回空content(图片还在生成中)
|
||||
if "choices" in data and data["choices"]:
|
||||
message = data["choices"][0].get("message", {})
|
||||
content = message.get("content", "")
|
||||
video_url = message.get("video_url")
|
||||
|
||||
# 如果content为空且没有video_url,说明还在生成,等待后重试
|
||||
if not content and not video_url:
|
||||
wait_time = min(10 + attempt * 5, 30)
|
||||
logger.info(f"图片生成中,等待 {wait_time} 秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
# 提取图片URL
|
||||
image_paths = await self._extract_images(data)
|
||||
|
||||
if image_paths:
|
||||
logger.success(f"成功生成 {len(image_paths)} 张图像")
|
||||
return image_paths
|
||||
else:
|
||||
logger.warning(f"未找到图像数据,API返回: {str(data)[:500]}")
|
||||
continue
|
||||
|
||||
elif response.status_code == 401:
|
||||
logger.warning("Token认证失败,尝试下一个token")
|
||||
break
|
||||
elif response.status_code == 429:
|
||||
logger.warning("请求频率限制,等待后重试")
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
else:
|
||||
error_text = response.text
|
||||
logger.error(f"API请求失败: {response.status_code}, {error_text[:200]}")
|
||||
continue
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"请求超时,重试中... ({attempt + 1}/{max_retry})")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"请求异常: {e}")
|
||||
continue
|
||||
|
||||
# 当前token失败,轮换
|
||||
await self.token_state.rotate(tokens)
|
||||
|
||||
logger.error("所有token都失败了")
|
||||
return []
|
||||
|
||||
async def _extract_images(self, data: dict) -> List[str]:
|
||||
"""从API响应中提取图片(只提取图片,忽略文字)"""
|
||||
import re
|
||||
image_paths = []
|
||||
|
||||
# OpenAI格式的choices
|
||||
if "choices" in data and data["choices"]:
|
||||
for choice in data["choices"]:
|
||||
message = choice.get("message", {})
|
||||
|
||||
# 检查video_url字段(实际包含图片URL)
|
||||
if "video_url" in message:
|
||||
video_url = message["video_url"]
|
||||
if isinstance(video_url, list) and video_url:
|
||||
url = video_url[0]
|
||||
if isinstance(url, str) and url.startswith("http"):
|
||||
path = await self._download_image(url)
|
||||
if path:
|
||||
image_paths.append(path)
|
||||
|
||||
# 检查content字段
|
||||
if "content" in message and not image_paths:
|
||||
content = message["content"]
|
||||
if content and "http" in content:
|
||||
urls = re.findall(r'https?://[^\s\)\]"]+', content)
|
||||
for url in urls:
|
||||
path = await self._download_image(url)
|
||||
if path:
|
||||
image_paths.append(path)
|
||||
|
||||
return image_paths
|
||||
|
||||
async def _download_image(self, url: str) -> Optional[str]:
|
||||
"""下载图片到本地"""
|
||||
try:
|
||||
timeout = httpx.Timeout(connect=10.0, read=30.0, write=10.0, pool=10.0)
|
||||
|
||||
# 配置代理
|
||||
proxy = None
|
||||
proxy_config = self.config.get("proxy", {})
|
||||
if proxy_config.get("enabled", False):
|
||||
proxy_type = proxy_config.get("type", "socks5")
|
||||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||||
proxy_port = proxy_config.get("port", 7890)
|
||||
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# 生成文件名
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
uid = uuid.uuid4().hex[:8]
|
||||
file_path = self.images_dir / f"kiira2_{ts}_{uid}.jpg"
|
||||
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
logger.info(f"图片下载成功: {file_path}")
|
||||
return str(file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"下载图片失败: {e}")
|
||||
return None
|
||||
|
||||
@on_text_message(priority=70)
|
||||
async def handle_message(self, bot: WechatHookClient, message: dict):
|
||||
"""处理文本消息"""
|
||||
if not self.config["behavior"]["enable_command"]:
|
||||
return True
|
||||
|
||||
content = message.get("Content", "").strip()
|
||||
from_wxid = message.get("FromWxid", "")
|
||||
is_group = message.get("IsGroup", False)
|
||||
|
||||
# 检查群聊/私聊开关
|
||||
if is_group and not self.config["behavior"]["enable_group"]:
|
||||
return True
|
||||
if not is_group and not self.config["behavior"]["enable_private"]:
|
||||
return True
|
||||
|
||||
# 检查是否是绘图命令
|
||||
keywords = self.config["behavior"]["command_keywords"]
|
||||
matched_keyword = None
|
||||
for keyword in keywords:
|
||||
if content.startswith(keyword + " "):
|
||||
matched_keyword = keyword
|
||||
break
|
||||
|
||||
if not matched_keyword:
|
||||
return True
|
||||
|
||||
# 提取提示词
|
||||
prompt = content[len(matched_keyword):].strip()
|
||||
if not prompt:
|
||||
await bot.send_text(from_wxid, "❌ 请提供绘图提示词\n用法: /画画 <提示词>")
|
||||
return False
|
||||
|
||||
logger.info(f"收到绘图请求: {prompt[:50]}...")
|
||||
|
||||
# 发送处理中提示
|
||||
await bot.send_text(from_wxid, "🎨 正在为您生成图像,请稍候...")
|
||||
|
||||
try:
|
||||
# 生成图像
|
||||
image_paths = await self.generate_image(prompt)
|
||||
|
||||
if image_paths:
|
||||
# 直接发送图片
|
||||
await bot.send_image(from_wxid, image_paths[0])
|
||||
logger.success(f"绘图成功,已发送图片")
|
||||
else:
|
||||
await bot.send_text(from_wxid, "❌ 图像生成失败,请稍后重试")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"绘图处理失败: {e}")
|
||||
await bot.send_text(from_wxid, f"❌ 处理失败: {str(e)}")
|
||||
|
||||
return False
|
||||
|
||||
def get_llm_tools(self) -> List[dict]:
|
||||
"""返回LLM工具定义"""
|
||||
if not self.config["llm_tool"]["enabled"]:
|
||||
return []
|
||||
|
||||
return [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.config["llm_tool"]["tool_name"],
|
||||
"description": self.config["llm_tool"]["tool_description"],
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "图像生成提示词,描述想要生成的图像内容"
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict:
|
||||
"""执行LLM工具调用"""
|
||||
expected_tool_name = self.config["llm_tool"]["tool_name"]
|
||||
|
||||
if tool_name != expected_tool_name:
|
||||
return None
|
||||
|
||||
try:
|
||||
prompt = arguments.get("prompt")
|
||||
if not prompt:
|
||||
return {"success": False, "message": "缺少提示词参数"}
|
||||
|
||||
logger.info(f"LLM工具调用绘图: {prompt[:50]}...")
|
||||
|
||||
# 生成图像
|
||||
image_paths = await self.generate_image(prompt=prompt)
|
||||
|
||||
if image_paths:
|
||||
# 直接发送图片
|
||||
await bot.send_image(from_wxid, image_paths[0])
|
||||
return {
|
||||
"success": True,
|
||||
"message": "已生成并发送图像",
|
||||
"images": [image_paths[0]]
|
||||
}
|
||||
else:
|
||||
return {"success": False, "message": "图像生成失败"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM工具执行失败: {e}")
|
||||
return {"success": False, "message": f"执行失败: {str(e)}"}
|
||||
3
plugins/TravelPlanner/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .main import TravelPlanner
|
||||
|
||||
__all__ = ["TravelPlanner"]
|
||||
860
plugins/TravelPlanner/amap_client.py
Normal file
@@ -0,0 +1,860 @@
|
||||
"""
|
||||
高德地图 API 客户端封装
|
||||
|
||||
提供以下功能:
|
||||
- 地理编码:地址 → 坐标
|
||||
- 逆地理编码:坐标 → 地址
|
||||
- 行政区域查询:获取城市 adcode
|
||||
- 天气查询:实况/预报天气
|
||||
- POI 搜索:关键字搜索、周边搜索
|
||||
- 路径规划:驾车、公交、步行、骑行
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import aiohttp
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Literal
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class AmapConfig:
|
||||
"""高德 API 配置"""
|
||||
api_key: str
|
||||
secret: str = "" # 安全密钥,用于数字签名
|
||||
timeout: int = 30
|
||||
|
||||
|
||||
class AmapClient:
|
||||
"""高德地图 API 客户端"""
|
||||
|
||||
BASE_URL = "https://restapi.amap.com"
|
||||
|
||||
def __init__(self, config: AmapConfig):
|
||||
self.config = config
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
@staticmethod
|
||||
def _safe_int(value, default: int = 0) -> int:
|
||||
"""安全地将值转换为整数,处理列表、None、空字符串等情况"""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, list):
|
||||
return default
|
||||
if isinstance(value, (int, float)):
|
||||
return int(value)
|
||||
if isinstance(value, str):
|
||||
if not value.strip():
|
||||
return default
|
||||
try:
|
||||
return int(float(value))
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _safe_float(value, default: float = 0.0) -> float:
|
||||
"""安全地将值转换为浮点数"""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, list):
|
||||
return default
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
if isinstance(value, str):
|
||||
if not value.strip():
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _safe_str(value, default: str = "") -> str:
|
||||
"""安全地将值转换为字符串,处理列表等情况"""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, list):
|
||||
return default
|
||||
return str(value)
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""获取或创建 HTTP 会话"""
|
||||
if self._session is None or self._session.closed:
|
||||
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""关闭 HTTP 会话"""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
def _generate_signature(self, params: Dict[str, Any]) -> str:
|
||||
"""
|
||||
生成数字签名
|
||||
|
||||
算法:
|
||||
1. 将请求参数按参数名升序排序
|
||||
2. 按 key=value 格式拼接,用 & 连接
|
||||
3. 最后拼接上私钥(secret)
|
||||
4. 对整个字符串进行 MD5 加密
|
||||
|
||||
Args:
|
||||
params: 请求参数(不含 sig)
|
||||
|
||||
Returns:
|
||||
MD5 签名字符串
|
||||
"""
|
||||
# 按参数名升序排序
|
||||
sorted_params = sorted(params.items(), key=lambda x: x[0])
|
||||
# 拼接成 key=value&key=value 格式
|
||||
param_str = "&".join(f"{k}={v}" for k, v in sorted_params)
|
||||
# 拼接私钥
|
||||
sign_str = param_str + self.config.secret
|
||||
# MD5 加密
|
||||
return hashlib.md5(sign_str.encode('utf-8')).hexdigest()
|
||||
|
||||
async def _request(self, endpoint: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
发送 API 请求
|
||||
|
||||
Args:
|
||||
endpoint: API 端点路径
|
||||
params: 请求参数
|
||||
|
||||
Returns:
|
||||
API 响应数据
|
||||
"""
|
||||
params["key"] = self.config.api_key
|
||||
params["output"] = "JSON"
|
||||
|
||||
# 如果配置了安全密钥,生成数字签名
|
||||
if self.config.secret:
|
||||
params["sig"] = self._generate_signature(params)
|
||||
|
||||
url = f"{self.BASE_URL}{endpoint}"
|
||||
session = await self._get_session()
|
||||
|
||||
try:
|
||||
async with session.get(url, params=params) as response:
|
||||
data = await response.json()
|
||||
|
||||
# 检查 API 状态
|
||||
status = data.get("status", "0")
|
||||
if status != "1":
|
||||
info = data.get("info", "未知错误")
|
||||
infocode = data.get("infocode", "")
|
||||
logger.warning(f"高德 API 错误: {info} (code: {infocode})")
|
||||
return {"success": False, "error": info, "code": infocode}
|
||||
|
||||
return {"success": True, "data": data}
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"高德 API 请求失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
except Exception as e:
|
||||
logger.error(f"高德 API 未知错误: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
# ==================== 地理编码 ====================
|
||||
|
||||
async def geocode(self, address: str, city: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
地理编码:将地址转换为坐标
|
||||
|
||||
Args:
|
||||
address: 结构化地址,如 "北京市朝阳区阜通东大街6号"
|
||||
city: 指定城市(可选)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": True,
|
||||
"location": "116.480881,39.989410",
|
||||
"adcode": "110105",
|
||||
"city": "北京市",
|
||||
"district": "朝阳区",
|
||||
"level": "门址"
|
||||
}
|
||||
"""
|
||||
params = {"address": address}
|
||||
if city:
|
||||
params["city"] = city
|
||||
|
||||
result = await self._request("/v3/geocode/geo", params)
|
||||
|
||||
if not result["success"]:
|
||||
return result
|
||||
|
||||
geocodes = result["data"].get("geocodes", [])
|
||||
if not geocodes:
|
||||
return {"success": False, "error": "未找到该地址"}
|
||||
|
||||
geo = geocodes[0]
|
||||
return {
|
||||
"success": True,
|
||||
"location": geo.get("location", ""),
|
||||
"adcode": geo.get("adcode", ""),
|
||||
"province": geo.get("province", ""),
|
||||
"city": geo.get("city", ""),
|
||||
"district": geo.get("district", ""),
|
||||
"level": geo.get("level", ""),
|
||||
"formatted_address": geo.get("formatted_address", address)
|
||||
}
|
||||
|
||||
async def reverse_geocode(
|
||||
self,
|
||||
location: str,
|
||||
radius: int = 1000,
|
||||
extensions: str = "base"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
逆地理编码:将坐标转换为地址
|
||||
|
||||
Args:
|
||||
location: 经纬度坐标,格式 "lng,lat"
|
||||
radius: 搜索半径(米),0-3000
|
||||
extensions: base 或 all
|
||||
|
||||
Returns:
|
||||
地址信息
|
||||
"""
|
||||
params = {
|
||||
"location": location,
|
||||
"radius": min(radius, 3000),
|
||||
"extensions": extensions
|
||||
}
|
||||
|
||||
result = await self._request("/v3/geocode/regeo", params)
|
||||
|
||||
if not result["success"]:
|
||||
return result
|
||||
|
||||
regeocode = result["data"].get("regeocode", {})
|
||||
address_component = regeocode.get("addressComponent", {})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"formatted_address": regeocode.get("formatted_address", ""),
|
||||
"province": address_component.get("province", ""),
|
||||
"city": address_component.get("city", ""),
|
||||
"district": address_component.get("district", ""),
|
||||
"adcode": address_component.get("adcode", ""),
|
||||
"township": address_component.get("township", ""),
|
||||
"pois": regeocode.get("pois", []) if extensions == "all" else []
|
||||
}
|
||||
|
||||
# ==================== 行政区域查询 ====================
|
||||
|
||||
async def get_district(
|
||||
self,
|
||||
keywords: str = None,
|
||||
subdistrict: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
行政区域查询
|
||||
|
||||
Args:
|
||||
keywords: 查询关键字(城市名、adcode 等)
|
||||
subdistrict: 返回下级行政区级数(0-3)
|
||||
|
||||
Returns:
|
||||
行政区域信息,包含 adcode、citycode 等
|
||||
"""
|
||||
params = {"subdistrict": subdistrict}
|
||||
if keywords:
|
||||
params["keywords"] = keywords
|
||||
|
||||
result = await self._request("/v3/config/district", params)
|
||||
|
||||
if not result["success"]:
|
||||
return result
|
||||
|
||||
districts = result["data"].get("districts", [])
|
||||
if not districts:
|
||||
return {"success": False, "error": "未找到该行政区域"}
|
||||
|
||||
district = districts[0]
|
||||
return {
|
||||
"success": True,
|
||||
"name": district.get("name", ""),
|
||||
"adcode": district.get("adcode", ""),
|
||||
"citycode": district.get("citycode", ""),
|
||||
"center": district.get("center", ""),
|
||||
"level": district.get("level", ""),
|
||||
"districts": district.get("districts", [])
|
||||
}
|
||||
|
||||
# ==================== 天气查询 ====================
|
||||
|
||||
async def get_weather(
|
||||
self,
|
||||
city: str,
|
||||
extensions: Literal["base", "all"] = "all"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
天气查询
|
||||
|
||||
Args:
|
||||
city: 城市 adcode(如 110000)或城市名
|
||||
extensions: base=实况天气,all=预报天气(未来4天)
|
||||
|
||||
Returns:
|
||||
天气信息
|
||||
"""
|
||||
# 如果传入的是城市名,先获取 adcode
|
||||
if not city.isdigit():
|
||||
district_result = await self.get_district(city)
|
||||
if not district_result["success"]:
|
||||
return {"success": False, "error": f"无法识别城市: {city}"}
|
||||
city = district_result["adcode"]
|
||||
|
||||
params = {
|
||||
"city": city,
|
||||
"extensions": extensions
|
||||
}
|
||||
|
||||
result = await self._request("/v3/weather/weatherInfo", params)
|
||||
|
||||
if not result["success"]:
|
||||
return result
|
||||
|
||||
data = result["data"]
|
||||
|
||||
if extensions == "base":
|
||||
# 实况天气
|
||||
lives = data.get("lives", [])
|
||||
if not lives:
|
||||
return {"success": False, "error": "未获取到天气数据"}
|
||||
|
||||
live = lives[0]
|
||||
return {
|
||||
"success": True,
|
||||
"type": "live",
|
||||
"city": live.get("city", ""),
|
||||
"weather": live.get("weather", ""),
|
||||
"temperature": live.get("temperature", ""),
|
||||
"winddirection": live.get("winddirection", ""),
|
||||
"windpower": live.get("windpower", ""),
|
||||
"humidity": live.get("humidity", ""),
|
||||
"reporttime": live.get("reporttime", "")
|
||||
}
|
||||
else:
|
||||
# 预报天气
|
||||
forecasts = data.get("forecasts", [])
|
||||
if not forecasts:
|
||||
return {"success": False, "error": "未获取到天气预报数据"}
|
||||
|
||||
forecast = forecasts[0]
|
||||
casts = forecast.get("casts", [])
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"type": "forecast",
|
||||
"city": forecast.get("city", ""),
|
||||
"province": forecast.get("province", ""),
|
||||
"reporttime": forecast.get("reporttime", ""),
|
||||
"forecasts": [
|
||||
{
|
||||
"date": cast.get("date", ""),
|
||||
"week": cast.get("week", ""),
|
||||
"dayweather": cast.get("dayweather", ""),
|
||||
"nightweather": cast.get("nightweather", ""),
|
||||
"daytemp": cast.get("daytemp", ""),
|
||||
"nighttemp": cast.get("nighttemp", ""),
|
||||
"daywind": cast.get("daywind", ""),
|
||||
"nightwind": cast.get("nightwind", ""),
|
||||
"daypower": cast.get("daypower", ""),
|
||||
"nightpower": cast.get("nightpower", "")
|
||||
}
|
||||
for cast in casts
|
||||
]
|
||||
}
|
||||
|
||||
# ==================== POI 搜索 ====================
|
||||
|
||||
async def search_poi(
|
||||
self,
|
||||
keywords: str = None,
|
||||
types: str = None,
|
||||
city: str = None,
|
||||
citylimit: bool = True,
|
||||
offset: int = 20,
|
||||
page: int = 1,
|
||||
extensions: str = "all"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
关键字搜索 POI
|
||||
|
||||
Args:
|
||||
keywords: 查询关键字
|
||||
types: POI 类型代码,多个用 | 分隔
|
||||
city: 城市名或 adcode
|
||||
citylimit: 是否仅返回指定城市
|
||||
offset: 每页数量(建议不超过25)
|
||||
page: 页码
|
||||
extensions: base 或 all
|
||||
|
||||
Returns:
|
||||
POI 列表
|
||||
"""
|
||||
params = {
|
||||
"offset": min(offset, 25),
|
||||
"page": page,
|
||||
"extensions": extensions
|
||||
}
|
||||
|
||||
if keywords:
|
||||
params["keywords"] = keywords
|
||||
if types:
|
||||
params["types"] = types
|
||||
if city:
|
||||
params["city"] = city
|
||||
params["citylimit"] = "true" if citylimit else "false"
|
||||
|
||||
result = await self._request("/v3/place/text", params)
|
||||
|
||||
if not result["success"]:
|
||||
return result
|
||||
|
||||
pois = result["data"].get("pois", [])
|
||||
count = self._safe_int(result["data"].get("count", 0))
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"count": count,
|
||||
"pois": [self._format_poi(poi) for poi in pois]
|
||||
}
|
||||
|
||||
async def search_around(
|
||||
self,
|
||||
location: str,
|
||||
keywords: str = None,
|
||||
types: str = None,
|
||||
radius: int = 3000,
|
||||
offset: int = 20,
|
||||
page: int = 1,
|
||||
extensions: str = "all"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
周边搜索 POI
|
||||
|
||||
Args:
|
||||
location: 中心点坐标,格式 "lng,lat"
|
||||
keywords: 查询关键字
|
||||
types: POI 类型代码
|
||||
radius: 搜索半径(米),0-50000
|
||||
offset: 每页数量
|
||||
page: 页码
|
||||
extensions: base 或 all
|
||||
|
||||
Returns:
|
||||
POI 列表
|
||||
"""
|
||||
params = {
|
||||
"location": location,
|
||||
"radius": min(radius, 50000),
|
||||
"offset": min(offset, 25),
|
||||
"page": page,
|
||||
"extensions": extensions,
|
||||
"sortrule": "distance"
|
||||
}
|
||||
|
||||
if keywords:
|
||||
params["keywords"] = keywords
|
||||
if types:
|
||||
params["types"] = types
|
||||
|
||||
result = await self._request("/v3/place/around", params)
|
||||
|
||||
if not result["success"]:
|
||||
return result
|
||||
|
||||
pois = result["data"].get("pois", [])
|
||||
count = self._safe_int(result["data"].get("count", 0))
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"count": count,
|
||||
"pois": [self._format_poi(poi) for poi in pois]
|
||||
}
|
||||
|
||||
def _format_poi(self, poi: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""格式化 POI 数据"""
|
||||
biz_ext = poi.get("biz_ext", {}) or {}
|
||||
return {
|
||||
"id": poi.get("id", ""),
|
||||
"name": poi.get("name", ""),
|
||||
"type": poi.get("type", ""),
|
||||
"address": poi.get("address", ""),
|
||||
"location": poi.get("location", ""),
|
||||
"tel": poi.get("tel", ""),
|
||||
"distance": poi.get("distance", ""),
|
||||
"pname": poi.get("pname", ""),
|
||||
"cityname": poi.get("cityname", ""),
|
||||
"adname": poi.get("adname", ""),
|
||||
"rating": biz_ext.get("rating", ""),
|
||||
"cost": biz_ext.get("cost", "")
|
||||
}
|
||||
|
||||
# ==================== 路径规划 ====================
|
||||
|
||||
async def route_driving(
|
||||
self,
|
||||
origin: str,
|
||||
destination: str,
|
||||
strategy: int = 10,
|
||||
waypoints: str = None,
|
||||
extensions: str = "base"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
驾车路径规划
|
||||
|
||||
Args:
|
||||
origin: 起点坐标 "lng,lat"
|
||||
destination: 终点坐标 "lng,lat"
|
||||
strategy: 驾车策略(10=躲避拥堵,13=不走高速,14=避免收费)
|
||||
waypoints: 途经点,多个用 ; 分隔
|
||||
extensions: base 或 all
|
||||
|
||||
Returns:
|
||||
路径规划结果
|
||||
"""
|
||||
params = {
|
||||
"origin": origin,
|
||||
"destination": destination,
|
||||
"strategy": strategy,
|
||||
"extensions": extensions
|
||||
}
|
||||
if waypoints:
|
||||
params["waypoints"] = waypoints
|
||||
|
||||
result = await self._request("/v3/direction/driving", params)
|
||||
|
||||
if not result["success"]:
|
||||
return result
|
||||
|
||||
route = result["data"].get("route", {})
|
||||
paths = route.get("paths", [])
|
||||
|
||||
if not paths:
|
||||
return {"success": False, "error": "未找到驾车路线"}
|
||||
|
||||
path = paths[0]
|
||||
return {
|
||||
"success": True,
|
||||
"mode": "driving",
|
||||
"origin": route.get("origin", ""),
|
||||
"destination": route.get("destination", ""),
|
||||
"distance": self._safe_int(path.get("distance", 0)),
|
||||
"duration": self._safe_int(path.get("duration", 0)),
|
||||
"tolls": self._safe_float(path.get("tolls", 0)),
|
||||
"toll_distance": self._safe_int(path.get("toll_distance", 0)),
|
||||
"traffic_lights": self._safe_int(path.get("traffic_lights", 0)),
|
||||
"taxi_cost": self._safe_str(route.get("taxi_cost", "")),
|
||||
"strategy": path.get("strategy", ""),
|
||||
"steps": self._format_driving_steps(path.get("steps", []))
|
||||
}
|
||||
|
||||
async def route_transit(
|
||||
self,
|
||||
origin: str,
|
||||
destination: str,
|
||||
city: str,
|
||||
cityd: str = None,
|
||||
strategy: int = 0,
|
||||
extensions: str = "all"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
公交路径规划(含火车、地铁)
|
||||
|
||||
Args:
|
||||
origin: 起点坐标 "lng,lat"
|
||||
destination: 终点坐标 "lng,lat"
|
||||
city: 起点城市
|
||||
cityd: 终点城市(跨城时必填)
|
||||
strategy: 0=最快,1=最省钱,2=最少换乘,3=最少步行
|
||||
extensions: base 或 all
|
||||
|
||||
Returns:
|
||||
公交路径规划结果
|
||||
"""
|
||||
params = {
|
||||
"origin": origin,
|
||||
"destination": destination,
|
||||
"city": city,
|
||||
"strategy": strategy,
|
||||
"extensions": extensions
|
||||
}
|
||||
if cityd:
|
||||
params["cityd"] = cityd
|
||||
|
||||
result = await self._request("/v3/direction/transit/integrated", params)
|
||||
|
||||
if not result["success"]:
|
||||
return result
|
||||
|
||||
route = result["data"].get("route", {})
|
||||
transits = route.get("transits", [])
|
||||
|
||||
if not transits:
|
||||
return {"success": False, "error": "未找到公交路线"}
|
||||
|
||||
# 返回前3个方案
|
||||
formatted_transits = []
|
||||
for transit in transits[:3]:
|
||||
segments = transit.get("segments", [])
|
||||
formatted_segments = []
|
||||
|
||||
for seg in segments:
|
||||
# 步行段
|
||||
walking = seg.get("walking", {})
|
||||
if walking and walking.get("distance"):
|
||||
formatted_segments.append({
|
||||
"type": "walking",
|
||||
"distance": self._safe_int(walking.get("distance", 0)),
|
||||
"duration": self._safe_int(walking.get("duration", 0))
|
||||
})
|
||||
|
||||
# 公交/地铁段
|
||||
bus_info = seg.get("bus", {})
|
||||
buslines = bus_info.get("buslines", [])
|
||||
if buslines:
|
||||
line = buslines[0]
|
||||
formatted_segments.append({
|
||||
"type": "bus",
|
||||
"name": self._safe_str(line.get("name", "")),
|
||||
"departure_stop": self._safe_str(line.get("departure_stop", {}).get("name", "")),
|
||||
"arrival_stop": self._safe_str(line.get("arrival_stop", {}).get("name", "")),
|
||||
"via_num": self._safe_int(line.get("via_num", 0)),
|
||||
"distance": self._safe_int(line.get("distance", 0)),
|
||||
"duration": self._safe_int(line.get("duration", 0))
|
||||
})
|
||||
|
||||
# 火车段
|
||||
railway = seg.get("railway", {})
|
||||
if railway and railway.get("name"):
|
||||
formatted_segments.append({
|
||||
"type": "railway",
|
||||
"name": self._safe_str(railway.get("name", "")),
|
||||
"trip": self._safe_str(railway.get("trip", "")),
|
||||
"departure_stop": self._safe_str(railway.get("departure_stop", {}).get("name", "")),
|
||||
"arrival_stop": self._safe_str(railway.get("arrival_stop", {}).get("name", "")),
|
||||
"departure_time": self._safe_str(railway.get("departure_stop", {}).get("time", "")),
|
||||
"arrival_time": self._safe_str(railway.get("arrival_stop", {}).get("time", "")),
|
||||
"distance": self._safe_int(railway.get("distance", 0)),
|
||||
"time": self._safe_str(railway.get("time", ""))
|
||||
})
|
||||
|
||||
formatted_transits.append({
|
||||
"cost": self._safe_str(transit.get("cost", "")),
|
||||
"duration": self._safe_int(transit.get("duration", 0)),
|
||||
"walking_distance": self._safe_int(transit.get("walking_distance", 0)),
|
||||
"segments": formatted_segments
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"mode": "transit",
|
||||
"origin": route.get("origin", ""),
|
||||
"destination": route.get("destination", ""),
|
||||
"distance": self._safe_int(route.get("distance", 0)),
|
||||
"taxi_cost": self._safe_str(route.get("taxi_cost", "")),
|
||||
"transits": formatted_transits
|
||||
}
|
||||
|
||||
async def route_walking(
|
||||
self,
|
||||
origin: str,
|
||||
destination: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
步行路径规划
|
||||
|
||||
Args:
|
||||
origin: 起点坐标 "lng,lat"
|
||||
destination: 终点坐标 "lng,lat"
|
||||
|
||||
Returns:
|
||||
步行路径规划结果
|
||||
"""
|
||||
params = {
|
||||
"origin": origin,
|
||||
"destination": destination
|
||||
}
|
||||
|
||||
result = await self._request("/v3/direction/walking", params)
|
||||
|
||||
if not result["success"]:
|
||||
return result
|
||||
|
||||
route = result["data"].get("route", {})
|
||||
paths = route.get("paths", [])
|
||||
|
||||
if not paths:
|
||||
return {"success": False, "error": "未找到步行路线"}
|
||||
|
||||
path = paths[0]
|
||||
return {
|
||||
"success": True,
|
||||
"mode": "walking",
|
||||
"origin": route.get("origin", ""),
|
||||
"destination": route.get("destination", ""),
|
||||
"distance": self._safe_int(path.get("distance", 0)),
|
||||
"duration": self._safe_int(path.get("duration", 0))
|
||||
}
|
||||
|
||||
async def route_bicycling(
|
||||
self,
|
||||
origin: str,
|
||||
destination: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
骑行路径规划
|
||||
|
||||
Args:
|
||||
origin: 起点坐标 "lng,lat"
|
||||
destination: 终点坐标 "lng,lat"
|
||||
|
||||
Returns:
|
||||
骑行路径规划结果
|
||||
"""
|
||||
params = {
|
||||
"origin": origin,
|
||||
"destination": destination
|
||||
}
|
||||
|
||||
# 骑行用 v4 接口
|
||||
result = await self._request("/v4/direction/bicycling", params)
|
||||
|
||||
if not result["success"]:
|
||||
return result
|
||||
|
||||
data = result["data"].get("data", {})
|
||||
paths = data.get("paths", [])
|
||||
|
||||
if not paths:
|
||||
return {"success": False, "error": "未找到骑行路线"}
|
||||
|
||||
path = paths[0]
|
||||
return {
|
||||
"success": True,
|
||||
"mode": "bicycling",
|
||||
"origin": data.get("origin", ""),
|
||||
"destination": data.get("destination", ""),
|
||||
"distance": self._safe_int(path.get("distance", 0)),
|
||||
"duration": self._safe_int(path.get("duration", 0))
|
||||
}
|
||||
|
||||
def _format_driving_steps(self, steps: List[Dict]) -> List[Dict]:
|
||||
"""格式化驾车步骤"""
|
||||
return [
|
||||
{
|
||||
"instruction": step.get("instruction", ""),
|
||||
"road": step.get("road", ""),
|
||||
"distance": self._safe_int(step.get("distance", 0)),
|
||||
"duration": self._safe_int(step.get("duration", 0)),
|
||||
"orientation": step.get("orientation", "")
|
||||
}
|
||||
for step in steps[:10] # 只返回前10步
|
||||
]
|
||||
|
||||
# ==================== 距离测量 ====================
|
||||
|
||||
async def get_distance(
|
||||
self,
|
||||
origins: str,
|
||||
destination: str,
|
||||
type: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
距离测量
|
||||
|
||||
Args:
|
||||
origins: 起点坐标,多个用 | 分隔
|
||||
destination: 终点坐标
|
||||
type: 0=直线距离,1=驾车距离,3=步行距离
|
||||
|
||||
Returns:
|
||||
距离信息
|
||||
"""
|
||||
params = {
|
||||
"origins": origins,
|
||||
"destination": destination,
|
||||
"type": type
|
||||
}
|
||||
|
||||
result = await self._request("/v3/distance", params)
|
||||
|
||||
if not result["success"]:
|
||||
return result
|
||||
|
||||
results = result["data"].get("results", [])
|
||||
if not results:
|
||||
return {"success": False, "error": "无法计算距离"}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"results": [
|
||||
{
|
||||
"origin_id": r.get("origin_id", ""),
|
||||
"distance": self._safe_int(r.get("distance", 0)),
|
||||
"duration": self._safe_int(r.get("duration", 0))
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
}
|
||||
|
||||
# ==================== 输入提示 ====================
|
||||
|
||||
async def input_tips(
|
||||
self,
|
||||
keywords: str,
|
||||
city: str = None,
|
||||
citylimit: bool = False,
|
||||
datatype: str = "all"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
输入提示
|
||||
|
||||
Args:
|
||||
keywords: 查询关键字
|
||||
city: 城市名或 adcode
|
||||
citylimit: 是否仅返回指定城市
|
||||
datatype: all/poi/bus/busline
|
||||
|
||||
Returns:
|
||||
提示列表
|
||||
"""
|
||||
params = {
|
||||
"keywords": keywords,
|
||||
"datatype": datatype
|
||||
}
|
||||
if city:
|
||||
params["city"] = city
|
||||
params["citylimit"] = "true" if citylimit else "false"
|
||||
|
||||
result = await self._request("/v3/assistant/inputtips", params)
|
||||
|
||||
if not result["success"]:
|
||||
return result
|
||||
|
||||
tips = result["data"].get("tips", [])
|
||||
return {
|
||||
"success": True,
|
||||
"tips": [
|
||||
{
|
||||
"id": tip.get("id", ""),
|
||||
"name": tip.get("name", ""),
|
||||
"district": tip.get("district", ""),
|
||||
"adcode": tip.get("adcode", ""),
|
||||
"location": tip.get("location", ""),
|
||||
"address": tip.get("address", "")
|
||||
}
|
||||
for tip in tips
|
||||
if tip.get("location") # 过滤无坐标的结果
|
||||
]
|
||||
}
|
||||
609
plugins/TravelPlanner/main.py
Normal file
@@ -0,0 +1,609 @@
|
||||
"""
|
||||
旅行规划插件
|
||||
|
||||
基于高德地图 API,提供以下功能:
|
||||
- 地点搜索与地理编码
|
||||
- 天气查询(实况 + 4天预报)
|
||||
- 景点/酒店/餐厅搜索
|
||||
- 路径规划(驾车/公交/步行)
|
||||
- 周边搜索
|
||||
|
||||
支持 LLM 函数调用,可与 AIChat 插件配合使用。
|
||||
"""
|
||||
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from loguru import logger
|
||||
|
||||
from utils.plugin_base import PluginBase
|
||||
from .amap_client import AmapClient, AmapConfig
|
||||
|
||||
|
||||
class TravelPlanner(PluginBase):
|
||||
"""旅行规划插件"""
|
||||
|
||||
description = "旅行规划助手,支持天气查询、景点搜索、路线规划"
|
||||
author = "ShiHao"
|
||||
version = "1.0.0"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.config = None
|
||||
self.amap: AmapClient = None
|
||||
|
||||
async def async_init(self):
|
||||
"""插件异步初始化"""
|
||||
# 读取配置
|
||||
config_path = Path(__file__).parent / "config.toml"
|
||||
with open(config_path, "rb") as f:
|
||||
self.config = tomllib.load(f)
|
||||
|
||||
# 初始化高德 API 客户端
|
||||
amap_config = self.config.get("amap", {})
|
||||
api_key = amap_config.get("api_key", "")
|
||||
secret = amap_config.get("secret", "")
|
||||
|
||||
if not api_key:
|
||||
logger.warning("TravelPlanner: 未配置高德 API Key,请在 config.toml 中设置")
|
||||
else:
|
||||
self.amap = AmapClient(AmapConfig(
|
||||
api_key=api_key,
|
||||
secret=secret,
|
||||
timeout=amap_config.get("timeout", 30)
|
||||
))
|
||||
if secret:
|
||||
logger.success(f"TravelPlanner 插件已加载,API Key: {api_key[:8]}...(已启用数字签名)")
|
||||
else:
|
||||
logger.success(f"TravelPlanner 插件已加载,API Key: {api_key[:8]}...(未配置安全密钥)")
|
||||
|
||||
async def on_disable(self):
|
||||
"""插件禁用时关闭连接"""
|
||||
await super().on_disable()
|
||||
if self.amap:
|
||||
await self.amap.close()
|
||||
logger.info("TravelPlanner: 已关闭高德 API 连接")
|
||||
|
||||
# ==================== LLM 工具定义 ====================
|
||||
|
||||
def get_llm_tools(self) -> List[Dict]:
|
||||
"""返回 LLM 可调用的工具列表"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_location",
|
||||
"description": "【旅行工具】将地名转换为坐标和行政区划信息。仅当用户明确询问某个地点的位置信息时使用。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"address": {
|
||||
"type": "string",
|
||||
"description": "地址或地名,如:北京市、西湖、故宫"
|
||||
},
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "所在城市,可选。填写可提高搜索精度"
|
||||
}
|
||||
},
|
||||
"required": ["address"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "query_weather",
|
||||
"description": "【旅行工具】查询城市天气预报。仅当用户明确询问某城市的天气情况时使用,如'北京天气怎么样'、'杭州明天会下雨吗'。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "城市名称,如:北京、杭州、上海"
|
||||
},
|
||||
"forecast": {
|
||||
"type": "boolean",
|
||||
"description": "是否查询预报天气。true=未来4天预报,false=当前实况"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_poi",
|
||||
"description": "【旅行工具】搜索地点(景点、酒店、餐厅等)。仅当用户明确要求查找某城市的景点、酒店、餐厅等时使用。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "搜索城市,如:杭州、北京"
|
||||
},
|
||||
"keyword": {
|
||||
"type": "string",
|
||||
"description": "搜索关键词,如:西湖、希尔顿酒店、火锅"
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["景点", "酒店", "餐厅", "购物", "交通"],
|
||||
"description": "POI 类别。不填则搜索所有类别"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "返回结果数量,默认10,最大20"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_nearby",
|
||||
"description": "【旅行工具】搜索某地点周边的设施。仅当用户明确要求查找某地点附近的餐厅、酒店等时使用,如'西湖附近有什么好吃的'。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "中心地点名称,如:西湖、故宫"
|
||||
},
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "所在城市"
|
||||
},
|
||||
"keyword": {
|
||||
"type": "string",
|
||||
"description": "搜索关键词"
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["景点", "酒店", "餐厅", "购物", "交通"],
|
||||
"description": "POI 类别"
|
||||
},
|
||||
"radius": {
|
||||
"type": "integer",
|
||||
"description": "搜索半径(米),默认3000,最大50000"
|
||||
}
|
||||
},
|
||||
"required": ["location", "city"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "plan_route",
|
||||
"description": "【旅行工具】规划两地之间的出行路线。仅当用户明确要求规划从A到B的路线时使用,如'从北京到杭州怎么走'、'上海到苏州的高铁'。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"origin": {
|
||||
"type": "string",
|
||||
"description": "起点地名,如:北京、上海虹桥站"
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "终点地名,如:杭州、西湖"
|
||||
},
|
||||
"origin_city": {
|
||||
"type": "string",
|
||||
"description": "起点所在城市"
|
||||
},
|
||||
"destination_city": {
|
||||
"type": "string",
|
||||
"description": "终点所在城市(跨城时必填)"
|
||||
},
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["driving", "transit", "walking"],
|
||||
"description": "出行方式:driving=驾车,transit=公交/高铁,walking=步行。默认 transit"
|
||||
}
|
||||
},
|
||||
"required": ["origin", "destination", "origin_city"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_travel_info",
|
||||
"description": "【旅行工具】获取目的地城市的旅行信息(天气、景点、交通)。仅当用户明确表示要去某城市旅游并询问相关信息时使用,如'我想去杭州玩,帮我看看'、'北京旅游攻略'。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "目的地城市,如:杭州、成都"
|
||||
},
|
||||
"origin": {
|
||||
"type": "string",
|
||||
"description": "出发城市,如:北京、上海。填写后会规划交通路线"
|
||||
}
|
||||
},
|
||||
"required": ["destination"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
async def execute_llm_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
bot,
|
||||
from_wxid: str
|
||||
) -> Dict[str, Any]:
|
||||
"""执行 LLM 工具调用"""
|
||||
|
||||
if not self.amap:
|
||||
return {"success": False, "message": "高德 API 未配置,请联系管理员设置 API Key"}
|
||||
|
||||
try:
|
||||
if tool_name == "search_location":
|
||||
return await self._tool_search_location(arguments)
|
||||
elif tool_name == "query_weather":
|
||||
return await self._tool_query_weather(arguments)
|
||||
elif tool_name == "search_poi":
|
||||
return await self._tool_search_poi(arguments)
|
||||
elif tool_name == "search_nearby":
|
||||
return await self._tool_search_nearby(arguments)
|
||||
elif tool_name == "plan_route":
|
||||
return await self._tool_plan_route(arguments)
|
||||
elif tool_name == "get_travel_info":
|
||||
return await self._tool_get_travel_info(arguments)
|
||||
else:
|
||||
return {"success": False, "message": f"未知工具: {tool_name}"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TravelPlanner 工具执行失败: {tool_name}, 错误: {e}")
|
||||
return {"success": False, "message": f"工具执行失败: {str(e)}"}
|
||||
|
||||
# ==================== 工具实现 ====================
|
||||
|
||||
async def _tool_search_location(self, args: Dict) -> Dict:
|
||||
"""地点搜索工具"""
|
||||
address = args.get("address", "")
|
||||
city = args.get("city")
|
||||
|
||||
result = await self.amap.geocode(address, city)
|
||||
|
||||
if not result["success"]:
|
||||
return {"success": False, "message": result.get("error", "地点搜索失败")}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已找到地点:{result['formatted_address']}",
|
||||
"data": {
|
||||
"name": address,
|
||||
"formatted_address": result["formatted_address"],
|
||||
"location": result["location"],
|
||||
"province": result["province"],
|
||||
"city": result["city"],
|
||||
"district": result["district"],
|
||||
"adcode": result["adcode"]
|
||||
}
|
||||
}
|
||||
|
||||
async def _tool_query_weather(self, args: Dict) -> Dict:
|
||||
"""天气查询工具"""
|
||||
city = args.get("city", "")
|
||||
forecast = args.get("forecast", True)
|
||||
|
||||
extensions = "all" if forecast else "base"
|
||||
result = await self.amap.get_weather(city, extensions)
|
||||
|
||||
if not result["success"]:
|
||||
return {"success": False, "message": result.get("error", "天气查询失败")}
|
||||
|
||||
if result["type"] == "live":
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"{result['city']}当前天气:{result['weather']},{result['temperature']}℃",
|
||||
"data": {
|
||||
"city": result["city"],
|
||||
"weather": result["weather"],
|
||||
"temperature": result["temperature"],
|
||||
"humidity": result["humidity"],
|
||||
"wind": f"{result['winddirection']}风 {result['windpower']}级",
|
||||
"reporttime": result["reporttime"]
|
||||
}
|
||||
}
|
||||
else:
|
||||
forecasts = result["forecasts"]
|
||||
weather_text = "\n".join([
|
||||
f"- {f['date']} 星期{self._weekday_cn(f['week'])}:白天{f['dayweather']} {f['daytemp']}℃,夜间{f['nightweather']} {f['nighttemp']}℃"
|
||||
for f in forecasts
|
||||
])
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"{result['city']}未来天气预报:\n{weather_text}",
|
||||
"data": {
|
||||
"city": result["city"],
|
||||
"province": result["province"],
|
||||
"forecasts": forecasts,
|
||||
"reporttime": result["reporttime"]
|
||||
}
|
||||
}
|
||||
|
||||
async def _tool_search_poi(self, args: Dict) -> Dict:
|
||||
"""POI 搜索工具"""
|
||||
city = args.get("city", "")
|
||||
keyword = args.get("keyword")
|
||||
category = args.get("category")
|
||||
limit = min(args.get("limit", 10), 20)
|
||||
|
||||
# 获取 POI 类型代码
|
||||
types = None
|
||||
if category:
|
||||
poi_types = self.config.get("poi_types", {})
|
||||
types = poi_types.get(category)
|
||||
|
||||
result = await self.amap.search_poi(
|
||||
keywords=keyword,
|
||||
types=types,
|
||||
city=city,
|
||||
citylimit=True,
|
||||
offset=limit
|
||||
)
|
||||
|
||||
if not result["success"]:
|
||||
return {"success": False, "message": result.get("error", "搜索失败")}
|
||||
|
||||
pois = result["pois"]
|
||||
if not pois:
|
||||
return {"success": False, "message": f"在{city}未找到相关地点"}
|
||||
|
||||
# 格式化输出
|
||||
poi_list = []
|
||||
for i, poi in enumerate(pois, 1):
|
||||
info = f"{i}. {poi['name']}"
|
||||
if poi.get("address"):
|
||||
info += f" - {poi['address']}"
|
||||
if poi.get("rating"):
|
||||
info += f" ⭐{poi['rating']}"
|
||||
if poi.get("cost"):
|
||||
info += f" 人均¥{poi['cost']}"
|
||||
poi_list.append(info)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"在{city}找到{len(pois)}个结果:\n" + "\n".join(poi_list),
|
||||
"data": {
|
||||
"city": city,
|
||||
"category": category or "全部",
|
||||
"count": len(pois),
|
||||
"pois": pois
|
||||
}
|
||||
}
|
||||
|
||||
async def _tool_search_nearby(self, args: Dict) -> Dict:
|
||||
"""周边搜索工具"""
|
||||
location_name = args.get("location", "")
|
||||
city = args.get("city", "")
|
||||
keyword = args.get("keyword")
|
||||
category = args.get("category")
|
||||
radius = min(args.get("radius", 3000), 50000)
|
||||
|
||||
# 先获取中心点坐标
|
||||
geo_result = await self.amap.geocode(location_name, city)
|
||||
if not geo_result["success"]:
|
||||
return {"success": False, "message": f"无法定位 {location_name}"}
|
||||
|
||||
location = geo_result["location"]
|
||||
|
||||
# 获取 POI 类型代码
|
||||
types = None
|
||||
if category:
|
||||
poi_types = self.config.get("poi_types", {})
|
||||
types = poi_types.get(category)
|
||||
|
||||
result = await self.amap.search_around(
|
||||
location=location,
|
||||
keywords=keyword,
|
||||
types=types,
|
||||
radius=radius,
|
||||
offset=10
|
||||
)
|
||||
|
||||
if not result["success"]:
|
||||
return {"success": False, "message": result.get("error", "周边搜索失败")}
|
||||
|
||||
pois = result["pois"]
|
||||
if not pois:
|
||||
return {"success": False, "message": f"在{location_name}周边未找到相关地点"}
|
||||
|
||||
# 格式化输出
|
||||
poi_list = []
|
||||
for i, poi in enumerate(pois, 1):
|
||||
info = f"{i}. {poi['name']}"
|
||||
if poi.get("distance"):
|
||||
info += f" ({poi['distance']}米)"
|
||||
if poi.get("rating"):
|
||||
info += f" ⭐{poi['rating']}"
|
||||
poi_list.append(info)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"{location_name}周边{radius}米内找到{len(pois)}个结果:\n" + "\n".join(poi_list),
|
||||
"data": {
|
||||
"center": location_name,
|
||||
"radius": radius,
|
||||
"category": category or "全部",
|
||||
"count": len(pois),
|
||||
"pois": pois
|
||||
}
|
||||
}
|
||||
|
||||
async def _tool_plan_route(self, args: Dict) -> Dict:
|
||||
"""路线规划工具"""
|
||||
origin = args.get("origin", "")
|
||||
destination = args.get("destination", "")
|
||||
origin_city = args.get("origin_city", "")
|
||||
destination_city = args.get("destination_city", origin_city)
|
||||
mode = args.get("mode", "transit")
|
||||
|
||||
# 获取起终点坐标
|
||||
origin_geo = await self.amap.geocode(origin, origin_city)
|
||||
if not origin_geo["success"]:
|
||||
return {"success": False, "message": f"无法定位起点:{origin}"}
|
||||
|
||||
dest_geo = await self.amap.geocode(destination, destination_city)
|
||||
if not dest_geo["success"]:
|
||||
return {"success": False, "message": f"无法定位终点:{destination}"}
|
||||
|
||||
origin_loc = origin_geo["location"]
|
||||
dest_loc = dest_geo["location"]
|
||||
|
||||
# 根据模式规划路线
|
||||
if mode == "driving":
|
||||
result = await self.amap.route_driving(origin_loc, dest_loc)
|
||||
if not result["success"]:
|
||||
return {"success": False, "message": result.get("error", "驾车路线规划失败")}
|
||||
|
||||
distance_km = result["distance"] / 1000
|
||||
duration_h = result["duration"] / 3600
|
||||
|
||||
msg = f"🚗 驾车路线:{origin} → {destination}\n"
|
||||
msg += f"距离:{distance_km:.1f}公里,预计{self._format_duration(result['duration'])}\n"
|
||||
if result["tolls"]:
|
||||
msg += f"收费:约{result['tolls']}元\n"
|
||||
if result["taxi_cost"]:
|
||||
msg += f"打车费用:约{result['taxi_cost']}元"
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": msg,
|
||||
"data": result
|
||||
}
|
||||
|
||||
elif mode == "transit":
|
||||
result = await self.amap.route_transit(
|
||||
origin_loc, dest_loc,
|
||||
city=origin_city,
|
||||
cityd=destination_city if destination_city != origin_city else None
|
||||
)
|
||||
if not result["success"]:
|
||||
return {"success": False, "message": result.get("error", "公交路线规划失败")}
|
||||
|
||||
msg = f"🚄 公交/高铁路线:{origin} → {destination}\n"
|
||||
|
||||
for i, transit in enumerate(result["transits"][:2], 1):
|
||||
msg += f"\n方案{i}:{self._format_duration(transit['duration'])}"
|
||||
if transit.get("cost"):
|
||||
msg += f",约{transit['cost']}元"
|
||||
msg += "\n"
|
||||
|
||||
for seg in transit["segments"]:
|
||||
if seg["type"] == "walking" and seg["distance"] > 100:
|
||||
msg += f" 🚶 步行{seg['distance']}米\n"
|
||||
elif seg["type"] == "bus":
|
||||
msg += f" 🚌 {seg['name']}:{seg['departure_stop']} → {seg['arrival_stop']}({seg['via_num']}站)\n"
|
||||
elif seg["type"] == "railway":
|
||||
msg += f" 🚄 {seg['trip']} {seg['name']}:{seg['departure_stop']} {seg.get('departure_time', '')} → {seg['arrival_stop']} {seg.get('arrival_time', '')}\n"
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": msg.strip(),
|
||||
"data": result
|
||||
}
|
||||
|
||||
elif mode == "walking":
|
||||
result = await self.amap.route_walking(origin_loc, dest_loc)
|
||||
if not result["success"]:
|
||||
return {"success": False, "message": result.get("error", "步行路线规划失败")}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"🚶 步行路线:{origin} → {destination}\n距离:{result['distance']}米,预计{self._format_duration(result['duration'])}",
|
||||
"data": result
|
||||
}
|
||||
|
||||
return {"success": False, "message": f"不支持的出行方式:{mode}"}
|
||||
|
||||
async def _tool_get_travel_info(self, args: Dict) -> Dict:
|
||||
"""一键获取旅行信息"""
|
||||
destination = args.get("destination", "")
|
||||
origin = args.get("origin")
|
||||
|
||||
info = {"destination": destination}
|
||||
msg_parts = [f"📍 {destination} 旅行信息\n"]
|
||||
|
||||
# 1. 查询天气
|
||||
weather_result = await self.amap.get_weather(destination, "all")
|
||||
if weather_result["success"]:
|
||||
info["weather"] = weather_result
|
||||
msg_parts.append("🌤️ 天气预报:")
|
||||
for f in weather_result["forecasts"][:3]:
|
||||
msg_parts.append(f" {f['date']} {f['dayweather']} {f['nighttemp']}~{f['daytemp']}℃")
|
||||
|
||||
# 2. 搜索热门景点
|
||||
poi_result = await self.amap.search_poi(
|
||||
types="110000", # 景点
|
||||
city=destination,
|
||||
citylimit=True,
|
||||
offset=5
|
||||
)
|
||||
if poi_result["success"] and poi_result["pois"]:
|
||||
info["attractions"] = poi_result["pois"]
|
||||
msg_parts.append("\n🏞️ 热门景点:")
|
||||
for poi in poi_result["pois"][:5]:
|
||||
rating = f" ⭐{poi['rating']}" if poi.get("rating") else ""
|
||||
msg_parts.append(f" • {poi['name']}{rating}")
|
||||
|
||||
# 3. 规划交通路线(如果提供了出发地)
|
||||
if origin:
|
||||
origin_geo = await self.amap.geocode(origin)
|
||||
dest_geo = await self.amap.geocode(destination)
|
||||
|
||||
if origin_geo["success"] and dest_geo["success"]:
|
||||
route_result = await self.amap.route_transit(
|
||||
origin_geo["location"],
|
||||
dest_geo["location"],
|
||||
city=origin_geo.get("city", origin),
|
||||
cityd=dest_geo.get("city", destination)
|
||||
)
|
||||
|
||||
if route_result["success"] and route_result["transits"]:
|
||||
info["route"] = route_result
|
||||
transit = route_result["transits"][0]
|
||||
msg_parts.append(f"\n🚄 从{origin}出发:")
|
||||
msg_parts.append(f" 预计{self._format_duration(transit['duration'])}")
|
||||
|
||||
# 显示主要交通工具
|
||||
for seg in transit["segments"]:
|
||||
if seg["type"] == "railway":
|
||||
msg_parts.append(f" {seg['trip']}:{seg['departure_stop']} → {seg['arrival_stop']}")
|
||||
break
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "\n".join(msg_parts),
|
||||
"data": info
|
||||
}
|
||||
|
||||
# ==================== 辅助方法 ====================
|
||||
|
||||
def _weekday_cn(self, week: str) -> str:
|
||||
"""星期数字转中文"""
|
||||
mapping = {"1": "一", "2": "二", "3": "三", "4": "四", "5": "五", "6": "六", "7": "日"}
|
||||
return mapping.get(str(week), week)
|
||||
|
||||
def _format_duration(self, seconds: int) -> str:
|
||||
"""格式化时长"""
|
||||
if seconds < 60:
|
||||
return f"{seconds}秒"
|
||||
elif seconds < 3600:
|
||||
return f"{seconds // 60}分钟"
|
||||
else:
|
||||
hours = seconds // 3600
|
||||
minutes = (seconds % 3600) // 60
|
||||
if minutes:
|
||||
return f"{hours}小时{minutes}分钟"
|
||||
return f"{hours}小时"
|
||||