560 lines
20 KiB
Python
560 lines
20 KiB
Python
# -*- coding: utf-8 -*-
|
||
|
||
from datetime import datetime
|
||
from typing import Dict, List, Optional
|
||
|
||
from db.base import BaseDBOperator
|
||
from db.connection import DBConnectionManager
|
||
from wechat_ipad.models.message import WxMessage
|
||
|
||
|
||
class MessageStorageDB(BaseDBOperator):
|
||
"""消息存储相关数据库操作"""
|
||
|
||
def __init__(self, db_manager: DBConnectionManager):
|
||
super().__init__(db_manager)
|
||
|
||
def archive_message(self, msg: WxMessage) -> bool:
|
||
"""存档消息"""
|
||
now_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
sql = """
|
||
INSERT INTO messages (group_id, timestamp, sender, content, message_type, attachment_url, message_id, message_xml, message_thumb)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||
"""
|
||
params = (
|
||
msg.roomid, now_time, msg.sender, str(msg.content.clean_content), msg.msg_type.value,
|
||
str(msg.content.xml_content), msg.msg_id,
|
||
msg.msg_source, "")
|
||
result = self.execute_update(sql, params)
|
||
return result
|
||
|
||
def get_recent_messages(self, group_id: str, hours_ago: int = 8, min_content_length: int = 6) -> List[Dict]:
|
||
"""获取最近的消息"""
|
||
sql = """
|
||
SELECT timestamp, sender, content, message_type
|
||
FROM messages
|
||
WHERE timestamp >= DATE_SUB(NOW(), INTERVAL %s HOUR)
|
||
AND message_type in (1, 49)
|
||
AND group_id = %s
|
||
AND length(content) > %s
|
||
AND CHAR_LENGTH(content) < 300
|
||
AND content NOT LIKE '/%'
|
||
"""
|
||
params = (hours_ago, group_id, min_content_length)
|
||
return self.execute_query(sql, params) or []
|
||
|
||
def get_latest_image_message(self, group_id: str, before_timestamp: str = "", hours_ago: int = 8) -> Optional[Dict]:
|
||
"""获取指定群最近一条已落盘图片消息"""
|
||
sql = """
|
||
SELECT timestamp, sender, content, message_type, image_path
|
||
FROM messages
|
||
WHERE timestamp >= DATE_SUB(NOW(), INTERVAL %s HOUR)
|
||
AND group_id = %s
|
||
AND message_type = 3
|
||
AND image_path IS NOT NULL
|
||
AND image_path <> ''
|
||
"""
|
||
params: List = [hours_ago, group_id]
|
||
if before_timestamp:
|
||
sql += " AND timestamp <= %s"
|
||
params.append(before_timestamp)
|
||
sql += " ORDER BY timestamp DESC LIMIT 1"
|
||
return self.execute_query(sql, tuple(params), fetch_one=True)
|
||
|
||
def get_message_by_message_id(self, message_id: int | str) -> Optional[Dict]:
|
||
"""根据 message_id 获取单条消息"""
|
||
sql = """
|
||
SELECT id, group_id, timestamp, sender, content, message_type,
|
||
attachment_url, message_id, message_xml, message_thumb, image_path
|
||
FROM messages
|
||
WHERE message_id = %s
|
||
ORDER BY id DESC
|
||
LIMIT 1
|
||
"""
|
||
return self.execute_query(sql, (message_id,), fetch_one=True)
|
||
|
||
def get_image_message_by_md5(self, md5: str) -> Optional[Dict]:
|
||
"""根据图片消息 attachment_url 中的 md5 反查原图消息"""
|
||
sql = """
|
||
SELECT id, group_id, timestamp, sender, content, message_type,
|
||
attachment_url, message_id, message_xml, message_thumb, image_path
|
||
FROM messages
|
||
WHERE message_type = 3
|
||
AND attachment_url IS NOT NULL
|
||
AND attachment_url <> ''
|
||
AND attachment_url LIKE %s
|
||
ORDER BY id DESC
|
||
LIMIT 1
|
||
"""
|
||
return self.execute_query(sql, (f'%md5="{md5}"%',), fetch_one=True)
|
||
|
||
def get_member_recent_messages(self, group_id: str, wxid: str, days: int = 30,
|
||
limit: int = 200, include_today: bool = True) -> List[Dict]:
|
||
"""获取指定群成员近期消息"""
|
||
if include_today:
|
||
sql = """
|
||
SELECT timestamp, sender, content, message_type
|
||
FROM messages
|
||
WHERE timestamp >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
||
AND group_id = %s
|
||
AND sender = %s
|
||
AND message_type IN (1, 49)
|
||
AND CHAR_LENGTH(content) BETWEEN 2 AND 500
|
||
AND content NOT LIKE '/%%'
|
||
ORDER BY timestamp DESC
|
||
LIMIT %s
|
||
"""
|
||
params = (days, group_id, wxid, limit)
|
||
else:
|
||
sql = """
|
||
SELECT timestamp, sender, content, message_type
|
||
FROM messages
|
||
WHERE timestamp >= DATE_SUB(CURDATE(), INTERVAL %s DAY)
|
||
AND timestamp < CURDATE()
|
||
AND group_id = %s
|
||
AND sender = %s
|
||
AND message_type IN (1, 49)
|
||
AND CHAR_LENGTH(content) BETWEEN 2 AND 500
|
||
AND content NOT LIKE '/%%'
|
||
ORDER BY timestamp DESC
|
||
LIMIT %s
|
||
"""
|
||
params = (days, group_id, wxid, limit)
|
||
results = self.execute_query(sql, params) or []
|
||
return list(reversed(results))
|
||
|
||
def get_member_messages_since(self, group_id: str, wxid: str, since_time, limit: int = 200) -> List[Dict]:
|
||
"""获取指定时间之后的成员消息"""
|
||
sql = """
|
||
SELECT timestamp, sender, content, message_type
|
||
FROM messages
|
||
WHERE timestamp > %s
|
||
AND group_id = %s
|
||
AND sender = %s
|
||
AND message_type IN (1, 49)
|
||
AND CHAR_LENGTH(content) BETWEEN 2 AND 500
|
||
AND content NOT LIKE '/%%'
|
||
ORDER BY timestamp ASC
|
||
LIMIT %s
|
||
"""
|
||
if isinstance(since_time, datetime):
|
||
since_time = since_time.strftime("%Y-%m-%d %H:%M:%S")
|
||
return self.execute_query(sql, (since_time, group_id, wxid, limit)) or []
|
||
|
||
def get_member_active_dates(self, group_id: str, wxid: str, days: int = 365) -> List[Dict]:
|
||
"""获取成员在指定时间窗口内的活跃日期列表"""
|
||
sql = """
|
||
SELECT
|
||
DATE(timestamp) AS message_date,
|
||
COUNT(*) AS msg_count,
|
||
MIN(timestamp) AS first_message_time,
|
||
MAX(timestamp) AS last_message_time
|
||
FROM messages
|
||
WHERE timestamp >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
||
AND group_id = %s
|
||
AND sender = %s
|
||
AND message_type IN (1, 49)
|
||
AND CHAR_LENGTH(content) BETWEEN 2 AND 500
|
||
AND content NOT LIKE '/%%'
|
||
GROUP BY DATE(timestamp)
|
||
ORDER BY message_date ASC
|
||
"""
|
||
rows = self.execute_query(sql, (days, group_id, wxid)) or []
|
||
for row in rows:
|
||
for key in ("message_date", "first_message_time", "last_message_time"):
|
||
value = row.get(key)
|
||
if isinstance(value, datetime):
|
||
row[key] = value.strftime("%Y-%m-%d %H:%M:%S") if key != "message_date" else value.strftime("%Y-%m-%d")
|
||
elif value:
|
||
row[key] = str(value)
|
||
return rows
|
||
|
||
def get_member_messages_on_date(self, group_id: str, wxid: str, target_date: str, limit: int = 120) -> List[Dict]:
|
||
"""获取成员在某一天的消息"""
|
||
sql = """
|
||
SELECT timestamp, sender, content, message_type
|
||
FROM messages
|
||
WHERE DATE(timestamp) = %s
|
||
AND group_id = %s
|
||
AND sender = %s
|
||
AND message_type IN (1, 49)
|
||
AND CHAR_LENGTH(content) BETWEEN 2 AND 500
|
||
AND content NOT LIKE '/%%'
|
||
ORDER BY timestamp ASC
|
||
LIMIT %s
|
||
"""
|
||
return self.execute_query(sql, (target_date, group_id, wxid, limit)) or []
|
||
|
||
def get_member_messages_for_group_date(self, group_id: str, target_date: str, limit: int = 5000) -> List[Dict]:
|
||
"""获取群在某一天的全部文本消息"""
|
||
sql = """
|
||
SELECT timestamp, sender, content, message_type
|
||
FROM messages
|
||
WHERE DATE(timestamp) = %s
|
||
AND group_id = %s
|
||
AND sender IS NOT NULL
|
||
AND sender <> ''
|
||
AND message_type IN (1, 49)
|
||
AND CHAR_LENGTH(content) BETWEEN 2 AND 500
|
||
AND content NOT LIKE '/%%'
|
||
ORDER BY timestamp ASC
|
||
LIMIT %s
|
||
"""
|
||
return self.execute_query(sql, (target_date, group_id, limit)) or []
|
||
|
||
def get_message_count_by_date(self, date: str) -> List[Dict]:
|
||
"""获取指定日期的消息统计"""
|
||
sql = """
|
||
SELECT group_id, sender, COUNT(*) as count
|
||
FROM messages
|
||
WHERE DATE(timestamp) = %s
|
||
GROUP BY group_id, sender
|
||
"""
|
||
return self.execute_query(sql, (date,)) or []
|
||
|
||
def get_speech_ranking(self, date: str, group_id: str, limit: int = 20) -> List[Dict]:
|
||
"""获取指定日期和群组的发言排名"""
|
||
sql = """
|
||
SELECT wx_id, count AS speech_count
|
||
FROM speech_counts
|
||
WHERE date = %s
|
||
AND group_id = %s
|
||
GROUP BY wx_id
|
||
ORDER BY count DESC
|
||
LIMIT %s
|
||
"""
|
||
params = (date, group_id, limit)
|
||
results = self.execute_query(sql, params)
|
||
return results or []
|
||
|
||
def insert_speech_count(self, group_id: str, wx_id: str, date: str, count: int) -> bool:
|
||
"""插入发言统计数据
|
||
|
||
Args:
|
||
group_id: 群组ID
|
||
wx_id: 微信ID
|
||
date: 日期,格式为 YYYY-MM-DD
|
||
count: 发言计数
|
||
|
||
Returns:
|
||
是否成功插入
|
||
"""
|
||
sql = """
|
||
INSERT INTO speech_counts (group_id, wx_id, date, count)
|
||
VALUES (%s, %s, %s, %s)
|
||
ON DUPLICATE KEY UPDATE count = VALUES(count)
|
||
"""
|
||
params = (group_id, wx_id, date, count)
|
||
return self.execute_update(sql, params)
|
||
|
||
def get_message_trend(self, group_id: str, days: int = 7) -> List[Dict]:
|
||
"""获取指定群组的消息趋势数据
|
||
|
||
Args:
|
||
group_id: 群组ID
|
||
days: 获取最近几天的数据,默认7天
|
||
|
||
Returns:
|
||
包含日期和消息数量的列表
|
||
"""
|
||
sql = """
|
||
SELECT
|
||
DATE(timestamp) as date,
|
||
COUNT(*) as message_count
|
||
FROM messages
|
||
WHERE group_id = %s
|
||
AND timestamp >= DATE_SUB(CURDATE(), INTERVAL %s DAY)
|
||
GROUP BY DATE(timestamp)
|
||
ORDER BY date
|
||
"""
|
||
return self.execute_query(sql, (group_id, days)) or []
|
||
|
||
def get_messages_by_filter(self, group_id=None, start_date=None, end_date=None,
|
||
search_text=None, page=1, page_size=20) -> Dict:
|
||
"""按条件筛选消息并支持分页和模糊搜索
|
||
|
||
Args:
|
||
group_id: 群组ID,可选
|
||
start_date: 开始日期,格式为YYYY-MM-DD,可选
|
||
end_date: 结束日期,格式为YYYY-MM-DD,可选
|
||
search_text: 搜索文本,可选,用于模糊搜索消息内容
|
||
page: 页码,从1开始
|
||
page_size: 每页记录数
|
||
|
||
Returns:
|
||
包含消息列表和总记录数的字典
|
||
"""
|
||
# 构建基础SQL查询
|
||
sql_count = "SELECT COUNT(*) as total FROM messages WHERE 1=1 "
|
||
sql_data = """
|
||
SELECT id, group_id, timestamp, sender, content, message_type,
|
||
attachment_url, message_id, message_xml, message_thumb, image_path
|
||
FROM messages
|
||
WHERE 1=1
|
||
"""
|
||
|
||
# 构建参数列表
|
||
params = []
|
||
|
||
# 添加筛选条件
|
||
if group_id:
|
||
sql_count += " AND group_id = %s "
|
||
sql_data += " AND group_id = %s "
|
||
params.append(group_id)
|
||
|
||
if start_date:
|
||
sql_count += " AND DATE(timestamp) >= %s "
|
||
sql_data += " AND DATE(timestamp) >= %s "
|
||
params.append(start_date)
|
||
|
||
if end_date:
|
||
sql_count += " AND DATE(timestamp) <= %s "
|
||
sql_data += " AND DATE(timestamp) <= %s "
|
||
params.append(end_date)
|
||
|
||
if search_text:
|
||
sql_count += " AND content LIKE %s "
|
||
sql_data += " AND content LIKE %s "
|
||
params.append(f"%{search_text}%")
|
||
|
||
# 添加排序和分页
|
||
sql_data += " ORDER BY timestamp DESC "
|
||
sql_data += " LIMIT %s OFFSET %s "
|
||
|
||
# 计算分页参数
|
||
offset = (page - 1) * page_size
|
||
data_params = params.copy()
|
||
data_params.extend([page_size, offset])
|
||
|
||
# 执行查询
|
||
count_result = self.execute_query(sql_count, params)
|
||
total = count_result[0]['total'] if count_result else 0
|
||
|
||
messages = self.execute_query(sql_data, data_params) or []
|
||
|
||
return {
|
||
'total': total,
|
||
'page': page,
|
||
'page_size': page_size,
|
||
'total_pages': (total + page_size - 1) // page_size,
|
||
'messages': messages
|
||
}
|
||
|
||
def update_message_image_path(self, message_id, image_base64str):
|
||
"""
|
||
更新消息的图片路径
|
||
|
||
Args:
|
||
message_id: 消息ID
|
||
image_base64str: 图片base64内容
|
||
|
||
Returns:
|
||
bool: 更新成功返回True,否则返回False
|
||
"""
|
||
try:
|
||
# 构建SQL语句
|
||
sql = """
|
||
UPDATE messages
|
||
SET message_thumb = %s
|
||
WHERE message_id = %s
|
||
"""
|
||
params = (image_base64str, message_id)
|
||
|
||
# 执行更新操作
|
||
result = self.execute_update(sql, params)
|
||
return result
|
||
except Exception as e:
|
||
# 使用已有的日志记录方式
|
||
print(f"更新消息图片路径出错: {e}")
|
||
return False
|
||
|
||
def update_message_image_file_path(self, message_id, image_path):
|
||
try:
|
||
sql = """
|
||
UPDATE messages
|
||
SET image_path = %s
|
||
WHERE message_id = %s
|
||
AND image_path IS NULL
|
||
"""
|
||
params = (image_path, message_id)
|
||
result = self.execute_update(sql, params)
|
||
return result
|
||
except Exception as e:
|
||
print(f"更新消息图片文件路径出错: {e}")
|
||
return False
|
||
|
||
|
||
def get_hourly_message_trend(self, group_id: str = None, days: int = 1) -> List[Dict]:
|
||
"""获取指定群组的按小时消息趋势数据
|
||
|
||
Args:
|
||
group_id: 群组ID,如果为None则获取所有群组的数据
|
||
days: 获取最近几天的数据,默认1天
|
||
|
||
Returns:
|
||
包含小时和消息数量的列表
|
||
"""
|
||
sql = """
|
||
SELECT
|
||
DATE_FORMAT(timestamp, '%Y-%m-%d %H:00') as hour_slot,
|
||
COUNT(*) as message_count
|
||
FROM messages
|
||
WHERE timestamp >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
||
"""
|
||
|
||
params = [days]
|
||
|
||
# 如果指定了群组ID,则添加群组筛选条件
|
||
if group_id:
|
||
sql += "AND group_id = %s "
|
||
params.append(group_id)
|
||
|
||
# 按小时分组并排序
|
||
sql += """
|
||
GROUP BY hour_slot
|
||
ORDER BY hour_slot
|
||
"""
|
||
|
||
return self.execute_query(sql, tuple(params)) or []
|
||
|
||
def get_pending_image_messages(self, minutes_ago: int = 10, limit: int = 50) -> List[Dict]:
|
||
"""获取最近N分钟内未处理图片的消息(image_path IS NULL)
|
||
|
||
Args:
|
||
minutes_ago: 查询最近多少分钟的消息,默认10分钟
|
||
limit: 每次最多处理多少条,默认50条
|
||
|
||
Returns:
|
||
包含消息ID、群ID、消息XML等信息的列表
|
||
"""
|
||
sql = """
|
||
SELECT message_id, group_id, message_xml, timestamp,attachment_url
|
||
FROM messages
|
||
WHERE message_type = '3'
|
||
AND image_path IS NULL
|
||
AND timestamp >= DATE_SUB(NOW(), INTERVAL %s MINUTE)
|
||
AND message_xml IS NOT NULL
|
||
AND message_xml != ''
|
||
ORDER BY timestamp ASC
|
||
LIMIT %s
|
||
"""
|
||
params = (minutes_ago, limit)
|
||
return self.execute_query(sql, params) or []
|
||
|
||
def get_messages_by_date_range(self, group_id: str, start_date: str, end_date: str = None,
|
||
min_content_length: int = 6, max_results: int = 5000) -> List[Dict]:
|
||
"""按日期范围获取消息(支持按天总结)
|
||
|
||
Args:
|
||
group_id: 群组ID
|
||
start_date: 开始日期,格式 YYYY-MM-DD
|
||
end_date: 结束日期,格式 YYYY-MM-DD,如果为None则使用start_date当天
|
||
min_content_length: 最小内容长度
|
||
max_results: 最多返回多少条消息,防止数据过多(默认5000条,足够总结使用)
|
||
|
||
Returns:
|
||
消息列表
|
||
"""
|
||
if end_date is None:
|
||
end_date = start_date
|
||
|
||
sql = """
|
||
SELECT timestamp, sender, content, message_type
|
||
FROM messages
|
||
WHERE DATE(timestamp) >= %s
|
||
AND DATE(timestamp) <= %s
|
||
AND group_id = %s
|
||
AND message_type IN (1, 49)
|
||
AND LENGTH(content) > %s
|
||
AND CHAR_LENGTH(content) < 300
|
||
AND content NOT LIKE '/%'
|
||
ORDER BY timestamp ASC
|
||
LIMIT %s
|
||
"""
|
||
params = (start_date, end_date, group_id, min_content_length, max_results)
|
||
return self.execute_query(sql, params) or []
|
||
|
||
def get_messages_for_summary(self, group_id: str, hours_ago: int = 8,
|
||
min_messages: int = 50,
|
||
max_hours: int = 48,
|
||
max_results: int = 5000) -> List[Dict]:
|
||
"""智能获取用于总结的消息(自动调整时间范围)
|
||
|
||
Args:
|
||
group_id: 群组ID
|
||
hours_ago: 默认查询最近多少小时
|
||
min_messages: 最少需要多少条消息,如果不足会扩大时间范围
|
||
max_hours: 最大查询多少小时内的消息
|
||
max_results: 最多返回多少条消息(默认5000条,确保有足够数据)
|
||
|
||
Returns:
|
||
消息列表
|
||
"""
|
||
# 先尝试默认时间范围
|
||
messages = self.get_recent_messages(group_id, hours_ago=hours_ago)
|
||
|
||
# 如果消息不足,逐步扩大时间范围
|
||
current_hours = hours_ago
|
||
while len(messages) < min_messages and current_hours < max_hours:
|
||
current_hours += 8 # 每次增加8小时
|
||
messages = self.get_recent_messages(group_id, hours_ago=current_hours)
|
||
|
||
# 限制最大返回数量(5000条足以覆盖1-2天的活跃群聊)
|
||
return messages[:max_results] if messages else []
|
||
|
||
def get_messages_by_date_range(self, group_id: str, start_time: datetime, end_time: datetime) -> List[Dict]:
|
||
"""获取指定时间范围内的消息
|
||
|
||
Args:
|
||
group_id: 群组ID
|
||
start_time: 开始时间
|
||
end_time: 结束时间
|
||
|
||
Returns:
|
||
消息列表
|
||
"""
|
||
sql = """
|
||
SELECT timestamp, sender, content, message_type
|
||
FROM messages
|
||
WHERE timestamp >= %s
|
||
AND timestamp <= %s
|
||
AND message_type in (1, 49)
|
||
AND group_id = %s
|
||
AND length(content) > 6
|
||
AND CHAR_LENGTH(content) < 300
|
||
AND content NOT LIKE '/%'
|
||
ORDER BY timestamp ASC
|
||
"""
|
||
params = (start_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||
end_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||
group_id)
|
||
return self.execute_query(sql, params) or []
|
||
|
||
def count_messages_by_date_range(self, group_id: str, start_time: datetime, end_time: datetime) -> int:
|
||
"""统计指定时间范围内的消息数量
|
||
|
||
Args:
|
||
group_id: 群组ID
|
||
start_time: 开始时间
|
||
end_time: 结束时间
|
||
|
||
Returns:
|
||
消息数量
|
||
"""
|
||
sql = """
|
||
SELECT COUNT(*) as count
|
||
FROM messages
|
||
WHERE timestamp >= %s
|
||
AND timestamp <= %s
|
||
AND message_type in (1, 49)
|
||
AND group_id = %s
|
||
AND length(content) > 6
|
||
AND CHAR_LENGTH(content) < 300
|
||
AND content NOT LIKE '/%'
|
||
"""
|
||
params = (start_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||
end_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||
group_id)
|
||
result = self.execute_query(sql, params)
|
||
return result[0]['count'] if result else 0
|