Files
abot/db/message_storage.py

594 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Dict, List, Optional
from db.base import BaseDBOperator
from db.connection import DBConnectionManager
from wechat_ipad.models.message import WxMessage
class MessageStorageDB(BaseDBOperator):
"""消息存储相关数据库操作"""
def __init__(self, db_manager: DBConnectionManager):
super().__init__(db_manager)
def archive_message(self, msg: WxMessage) -> bool:
"""存档消息"""
now_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
sql = """
INSERT INTO messages (group_id, timestamp, sender, content, message_type, attachment_url, message_id, message_xml, message_thumb)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
params = (
msg.roomid, now_time, msg.sender, str(msg.content.clean_content), msg.msg_type.value,
str(msg.content.xml_content), msg.msg_id,
msg.msg_source, "")
result = self.execute_update(sql, params)
return result
def get_recent_messages(self, group_id: str, hours_ago: int = 8, min_content_length: int = 6) -> List[Dict]:
"""获取最近的消息"""
sql = """
SELECT timestamp, sender, content, message_type
FROM messages
WHERE timestamp >= DATE_SUB(NOW(), INTERVAL %s HOUR)
AND message_type in (1, 49)
AND group_id = %s
AND length(content) > %s
AND CHAR_LENGTH(content) < 300
AND content NOT LIKE '/%'
"""
params = (hours_ago, group_id, min_content_length)
return self.execute_query(sql, params) or []
def get_latest_image_message(self, group_id: str, before_timestamp: str = "", hours_ago: int = 8) -> Optional[Dict]:
"""获取指定群最近一条已落盘图片消息"""
sql = """
SELECT timestamp, sender, content, message_type, image_path
FROM messages
WHERE timestamp >= DATE_SUB(NOW(), INTERVAL %s HOUR)
AND group_id = %s
AND message_type = 3
AND image_path IS NOT NULL
AND image_path <> ''
"""
params: List = [hours_ago, group_id]
if before_timestamp:
sql += " AND timestamp <= %s"
params.append(before_timestamp)
sql += " ORDER BY timestamp DESC LIMIT 1"
return self.execute_query(sql, tuple(params), fetch_one=True)
def get_message_by_message_id(self, message_id: int | str) -> Optional[Dict]:
"""根据 message_id 获取单条消息"""
sql = """
SELECT id, group_id, timestamp, sender, content, message_type,
attachment_url, message_id, message_xml, message_thumb, image_path
FROM messages
WHERE message_id = %s
ORDER BY id DESC
LIMIT 1
"""
return self.execute_query(sql, (message_id,), fetch_one=True)
def get_image_message_by_md5(self, md5: str) -> Optional[Dict]:
"""根据图片消息 attachment_url 中的 md5 反查原图消息"""
sql = """
SELECT id, group_id, timestamp, sender, content, message_type,
attachment_url, message_id, message_xml, message_thumb, image_path
FROM messages
WHERE message_type = 3
AND attachment_url IS NOT NULL
AND attachment_url <> ''
AND attachment_url LIKE %s
ORDER BY id DESC
LIMIT 1
"""
return self.execute_query(sql, (f'%md5="{md5}"%',), fetch_one=True)
def get_member_recent_messages(self, group_id: str, wxid: str, days: int = 30,
limit: int = 200, include_today: bool = True) -> List[Dict]:
"""获取指定群成员近期消息"""
if include_today:
sql = """
SELECT timestamp, sender, content, message_type
FROM messages
WHERE timestamp >= DATE_SUB(NOW(), INTERVAL %s DAY)
AND group_id = %s
AND sender = %s
AND message_type IN (1, 49)
AND CHAR_LENGTH(content) BETWEEN 2 AND 500
AND content NOT LIKE '/%%'
ORDER BY timestamp DESC
LIMIT %s
"""
params = (days, group_id, wxid, limit)
else:
sql = """
SELECT timestamp, sender, content, message_type
FROM messages
WHERE timestamp >= DATE_SUB(CURDATE(), INTERVAL %s DAY)
AND timestamp < CURDATE()
AND group_id = %s
AND sender = %s
AND message_type IN (1, 49)
AND CHAR_LENGTH(content) BETWEEN 2 AND 500
AND content NOT LIKE '/%%'
ORDER BY timestamp DESC
LIMIT %s
"""
params = (days, group_id, wxid, limit)
results = self.execute_query(sql, params) or []
return list(reversed(results))
def get_member_messages_since(self, group_id: str, wxid: str, since_time, limit: int = 200) -> List[Dict]:
"""获取指定时间之后的成员消息"""
sql = """
SELECT timestamp, sender, content, message_type
FROM messages
WHERE timestamp > %s
AND group_id = %s
AND sender = %s
AND message_type IN (1, 49)
AND CHAR_LENGTH(content) BETWEEN 2 AND 500
AND content NOT LIKE '/%%'
ORDER BY timestamp ASC
LIMIT %s
"""
if isinstance(since_time, datetime):
since_time = since_time.strftime("%Y-%m-%d %H:%M:%S")
return self.execute_query(sql, (since_time, group_id, wxid, limit)) or []
def get_member_active_dates(self, group_id: str, wxid: str, days: int = 365) -> List[Dict]:
"""获取成员在指定时间窗口内的活跃日期列表"""
sql = """
SELECT
DATE(timestamp) AS message_date,
COUNT(*) AS msg_count,
MIN(timestamp) AS first_message_time,
MAX(timestamp) AS last_message_time
FROM messages
WHERE timestamp >= DATE_SUB(NOW(), INTERVAL %s DAY)
AND group_id = %s
AND sender = %s
AND message_type IN (1, 49)
AND CHAR_LENGTH(content) BETWEEN 2 AND 500
AND content NOT LIKE '/%%'
GROUP BY DATE(timestamp)
ORDER BY message_date ASC
"""
rows = self.execute_query(sql, (days, group_id, wxid)) or []
for row in rows:
for key in ("message_date", "first_message_time", "last_message_time"):
value = row.get(key)
if isinstance(value, datetime):
row[key] = value.strftime("%Y-%m-%d %H:%M:%S") if key != "message_date" else value.strftime("%Y-%m-%d")
elif value:
row[key] = str(value)
return rows
def get_member_messages_on_date(self, group_id: str, wxid: str, target_date: str, limit: int = 120) -> List[Dict]:
"""获取成员在某一天的消息"""
sql = """
SELECT timestamp, sender, content, message_type
FROM messages
WHERE DATE(timestamp) = %s
AND group_id = %s
AND sender = %s
AND message_type IN (1, 49)
AND CHAR_LENGTH(content) BETWEEN 2 AND 500
AND content NOT LIKE '/%%'
ORDER BY timestamp ASC
LIMIT %s
"""
return self.execute_query(sql, (target_date, group_id, wxid, limit)) or []
def get_member_messages_for_group_date(self, group_id: str, target_date: str, limit: int = 5000) -> List[Dict]:
"""获取群在某一天的全部文本消息"""
sql = """
SELECT timestamp, sender, content, message_type
FROM messages
WHERE DATE(timestamp) = %s
AND group_id = %s
AND sender IS NOT NULL
AND sender <> ''
AND message_type IN (1, 49)
AND CHAR_LENGTH(content) BETWEEN 2 AND 500
AND content NOT LIKE '/%%'
ORDER BY timestamp ASC
LIMIT %s
"""
return self.execute_query(sql, (target_date, group_id, limit)) or []
def get_message_count_by_date(self, date: str) -> List[Dict]:
"""获取指定日期的消息统计"""
sql = """
SELECT group_id, sender, COUNT(*) as count
FROM messages
WHERE DATE(timestamp) = %s
GROUP BY group_id, sender
"""
return self.execute_query(sql, (date,)) or []
def get_speech_ranking(self, date: str, group_id: str, limit: int = 20) -> List[Dict]:
"""获取指定日期和群组的发言排名"""
sql = """
SELECT wx_id, count AS speech_count
FROM speech_counts
WHERE date = %s
AND group_id = %s
GROUP BY wx_id
ORDER BY count DESC
LIMIT %s
"""
params = (date, group_id, limit)
results = self.execute_query(sql, params)
return results or []
def insert_speech_count(self, group_id: str, wx_id: str, date: str, count: int) -> bool:
"""插入发言统计数据
Args:
group_id: 群组ID
wx_id: 微信ID
date: 日期,格式为 YYYY-MM-DD
count: 发言计数
Returns:
是否成功插入
"""
sql = """
INSERT INTO speech_counts (group_id, wx_id, date, count)
VALUES (%s, %s, %s, %s)
ON DUPLICATE KEY UPDATE count = VALUES(count)
"""
params = (group_id, wx_id, date, count)
return self.execute_update(sql, params)
def get_message_trend(self, group_id: str, days: int = 7) -> List[Dict]:
"""获取指定群组的消息趋势数据
Args:
group_id: 群组ID
days: 获取最近几天的数据默认7天
Returns:
包含日期和消息数量的列表
"""
sql = """
SELECT
DATE(timestamp) as date,
COUNT(*) as message_count
FROM messages
WHERE group_id = %s
AND timestamp >= DATE_SUB(CURDATE(), INTERVAL %s DAY)
GROUP BY DATE(timestamp)
ORDER BY date
"""
return self.execute_query(sql, (group_id, days)) or []
def get_messages_by_filter(self, group_id=None, start_date=None, end_date=None,
search_text=None, page=1, page_size=20) -> Dict:
"""按条件筛选消息并支持分页和模糊搜索
Args:
group_id: 群组ID可选
start_date: 开始日期格式为YYYY-MM-DD可选
end_date: 结束日期格式为YYYY-MM-DD可选
search_text: 搜索文本,可选,用于模糊搜索消息内容
page: 页码从1开始
page_size: 每页记录数
Returns:
包含消息列表和总记录数的字典
"""
# 构建基础SQL查询
sql_count = "SELECT COUNT(*) as total FROM messages WHERE 1=1 "
sql_data = """
SELECT id, group_id, timestamp, sender, content, message_type,
attachment_url, message_id, message_xml, message_thumb, image_path
FROM messages
WHERE 1=1
"""
# 构建参数列表
params = []
# 添加筛选条件
if group_id:
sql_count += " AND group_id = %s "
sql_data += " AND group_id = %s "
params.append(group_id)
if start_date:
sql_count += " AND DATE(timestamp) >= %s "
sql_data += " AND DATE(timestamp) >= %s "
params.append(start_date)
if end_date:
sql_count += " AND DATE(timestamp) <= %s "
sql_data += " AND DATE(timestamp) <= %s "
params.append(end_date)
if search_text:
sql_count += " AND content LIKE %s "
sql_data += " AND content LIKE %s "
params.append(f"%{search_text}%")
# 添加排序和分页
sql_data += " ORDER BY timestamp DESC "
sql_data += " LIMIT %s OFFSET %s "
# 计算分页参数
offset = (page - 1) * page_size
data_params = params.copy()
data_params.extend([page_size, offset])
# 执行查询
count_result = self.execute_query(sql_count, params)
total = count_result[0]['total'] if count_result else 0
messages = self.execute_query(sql_data, data_params) or []
return {
'total': total,
'page': page,
'page_size': page_size,
'total_pages': (total + page_size - 1) // page_size,
'messages': messages
}
def update_message_image_path(self, message_id, image_base64str):
"""
更新消息的图片路径
Args:
message_id: 消息ID
image_base64str: 图片base64内容
Returns:
bool: 更新成功返回True否则返回False
"""
try:
# 构建SQL语句
sql = """
UPDATE messages
SET message_thumb = %s
WHERE message_id = %s
"""
params = (image_base64str, message_id)
# 执行更新操作
result = self.execute_update(sql, params)
return result
except Exception as e:
# 使用已有的日志记录方式
print(f"更新消息图片路径出错: {e}")
return False
def update_message_image_file_path(self, message_id, image_path):
try:
sql = """
UPDATE messages
SET image_path = %s
WHERE message_id = %s
AND image_path IS NULL
"""
params = (image_path, message_id)
result = self.execute_update(sql, params)
return result
except Exception as e:
print(f"更新消息图片文件路径出错: {e}")
return False
def get_hourly_message_trend(self, group_id: str = None, days: int = 1) -> List[Dict]:
"""获取指定群组的按小时消息趋势数据
Args:
group_id: 群组ID如果为None则获取所有群组的数据
days: 获取最近几天的数据默认1天
Returns:
包含小时和消息数量的列表
"""
sql = """
SELECT
DATE_FORMAT(timestamp, '%Y-%m-%d %H:00') as hour_slot,
COUNT(*) as message_count
FROM messages
WHERE timestamp >= DATE_SUB(NOW(), INTERVAL %s DAY)
"""
params = [days]
# 如果指定了群组ID则添加群组筛选条件
if group_id:
sql += "AND group_id = %s "
params.append(group_id)
# 按小时分组并排序
sql += """
GROUP BY hour_slot
ORDER BY hour_slot
"""
return self.execute_query(sql, tuple(params)) or []
def get_pending_image_messages(self, minutes_ago: int = 10, limit: int = 50) -> List[Dict]:
"""获取最近N分钟内未处理图片的消息image_path IS NULL
Args:
minutes_ago: 查询最近多少分钟的消息默认10分钟
limit: 每次最多处理多少条默认50条
Returns:
包含消息ID、群ID、消息XML等信息的列表
"""
sql = """
SELECT message_id, group_id, message_xml, timestamp,attachment_url
FROM messages
WHERE message_type = '3'
AND image_path IS NULL
AND timestamp >= DATE_SUB(NOW(), INTERVAL %s MINUTE)
AND message_xml IS NOT NULL
AND message_xml != ''
ORDER BY timestamp ASC
LIMIT %s
"""
params = (minutes_ago, limit)
return self.execute_query(sql, params) or []
def get_messages_by_date_range(self, group_id: str, start_date: str, end_date: str = None,
min_content_length: int = 6, max_results: int = 5000) -> List[Dict]:
"""按日期范围获取消息(支持按天总结)
Args:
group_id: 群组ID
start_date: 开始日期,格式 YYYY-MM-DD
end_date: 结束日期,格式 YYYY-MM-DD如果为None则使用start_date当天
min_content_length: 最小内容长度
max_results: 最多返回多少条消息防止数据过多默认5000条足够总结使用
Returns:
消息列表
"""
if end_date is None:
end_date = start_date
sql = """
SELECT timestamp, sender, content, message_type
FROM messages
WHERE DATE(timestamp) >= %s
AND DATE(timestamp) <= %s
AND group_id = %s
AND message_type IN (1, 49)
AND LENGTH(content) > %s
AND CHAR_LENGTH(content) < 300
AND content NOT LIKE '/%'
ORDER BY timestamp ASC
LIMIT %s
"""
params = (start_date, end_date, group_id, min_content_length, max_results)
return self.execute_query(sql, params) or []
def get_messages_for_summary(self, group_id: str, hours_ago: int = 8,
min_messages: int = 50,
max_hours: int = 48,
max_results: int = 5000) -> List[Dict]:
"""智能获取用于总结的消息(自动调整时间范围)
Args:
group_id: 群组ID
hours_ago: 默认查询最近多少小时
min_messages: 最少需要多少条消息,如果不足会扩大时间范围
max_hours: 最大查询多少小时内的消息
max_results: 最多返回多少条消息默认5000条确保有足够数据
Returns:
消息列表
"""
# 先尝试默认时间范围
messages = self.get_recent_messages(group_id, hours_ago=hours_ago)
# 如果消息不足,逐步扩大时间范围
current_hours = hours_ago
while len(messages) < min_messages and current_hours < max_hours:
current_hours += 8 # 每次增加8小时
messages = self.get_recent_messages(group_id, hours_ago=current_hours)
# 限制最大返回数量5000条足以覆盖1-2天的活跃群聊
return messages[:max_results] if messages else []
def get_messages_by_date_range(self, group_id: str, start_time: datetime, end_time: datetime) -> List[Dict]:
"""获取指定时间范围内的消息
Args:
group_id: 群组ID
start_time: 开始时间
end_time: 结束时间
Returns:
消息列表
"""
sql = """
SELECT timestamp, sender, content, message_type
FROM messages
WHERE timestamp >= %s
AND timestamp <= %s
AND message_type in (1, 49)
AND group_id = %s
AND length(content) > 6
AND CHAR_LENGTH(content) < 300
AND content NOT LIKE '/%'
ORDER BY timestamp ASC
"""
params = (start_time.strftime('%Y-%m-%d %H:%M:%S'),
end_time.strftime('%Y-%m-%d %H:%M:%S'),
group_id)
return self.execute_query(sql, params) or []
def count_messages_by_date_range(self, group_id: str, start_time: datetime, end_time: datetime) -> int:
"""统计指定时间范围内的消息数量
Args:
group_id: 群组ID
start_time: 开始时间
end_time: 结束时间
Returns:
消息数量
"""
sql = """
SELECT COUNT(*) as count
FROM messages
WHERE timestamp >= %s
AND timestamp <= %s
AND message_type in (1, 49)
AND group_id = %s
AND length(content) > 6
AND CHAR_LENGTH(content) < 300
AND content NOT LIKE '/%'
"""
params = (start_time.strftime('%Y-%m-%d %H:%M:%S'),
end_time.strftime('%Y-%m-%d %H:%M:%S'),
group_id)
result = self.execute_query(sql, params)
return result[0]['count'] if result else 0
def get_message_stats_by_date_range(self, group_id: str, start_time: datetime, end_time: datetime) -> Dict:
"""统计指定时间范围内的群消息概览"""
sql = """
SELECT
COUNT(*) AS total_count,
COUNT(DISTINCT sender) AS participant_count,
SUM(CASE WHEN message_type = 1 THEN 1 ELSE 0 END) AS text_count,
SUM(CASE WHEN message_type = 3 THEN 1 ELSE 0 END) AS image_count,
SUM(CASE WHEN message_type IN (43, 62) THEN 1 ELSE 0 END) AS video_count,
SUM(CASE WHEN message_type = 49 THEN 1 ELSE 0 END) AS link_count,
SUM(CASE WHEN message_type IN (47, 1048625, 1090519089) THEN 1 ELSE 0 END) AS emoji_count
FROM messages
WHERE timestamp >= %s
AND timestamp <= %s
AND group_id = %s
AND sender IS NOT NULL
AND sender <> ''
"""
params = (
start_time.strftime('%Y-%m-%d %H:%M:%S'),
end_time.strftime('%Y-%m-%d %H:%M:%S'),
group_id,
)
result = self.execute_query(sql, params, fetch_one=True) or {}
return {
"total_count": int(result.get("total_count") or 0),
"participant_count": int(result.get("participant_count") or 0),
"text_count": int(result.get("text_count") or 0),
"image_count": int(result.get("image_count") or 0),
"video_count": int(result.get("video_count") or 0),
"link_count": int(result.get("link_count") or 0),
"emoji_count": int(result.get("emoji_count") or 0),
}