88 lines
3.8 KiB
Python
88 lines
3.8 KiB
Python
from __future__ import annotations
|
||
|
||
import re
|
||
from dataclasses import dataclass, field
|
||
from typing import Dict, List
|
||
|
||
|
||
QUESTION_PATTERNS = [
|
||
r"\?$", r"?$", r"怎么", r"如何", r"咋弄", r"为啥", r"为什么",
|
||
r"有人知道", r"谁知道", r"能不能", r"可以吗", r"报错", r"怎么解决",
|
||
]
|
||
SOCIAL_PATTERNS = [r"小牛", r"在吗", r"出来", r"帮忙看", r"看看"]
|
||
|
||
|
||
@dataclass
|
||
class TriggerResult:
|
||
trigger_type: str = "none"
|
||
priority: float = 0.0
|
||
is_question: bool = False
|
||
is_followup: bool = False
|
||
is_social_call: bool = False
|
||
is_returning_member: bool = False
|
||
should_respond: bool = False
|
||
topic: str = ""
|
||
reasons: List[str] = field(default_factory=list)
|
||
|
||
|
||
class TriggerRouter:
|
||
def __init__(self, config: Dict):
|
||
self.config = config or {}
|
||
self.topic_keywords = [str(item).lower() for item in self.config.get("focus", [])]
|
||
|
||
def route(self, message: Dict, memory_hints: Dict) -> TriggerResult:
|
||
content = str(message.get("content", "")).strip()
|
||
content_lower = content.lower()
|
||
result = TriggerResult()
|
||
if message.get("is_at"):
|
||
result.trigger_type = "at_trigger"
|
||
result.priority = float(self.config.get("at_bot", 1.0))
|
||
result.should_respond = True
|
||
result.reasons.append("is_at")
|
||
if self._is_question(content):
|
||
if result.priority < float(self.config.get("explicit_question", 0.95)):
|
||
result.trigger_type = "question_trigger"
|
||
result.priority = float(self.config.get("explicit_question", 0.95))
|
||
result.is_question = True
|
||
result.should_respond = True
|
||
result.reasons.append("question")
|
||
if memory_hints.get("is_followup"):
|
||
if result.priority < float(self.config.get("followup", 0.90)):
|
||
result.trigger_type = "followup_trigger"
|
||
result.priority = float(self.config.get("followup", 0.90))
|
||
result.is_followup = True
|
||
result.should_respond = True
|
||
result.reasons.append("followup")
|
||
topic = self._detect_topic(content_lower)
|
||
if topic:
|
||
result.topic = topic
|
||
result.reasons.append("topic")
|
||
if result.priority < float(self.config.get("casual_topic", 0.35)):
|
||
result.trigger_type = result.trigger_type if result.trigger_type != "none" else "topic_trigger"
|
||
result.priority = max(result.priority, float(self.config.get("casual_topic", 0.35)))
|
||
if self._is_social_call(content_lower):
|
||
if result.priority < float(self.config.get("social_call", 0.65)):
|
||
result.trigger_type = result.trigger_type if result.trigger_type != "none" else "social_trigger"
|
||
result.priority = max(result.priority, float(self.config.get("social_call", 0.65)))
|
||
result.is_social_call = True
|
||
result.reasons.append("social_call")
|
||
if memory_hints.get("returning_member_state") in {"returning_member", "long_absent_member"}:
|
||
result.is_returning_member = True
|
||
result.reasons.append(memory_hints.get("returning_member_state"))
|
||
if result.trigger_type == "none":
|
||
result.trigger_type = "returning_member"
|
||
result.priority = max(result.priority, float(self.config.get("casual_topic", 0.35)))
|
||
return result
|
||
|
||
def _is_question(self, content: str) -> bool:
|
||
return any(re.search(pattern, content, flags=re.IGNORECASE) for pattern in QUESTION_PATTERNS)
|
||
|
||
def _is_social_call(self, content: str) -> bool:
|
||
return any(re.search(pattern, content, flags=re.IGNORECASE) for pattern in SOCIAL_PATTERNS)
|
||
|
||
def _detect_topic(self, content_lower: str) -> str:
|
||
for keyword in self.topic_keywords:
|
||
if keyword and keyword in content_lower:
|
||
return keyword
|
||
return ""
|