""" 消息队列模块 提供高性能的优先级消息队列,支持多种溢出策略: - drop_oldest: 丢弃最旧的消息 - drop_lowest: 丢弃优先级最低的消息 - sampling: 按采样率丢弃消息 - reject: 拒绝新消息 """ import asyncio import heapq import random import time from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Optional, Tuple from loguru import logger # ==================== 消息优先级常量 ==================== class MessagePriority: """消息优先级常量""" CRITICAL = 100 # 系统消息、登录信息 HIGH = 80 # 管理员命令、群成员变动 NORMAL = 50 # @bot 消息(默认) LOW = 20 # 普通群消息 # ==================== 溢出策略 ==================== class OverflowStrategy(Enum): """队列溢出策略""" DROP_OLDEST = "drop_oldest" # 丢弃最旧的消息 DROP_LOWEST = "drop_lowest" # 丢弃优先级最低的消息 SAMPLING = "sampling" # 按采样率丢弃 REJECT = "reject" # 拒绝新消息 # ==================== 优先级消息 ==================== @dataclass(order=True) class PriorityMessage: """优先级消息""" priority: int = field(compare=True) timestamp: float = field(compare=True) msg_type: int = field(compare=False) data: Dict[str, Any] = field(compare=False) def __init__(self, msg_type: int, data: Dict[str, Any], priority: int = None): # 优先级越高,数值越大,但 heapq 是最小堆,所以取负数 self.priority = -(priority if priority is not None else MessagePriority.NORMAL) self.timestamp = time.time() self.msg_type = msg_type self.data = data # ==================== 优先级消息队列 ==================== class PriorityMessageQueue: """ 优先级消息队列 特性: - 基于堆的优先级队列 - 支持多种溢出策略 - 线程安全(使用 asyncio.Lock) - 支持任务计数和 join """ def __init__( self, maxsize: int = 1000, overflow_strategy: str = "drop_oldest", sampling_rate: float = 0.5, ): """ 初始化队列 Args: maxsize: 最大队列大小 overflow_strategy: 溢出策略 (drop_oldest, drop_lowest, sampling, reject) sampling_rate: 采样策略的保留率 (0.0-1.0) """ self.maxsize = maxsize self.overflow_strategy = OverflowStrategy(overflow_strategy) self.sampling_rate = max(0.0, min(1.0, sampling_rate)) self._heap: List[PriorityMessage] = [] self._lock = asyncio.Lock() self._not_empty = asyncio.Event() self._unfinished_tasks = 0 self._finished = asyncio.Event() self._finished.set() # 统计 self._total_put = 0 self._total_dropped = 0 self._total_rejected = 0 def qsize(self) -> int: """返回队列大小""" return len(self._heap) def empty(self) -> bool: """队列是否为空""" return len(self._heap) == 0 def full(self) -> bool: """队列是否已满""" return len(self._heap) >= self.maxsize async def put( self, msg_type: int, data: Dict[str, Any], priority: int = None, ) -> bool: """ 添加消息到队列 Args: msg_type: 消息类型 data: 消息数据 priority: 优先级(可选) Returns: 是否成功添加 """ async with self._lock: self._total_put += 1 # 处理队列满的情况 if self.full(): if not self._handle_overflow(): self._total_rejected += 1 return False msg = PriorityMessage(msg_type, data, priority) heapq.heappush(self._heap, msg) self._unfinished_tasks += 1 self._finished.clear() self._not_empty.set() return True def _handle_overflow(self) -> bool: """ 处理队列溢出 Returns: True 表示成功腾出空间,False 表示拒绝 """ if self.overflow_strategy == OverflowStrategy.REJECT: logger.warning("队列已满,拒绝新消息") return False if self.overflow_strategy == OverflowStrategy.DROP_OLDEST: # 找到最旧的消息(timestamp 最小) if self._heap: oldest_idx = 0 for i, msg in enumerate(self._heap): if msg.timestamp < self._heap[oldest_idx].timestamp: oldest_idx = i self._heap.pop(oldest_idx) heapq.heapify(self._heap) self._total_dropped += 1 self._unfinished_tasks = max(0, self._unfinished_tasks - 1) return True elif self.overflow_strategy == OverflowStrategy.DROP_LOWEST: # 找到优先级最低的消息(priority 值最大,因为是负数) if self._heap: lowest_idx = 0 for i, msg in enumerate(self._heap): if msg.priority > self._heap[lowest_idx].priority: lowest_idx = i self._heap.pop(lowest_idx) heapq.heapify(self._heap) self._total_dropped += 1 self._unfinished_tasks = max(0, self._unfinished_tasks - 1) return True elif self.overflow_strategy == OverflowStrategy.SAMPLING: # 按采样率决定是否接受 if random.random() < self.sampling_rate: # 接受新消息,丢弃最旧的 if self._heap: oldest_idx = 0 for i, msg in enumerate(self._heap): if msg.timestamp < self._heap[oldest_idx].timestamp: oldest_idx = i self._heap.pop(oldest_idx) heapq.heapify(self._heap) self._total_dropped += 1 self._unfinished_tasks = max(0, self._unfinished_tasks - 1) return True else: self._total_dropped += 1 return False return False async def get(self, timeout: float = None) -> Tuple[int, Dict[str, Any]]: """ 获取优先级最高的消息 Args: timeout: 超时时间(秒),None 表示无限等待 Returns: (msg_type, data) 元组 Raises: asyncio.TimeoutError: 超时 """ start_time = time.time() while True: async with self._lock: if self._heap: msg = heapq.heappop(self._heap) if not self._heap: self._not_empty.clear() return (msg.msg_type, msg.data) # 计算剩余超时时间 if timeout is not None: elapsed = time.time() - start_time remaining = timeout - elapsed if remaining <= 0: raise asyncio.TimeoutError("Queue get timeout") try: await asyncio.wait_for(self._not_empty.wait(), timeout=remaining) except asyncio.TimeoutError: raise asyncio.TimeoutError("Queue get timeout") else: await self._not_empty.wait() def get_nowait(self) -> Tuple[int, Dict[str, Any]]: """非阻塞获取消息""" if not self._heap: raise asyncio.QueueEmpty() msg = heapq.heappop(self._heap) if not self._heap: self._not_empty.clear() return (msg.msg_type, msg.data) def task_done(self): """标记任务完成""" self._unfinished_tasks = max(0, self._unfinished_tasks - 1) if self._unfinished_tasks == 0: self._finished.set() async def join(self): """等待所有任务完成""" await self._finished.wait() def clear(self): """清空队列""" self._heap.clear() self._not_empty.clear() self._unfinished_tasks = 0 self._finished.set() def get_stats(self) -> Dict[str, Any]: """获取队列统计信息""" return { "current_size": len(self._heap), "max_size": self.maxsize, "total_put": self._total_put, "total_dropped": self._total_dropped, "total_rejected": self._total_rejected, "unfinished_tasks": self._unfinished_tasks, "overflow_strategy": self.overflow_strategy.value, "utilization": len(self._heap) / max(self.maxsize, 1), } @classmethod def from_config(cls, queue_config: Dict[str, Any]) -> "PriorityMessageQueue": """ 从配置创建队列 Args: queue_config: Queue 配置节 Returns: PriorityMessageQueue 实例 """ return cls( maxsize=queue_config.get("max_size", 1000), overflow_strategy=queue_config.get("overflow_strategy", "drop_oldest"), sampling_rate=queue_config.get("sampling_rate", 0.5), ) # ==================== 导出 ==================== __all__ = [ 'MessagePriority', 'OverflowStrategy', 'PriorityMessage', 'PriorityMessageQueue', ]