feat:新增意图路由
This commit is contained in:
@@ -53,6 +53,7 @@ class AIChat(PluginBase):
|
||||
self._chatroom_member_cache = {} # {chatroom_id: (ts, {wxid: display_name})}
|
||||
self._chatroom_member_cache_locks = {} # {chatroom_id: asyncio.Lock}
|
||||
self._chatroom_member_cache_ttl_seconds = 3600 # 群名片缓存1小时,减少协议 API 调用
|
||||
self._intent_cache = {} # {normalized_text: (ts, intent)}
|
||||
|
||||
async def async_init(self):
|
||||
"""插件异步初始化"""
|
||||
@@ -1320,6 +1321,271 @@ class AIChat(PluginBase):
|
||||
text = text.rsplit(marker, 1)[-1].strip()
|
||||
return text
|
||||
|
||||
def _looks_like_info_query(self, text: str) -> bool:
|
||||
t = str(text or "").strip().lower()
|
||||
if not t:
|
||||
return False
|
||||
|
||||
# 太短的消息不值得额外走一轮分类
|
||||
if len(t) < 6:
|
||||
return False
|
||||
|
||||
# 疑问/求评价/求推荐类
|
||||
if any(x in t for x in ("?", "?")):
|
||||
return True
|
||||
if re.search(r"(什么|咋|怎么|如何|为啥|为什么|哪|哪里|哪个|多少|推荐|值不值得|值得吗|好不好|靠谱吗|评价|口碑|怎么样|如何评价|近况|最新|最近)", t):
|
||||
return True
|
||||
|
||||
# 明确的实体/对象询问(公会/游戏/公司/项目等)
|
||||
if re.search(r"(公会|战队|服务器|区服|游戏|公司|品牌|产品|软件|插件|项目|平台|up主|主播|作者|电影|电视剧|小说)", t) and len(t) >= 8:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _intent_to_allowed_tool_names(self, intent: str) -> set[str]:
|
||||
intent = str(intent or "").strip().lower()
|
||||
mapping = {
|
||||
"search": {"tavily_web_search", "web_search"},
|
||||
"draw": {
|
||||
"nano_ai_image_generation",
|
||||
"flow2_ai_image_generation",
|
||||
"jimeng_ai_image_generation",
|
||||
"kiira2_ai_image_generation",
|
||||
"generate_image",
|
||||
},
|
||||
"weather": {"query_weather"},
|
||||
"register_city": {"register_city"},
|
||||
"signin": {"user_signin"},
|
||||
"profile": {"check_profile"},
|
||||
"news": {"get_daily_news"},
|
||||
"music": {"search_music"},
|
||||
"playlet": {"search_playlet"},
|
||||
"kfc": {"get_kfc"},
|
||||
"fabing": {"get_fabing"},
|
||||
"chat": set(),
|
||||
}
|
||||
return set(mapping.get(intent, set()))
|
||||
|
||||
def _should_run_intent_router(self, intent_text: str) -> bool:
|
||||
cfg = (self.config or {}).get("intent_router", {})
|
||||
if not cfg.get("enabled", False):
|
||||
return False
|
||||
|
||||
mode = str(cfg.get("mode", "hybrid") or "hybrid").strip().lower()
|
||||
if mode == "always":
|
||||
return True
|
||||
|
||||
# hybrid:只在像“信息查询/求评价”的问题时触发,避免闲聊额外增加延迟
|
||||
return self._looks_like_info_query(intent_text)
|
||||
|
||||
def _extract_intent_from_llm_output(self, content: str) -> str:
|
||||
raw = str(content or "").strip()
|
||||
if not raw:
|
||||
return ""
|
||||
|
||||
# 尝试直接 JSON
|
||||
try:
|
||||
obj = json.loads(raw)
|
||||
if isinstance(obj, dict):
|
||||
intent = obj.get("intent") or obj.get("class_name") or ""
|
||||
return str(intent or "").strip().lower()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 尝试截取 JSON 片段
|
||||
try:
|
||||
m = re.search(r"\{[\s\S]*\}", raw)
|
||||
if m:
|
||||
obj = json.loads(m.group(0))
|
||||
if isinstance(obj, dict):
|
||||
intent = obj.get("intent") or obj.get("class_name") or ""
|
||||
return str(intent or "").strip().lower()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 兜底:纯文本标签
|
||||
token = re.sub(r"[^a-zA-Z_]+", "", raw).lower()
|
||||
return token
|
||||
|
||||
async def _classify_intent_with_llm(self, intent_text: str) -> str:
|
||||
cfg = (self.config or {}).get("intent_router", {})
|
||||
api_cfg = (self.config or {}).get("api", {})
|
||||
|
||||
text = str(intent_text or "").strip()
|
||||
if not text:
|
||||
return "chat"
|
||||
|
||||
cache_ttl = cfg.get("cache_ttl_seconds", 120)
|
||||
try:
|
||||
cache_ttl = float(cache_ttl or 0)
|
||||
except Exception:
|
||||
cache_ttl = 0
|
||||
|
||||
cache_key = re.sub(r"\s+", " ", text).strip().lower()
|
||||
if cache_ttl > 0:
|
||||
cached = self._intent_cache.get(cache_key)
|
||||
if cached:
|
||||
ts, intent = cached
|
||||
if (time.time() - float(ts or 0)) <= cache_ttl and intent:
|
||||
return str(intent).strip().lower()
|
||||
|
||||
model = cfg.get("model") or api_cfg.get("model")
|
||||
url = api_cfg.get("url")
|
||||
api_key = api_cfg.get("api_key")
|
||||
timeout_s = cfg.get("timeout", 15)
|
||||
try:
|
||||
timeout_s = float(timeout_s or 15)
|
||||
except Exception:
|
||||
timeout_s = 15
|
||||
|
||||
system_prompt = (
|
||||
"你是一个聊天机器人“意图分类器”,只负责判断用户当前这句话最可能的意图。\n"
|
||||
"只允许返回 JSON,格式固定:{\"intent\":\"<label>\"},不要输出任何多余文字。\n"
|
||||
"label 只能是以下之一:\n"
|
||||
"- chat:普通聊天/主观闲聊,不需要工具。\n"
|
||||
"- search:需要联网检索/查证事实/口碑/背景/最新信息(游戏/公会/公司/插件/项目等)。\n"
|
||||
"- draw:用户要生成/绘制图片。\n"
|
||||
"- weather:用户要查询天气/气温/预报/空气质量。\n"
|
||||
"- register_city:用户要注册/设置/修改默认城市。\n"
|
||||
"- signin:用户要签到。\n"
|
||||
"- profile:用户要查积分/个人信息/资料。\n"
|
||||
"- news:用户要每日新闻/早报。\n"
|
||||
"- music:用户要搜歌/点歌。\n"
|
||||
"- playlet:用户要搜短剧。\n"
|
||||
"- kfc:用户要疯狂星期四/KFC 文案。\n"
|
||||
"- fabing:用户明确要“发病文学/发病文/发病语录”。\n"
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": text},
|
||||
],
|
||||
"temperature": 0,
|
||||
"max_tokens": 80,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
connector = None
|
||||
proxy_config = (self.config or {}).get("proxy", {})
|
||||
if proxy_config.get("enabled", False):
|
||||
proxy_type = proxy_config.get("type", "socks5").upper()
|
||||
proxy_host = proxy_config.get("host", "127.0.0.1")
|
||||
proxy_port = proxy_config.get("port", 7890)
|
||||
proxy_username = proxy_config.get("username")
|
||||
proxy_password = proxy_config.get("password")
|
||||
|
||||
if proxy_username and proxy_password:
|
||||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||||
else:
|
||||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||||
|
||||
if PROXY_SUPPORT:
|
||||
try:
|
||||
connector = ProxyConnector.from_url(proxy_url)
|
||||
except Exception:
|
||||
connector = None
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=timeout_s)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||||
async with session.post(url, json=payload, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
return "chat"
|
||||
data = await resp.json()
|
||||
except Exception:
|
||||
return "chat"
|
||||
|
||||
content = (
|
||||
(((data or {}).get("choices") or [{}])[0].get("message") or {}).get("content")
|
||||
if isinstance(data, dict)
|
||||
else ""
|
||||
)
|
||||
intent = self._extract_intent_from_llm_output(content)
|
||||
|
||||
valid = {
|
||||
"chat",
|
||||
"search",
|
||||
"draw",
|
||||
"weather",
|
||||
"register_city",
|
||||
"signin",
|
||||
"profile",
|
||||
"news",
|
||||
"music",
|
||||
"playlet",
|
||||
"kfc",
|
||||
"fabing",
|
||||
}
|
||||
if intent not in valid:
|
||||
intent = "chat"
|
||||
|
||||
if cache_ttl > 0:
|
||||
try:
|
||||
self._intent_cache[cache_key] = (time.time(), intent)
|
||||
if len(self._intent_cache) > 500:
|
||||
# 简单裁剪:清空最旧的 100 条(不做复杂 LRU)
|
||||
items = sorted(self._intent_cache.items(), key=lambda kv: kv[1][0])[:100]
|
||||
for k, _v in items:
|
||||
self._intent_cache.pop(k, None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return intent
|
||||
|
||||
async def _select_tools_for_message_async(self, tools: list, *, user_message: str, tool_query: str | None = None) -> list:
|
||||
"""
|
||||
工具选择(支持意图路由):
|
||||
1) 先走现有 smart_select 规则(快)
|
||||
2) 规则未命中且像信息查询时,走一次轻量 LLM 意图分类(慢但更准)
|
||||
"""
|
||||
tools_config = (self.config or {}).get("tools", {})
|
||||
if not tools_config.get("smart_select", False):
|
||||
return tools
|
||||
|
||||
selected = self._select_tools_for_message(tools, user_message=user_message, tool_query=tool_query)
|
||||
if selected:
|
||||
return selected
|
||||
|
||||
intent_text = self._extract_tool_intent_text(user_message, tool_query=tool_query)
|
||||
if not intent_text:
|
||||
return []
|
||||
|
||||
if not self._should_run_intent_router(intent_text):
|
||||
return []
|
||||
|
||||
intent = await self._classify_intent_with_llm(intent_text)
|
||||
allow = self._intent_to_allowed_tool_names(intent)
|
||||
if not allow:
|
||||
return []
|
||||
|
||||
# 对“容易误触”的工具再做一次本地硬约束,避免分类器误判导致执行敏感动作
|
||||
t = intent_text.lower()
|
||||
if "register_city" in allow and not re.search(r"(注册|设置|更新|更换|修改|绑定|默认).{0,6}城市|城市.{0,6}(注册|设置|更新|更换|修改|绑定|默认)", t):
|
||||
allow.discard("register_city")
|
||||
if "user_signin" in allow and not re.search(r"(用户签到|签到|签个到)", t):
|
||||
allow.discard("user_signin")
|
||||
if "check_profile" in allow and not re.search(r"(个人信息|我的信息|我的积分|查积分|积分多少|连续签到|连签|我的资料)", t):
|
||||
allow.discard("check_profile")
|
||||
if "get_fabing" in allow and not re.search(r"(发病文学|犯病文学|发病文|犯病文|发病语录|犯病语录|发病一下|犯病一下)", t):
|
||||
allow.discard("get_fabing")
|
||||
|
||||
if not allow:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for tool in tools or []:
|
||||
name = tool.get("function", {}).get("name", "")
|
||||
if name and name in allow:
|
||||
result.append(tool)
|
||||
return result
|
||||
|
||||
def _select_tools_for_message(self, tools: list, *, user_message: str, tool_query: str | None = None) -> list:
|
||||
tools_config = (self.config or {}).get("tools", {})
|
||||
if not tools_config.get("smart_select", False):
|
||||
@@ -1965,7 +2231,7 @@ class AIChat(PluginBase):
|
||||
|
||||
# 收集工具
|
||||
all_tools = self._collect_tools()
|
||||
tools = self._select_tools_for_message(all_tools, user_message=user_message, tool_query=tool_query)
|
||||
tools = await self._select_tools_for_message_async(all_tools, user_message=user_message, tool_query=tool_query)
|
||||
logger.info(f"收集到 {len(all_tools)} 个工具函数,本次启用 {len(tools)} 个")
|
||||
if tools:
|
||||
tool_names = [t["function"]["name"] for t in tools]
|
||||
@@ -3790,7 +4056,7 @@ class AIChat(PluginBase):
|
||||
"""调用AI API(带图片)"""
|
||||
api_config = self.config["api"]
|
||||
all_tools = self._collect_tools()
|
||||
tools = self._select_tools_for_message(all_tools, user_message=user_message, tool_query=tool_query)
|
||||
tools = await self._select_tools_for_message_async(all_tools, user_message=user_message, tool_query=tool_query)
|
||||
logger.info(f"[图片] 收集到 {len(all_tools)} 个工具函数,本次启用 {len(tools)} 个")
|
||||
if tools:
|
||||
tool_names = [t["function"]["name"] for t in tools]
|
||||
|
||||
Reference in New Issue
Block a user