Files

620 lines
24 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
TavilySearch 联网搜索插件
基于 Tavily API 的联网搜索功能,仅作为 LLM Tool 供 AIChat 调用
支持多 API Key 轮询,搜索结果返回给 AIChat 的 AI 处理(带上下文和人设)
"""
import tomllib
import aiohttp
import uuid
import asyncio
import re
from pathlib import Path
from typing import List, Optional
from loguru import logger
from utils.plugin_base import PluginBase
class TavilySearch(PluginBase):
"""Tavily 联网搜索插件 - 仅作为 LLM Tool"""
description = "Tavily 联网搜索 - 支持多 Key 轮询的搜索工具"
author = "Assistant"
version = "1.0.0"
def __init__(self):
super().__init__()
self.config = None
self.api_keys = []
self.current_key_index = 0
self.temp_dir: Optional[Path] = None
async def async_init(self):
"""异步初始化"""
try:
config_path = Path(__file__).parent / "config.toml"
if not config_path.exists():
logger.error(f"TavilySearch 配置文件不存在: {config_path}")
return
with open(config_path, "rb") as f:
self.config = tomllib.load(f)
self.temp_dir = Path(__file__).parent / "temp"
self.temp_dir.mkdir(exist_ok=True)
self.api_keys = self._load_api_keys()
if not self.api_keys:
logger.warning("TavilySearch: 未配置有效的 API Key")
else:
logger.success(f"TavilySearch 已加载,共 {len(self.api_keys)} 个 API Key")
except Exception as e:
logger.error(f"TavilySearch 初始化失败: {e}")
self.config = None
def _load_api_keys(self) -> List[str]:
"""从配置加载 API Keys兼容 api_key / api_keys"""
if not self.config:
return []
tavily_config = self.config.get("tavily", {})
keys: List[str] = []
raw_keys = tavily_config.get("api_keys", [])
if isinstance(raw_keys, str):
keys.extend([k.strip() for k in raw_keys.replace("\n", ",").split(",")])
elif isinstance(raw_keys, list):
keys.extend([str(k).strip() for k in raw_keys])
single_key = str(tavily_config.get("api_key", "")).strip()
if single_key:
keys.append(single_key)
cleaned = []
seen = set()
for k in keys:
if not k or k.startswith("#"):
continue
if k in seen:
continue
seen.add(k)
cleaned.append(k)
return cleaned
def _get_next_api_key(self) -> str:
"""轮询获取下一个 API Key"""
if not self.api_keys:
return ""
key = self.api_keys[self.current_key_index]
self.current_key_index = (self.current_key_index + 1) % len(self.api_keys)
return key
def _clean_query_text(self, text: str) -> str:
"""清洗查询文本"""
cleaned = str(text or "").strip()
if not cleaned:
return ""
cleaned = cleaned.replace("【当前消息】", "").strip()
cleaned = re.sub(r"^(?:@\S+\s*)+", "", cleaned)
cleaned = re.sub(
r"^(?:请|帮我|麻烦|请帮我)?(?:搜索|搜|查|查询|检索|搜一下|查一下|搜索下|搜下)\s*",
"",
cleaned,
)
return cleaned.strip()
def _extract_topic_hint(self, query: str) -> str:
"""提取主题前缀,用于补全后续子问题上下文"""
text = self._clean_query_text(query)
if not text:
return ""
first_part = text
for sep in ("", "以及", "并且", "还有", "同时", "", ",", "", ";", ""):
idx = first_part.find(sep)
if idx > 0:
first_part = first_part[:idx].strip()
break
match = re.match(r"^(.{2,40}?)(?:的|是|有哪些|包括|改动|更新|介绍|详情|内容|情况)", first_part)
topic_hint = match.group(1).strip() if match else ""
if not topic_hint and len(first_part) <= 40:
topic_hint = first_part
topic_hint = re.sub(r"(是什么|有哪些|有啥|是什么样).*$", "", topic_hint).strip()
return topic_hint
def _split_multi_queries(self, query: str, tavily_config: dict) -> List[str]:
"""将复合问题拆分为多个子查询"""
split_debug_log = bool(tavily_config.get("split_debug_log", False))
raw = self._clean_query_text(query)
if not raw:
return []
if split_debug_log:
logger.info(f"[TavilySplit] 原始查询: {query}")
logger.info(f"[TavilySplit] 清洗后查询: {raw}")
max_sub_queries = int(tavily_config.get("max_sub_queries", 4) or 4)
split_min_chars = int(tavily_config.get("split_min_chars", 6) or 6)
prepend_context = bool(tavily_config.get("prepend_context_for_sub_query", True))
normalized = raw
normalized = re.sub(r"(另外|此外|同时|并且|还有|以及|然后|再者|顺便)", "", normalized)
normalized = re.sub(r"[;。!?!?\n\r]+", "", normalized)
parts = [
p.strip(" ,、|")
for p in normalized.split("")
if p.strip(" ,、|")
]
if split_debug_log:
logger.info(f"[TavilySplit] 初步拆分片段: {parts}")
if len(parts) == 1:
single = parts[0]
if "" in single and len(single) >= split_min_chars * 2:
candidate = re.split(r"\s*和\s*", single, maxsplit=1)
if len(candidate) == 2:
left = candidate[0].strip()
right = candidate[1].strip()
if len(left) >= split_min_chars and len(right) >= split_min_chars:
parts = [left, right]
if split_debug_log:
logger.info(f"[TavilySplit] 通过“和”二次拆分: {parts}")
# 语义拆分兜底:即使没有明显连接词,也尽量把“版本改动 + 英雄技能介绍”拆开
if len(parts) == 1:
single = parts[0].strip()
change_keywords = ("改动", "更新", "变更", "调整", "改版", "平衡")
hero_keywords = ("新英雄", "英雄", "技能", "机制", "天赋", "介绍", "详解")
change_pos = min([single.find(k) for k in change_keywords if k in single] or [-1])
hero_pos = min([single.find(k) for k in hero_keywords if k in single] or [-1])
if change_pos >= 0 and hero_pos >= 0 and hero_pos > change_pos:
left = single[:hero_pos].strip(" ,、")
right = single[hero_pos:].strip(" ,、")
if len(left) >= split_min_chars and len(right) >= split_min_chars:
topic_hint = self._extract_topic_hint(left or single)
if topic_hint and topic_hint not in right:
right = f"{topic_hint} {right}".strip()
parts = [left, right]
if split_debug_log:
logger.info(f"[TavilySplit] 语义兜底拆分: {parts}")
deduped: List[str] = []
seen = set()
for p in parts:
if len(p) < split_min_chars:
continue
if p in seen:
continue
seen.add(p)
deduped.append(p)
parts = deduped[:max_sub_queries] if deduped else [raw]
if split_debug_log:
logger.info(f"[TavilySplit] 去重截断后: {parts}")
if prepend_context and len(parts) > 1:
topic_hint = self._extract_topic_hint(parts[0] or raw)
if topic_hint:
with_context: List[str] = []
for idx, p in enumerate(parts):
item = p
if idx > 0 and topic_hint not in item:
item = f"{topic_hint} {item}".strip()
with_context.append(item)
parts = with_context
if split_debug_log:
logger.info(f"[TavilySplit] 主题前缀: {topic_hint}")
logger.info(f"[TavilySplit] 前缀补全后: {parts}")
if split_debug_log:
logger.info(f"[TavilySplit] 最终子查询({len(parts)}): {parts}")
return parts
def _truncate_text(self, text: str, max_chars: int) -> str:
"""按字符数截断文本"""
content = str(text or "").strip()
if max_chars <= 0 or len(content) <= max_chars:
return content
return content[:max_chars].rstrip() + "..."
async def _search_tavily(self, query: str) -> Optional[dict]:
"""调用 Tavily API 进行搜索"""
tavily_config = self.config["tavily"]
proxy_config = self.config.get("proxy", {})
proxy = None
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "http")
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
try:
import ssl
timeout = aiohttp.ClientTimeout(total=30)
# SSL 配置
ssl_config = self.config.get("ssl", {})
ssl_verify = ssl_config.get("verify", True)
connector = None
if not ssl_verify:
# 跳过 SSL 验证
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
connector = aiohttp.TCPConnector(ssl=ssl_context)
if not self.api_keys:
logger.error("没有可用的 Tavily API Key")
return None
max_attempts = min(len(self.api_keys), tavily_config.get("max_key_attempts", len(self.api_keys)))
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
for attempt in range(max_attempts):
api_key = self._get_next_api_key()
if not api_key:
logger.error("没有可用的 Tavily API Key")
return None
payload = {
"api_key": api_key,
"query": query,
"search_depth": tavily_config.get("search_depth", "basic"),
"max_results": tavily_config.get("max_results", 5),
"include_raw_content": (
tavily_config.get("include_raw_content", False)
or tavily_config.get("use_raw_content_in_result", False)
),
"include_images": tavily_config.get("include_images", False),
}
async with session.post(
"https://api.tavily.com/search",
json=payload,
proxy=proxy
) as resp:
if resp.status == 200:
result = await resp.json()
logger.info(f"Tavily 搜索成功: {query[:30]}...")
logger.info(f"Tavily 原始返回: {result}")
return result
error_text = await resp.text()
logger.warning(
f"Tavily API 错误: {resp.status}, 尝试 key {attempt + 1}/{max_attempts}, "
f"body={error_text[:200]}"
)
if resp.status in {401, 403, 429}:
continue
return None
except Exception as e:
logger.error(f"Tavily 搜索失败: {e}")
return None
def _extract_image_urls(self, results: dict) -> List[str]:
"""从搜索结果中提取图片 URL"""
if not results:
return []
images = results.get("images", [])
urls: List[str] = []
for item in images:
if isinstance(item, str):
url = item.strip()
elif isinstance(item, dict):
url = (item.get("url") or item.get("image") or item.get("src") or "").strip()
else:
url = ""
if url:
urls.append(url)
return urls
async def _download_image_with_session(
self,
session: aiohttp.ClientSession,
url: str,
proxy: Optional[str],
max_retries: int = 1
) -> Optional[str]:
"""下载图片到本地临时目录(复用 session"""
if not self.temp_dir:
return None
for attempt in range(max_retries + 1):
try:
async with session.get(url, proxy=proxy) as resp:
if resp.status != 200:
if attempt >= max_retries:
return None
await asyncio.sleep(0.5 * (attempt + 1))
continue
content = await resp.read()
ext = Path(url).suffix.lower()
if ext not in {".jpg", ".jpeg", ".png", ".webp"}:
ext = ".jpg"
filename = f"tavily_{uuid.uuid4().hex}{ext}"
save_path = self.temp_dir / filename
with open(save_path, "wb") as f:
f.write(content)
return str(save_path)
except Exception as e:
if attempt < max_retries:
await asyncio.sleep(0.5 * (attempt + 1))
continue
logger.warning(f"下载图片失败: {url} -> {e}")
return None
async def _download_image(self, url: str) -> Optional[str]:
"""下载图片到本地临时目录(兼容旧调用)"""
if not self.temp_dir:
return None
try:
import ssl
timeout = aiohttp.ClientTimeout(total=30)
proxy_config = self.config.get("proxy", {}) if self.config else {}
proxy = None
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "http")
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
ssl_config = self.config.get("ssl", {}) if self.config else {}
ssl_verify = ssl_config.get("verify", True)
ssl_context = None
if not ssl_verify:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
connector = aiohttp.TCPConnector(ssl=ssl_context) if ssl_context else None
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
return await self._download_image_with_session(session, url, proxy, max_retries=1)
except Exception as e:
logger.warning(f"下载图片失败: {url} -> {e}")
return None
def _format_search_results(
self,
results: dict,
*,
include_raw_content: bool = False,
raw_content_max_chars: int = 1800,
section_title: Optional[str] = None,
) -> str:
"""格式化搜索结果供 AI 处理"""
if not results or "results" not in results:
if section_title:
return f"{section_title}\n未找到相关搜索结果"
return "未找到相关搜索结果"
formatted = []
if section_title:
formatted.append(section_title)
for i, item in enumerate(results["results"], 1):
title = item.get("title", "无标题")
content = item.get("content", "")
url = item.get("url", "")
block = [
f"【结果 {i}",
f"标题: {title}",
f"内容: {content}",
f"来源: {url}",
]
if include_raw_content:
raw_content = self._truncate_text(item.get("raw_content", ""), raw_content_max_chars)
if raw_content:
block.append(f"原文摘录: {raw_content}")
formatted.append("\n".join(block) + "\n")
return "\n".join(formatted)
def get_llm_tools(self) -> List[dict]:
"""返回 LLM 工具定义"""
if not self.config or not self.config["behavior"]["enabled"]:
return []
return [
{
"type": "function",
"function": {
"name": "tavily_web_search",
"description": (
"执行联网检索并返回可引用的信息来源。"
"仅在用户明确要求查资料、最新信息、权威来源或需要事实核实时调用;"
"可直接回答的问题不要触发该工具。"
),
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "检索问题或关键词。应简洁、明确,避免口语噪声。"
}
},
"required": ["query"],
"additionalProperties": False
}
}
}
]
async def execute_llm_tool(self, tool_name: str, arguments: dict, bot, from_wxid: str) -> dict:
"""
执行 LLM 工具调用
只负责搜索,返回结果给 AIChat 的 AI 处理(带上下文和人设)
"""
if tool_name != "tavily_web_search":
return None
if not self.config or not self.config["behavior"]["enabled"]:
return {"success": False, "message": "TavilySearch 插件未启用"}
if not self.api_keys:
return {"success": False, "message": "未配置 Tavily API Key"}
query = arguments.get("query", "")
if not query:
return {"success": False, "message": "搜索关键词不能为空"}
tavily_config = self.config.get("tavily", {})
multi_query_split = bool(tavily_config.get("multi_query_split", True))
use_raw_content_in_result = bool(tavily_config.get("use_raw_content_in_result", False))
raw_content_max_chars = int(tavily_config.get("raw_content_max_chars", 1800) or 1800)
try:
logger.info(f"开始 Tavily 搜索: {query}")
split_debug_log = bool(tavily_config.get("split_debug_log", False))
if multi_query_split:
sub_queries = self._split_multi_queries(query, tavily_config)
else:
cleaned_query = self._clean_query_text(query)
sub_queries = [cleaned_query] if cleaned_query else [str(query).strip()]
if not sub_queries:
return {"success": False, "message": "搜索关键词不能为空"}
if split_debug_log:
logger.info(f"Tavily 子问题拆分完成,共 {len(sub_queries)} 个: {sub_queries}")
else:
logger.info(f"Tavily 子问题拆分完成,共 {len(sub_queries)}")
search_batches = []
failed_queries = []
for sub_query in sub_queries:
result = await self._search_tavily(sub_query)
if result:
search_batches.append((sub_query, result))
else:
failed_queries.append(sub_query)
if not search_batches:
return {"success": False, "message": "搜索失败,请稍后重试"}
# 发送搜索图片(若开启 include_images
if tavily_config.get("include_images", False):
image_urls = []
for _sub_query, sub_result in search_batches:
image_urls.extend(self._extract_image_urls(sub_result))
if image_urls:
image_urls = list(dict.fromkeys(image_urls))
max_images = int(tavily_config.get("max_images", 3) or 3)
download_concurrency = int(tavily_config.get("image_download_concurrency", 3) or 3)
download_retries = int(tavily_config.get("image_download_retries", 1) or 1)
download_timeout = int(tavily_config.get("image_download_timeout", 30) or 30)
import ssl
timeout = aiohttp.ClientTimeout(total=download_timeout)
proxy_config = self.config.get("proxy", {}) if self.config else {}
proxy = None
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "http")
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
ssl_config = self.config.get("ssl", {}) if self.config else {}
ssl_verify = ssl_config.get("verify", True)
ssl_context = None
if not ssl_verify:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
connector = aiohttp.TCPConnector(ssl=ssl_context) if ssl_context else None
semaphore = asyncio.Semaphore(max(1, download_concurrency))
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async def fetch_image(url: str) -> Optional[str]:
async with semaphore:
return await self._download_image_with_session(
session,
url,
proxy,
max_retries=download_retries
)
tasks = [fetch_image(url) for url in image_urls[:max_images]]
results = await asyncio.gather(*tasks, return_exceptions=True)
sent = 0
for result in results:
if sent >= max_images:
break
if isinstance(result, str) and result:
await bot.send_image(from_wxid, result)
sent += 1
# 格式化搜索结果
if len(search_batches) == 1:
formatted_results = self._format_search_results(
search_batches[0][1],
include_raw_content=use_raw_content_in_result,
raw_content_max_chars=raw_content_max_chars,
)
else:
sections = []
for idx, (sub_query, sub_result) in enumerate(search_batches, 1):
sections.append(
self._format_search_results(
sub_result,
include_raw_content=use_raw_content_in_result,
raw_content_max_chars=raw_content_max_chars,
section_title=f"【子问题 {idx}{sub_query}",
)
)
formatted_results = "\n\n".join(sections)
if failed_queries:
failed_text = "\n".join([f"- {q}" for q in failed_queries])
formatted_results = (
f"{formatted_results}\n\n"
f"【未检索成功的子问题】\n{failed_text}"
)
logger.success(f"Tavily 搜索完成: {query[:30]}...")
# 返回搜索结果,标记需要 AI 继续处理
return {
"success": True,
"message": formatted_results,
"need_ai_reply": True # 标记需要 AI 基于此结果继续回复
}
except Exception as e:
logger.error(f"Tavily 搜索执行失败: {e}")
import traceback
logger.error(traceback.format_exc())
return {"success": False, "message": f"搜索失败: {str(e)}"}