203 lines
7.4 KiB
Python
203 lines
7.4 KiB
Python
"""
|
|
TavilySearch 联网搜索插件
|
|
|
|
基于 Tavily API 的联网搜索功能,仅作为 LLM Tool 供 AIChat 调用
|
|
支持多 API Key 轮询,搜索结果返回给 AIChat 的 AI 处理(带上下文和人设)
|
|
"""
|
|
|
|
import tomllib
|
|
import aiohttp
|
|
from pathlib import Path
|
|
from typing import List, Optional
|
|
from loguru import logger
|
|
from utils.plugin_base import PluginBase
|
|
|
|
|
|
class TavilySearch(PluginBase):
|
|
"""Tavily 联网搜索插件 - 仅作为 LLM Tool"""
|
|
|
|
description = "Tavily 联网搜索 - 支持多 Key 轮询的搜索工具"
|
|
author = "Assistant"
|
|
version = "1.0.0"
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.config = None
|
|
self.api_keys = []
|
|
self.current_key_index = 0
|
|
|
|
async def async_init(self):
|
|
"""异步初始化"""
|
|
try:
|
|
config_path = Path(__file__).parent / "config.toml"
|
|
if not config_path.exists():
|
|
logger.error(f"TavilySearch 配置文件不存在: {config_path}")
|
|
return
|
|
|
|
with open(config_path, "rb") as f:
|
|
self.config = tomllib.load(f)
|
|
|
|
self.api_keys = [k for k in self.config["tavily"]["api_keys"] if k and not k.startswith("#")]
|
|
if not self.api_keys:
|
|
logger.warning("TavilySearch: 未配置有效的 API Key")
|
|
else:
|
|
logger.success(f"TavilySearch 已加载,共 {len(self.api_keys)} 个 API Key")
|
|
|
|
except Exception as e:
|
|
logger.error(f"TavilySearch 初始化失败: {e}")
|
|
self.config = None
|
|
|
|
def _get_next_api_key(self) -> str:
|
|
"""轮询获取下一个 API Key"""
|
|
if not self.api_keys:
|
|
return ""
|
|
key = self.api_keys[self.current_key_index]
|
|
self.current_key_index = (self.current_key_index + 1) % len(self.api_keys)
|
|
return key
|
|
|
|
async def _search_tavily(self, query: str) -> Optional[dict]:
|
|
"""调用 Tavily API 进行搜索"""
|
|
api_key = self._get_next_api_key()
|
|
if not api_key:
|
|
logger.error("没有可用的 Tavily API Key")
|
|
return None
|
|
|
|
tavily_config = self.config["tavily"]
|
|
proxy_config = self.config.get("proxy", {})
|
|
|
|
payload = {
|
|
"api_key": api_key,
|
|
"query": query,
|
|
"search_depth": tavily_config.get("search_depth", "basic"),
|
|
"max_results": tavily_config.get("max_results", 5),
|
|
"include_raw_content": tavily_config.get("include_raw_content", False),
|
|
"include_images": tavily_config.get("include_images", False),
|
|
}
|
|
|
|
proxy = None
|
|
if proxy_config.get("enabled", False):
|
|
proxy_type = proxy_config.get("type", "http")
|
|
proxy_host = proxy_config.get("host", "127.0.0.1")
|
|
proxy_port = proxy_config.get("port", 7890)
|
|
proxy = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
|
|
|
try:
|
|
import ssl
|
|
timeout = aiohttp.ClientTimeout(total=30)
|
|
|
|
# SSL 配置
|
|
ssl_config = self.config.get("ssl", {})
|
|
ssl_verify = ssl_config.get("verify", True)
|
|
|
|
connector = None
|
|
if not ssl_verify:
|
|
# 跳过 SSL 验证
|
|
ssl_context = ssl.create_default_context()
|
|
ssl_context.check_hostname = False
|
|
ssl_context.verify_mode = ssl.CERT_NONE
|
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
|
|
|
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
|
async with session.post(
|
|
"https://api.tavily.com/search",
|
|
json=payload,
|
|
proxy=proxy
|
|
) as resp:
|
|
if resp.status == 200:
|
|
result = await resp.json()
|
|
logger.info(f"Tavily 搜索成功: {query[:30]}...")
|
|
logger.info(f"Tavily 原始返回: {result}")
|
|
return result
|
|
else:
|
|
error_text = await resp.text()
|
|
logger.error(f"Tavily API 错误: {resp.status}, {error_text}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Tavily 搜索失败: {e}")
|
|
return None
|
|
|
|
def _format_search_results(self, results: dict) -> str:
|
|
"""格式化搜索结果供 AI 处理"""
|
|
if not results or "results" not in results:
|
|
return "未找到相关搜索结果"
|
|
|
|
formatted = []
|
|
for i, item in enumerate(results["results"], 1):
|
|
title = item.get("title", "无标题")
|
|
content = item.get("content", "")
|
|
url = item.get("url", "")
|
|
formatted.append(f"【结果 {i}】\n标题: {title}\n内容: {content}\n来源: {url}\n")
|
|
|
|
return "\n".join(formatted)
|
|
|
|
def get_llm_tools(self) -> List[dict]:
|
|
"""返回 LLM 工具定义"""
|
|
if not self.config or not self.config["behavior"]["enabled"]:
|
|
return []
|
|
|
|
return [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "tavily_web_search",
|
|
"description": "仅当用户明确要求“联网搜索/查资料/最新信息/来源/权威说法”或需要事实核实时调用;不要在闲聊中触发。",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "搜索关键词或问题,建议使用简洁明确的搜索词"
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
}
|
|
}
|
|
}
|
|
]
|
|
|
|
async def execute_llm_tool(self, tool_name: str, arguments: dict, bot, from_wxid: str) -> dict:
|
|
"""
|
|
执行 LLM 工具调用
|
|
|
|
只负责搜索,返回结果给 AIChat 的 AI 处理(带上下文和人设)
|
|
"""
|
|
if tool_name != "tavily_web_search":
|
|
return None
|
|
|
|
if not self.config or not self.config["behavior"]["enabled"]:
|
|
return {"success": False, "message": "TavilySearch 插件未启用"}
|
|
|
|
if not self.api_keys:
|
|
return {"success": False, "message": "未配置 Tavily API Key"}
|
|
|
|
query = arguments.get("query", "")
|
|
if not query:
|
|
return {"success": False, "message": "搜索关键词不能为空"}
|
|
|
|
try:
|
|
logger.info(f"开始 Tavily 搜索: {query}")
|
|
|
|
# 调用 Tavily 搜索
|
|
search_results = await self._search_tavily(query)
|
|
if not search_results:
|
|
return {"success": False, "message": "搜索失败,请稍后重试"}
|
|
|
|
# 格式化搜索结果
|
|
formatted_results = self._format_search_results(search_results)
|
|
|
|
logger.success(f"Tavily 搜索完成: {query[:30]}...")
|
|
|
|
# 返回搜索结果,标记需要 AI 继续处理
|
|
return {
|
|
"success": True,
|
|
"message": formatted_results,
|
|
"need_ai_reply": True # 标记需要 AI 基于此结果继续回复
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Tavily 搜索执行失败: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return {"success": False, "message": f"搜索失败: {str(e)}"}
|