Files
abot/db/group_profile_snapshot_db.py

110 lines
4.6 KiB
Python

# -*- coding: utf-8 -*-
import json
from datetime import datetime
from typing import Dict, Optional
from db.base import BaseDBOperator
from db.connection import DBConnectionManager
class GroupProfileSnapshotDBOperator(BaseDBOperator):
"""群画像快照数据库操作"""
def __init__(self, db_manager: DBConnectionManager):
super().__init__(db_manager)
self._create_tables()
def _create_tables(self):
try:
self.execute_update("""
CREATE TABLE IF NOT EXISTS t_group_profile_snapshot (
id INT AUTO_INCREMENT PRIMARY KEY,
chatroom_id VARCHAR(64) NOT NULL COMMENT '群聊ID',
group_name VARCHAR(128) DEFAULT '' COMMENT '群名称',
profile_json LONGTEXT COMMENT '群画像快照JSON',
source_summary_latest_at DATETIME NULL COMMENT '构建时参考的最近群总结更新时间',
source_message_latest_at DATETIME NULL COMMENT '构建时参考的最近群消息时间',
source_summary_count INT NOT NULL DEFAULT 0 COMMENT '构建时参考的群总结条数',
source_message_sample_count INT NOT NULL DEFAULT 0 COMMENT '构建时参考的消息样本数',
last_generated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '最后一次生成时间',
create_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
update_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
UNIQUE KEY idx_group_profile_snapshot (chatroom_id),
KEY idx_group_profile_generated_at (last_generated_at)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='群画像快照表';
""")
except Exception as e:
self.LOG.error(f"创建群画像快照表失败: {e}")
def get_snapshot(self, chatroom_id: str) -> Optional[Dict]:
try:
sql = """
SELECT *
FROM t_group_profile_snapshot
WHERE chatroom_id = %s
LIMIT 1
"""
row = self.execute_query(sql, (chatroom_id,), fetch_one=True)
return self._deserialize_row(row)
except Exception as e:
self.LOG.error(f"获取群画像快照失败: {e}")
return None
def save_snapshot(self, snapshot: Dict) -> bool:
try:
data = {
"chatroom_id": snapshot.get("chatroom_id", ""),
"group_name": snapshot.get("group_name", ""),
"profile_json": json.dumps(snapshot.get("profile", {}), ensure_ascii=False),
"source_summary_latest_at": snapshot.get("source_summary_latest_at"),
"source_message_latest_at": snapshot.get("source_message_latest_at"),
"source_summary_count": int(snapshot.get("source_summary_count", 0) or 0),
"source_message_sample_count": int(snapshot.get("source_message_sample_count", 0) or 0),
"last_generated_at": snapshot.get(
"last_generated_at",
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
),
}
fields = ", ".join(data.keys())
placeholders = ", ".join(["%s"] * len(data))
update_clause = ", ".join(
[f"{key}=VALUES({key})" for key in data.keys() if key != "chatroom_id"]
)
sql = f"""
INSERT INTO t_group_profile_snapshot ({fields})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {update_clause}
"""
return self.execute_update(sql, tuple(data.values()))
except Exception as e:
self.LOG.error(f"保存群画像快照失败: {e}")
return False
@staticmethod
def _deserialize_row(row: Optional[Dict]) -> Optional[Dict]:
if not row:
return row
profile_json = row.get("profile_json")
if profile_json:
try:
row["profile_json"] = json.loads(profile_json)
except Exception:
row["profile_json"] = {}
else:
row["profile_json"] = {}
for key in (
"source_summary_latest_at",
"source_message_latest_at",
"last_generated_at",
"create_time",
"update_time",
):
value = row.get(key)
if isinstance(value, datetime):
row[key] = value.strftime("%Y-%m-%d %H:%M:%S")
row["profile"] = row.get("profile_json", {})
return row