diff --git a/plugins/AIChat/main.py b/plugins/AIChat/main.py index 56f3bb2..1802860 100644 --- a/plugins/AIChat/main.py +++ b/plugins/AIChat/main.py @@ -53,6 +53,7 @@ class AIChat(PluginBase): self._chatroom_member_cache = {} # {chatroom_id: (ts, {wxid: display_name})} self._chatroom_member_cache_locks = {} # {chatroom_id: asyncio.Lock} self._chatroom_member_cache_ttl_seconds = 3600 # 群名片缓存1小时,减少协议 API 调用 + self._intent_cache = {} # {normalized_text: (ts, intent)} async def async_init(self): """插件异步初始化""" @@ -1320,6 +1321,271 @@ class AIChat(PluginBase): text = text.rsplit(marker, 1)[-1].strip() return text + def _looks_like_info_query(self, text: str) -> bool: + t = str(text or "").strip().lower() + if not t: + return False + + # 太短的消息不值得额外走一轮分类 + if len(t) < 6: + return False + + # 疑问/求评价/求推荐类 + if any(x in t for x in ("?", "?")): + return True + if re.search(r"(什么|咋|怎么|如何|为啥|为什么|哪|哪里|哪个|多少|推荐|值不值得|值得吗|好不好|靠谱吗|评价|口碑|怎么样|如何评价|近况|最新|最近)", t): + return True + + # 明确的实体/对象询问(公会/游戏/公司/项目等) + if re.search(r"(公会|战队|服务器|区服|游戏|公司|品牌|产品|软件|插件|项目|平台|up主|主播|作者|电影|电视剧|小说)", t) and len(t) >= 8: + return True + + return False + + def _intent_to_allowed_tool_names(self, intent: str) -> set[str]: + intent = str(intent or "").strip().lower() + mapping = { + "search": {"tavily_web_search", "web_search"}, + "draw": { + "nano_ai_image_generation", + "flow2_ai_image_generation", + "jimeng_ai_image_generation", + "kiira2_ai_image_generation", + "generate_image", + }, + "weather": {"query_weather"}, + "register_city": {"register_city"}, + "signin": {"user_signin"}, + "profile": {"check_profile"}, + "news": {"get_daily_news"}, + "music": {"search_music"}, + "playlet": {"search_playlet"}, + "kfc": {"get_kfc"}, + "fabing": {"get_fabing"}, + "chat": set(), + } + return set(mapping.get(intent, set())) + + def _should_run_intent_router(self, intent_text: str) -> bool: + cfg = (self.config or {}).get("intent_router", {}) + if not cfg.get("enabled", False): + return False + + mode = str(cfg.get("mode", "hybrid") or "hybrid").strip().lower() + if mode == "always": + return True + + # hybrid:只在像“信息查询/求评价”的问题时触发,避免闲聊额外增加延迟 + return self._looks_like_info_query(intent_text) + + def _extract_intent_from_llm_output(self, content: str) -> str: + raw = str(content or "").strip() + if not raw: + return "" + + # 尝试直接 JSON + try: + obj = json.loads(raw) + if isinstance(obj, dict): + intent = obj.get("intent") or obj.get("class_name") or "" + return str(intent or "").strip().lower() + except Exception: + pass + + # 尝试截取 JSON 片段 + try: + m = re.search(r"\{[\s\S]*\}", raw) + if m: + obj = json.loads(m.group(0)) + if isinstance(obj, dict): + intent = obj.get("intent") or obj.get("class_name") or "" + return str(intent or "").strip().lower() + except Exception: + pass + + # 兜底:纯文本标签 + token = re.sub(r"[^a-zA-Z_]+", "", raw).lower() + return token + + async def _classify_intent_with_llm(self, intent_text: str) -> str: + cfg = (self.config or {}).get("intent_router", {}) + api_cfg = (self.config or {}).get("api", {}) + + text = str(intent_text or "").strip() + if not text: + return "chat" + + cache_ttl = cfg.get("cache_ttl_seconds", 120) + try: + cache_ttl = float(cache_ttl or 0) + except Exception: + cache_ttl = 0 + + cache_key = re.sub(r"\s+", " ", text).strip().lower() + if cache_ttl > 0: + cached = self._intent_cache.get(cache_key) + if cached: + ts, intent = cached + if (time.time() - float(ts or 0)) <= cache_ttl and intent: + return str(intent).strip().lower() + + model = cfg.get("model") or api_cfg.get("model") + url = api_cfg.get("url") + api_key = api_cfg.get("api_key") + timeout_s = cfg.get("timeout", 15) + try: + timeout_s = float(timeout_s or 15) + except Exception: + timeout_s = 15 + + system_prompt = ( + "你是一个聊天机器人“意图分类器”,只负责判断用户当前这句话最可能的意图。\n" + "只允许返回 JSON,格式固定:{\"intent\":\"