diff --git a/db/base.py b/db/base.py index 66d08b2..5211455 100644 --- a/db/base.py +++ b/db/base.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +import time + from loguru import logger from typing import List, Dict, Any, Optional, Tuple, Union @@ -12,19 +14,62 @@ class BaseDBOperator: self.db_manager = db_manager self.LOG = logger + @staticmethod + def _compact_sql(sql: str) -> str: + """把 SQL 压成单行,便于日志里快速定位问题。""" + return " ".join(str(sql or "").split()) + + @classmethod + def _truncate_text(cls, value, max_length: int = 240) -> str: + """截断长文本,避免日志被超长 SQL 或参数刷屏。""" + text = str(value or "") + if len(text) <= max_length: + return text + return f"{text[:max_length]}..." + + def _log_sql_timing(self, operation: str, sql: str, params, elapsed_ms: float, affected_rows: Optional[int] = None) -> None: + """记录慢 SQL 日志。 + + 设计说明: + 1. 只在超过阈值时输出 warning,避免日常日志噪声过大; + 2. 统一输出压缩后的 SQL 与截断参数,便于线上排查具体慢点; + 3. 查询/更新/批量/事务都走同一入口,后续如果要接后台审计也更容易扩展。 + """ + if not self.db_manager.is_slow_query_log_enabled(): + return + + threshold_ms = self.db_manager.get_slow_query_threshold_ms() + if elapsed_ms < threshold_ms: + return + + affected_text = "" + if affected_rows is not None: + affected_text = f" affected_rows={affected_rows}" + self.LOG.warning( + f"检测到慢SQL operation={operation} cost_ms={round(elapsed_ms, 2)} threshold_ms={threshold_ms}" + f"{affected_text} sql={self._truncate_text(self._compact_sql(sql), 400)} " + f"params={self._truncate_text(params, 240)}" + ) + def execute_query(self, sql: str, params: Optional[tuple] = None, fetch_one: bool = False) -> Union[ List[Dict], Dict, None]: """执行查询SQL""" conn = self.db_manager.get_mysql_connection() + started_at = time.perf_counter() try: with conn.cursor(dictionary=True) as cursor: cursor.execute(sql, params or ()) + elapsed_ms = (time.perf_counter() - started_at) * 1000 if fetch_one: - return cursor.fetchone() - return cursor.fetchall() + result = cursor.fetchone() + self._log_sql_timing("query_one", sql, params, elapsed_ms, 1 if result else 0) + return result + result = cursor.fetchall() + self._log_sql_timing("query", sql, params, elapsed_ms, len(result or [])) + return result except Exception as e: self.LOG.error( - f"执行更新SQL出错: {e}, SQL: {sql}, 参数: {str(params)[:200] + '...' if len(str(params)) > 200 else params}" + f"执行查询SQL出错: {e}, SQL: {sql}, 参数: {str(params)[:200] + '...' if len(str(params)) > 200 else params}" ) return None finally: @@ -33,10 +78,13 @@ class BaseDBOperator: def execute_update(self, sql: str, params: Optional[tuple] = None) -> bool: """执行更新SQL""" conn = self.db_manager.get_mysql_connection() + started_at = time.perf_counter() try: with conn.cursor() as cursor: cursor.execute(sql, params or ()) + affected_rows = cursor.rowcount conn.commit() + self._log_sql_timing("update", sql, params, (time.perf_counter() - started_at) * 1000, affected_rows) return True except Exception as e: self.LOG.error( @@ -53,10 +101,19 @@ class BaseDBOperator: return True conn = self.db_manager.get_mysql_connection() + started_at = time.perf_counter() try: with conn.cursor() as cursor: cursor.executemany(sql, params_list) + affected_rows = cursor.rowcount conn.commit() + self._log_sql_timing( + "batch_update", + sql, + f"params_count={len(params_list)}", + (time.perf_counter() - started_at) * 1000, + affected_rows, + ) return True except Exception as e: self.LOG.error(f"批量执行SQL出错: {e}, SQL: {sql}, 参数数量: {len(params_list)}") @@ -71,11 +128,18 @@ class BaseDBOperator: return True conn = self.db_manager.get_mysql_connection() + started_at = time.perf_counter() try: with conn.cursor() as cursor: for sql, params in operations: cursor.execute(sql, params) conn.commit() + self._log_sql_timing( + "transaction", + f"{len(operations)} statements", + f"operations={len(operations)}", + (time.perf_counter() - started_at) * 1000, + ) return True except Exception as e: self.LOG.error(f"执行事务出错: {e}, 操作数量: {len(operations)}") diff --git a/db/connection.py b/db/connection.py index 17964ff..10bc57f 100644 --- a/db/connection.py +++ b/db/connection.py @@ -39,7 +39,13 @@ class DBConnectionManager: self.LOG = logger self.mysql_pool = None self.redis_pool = None - + # 保存原始配置快照,供慢 SQL 阈值、库名探测等公共能力复用: + # 1. BaseDBOperator 需要读取数据库名,去 information_schema 中检查索引; + # 2. 慢 SQL 记录需要统一读取阈值配置,而不是每个 DB Operator 各自硬编码; + # 3. 这里做浅拷贝即可,避免后续外部修改传入 dict 时影响内部状态。 + self.mysql_config = dict(mysql_config or {}) + self.redis_config = dict(redis_config or {}) + # 初始化MySQL连接池 if mysql_config: self.init_mysql_pool(mysql_config) @@ -58,6 +64,8 @@ class DBConnectionManager: if not config: self.LOG.warning("MySQL配置为空,跳过初始化") return + + self.mysql_config = dict(config or {}) # 准备连接池配置 pool_config = { @@ -90,6 +98,8 @@ class DBConnectionManager: if not config: self.LOG.warning("Redis配置为空,跳过初始化") return + + self.redis_config = dict(config or {}) self.redis_pool = redis.ConnectionPool( host=config.get('host', 'localhost'), @@ -117,6 +127,26 @@ class DBConnectionManager: raise Exception("MySQL连接池未初始化") return self.mysql_pool.get_connection() + + def get_mysql_database_name(self) -> str: + """返回当前 MySQL 目标库名。""" + return str(self.mysql_config.get('database', '') or '').strip() + + def get_slow_query_threshold_ms(self) -> int: + """读取慢 SQL 阈值,默认 500ms。""" + try: + threshold = int(self.mysql_config.get('slow_query_threshold_ms', 500) or 500) + return threshold if threshold > 0 else 500 + except (TypeError, ValueError): + return 500 + + def is_slow_query_log_enabled(self) -> bool: + """是否启用慢 SQL 日志。""" + raw_value = self.mysql_config.get('enable_slow_query_log', True) + if isinstance(raw_value, str): + normalized = raw_value.strip().lower() + return normalized not in {'0', 'false', 'off', 'no'} + return bool(raw_value) def get_redis_connection(self): """获取Redis连接 @@ -140,4 +170,4 @@ class DBConnectionManager: # 关闭Redis连接池 if self.redis_pool: self.redis_pool.disconnect() - self.redis_pool = None \ No newline at end of file + self.redis_pool = None diff --git a/db/message_storage.py b/db/message_storage.py index 9bbf852..ee62ea9 100644 --- a/db/message_storage.py +++ b/db/message_storage.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- -from datetime import datetime +from datetime import datetime, timedelta import json +from threading import Lock from typing import Dict, List, Optional from db.base import BaseDBOperator @@ -12,8 +13,103 @@ from wechat_ipad.models.message import WxMessage class MessageStorageDB(BaseDBOperator): """消息存储相关数据库操作""" + _performance_ready = False + _performance_lock = Lock() + def __init__(self, db_manager: DBConnectionManager): super().__init__(db_manager) + self._ensure_performance_primitives() + + @staticmethod + def _normalize_datetime_text(value) -> str: + """把日期/时间对象统一转成数据库可比较的标准字符串。""" + if isinstance(value, datetime): + return value.strftime("%Y-%m-%d %H:%M:%S") + return str(value or "").strip() + + @classmethod + def _build_day_time_range(cls, target_date: str) -> tuple[str, str]: + """把 `YYYY-MM-DD` 日期转换成 `[00:00:00, 次日00:00:00)` 时间范围。""" + start_dt = datetime.strptime(str(target_date or "").strip(), "%Y-%m-%d") + end_dt = start_dt.replace(hour=0, minute=0, second=0, microsecond=0) + next_day_dt = end_dt + timedelta(days=1) + return ( + end_dt.strftime("%Y-%m-%d 00:00:00"), + next_day_dt.strftime("%Y-%m-%d 00:00:00"), + ) + + @classmethod + def _build_day_bounds(cls, start_date: str, end_date: str) -> tuple[str, str]: + """把日期区间转换成适合索引命中的时间范围。""" + start_dt = datetime.strptime(str(start_date or "").strip(), "%Y-%m-%d") + end_dt = datetime.strptime(str(end_date or "").strip(), "%Y-%m-%d") + if end_dt < start_dt: + start_dt, end_dt = end_dt, start_dt + next_day_dt = end_dt + timedelta(days=1) + return ( + start_dt.strftime("%Y-%m-%d 00:00:00"), + next_day_dt.strftime("%Y-%m-%d 00:00:00"), + ) + + def _ensure_performance_primitives(self) -> None: + """确保消息存储相关的关键索引存在。 + + 设计说明: + 1. 这一步只补“高频查询明确受益”的索引,不做激进表结构重写; + 2. 使用 information_schema 做存在性检查,保证重复启动时仍然幂等; + 3. 只在进程内执行一次,避免每次 new MessageStorageDB 都重复打元数据查询。 + """ + if self.__class__._performance_ready: + return + + with self.__class__._performance_lock: + if self.__class__._performance_ready: + return + + self._ensure_index_exists( + table_name="messages", + index_name="idx_group_sender_timestamp", + create_sql="CREATE INDEX idx_group_sender_timestamp ON messages (group_id, sender, timestamp)", + ) + self._ensure_index_exists( + table_name="messages", + index_name="idx_group_type_timestamp", + create_sql="CREATE INDEX idx_group_type_timestamp ON messages (group_id, message_type, timestamp)", + ) + self._ensure_index_exists( + table_name="messages", + index_name="idx_media_pending_lookup", + create_sql="CREATE INDEX idx_media_pending_lookup ON messages (message_type, image_path, timestamp, group_id)", + ) + self.__class__._performance_ready = True + + def _ensure_index_exists(self, table_name: str, index_name: str, create_sql: str) -> None: + """按需补建单个索引。""" + database_name = self.db_manager.get_mysql_database_name() + if not database_name: + return + + existing = self.execute_query( + """ + SELECT 1 + FROM information_schema.statistics + WHERE table_schema = %s + AND table_name = %s + AND index_name = %s + LIMIT 1 + """, + (database_name, table_name, index_name), + fetch_one=True, + ) + if existing: + return + + # 索引补建属于“性能自愈”动作: + # 1. 不要求用户手工跑 migration,服务启动时可自动补齐; + # 2. 若线上库字段类型和预期不一致,失败后只记日志,不阻断主流程; + # 3. 这样先拿到可观测收益,再决定后续是否做更完整的 schema migration。 + if not self.execute_update(create_sql): + self.LOG.warning(f"消息表索引补建失败,请人工检查: table={table_name}, index={index_name}") def archive_message(self, msg: WxMessage) -> bool: """存档消息 @@ -252,10 +348,12 @@ class MessageStorageDB(BaseDBOperator): def get_member_messages_on_date(self, group_id: str, wxid: str, target_date: str, limit: int = 120) -> List[Dict]: """获取成员在某一天的消息""" + start_time, end_time = self._build_day_time_range(target_date) sql = """ SELECT timestamp, sender, content, message_type FROM messages - WHERE DATE(timestamp) = %s + WHERE timestamp >= %s + AND timestamp < %s AND group_id = %s AND sender = %s AND message_type IN (1, 49) @@ -264,14 +362,16 @@ class MessageStorageDB(BaseDBOperator): ORDER BY timestamp ASC LIMIT %s """ - return self.execute_query(sql, (target_date, group_id, wxid, limit)) or [] + return self.execute_query(sql, (start_time, end_time, group_id, wxid, limit)) or [] def get_member_messages_for_group_date(self, group_id: str, target_date: str, limit: int = 5000) -> List[Dict]: """获取群在某一天的全部文本消息""" + start_time, end_time = self._build_day_time_range(target_date) sql = """ SELECT timestamp, sender, content, message_type FROM messages - WHERE DATE(timestamp) = %s + WHERE timestamp >= %s + AND timestamp < %s AND group_id = %s AND sender IS NOT NULL AND sender <> '' @@ -281,7 +381,7 @@ class MessageStorageDB(BaseDBOperator): ORDER BY timestamp ASC LIMIT %s """ - return self.execute_query(sql, (target_date, group_id, limit)) or [] + return self.execute_query(sql, (start_time, end_time, group_id, limit)) or [] def get_recent_group_chat_messages(self, group_id: str, limit: int = 20) -> List[Dict]: """获取群聊最近消息""" @@ -315,13 +415,15 @@ class MessageStorageDB(BaseDBOperator): def get_message_count_by_date(self, date: str) -> List[Dict]: """获取指定日期的消息统计""" + start_time, end_time = self._build_day_time_range(date) sql = """ SELECT group_id, sender, COUNT(*) as count FROM messages - WHERE DATE(timestamp) = %s + WHERE timestamp >= %s + AND timestamp < %s GROUP BY group_id, sender """ - return self.execute_query(sql, (date,)) or [] + return self.execute_query(sql, (start_time, end_time)) or [] def get_speech_ranking(self, date: str, group_id: str, limit: int = 20) -> List[Dict]: """获取指定日期和群组的发言排名""" @@ -480,14 +582,19 @@ class MessageStorageDB(BaseDBOperator): params.append(group_id) if start_date: - sql_count += " AND DATE(timestamp) >= %s " - sql_data += " AND DATE(timestamp) >= %s " - params.append(start_date) + start_bound = f"{str(start_date).strip()} 00:00:00" + sql_count += " AND timestamp >= %s " + sql_data += " AND timestamp >= %s " + params.append(start_bound) if end_date: - sql_count += " AND DATE(timestamp) <= %s " - sql_data += " AND DATE(timestamp) <= %s " - params.append(end_date) + _, end_bound = self._build_day_bounds( + start_date or str(end_date).strip(), + str(end_date).strip(), + ) + sql_count += " AND timestamp < %s " + sql_data += " AND timestamp < %s " + params.append(end_bound) if search_text: sql_count += " AND content LIKE %s " @@ -665,8 +772,8 @@ class MessageStorageDB(BaseDBOperator): """ return self.execute_query(sql, (f'%md5="{md5}"%',), fetch_one=True) - def get_messages_by_date_range(self, group_id: str, start_date: str, end_date: str = None, - min_content_length: int = 6, max_results: int = 5000) -> List[Dict]: + def get_messages_by_calendar_range(self, group_id: str, start_date: str, end_date: str = None, + min_content_length: int = 6, max_results: int = 5000) -> List[Dict]: """按日期范围获取消息(支持按天总结) Args: @@ -682,11 +789,13 @@ class MessageStorageDB(BaseDBOperator): if end_date is None: end_date = start_date + start_time, end_time = self._build_day_bounds(start_date, end_date) + sql = """ SELECT timestamp, sender, content, message_type FROM messages - WHERE DATE(timestamp) >= %s - AND DATE(timestamp) <= %s + WHERE timestamp >= %s + AND timestamp < %s AND group_id = %s AND message_type IN (1, 49) AND LENGTH(content) > %s @@ -695,7 +804,7 @@ class MessageStorageDB(BaseDBOperator): ORDER BY timestamp ASC LIMIT %s """ - params = (start_date, end_date, group_id, min_content_length, max_results) + params = (start_time, end_time, group_id, min_content_length, max_results) return self.execute_query(sql, params) or [] def get_messages_for_summary(self, group_id: str, hours_ago: int = 8, @@ -749,8 +858,8 @@ class MessageStorageDB(BaseDBOperator): AND content NOT LIKE '/%' ORDER BY timestamp ASC """ - params = (start_time.strftime('%Y-%m-%d %H:%M:%S'), - end_time.strftime('%Y-%m-%d %H:%M:%S'), + params = (self._normalize_datetime_text(start_time), + self._normalize_datetime_text(end_time), group_id) return self.execute_query(sql, params) or [] @@ -776,8 +885,8 @@ class MessageStorageDB(BaseDBOperator): AND CHAR_LENGTH(content) < 300 AND content NOT LIKE '/%' """ - params = (start_time.strftime('%Y-%m-%d %H:%M:%S'), - end_time.strftime('%Y-%m-%d %H:%M:%S'), + params = (self._normalize_datetime_text(start_time), + self._normalize_datetime_text(end_time), group_id) result = self.execute_query(sql, params) return result[0]['count'] if result else 0 @@ -801,8 +910,8 @@ class MessageStorageDB(BaseDBOperator): AND sender <> '' """ params = ( - start_time.strftime('%Y-%m-%d %H:%M:%S'), - end_time.strftime('%Y-%m-%d %H:%M:%S'), + self._normalize_datetime_text(start_time), + self._normalize_datetime_text(end_time), group_id, ) result = self.execute_query(sql, params, fetch_one=True) or {} diff --git a/db/scripts/init.sql b/db/scripts/init.sql index 5db917a..909bb06 100644 --- a/db/scripts/init.sql +++ b/db/scripts/init.sql @@ -52,6 +52,12 @@ create or replace index idx_date_timestamp create or replace index idx_group_timestamp on message_archive.messages (group_id, timestamp); +create or replace index idx_group_sender_timestamp + on message_archive.messages (group_id, sender, timestamp); + +create or replace index idx_group_type_timestamp + on message_archive.messages (group_id, message_type, timestamp); + create or replace index idx_message_sender on message_archive.messages (sender); @@ -61,6 +67,9 @@ create or replace index idx_message_type create or replace index messages_message_id_index on message_archive.messages (message_id); +create or replace index idx_media_pending_lookup + on message_archive.messages (message_type, image_path, timestamp, group_id); + create or replace table message_archive.t_emoji_assets ( md5 varchar(64) not null comment '表情MD5' diff --git a/docs/工程优化与Feature清单.md b/docs/工程优化与Feature清单.md index e22e72d..835d5b4 100644 --- a/docs/工程优化与Feature清单.md +++ b/docs/工程优化与Feature清单.md @@ -497,6 +497,14 @@ - 提高高消息量场景下的吞吐与查询效率 +当前进展: + +- 第一阶段已完成:数据库公共层已增加慢 SQL 记录能力,可按 `db_config.slow_query_threshold_ms` 阈值输出慢查询日志 +- 第一阶段已完成:消息存储层启动时会自动补齐关键查询索引,优先覆盖群消息范围查询、成员消息回溯与待处理媒体扫描场景 +- 第一阶段已完成:多处按日期查询已改为时间范围查询,避免 `DATE(timestamp)` 直接作用在索引列上导致索引失效 +- 第一阶段已完成:已修正消息存储层重复定义的日期范围方法,避免按天汇总查询误走错误实现 +- 后续可继续补充统计报表快照表、Redis key 扫描替换方案、后台慢 SQL 看板与更多统计表索引治理 + 建议内容: - 梳理消息表与统计表索引 diff --git a/utils/wechat/message_to_db.py b/utils/wechat/message_to_db.py index 2ab01d6..a28160d 100644 --- a/utils/wechat/message_to_db.py +++ b/utils/wechat/message_to_db.py @@ -883,7 +883,7 @@ class MessageStorage: end_date = current_time.strftime('%Y-%m-%d') # 使用新的按日期查询方法 - messages = self.message_db.get_messages_by_date_range( + messages = self.message_db.get_messages_by_calendar_range( group_id, start_date=start_date, end_date=end_date,