Files
abot/utils/wechat/message_to_db.py
2026-01-06 16:11:10 +08:00

631 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import time
from datetime import datetime, timedelta
import xml.etree.ElementTree as ET
import concurrent.futures # 添加线程池支持
import os
import base64
import imghdr
import re
from threading import Lock
from typing import Dict
from db.connection import DBConnectionManager
from db.levels_db import LevelsDBOperator
from db.message_storage import MessageStorageDB
# 导入积分系统
from db.points_db import PointsDBOperator, PointSource
from utils.wechat.contact_manager import ContactManager
from wechat_ipad import WechatAPIClient
from wechat_ipad.models.message import WxMessage, MessageType
from loguru import logger
logging = logger
class MessageStorage:
def __init__(self, client: WechatAPIClient = None):
# 获取数据库连接管理器的单例
self.db_manager = DBConnectionManager.get_instance()
self.message_db = MessageStorageDB(self.db_manager)
self.points_db = PointsDBOperator(self.db_manager)
# 初始化本地缓存字典,使用 group_id 作为键
self.local_membercounts = {}
self.local_members = {}
# 创建线程池,用于异步存储消息
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
# 用于跟踪异步任务的列表
self.pending_tasks = []
# 保存WCF实例用于图片处理
self.client = client
# 图片处理相关初始化
self.image_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) # 专用于图片处理的线程池
self.image_tasks = []
# 图片任务在途控制
self._image_task_inflight = 0
self._image_task_lock = Lock()
self.MAX_IMAGE_TASKS = 50 # 可调20~100 之间
# 事件循环(只创建一次,替代 asyncio.run
self._image_loop = asyncio.new_event_loop()
# 正则(替代 XML 解析)
self._aeskey_re = re.compile(r'aeskey="(.*?)"')
self._cdn_re = re.compile(r'cdnthumburl="(.*?)"')
# 修改为项目根目录下的 static/images
self.image_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "static", "images")
# 确保图片存储目录存在
if not os.path.exists(self.image_dir):
os.makedirs(self.image_dir, exist_ok=True)
logger.info(f"图片存储目录: {self.image_dir}")
def process_message(self, message: WxMessage):
# 示例message字符串
current_date = datetime.now().strftime('%Y-%m-%d')
# 生成Redis key
key = f"{message.roomid}:{message.sender}:{current_date}:count"
# 获取 Redis 连接
redis_conn = self.db_manager.get_redis_connection()
# 使用Redis哈希或字符串增加发言次数
redis_conn.hincrby(key, 'count', 1) # 这里使用哈希但也可以考虑用字符串的INCR操作
# 设置时效为48小时
redis_conn.expire(key, 86400 * 2)
# 或者使用字符串r.incr(key) # 如果只存储一个整数值,字符串类型可能更简单
def archive_message(self, msg: WxMessage):
"""异步存档消息,防止堵塞主线程"""
# 提交任务到线程池
future = self.executor.submit(self._archive_message_task, msg)
# 可选:添加回调函数处理完成后的操作
future.add_done_callback(self._archive_callback)
# 将任务添加到待处理列表
self.pending_tasks.append(future)
# 清理已完成的任务
self._cleanup_completed_tasks()
def _archive_message_task(self, msg: WxMessage):
"""实际执行消息存档的任务函数"""
try:
# 使用 MessageStorageDB 类存档消息
result = self.message_db.archive_message(msg)
return {
'success': result,
'roomid': msg.roomid,
'sender': msg.sender,
'content': str(msg.content.clean_content), # 添加消息内容
'message_id': msg.msg_id # 添加消息ID
}
except Exception as e:
logger.error(f"存档消息出错: {e}")
return {
'success': False,
'roomid': msg.roomid,
'sender': msg.sender,
'content': str(msg.content.clean_content), # 添加消息内容
'message_id': msg.msg_id, # 添加消息ID
'error': str(e)
}
def process_image(self, msg: WxMessage):
"""图片消息已通过 archive_message 存入数据库,不再实时处理
改为定时任务批量处理,减少对主流程的影响和数据库锁竞争
"""
# 图片消息已经通过 archive_message 存入数据库
# 定时任务会定期扫描并处理未下载的图片
logger.debug(f"图片消息已记录,等待定时任务处理: msg_id={msg.msg_id}, roomid={msg.roomid}")
return True
async def _process_image_from_db(self, db_record: Dict) -> Dict:
"""从数据库记录处理图片(用于定时任务,异步版本)
Args:
db_record: 数据库记录,包含 message_id, group_id, attachment_url 等
Returns:
处理结果字典
"""
message_id = db_record.get('message_id')
group_id = db_record.get('group_id', 'unknown')
xml_content = db_record.get('attachment_url', '')
if not self.client or not message_id or not xml_content:
return {
'success': False,
'message_id': message_id,
'error': "缺少必要参数"
}
try:
# ===== 1. 正则提取参数(替代 XML=====
aeskey_match = self._aeskey_re.search(xml_content)
cdn_match = self._cdn_re.search(xml_content)
if not aeskey_match or not cdn_match:
return {
'success': False,
'message_id': message_id,
'error': "XML 中未找到图片参数"
}
aeskey = aeskey_match.group(1)
cdnthumburl = cdn_match.group(1)
# ===== 2. 下载图片(异步方式,直接 await=====
try:
base64_str = await self.client.download_image(
aeskey=aeskey,
cdnmidimgurl=cdnthumburl
)
except Exception as e:
logger.error(f"图片下载失败 message_id={message_id}: {e}")
return {
'success': False,
'message_id': message_id,
'error': f"图片下载失败: {str(e)}"
}
if not base64_str:
return {
'success': False,
'message_id': message_id,
'error': "图片下载失败:返回为空"
}
# ===== 3. base64 解码 =====
try:
data = base64.b64decode(base64_str)
except Exception as e:
logger.error(f"图片解码失败 message_id={message_id}: {e}")
return {
'success': False,
'message_id': message_id,
'error': f"图片解码失败: {str(e)}"
}
# ===== 4. 构建路径 =====
room_id = group_id or "unknown"
group_dir = os.path.join(self.image_dir, room_id)
os.makedirs(group_dir, exist_ok=True)
# 微信图片默认 jpg
file_name = f"{message_id}.jpg"
file_path = os.path.join(group_dir, file_name)
# ===== 5. 写文件 =====
skipped = False
if os.path.isfile(file_path):
skipped = True
logger.debug(f"图片文件已存在,跳过保存: {room_id}-{file_name}")
else:
with open(file_path, "wb") as f:
f.write(data)
# ===== 6. 更新数据库(串行更新,避免锁竞争)=====
if not skipped:
web_path = f"/static/images/{room_id}/{file_name}"
success = self.message_db.update_message_image_file_path(message_id, web_path)
if success:
logger.debug(f"图片处理成功: message_id={message_id}, path={web_path}")
else:
logger.warning(f"图片路径更新失败: message_id={message_id}")
return {
'success': False,
'message_id': message_id,
'error': "数据库更新失败"
}
return {
"success": True,
"message_id": message_id,
"roomid": room_id,
"file_path": f"/static/images/{room_id}/{file_name}" if not skipped else None,
"skipped": skipped
}
except Exception as e:
logger.exception(f"处理图片出错 message_id={message_id}")
return {
'success': False,
'message_id': message_id,
'error': f"处理出错: {str(e)}"
}
async def process_pending_images(self, minutes_ago: int = 10, batch_size: int = 20):
"""定时任务:批量处理未下载的图片消息(串行处理,避免锁竞争)
Args:
minutes_ago: 处理最近多少分钟的消息默认10分钟
batch_size: 每次处理多少条默认20条
"""
if not self.client:
logger.warning("微信客户端未初始化,跳过图片处理")
return
try:
# 查询未处理的图片消息
pending_messages = self.message_db.get_pending_image_messages(minutes_ago, batch_size)
if not pending_messages:
logger.debug(f"未发现待处理的图片消息(最近{minutes_ago}分钟)")
return
logger.info(f"开始处理 {len(pending_messages)} 条待处理图片消息")
success_count = 0
fail_count = 0
# 串行处理,避免并发更新数据库导致锁竞争
for msg_record in pending_messages:
result = await self._process_image_from_db(msg_record)
if result.get('success'):
success_count += 1
else:
fail_count += 1
error = result.get('error', '未知错误')
logger.warning(f"图片处理失败 message_id={result.get('message_id')}: {error}")
logger.info(f"图片处理完成: 成功={success_count}, 失败={fail_count}, 总计={len(pending_messages)}")
except Exception as e:
logger.exception(f"定时处理图片任务出错: {e}")
def _process_image_done(self, future):
"""任务完成统一回调(极轻量)"""
try:
result = future.result()
self._process_image_callback(result)
except Exception as e:
logger.error(f"处理图片回调时出错: {e}")
finally:
# ⚠️ 无论成功失败,都必须释放在途计数
with self._image_task_lock:
self._image_task_inflight -= 1
def _process_image_callback(self, result):
if result['success']:
skipped_info = " (已存在)" if result.get('skipped') else ""
logger.info(
f"图片处理成功{skipped_info}: "
f"{result['roomid']}:{result['sender']}:{result['message_id']}"
)
else:
logger.error(
f"图片处理失败: "
f"{result.get('roomid', '')}:"
f"{result.get('sender', '')}:"
f"{result.get('message_id', '')} - "
f"{result.get('error', '未知错误')}"
)
def _archive_callback(self, future):
"""处理异步存档任务完成后的回调"""
try:
result = future.result()
if result['success']:
# 修改日志输出,包含消息内容
compressed = result['content'].replace('\n', '').replace('\r', '')
logger.info(f"archive_success: {result['roomid']}:{result['sender']}: {compressed}")
else:
error_msg = result.get('error', '未知错误')
logger.error(f"archive_fail: {result['roomid']}:{result['sender']} - {error_msg}")
except Exception as e:
logger.error(f"处理存档回调时出错: {e}")
def _cleanup_completed_tasks(self):
"""清理已完成的任务,防止内存泄漏"""
# 只有当任务数量超过阈值时才进行清理,减少频繁操作
if len(self.pending_tasks) > 20:
# 过滤出已完成的任务
completed_tasks = [task for task in self.pending_tasks if task.done()]
# 从待处理列表中移除已完成的任务
for task in completed_tasks:
self.pending_tasks.remove(task)
# 如果待处理任务过多,记录警告日志
if len(self.pending_tasks) > 100:
logger.warning(f"待处理的存档任务数量过多: {len(self.pending_tasks)}")
# 只有当图片任务数量超过阈值时才进行清理
if len(self.image_tasks) > 10:
# 清理已完成的图片处理任务
completed_image_tasks = [task for task in self.image_tasks if task.done()]
for task in completed_image_tasks:
self.image_tasks.remove(task)
# 如果待处理任务过多,记录警告日志
if len(self.image_tasks) > 50:
logger.warning(f"待处理的图片处理任务数量过多: {len(self.image_tasks)}")
def write_to_db(self):
"""从Redis读取发言统计数据并写入数据库"""
# 获取Redis连接
redis_conn = self.db_manager.get_redis_connection()
# 获取当前日期的前一天
yesterday = (datetime.now() - timedelta(days=1)).strftime('%Y-%m-%d')
# 遍历Redis中所有与昨天日期相关的key并写入数据库
for key_item in redis_conn.keys(f"*:*:{yesterday}:count"):
# 检查key是否为字节类型如果是则解码
key = key_item.decode('utf-8') if isinstance(key_item, bytes) else key_item
parts = key.split(':')
group_id, wx_id, _date = parts[0], parts[1], parts[2] # _date应该是yesterday
# 获取计数值
count_bytes = redis_conn.hget(key, 'count')
count = int(count_bytes) if isinstance(count_bytes, bytes) else int(count_bytes) if count_bytes else 0
# 使用MessageStorageDB插入数据
try:
result = self.message_db.insert_speech_count(group_id, wx_id, yesterday, count)
if result:
logging.info(f"成功写入发言统计: {group_id}, {wx_id}, {yesterday}, {count}")
else:
logging.error(f"写入发言统计失败: {group_id}, {wx_id}, {yesterday}, {count}")
try:
levels_db = LevelsDBOperator(self.db_manager)
delta = int(0.5 * min(count, 10))
if delta > 0:
levels_db.add_exp(wx_id, group_id, delta, "speech_count")
except Exception as e2:
logging.error(f"写入等级经验失败: {group_id}, {wx_id}, {yesterday}, {count} - {e2}")
except Exception as e:
logging.error(f"写入发言统计出错: {e}")
async def generate_and_send_ranking(self, groupId, allContacts: dict):
"""生成并发送群聊发言排名,并根据排名发放积分奖励"""
try:
yesterday = (datetime.now() - timedelta(days=1)).strftime('%Y-%m-%d')
# 使用数据库操作类获取排名数据
results = self.message_db.get_speech_ranking(yesterday, groupId, limit=20)
if not results:
logging.info(f"没有找到 {yesterday} 的群聊 {groupId} 发言记录")
return False, f"📊 {yesterday} 没有发言记录"
# 格式化输出字符串添加emoji和美化格式
ranking_str = f"🏆 {yesterday} 发言排行榜 🏆\n"
con = ContactManager.get_instance()
# 为不同名次添加不同的奖杯和样式,并发放积分
for rank, result in enumerate(results, start=1):
username = result['wx_id']
speech_count = result['speech_count']
display_name = con.get_group_name(groupId, username) or username
# display_name = await self.client.get_chatroom_nickname(username, groupId)
if isinstance(display_name, str):
display_name = display_name
else:
display_name = ','.join(display_name)
# 根据排名发放不同数量的积分
reward_points = 0
if rank == 1:
reward_points = 30
ranking_str += f"🥇🐲 {rank}.{display_name}: {speech_count}次 🔥 +{reward_points}积分\n"
elif rank == 2:
reward_points = 20
ranking_str += f"🥈 {rank}.{display_name}: {speech_count}次 ✨ +{reward_points}积分\n"
elif rank == 3:
reward_points = 10
ranking_str += f"🥉 {rank}.{display_name}: {speech_count}次 👏 +{reward_points}积分\n"
elif rank <= 10:
reward_points = 5
ranking_str += f"🌟 {rank}.{display_name}: {speech_count}次 +{reward_points}积分\n"
else:
reward_points = 3
ranking_str += f"👍 {rank}.{display_name}: {speech_count}次 +{reward_points}积分\n"
# 发放积分奖励
if reward_points > 0:
success, _ = self.points_db.add_points(
username,
groupId,
reward_points,
PointSource.OTHER,
f"{yesterday}发言排行第{rank}名奖励"
)
if not success:
logging.error(f"发放积分失败: {username}, {groupId}, {reward_points}")
logging.info(f"成功生成 {yesterday} 的群聊 {groupId} 发言排名并发放积分")
return True, ranking_str
except Exception as e:
logging.error(f"生成发言排名出错: {e}")
return False, f"❌ 生成发言排名出错: {e}"
def get_messages(self, group_id, all_contacts: dict):
try:
# 获取 Redis 连接
redis_conn = self.db_manager.get_redis_connection()
# 获取 redis 中的上次总结时间,本次从上次开始算,若没有,则从 8 小时之前开始计算
key = f"{group_id}:summary_time"
last_summary_time = redis_conn.get(key)
logger.info(f"上次总结时间: {last_summary_time}")
current_time = datetime.now()
current_date = current_time.strftime('%Y-%m-%d %H:%M:%S')
if not last_summary_time:
# 获取当前时间并计算 8 小时前的时间
eight_hours_ago = current_time - timedelta(hours=8)
last_summary_time = eight_hours_ago.strftime('%Y-%m-%d %H:%M:%S')
else:
# 如果 Redis 返回值为字节类型,转换为字符串
if isinstance(last_summary_time, bytes):
last_summary_time = last_summary_time.decode('utf-8')
# 检查 redis 中的时间与当前时间差是否小于 3 小时
last_summary_time_obj = datetime.strptime(last_summary_time, '%Y-%m-%d %H:%M:%S')
time_diff = current_time - last_summary_time_obj
if time_diff < timedelta(hours=3):
# 小于 3 小时,取 8 小时前
last_summary_time = (current_time - timedelta(hours=8)).strftime('%Y-%m-%d %H:%M:%S')
elif time_diff > timedelta(days=1):
# 大于 24 小时,取 10 小时前
last_summary_time = (current_time - timedelta(hours=24)).strftime('%Y-%m-%d %H:%M:%S')
# 更新 Redis 存储的当前时间
redis_conn.set(key, current_date)
# 使用智能查询方法(自动调整时间范围,确保有足够的消息)
messages = self.message_db.get_messages_for_summary(
group_id,
hours_ago=8, # 默认8小时
min_messages=50, # 最少需要50条消息
max_hours=48, # 最多查询48小时
max_results=5000 # 最多返回5000条之前是500
)
# 使用优化后的格式化方法
result_str = self._format_messages_optimized(messages, all_contacts)
logger.info(f"获取到 {len(messages)} 条消息,格式化后长度: {len(result_str)}")
return result_str
except Exception as e:
logger.error(f"获取消息出错: {e}")
return ""
def _format_messages_optimized(self, messages: list, all_contacts: dict) -> str:
"""优化的消息格式化方法,减少冗余
格式示例:
【01-05】
【08:30】
张三消息1
消息2
李四消息3
【01-06】
【09:00】
王五消息4
"""
if not messages:
return ""
from collections import defaultdict
import xml.etree.ElementTree as ET
# 按日期和时间分组
# 结构: {date_key: {time_key: {sender_name: [contents]}}}
time_groups = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
dates_included = set() # 记录出现的日期
for msg in messages:
timestamp, sender, content, message_type = msg['timestamp'], msg['sender'], msg['content'], msg['message_type']
# 处理应用消息
try:
if message_type == 49:
root = ET.fromstring(content)
title_elem = root.find('.//title')
if title_elem is not None:
content = title_elem.text
except Exception as e:
logger.error(f"解析消息类型49出错: {e}")
# 解析时间并按小时分组
try:
dt = datetime.strptime(str(timestamp), '%Y-%m-%d %H:%M:%S')
date_key = dt.strftime('%Y-%m-%d') # 完整日期
time_key = dt.strftime('%H:%M') # 只有时分
display_date = dt.strftime('%m-%d') # 显示的日期(月-日)
# 获取发送者名称
sender_name = all_contacts.get(sender, sender)
# 添加到分组date_key -> time_key -> sender_name -> [contents]
time_groups[date_key][time_key][sender_name].append(content)
dates_included.add(display_date)
except Exception as e:
logger.warning(f"解析时间戳失败: {timestamp}, 错误: {e}")
continue
# 构建结果字符串
result_lines = []
# 按日期排序
for date_key in sorted(time_groups.keys()):
# 添加日期标题(月-日格式)
display_date = datetime.strptime(date_key, '%Y-%m-%d').strftime('%m-%d')
result_lines.append(f"\n{display_date}")
# 获取该日期的所有时间段
time_slots = time_groups[date_key]
# 按时间排序
for time_key in sorted(time_slots.keys()):
# 添加时间标题
result_lines.append(f"{time_key}")
# 获取该时间段的所有发言者
senders = time_slots[time_key]
# 按发送者组织消息
for sender_name, contents in senders.items():
# 如果一个人有多条消息,缩进显示
for idx, content in enumerate(contents):
if idx == 0:
# 第一条消息显示发送者名
result_lines.append(f"{sender_name}{content}")
else:
# 后续消息缩进
result_lines.append(f" {content}")
return "\n".join(result_lines)
def get_messages_by_date(self, group_id: str, all_contacts: dict, days: int = 1) -> str:
"""按天获取消息(用于按天总结)
Args:
group_id: 群组ID
all_contacts: 联系人字典
days: 获取最近几天的消息默认1天昨天+今天)
Returns:
格式化后的消息字符串
"""
try:
current_time = datetime.now()
# 计算日期范围
if days == 1:
# 昨天全天 + 今天到目前为止
yesterday = (current_time - timedelta(days=1)).strftime('%Y-%m-%d')
today = current_time.strftime('%Y-%m-%d')
start_date = yesterday
end_date = today
else:
# 获取最近N天
start_date = (current_time - timedelta(days=days)).strftime('%Y-%m-%d')
end_date = current_time.strftime('%Y-%m-%d')
# 使用新的按日期查询方法
messages = self.message_db.get_messages_by_date_range(
group_id,
start_date=start_date,
end_date=end_date,
max_results=5000 # 增加到5000条
)
# 使用优化后的格式化方法
result_str = self._format_messages_optimized(messages, all_contacts)
logger.info(f"按天查询获取到 {len(messages)} 条消息({start_date}{end_date}")
return result_str
except Exception as e:
logger.error(f"按天获取消息出错: {e}")
return ""