""" TavilySearch 联网搜索插件 基于 Tavily API 的联网搜索功能,仅作为 LLM Tool 供 AIChat 调用 支持多 API Key 轮询,搜索结果返回给 AIChat 的 AI 处理(带上下文和人设) """ import tomllib import aiohttp import uuid import asyncio import re from pathlib import Path from typing import List, Optional from loguru import logger from utils.plugin_base import PluginBase class TavilySearch(PluginBase): """Tavily 联网搜索插件 - 仅作为 LLM Tool""" description = "Tavily 联网搜索 - 支持多 Key 轮询的搜索工具" author = "Assistant" version = "1.0.0" def __init__(self): super().__init__() self.config = None self.api_keys = [] self.current_key_index = 0 self.temp_dir: Optional[Path] = None async def async_init(self): """异步初始化""" try: config_path = Path(__file__).parent / "config.toml" if not config_path.exists(): logger.error(f"TavilySearch 配置文件不存在: {config_path}") return with open(config_path, "rb") as f: self.config = tomllib.load(f) self.temp_dir = Path(__file__).parent / "temp" self.temp_dir.mkdir(exist_ok=True) self.api_keys = self._load_api_keys() if not self.api_keys: logger.warning("TavilySearch: 未配置有效的 API Key") else: logger.success(f"TavilySearch 已加载,共 {len(self.api_keys)} 个 API Key") except Exception as e: logger.error(f"TavilySearch 初始化失败: {e}") self.config = None def _load_api_keys(self) -> List[str]: """从配置加载 API Keys(兼容 api_key / api_keys)""" if not self.config: return [] tavily_config = self.config.get("tavily", {}) keys: List[str] = [] raw_keys = tavily_config.get("api_keys", []) if isinstance(raw_keys, str): keys.extend([k.strip() for k in raw_keys.replace("\n", ",").split(",")]) elif isinstance(raw_keys, list): keys.extend([str(k).strip() for k in raw_keys]) single_key = str(tavily_config.get("api_key", "")).strip() if single_key: keys.append(single_key) cleaned = [] seen = set() for k in keys: if not k or k.startswith("#"): continue if k in seen: continue seen.add(k) cleaned.append(k) return cleaned def _get_next_api_key(self) -> str: """轮询获取下一个 API Key""" if not self.api_keys: return "" key = self.api_keys[self.current_key_index] self.current_key_index = (self.current_key_index + 1) % len(self.api_keys) return key def _clean_query_text(self, text: str) -> str: """清洗查询文本""" cleaned = str(text or "").strip() if not cleaned: return "" cleaned = cleaned.replace("【当前消息】", "").strip() cleaned = re.sub(r"^(?:@\S+\s*)+", "", cleaned) cleaned = re.sub( r"^(?:请|帮我|麻烦|请帮我)?(?:搜索|搜|查|查询|检索|搜一下|查一下|搜索下|搜下)\s*", "", cleaned, ) return cleaned.strip() def _extract_topic_hint(self, query: str) -> str: """提取主题前缀,用于补全后续子问题上下文""" text = self._clean_query_text(query) if not text: return "" first_part = text for sep in ("和", "以及", "并且", "还有", "同时", ",", ",", ";", ";", "。"): idx = first_part.find(sep) if idx > 0: first_part = first_part[:idx].strip() break match = re.match(r"^(.{2,40}?)(?:的|是|有哪些|包括|改动|更新|介绍|详情|内容|情况)", first_part) topic_hint = match.group(1).strip() if match else "" if not topic_hint and len(first_part) <= 40: topic_hint = first_part topic_hint = re.sub(r"(是什么|有哪些|有啥|是什么样).*$", "", topic_hint).strip() return topic_hint def _split_multi_queries(self, query: str, tavily_config: dict) -> List[str]: """将复合问题拆分为多个子查询""" split_debug_log = bool(tavily_config.get("split_debug_log", False)) raw = self._clean_query_text(query) if not raw: return [] if split_debug_log: logger.info(f"[TavilySplit] 原始查询: {query}") logger.info(f"[TavilySplit] 清洗后查询: {raw}") max_sub_queries = int(tavily_config.get("max_sub_queries", 4) or 4) split_min_chars = int(tavily_config.get("split_min_chars", 6) or 6) prepend_context = bool(tavily_config.get("prepend_context_for_sub_query", True)) normalized = raw normalized = re.sub(r"(另外|此外|同时|并且|还有|以及|然后|再者|顺便)", "|", normalized) normalized = re.sub(r"[;;。!?!?\n\r]+", "|", normalized) parts = [ p.strip(" ,,、||") for p in normalized.split("|") if p.strip(" ,,、||") ] if split_debug_log: logger.info(f"[TavilySplit] 初步拆分片段: {parts}") if len(parts) == 1: single = parts[0] if "和" in single and len(single) >= split_min_chars * 2: candidate = re.split(r"\s*和\s*", single, maxsplit=1) if len(candidate) == 2: left = candidate[0].strip() right = candidate[1].strip() if len(left) >= split_min_chars and len(right) >= split_min_chars: parts = [left, right] if split_debug_log: logger.info(f"[TavilySplit] 通过“和”二次拆分: {parts}") # 语义拆分兜底:即使没有明显连接词,也尽量把“版本改动 + 英雄技能介绍”拆开 if len(parts) == 1: single = parts[0].strip() change_keywords = ("改动", "更新", "变更", "调整", "改版", "平衡") hero_keywords = ("新英雄", "英雄", "技能", "机制", "天赋", "介绍", "详解") change_pos = min([single.find(k) for k in change_keywords if k in single] or [-1]) hero_pos = min([single.find(k) for k in hero_keywords if k in single] or [-1]) if change_pos >= 0 and hero_pos >= 0 and hero_pos > change_pos: left = single[:hero_pos].strip(" ,,、") right = single[hero_pos:].strip(" ,,、") if len(left) >= split_min_chars and len(right) >= split_min_chars: topic_hint = self._extract_topic_hint(left or single) if topic_hint and topic_hint not in right: right = f"{topic_hint} {right}".strip() parts = [left, right] if split_debug_log: logger.info(f"[TavilySplit] 语义兜底拆分: {parts}") deduped: List[str] = [] seen = set() for p in parts: if len(p) < split_min_chars: continue if p in seen: continue seen.add(p) deduped.append(p) parts = deduped[:max_sub_queries] if deduped else [raw] if split_debug_log: logger.info(f"[TavilySplit] 去重截断后: {parts}") if prepend_context and len(parts) > 1: topic_hint = self._extract_topic_hint(parts[0] or raw) if topic_hint: with_context: List[str] = [] for idx, p in enumerate(parts): item = p if idx > 0 and topic_hint not in item: item = f"{topic_hint} {item}".strip() with_context.append(item) parts = with_context if split_debug_log: logger.info(f"[TavilySplit] 主题前缀: {topic_hint}") logger.info(f"[TavilySplit] 前缀补全后: {parts}") if split_debug_log: logger.info(f"[TavilySplit] 最终子查询({len(parts)}): {parts}") return parts def _truncate_text(self, text: str, max_chars: int) -> str: """按字符数截断文本""" content = str(text or "").strip() if max_chars <= 0 or len(content) <= max_chars: return content return content[:max_chars].rstrip() + "..." async def _search_tavily(self, query: str) -> Optional[dict]: """调用 Tavily API 进行搜索""" tavily_config = self.config["tavily"] proxy_config = self.config.get("proxy", {}) proxy = None if proxy_config.get("enabled", False): proxy_type = proxy_config.get("type", "http") proxy_host = proxy_config.get("host", "127.0.0.1") proxy_port = proxy_config.get("port", 7890) proxy = f"{proxy_type}://{proxy_host}:{proxy_port}" try: import ssl timeout = aiohttp.ClientTimeout(total=30) # SSL 配置 ssl_config = self.config.get("ssl", {}) ssl_verify = ssl_config.get("verify", True) connector = None if not ssl_verify: # 跳过 SSL 验证 ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE connector = aiohttp.TCPConnector(ssl=ssl_context) if not self.api_keys: logger.error("没有可用的 Tavily API Key") return None max_attempts = min(len(self.api_keys), tavily_config.get("max_key_attempts", len(self.api_keys))) async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: for attempt in range(max_attempts): api_key = self._get_next_api_key() if not api_key: logger.error("没有可用的 Tavily API Key") return None payload = { "api_key": api_key, "query": query, "search_depth": tavily_config.get("search_depth", "basic"), "max_results": tavily_config.get("max_results", 5), "include_raw_content": ( tavily_config.get("include_raw_content", False) or tavily_config.get("use_raw_content_in_result", False) ), "include_images": tavily_config.get("include_images", False), } async with session.post( "https://api.tavily.com/search", json=payload, proxy=proxy ) as resp: if resp.status == 200: result = await resp.json() logger.info(f"Tavily 搜索成功: {query[:30]}...") logger.info(f"Tavily 原始返回: {result}") return result error_text = await resp.text() logger.warning( f"Tavily API 错误: {resp.status}, 尝试 key {attempt + 1}/{max_attempts}, " f"body={error_text[:200]}" ) if resp.status in {401, 403, 429}: continue return None except Exception as e: logger.error(f"Tavily 搜索失败: {e}") return None def _extract_image_urls(self, results: dict) -> List[str]: """从搜索结果中提取图片 URL""" if not results: return [] images = results.get("images", []) urls: List[str] = [] for item in images: if isinstance(item, str): url = item.strip() elif isinstance(item, dict): url = (item.get("url") or item.get("image") or item.get("src") or "").strip() else: url = "" if url: urls.append(url) return urls async def _download_image_with_session( self, session: aiohttp.ClientSession, url: str, proxy: Optional[str], max_retries: int = 1 ) -> Optional[str]: """下载图片到本地临时目录(复用 session)""" if not self.temp_dir: return None for attempt in range(max_retries + 1): try: async with session.get(url, proxy=proxy) as resp: if resp.status != 200: if attempt >= max_retries: return None await asyncio.sleep(0.5 * (attempt + 1)) continue content = await resp.read() ext = Path(url).suffix.lower() if ext not in {".jpg", ".jpeg", ".png", ".webp"}: ext = ".jpg" filename = f"tavily_{uuid.uuid4().hex}{ext}" save_path = self.temp_dir / filename with open(save_path, "wb") as f: f.write(content) return str(save_path) except Exception as e: if attempt < max_retries: await asyncio.sleep(0.5 * (attempt + 1)) continue logger.warning(f"下载图片失败: {url} -> {e}") return None async def _download_image(self, url: str) -> Optional[str]: """下载图片到本地临时目录(兼容旧调用)""" if not self.temp_dir: return None try: import ssl timeout = aiohttp.ClientTimeout(total=30) proxy_config = self.config.get("proxy", {}) if self.config else {} proxy = None if proxy_config.get("enabled", False): proxy_type = proxy_config.get("type", "http") proxy_host = proxy_config.get("host", "127.0.0.1") proxy_port = proxy_config.get("port", 7890) proxy = f"{proxy_type}://{proxy_host}:{proxy_port}" ssl_config = self.config.get("ssl", {}) if self.config else {} ssl_verify = ssl_config.get("verify", True) ssl_context = None if not ssl_verify: ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE connector = aiohttp.TCPConnector(ssl=ssl_context) if ssl_context else None async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: return await self._download_image_with_session(session, url, proxy, max_retries=1) except Exception as e: logger.warning(f"下载图片失败: {url} -> {e}") return None def _format_search_results( self, results: dict, *, include_raw_content: bool = False, raw_content_max_chars: int = 1800, section_title: Optional[str] = None, ) -> str: """格式化搜索结果供 AI 处理""" if not results or "results" not in results: if section_title: return f"{section_title}\n未找到相关搜索结果" return "未找到相关搜索结果" formatted = [] if section_title: formatted.append(section_title) for i, item in enumerate(results["results"], 1): title = item.get("title", "无标题") content = item.get("content", "") url = item.get("url", "") block = [ f"【结果 {i}】", f"标题: {title}", f"内容: {content}", f"来源: {url}", ] if include_raw_content: raw_content = self._truncate_text(item.get("raw_content", ""), raw_content_max_chars) if raw_content: block.append(f"原文摘录: {raw_content}") formatted.append("\n".join(block) + "\n") return "\n".join(formatted) def get_llm_tools(self) -> List[dict]: """返回 LLM 工具定义""" if not self.config or not self.config["behavior"]["enabled"]: return [] return [ { "type": "function", "function": { "name": "tavily_web_search", "description": ( "执行联网检索并返回可引用的信息来源。" "仅在用户明确要求查资料、最新信息、权威来源或需要事实核实时调用;" "可直接回答的问题不要触发该工具。" ), "parameters": { "type": "object", "properties": { "query": { "type": "string", "description": "检索问题或关键词。应简洁、明确,避免口语噪声。" } }, "required": ["query"], "additionalProperties": False } } } ] async def execute_llm_tool(self, tool_name: str, arguments: dict, bot, from_wxid: str) -> dict: """ 执行 LLM 工具调用 只负责搜索,返回结果给 AIChat 的 AI 处理(带上下文和人设) """ if tool_name != "tavily_web_search": return None if not self.config or not self.config["behavior"]["enabled"]: return {"success": False, "message": "TavilySearch 插件未启用"} if not self.api_keys: return {"success": False, "message": "未配置 Tavily API Key"} query = arguments.get("query", "") if not query: return {"success": False, "message": "搜索关键词不能为空"} tavily_config = self.config.get("tavily", {}) multi_query_split = bool(tavily_config.get("multi_query_split", True)) use_raw_content_in_result = bool(tavily_config.get("use_raw_content_in_result", False)) raw_content_max_chars = int(tavily_config.get("raw_content_max_chars", 1800) or 1800) try: logger.info(f"开始 Tavily 搜索: {query}") split_debug_log = bool(tavily_config.get("split_debug_log", False)) if multi_query_split: sub_queries = self._split_multi_queries(query, tavily_config) else: cleaned_query = self._clean_query_text(query) sub_queries = [cleaned_query] if cleaned_query else [str(query).strip()] if not sub_queries: return {"success": False, "message": "搜索关键词不能为空"} if split_debug_log: logger.info(f"Tavily 子问题拆分完成,共 {len(sub_queries)} 个: {sub_queries}") else: logger.info(f"Tavily 子问题拆分完成,共 {len(sub_queries)} 个") search_batches = [] failed_queries = [] for sub_query in sub_queries: result = await self._search_tavily(sub_query) if result: search_batches.append((sub_query, result)) else: failed_queries.append(sub_query) if not search_batches: return {"success": False, "message": "搜索失败,请稍后重试"} # 发送搜索图片(若开启 include_images) if tavily_config.get("include_images", False): image_urls = [] for _sub_query, sub_result in search_batches: image_urls.extend(self._extract_image_urls(sub_result)) if image_urls: image_urls = list(dict.fromkeys(image_urls)) max_images = int(tavily_config.get("max_images", 3) or 3) download_concurrency = int(tavily_config.get("image_download_concurrency", 3) or 3) download_retries = int(tavily_config.get("image_download_retries", 1) or 1) download_timeout = int(tavily_config.get("image_download_timeout", 30) or 30) import ssl timeout = aiohttp.ClientTimeout(total=download_timeout) proxy_config = self.config.get("proxy", {}) if self.config else {} proxy = None if proxy_config.get("enabled", False): proxy_type = proxy_config.get("type", "http") proxy_host = proxy_config.get("host", "127.0.0.1") proxy_port = proxy_config.get("port", 7890) proxy = f"{proxy_type}://{proxy_host}:{proxy_port}" ssl_config = self.config.get("ssl", {}) if self.config else {} ssl_verify = ssl_config.get("verify", True) ssl_context = None if not ssl_verify: ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE connector = aiohttp.TCPConnector(ssl=ssl_context) if ssl_context else None semaphore = asyncio.Semaphore(max(1, download_concurrency)) async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: async def fetch_image(url: str) -> Optional[str]: async with semaphore: return await self._download_image_with_session( session, url, proxy, max_retries=download_retries ) tasks = [fetch_image(url) for url in image_urls[:max_images]] results = await asyncio.gather(*tasks, return_exceptions=True) sent = 0 for result in results: if sent >= max_images: break if isinstance(result, str) and result: await bot.send_image(from_wxid, result) sent += 1 # 格式化搜索结果 if len(search_batches) == 1: formatted_results = self._format_search_results( search_batches[0][1], include_raw_content=use_raw_content_in_result, raw_content_max_chars=raw_content_max_chars, ) else: sections = [] for idx, (sub_query, sub_result) in enumerate(search_batches, 1): sections.append( self._format_search_results( sub_result, include_raw_content=use_raw_content_in_result, raw_content_max_chars=raw_content_max_chars, section_title=f"【子问题 {idx}】{sub_query}", ) ) formatted_results = "\n\n".join(sections) if failed_queries: failed_text = "\n".join([f"- {q}" for q in failed_queries]) formatted_results = ( f"{formatted_results}\n\n" f"【未检索成功的子问题】\n{failed_text}" ) logger.success(f"Tavily 搜索完成: {query[:30]}...") # 返回搜索结果,标记需要 AI 继续处理 return { "success": True, "message": formatted_results, "need_ai_reply": True # 标记需要 AI 基于此结果继续回复 } except Exception as e: logger.error(f"Tavily 搜索执行失败: {e}") import traceback logger.error(traceback.format_exc()) return {"success": False, "message": f"搜索失败: {str(e)}"}