feat:新增意图路由

This commit is contained in:
2025-12-30 10:52:34 +08:00
parent 9b6173be76
commit b44d1589d1

View File

@@ -53,6 +53,7 @@ class AIChat(PluginBase):
self._chatroom_member_cache = {} # {chatroom_id: (ts, {wxid: display_name})}
self._chatroom_member_cache_locks = {} # {chatroom_id: asyncio.Lock}
self._chatroom_member_cache_ttl_seconds = 3600 # 群名片缓存1小时减少协议 API 调用
self._intent_cache = {} # {normalized_text: (ts, intent)}
async def async_init(self):
"""插件异步初始化"""
@@ -1320,6 +1321,271 @@ class AIChat(PluginBase):
text = text.rsplit(marker, 1)[-1].strip()
return text
def _looks_like_info_query(self, text: str) -> bool:
t = str(text or "").strip().lower()
if not t:
return False
# 太短的消息不值得额外走一轮分类
if len(t) < 6:
return False
# 疑问/求评价/求推荐类
if any(x in t for x in ("?", "")):
return True
if re.search(r"(什么|咋|怎么|如何|为啥|为什么|哪|哪里|哪个|多少|推荐|值不值得|值得吗|好不好|靠谱吗|评价|口碑|怎么样|如何评价|近况|最新|最近)", t):
return True
# 明确的实体/对象询问(公会/游戏/公司/项目等)
if re.search(r"(公会|战队|服务器|区服|游戏|公司|品牌|产品|软件|插件|项目|平台|up主|主播|作者|电影|电视剧|小说)", t) and len(t) >= 8:
return True
return False
def _intent_to_allowed_tool_names(self, intent: str) -> set[str]:
intent = str(intent or "").strip().lower()
mapping = {
"search": {"tavily_web_search", "web_search"},
"draw": {
"nano_ai_image_generation",
"flow2_ai_image_generation",
"jimeng_ai_image_generation",
"kiira2_ai_image_generation",
"generate_image",
},
"weather": {"query_weather"},
"register_city": {"register_city"},
"signin": {"user_signin"},
"profile": {"check_profile"},
"news": {"get_daily_news"},
"music": {"search_music"},
"playlet": {"search_playlet"},
"kfc": {"get_kfc"},
"fabing": {"get_fabing"},
"chat": set(),
}
return set(mapping.get(intent, set()))
def _should_run_intent_router(self, intent_text: str) -> bool:
cfg = (self.config or {}).get("intent_router", {})
if not cfg.get("enabled", False):
return False
mode = str(cfg.get("mode", "hybrid") or "hybrid").strip().lower()
if mode == "always":
return True
# hybrid只在像“信息查询/求评价”的问题时触发,避免闲聊额外增加延迟
return self._looks_like_info_query(intent_text)
def _extract_intent_from_llm_output(self, content: str) -> str:
raw = str(content or "").strip()
if not raw:
return ""
# 尝试直接 JSON
try:
obj = json.loads(raw)
if isinstance(obj, dict):
intent = obj.get("intent") or obj.get("class_name") or ""
return str(intent or "").strip().lower()
except Exception:
pass
# 尝试截取 JSON 片段
try:
m = re.search(r"\{[\s\S]*\}", raw)
if m:
obj = json.loads(m.group(0))
if isinstance(obj, dict):
intent = obj.get("intent") or obj.get("class_name") or ""
return str(intent or "").strip().lower()
except Exception:
pass
# 兜底:纯文本标签
token = re.sub(r"[^a-zA-Z_]+", "", raw).lower()
return token
async def _classify_intent_with_llm(self, intent_text: str) -> str:
cfg = (self.config or {}).get("intent_router", {})
api_cfg = (self.config or {}).get("api", {})
text = str(intent_text or "").strip()
if not text:
return "chat"
cache_ttl = cfg.get("cache_ttl_seconds", 120)
try:
cache_ttl = float(cache_ttl or 0)
except Exception:
cache_ttl = 0
cache_key = re.sub(r"\s+", " ", text).strip().lower()
if cache_ttl > 0:
cached = self._intent_cache.get(cache_key)
if cached:
ts, intent = cached
if (time.time() - float(ts or 0)) <= cache_ttl and intent:
return str(intent).strip().lower()
model = cfg.get("model") or api_cfg.get("model")
url = api_cfg.get("url")
api_key = api_cfg.get("api_key")
timeout_s = cfg.get("timeout", 15)
try:
timeout_s = float(timeout_s or 15)
except Exception:
timeout_s = 15
system_prompt = (
"你是一个聊天机器人“意图分类器”,只负责判断用户当前这句话最可能的意图。\n"
"只允许返回 JSON格式固定{\"intent\":\"<label>\"},不要输出任何多余文字。\n"
"label 只能是以下之一:\n"
"- chat普通聊天/主观闲聊,不需要工具。\n"
"- search需要联网检索/查证事实/口碑/背景/最新信息(游戏/公会/公司/插件/项目等)。\n"
"- draw用户要生成/绘制图片。\n"
"- weather用户要查询天气/气温/预报/空气质量。\n"
"- register_city用户要注册/设置/修改默认城市。\n"
"- signin用户要签到。\n"
"- profile用户要查积分/个人信息/资料。\n"
"- news用户要每日新闻/早报。\n"
"- music用户要搜歌/点歌。\n"
"- playlet用户要搜短剧。\n"
"- kfc用户要疯狂星期四/KFC 文案。\n"
"- fabing用户明确要“发病文学/发病文/发病语录”。\n"
)
payload = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
],
"temperature": 0,
"max_tokens": 80,
"stream": False,
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
connector = None
proxy_config = (self.config or {}).get("proxy", {})
if proxy_config.get("enabled", False):
proxy_type = proxy_config.get("type", "socks5").upper()
proxy_host = proxy_config.get("host", "127.0.0.1")
proxy_port = proxy_config.get("port", 7890)
proxy_username = proxy_config.get("username")
proxy_password = proxy_config.get("password")
if proxy_username and proxy_password:
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
else:
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
if PROXY_SUPPORT:
try:
connector = ProxyConnector.from_url(proxy_url)
except Exception:
connector = None
timeout = aiohttp.ClientTimeout(total=timeout_s)
try:
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.post(url, json=payload, headers=headers) as resp:
if resp.status != 200:
return "chat"
data = await resp.json()
except Exception:
return "chat"
content = (
(((data or {}).get("choices") or [{}])[0].get("message") or {}).get("content")
if isinstance(data, dict)
else ""
)
intent = self._extract_intent_from_llm_output(content)
valid = {
"chat",
"search",
"draw",
"weather",
"register_city",
"signin",
"profile",
"news",
"music",
"playlet",
"kfc",
"fabing",
}
if intent not in valid:
intent = "chat"
if cache_ttl > 0:
try:
self._intent_cache[cache_key] = (time.time(), intent)
if len(self._intent_cache) > 500:
# 简单裁剪:清空最旧的 100 条(不做复杂 LRU
items = sorted(self._intent_cache.items(), key=lambda kv: kv[1][0])[:100]
for k, _v in items:
self._intent_cache.pop(k, None)
except Exception:
pass
return intent
async def _select_tools_for_message_async(self, tools: list, *, user_message: str, tool_query: str | None = None) -> list:
"""
工具选择(支持意图路由):
1) 先走现有 smart_select 规则(快)
2) 规则未命中且像信息查询时,走一次轻量 LLM 意图分类(慢但更准)
"""
tools_config = (self.config or {}).get("tools", {})
if not tools_config.get("smart_select", False):
return tools
selected = self._select_tools_for_message(tools, user_message=user_message, tool_query=tool_query)
if selected:
return selected
intent_text = self._extract_tool_intent_text(user_message, tool_query=tool_query)
if not intent_text:
return []
if not self._should_run_intent_router(intent_text):
return []
intent = await self._classify_intent_with_llm(intent_text)
allow = self._intent_to_allowed_tool_names(intent)
if not allow:
return []
# 对“容易误触”的工具再做一次本地硬约束,避免分类器误判导致执行敏感动作
t = intent_text.lower()
if "register_city" in allow and not re.search(r"(注册|设置|更新|更换|修改|绑定|默认).{0,6}城市|城市.{0,6}(注册|设置|更新|更换|修改|绑定|默认)", t):
allow.discard("register_city")
if "user_signin" in allow and not re.search(r"(用户签到|签到|签个到)", t):
allow.discard("user_signin")
if "check_profile" in allow and not re.search(r"(个人信息|我的信息|我的积分|查积分|积分多少|连续签到|连签|我的资料)", t):
allow.discard("check_profile")
if "get_fabing" in allow and not re.search(r"(发病文学|犯病文学|发病文|犯病文|发病语录|犯病语录|发病一下|犯病一下)", t):
allow.discard("get_fabing")
if not allow:
return []
result = []
for tool in tools or []:
name = tool.get("function", {}).get("name", "")
if name and name in allow:
result.append(tool)
return result
def _select_tools_for_message(self, tools: list, *, user_message: str, tool_query: str | None = None) -> list:
tools_config = (self.config or {}).get("tools", {})
if not tools_config.get("smart_select", False):
@@ -1965,7 +2231,7 @@ class AIChat(PluginBase):
# 收集工具
all_tools = self._collect_tools()
tools = self._select_tools_for_message(all_tools, user_message=user_message, tool_query=tool_query)
tools = await self._select_tools_for_message_async(all_tools, user_message=user_message, tool_query=tool_query)
logger.info(f"收集到 {len(all_tools)} 个工具函数,本次启用 {len(tools)}")
if tools:
tool_names = [t["function"]["name"] for t in tools]
@@ -3790,7 +4056,7 @@ class AIChat(PluginBase):
"""调用AI API带图片"""
api_config = self.config["api"]
all_tools = self._collect_tools()
tools = self._select_tools_for_message(all_tools, user_message=user_message, tool_query=tool_query)
tools = await self._select_tools_for_message_async(all_tools, user_message=user_message, tool_query=tool_query)
logger.info(f"[图片] 收集到 {len(all_tools)} 个工具函数,本次启用 {len(tools)}")
if tools:
tool_names = [t["function"]["name"] for t in tools]