284 lines
9.3 KiB
Python
284 lines
9.3 KiB
Python
# -*- coding: utf-8 -*-
|
||
|
||
from datetime import datetime
|
||
from typing import Dict, List, Optional
|
||
|
||
from db.base import BaseDBOperator
|
||
from db.connection import DBConnectionManager
|
||
from wechat_ipad.models.message import WxMessage
|
||
|
||
|
||
class MessageStorageDB(BaseDBOperator):
|
||
"""消息存储相关数据库操作"""
|
||
|
||
def __init__(self, db_manager: DBConnectionManager):
|
||
super().__init__(db_manager)
|
||
|
||
def archive_message(self, msg: WxMessage) -> bool:
|
||
"""存档消息"""
|
||
now_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
sql = """
|
||
INSERT INTO messages (group_id, timestamp, sender, content, message_type, attachment_url, message_id, message_xml, message_thumb)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||
"""
|
||
params = (
|
||
msg.roomid, now_time, msg.sender, str(msg.content.clean_content), msg.msg_type.value,
|
||
str(msg.content.xml_content), msg.msg_id,
|
||
msg.msg_source, "")
|
||
result = self.execute_update(sql, params)
|
||
return result
|
||
|
||
def get_recent_messages(self, group_id: str, hours_ago: int = 8, min_content_length: int = 6) -> List[Dict]:
|
||
"""获取最近的消息"""
|
||
sql = """
|
||
SELECT timestamp, sender, content, message_type
|
||
FROM messages
|
||
WHERE timestamp >= DATE_SUB(NOW(), INTERVAL %s HOUR)
|
||
AND message_type in (1, 49)
|
||
AND group_id = %s
|
||
AND length(content) > %s
|
||
AND CHAR_LENGTH(content) < 300
|
||
AND content NOT LIKE '/%'
|
||
"""
|
||
params = (hours_ago, group_id, min_content_length)
|
||
return self.execute_query(sql, params) or []
|
||
|
||
def get_message_count_by_date(self, date: str) -> List[Dict]:
|
||
"""获取指定日期的消息统计"""
|
||
sql = """
|
||
SELECT group_id, sender, COUNT(*) as count
|
||
FROM messages
|
||
WHERE DATE(timestamp) = %s
|
||
GROUP BY group_id, sender
|
||
"""
|
||
return self.execute_query(sql, (date,)) or []
|
||
|
||
def get_speech_ranking(self, date: str, group_id: str, limit: int = 20) -> List[Dict]:
|
||
"""获取指定日期和群组的发言排名"""
|
||
sql = """
|
||
SELECT wx_id, count AS speech_count
|
||
FROM speech_counts
|
||
WHERE date = %s
|
||
AND group_id = %s
|
||
GROUP BY wx_id
|
||
ORDER BY count DESC
|
||
LIMIT %s
|
||
"""
|
||
params = (date, group_id, limit)
|
||
results = self.execute_query(sql, params)
|
||
return results or []
|
||
|
||
def insert_speech_count(self, group_id: str, wx_id: str, date: str, count: int) -> bool:
|
||
"""插入发言统计数据
|
||
|
||
Args:
|
||
group_id: 群组ID
|
||
wx_id: 微信ID
|
||
date: 日期,格式为 YYYY-MM-DD
|
||
count: 发言计数
|
||
|
||
Returns:
|
||
是否成功插入
|
||
"""
|
||
sql = """
|
||
INSERT INTO speech_counts (group_id, wx_id, date, count)
|
||
VALUES (%s, %s, %s, %s)
|
||
ON DUPLICATE KEY UPDATE count = VALUES(count)
|
||
"""
|
||
params = (group_id, wx_id, date, count)
|
||
return self.execute_update(sql, params)
|
||
|
||
def get_message_trend(self, group_id: str, days: int = 7) -> List[Dict]:
|
||
"""获取指定群组的消息趋势数据
|
||
|
||
Args:
|
||
group_id: 群组ID
|
||
days: 获取最近几天的数据,默认7天
|
||
|
||
Returns:
|
||
包含日期和消息数量的列表
|
||
"""
|
||
sql = """
|
||
SELECT
|
||
DATE(timestamp) as date,
|
||
COUNT(*) as message_count
|
||
FROM messages
|
||
WHERE group_id = %s
|
||
AND timestamp >= DATE_SUB(CURDATE(), INTERVAL %s DAY)
|
||
GROUP BY DATE(timestamp)
|
||
ORDER BY date
|
||
"""
|
||
return self.execute_query(sql, (group_id, days)) or []
|
||
|
||
def get_messages_by_filter(self, group_id=None, start_date=None, end_date=None,
|
||
search_text=None, page=1, page_size=20) -> Dict:
|
||
"""按条件筛选消息并支持分页和模糊搜索
|
||
|
||
Args:
|
||
group_id: 群组ID,可选
|
||
start_date: 开始日期,格式为YYYY-MM-DD,可选
|
||
end_date: 结束日期,格式为YYYY-MM-DD,可选
|
||
search_text: 搜索文本,可选,用于模糊搜索消息内容
|
||
page: 页码,从1开始
|
||
page_size: 每页记录数
|
||
|
||
Returns:
|
||
包含消息列表和总记录数的字典
|
||
"""
|
||
# 构建基础SQL查询
|
||
sql_count = "SELECT COUNT(*) as total FROM messages WHERE 1=1 "
|
||
sql_data = """
|
||
SELECT id, group_id, timestamp, sender, content, message_type,
|
||
attachment_url, message_id, message_xml, message_thumb, image_path
|
||
FROM messages
|
||
WHERE 1=1
|
||
"""
|
||
|
||
# 构建参数列表
|
||
params = []
|
||
|
||
# 添加筛选条件
|
||
if group_id:
|
||
sql_count += " AND group_id = %s "
|
||
sql_data += " AND group_id = %s "
|
||
params.append(group_id)
|
||
|
||
if start_date:
|
||
sql_count += " AND DATE(timestamp) >= %s "
|
||
sql_data += " AND DATE(timestamp) >= %s "
|
||
params.append(start_date)
|
||
|
||
if end_date:
|
||
sql_count += " AND DATE(timestamp) <= %s "
|
||
sql_data += " AND DATE(timestamp) <= %s "
|
||
params.append(end_date)
|
||
|
||
if search_text:
|
||
sql_count += " AND content LIKE %s "
|
||
sql_data += " AND content LIKE %s "
|
||
params.append(f"%{search_text}%")
|
||
|
||
# 添加排序和分页
|
||
sql_data += " ORDER BY timestamp DESC "
|
||
sql_data += " LIMIT %s OFFSET %s "
|
||
|
||
# 计算分页参数
|
||
offset = (page - 1) * page_size
|
||
data_params = params.copy()
|
||
data_params.extend([page_size, offset])
|
||
|
||
# 执行查询
|
||
count_result = self.execute_query(sql_count, params)
|
||
total = count_result[0]['total'] if count_result else 0
|
||
|
||
messages = self.execute_query(sql_data, data_params) or []
|
||
|
||
return {
|
||
'total': total,
|
||
'page': page,
|
||
'page_size': page_size,
|
||
'total_pages': (total + page_size - 1) // page_size,
|
||
'messages': messages
|
||
}
|
||
|
||
def update_message_image_path(self, message_id, image_base64str):
|
||
"""
|
||
更新消息的图片路径
|
||
|
||
Args:
|
||
message_id: 消息ID
|
||
image_base64str: 图片base64内容
|
||
|
||
Returns:
|
||
bool: 更新成功返回True,否则返回False
|
||
"""
|
||
try:
|
||
# 构建SQL语句
|
||
sql = """
|
||
UPDATE messages
|
||
SET message_thumb = %s
|
||
WHERE message_id = %s
|
||
"""
|
||
params = (image_base64str, message_id)
|
||
|
||
# 执行更新操作
|
||
result = self.execute_update(sql, params)
|
||
return result
|
||
except Exception as e:
|
||
# 使用已有的日志记录方式
|
||
print(f"更新消息图片路径出错: {e}")
|
||
return False
|
||
|
||
def update_message_image_file_path(self, message_id, image_path):
|
||
try:
|
||
sql = """
|
||
UPDATE messages
|
||
SET image_path = %s
|
||
WHERE message_id = %s
|
||
AND image_path IS NULL
|
||
"""
|
||
params = (image_path, message_id)
|
||
result = self.execute_update(sql, params)
|
||
return result
|
||
except Exception as e:
|
||
print(f"更新消息图片文件路径出错: {e}")
|
||
return False
|
||
|
||
|
||
def get_hourly_message_trend(self, group_id: str = None, days: int = 1) -> List[Dict]:
|
||
"""获取指定群组的按小时消息趋势数据
|
||
|
||
Args:
|
||
group_id: 群组ID,如果为None则获取所有群组的数据
|
||
days: 获取最近几天的数据,默认1天
|
||
|
||
Returns:
|
||
包含小时和消息数量的列表
|
||
"""
|
||
sql = """
|
||
SELECT
|
||
DATE_FORMAT(timestamp, '%Y-%m-%d %H:00') as hour_slot,
|
||
COUNT(*) as message_count
|
||
FROM messages
|
||
WHERE timestamp >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
||
"""
|
||
|
||
params = [days]
|
||
|
||
# 如果指定了群组ID,则添加群组筛选条件
|
||
if group_id:
|
||
sql += "AND group_id = %s "
|
||
params.append(group_id)
|
||
|
||
# 按小时分组并排序
|
||
sql += """
|
||
GROUP BY hour_slot
|
||
ORDER BY hour_slot
|
||
"""
|
||
|
||
return self.execute_query(sql, tuple(params)) or []
|
||
|
||
def get_pending_image_messages(self, minutes_ago: int = 10, limit: int = 50) -> List[Dict]:
|
||
"""获取最近N分钟内未处理图片的消息(image_path IS NULL)
|
||
|
||
Args:
|
||
minutes_ago: 查询最近多少分钟的消息,默认10分钟
|
||
limit: 每次最多处理多少条,默认50条
|
||
|
||
Returns:
|
||
包含消息ID、群ID、消息XML等信息的列表
|
||
"""
|
||
sql = """
|
||
SELECT message_id, group_id, message_xml, timestamp,attachment_url
|
||
FROM messages
|
||
WHERE message_type = '3'
|
||
AND image_path IS NULL
|
||
AND timestamp >= DATE_SUB(NOW(), INTERVAL %s MINUTE)
|
||
AND message_xml IS NOT NULL
|
||
AND message_xml != ''
|
||
ORDER BY timestamp ASC
|
||
LIMIT %s
|
||
"""
|
||
params = (minutes_ago, limit)
|
||
return self.execute_query(sql, params) or []
|