diff --git a/plugins/AIChat/main.py b/plugins/AIChat/main.py index 4378c17..c16da5f 100644 --- a/plugins/AIChat/main.py +++ b/plugins/AIChat/main.py @@ -2831,10 +2831,64 @@ class AIChat(PluginBase): refer_type = int(refer_type_elem.text) if refer_type_elem is not None and refer_type_elem.text else 0 logger.debug(f"被引用消息类型: {refer_type}") - # 纯文本消息不需要处理(type=1) + # 纯文本消息(type=1):如果@了机器人,转发给 AI 处理 if refer_type == 1: - logger.debug("引用的是纯文本消息,跳过") - return True + if self._should_reply_quote(message, title_text): + # 获取被引用的文本内容 + refer_content_elem = refermsg.find("content") + refer_text = refer_content_elem.text.strip() if refer_content_elem is not None and refer_content_elem.text else "" + + # 获取被引用者昵称 + refer_displayname = refermsg.find("displayname") + refer_nickname = refer_displayname.text if refer_displayname is not None and refer_displayname.text else "某人" + + # 组合消息:引用内容 + 用户评论 + # title_text 格式如 "@瑞依 评价下",需要去掉 @昵称 部分 + import tomllib + with open("main_config.toml", "rb") as f: + main_config = tomllib.load(f) + bot_nickname = main_config.get("Bot", {}).get("nickname", "") + + user_comment = title_text + if bot_nickname: + # 移除 @机器人昵称(可能有空格分隔) + user_comment = user_comment.replace(f"@{bot_nickname}", "").strip() + + # 构造给 AI 的消息 + combined_message = f"[引用 {refer_nickname} 的消息:{refer_text}]\n{user_comment}" + logger.info(f"引用纯文本消息,转发给 AI: {combined_message[:80]}...") + + # 调用 AI 处理 + nickname = await self._get_user_display_label(bot, from_wxid, user_wxid, is_group) + chat_id = from_wxid if is_group else user_wxid + + # 保存用户消息到群组历史记录 + history_enabled = bool(self.store) and self.config.get("history", {}).get("enabled", True) + sync_bot_messages = self.config.get("history", {}).get("sync_bot_messages", True) + if is_group and history_enabled: + history_chat_id = self._get_group_history_chat_id(from_wxid, user_wxid) + await self._add_to_history(history_chat_id, nickname, combined_message, sender_wxid=user_wxid) + + ai_response = await self._call_ai_api( + combined_message, + bot=bot, + from_wxid=from_wxid, + chat_id=chat_id, + nickname=nickname + ) + + if ai_response: + final_response = self._sanitize_llm_output(ai_response) + await bot.send_text(from_wxid, final_response) + + # 保存 AI 回复到群组历史记录 + if is_group and history_enabled and sync_bot_messages: + bot_nickname_display = main_config.get("Bot", {}).get("nickname", "AI") + await self._add_to_history(history_chat_id, bot_nickname_display, final_response, role="assistant") + return False + else: + logger.debug("引用的是纯文本消息且未@机器人,跳过") + return True # 只处理图片(3)、视频(43)、应用消息(49,含聊天记录) if refer_type not in [3, 43, 49]: @@ -3553,40 +3607,50 @@ class AIChat(PluginBase): def _should_reply_quote(self, message: dict, title_text: str) -> bool: """判断是否应该回复引用消息""" is_group = message.get("IsGroup", False) - + # 检查群聊/私聊开关 if is_group and not self.config["behavior"]["reply_group"]: return False if not is_group and not self.config["behavior"]["reply_private"]: return False - + trigger_mode = self.config["behavior"]["trigger_mode"] - + # all模式:回复所有消息 if trigger_mode == "all": return True - + # mention模式:检查是否@了机器人 if trigger_mode == "mention": if is_group: + # 方式1:检查 Ats 字段(普通消息格式) ats = message.get("Ats", []) - if not ats: - return False - + import tomllib with open("main_config.toml", "rb") as f: main_config = tomllib.load(f) bot_wxid = main_config.get("Bot", {}).get("wxid", "") - - return bot_wxid and bot_wxid in ats + bot_nickname = main_config.get("Bot", {}).get("nickname", "") + + # 检查 Ats 列表 + if bot_wxid and bot_wxid in ats: + return True + + # 方式2:检查标题中是否包含 @机器人昵称(引用消息格式) + # 引用消息的 @ 信息在 title 中,如 "@瑞依 评价下" + if bot_nickname and f"@{bot_nickname}" in title_text: + logger.debug(f"引用消息标题中检测到 @{bot_nickname}") + return True + + return False else: return True - + # keyword模式:检查关键词 if trigger_mode == "keyword": keywords = self.config["behavior"]["keywords"] return any(kw in title_text for kw in keywords) - + return False async def _call_ai_api_with_image( diff --git a/plugins/GrokVideo/main.py b/plugins/GrokVideo/main.py index fde9a8d..7c693ce 100644 --- a/plugins/GrokVideo/main.py +++ b/plugins/GrokVideo/main.py @@ -70,7 +70,7 @@ class GrokVideo(PluginBase): # 初始化MinIO客户端 self.minio_client = Minio( - "101.201.65.129:19000", + "115.190.113.141:19000", access_key="admin", secret_key="80012029Lz", secure=False @@ -173,7 +173,7 @@ class GrokVideo(PluginBase): ) # 返回访问URL - url = f"http://101.201.65.129:19000/{self.minio_bucket}/{object_name}" + url = f"http://115.190.113.141:19000/{self.minio_bucket}/{object_name}" logger.info(f"视频上传成功: {url}") return url @@ -293,7 +293,15 @@ class GrokVideo(PluginBase): # 解析 XML 获取标题和引用消息 try: - root = ET.fromstring(content) + xml_content = content.lstrip("\ufeff") + if ":\n" in xml_content: + xml_start = xml_content.find(" 0: + xml_content = xml_content[xml_start:] + + root = ET.fromstring(xml_content) title = root.find(".//title") if title is None or not title.text: return @@ -338,7 +346,13 @@ class GrokVideo(PluginBase): # 解码 HTML 实体 import html - refer_xml = html.unescape(refer_content.text) + refer_xml = html.unescape(refer_content.text).lstrip("\ufeff") + if ":\n" in refer_xml: + xml_start = refer_xml.find(" 0: + refer_xml = refer_xml[xml_start:] refer_root = ET.fromstring(refer_xml) # 提取图片信息 diff --git a/plugins/JimengAI/__init__.py b/plugins/JimengAI/__init__.py deleted file mode 100644 index f64393c..0000000 --- a/plugins/JimengAI/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# 即梦AI绘图插件 diff --git a/plugins/JimengAI/images/jimeng_20251114_114822_f9403e78.jpg b/plugins/JimengAI/images/jimeng_20251114_114822_f9403e78.jpg deleted file mode 100644 index a6ec7bf..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_114822_f9403e78.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_114823_b311fa36.jpg b/plugins/JimengAI/images/jimeng_20251114_114823_b311fa36.jpg deleted file mode 100644 index 2a33903..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_114823_b311fa36.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_114824_82c1e7a2.jpg b/plugins/JimengAI/images/jimeng_20251114_114824_82c1e7a2.jpg deleted file mode 100644 index d20929c..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_114824_82c1e7a2.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_115142_ec504bc6.jpg b/plugins/JimengAI/images/jimeng_20251114_115142_ec504bc6.jpg deleted file mode 100644 index caeba6f..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_115142_ec504bc6.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_115143_5ece5e06.jpg b/plugins/JimengAI/images/jimeng_20251114_115143_5ece5e06.jpg deleted file mode 100644 index 2413988..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_115143_5ece5e06.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_115144_975d96d1.jpg b/plugins/JimengAI/images/jimeng_20251114_115144_975d96d1.jpg deleted file mode 100644 index e181af7..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_115144_975d96d1.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_115146_8c7f21ae.jpg b/plugins/JimengAI/images/jimeng_20251114_115146_8c7f21ae.jpg deleted file mode 100644 index 519360a..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_115146_8c7f21ae.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_120116_0adfd7ff.jpg b/plugins/JimengAI/images/jimeng_20251114_120116_0adfd7ff.jpg deleted file mode 100644 index c8cc5ef..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_120116_0adfd7ff.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_120117_b78cb9de.jpg b/plugins/JimengAI/images/jimeng_20251114_120117_b78cb9de.jpg deleted file mode 100644 index 6615db5..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_120117_b78cb9de.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_120119_24df1268.jpg b/plugins/JimengAI/images/jimeng_20251114_120119_24df1268.jpg deleted file mode 100644 index 2d24476..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_120119_24df1268.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_120120_0865c643.jpg b/plugins/JimengAI/images/jimeng_20251114_120120_0865c643.jpg deleted file mode 100644 index 6ae8944..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_120120_0865c643.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_120159_5d794eaf.jpg b/plugins/JimengAI/images/jimeng_20251114_120159_5d794eaf.jpg deleted file mode 100644 index 3e62c40..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_120159_5d794eaf.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_120202_327e6bd3.jpg b/plugins/JimengAI/images/jimeng_20251114_120202_327e6bd3.jpg deleted file mode 100644 index 2ca533d..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_120202_327e6bd3.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_120203_6b969b37.jpg b/plugins/JimengAI/images/jimeng_20251114_120203_6b969b37.jpg deleted file mode 100644 index 6c09f8c..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_120203_6b969b37.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_120204_cef46a95.jpg b/plugins/JimengAI/images/jimeng_20251114_120204_cef46a95.jpg deleted file mode 100644 index 68e9f97..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_120204_cef46a95.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_120637_9ad90d59.jpg b/plugins/JimengAI/images/jimeng_20251114_120637_9ad90d59.jpg deleted file mode 100644 index 042ddb0..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_120637_9ad90d59.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_120638_817a6232.jpg b/plugins/JimengAI/images/jimeng_20251114_120638_817a6232.jpg deleted file mode 100644 index af49635..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_120638_817a6232.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_120639_2561d9da.jpg b/plugins/JimengAI/images/jimeng_20251114_120639_2561d9da.jpg deleted file mode 100644 index b168413..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_120639_2561d9da.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_121455_eb449f8d.jpg b/plugins/JimengAI/images/jimeng_20251114_121455_eb449f8d.jpg deleted file mode 100644 index 94f8f03..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_121455_eb449f8d.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_121456_c6c95b59.jpg b/plugins/JimengAI/images/jimeng_20251114_121456_c6c95b59.jpg deleted file mode 100644 index be626fa..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_121456_c6c95b59.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_121457_b9d54377.jpg b/plugins/JimengAI/images/jimeng_20251114_121457_b9d54377.jpg deleted file mode 100644 index 5bfd313..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_121457_b9d54377.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_195847_094489b2.jpg b/plugins/JimengAI/images/jimeng_20251114_195847_094489b2.jpg deleted file mode 100644 index 3b170c9..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_195847_094489b2.jpg and /dev/null differ diff --git a/plugins/JimengAI/images/jimeng_20251114_195848_0aa28cf8.jpg b/plugins/JimengAI/images/jimeng_20251114_195848_0aa28cf8.jpg deleted file mode 100644 index 0b85c63..0000000 Binary files a/plugins/JimengAI/images/jimeng_20251114_195848_0aa28cf8.jpg and /dev/null differ diff --git a/plugins/JimengAI/main.py b/plugins/JimengAI/main.py deleted file mode 100644 index 9b1b7a6..0000000 --- a/plugins/JimengAI/main.py +++ /dev/null @@ -1,372 +0,0 @@ -""" -即梦AI绘图插件 - -支持命令触发和LLM工具调用 -""" - -import asyncio -import tomllib -import aiohttp -import uuid -from pathlib import Path -from datetime import datetime -from typing import List, Optional -from loguru import logger -from utils.plugin_base import PluginBase -from utils.decorators import on_text_message -from WechatHook import WechatHookClient - - -class TokenState: - """Token轮询状态管理""" - def __init__(self): - self.token_index = 0 - self._lock = asyncio.Lock() - - async def get_next_token(self, tokens: List[str]) -> str: - """获取下一个可用的token""" - async with self._lock: - if not tokens: - raise ValueError("Token列表为空") - return tokens[self.token_index % len(tokens)] - - async def rotate(self, tokens: List[str]): - """轮换到下一个token""" - async with self._lock: - if tokens: - self.token_index = (self.token_index + 1) % len(tokens) - - -class JimengAI(PluginBase): - """即梦AI绘图插件""" - - description = "即梦AI绘图插件 - 支持AI绘图和LLM工具调用" - author = "ShiHao" - version = "1.0.0" - - def __init__(self): - super().__init__() - self.config = None - self.token_state = TokenState() - self.images_dir = None - - async def async_init(self): - """异步初始化""" - config_path = Path(__file__).parent / "config.toml" - with open(config_path, "rb") as f: - self.config = tomllib.load(f) - - # 创建图片目录 - self.images_dir = Path(__file__).parent / "images" - self.images_dir.mkdir(exist_ok=True) - - logger.success(f"即梦AI插件初始化完成,配置了 {len(self.config['api']['tokens'])} 个token") - - async def generate_image(self, prompt: str, **kwargs) -> List[str]: - """ - 生成图像 - - Args: - prompt: 提示词 - **kwargs: 其他参数(model, width, height, sample_strength, negative_prompt) - - Returns: - 图片本地路径列表 - """ - api_config = self.config["api"] - gen_config = self.config["generation"] - - model = kwargs.get("model", gen_config["default_model"]) - width = kwargs.get("width", gen_config["default_width"]) - height = kwargs.get("height", gen_config["default_height"]) - sample_strength = kwargs.get("sample_strength", gen_config["default_sample_strength"]) - negative_prompt = kwargs.get("negative_prompt", gen_config["default_negative_prompt"]) - - # 参数验证 - sample_strength = max(0.0, min(1.0, sample_strength)) - width = max(64, min(2048, width)) - height = max(64, min(2048, height)) - - tokens = api_config["tokens"] - max_retry = gen_config["max_retry_attempts"] - - # 尝试每个token - for token_attempt in range(len(tokens)): - current_token = await self.token_state.get_next_token(tokens) - - for attempt in range(max_retry): - if attempt > 0: - await asyncio.sleep(min(2 ** attempt, 10)) - - try: - url = f"{api_config['base_url'].rstrip('/')}/v1/chat/completions" - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {current_token}" - } - - payload = { - "model": model, - "messages": [{"role": "user", "content": prompt}], - "prompt": prompt, - "negativePrompt": negative_prompt, - "width": width, - "height": height, - "sample_strength": sample_strength - } - - logger.info(f"即梦AI请求: {model}, 尺寸: {width}x{height}, 提示词: {prompt[:50]}...") - - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=api_config["timeout"])) as session: - async with session.post(url, headers=headers, json=payload) as response: - if response.status == 200: - data = await response.json() - logger.debug(f"API返回数据: {data}") - - if "error" in data: - logger.error(f"API错误: {data['error']}") - continue - - # 提取图片URL - image_paths = await self._extract_images(data) - - if image_paths: - logger.success(f"成功生成 {len(image_paths)} 张图像") - return image_paths - else: - logger.warning(f"未找到图像数据,API返回: {str(data)[:200]}") - continue - - elif response.status == 401: - logger.warning("Token认证失败,尝试下一个token") - break - elif response.status == 429: - logger.warning("请求频率限制,等待后重试") - await asyncio.sleep(5) - continue - else: - error_text = await response.text() - logger.error(f"API请求失败: {response.status}, {error_text[:200]}") - continue - - except asyncio.TimeoutError: - logger.warning(f"请求超时,重试中... ({attempt + 1}/{max_retry})") - continue - except Exception as e: - logger.error(f"请求异常: {e}") - continue - - # 当前token失败,轮换 - await self.token_state.rotate(tokens) - - logger.error("所有token都失败了") - return [] - - async def _extract_images(self, data: dict) -> List[str]: - """从API响应中提取图片""" - import re - image_paths = [] - - # 格式1: OpenAI格式的choices - if "choices" in data and data["choices"]: - for choice in data["choices"]: - if "message" in choice and "content" in choice["message"]: - content = choice["message"]["content"] - if "https://" in content: - urls = re.findall(r'https://[^\s\)]+', content) - for url in urls: - path = await self._download_image(url) - if path: - image_paths.append(path) - - # 格式2: data数组 - elif "data" in data: - data_list = data["data"] if isinstance(data["data"], list) else [data["data"]] - for item in data_list: - if isinstance(item, str) and item.startswith("http"): - path = await self._download_image(item) - if path: - image_paths.append(path) - elif isinstance(item, dict) and "url" in item: - path = await self._download_image(item["url"]) - if path: - image_paths.append(path) - - # 格式3: images数组 - elif "images" in data: - images_list = data["images"] if isinstance(data["images"], list) else [data["images"]] - for item in images_list: - if isinstance(item, str) and item.startswith("http"): - path = await self._download_image(item) - if path: - image_paths.append(path) - elif isinstance(item, dict) and "url" in item: - path = await self._download_image(item["url"]) - if path: - image_paths.append(path) - - # 格式4: 单个URL - elif "url" in data: - path = await self._download_image(data["url"]) - if path: - image_paths.append(path) - - return image_paths - - async def _download_image(self, url: str) -> Optional[str]: - """下载图片到本地""" - try: - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) as session: - async with session.get(url) as response: - if response.status == 200: - content = await response.read() - - # 生成文件名 - ts = datetime.now().strftime("%Y%m%d_%H%M%S") - uid = uuid.uuid4().hex[:8] - file_path = self.images_dir / f"jimeng_{ts}_{uid}.jpg" - - # 保存文件 - with open(file_path, "wb") as f: - f.write(content) - - logger.info(f"图片下载成功: {file_path}") - return str(file_path) - except Exception as e: - logger.error(f"下载图片失败: {e}") - return None - - @on_text_message(priority=70) - async def handle_message(self, bot: WechatHookClient, message: dict): - """处理文本消息""" - if not self.config["behavior"]["enable_command"]: - return True - - content = message.get("Content", "").strip() - from_wxid = message.get("FromWxid", "") - is_group = message.get("IsGroup", False) - - # 检查群聊/私聊开关 - if is_group and not self.config["behavior"]["enable_group"]: - return True - if not is_group and not self.config["behavior"]["enable_private"]: - return True - - # 检查是否是绘图命令(精确匹配命令+空格+提示词) - keywords = self.config["behavior"]["command_keywords"] - matched_keyword = None - for keyword in keywords: - if content.startswith(keyword + " "): - matched_keyword = keyword - break - - if not matched_keyword: - return True - - # 提取提示词 - prompt = content[len(matched_keyword):].strip() - if not prompt: - await bot.send_text(from_wxid, "❌ 请提供绘图提示词\n用法: /绘图 <提示词>") - return False - - logger.info(f"收到绘图请求: {prompt[:50]}...") - - # 发送处理中提示 - await bot.send_text(from_wxid, "🎨 正在为您生成图像,请稍候...") - - try: - # 生成图像 - image_paths = await self.generate_image(prompt) - - if image_paths: - # 直接发送图片 - await bot.send_image(from_wxid, image_paths[0]) - logger.success(f"绘图成功,已发送图片") - else: - await bot.send_text(from_wxid, "❌ 图像生成失败,请稍后重试") - - except Exception as e: - logger.error(f"绘图处理失败: {e}") - await bot.send_text(from_wxid, f"❌ 处理失败: {str(e)}") - - return False - - def get_llm_tools(self) -> List[dict]: - """ - 返回LLM工具定义 - 供AIChat插件调用 - """ - if not self.config["llm_tool"]["enabled"]: - return [] - - return [{ - "type": "function", - "function": { - "name": self.config["llm_tool"]["tool_name"], - "description": self.config["llm_tool"]["tool_description"], - "parameters": { - "type": "object", - "properties": { - "prompt": { - "type": "string", - "description": "图像生成提示词,描述想要生成的图像内容" - }, - "width": { - "type": "integer", - "description": "图像宽度(64-2048),默认1024", - "default": 1024 - }, - "height": { - "type": "integer", - "description": "图像高度(64-2048),默认1024", - "default": 1024 - } - }, - "required": ["prompt"] - } - } - }] - - async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict: - """ - 执行LLM工具调用 - 供AIChat插件调用 - - Returns: - {"success": bool, "message": str, "images": List[str]} - """ - expected_tool_name = self.config["llm_tool"]["tool_name"] - logger.info(f"JimengAI工具检查: 收到={tool_name}, 期望={expected_tool_name}") - - if tool_name != expected_tool_name: - return None # 不是本插件的工具,返回None让其他插件处理 - - try: - prompt = arguments.get("prompt") - if not prompt: - return {"success": False, "message": "缺少提示词参数"} - - logger.info(f"LLM工具调用绘图: {prompt[:50]}...") - - # 生成图像(使用配置的默认尺寸) - gen_config = self.config["generation"] - image_paths = await self.generate_image( - prompt=prompt, - width=arguments.get("width", gen_config["default_width"]), - height=arguments.get("height", gen_config["default_height"]) - ) - - if image_paths: - # 直接发送图片 - await bot.send_image(from_wxid, image_paths[0]) - return { - "success": True, - "message": "已生成并发送图像", - "images": [image_paths[0]] - } - else: - return {"success": False, "message": "图像生成失败"} - - except Exception as e: - logger.error(f"LLM工具执行失败: {e}") - return {"success": False, "message": f"执行失败: {str(e)}"} diff --git a/plugins/Kiira2AI/__init__.py b/plugins/Kiira2AI/__init__.py deleted file mode 100644 index 21751ea..0000000 --- a/plugins/Kiira2AI/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Kiira2 AI绘图插件""" diff --git a/plugins/Kiira2AI/main.py b/plugins/Kiira2AI/main.py deleted file mode 100644 index 283c518..0000000 --- a/plugins/Kiira2AI/main.py +++ /dev/null @@ -1,350 +0,0 @@ -""" -Kiira2 AI绘图插件 - -支持命令触发和LLM工具调用 -""" - -import asyncio -import tomllib -import httpx -import uuid -from pathlib import Path -from datetime import datetime -from typing import List, Optional -from loguru import logger -from utils.plugin_base import PluginBase -from utils.decorators import on_text_message -from WechatHook import WechatHookClient - - -class TokenState: - """Token轮询状态管理""" - def __init__(self): - self.token_index = 0 - self._lock = asyncio.Lock() - - async def get_next_token(self, tokens: List[str]) -> str: - """获取下一个可用的token""" - async with self._lock: - if not tokens: - raise ValueError("Token列表为空") - return tokens[self.token_index % len(tokens)] - - async def rotate(self, tokens: List[str]): - """轮换到下一个token""" - async with self._lock: - if tokens: - self.token_index = (self.token_index + 1) % len(tokens) - - -class Kiira2AI(PluginBase): - """Kiira2 AI绘图插件""" - - description = "Kiira2 AI绘图插件 - 支持AI绘图和LLM工具调用" - author = "ShiHao" - version = "1.0.0" - - def __init__(self): - super().__init__() - self.config = None - self.token_state = TokenState() - self.images_dir = None - - async def async_init(self): - """异步初始化""" - config_path = Path(__file__).parent / "config.toml" - with open(config_path, "rb") as f: - self.config = tomllib.load(f) - - # 创建图片目录 - self.images_dir = Path(__file__).parent / "images" - self.images_dir.mkdir(exist_ok=True) - - logger.success(f"Kiira2 AI插件初始化完成,配置了 {len(self.config['api']['tokens'])} 个token") - - async def generate_image(self, prompt: str, **kwargs) -> List[str]: - """ - 生成图像 - - Args: - prompt: 提示词 - **kwargs: 其他参数(model) - - Returns: - 图片本地路径列表 - """ - api_config = self.config["api"] - gen_config = self.config["generation"] - - model = kwargs.get("model", gen_config["default_model"]) - tokens = api_config["tokens"] - max_retry = gen_config["max_retry_attempts"] - - # 尝试每个token - for token_attempt in range(len(tokens)): - current_token = await self.token_state.get_next_token(tokens) - - for attempt in range(max_retry): - if attempt > 0: - await asyncio.sleep(min(2 ** attempt, 10)) - - try: - url = f"{api_config['base_url'].rstrip('/')}/v1/chat/completions" - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {current_token}" - } - - payload = { - "model": model, - "messages": [{"role": "user", "content": prompt}], - "stream": False - } - - logger.info(f"Kiira2 AI请求: {model}, 提示词: {prompt[:50]}...") - - timeout = httpx.Timeout(connect=10.0, read=api_config["timeout"], write=10.0, pool=10.0) - - # 配置代理 - proxy = None - proxy_config = self.config.get("proxy", {}) - if proxy_config.get("enabled", False): - proxy_type = proxy_config.get("type", "socks5") - proxy_host = proxy_config.get("host", "127.0.0.1") - proxy_port = proxy_config.get("port", 7890) - proxy = f"{proxy_type}://{proxy_host}:{proxy_port}" - logger.info(f"使用代理: {proxy}") - - async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client: - response = await client.post(url, json=payload, headers=headers) - - if response.status_code == 200: - data = response.json() - logger.debug(f"API返回数据: {data}") - - if "error" in data: - logger.error(f"API错误: {data['error']}") - continue - - # 检查是否返回空content(图片还在生成中) - if "choices" in data and data["choices"]: - message = data["choices"][0].get("message", {}) - content = message.get("content", "") - video_url = message.get("video_url") - - # 如果content为空且没有video_url,说明还在生成,等待后重试 - if not content and not video_url: - wait_time = min(10 + attempt * 5, 30) - logger.info(f"图片生成中,等待 {wait_time} 秒后重试...") - await asyncio.sleep(wait_time) - continue - - # 提取图片URL - image_paths = await self._extract_images(data) - - if image_paths: - logger.success(f"成功生成 {len(image_paths)} 张图像") - return image_paths - else: - logger.warning(f"未找到图像数据,API返回: {str(data)[:500]}") - continue - - elif response.status_code == 401: - logger.warning("Token认证失败,尝试下一个token") - break - elif response.status_code == 429: - logger.warning("请求频率限制,等待后重试") - await asyncio.sleep(5) - continue - else: - error_text = response.text - logger.error(f"API请求失败: {response.status_code}, {error_text[:200]}") - continue - - except asyncio.TimeoutError: - logger.warning(f"请求超时,重试中... ({attempt + 1}/{max_retry})") - continue - except Exception as e: - logger.error(f"请求异常: {e}") - continue - - # 当前token失败,轮换 - await self.token_state.rotate(tokens) - - logger.error("所有token都失败了") - return [] - - async def _extract_images(self, data: dict) -> List[str]: - """从API响应中提取图片(只提取图片,忽略文字)""" - import re - image_paths = [] - - # OpenAI格式的choices - if "choices" in data and data["choices"]: - for choice in data["choices"]: - message = choice.get("message", {}) - - # 检查video_url字段(实际包含图片URL) - if "video_url" in message: - video_url = message["video_url"] - if isinstance(video_url, list) and video_url: - url = video_url[0] - if isinstance(url, str) and url.startswith("http"): - path = await self._download_image(url) - if path: - image_paths.append(path) - - # 检查content字段 - if "content" in message and not image_paths: - content = message["content"] - if content and "http" in content: - urls = re.findall(r'https?://[^\s\)\]"]+', content) - for url in urls: - path = await self._download_image(url) - if path: - image_paths.append(path) - - return image_paths - - async def _download_image(self, url: str) -> Optional[str]: - """下载图片到本地""" - try: - timeout = httpx.Timeout(connect=10.0, read=30.0, write=10.0, pool=10.0) - - # 配置代理 - proxy = None - proxy_config = self.config.get("proxy", {}) - if proxy_config.get("enabled", False): - proxy_type = proxy_config.get("type", "socks5") - proxy_host = proxy_config.get("host", "127.0.0.1") - proxy_port = proxy_config.get("port", 7890) - proxy = f"{proxy_type}://{proxy_host}:{proxy_port}" - - async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client: - response = await client.get(url) - response.raise_for_status() - - # 生成文件名 - ts = datetime.now().strftime("%Y%m%d_%H%M%S") - uid = uuid.uuid4().hex[:8] - file_path = self.images_dir / f"kiira2_{ts}_{uid}.jpg" - - # 保存文件 - with open(file_path, "wb") as f: - f.write(response.content) - - logger.info(f"图片下载成功: {file_path}") - return str(file_path) - except Exception as e: - logger.error(f"下载图片失败: {e}") - return None - - @on_text_message(priority=70) - async def handle_message(self, bot: WechatHookClient, message: dict): - """处理文本消息""" - if not self.config["behavior"]["enable_command"]: - return True - - content = message.get("Content", "").strip() - from_wxid = message.get("FromWxid", "") - is_group = message.get("IsGroup", False) - - # 检查群聊/私聊开关 - if is_group and not self.config["behavior"]["enable_group"]: - return True - if not is_group and not self.config["behavior"]["enable_private"]: - return True - - # 检查是否是绘图命令 - keywords = self.config["behavior"]["command_keywords"] - matched_keyword = None - for keyword in keywords: - if content.startswith(keyword + " "): - matched_keyword = keyword - break - - if not matched_keyword: - return True - - # 提取提示词 - prompt = content[len(matched_keyword):].strip() - if not prompt: - await bot.send_text(from_wxid, "❌ 请提供绘图提示词\n用法: /画画 <提示词>") - return False - - logger.info(f"收到绘图请求: {prompt[:50]}...") - - # 发送处理中提示 - await bot.send_text(from_wxid, "🎨 正在为您生成图像,请稍候...") - - try: - # 生成图像 - image_paths = await self.generate_image(prompt) - - if image_paths: - # 直接发送图片 - await bot.send_image(from_wxid, image_paths[0]) - logger.success(f"绘图成功,已发送图片") - else: - await bot.send_text(from_wxid, "❌ 图像生成失败,请稍后重试") - - except Exception as e: - logger.error(f"绘图处理失败: {e}") - await bot.send_text(from_wxid, f"❌ 处理失败: {str(e)}") - - return False - - def get_llm_tools(self) -> List[dict]: - """返回LLM工具定义""" - if not self.config["llm_tool"]["enabled"]: - return [] - - return [{ - "type": "function", - "function": { - "name": self.config["llm_tool"]["tool_name"], - "description": self.config["llm_tool"]["tool_description"], - "parameters": { - "type": "object", - "properties": { - "prompt": { - "type": "string", - "description": "图像生成提示词,描述想要生成的图像内容" - } - }, - "required": ["prompt"] - } - } - }] - - async def execute_llm_tool(self, tool_name: str, arguments: dict, bot: WechatHookClient, from_wxid: str) -> dict: - """执行LLM工具调用""" - expected_tool_name = self.config["llm_tool"]["tool_name"] - - if tool_name != expected_tool_name: - return None - - try: - prompt = arguments.get("prompt") - if not prompt: - return {"success": False, "message": "缺少提示词参数"} - - logger.info(f"LLM工具调用绘图: {prompt[:50]}...") - - # 生成图像 - image_paths = await self.generate_image(prompt=prompt) - - if image_paths: - # 直接发送图片 - await bot.send_image(from_wxid, image_paths[0]) - return { - "success": True, - "message": "已生成并发送图像", - "images": [image_paths[0]] - } - else: - return {"success": False, "message": "图像生成失败"} - - except Exception as e: - logger.error(f"LLM工具执行失败: {e}") - return {"success": False, "message": f"执行失败: {str(e)}"} diff --git a/plugins/TravelPlanner/__init__.py b/plugins/TravelPlanner/__init__.py new file mode 100644 index 0000000..4e1d0f5 --- /dev/null +++ b/plugins/TravelPlanner/__init__.py @@ -0,0 +1,3 @@ +from .main import TravelPlanner + +__all__ = ["TravelPlanner"] diff --git a/plugins/TravelPlanner/amap_client.py b/plugins/TravelPlanner/amap_client.py new file mode 100644 index 0000000..6f8c515 --- /dev/null +++ b/plugins/TravelPlanner/amap_client.py @@ -0,0 +1,860 @@ +""" +高德地图 API 客户端封装 + +提供以下功能: +- 地理编码:地址 → 坐标 +- 逆地理编码:坐标 → 地址 +- 行政区域查询:获取城市 adcode +- 天气查询:实况/预报天气 +- POI 搜索:关键字搜索、周边搜索 +- 路径规划:驾车、公交、步行、骑行 +""" + +from __future__ import annotations + +import hashlib +import aiohttp +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Literal +from loguru import logger + + +@dataclass +class AmapConfig: + """高德 API 配置""" + api_key: str + secret: str = "" # 安全密钥,用于数字签名 + timeout: int = 30 + + +class AmapClient: + """高德地图 API 客户端""" + + BASE_URL = "https://restapi.amap.com" + + def __init__(self, config: AmapConfig): + self.config = config + self._session: Optional[aiohttp.ClientSession] = None + + @staticmethod + def _safe_int(value, default: int = 0) -> int: + """安全地将值转换为整数,处理列表、None、空字符串等情况""" + if value is None: + return default + if isinstance(value, list): + return default + if isinstance(value, (int, float)): + return int(value) + if isinstance(value, str): + if not value.strip(): + return default + try: + return int(float(value)) + except (ValueError, TypeError): + return default + return default + + @staticmethod + def _safe_float(value, default: float = 0.0) -> float: + """安全地将值转换为浮点数""" + if value is None: + return default + if isinstance(value, list): + return default + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + if not value.strip(): + return default + try: + return float(value) + except (ValueError, TypeError): + return default + return default + + @staticmethod + def _safe_str(value, default: str = "") -> str: + """安全地将值转换为字符串,处理列表等情况""" + if value is None: + return default + if isinstance(value, list): + return default + return str(value) + + async def _get_session(self) -> aiohttp.ClientSession: + """获取或创建 HTTP 会话""" + if self._session is None or self._session.closed: + timeout = aiohttp.ClientTimeout(total=self.config.timeout) + self._session = aiohttp.ClientSession(timeout=timeout) + return self._session + + async def close(self): + """关闭 HTTP 会话""" + if self._session and not self._session.closed: + await self._session.close() + + def _generate_signature(self, params: Dict[str, Any]) -> str: + """ + 生成数字签名 + + 算法: + 1. 将请求参数按参数名升序排序 + 2. 按 key=value 格式拼接,用 & 连接 + 3. 最后拼接上私钥(secret) + 4. 对整个字符串进行 MD5 加密 + + Args: + params: 请求参数(不含 sig) + + Returns: + MD5 签名字符串 + """ + # 按参数名升序排序 + sorted_params = sorted(params.items(), key=lambda x: x[0]) + # 拼接成 key=value&key=value 格式 + param_str = "&".join(f"{k}={v}" for k, v in sorted_params) + # 拼接私钥 + sign_str = param_str + self.config.secret + # MD5 加密 + return hashlib.md5(sign_str.encode('utf-8')).hexdigest() + + async def _request(self, endpoint: str, params: Dict[str, Any]) -> Dict[str, Any]: + """ + 发送 API 请求 + + Args: + endpoint: API 端点路径 + params: 请求参数 + + Returns: + API 响应数据 + """ + params["key"] = self.config.api_key + params["output"] = "JSON" + + # 如果配置了安全密钥,生成数字签名 + if self.config.secret: + params["sig"] = self._generate_signature(params) + + url = f"{self.BASE_URL}{endpoint}" + session = await self._get_session() + + try: + async with session.get(url, params=params) as response: + data = await response.json() + + # 检查 API 状态 + status = data.get("status", "0") + if status != "1": + info = data.get("info", "未知错误") + infocode = data.get("infocode", "") + logger.warning(f"高德 API 错误: {info} (code: {infocode})") + return {"success": False, "error": info, "code": infocode} + + return {"success": True, "data": data} + + except aiohttp.ClientError as e: + logger.error(f"高德 API 请求失败: {e}") + return {"success": False, "error": str(e)} + except Exception as e: + logger.error(f"高德 API 未知错误: {e}") + return {"success": False, "error": str(e)} + + # ==================== 地理编码 ==================== + + async def geocode(self, address: str, city: str = None) -> Dict[str, Any]: + """ + 地理编码:将地址转换为坐标 + + Args: + address: 结构化地址,如 "北京市朝阳区阜通东大街6号" + city: 指定城市(可选) + + Returns: + { + "success": True, + "location": "116.480881,39.989410", + "adcode": "110105", + "city": "北京市", + "district": "朝阳区", + "level": "门址" + } + """ + params = {"address": address} + if city: + params["city"] = city + + result = await self._request("/v3/geocode/geo", params) + + if not result["success"]: + return result + + geocodes = result["data"].get("geocodes", []) + if not geocodes: + return {"success": False, "error": "未找到该地址"} + + geo = geocodes[0] + return { + "success": True, + "location": geo.get("location", ""), + "adcode": geo.get("adcode", ""), + "province": geo.get("province", ""), + "city": geo.get("city", ""), + "district": geo.get("district", ""), + "level": geo.get("level", ""), + "formatted_address": geo.get("formatted_address", address) + } + + async def reverse_geocode( + self, + location: str, + radius: int = 1000, + extensions: str = "base" + ) -> Dict[str, Any]: + """ + 逆地理编码:将坐标转换为地址 + + Args: + location: 经纬度坐标,格式 "lng,lat" + radius: 搜索半径(米),0-3000 + extensions: base 或 all + + Returns: + 地址信息 + """ + params = { + "location": location, + "radius": min(radius, 3000), + "extensions": extensions + } + + result = await self._request("/v3/geocode/regeo", params) + + if not result["success"]: + return result + + regeocode = result["data"].get("regeocode", {}) + address_component = regeocode.get("addressComponent", {}) + + return { + "success": True, + "formatted_address": regeocode.get("formatted_address", ""), + "province": address_component.get("province", ""), + "city": address_component.get("city", ""), + "district": address_component.get("district", ""), + "adcode": address_component.get("adcode", ""), + "township": address_component.get("township", ""), + "pois": regeocode.get("pois", []) if extensions == "all" else [] + } + + # ==================== 行政区域查询 ==================== + + async def get_district( + self, + keywords: str = None, + subdistrict: int = 1 + ) -> Dict[str, Any]: + """ + 行政区域查询 + + Args: + keywords: 查询关键字(城市名、adcode 等) + subdistrict: 返回下级行政区级数(0-3) + + Returns: + 行政区域信息,包含 adcode、citycode 等 + """ + params = {"subdistrict": subdistrict} + if keywords: + params["keywords"] = keywords + + result = await self._request("/v3/config/district", params) + + if not result["success"]: + return result + + districts = result["data"].get("districts", []) + if not districts: + return {"success": False, "error": "未找到该行政区域"} + + district = districts[0] + return { + "success": True, + "name": district.get("name", ""), + "adcode": district.get("adcode", ""), + "citycode": district.get("citycode", ""), + "center": district.get("center", ""), + "level": district.get("level", ""), + "districts": district.get("districts", []) + } + + # ==================== 天气查询 ==================== + + async def get_weather( + self, + city: str, + extensions: Literal["base", "all"] = "all" + ) -> Dict[str, Any]: + """ + 天气查询 + + Args: + city: 城市 adcode(如 110000)或城市名 + extensions: base=实况天气,all=预报天气(未来4天) + + Returns: + 天气信息 + """ + # 如果传入的是城市名,先获取 adcode + if not city.isdigit(): + district_result = await self.get_district(city) + if not district_result["success"]: + return {"success": False, "error": f"无法识别城市: {city}"} + city = district_result["adcode"] + + params = { + "city": city, + "extensions": extensions + } + + result = await self._request("/v3/weather/weatherInfo", params) + + if not result["success"]: + return result + + data = result["data"] + + if extensions == "base": + # 实况天气 + lives = data.get("lives", []) + if not lives: + return {"success": False, "error": "未获取到天气数据"} + + live = lives[0] + return { + "success": True, + "type": "live", + "city": live.get("city", ""), + "weather": live.get("weather", ""), + "temperature": live.get("temperature", ""), + "winddirection": live.get("winddirection", ""), + "windpower": live.get("windpower", ""), + "humidity": live.get("humidity", ""), + "reporttime": live.get("reporttime", "") + } + else: + # 预报天气 + forecasts = data.get("forecasts", []) + if not forecasts: + return {"success": False, "error": "未获取到天气预报数据"} + + forecast = forecasts[0] + casts = forecast.get("casts", []) + + return { + "success": True, + "type": "forecast", + "city": forecast.get("city", ""), + "province": forecast.get("province", ""), + "reporttime": forecast.get("reporttime", ""), + "forecasts": [ + { + "date": cast.get("date", ""), + "week": cast.get("week", ""), + "dayweather": cast.get("dayweather", ""), + "nightweather": cast.get("nightweather", ""), + "daytemp": cast.get("daytemp", ""), + "nighttemp": cast.get("nighttemp", ""), + "daywind": cast.get("daywind", ""), + "nightwind": cast.get("nightwind", ""), + "daypower": cast.get("daypower", ""), + "nightpower": cast.get("nightpower", "") + } + for cast in casts + ] + } + + # ==================== POI 搜索 ==================== + + async def search_poi( + self, + keywords: str = None, + types: str = None, + city: str = None, + citylimit: bool = True, + offset: int = 20, + page: int = 1, + extensions: str = "all" + ) -> Dict[str, Any]: + """ + 关键字搜索 POI + + Args: + keywords: 查询关键字 + types: POI 类型代码,多个用 | 分隔 + city: 城市名或 adcode + citylimit: 是否仅返回指定城市 + offset: 每页数量(建议不超过25) + page: 页码 + extensions: base 或 all + + Returns: + POI 列表 + """ + params = { + "offset": min(offset, 25), + "page": page, + "extensions": extensions + } + + if keywords: + params["keywords"] = keywords + if types: + params["types"] = types + if city: + params["city"] = city + params["citylimit"] = "true" if citylimit else "false" + + result = await self._request("/v3/place/text", params) + + if not result["success"]: + return result + + pois = result["data"].get("pois", []) + count = self._safe_int(result["data"].get("count", 0)) + + return { + "success": True, + "count": count, + "pois": [self._format_poi(poi) for poi in pois] + } + + async def search_around( + self, + location: str, + keywords: str = None, + types: str = None, + radius: int = 3000, + offset: int = 20, + page: int = 1, + extensions: str = "all" + ) -> Dict[str, Any]: + """ + 周边搜索 POI + + Args: + location: 中心点坐标,格式 "lng,lat" + keywords: 查询关键字 + types: POI 类型代码 + radius: 搜索半径(米),0-50000 + offset: 每页数量 + page: 页码 + extensions: base 或 all + + Returns: + POI 列表 + """ + params = { + "location": location, + "radius": min(radius, 50000), + "offset": min(offset, 25), + "page": page, + "extensions": extensions, + "sortrule": "distance" + } + + if keywords: + params["keywords"] = keywords + if types: + params["types"] = types + + result = await self._request("/v3/place/around", params) + + if not result["success"]: + return result + + pois = result["data"].get("pois", []) + count = self._safe_int(result["data"].get("count", 0)) + + return { + "success": True, + "count": count, + "pois": [self._format_poi(poi) for poi in pois] + } + + def _format_poi(self, poi: Dict[str, Any]) -> Dict[str, Any]: + """格式化 POI 数据""" + biz_ext = poi.get("biz_ext", {}) or {} + return { + "id": poi.get("id", ""), + "name": poi.get("name", ""), + "type": poi.get("type", ""), + "address": poi.get("address", ""), + "location": poi.get("location", ""), + "tel": poi.get("tel", ""), + "distance": poi.get("distance", ""), + "pname": poi.get("pname", ""), + "cityname": poi.get("cityname", ""), + "adname": poi.get("adname", ""), + "rating": biz_ext.get("rating", ""), + "cost": biz_ext.get("cost", "") + } + + # ==================== 路径规划 ==================== + + async def route_driving( + self, + origin: str, + destination: str, + strategy: int = 10, + waypoints: str = None, + extensions: str = "base" + ) -> Dict[str, Any]: + """ + 驾车路径规划 + + Args: + origin: 起点坐标 "lng,lat" + destination: 终点坐标 "lng,lat" + strategy: 驾车策略(10=躲避拥堵,13=不走高速,14=避免收费) + waypoints: 途经点,多个用 ; 分隔 + extensions: base 或 all + + Returns: + 路径规划结果 + """ + params = { + "origin": origin, + "destination": destination, + "strategy": strategy, + "extensions": extensions + } + if waypoints: + params["waypoints"] = waypoints + + result = await self._request("/v3/direction/driving", params) + + if not result["success"]: + return result + + route = result["data"].get("route", {}) + paths = route.get("paths", []) + + if not paths: + return {"success": False, "error": "未找到驾车路线"} + + path = paths[0] + return { + "success": True, + "mode": "driving", + "origin": route.get("origin", ""), + "destination": route.get("destination", ""), + "distance": self._safe_int(path.get("distance", 0)), + "duration": self._safe_int(path.get("duration", 0)), + "tolls": self._safe_float(path.get("tolls", 0)), + "toll_distance": self._safe_int(path.get("toll_distance", 0)), + "traffic_lights": self._safe_int(path.get("traffic_lights", 0)), + "taxi_cost": self._safe_str(route.get("taxi_cost", "")), + "strategy": path.get("strategy", ""), + "steps": self._format_driving_steps(path.get("steps", [])) + } + + async def route_transit( + self, + origin: str, + destination: str, + city: str, + cityd: str = None, + strategy: int = 0, + extensions: str = "all" + ) -> Dict[str, Any]: + """ + 公交路径规划(含火车、地铁) + + Args: + origin: 起点坐标 "lng,lat" + destination: 终点坐标 "lng,lat" + city: 起点城市 + cityd: 终点城市(跨城时必填) + strategy: 0=最快,1=最省钱,2=最少换乘,3=最少步行 + extensions: base 或 all + + Returns: + 公交路径规划结果 + """ + params = { + "origin": origin, + "destination": destination, + "city": city, + "strategy": strategy, + "extensions": extensions + } + if cityd: + params["cityd"] = cityd + + result = await self._request("/v3/direction/transit/integrated", params) + + if not result["success"]: + return result + + route = result["data"].get("route", {}) + transits = route.get("transits", []) + + if not transits: + return {"success": False, "error": "未找到公交路线"} + + # 返回前3个方案 + formatted_transits = [] + for transit in transits[:3]: + segments = transit.get("segments", []) + formatted_segments = [] + + for seg in segments: + # 步行段 + walking = seg.get("walking", {}) + if walking and walking.get("distance"): + formatted_segments.append({ + "type": "walking", + "distance": self._safe_int(walking.get("distance", 0)), + "duration": self._safe_int(walking.get("duration", 0)) + }) + + # 公交/地铁段 + bus_info = seg.get("bus", {}) + buslines = bus_info.get("buslines", []) + if buslines: + line = buslines[0] + formatted_segments.append({ + "type": "bus", + "name": self._safe_str(line.get("name", "")), + "departure_stop": self._safe_str(line.get("departure_stop", {}).get("name", "")), + "arrival_stop": self._safe_str(line.get("arrival_stop", {}).get("name", "")), + "via_num": self._safe_int(line.get("via_num", 0)), + "distance": self._safe_int(line.get("distance", 0)), + "duration": self._safe_int(line.get("duration", 0)) + }) + + # 火车段 + railway = seg.get("railway", {}) + if railway and railway.get("name"): + formatted_segments.append({ + "type": "railway", + "name": self._safe_str(railway.get("name", "")), + "trip": self._safe_str(railway.get("trip", "")), + "departure_stop": self._safe_str(railway.get("departure_stop", {}).get("name", "")), + "arrival_stop": self._safe_str(railway.get("arrival_stop", {}).get("name", "")), + "departure_time": self._safe_str(railway.get("departure_stop", {}).get("time", "")), + "arrival_time": self._safe_str(railway.get("arrival_stop", {}).get("time", "")), + "distance": self._safe_int(railway.get("distance", 0)), + "time": self._safe_str(railway.get("time", "")) + }) + + formatted_transits.append({ + "cost": self._safe_str(transit.get("cost", "")), + "duration": self._safe_int(transit.get("duration", 0)), + "walking_distance": self._safe_int(transit.get("walking_distance", 0)), + "segments": formatted_segments + }) + + return { + "success": True, + "mode": "transit", + "origin": route.get("origin", ""), + "destination": route.get("destination", ""), + "distance": self._safe_int(route.get("distance", 0)), + "taxi_cost": self._safe_str(route.get("taxi_cost", "")), + "transits": formatted_transits + } + + async def route_walking( + self, + origin: str, + destination: str + ) -> Dict[str, Any]: + """ + 步行路径规划 + + Args: + origin: 起点坐标 "lng,lat" + destination: 终点坐标 "lng,lat" + + Returns: + 步行路径规划结果 + """ + params = { + "origin": origin, + "destination": destination + } + + result = await self._request("/v3/direction/walking", params) + + if not result["success"]: + return result + + route = result["data"].get("route", {}) + paths = route.get("paths", []) + + if not paths: + return {"success": False, "error": "未找到步行路线"} + + path = paths[0] + return { + "success": True, + "mode": "walking", + "origin": route.get("origin", ""), + "destination": route.get("destination", ""), + "distance": self._safe_int(path.get("distance", 0)), + "duration": self._safe_int(path.get("duration", 0)) + } + + async def route_bicycling( + self, + origin: str, + destination: str + ) -> Dict[str, Any]: + """ + 骑行路径规划 + + Args: + origin: 起点坐标 "lng,lat" + destination: 终点坐标 "lng,lat" + + Returns: + 骑行路径规划结果 + """ + params = { + "origin": origin, + "destination": destination + } + + # 骑行用 v4 接口 + result = await self._request("/v4/direction/bicycling", params) + + if not result["success"]: + return result + + data = result["data"].get("data", {}) + paths = data.get("paths", []) + + if not paths: + return {"success": False, "error": "未找到骑行路线"} + + path = paths[0] + return { + "success": True, + "mode": "bicycling", + "origin": data.get("origin", ""), + "destination": data.get("destination", ""), + "distance": self._safe_int(path.get("distance", 0)), + "duration": self._safe_int(path.get("duration", 0)) + } + + def _format_driving_steps(self, steps: List[Dict]) -> List[Dict]: + """格式化驾车步骤""" + return [ + { + "instruction": step.get("instruction", ""), + "road": step.get("road", ""), + "distance": self._safe_int(step.get("distance", 0)), + "duration": self._safe_int(step.get("duration", 0)), + "orientation": step.get("orientation", "") + } + for step in steps[:10] # 只返回前10步 + ] + + # ==================== 距离测量 ==================== + + async def get_distance( + self, + origins: str, + destination: str, + type: int = 1 + ) -> Dict[str, Any]: + """ + 距离测量 + + Args: + origins: 起点坐标,多个用 | 分隔 + destination: 终点坐标 + type: 0=直线距离,1=驾车距离,3=步行距离 + + Returns: + 距离信息 + """ + params = { + "origins": origins, + "destination": destination, + "type": type + } + + result = await self._request("/v3/distance", params) + + if not result["success"]: + return result + + results = result["data"].get("results", []) + if not results: + return {"success": False, "error": "无法计算距离"} + + return { + "success": True, + "results": [ + { + "origin_id": r.get("origin_id", ""), + "distance": self._safe_int(r.get("distance", 0)), + "duration": self._safe_int(r.get("duration", 0)) + } + for r in results + ] + } + + # ==================== 输入提示 ==================== + + async def input_tips( + self, + keywords: str, + city: str = None, + citylimit: bool = False, + datatype: str = "all" + ) -> Dict[str, Any]: + """ + 输入提示 + + Args: + keywords: 查询关键字 + city: 城市名或 adcode + citylimit: 是否仅返回指定城市 + datatype: all/poi/bus/busline + + Returns: + 提示列表 + """ + params = { + "keywords": keywords, + "datatype": datatype + } + if city: + params["city"] = city + params["citylimit"] = "true" if citylimit else "false" + + result = await self._request("/v3/assistant/inputtips", params) + + if not result["success"]: + return result + + tips = result["data"].get("tips", []) + return { + "success": True, + "tips": [ + { + "id": tip.get("id", ""), + "name": tip.get("name", ""), + "district": tip.get("district", ""), + "adcode": tip.get("adcode", ""), + "location": tip.get("location", ""), + "address": tip.get("address", "") + } + for tip in tips + if tip.get("location") # 过滤无坐标的结果 + ] + } diff --git a/plugins/TravelPlanner/main.py b/plugins/TravelPlanner/main.py new file mode 100644 index 0000000..a6af8fb --- /dev/null +++ b/plugins/TravelPlanner/main.py @@ -0,0 +1,609 @@ +""" +旅行规划插件 + +基于高德地图 API,提供以下功能: +- 地点搜索与地理编码 +- 天气查询(实况 + 4天预报) +- 景点/酒店/餐厅搜索 +- 路径规划(驾车/公交/步行) +- 周边搜索 + +支持 LLM 函数调用,可与 AIChat 插件配合使用。 +""" + +import tomllib +from pathlib import Path +from typing import Any, Dict, List +from loguru import logger + +from utils.plugin_base import PluginBase +from .amap_client import AmapClient, AmapConfig + + +class TravelPlanner(PluginBase): + """旅行规划插件""" + + description = "旅行规划助手,支持天气查询、景点搜索、路线规划" + author = "ShiHao" + version = "1.0.0" + + def __init__(self): + super().__init__() + self.config = None + self.amap: AmapClient = None + + async def async_init(self): + """插件异步初始化""" + # 读取配置 + config_path = Path(__file__).parent / "config.toml" + with open(config_path, "rb") as f: + self.config = tomllib.load(f) + + # 初始化高德 API 客户端 + amap_config = self.config.get("amap", {}) + api_key = amap_config.get("api_key", "") + secret = amap_config.get("secret", "") + + if not api_key: + logger.warning("TravelPlanner: 未配置高德 API Key,请在 config.toml 中设置") + else: + self.amap = AmapClient(AmapConfig( + api_key=api_key, + secret=secret, + timeout=amap_config.get("timeout", 30) + )) + if secret: + logger.success(f"TravelPlanner 插件已加载,API Key: {api_key[:8]}...(已启用数字签名)") + else: + logger.success(f"TravelPlanner 插件已加载,API Key: {api_key[:8]}...(未配置安全密钥)") + + async def on_disable(self): + """插件禁用时关闭连接""" + await super().on_disable() + if self.amap: + await self.amap.close() + logger.info("TravelPlanner: 已关闭高德 API 连接") + + # ==================== LLM 工具定义 ==================== + + def get_llm_tools(self) -> List[Dict]: + """返回 LLM 可调用的工具列表""" + return [ + { + "type": "function", + "function": { + "name": "search_location", + "description": "【旅行工具】将地名转换为坐标和行政区划信息。仅当用户明确询问某个地点的位置信息时使用。", + "parameters": { + "type": "object", + "properties": { + "address": { + "type": "string", + "description": "地址或地名,如:北京市、西湖、故宫" + }, + "city": { + "type": "string", + "description": "所在城市,可选。填写可提高搜索精度" + } + }, + "required": ["address"] + } + } + }, + { + "type": "function", + "function": { + "name": "query_weather", + "description": "【旅行工具】查询城市天气预报。仅当用户明确询问某城市的天气情况时使用,如'北京天气怎么样'、'杭州明天会下雨吗'。", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "城市名称,如:北京、杭州、上海" + }, + "forecast": { + "type": "boolean", + "description": "是否查询预报天气。true=未来4天预报,false=当前实况" + } + }, + "required": ["city"] + } + } + }, + { + "type": "function", + "function": { + "name": "search_poi", + "description": "【旅行工具】搜索地点(景点、酒店、餐厅等)。仅当用户明确要求查找某城市的景点、酒店、餐厅等时使用。", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "搜索城市,如:杭州、北京" + }, + "keyword": { + "type": "string", + "description": "搜索关键词,如:西湖、希尔顿酒店、火锅" + }, + "category": { + "type": "string", + "enum": ["景点", "酒店", "餐厅", "购物", "交通"], + "description": "POI 类别。不填则搜索所有类别" + }, + "limit": { + "type": "integer", + "description": "返回结果数量,默认10,最大20" + } + }, + "required": ["city"] + } + } + }, + { + "type": "function", + "function": { + "name": "search_nearby", + "description": "【旅行工具】搜索某地点周边的设施。仅当用户明确要求查找某地点附近的餐厅、酒店等时使用,如'西湖附近有什么好吃的'。", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "中心地点名称,如:西湖、故宫" + }, + "city": { + "type": "string", + "description": "所在城市" + }, + "keyword": { + "type": "string", + "description": "搜索关键词" + }, + "category": { + "type": "string", + "enum": ["景点", "酒店", "餐厅", "购物", "交通"], + "description": "POI 类别" + }, + "radius": { + "type": "integer", + "description": "搜索半径(米),默认3000,最大50000" + } + }, + "required": ["location", "city"] + } + } + }, + { + "type": "function", + "function": { + "name": "plan_route", + "description": "【旅行工具】规划两地之间的出行路线。仅当用户明确要求规划从A到B的路线时使用,如'从北京到杭州怎么走'、'上海到苏州的高铁'。", + "parameters": { + "type": "object", + "properties": { + "origin": { + "type": "string", + "description": "起点地名,如:北京、上海虹桥站" + }, + "destination": { + "type": "string", + "description": "终点地名,如:杭州、西湖" + }, + "origin_city": { + "type": "string", + "description": "起点所在城市" + }, + "destination_city": { + "type": "string", + "description": "终点所在城市(跨城时必填)" + }, + "mode": { + "type": "string", + "enum": ["driving", "transit", "walking"], + "description": "出行方式:driving=驾车,transit=公交/高铁,walking=步行。默认 transit" + } + }, + "required": ["origin", "destination", "origin_city"] + } + } + }, + { + "type": "function", + "function": { + "name": "get_travel_info", + "description": "【旅行工具】获取目的地城市的旅行信息(天气、景点、交通)。仅当用户明确表示要去某城市旅游并询问相关信息时使用,如'我想去杭州玩,帮我看看'、'北京旅游攻略'。", + "parameters": { + "type": "object", + "properties": { + "destination": { + "type": "string", + "description": "目的地城市,如:杭州、成都" + }, + "origin": { + "type": "string", + "description": "出发城市,如:北京、上海。填写后会规划交通路线" + } + }, + "required": ["destination"] + } + } + } + ] + + async def execute_llm_tool( + self, + tool_name: str, + arguments: Dict[str, Any], + bot, + from_wxid: str + ) -> Dict[str, Any]: + """执行 LLM 工具调用""" + + if not self.amap: + return {"success": False, "message": "高德 API 未配置,请联系管理员设置 API Key"} + + try: + if tool_name == "search_location": + return await self._tool_search_location(arguments) + elif tool_name == "query_weather": + return await self._tool_query_weather(arguments) + elif tool_name == "search_poi": + return await self._tool_search_poi(arguments) + elif tool_name == "search_nearby": + return await self._tool_search_nearby(arguments) + elif tool_name == "plan_route": + return await self._tool_plan_route(arguments) + elif tool_name == "get_travel_info": + return await self._tool_get_travel_info(arguments) + else: + return {"success": False, "message": f"未知工具: {tool_name}"} + + except Exception as e: + logger.error(f"TravelPlanner 工具执行失败: {tool_name}, 错误: {e}") + return {"success": False, "message": f"工具执行失败: {str(e)}"} + + # ==================== 工具实现 ==================== + + async def _tool_search_location(self, args: Dict) -> Dict: + """地点搜索工具""" + address = args.get("address", "") + city = args.get("city") + + result = await self.amap.geocode(address, city) + + if not result["success"]: + return {"success": False, "message": result.get("error", "地点搜索失败")} + + return { + "success": True, + "message": f"已找到地点:{result['formatted_address']}", + "data": { + "name": address, + "formatted_address": result["formatted_address"], + "location": result["location"], + "province": result["province"], + "city": result["city"], + "district": result["district"], + "adcode": result["adcode"] + } + } + + async def _tool_query_weather(self, args: Dict) -> Dict: + """天气查询工具""" + city = args.get("city", "") + forecast = args.get("forecast", True) + + extensions = "all" if forecast else "base" + result = await self.amap.get_weather(city, extensions) + + if not result["success"]: + return {"success": False, "message": result.get("error", "天气查询失败")} + + if result["type"] == "live": + return { + "success": True, + "message": f"{result['city']}当前天气:{result['weather']},{result['temperature']}℃", + "data": { + "city": result["city"], + "weather": result["weather"], + "temperature": result["temperature"], + "humidity": result["humidity"], + "wind": f"{result['winddirection']}风 {result['windpower']}级", + "reporttime": result["reporttime"] + } + } + else: + forecasts = result["forecasts"] + weather_text = "\n".join([ + f"- {f['date']} 星期{self._weekday_cn(f['week'])}:白天{f['dayweather']} {f['daytemp']}℃,夜间{f['nightweather']} {f['nighttemp']}℃" + for f in forecasts + ]) + + return { + "success": True, + "message": f"{result['city']}未来天气预报:\n{weather_text}", + "data": { + "city": result["city"], + "province": result["province"], + "forecasts": forecasts, + "reporttime": result["reporttime"] + } + } + + async def _tool_search_poi(self, args: Dict) -> Dict: + """POI 搜索工具""" + city = args.get("city", "") + keyword = args.get("keyword") + category = args.get("category") + limit = min(args.get("limit", 10), 20) + + # 获取 POI 类型代码 + types = None + if category: + poi_types = self.config.get("poi_types", {}) + types = poi_types.get(category) + + result = await self.amap.search_poi( + keywords=keyword, + types=types, + city=city, + citylimit=True, + offset=limit + ) + + if not result["success"]: + return {"success": False, "message": result.get("error", "搜索失败")} + + pois = result["pois"] + if not pois: + return {"success": False, "message": f"在{city}未找到相关地点"} + + # 格式化输出 + poi_list = [] + for i, poi in enumerate(pois, 1): + info = f"{i}. {poi['name']}" + if poi.get("address"): + info += f" - {poi['address']}" + if poi.get("rating"): + info += f" ⭐{poi['rating']}" + if poi.get("cost"): + info += f" 人均¥{poi['cost']}" + poi_list.append(info) + + return { + "success": True, + "message": f"在{city}找到{len(pois)}个结果:\n" + "\n".join(poi_list), + "data": { + "city": city, + "category": category or "全部", + "count": len(pois), + "pois": pois + } + } + + async def _tool_search_nearby(self, args: Dict) -> Dict: + """周边搜索工具""" + location_name = args.get("location", "") + city = args.get("city", "") + keyword = args.get("keyword") + category = args.get("category") + radius = min(args.get("radius", 3000), 50000) + + # 先获取中心点坐标 + geo_result = await self.amap.geocode(location_name, city) + if not geo_result["success"]: + return {"success": False, "message": f"无法定位 {location_name}"} + + location = geo_result["location"] + + # 获取 POI 类型代码 + types = None + if category: + poi_types = self.config.get("poi_types", {}) + types = poi_types.get(category) + + result = await self.amap.search_around( + location=location, + keywords=keyword, + types=types, + radius=radius, + offset=10 + ) + + if not result["success"]: + return {"success": False, "message": result.get("error", "周边搜索失败")} + + pois = result["pois"] + if not pois: + return {"success": False, "message": f"在{location_name}周边未找到相关地点"} + + # 格式化输出 + poi_list = [] + for i, poi in enumerate(pois, 1): + info = f"{i}. {poi['name']}" + if poi.get("distance"): + info += f" ({poi['distance']}米)" + if poi.get("rating"): + info += f" ⭐{poi['rating']}" + poi_list.append(info) + + return { + "success": True, + "message": f"{location_name}周边{radius}米内找到{len(pois)}个结果:\n" + "\n".join(poi_list), + "data": { + "center": location_name, + "radius": radius, + "category": category or "全部", + "count": len(pois), + "pois": pois + } + } + + async def _tool_plan_route(self, args: Dict) -> Dict: + """路线规划工具""" + origin = args.get("origin", "") + destination = args.get("destination", "") + origin_city = args.get("origin_city", "") + destination_city = args.get("destination_city", origin_city) + mode = args.get("mode", "transit") + + # 获取起终点坐标 + origin_geo = await self.amap.geocode(origin, origin_city) + if not origin_geo["success"]: + return {"success": False, "message": f"无法定位起点:{origin}"} + + dest_geo = await self.amap.geocode(destination, destination_city) + if not dest_geo["success"]: + return {"success": False, "message": f"无法定位终点:{destination}"} + + origin_loc = origin_geo["location"] + dest_loc = dest_geo["location"] + + # 根据模式规划路线 + if mode == "driving": + result = await self.amap.route_driving(origin_loc, dest_loc) + if not result["success"]: + return {"success": False, "message": result.get("error", "驾车路线规划失败")} + + distance_km = result["distance"] / 1000 + duration_h = result["duration"] / 3600 + + msg = f"🚗 驾车路线:{origin} → {destination}\n" + msg += f"距离:{distance_km:.1f}公里,预计{self._format_duration(result['duration'])}\n" + if result["tolls"]: + msg += f"收费:约{result['tolls']}元\n" + if result["taxi_cost"]: + msg += f"打车费用:约{result['taxi_cost']}元" + + return { + "success": True, + "message": msg, + "data": result + } + + elif mode == "transit": + result = await self.amap.route_transit( + origin_loc, dest_loc, + city=origin_city, + cityd=destination_city if destination_city != origin_city else None + ) + if not result["success"]: + return {"success": False, "message": result.get("error", "公交路线规划失败")} + + msg = f"🚄 公交/高铁路线:{origin} → {destination}\n" + + for i, transit in enumerate(result["transits"][:2], 1): + msg += f"\n方案{i}:{self._format_duration(transit['duration'])}" + if transit.get("cost"): + msg += f",约{transit['cost']}元" + msg += "\n" + + for seg in transit["segments"]: + if seg["type"] == "walking" and seg["distance"] > 100: + msg += f" 🚶 步行{seg['distance']}米\n" + elif seg["type"] == "bus": + msg += f" 🚌 {seg['name']}:{seg['departure_stop']} → {seg['arrival_stop']}({seg['via_num']}站)\n" + elif seg["type"] == "railway": + msg += f" 🚄 {seg['trip']} {seg['name']}:{seg['departure_stop']} {seg.get('departure_time', '')} → {seg['arrival_stop']} {seg.get('arrival_time', '')}\n" + + return { + "success": True, + "message": msg.strip(), + "data": result + } + + elif mode == "walking": + result = await self.amap.route_walking(origin_loc, dest_loc) + if not result["success"]: + return {"success": False, "message": result.get("error", "步行路线规划失败")} + + return { + "success": True, + "message": f"🚶 步行路线:{origin} → {destination}\n距离:{result['distance']}米,预计{self._format_duration(result['duration'])}", + "data": result + } + + return {"success": False, "message": f"不支持的出行方式:{mode}"} + + async def _tool_get_travel_info(self, args: Dict) -> Dict: + """一键获取旅行信息""" + destination = args.get("destination", "") + origin = args.get("origin") + + info = {"destination": destination} + msg_parts = [f"📍 {destination} 旅行信息\n"] + + # 1. 查询天气 + weather_result = await self.amap.get_weather(destination, "all") + if weather_result["success"]: + info["weather"] = weather_result + msg_parts.append("🌤️ 天气预报:") + for f in weather_result["forecasts"][:3]: + msg_parts.append(f" {f['date']} {f['dayweather']} {f['nighttemp']}~{f['daytemp']}℃") + + # 2. 搜索热门景点 + poi_result = await self.amap.search_poi( + types="110000", # 景点 + city=destination, + citylimit=True, + offset=5 + ) + if poi_result["success"] and poi_result["pois"]: + info["attractions"] = poi_result["pois"] + msg_parts.append("\n🏞️ 热门景点:") + for poi in poi_result["pois"][:5]: + rating = f" ⭐{poi['rating']}" if poi.get("rating") else "" + msg_parts.append(f" • {poi['name']}{rating}") + + # 3. 规划交通路线(如果提供了出发地) + if origin: + origin_geo = await self.amap.geocode(origin) + dest_geo = await self.amap.geocode(destination) + + if origin_geo["success"] and dest_geo["success"]: + route_result = await self.amap.route_transit( + origin_geo["location"], + dest_geo["location"], + city=origin_geo.get("city", origin), + cityd=dest_geo.get("city", destination) + ) + + if route_result["success"] and route_result["transits"]: + info["route"] = route_result + transit = route_result["transits"][0] + msg_parts.append(f"\n🚄 从{origin}出发:") + msg_parts.append(f" 预计{self._format_duration(transit['duration'])}") + + # 显示主要交通工具 + for seg in transit["segments"]: + if seg["type"] == "railway": + msg_parts.append(f" {seg['trip']}:{seg['departure_stop']} → {seg['arrival_stop']}") + break + + return { + "success": True, + "message": "\n".join(msg_parts), + "data": info + } + + # ==================== 辅助方法 ==================== + + def _weekday_cn(self, week: str) -> str: + """星期数字转中文""" + mapping = {"1": "一", "2": "二", "3": "三", "4": "四", "5": "五", "6": "六", "7": "日"} + return mapping.get(str(week), week) + + def _format_duration(self, seconds: int) -> str: + """格式化时长""" + if seconds < 60: + return f"{seconds}秒" + elif seconds < 3600: + return f"{seconds // 60}分钟" + else: + hours = seconds // 3600 + minutes = (seconds % 3600) // 60 + if minutes: + return f"{hours}小时{minutes}分钟" + return f"{hours}小时" diff --git a/utils/image_processor.py b/utils/image_processor.py index 1aae4cc..432cde5 100644 --- a/utils/image_processor.py +++ b/utils/image_processor.py @@ -28,6 +28,7 @@ from __future__ import annotations import asyncio import base64 +import io import json import uuid from dataclasses import dataclass, field @@ -37,6 +38,14 @@ from typing import Any, Dict, Optional, TYPE_CHECKING import aiohttp from loguru import logger +# 图片处理支持 +try: + from PIL import Image + PIL_AVAILABLE = True +except ImportError: + PIL_AVAILABLE = False + logger.warning("[ImageProcessor] Pillow 未安装,GIF 转换功能不可用") + # 可选代理支持 try: from aiohttp_socks import ProxyConnector @@ -433,6 +442,60 @@ class ImageProcessor: logger.error(traceback.format_exc()) return "" + def _convert_gif_to_png(self, image_base64: str) -> str: + """ + 将 GIF 图片转换为 PNG(提取第一帧) + + Args: + image_base64: GIF 图片的 base64 数据(带 data URI 前缀) + + Returns: + PNG 图片的 base64 数据(带 data URI 前缀),失败返回原数据 + """ + if not PIL_AVAILABLE: + logger.warning("[ImageProcessor] Pillow 未安装,无法转换 GIF") + return image_base64 + + try: + # 提取 base64 数据部分 + if "," in image_base64: + base64_data = image_base64.split(",", 1)[1] + else: + base64_data = image_base64 + + # 解码 base64 + gif_bytes = base64.b64decode(base64_data) + + # 使用 Pillow 打开 GIF 并提取第一帧 + img = Image.open(io.BytesIO(gif_bytes)) + + # 转换为 RGB 模式(去除透明通道) + if img.mode in ('RGBA', 'LA', 'P'): + # 创建白色背景 + background = Image.new('RGB', img.size, (255, 255, 255)) + if img.mode == 'P': + img = img.convert('RGBA') + background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None) + img = background + elif img.mode != 'RGB': + img = img.convert('RGB') + + # 保存为 PNG + output = io.BytesIO() + img.save(output, format='PNG', optimize=True) + png_bytes = output.getvalue() + + # 编码为 base64 + png_base64 = base64.b64encode(png_bytes).decode() + result = f"data:image/png;base64,{png_base64}" + + logger.debug(f"[ImageProcessor] GIF 已转换为 PNG,原大小: {len(gif_bytes)} 字节,新大小: {len(png_bytes)} 字节") + return result + + except Exception as e: + logger.error(f"[ImageProcessor] GIF 转换失败: {e}") + return image_base64 + async def generate_description( self, image_base64: str, @@ -450,6 +513,11 @@ class ImageProcessor: Returns: 图片描述文本,失败返回空字符串 """ + # 检测并转换 GIF 格式(大多数视觉 API 不支持 GIF) + if image_base64.startswith("data:image/gif"): + logger.debug("[ImageProcessor] 检测到 GIF 格式,转换为 PNG...") + image_base64 = self._convert_gif_to_png(image_base64) + description_model = model or self.config.model messages = [ @@ -612,6 +680,13 @@ class ImageProcessor: logger.error(f"[ImageProcessor] 视频 API 错误: {resp.status}, {error_text[:300]}") return "" + # 检查响应类型是否为 JSON + content_type = resp.headers.get('Content-Type', '') + if 'application/json' not in content_type: + error_text = await resp.text() + logger.error(f"[ImageProcessor] 视频 API 返回非 JSON 响应: Content-Type={content_type}, Body={error_text[:500]}") + return "" + result = await resp.json() # 检查安全过滤