refactor ai_auto_response plugin architecture
This commit is contained in:
17
plugins/ai_auto_response/memory/__init__.py
Normal file
17
plugins/ai_auto_response/memory/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .group_facts import GroupFactsService
|
||||
from .group_memory import GroupMemoryCoordinator
|
||||
from .group_memory_profile import GroupMemoryService
|
||||
from .memory_ranker import MemoryRanker
|
||||
from .social_memory import SocialMemoryService
|
||||
from ..profile.group_profile import GroupProfileResolver
|
||||
|
||||
__all__ = [
|
||||
"GroupFactsService",
|
||||
"GroupMemoryCoordinator",
|
||||
"GroupMemoryService",
|
||||
"GroupProfileResolver",
|
||||
"MemoryRanker",
|
||||
"SocialMemoryService",
|
||||
]
|
||||
127
plugins/ai_auto_response/memory/group_facts.py
Normal file
127
plugins/ai_auto_response/memory/group_facts.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class GroupFactsService:
|
||||
DOMAIN_KEYWORDS = {
|
||||
"openclaw": ["openclaw", "claw", "节点", "工作流", "编排", "接入", "agent"],
|
||||
"robotics": ["机器人", "bot", "插件", "自动化", "微信", "框架", "消息"],
|
||||
"infra": ["部署", "docker", "服务器", "日志", "接口", "配置", "报错", "超时"],
|
||||
"dota": ["dota", "dota2", "刀塔", "英雄", "对线", "团战", "战绩", "版本"],
|
||||
"casual": ["吃饭", "睡觉", "上班", "下班", "摸鱼", "乐", "吐槽", "闲聊"],
|
||||
}
|
||||
ANSWER_WORDS = ["先", "然后", "试试", "看下", "排查", "配置", "日志", "原因", "改成", "部署", "重启"]
|
||||
JOKE_WORDS = ["笑死", "逆天", "离谱", "绷不住", "抽象", "节目效果", "蚌", "乐"]
|
||||
|
||||
def __init__(self, config: Dict | None = None):
|
||||
self.config = config or {}
|
||||
|
||||
def build_group_facts(
|
||||
self,
|
||||
*,
|
||||
room_id: str,
|
||||
recent_messages: List[Dict],
|
||||
name_map: Dict[str, str] | None = None,
|
||||
) -> Dict:
|
||||
name_map = name_map or {}
|
||||
window_size = max(int(self.config.get("group_fact_window_size", 80) or 80), 20)
|
||||
window = list(recent_messages or [])[-window_size:]
|
||||
if not window:
|
||||
return {"items": [], "prompt": ""}
|
||||
|
||||
topic_counter: Counter[str] = Counter()
|
||||
role_counter: Counter[str] = Counter()
|
||||
joke_counter: Counter[str] = Counter()
|
||||
co_occurrence: defaultdict[str, int] = defaultdict(int)
|
||||
|
||||
for item in window:
|
||||
sender = str(item.get("sender", "") or "")
|
||||
sender_name = str(item.get("sender_name") or name_map.get(sender) or sender or "未知成员")
|
||||
content = str(item.get("content") or item.get("message") or "").strip().lower()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
for domain, keywords in self.DOMAIN_KEYWORDS.items():
|
||||
hits = sum(1 for keyword in keywords if keyword and keyword.lower() in content)
|
||||
if hits:
|
||||
topic_counter[domain] += hits
|
||||
|
||||
if self._looks_like_answer(content):
|
||||
role_counter[sender_name] += 1
|
||||
|
||||
for word in self.JOKE_WORDS:
|
||||
if word in content:
|
||||
joke_counter[word] += 1
|
||||
|
||||
mentions = self._extract_member_mentions(content, name_map)
|
||||
for target in mentions:
|
||||
key = f"{sender_name}->{target}"
|
||||
co_occurrence[key] += 1
|
||||
|
||||
items: List[Dict] = []
|
||||
for domain, count in topic_counter.most_common(3):
|
||||
items.append({
|
||||
"fact_type": "group_theme",
|
||||
"summary": f"群里最近长期反复出现 {domain} 相关话题",
|
||||
"weight": min(count, 6),
|
||||
})
|
||||
for member, count in role_counter.most_common(2):
|
||||
if count >= 2:
|
||||
items.append({
|
||||
"fact_type": "group_role",
|
||||
"summary": f"{member} 最近更像答疑位或方案位",
|
||||
"weight": min(count, 5),
|
||||
})
|
||||
for pair, count in sorted(co_occurrence.items(), key=lambda item: item[1], reverse=True)[:2]:
|
||||
if count >= 2:
|
||||
items.append({
|
||||
"fact_type": "social_link",
|
||||
"summary": f"{pair.replace('->', ' 更常接 ')} 的话",
|
||||
"weight": min(count, 4),
|
||||
})
|
||||
for joke, count in joke_counter.most_common(2):
|
||||
if count >= 2:
|
||||
items.append({
|
||||
"fact_type": "group_joke",
|
||||
"summary": f"群里最近常用“{joke}”这类轻吐槽",
|
||||
"weight": min(count, 4),
|
||||
})
|
||||
|
||||
prompt = self._build_prompt(room_id, items)
|
||||
return {
|
||||
"items": items,
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
def _build_prompt(self, room_id: str, items: List[Dict]) -> str:
|
||||
if not items:
|
||||
return ""
|
||||
lines = [f"下面是群 {room_id} 最近沉淀出的轻量群事实,只在相关时参考。"]
|
||||
for item in items[:6]:
|
||||
lines.append(
|
||||
f"- [{item.get('fact_type', 'fact')}] {item.get('summary', '')}; weight={item.get('weight', 1)}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
@classmethod
|
||||
def _looks_like_answer(cls, content: str) -> bool:
|
||||
if len(content) >= 18:
|
||||
return True
|
||||
return any(word in content for word in cls.ANSWER_WORDS)
|
||||
|
||||
@staticmethod
|
||||
def _extract_member_mentions(content: str, name_map: Dict[str, str]) -> List[str]:
|
||||
if not name_map:
|
||||
return []
|
||||
hits: List[str] = []
|
||||
normalized = re.sub(r"\s+", "", content)
|
||||
for _, name in list(name_map.items())[:120]:
|
||||
short_name = str(name or "").strip()
|
||||
if len(short_name) < 2:
|
||||
continue
|
||||
if short_name in normalized and short_name not in hits:
|
||||
hits.append(short_name)
|
||||
return hits[:3]
|
||||
182
plugins/ai_auto_response/memory/group_memory.py
Normal file
182
plugins/ai_auto_response/memory/group_memory.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
import time
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from .group_facts import GroupFactsService
|
||||
from .group_memory_profile import GroupMemoryService
|
||||
from .social_memory import SocialMemoryService
|
||||
from ..profile.group_profile import GroupProfileResolver
|
||||
from .vector_memory import VectorMemoryStore
|
||||
|
||||
|
||||
class GroupMemoryCoordinator:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
group_memory_service: GroupMemoryService,
|
||||
group_profile_resolver: GroupProfileResolver,
|
||||
social_memory_service: SocialMemoryService,
|
||||
group_facts_service: GroupFactsService,
|
||||
vector_memory: VectorMemoryStore,
|
||||
memory_config: Dict | None = None,
|
||||
):
|
||||
self.group_memory_service = group_memory_service
|
||||
self.group_profile_resolver = group_profile_resolver
|
||||
self.social_memory_service = social_memory_service
|
||||
self.group_facts_service = group_facts_service
|
||||
self.vector_memory = vector_memory
|
||||
self.memory_config = memory_config or {}
|
||||
self._synced_social_snapshot_versions: Dict[str, str] = {}
|
||||
self._synced_group_fact_versions: Dict[str, str] = {}
|
||||
|
||||
def build(
|
||||
self,
|
||||
*,
|
||||
room_id: str,
|
||||
group_name: str,
|
||||
sender: str,
|
||||
current_content: str,
|
||||
recent_messages: List[Dict],
|
||||
name_map: Dict[str, str],
|
||||
) -> Dict:
|
||||
group_memory_profile = self.group_memory_service.build_group_memory_profile(room_id, group_name)
|
||||
group_profile = self.group_profile_resolver.resolve(room_id, group_name, group_memory_profile)
|
||||
social_context = self.social_memory_service.build_social_context(
|
||||
room_id=room_id,
|
||||
sender=sender,
|
||||
current_content=current_content,
|
||||
recent_messages=recent_messages,
|
||||
name_map=name_map,
|
||||
)
|
||||
group_facts = self.group_facts_service.build_group_facts(
|
||||
room_id=room_id,
|
||||
recent_messages=recent_messages,
|
||||
name_map=name_map,
|
||||
)
|
||||
return {
|
||||
"group_memory_profile": group_memory_profile,
|
||||
"group_profile": group_profile,
|
||||
"social_context": social_context,
|
||||
"group_facts": group_facts,
|
||||
}
|
||||
|
||||
def sync_snapshots(
|
||||
self,
|
||||
*,
|
||||
room_id: str,
|
||||
social_context: Dict,
|
||||
group_facts: Dict,
|
||||
log_event: Callable[..., None],
|
||||
) -> None:
|
||||
self._sync_social_snapshot(room_id, social_context, log_event)
|
||||
self._sync_group_fact_snapshot(room_id, group_facts, log_event)
|
||||
|
||||
def _sync_social_snapshot(self, room_id: str, social_context: Dict, log_event: Callable[..., None]) -> None:
|
||||
if not bool(self.memory_config.get("enable_social_snapshot", True)):
|
||||
return
|
||||
items = (social_context or {}).get("items", []) or []
|
||||
snapshot_text = self._build_social_snapshot_text(items)
|
||||
if not snapshot_text or not items:
|
||||
return
|
||||
version = hashlib.md5(snapshot_text.encode("utf-8")).hexdigest()[:16]
|
||||
if self._synced_social_snapshot_versions.get(room_id) == version:
|
||||
return
|
||||
topic_tags: List[str] = []
|
||||
for item in items[:3]:
|
||||
for tag in item.get("topic_tags", [])[:3]:
|
||||
if tag and tag not in topic_tags:
|
||||
topic_tags.append(tag)
|
||||
payload = {
|
||||
"chatroom_id": room_id,
|
||||
"memory_type": "group_social_snapshot",
|
||||
"source_id": f"{room_id}:social",
|
||||
"summary_text": snapshot_text[:500],
|
||||
"topic_tags": topic_tags[:6],
|
||||
"created_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
ok = self.vector_memory.upsert_memory(f"group_social:{room_id}:{version}", snapshot_text, payload)
|
||||
log_event(
|
||||
"memory_upsert",
|
||||
room_id=room_id,
|
||||
sender="group",
|
||||
memory_type="group_social_snapshot",
|
||||
ok=ok,
|
||||
error=self.vector_memory.last_error,
|
||||
)
|
||||
if ok:
|
||||
self._synced_social_snapshot_versions[room_id] = version
|
||||
|
||||
def _sync_group_fact_snapshot(self, room_id: str, group_facts: Dict, log_event: Callable[..., None]) -> None:
|
||||
if not bool(self.memory_config.get("enable_group_fact_snapshot", True)):
|
||||
return
|
||||
items = (group_facts or {}).get("items", []) or []
|
||||
snapshot_text = self._build_group_fact_snapshot_text(items)
|
||||
if not snapshot_text or not items:
|
||||
return
|
||||
version = hashlib.md5(snapshot_text.encode("utf-8")).hexdigest()[:16]
|
||||
if self._synced_group_fact_versions.get(room_id) == version:
|
||||
return
|
||||
topic_tags: List[str] = []
|
||||
for item in items[:4]:
|
||||
summary = str(item.get("summary", "") or "")
|
||||
tokens = re.findall(r"[A-Za-z0-9_\-\u4e00-\u9fff]{2,12}", summary)
|
||||
for tag in tokens[:4]:
|
||||
if tag and tag not in topic_tags:
|
||||
topic_tags.append(tag)
|
||||
payload = {
|
||||
"chatroom_id": room_id,
|
||||
"memory_type": "group_fact_snapshot",
|
||||
"source_id": f"{room_id}:facts",
|
||||
"summary_text": snapshot_text[:500],
|
||||
"topic_tags": topic_tags[:8],
|
||||
"created_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
ok = self.vector_memory.upsert_memory(f"group_fact:{room_id}:{version}", snapshot_text, payload)
|
||||
log_event(
|
||||
"memory_upsert",
|
||||
room_id=room_id,
|
||||
sender="group",
|
||||
memory_type="group_fact_snapshot",
|
||||
ok=ok,
|
||||
error=self.vector_memory.last_error,
|
||||
)
|
||||
if ok:
|
||||
self._synced_group_fact_versions[room_id] = version
|
||||
|
||||
@staticmethod
|
||||
def build_debug_summary(rank_debug: Dict | None) -> str:
|
||||
debug = rank_debug or {}
|
||||
parts = []
|
||||
for key, prefix in (("vector", "v"), ("social", "s"), ("facts", "f"), ("member", "m")):
|
||||
items = debug.get(key, []) or []
|
||||
if not items:
|
||||
continue
|
||||
parts.append(f"{prefix}[{items[0]}]")
|
||||
return " ".join(parts[:4])
|
||||
|
||||
@staticmethod
|
||||
def _build_social_snapshot_text(items: List[Dict]) -> str:
|
||||
if not items:
|
||||
return ""
|
||||
lines = ["群关系快照:"]
|
||||
for item in items[:4]:
|
||||
tags = "、".join(item.get("topic_tags", [])[:3]) or "泛互动"
|
||||
lines.append(
|
||||
f"- {item.get('target_name', '某成员')} | {item.get('relation_type', 'frequent_turn_taking')} | "
|
||||
f"strength={item.get('strength', 0.0)} | topics={tags}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _build_group_fact_snapshot_text(items: List[Dict]) -> str:
|
||||
if not items:
|
||||
return ""
|
||||
lines = ["群事实快照:"]
|
||||
for item in items[:6]:
|
||||
lines.append(
|
||||
f"- [{item.get('fact_type', 'fact')}] {item.get('summary', '')} | weight={item.get('weight', 1)}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
155
plugins/ai_auto_response/memory/group_memory_profile.py
Normal file
155
plugins/ai_auto_response/memory/group_memory_profile.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter
|
||||
from typing import Dict, List
|
||||
|
||||
from db.message_storage import MessageStorageDB
|
||||
from db.message_summary_db import MessageSummaryDBOperator
|
||||
|
||||
|
||||
class GroupMemoryService:
|
||||
DOMAIN_KEYWORDS = {
|
||||
"openclaw": ["openclaw", "claw", "工作流", "节点", "编排", "接入", "联调"],
|
||||
"robotics": ["机器人", "bot", "微信机器人", "插件", "自动化", "消息路由", "部署", "接口"],
|
||||
"dota": ["dota", "dota2", "刀塔", "英雄", "出装", "对线", "团战", "版本"],
|
||||
"tech": ["python", "docker", "redis", "mysql", "服务器", "报错", "脚本", "网络", "接口"],
|
||||
"casual": ["吃饭", "睡觉", "上班", "下班", "周末", "唠嗑", "闲聊"],
|
||||
}
|
||||
HUMOR_KEYWORDS = ["哈哈", "笑死", "乐", "蚌", "绷不住", "离谱", "逆天", "节目效果", "抽象", "乐子"]
|
||||
SHARPNESS_KEYWORDS = ["菜", "蠢", "逆天", "离谱", "抽象", "别搞", "别整", "你这", "搁这", "典"]
|
||||
RELAXED_KEYWORDS = ["随便", "行吧", "都行", "慢慢来", "不急", "摸鱼", "唠", "水群", "先这样"]
|
||||
SERIOUS_KEYWORDS = ["报错", "排查", "日志", "配置", "部署", "接口", "重现", "修复", "方案", "联调"]
|
||||
|
||||
def __init__(self, db_manager, config: Dict):
|
||||
self.config = config or {}
|
||||
self.message_db = MessageStorageDB(db_manager)
|
||||
self.summary_db = MessageSummaryDBOperator(db_manager)
|
||||
|
||||
def build_group_memory_profile(self, room_id: str, group_name: str = "") -> Dict:
|
||||
recent_messages = self.message_db.get_messages_for_summary(
|
||||
room_id, hours_ago=48, min_messages=20, max_hours=168, max_results=300
|
||||
) or []
|
||||
summary_text = self._load_recent_summary_text(room_id)
|
||||
topic_counter = Counter()
|
||||
domain_counter = Counter()
|
||||
humor_hits = 0
|
||||
sharpness_hits = 0
|
||||
relaxed_hits = 0
|
||||
serious_hits = 0
|
||||
short_message_count = 0
|
||||
message_count = 0
|
||||
|
||||
for item in recent_messages:
|
||||
content = str(item.get("content", "") or "").lower()
|
||||
if not content:
|
||||
continue
|
||||
message_count += 1
|
||||
if len(content) <= 8:
|
||||
short_message_count += 1
|
||||
for domain, keywords in self.DOMAIN_KEYWORDS.items():
|
||||
hits = sum(1 for keyword in keywords if keyword and keyword.lower() in content)
|
||||
if hits:
|
||||
domain_counter[domain] += hits
|
||||
for keyword in keywords:
|
||||
if keyword and keyword.lower() in content:
|
||||
topic_counter[keyword] += 1
|
||||
humor_hits += self._count_hits(content, self.HUMOR_KEYWORDS)
|
||||
sharpness_hits += self._count_hits(content, self.SHARPNESS_KEYWORDS)
|
||||
relaxed_hits += self._count_hits(content, self.RELAXED_KEYWORDS)
|
||||
serious_hits += self._count_hits(content, self.SERIOUS_KEYWORDS)
|
||||
|
||||
summary_lower = summary_text.lower()
|
||||
for domain, keywords in self.DOMAIN_KEYWORDS.items():
|
||||
hits = sum(1 for keyword in keywords if keyword and keyword.lower() in summary_lower)
|
||||
if hits:
|
||||
domain_counter[domain] += hits * 2
|
||||
for keyword in keywords:
|
||||
if keyword and keyword.lower() in summary_lower:
|
||||
topic_counter[keyword] += 2
|
||||
humor_hits += self._count_hits(summary_lower, self.HUMOR_KEYWORDS) * 2
|
||||
sharpness_hits += self._count_hits(summary_lower, self.SHARPNESS_KEYWORDS) * 2
|
||||
relaxed_hits += self._count_hits(summary_lower, self.RELAXED_KEYWORDS) * 2
|
||||
serious_hits += self._count_hits(summary_lower, self.SERIOUS_KEYWORDS) * 2
|
||||
|
||||
inferred_domain = domain_counter.most_common(1)[0][0] if domain_counter else "general"
|
||||
focus_topics = [item for item, _ in topic_counter.most_common(6)]
|
||||
style_profile = self._infer_style_profile(
|
||||
humor_hits=humor_hits,
|
||||
sharpness_hits=sharpness_hits,
|
||||
relaxed_hits=relaxed_hits,
|
||||
serious_hits=serious_hits,
|
||||
short_message_ratio=(short_message_count / message_count) if message_count else 0.0,
|
||||
)
|
||||
return {
|
||||
"room_id": room_id,
|
||||
"group_name": group_name,
|
||||
"inferred_domain": inferred_domain,
|
||||
"focus_topics": focus_topics,
|
||||
"message_sample_count": len(recent_messages),
|
||||
"summary_text": summary_text,
|
||||
"style_profile": style_profile,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _count_hits(text: str, keywords: List[str]) -> int:
|
||||
return sum(1 for keyword in keywords if keyword and keyword.lower() in text)
|
||||
|
||||
@staticmethod
|
||||
def _infer_style_profile(
|
||||
*,
|
||||
humor_hits: int,
|
||||
sharpness_hits: int,
|
||||
relaxed_hits: int,
|
||||
serious_hits: int,
|
||||
short_message_ratio: float,
|
||||
) -> Dict:
|
||||
humor_style = "轻微"
|
||||
if humor_hits >= 18:
|
||||
humor_style = "中等偏上,能接梗"
|
||||
elif humor_hits >= 8:
|
||||
humor_style = "中等,可以带一点冷幽默"
|
||||
|
||||
sharpness_style = "轻微嘴硬,不刻薄"
|
||||
if sharpness_hits >= 15:
|
||||
sharpness_style = "允许轻微毒舌,但别上头"
|
||||
elif sharpness_hits >= 7:
|
||||
sharpness_style = "允许轻微嘴欠,但别刺人"
|
||||
|
||||
interaction_tone = "自然群友感"
|
||||
if serious_hits >= max(relaxed_hits + 4, 10):
|
||||
interaction_tone = "偏认真,问题导向"
|
||||
elif relaxed_hits >= serious_hits + 4:
|
||||
interaction_tone = "偏松弛,像熟人闲聊"
|
||||
|
||||
expressiveness_style = "克制"
|
||||
if short_message_ratio >= 0.58 or relaxed_hits >= serious_hits + 4:
|
||||
expressiveness_style = "松弛一点,像随口接话"
|
||||
elif serious_hits >= 12:
|
||||
expressiveness_style = "短句,偏干货"
|
||||
|
||||
return {
|
||||
"interaction_tone": interaction_tone,
|
||||
"humor_style": humor_style,
|
||||
"sharpness_style": sharpness_style,
|
||||
"expressiveness_style": expressiveness_style,
|
||||
}
|
||||
|
||||
def _load_recent_summary_text(self, room_id: str) -> str:
|
||||
candidates: List[Dict] = []
|
||||
for summary_type in ("daily", "manual"):
|
||||
sql = """
|
||||
SELECT *
|
||||
FROM t_message_summary
|
||||
WHERE chatroom_id = %s AND summary_type = %s
|
||||
ORDER BY period_end DESC, update_time DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
rows = self.summary_db.execute_query(sql, (room_id, summary_type)) or []
|
||||
candidates.extend(rows)
|
||||
if not candidates:
|
||||
return ""
|
||||
candidates.sort(
|
||||
key=lambda item: (str(item.get("period_end", "")), str(item.get("update_time", ""))),
|
||||
reverse=True,
|
||||
)
|
||||
return str(candidates[0].get("summary_text", "") or "").strip()
|
||||
412
plugins/ai_auto_response/memory/memory_ranker.py
Normal file
412
plugins/ai_auto_response/memory/memory_ranker.py
Normal file
@@ -0,0 +1,412 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
class MemoryRanker:
|
||||
DOMAIN_HINTS = {
|
||||
"openclaw": {"openclaw", "claw", "节点", "工作流", "编排", "agent"},
|
||||
"robotics": {"机器人", "bot", "插件", "自动化", "微信", "消息"},
|
||||
"infra": {"docker", "部署", "日志", "配置", "接口", "报错", "服务器"},
|
||||
"dota": {"dota", "dota2", "刀塔", "英雄", "团战", "版本", "战绩"},
|
||||
}
|
||||
|
||||
def __init__(self, config: Dict | None = None):
|
||||
self.config = config or {}
|
||||
self.max_vector_items = int(self.config.get("ranked_vector_items", 2) or 2)
|
||||
self.max_social_items = int(self.config.get("ranked_social_items", 2) or 2)
|
||||
self.max_group_fact_items = int(self.config.get("ranked_group_fact_items", 3) or 3)
|
||||
self.max_member_focus_items = int(self.config.get("ranked_member_focus_items", 4) or 4)
|
||||
self.domain_weight = float(self.config.get("memory_domain_weight", 2.5) or 2.5)
|
||||
self.relation_weight = float(self.config.get("memory_relation_weight", 2.0) or 2.0)
|
||||
self.freshness_weight = float(self.config.get("memory_freshness_weight", 1.5) or 1.5)
|
||||
self.trigger_weight = float(self.config.get("memory_trigger_weight", 1.2) or 1.2)
|
||||
|
||||
def rank(
|
||||
self,
|
||||
*,
|
||||
content: str,
|
||||
quote_context: Dict,
|
||||
group_profile: Dict,
|
||||
member_context: Dict,
|
||||
vector_memories: List[Dict],
|
||||
social_context: Dict,
|
||||
group_facts: Dict,
|
||||
trigger: Dict,
|
||||
) -> Dict:
|
||||
focus_text = " ".join(
|
||||
[
|
||||
str(content or ""),
|
||||
str((quote_context or {}).get("title", "") or ""),
|
||||
str((quote_context or {}).get("quote_body", "") or ""),
|
||||
]
|
||||
)
|
||||
focus_tokens = self._extract_tokens(focus_text)
|
||||
focus_domain = str(group_profile.get("knowledge_domain", "") or "").strip().lower()
|
||||
relation_targets = self._extract_relation_targets(content, quote_context)
|
||||
trigger_type = str((trigger or {}).get("trigger_type", "") or "")
|
||||
|
||||
ranked_vector_memories, vector_debug = self._rank_vector_memories(
|
||||
vector_memories, focus_tokens, focus_domain, relation_targets, trigger_type
|
||||
)
|
||||
ranked_social_context, social_debug = self._rank_social_context(
|
||||
social_context, focus_tokens, focus_domain, relation_targets, trigger_type
|
||||
)
|
||||
ranked_group_facts, fact_debug = self._rank_group_facts(
|
||||
group_facts, focus_tokens, focus_domain, relation_targets, trigger_type
|
||||
)
|
||||
member_memory_focus, member_debug = self._rank_member_memory(
|
||||
member_context, focus_tokens, focus_domain, relation_targets, trigger_type
|
||||
)
|
||||
|
||||
return {
|
||||
"vector_memories": ranked_vector_memories,
|
||||
"social_context": ranked_social_context,
|
||||
"group_facts": ranked_group_facts,
|
||||
"member_memory_focus": member_memory_focus,
|
||||
"debug": {
|
||||
"vector": vector_debug,
|
||||
"social": social_debug,
|
||||
"facts": fact_debug,
|
||||
"member": member_debug,
|
||||
},
|
||||
}
|
||||
|
||||
def _rank_vector_memories(
|
||||
self,
|
||||
items: List[Dict],
|
||||
focus_tokens: set[str],
|
||||
focus_domain: str,
|
||||
relation_targets: set[str],
|
||||
trigger_type: str,
|
||||
) -> Tuple[List[Dict], List[str]]:
|
||||
scored = []
|
||||
for item in items or []:
|
||||
text = " ".join(
|
||||
[
|
||||
str(item.get("content_summary", "") or ""),
|
||||
str(item.get("summary_text", "") or ""),
|
||||
str(item.get("text", "") or ""),
|
||||
" ".join(item.get("topic_tags", []) or []),
|
||||
]
|
||||
)
|
||||
score, reasons = self._score_text(
|
||||
text=text,
|
||||
focus_tokens=focus_tokens,
|
||||
focus_domain=focus_domain,
|
||||
relation_targets=relation_targets,
|
||||
trigger_type=trigger_type,
|
||||
freshness_hint=self._freshness_from_payload(item),
|
||||
relation_hint=" ".join(item.get("topic_tags", []) or []),
|
||||
)
|
||||
if score <= 0:
|
||||
continue
|
||||
scored.append((score, item, self._describe_vector_item(item, reasons, score)))
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
top = scored[: self.max_vector_items]
|
||||
return [item for _, item, _ in top], [debug for _, _, debug in top]
|
||||
|
||||
def _rank_social_context(
|
||||
self,
|
||||
social_context: Dict,
|
||||
focus_tokens: set[str],
|
||||
focus_domain: str,
|
||||
relation_targets: set[str],
|
||||
trigger_type: str,
|
||||
) -> Tuple[Dict, List[str]]:
|
||||
items = []
|
||||
for item in (social_context or {}).get("items", []) or []:
|
||||
text = " ".join(
|
||||
[
|
||||
str(item.get("target_name", "") or ""),
|
||||
str(item.get("relation_type", "") or ""),
|
||||
" ".join(item.get("topic_tags", []) or []),
|
||||
]
|
||||
)
|
||||
score, reasons = self._score_text(
|
||||
text=text,
|
||||
focus_tokens=focus_tokens,
|
||||
focus_domain=focus_domain,
|
||||
relation_targets=relation_targets,
|
||||
trigger_type=trigger_type,
|
||||
freshness_hint=float(item.get("strength", 0.0)),
|
||||
relation_hint=str(item.get("target_name", "") or ""),
|
||||
)
|
||||
strength_bonus = float(item.get("strength", 0.0)) * 1.5
|
||||
score += strength_bonus
|
||||
if score <= 0:
|
||||
continue
|
||||
items.append(
|
||||
(
|
||||
score,
|
||||
item,
|
||||
self._describe_social_item(item, reasons + ([f"strength={strength_bonus:.1f}"] if strength_bonus else []), score),
|
||||
)
|
||||
)
|
||||
items.sort(key=lambda x: x[0], reverse=True)
|
||||
top = items[: self.max_social_items]
|
||||
ranked_items = [item for _, item, _ in top]
|
||||
return (
|
||||
{
|
||||
"items": ranked_items,
|
||||
"prompt": self._build_ranked_social_prompt(ranked_items),
|
||||
},
|
||||
[debug for _, _, debug in top],
|
||||
)
|
||||
|
||||
def _rank_group_facts(
|
||||
self,
|
||||
group_facts: Dict,
|
||||
focus_tokens: set[str],
|
||||
focus_domain: str,
|
||||
relation_targets: set[str],
|
||||
trigger_type: str,
|
||||
) -> Tuple[Dict, List[str]]:
|
||||
items = []
|
||||
for item in (group_facts or {}).get("items", []) or []:
|
||||
text = str(item.get("summary", "") or "")
|
||||
score, reasons = self._score_text(
|
||||
text=text,
|
||||
focus_tokens=focus_tokens,
|
||||
focus_domain=focus_domain,
|
||||
relation_targets=relation_targets,
|
||||
trigger_type=trigger_type,
|
||||
freshness_hint=float(item.get("weight", 0.0)) / 4.0,
|
||||
relation_hint=text,
|
||||
)
|
||||
weight_bonus = float(item.get("weight", 0.0))
|
||||
score += weight_bonus
|
||||
if score <= 0:
|
||||
continue
|
||||
items.append(
|
||||
(
|
||||
score,
|
||||
item,
|
||||
self._describe_fact_item(item, reasons + ([f"weight={weight_bonus:.1f}"] if weight_bonus else []), score),
|
||||
)
|
||||
)
|
||||
items.sort(key=lambda x: x[0], reverse=True)
|
||||
top = items[: self.max_group_fact_items]
|
||||
ranked_items = [item for _, item, _ in top]
|
||||
return (
|
||||
{
|
||||
"items": ranked_items,
|
||||
"prompt": self._build_ranked_group_fact_prompt(ranked_items),
|
||||
},
|
||||
[debug for _, _, debug in top],
|
||||
)
|
||||
|
||||
def _rank_member_memory(
|
||||
self,
|
||||
member_context: Dict,
|
||||
focus_tokens: set[str],
|
||||
focus_domain: str,
|
||||
relation_targets: set[str],
|
||||
trigger_type: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
if not member_context:
|
||||
return [], []
|
||||
meta = member_context.get("meta", {}) or {}
|
||||
candidates = []
|
||||
|
||||
def push_items(values, label: str) -> None:
|
||||
for value in values or []:
|
||||
if isinstance(value, dict):
|
||||
text = str(
|
||||
value.get("name")
|
||||
or value.get("label")
|
||||
or value.get("value")
|
||||
or value.get("text")
|
||||
or ""
|
||||
).strip()
|
||||
else:
|
||||
text = str(value or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
score, reasons = self._score_text(
|
||||
text=text,
|
||||
focus_tokens=focus_tokens,
|
||||
focus_domain=focus_domain,
|
||||
relation_targets=relation_targets,
|
||||
trigger_type=trigger_type,
|
||||
freshness_hint=1.0 if label in {"近期关注", "近期状态"} else 0.4,
|
||||
relation_hint=text,
|
||||
)
|
||||
if score <= 0:
|
||||
continue
|
||||
candidates.append((score, f"{label}:{text}", self._describe_member_item(label, text, reasons, score)))
|
||||
|
||||
push_items(member_context.get("topics_of_interest", []), "长期主题")
|
||||
push_items(member_context.get("recent_focus", []), "近期关注")
|
||||
push_items(meta.get("skill_profile", []), "技能侧重点")
|
||||
push_items(meta.get("problem_solving_profile", []), "处理问题方式")
|
||||
push_items(meta.get("reply_entry_profile", []), "有效接话点")
|
||||
push_items(meta.get("long_term_reply_preferences", []), "回复偏好")
|
||||
push_items(meta.get("recent_state", []), "近期状态")
|
||||
|
||||
unique_lines = []
|
||||
unique_debug = []
|
||||
for _, line, debug in sorted(candidates, key=lambda x: x[0], reverse=True):
|
||||
if line not in unique_lines:
|
||||
unique_lines.append(line)
|
||||
unique_debug.append(debug)
|
||||
return unique_lines[: self.max_member_focus_items], unique_debug[: self.max_member_focus_items]
|
||||
|
||||
def _build_ranked_social_prompt(self, items: List[Dict]) -> str:
|
||||
if not items:
|
||||
return ""
|
||||
lines = ["下面这些群关系只在当前这次话题明显相关时轻微利用。"]
|
||||
for item in items:
|
||||
tags = "、".join(item.get("topic_tags", [])[:3]) or "泛互动"
|
||||
lines.append(
|
||||
f"- {item.get('target_name', '某成员')}:{item.get('relation_type', 'frequent_turn_taking')};"
|
||||
f"强度={item.get('strength', 0.0)};"
|
||||
f"相关标签={tags}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
def _build_ranked_group_fact_prompt(self, items: List[Dict]) -> str:
|
||||
if not items:
|
||||
return ""
|
||||
lines = ["下面这些群事实是按当前话题重排后的结果,只在相关时参考。"]
|
||||
for item in items:
|
||||
lines.append(
|
||||
f"- [{item.get('fact_type', 'fact')}] {item.get('summary', '')}; weight={item.get('weight', 1)}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
def _score_text(
|
||||
self,
|
||||
*,
|
||||
text: str,
|
||||
focus_tokens: set[str],
|
||||
focus_domain: str,
|
||||
relation_targets: set[str],
|
||||
trigger_type: str,
|
||||
freshness_hint: float = 0.0,
|
||||
relation_hint: str = "",
|
||||
) -> Tuple[float, List[str]]:
|
||||
normalized = str(text or "").strip().lower()
|
||||
if not normalized:
|
||||
return 0.0, []
|
||||
text_tokens = self._extract_tokens(normalized)
|
||||
overlap = len(focus_tokens & text_tokens)
|
||||
score = overlap * 2.0
|
||||
reasons: List[str] = []
|
||||
if overlap:
|
||||
reasons.append(f"overlap={overlap}")
|
||||
if focus_domain and focus_domain in self.DOMAIN_HINTS:
|
||||
if self.DOMAIN_HINTS[focus_domain] & text_tokens:
|
||||
score += self.domain_weight
|
||||
reasons.append("domain")
|
||||
if relation_targets and any(target in (relation_hint or normalized) for target in relation_targets):
|
||||
score += self.relation_weight
|
||||
reasons.append("relation")
|
||||
score += max(freshness_hint, 0.0) * self.freshness_weight
|
||||
if freshness_hint > 0:
|
||||
reasons.append(f"fresh={freshness_hint:.1f}")
|
||||
trigger_bonus = self._trigger_bonus(trigger_type, normalized)
|
||||
score += trigger_bonus * self.trigger_weight
|
||||
if trigger_bonus > 0:
|
||||
reasons.append(f"trigger={trigger_type}")
|
||||
if not focus_tokens and normalized:
|
||||
score += 0.5
|
||||
reasons.append("fallback")
|
||||
return score, reasons
|
||||
|
||||
@staticmethod
|
||||
def _compact_reasons(reasons: List[str]) -> str:
|
||||
cleaned = []
|
||||
for reason in reasons:
|
||||
value = str(reason or "").strip()
|
||||
if value and value not in cleaned:
|
||||
cleaned.append(value)
|
||||
return "+".join(cleaned[:3]) or "-"
|
||||
|
||||
def _describe_vector_item(self, item: Dict, reasons: List[str], score: float) -> str:
|
||||
label = (
|
||||
str(item.get("memory_type", "") or "").strip()
|
||||
or str(item.get("source_id", "") or "").strip()
|
||||
or "vector"
|
||||
)
|
||||
return f"{label}:{score:.1f}@{self._compact_reasons(reasons)}"
|
||||
|
||||
def _describe_social_item(self, item: Dict, reasons: List[str], score: float) -> str:
|
||||
label = str(item.get("target_name", "") or "member").strip()
|
||||
relation_type = str(item.get("relation_type", "") or "").strip()
|
||||
if relation_type:
|
||||
label = f"{label}/{relation_type}"
|
||||
return f"{label}:{score:.1f}@{self._compact_reasons(reasons)}"
|
||||
|
||||
def _describe_fact_item(self, item: Dict, reasons: List[str], score: float) -> str:
|
||||
label = str(item.get("fact_type", "") or "fact").strip()
|
||||
return f"{label}:{score:.1f}@{self._compact_reasons(reasons)}"
|
||||
|
||||
def _describe_member_item(self, label: str, text: str, reasons: List[str], score: float) -> str:
|
||||
short_text = re.sub(r"\s+", "", str(text or ""))[:10]
|
||||
return f"{label}:{short_text}:{score:.1f}@{self._compact_reasons(reasons)}"
|
||||
|
||||
def _trigger_bonus(self, trigger_type: str, normalized: str) -> float:
|
||||
trigger_type = str(trigger_type or "")
|
||||
if trigger_type in {"at_trigger", "followup_trigger", "quote_followup_trigger"}:
|
||||
return 1.0
|
||||
if trigger_type == "question_trigger" and any(word in normalized for word in ["报错", "配置", "接口", "原因", "方案"]):
|
||||
return 1.0
|
||||
if trigger_type in {"social_trigger", "light_social_trigger"} and any(word in normalized for word in ["互动", "吐槽", "关系", "搭子"]):
|
||||
return 0.8
|
||||
return 0.0
|
||||
|
||||
def _freshness_from_payload(self, item: Dict) -> float:
|
||||
for key in ("created_at", "last_active_at"):
|
||||
value = str(item.get(key, "") or "").strip()
|
||||
if not value:
|
||||
continue
|
||||
parsed = self._parse_datetime(value)
|
||||
if not parsed:
|
||||
continue
|
||||
days = max((datetime.now() - parsed).days, 0)
|
||||
if days <= 1:
|
||||
return 1.0
|
||||
if days <= 7:
|
||||
return 0.7
|
||||
if days <= 30:
|
||||
return 0.4
|
||||
return 0.15
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _extract_relation_targets(content: str, quote_context: Dict) -> set[str]:
|
||||
targets = set()
|
||||
quote_sender = str((quote_context or {}).get("quote_sender_name", "") or "").strip().lower()
|
||||
if quote_sender:
|
||||
targets.add(quote_sender)
|
||||
normalized = str(content or "").strip().lower()
|
||||
for match in re.findall(r"@?[\u4e00-\u9fffA-Za-z0-9_]{2,12}", normalized):
|
||||
targets.add(match.lower())
|
||||
return targets
|
||||
|
||||
@staticmethod
|
||||
def _parse_datetime(value: str) -> datetime | None:
|
||||
if not value:
|
||||
return None
|
||||
for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d"):
|
||||
try:
|
||||
return datetime.strptime(value, fmt)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_tokens(content: str) -> set[str]:
|
||||
text = str(content or "").lower()
|
||||
tokens = set(re.findall(r"[a-z0-9_\\-]{3,}", text))
|
||||
for keyword in [
|
||||
"openclaw", "qdrant", "ollama", "docker", "python", "api", "插件", "机器人",
|
||||
"日志", "配置", "报错", "部署", "图片", "记忆", "群聊", "dota", "战绩",
|
||||
"吃饭", "摸鱼", "项目", "接口", "模型",
|
||||
]:
|
||||
if keyword in text:
|
||||
tokens.add(keyword)
|
||||
return tokens
|
||||
93
plugins/ai_auto_response/memory/memory_store.py
Normal file
93
plugins/ai_auto_response/memory/memory_store.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from db.member_context_db import MemberContextDBOperator
|
||||
from db.message_storage import MessageStorageDB
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
def __init__(self, db_manager, config: Dict):
|
||||
self.config = config or {}
|
||||
self.message_db = MessageStorageDB(db_manager)
|
||||
self.member_context_db = MemberContextDBOperator(db_manager)
|
||||
self.followup_sessions: Dict[str, Dict] = {}
|
||||
|
||||
def get_recent_messages(self, room_id: str) -> List[Dict]:
|
||||
hours = int(self.config.get("active_context_hours", 8))
|
||||
recent = self.message_db.get_recent_messages(room_id, hours_ago=hours, min_content_length=0) or []
|
||||
size = int(self.config.get("recent_context_size", 30))
|
||||
return recent[-size:]
|
||||
|
||||
def get_latest_image_message(self, room_id: str, before_timestamp: str = "") -> Optional[Dict]:
|
||||
hours = int(self.config.get("active_context_hours", 8))
|
||||
return self.message_db.get_latest_image_message(room_id, before_timestamp=before_timestamp, hours_ago=hours)
|
||||
|
||||
def get_member_context(self, room_id: str, wxid: str) -> Optional[Dict]:
|
||||
if not self.config.get("enable_member_context", True):
|
||||
return None
|
||||
return self.member_context_db.get_member_context(room_id, wxid)
|
||||
|
||||
def build_memory_hints(self, room_id: str, wxid: str) -> Dict:
|
||||
lookback_days = int(self.config.get("memory_lookback_days", 180))
|
||||
returning_days = int(self.config.get("returning_member_days", 7))
|
||||
long_absent_days = int(self.config.get("long_absent_member_days", 30))
|
||||
active_dates = self.message_db.get_member_active_dates(room_id, wxid, days=lookback_days) or []
|
||||
member_context = self.get_member_context(room_id, wxid)
|
||||
|
||||
last_active_at = ""
|
||||
returning_state = ""
|
||||
days_since_active = None
|
||||
if active_dates:
|
||||
last_item = active_dates[-1]
|
||||
last_active_at = last_item.get("last_message_time") or ""
|
||||
parsed = self._parse_datetime(last_active_at)
|
||||
if parsed:
|
||||
days_since_active = max((datetime.now() - parsed).days, 0)
|
||||
if days_since_active >= long_absent_days:
|
||||
returning_state = "long_absent_member"
|
||||
elif days_since_active >= returning_days:
|
||||
returning_state = "returning_member"
|
||||
|
||||
followup = self._get_followup_state(room_id, wxid)
|
||||
return {
|
||||
"member_context": member_context or {},
|
||||
"last_active_at": last_active_at,
|
||||
"days_since_active": days_since_active,
|
||||
"returning_member_state": returning_state,
|
||||
"is_followup": followup,
|
||||
}
|
||||
|
||||
def note_bot_reply(self, room_id: str, wxid: str, topic: str = "") -> None:
|
||||
key = self._followup_key(room_id, wxid)
|
||||
self.followup_sessions[key] = {
|
||||
"last_bot_reply_at": datetime.now(),
|
||||
"topic": topic,
|
||||
}
|
||||
|
||||
def _get_followup_state(self, room_id: str, wxid: str) -> bool:
|
||||
key = self._followup_key(room_id, wxid)
|
||||
state = self.followup_sessions.get(key)
|
||||
if not state:
|
||||
return False
|
||||
timeout = int(self.config.get("followup_session_window_sec", 300))
|
||||
last_reply_at = state.get("last_bot_reply_at")
|
||||
if not last_reply_at:
|
||||
return False
|
||||
return (datetime.now() - last_reply_at).total_seconds() <= timeout
|
||||
|
||||
@staticmethod
|
||||
def _followup_key(room_id: str, wxid: str) -> str:
|
||||
return f"{room_id}:{wxid}"
|
||||
|
||||
@staticmethod
|
||||
def _parse_datetime(value: str) -> Optional[datetime]:
|
||||
if not value:
|
||||
return None
|
||||
for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d"):
|
||||
try:
|
||||
return datetime.strptime(value, fmt)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
118
plugins/ai_auto_response/memory/social_memory.py
Normal file
118
plugins/ai_auto_response/memory/social_memory.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
from db.message_storage import MessageStorageDB
|
||||
|
||||
|
||||
class SocialMemoryService:
|
||||
def __init__(self, db_manager, config: Dict | None = None):
|
||||
self.config = config or {}
|
||||
self.message_db = MessageStorageDB(db_manager)
|
||||
self.lookback_hours = int(self.config.get("social_lookback_hours", 72) or 72)
|
||||
self.max_relation_items = int(self.config.get("max_relation_items", 4) or 4)
|
||||
self.cache_ttl_seconds = int(self.config.get("social_cache_ttl_seconds", 120) or 120)
|
||||
self._relation_cache: Dict[str, Dict] = {}
|
||||
|
||||
def build_social_context(
|
||||
self,
|
||||
room_id: str,
|
||||
sender: str,
|
||||
current_content: str,
|
||||
recent_messages: List[Dict],
|
||||
name_map: Dict[str, str] | None = None,
|
||||
) -> Dict:
|
||||
name_map = name_map or {}
|
||||
history = self._get_room_history(room_id)
|
||||
if not history:
|
||||
return {"items": [], "prompt": ""}
|
||||
relation_scores = defaultdict(float)
|
||||
shared_topics = defaultdict(Counter)
|
||||
previous_sender = ""
|
||||
for item in history:
|
||||
item_sender = str(item.get("sender", "") or "").strip()
|
||||
content = str(item.get("content", "") or "").strip()
|
||||
if not item_sender or not content:
|
||||
previous_sender = item_sender or previous_sender
|
||||
continue
|
||||
if previous_sender and previous_sender != item_sender:
|
||||
pair = (previous_sender, item_sender)
|
||||
relation_scores[pair] += 1.0
|
||||
for token in self._extract_tokens(content):
|
||||
shared_topics[pair][token] += 1
|
||||
previous_sender = item_sender
|
||||
|
||||
sender_links = []
|
||||
for (src, dst), score in relation_scores.items():
|
||||
if sender not in {src, dst}:
|
||||
continue
|
||||
other = dst if src == sender else src
|
||||
relation_type = "frequent_turn_taking"
|
||||
if score >= 8:
|
||||
relation_type = "stable_pairing"
|
||||
elif score >= 4:
|
||||
relation_type = "often_reply_to"
|
||||
topic_tags = [item for item, _ in shared_topics[(src, dst)].most_common(3)]
|
||||
sender_links.append({
|
||||
"target_wxid": other,
|
||||
"target_name": name_map.get(other, other),
|
||||
"relation_type": relation_type,
|
||||
"strength": round(min(score / 10.0, 1.0), 2),
|
||||
"topic_tags": topic_tags,
|
||||
})
|
||||
|
||||
sender_links.sort(key=lambda item: item.get("strength", 0.0), reverse=True)
|
||||
sender_links = sender_links[: self.max_relation_items]
|
||||
prompt = self._build_prompt(sender_links, current_content)
|
||||
return {
|
||||
"items": sender_links,
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
def _get_room_history(self, room_id: str) -> List[Dict]:
|
||||
now = time.time()
|
||||
cached = self._relation_cache.get(room_id)
|
||||
if cached and now - cached.get("ts", 0) <= self.cache_ttl_seconds:
|
||||
return cached.get("messages", []) or []
|
||||
history = self.message_db.get_messages_for_summary(
|
||||
room_id,
|
||||
hours_ago=self.lookback_hours,
|
||||
min_messages=20,
|
||||
max_hours=self.lookback_hours,
|
||||
max_results=300,
|
||||
) or []
|
||||
self._relation_cache[room_id] = {"ts": now, "messages": history}
|
||||
return history
|
||||
|
||||
@staticmethod
|
||||
def _build_prompt(items: List[Dict], current_content: str) -> str:
|
||||
if not items:
|
||||
return ""
|
||||
lines = [
|
||||
"群内关系记忆只可在当前话题明显相关时轻微利用,不要像在背档案。",
|
||||
]
|
||||
for item in items:
|
||||
tags = "、".join(item.get("topic_tags", [])[:3]) or "泛互动"
|
||||
lines.append(
|
||||
f"- 你与 {item.get('target_name', '某成员')} 的群内关系倾向:"
|
||||
f"{item.get('relation_type', 'frequent_turn_taking')},"
|
||||
f"强度={item.get('strength', 0.0)},"
|
||||
f"常见共现话题={tags}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _extract_tokens(content: str) -> set[str]:
|
||||
import re
|
||||
|
||||
text = str(content or "").lower()
|
||||
tokens = set(re.findall(r"[a-z0-9_\\-]{3,}", text))
|
||||
for keyword in [
|
||||
"openclaw", "docker", "python", "qdrant", "ollama", "部署", "报错", "token",
|
||||
"机器人", "插件", "模型", "dota", "吃饭", "项目",
|
||||
]:
|
||||
if keyword in text:
|
||||
tokens.add(keyword)
|
||||
return tokens
|
||||
138
plugins/ai_auto_response/memory/vector_memory.py
Normal file
138
plugins/ai_auto_response/memory/vector_memory.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class VectorMemoryStore:
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config or {}
|
||||
self.enabled = bool(self.config.get("enable_vector_memory"))
|
||||
self.qdrant_url = str(self.config.get("qdrant_url", "")).rstrip("/")
|
||||
self.collection = self.config.get("qdrant_collection", "")
|
||||
self.ollama_base_url = str(self.config.get("ollama_base_url", "")).rstrip("/")
|
||||
self.embedding_model = self.config.get("embedding_model", "")
|
||||
self.top_k = int(self.config.get("vector_top_k", 5))
|
||||
self.min_score = float(self.config.get("vector_min_score", 0.65))
|
||||
self.collection_ready = False
|
||||
self.last_error = ""
|
||||
|
||||
def should_search(self, reply_mode: str, trigger_type: str, returning_state: str) -> bool:
|
||||
modes = set(self.config.get("vector_trigger_modes", []))
|
||||
return any(item in modes for item in [reply_mode, trigger_type, returning_state] if item)
|
||||
|
||||
def search(self, query: str, room_id: str, wxid: str = "") -> List[Dict]:
|
||||
self.last_error = ""
|
||||
if not self.enabled or not self.qdrant_url or not self.collection or not self.embedding_model:
|
||||
self.last_error = "vector_disabled_or_incomplete"
|
||||
return []
|
||||
embedding = self._embed(query)
|
||||
if not embedding:
|
||||
self.last_error = "embed_failed"
|
||||
return []
|
||||
self._ensure_collection(len(embedding))
|
||||
must = [{"key": "chatroom_id", "match": {"value": room_id}}]
|
||||
payload = {
|
||||
"vector": embedding,
|
||||
"limit": self.top_k,
|
||||
"with_payload": True,
|
||||
"score_threshold": self.min_score,
|
||||
"filter": {"must": must},
|
||||
}
|
||||
if wxid:
|
||||
payload["filter"]["should"] = [{"key": "wxid", "match": {"value": wxid}}]
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.qdrant_url}/collections/{self.collection}/points/search",
|
||||
json=payload,
|
||||
timeout=15,
|
||||
)
|
||||
response.raise_for_status()
|
||||
items = response.json().get("result", []) or []
|
||||
return [item.get("payload", {}) for item in items if item.get("payload")]
|
||||
except Exception as exc:
|
||||
self.last_error = f"search_failed:{exc}"
|
||||
return []
|
||||
|
||||
def upsert_memory(self, memory_id: str, text: str, payload: Dict) -> bool:
|
||||
self.last_error = ""
|
||||
if not self.enabled or not text or not self.qdrant_url or not self.collection or not self.embedding_model:
|
||||
self.last_error = "vector_disabled_or_incomplete"
|
||||
return False
|
||||
embedding = self._embed(text)
|
||||
if not embedding:
|
||||
self.last_error = "embed_failed"
|
||||
return False
|
||||
if not self._ensure_collection(len(embedding)):
|
||||
self.last_error = "ensure_collection_failed"
|
||||
return False
|
||||
point = {
|
||||
"points": [
|
||||
{
|
||||
"id": self._stable_id(memory_id),
|
||||
"vector": embedding,
|
||||
"payload": payload | {"content_summary": text},
|
||||
}
|
||||
]
|
||||
}
|
||||
try:
|
||||
response = requests.put(
|
||||
f"{self.qdrant_url}/collections/{self.collection}/points",
|
||||
json=point,
|
||||
timeout=15,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except Exception as exc:
|
||||
self.last_error = f"upsert_failed:{exc}"
|
||||
return False
|
||||
|
||||
def _embed(self, query: str) -> List[float]:
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.ollama_base_url}/api/embeddings",
|
||||
json={"model": self.embedding_model, "prompt": query},
|
||||
timeout=20,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json().get("embedding") or []
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _ensure_collection(self, vector_size: int) -> bool:
|
||||
if self.collection_ready:
|
||||
return True
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.qdrant_url}/collections/{self.collection}",
|
||||
timeout=10,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
self.collection_ready = True
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
response = requests.put(
|
||||
f"{self.qdrant_url}/collections/{self.collection}",
|
||||
json={
|
||||
"vectors": {
|
||||
"size": vector_size,
|
||||
"distance": "Cosine",
|
||||
}
|
||||
},
|
||||
timeout=15,
|
||||
)
|
||||
response.raise_for_status()
|
||||
self.collection_ready = True
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _stable_id(memory_id: str) -> int:
|
||||
digest = hashlib.md5(memory_id.encode("utf-8")).hexdigest()[:15]
|
||||
return int(digest, 16)
|
||||
Reference in New Issue
Block a user