feat:新增意图路由
This commit is contained in:
@@ -53,6 +53,7 @@ class AIChat(PluginBase):
|
|||||||
self._chatroom_member_cache = {} # {chatroom_id: (ts, {wxid: display_name})}
|
self._chatroom_member_cache = {} # {chatroom_id: (ts, {wxid: display_name})}
|
||||||
self._chatroom_member_cache_locks = {} # {chatroom_id: asyncio.Lock}
|
self._chatroom_member_cache_locks = {} # {chatroom_id: asyncio.Lock}
|
||||||
self._chatroom_member_cache_ttl_seconds = 3600 # 群名片缓存1小时,减少协议 API 调用
|
self._chatroom_member_cache_ttl_seconds = 3600 # 群名片缓存1小时,减少协议 API 调用
|
||||||
|
self._intent_cache = {} # {normalized_text: (ts, intent)}
|
||||||
|
|
||||||
async def async_init(self):
|
async def async_init(self):
|
||||||
"""插件异步初始化"""
|
"""插件异步初始化"""
|
||||||
@@ -1320,6 +1321,271 @@ class AIChat(PluginBase):
|
|||||||
text = text.rsplit(marker, 1)[-1].strip()
|
text = text.rsplit(marker, 1)[-1].strip()
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def _looks_like_info_query(self, text: str) -> bool:
|
||||||
|
t = str(text or "").strip().lower()
|
||||||
|
if not t:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 太短的消息不值得额外走一轮分类
|
||||||
|
if len(t) < 6:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 疑问/求评价/求推荐类
|
||||||
|
if any(x in t for x in ("?", "?")):
|
||||||
|
return True
|
||||||
|
if re.search(r"(什么|咋|怎么|如何|为啥|为什么|哪|哪里|哪个|多少|推荐|值不值得|值得吗|好不好|靠谱吗|评价|口碑|怎么样|如何评价|近况|最新|最近)", t):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 明确的实体/对象询问(公会/游戏/公司/项目等)
|
||||||
|
if re.search(r"(公会|战队|服务器|区服|游戏|公司|品牌|产品|软件|插件|项目|平台|up主|主播|作者|电影|电视剧|小说)", t) and len(t) >= 8:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _intent_to_allowed_tool_names(self, intent: str) -> set[str]:
|
||||||
|
intent = str(intent or "").strip().lower()
|
||||||
|
mapping = {
|
||||||
|
"search": {"tavily_web_search", "web_search"},
|
||||||
|
"draw": {
|
||||||
|
"nano_ai_image_generation",
|
||||||
|
"flow2_ai_image_generation",
|
||||||
|
"jimeng_ai_image_generation",
|
||||||
|
"kiira2_ai_image_generation",
|
||||||
|
"generate_image",
|
||||||
|
},
|
||||||
|
"weather": {"query_weather"},
|
||||||
|
"register_city": {"register_city"},
|
||||||
|
"signin": {"user_signin"},
|
||||||
|
"profile": {"check_profile"},
|
||||||
|
"news": {"get_daily_news"},
|
||||||
|
"music": {"search_music"},
|
||||||
|
"playlet": {"search_playlet"},
|
||||||
|
"kfc": {"get_kfc"},
|
||||||
|
"fabing": {"get_fabing"},
|
||||||
|
"chat": set(),
|
||||||
|
}
|
||||||
|
return set(mapping.get(intent, set()))
|
||||||
|
|
||||||
|
def _should_run_intent_router(self, intent_text: str) -> bool:
|
||||||
|
cfg = (self.config or {}).get("intent_router", {})
|
||||||
|
if not cfg.get("enabled", False):
|
||||||
|
return False
|
||||||
|
|
||||||
|
mode = str(cfg.get("mode", "hybrid") or "hybrid").strip().lower()
|
||||||
|
if mode == "always":
|
||||||
|
return True
|
||||||
|
|
||||||
|
# hybrid:只在像“信息查询/求评价”的问题时触发,避免闲聊额外增加延迟
|
||||||
|
return self._looks_like_info_query(intent_text)
|
||||||
|
|
||||||
|
def _extract_intent_from_llm_output(self, content: str) -> str:
|
||||||
|
raw = str(content or "").strip()
|
||||||
|
if not raw:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 尝试直接 JSON
|
||||||
|
try:
|
||||||
|
obj = json.loads(raw)
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
intent = obj.get("intent") or obj.get("class_name") or ""
|
||||||
|
return str(intent or "").strip().lower()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 尝试截取 JSON 片段
|
||||||
|
try:
|
||||||
|
m = re.search(r"\{[\s\S]*\}", raw)
|
||||||
|
if m:
|
||||||
|
obj = json.loads(m.group(0))
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
intent = obj.get("intent") or obj.get("class_name") or ""
|
||||||
|
return str(intent or "").strip().lower()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 兜底:纯文本标签
|
||||||
|
token = re.sub(r"[^a-zA-Z_]+", "", raw).lower()
|
||||||
|
return token
|
||||||
|
|
||||||
|
async def _classify_intent_with_llm(self, intent_text: str) -> str:
|
||||||
|
cfg = (self.config or {}).get("intent_router", {})
|
||||||
|
api_cfg = (self.config or {}).get("api", {})
|
||||||
|
|
||||||
|
text = str(intent_text or "").strip()
|
||||||
|
if not text:
|
||||||
|
return "chat"
|
||||||
|
|
||||||
|
cache_ttl = cfg.get("cache_ttl_seconds", 120)
|
||||||
|
try:
|
||||||
|
cache_ttl = float(cache_ttl or 0)
|
||||||
|
except Exception:
|
||||||
|
cache_ttl = 0
|
||||||
|
|
||||||
|
cache_key = re.sub(r"\s+", " ", text).strip().lower()
|
||||||
|
if cache_ttl > 0:
|
||||||
|
cached = self._intent_cache.get(cache_key)
|
||||||
|
if cached:
|
||||||
|
ts, intent = cached
|
||||||
|
if (time.time() - float(ts or 0)) <= cache_ttl and intent:
|
||||||
|
return str(intent).strip().lower()
|
||||||
|
|
||||||
|
model = cfg.get("model") or api_cfg.get("model")
|
||||||
|
url = api_cfg.get("url")
|
||||||
|
api_key = api_cfg.get("api_key")
|
||||||
|
timeout_s = cfg.get("timeout", 15)
|
||||||
|
try:
|
||||||
|
timeout_s = float(timeout_s or 15)
|
||||||
|
except Exception:
|
||||||
|
timeout_s = 15
|
||||||
|
|
||||||
|
system_prompt = (
|
||||||
|
"你是一个聊天机器人“意图分类器”,只负责判断用户当前这句话最可能的意图。\n"
|
||||||
|
"只允许返回 JSON,格式固定:{\"intent\":\"<label>\"},不要输出任何多余文字。\n"
|
||||||
|
"label 只能是以下之一:\n"
|
||||||
|
"- chat:普通聊天/主观闲聊,不需要工具。\n"
|
||||||
|
"- search:需要联网检索/查证事实/口碑/背景/最新信息(游戏/公会/公司/插件/项目等)。\n"
|
||||||
|
"- draw:用户要生成/绘制图片。\n"
|
||||||
|
"- weather:用户要查询天气/气温/预报/空气质量。\n"
|
||||||
|
"- register_city:用户要注册/设置/修改默认城市。\n"
|
||||||
|
"- signin:用户要签到。\n"
|
||||||
|
"- profile:用户要查积分/个人信息/资料。\n"
|
||||||
|
"- news:用户要每日新闻/早报。\n"
|
||||||
|
"- music:用户要搜歌/点歌。\n"
|
||||||
|
"- playlet:用户要搜短剧。\n"
|
||||||
|
"- kfc:用户要疯狂星期四/KFC 文案。\n"
|
||||||
|
"- fabing:用户明确要“发病文学/发病文/发病语录”。\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": text},
|
||||||
|
],
|
||||||
|
"temperature": 0,
|
||||||
|
"max_tokens": 80,
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
}
|
||||||
|
|
||||||
|
connector = None
|
||||||
|
proxy_config = (self.config or {}).get("proxy", {})
|
||||||
|
if proxy_config.get("enabled", False):
|
||||||
|
proxy_type = proxy_config.get("type", "socks5").upper()
|
||||||
|
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||||||
|
proxy_port = proxy_config.get("port", 7890)
|
||||||
|
proxy_username = proxy_config.get("username")
|
||||||
|
proxy_password = proxy_config.get("password")
|
||||||
|
|
||||||
|
if proxy_username and proxy_password:
|
||||||
|
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||||||
|
else:
|
||||||
|
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||||||
|
|
||||||
|
if PROXY_SUPPORT:
|
||||||
|
try:
|
||||||
|
connector = ProxyConnector.from_url(proxy_url)
|
||||||
|
except Exception:
|
||||||
|
connector = None
|
||||||
|
|
||||||
|
timeout = aiohttp.ClientTimeout(total=timeout_s)
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||||||
|
async with session.post(url, json=payload, headers=headers) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
return "chat"
|
||||||
|
data = await resp.json()
|
||||||
|
except Exception:
|
||||||
|
return "chat"
|
||||||
|
|
||||||
|
content = (
|
||||||
|
(((data or {}).get("choices") or [{}])[0].get("message") or {}).get("content")
|
||||||
|
if isinstance(data, dict)
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
intent = self._extract_intent_from_llm_output(content)
|
||||||
|
|
||||||
|
valid = {
|
||||||
|
"chat",
|
||||||
|
"search",
|
||||||
|
"draw",
|
||||||
|
"weather",
|
||||||
|
"register_city",
|
||||||
|
"signin",
|
||||||
|
"profile",
|
||||||
|
"news",
|
||||||
|
"music",
|
||||||
|
"playlet",
|
||||||
|
"kfc",
|
||||||
|
"fabing",
|
||||||
|
}
|
||||||
|
if intent not in valid:
|
||||||
|
intent = "chat"
|
||||||
|
|
||||||
|
if cache_ttl > 0:
|
||||||
|
try:
|
||||||
|
self._intent_cache[cache_key] = (time.time(), intent)
|
||||||
|
if len(self._intent_cache) > 500:
|
||||||
|
# 简单裁剪:清空最旧的 100 条(不做复杂 LRU)
|
||||||
|
items = sorted(self._intent_cache.items(), key=lambda kv: kv[1][0])[:100]
|
||||||
|
for k, _v in items:
|
||||||
|
self._intent_cache.pop(k, None)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return intent
|
||||||
|
|
||||||
|
async def _select_tools_for_message_async(self, tools: list, *, user_message: str, tool_query: str | None = None) -> list:
|
||||||
|
"""
|
||||||
|
工具选择(支持意图路由):
|
||||||
|
1) 先走现有 smart_select 规则(快)
|
||||||
|
2) 规则未命中且像信息查询时,走一次轻量 LLM 意图分类(慢但更准)
|
||||||
|
"""
|
||||||
|
tools_config = (self.config or {}).get("tools", {})
|
||||||
|
if not tools_config.get("smart_select", False):
|
||||||
|
return tools
|
||||||
|
|
||||||
|
selected = self._select_tools_for_message(tools, user_message=user_message, tool_query=tool_query)
|
||||||
|
if selected:
|
||||||
|
return selected
|
||||||
|
|
||||||
|
intent_text = self._extract_tool_intent_text(user_message, tool_query=tool_query)
|
||||||
|
if not intent_text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not self._should_run_intent_router(intent_text):
|
||||||
|
return []
|
||||||
|
|
||||||
|
intent = await self._classify_intent_with_llm(intent_text)
|
||||||
|
allow = self._intent_to_allowed_tool_names(intent)
|
||||||
|
if not allow:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 对“容易误触”的工具再做一次本地硬约束,避免分类器误判导致执行敏感动作
|
||||||
|
t = intent_text.lower()
|
||||||
|
if "register_city" in allow and not re.search(r"(注册|设置|更新|更换|修改|绑定|默认).{0,6}城市|城市.{0,6}(注册|设置|更新|更换|修改|绑定|默认)", t):
|
||||||
|
allow.discard("register_city")
|
||||||
|
if "user_signin" in allow and not re.search(r"(用户签到|签到|签个到)", t):
|
||||||
|
allow.discard("user_signin")
|
||||||
|
if "check_profile" in allow and not re.search(r"(个人信息|我的信息|我的积分|查积分|积分多少|连续签到|连签|我的资料)", t):
|
||||||
|
allow.discard("check_profile")
|
||||||
|
if "get_fabing" in allow and not re.search(r"(发病文学|犯病文学|发病文|犯病文|发病语录|犯病语录|发病一下|犯病一下)", t):
|
||||||
|
allow.discard("get_fabing")
|
||||||
|
|
||||||
|
if not allow:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for tool in tools or []:
|
||||||
|
name = tool.get("function", {}).get("name", "")
|
||||||
|
if name and name in allow:
|
||||||
|
result.append(tool)
|
||||||
|
return result
|
||||||
|
|
||||||
def _select_tools_for_message(self, tools: list, *, user_message: str, tool_query: str | None = None) -> list:
|
def _select_tools_for_message(self, tools: list, *, user_message: str, tool_query: str | None = None) -> list:
|
||||||
tools_config = (self.config or {}).get("tools", {})
|
tools_config = (self.config or {}).get("tools", {})
|
||||||
if not tools_config.get("smart_select", False):
|
if not tools_config.get("smart_select", False):
|
||||||
@@ -1965,7 +2231,7 @@ class AIChat(PluginBase):
|
|||||||
|
|
||||||
# 收集工具
|
# 收集工具
|
||||||
all_tools = self._collect_tools()
|
all_tools = self._collect_tools()
|
||||||
tools = self._select_tools_for_message(all_tools, user_message=user_message, tool_query=tool_query)
|
tools = await self._select_tools_for_message_async(all_tools, user_message=user_message, tool_query=tool_query)
|
||||||
logger.info(f"收集到 {len(all_tools)} 个工具函数,本次启用 {len(tools)} 个")
|
logger.info(f"收集到 {len(all_tools)} 个工具函数,本次启用 {len(tools)} 个")
|
||||||
if tools:
|
if tools:
|
||||||
tool_names = [t["function"]["name"] for t in tools]
|
tool_names = [t["function"]["name"] for t in tools]
|
||||||
@@ -3790,7 +4056,7 @@ class AIChat(PluginBase):
|
|||||||
"""调用AI API(带图片)"""
|
"""调用AI API(带图片)"""
|
||||||
api_config = self.config["api"]
|
api_config = self.config["api"]
|
||||||
all_tools = self._collect_tools()
|
all_tools = self._collect_tools()
|
||||||
tools = self._select_tools_for_message(all_tools, user_message=user_message, tool_query=tool_query)
|
tools = await self._select_tools_for_message_async(all_tools, user_message=user_message, tool_query=tool_query)
|
||||||
logger.info(f"[图片] 收集到 {len(all_tools)} 个工具函数,本次启用 {len(tools)} 个")
|
logger.info(f"[图片] 收集到 {len(all_tools)} 个工具函数,本次启用 {len(tools)} 个")
|
||||||
if tools:
|
if tools:
|
||||||
tool_names = [t["function"]["name"] for t in tools]
|
tool_names = [t["function"]["name"] for t in tools]
|
||||||
|
|||||||
Reference in New Issue
Block a user