306 lines
9.4 KiB
Python
306 lines
9.4 KiB
Python
"""
|
||
消息队列模块
|
||
|
||
提供高性能的优先级消息队列,支持多种溢出策略:
|
||
- drop_oldest: 丢弃最旧的消息
|
||
- drop_lowest: 丢弃优先级最低的消息
|
||
- sampling: 按采样率丢弃消息
|
||
- reject: 拒绝新消息
|
||
"""
|
||
|
||
import asyncio
|
||
import heapq
|
||
import random
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
from loguru import logger
|
||
|
||
|
||
# ==================== 消息优先级常量 ====================
|
||
|
||
class MessagePriority:
|
||
"""消息优先级常量"""
|
||
CRITICAL = 100 # 系统消息、登录信息
|
||
HIGH = 80 # 管理员命令、群成员变动
|
||
NORMAL = 50 # @bot 消息(默认)
|
||
LOW = 20 # 普通群消息
|
||
|
||
|
||
# ==================== 溢出策略 ====================
|
||
|
||
class OverflowStrategy(Enum):
|
||
"""队列溢出策略"""
|
||
DROP_OLDEST = "drop_oldest" # 丢弃最旧的消息
|
||
DROP_LOWEST = "drop_lowest" # 丢弃优先级最低的消息
|
||
SAMPLING = "sampling" # 按采样率丢弃
|
||
REJECT = "reject" # 拒绝新消息
|
||
|
||
|
||
# ==================== 优先级消息 ====================
|
||
|
||
@dataclass(order=True)
|
||
class PriorityMessage:
|
||
"""优先级消息"""
|
||
priority: int = field(compare=True)
|
||
timestamp: float = field(compare=True)
|
||
msg_type: int = field(compare=False)
|
||
data: Dict[str, Any] = field(compare=False)
|
||
|
||
def __init__(self, msg_type: int, data: Dict[str, Any], priority: int = None):
|
||
# 优先级越高,数值越大,但 heapq 是最小堆,所以取负数
|
||
self.priority = -(priority if priority is not None else MessagePriority.NORMAL)
|
||
self.timestamp = time.time()
|
||
self.msg_type = msg_type
|
||
self.data = data
|
||
|
||
|
||
# ==================== 优先级消息队列 ====================
|
||
|
||
class PriorityMessageQueue:
|
||
"""
|
||
优先级消息队列
|
||
|
||
特性:
|
||
- 基于堆的优先级队列
|
||
- 支持多种溢出策略
|
||
- 线程安全(使用 asyncio.Lock)
|
||
- 支持任务计数和 join
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
maxsize: int = 1000,
|
||
overflow_strategy: str = "drop_oldest",
|
||
sampling_rate: float = 0.5,
|
||
):
|
||
"""
|
||
初始化队列
|
||
|
||
Args:
|
||
maxsize: 最大队列大小
|
||
overflow_strategy: 溢出策略 (drop_oldest, drop_lowest, sampling, reject)
|
||
sampling_rate: 采样策略的保留率 (0.0-1.0)
|
||
"""
|
||
self.maxsize = maxsize
|
||
self.overflow_strategy = OverflowStrategy(overflow_strategy)
|
||
self.sampling_rate = max(0.0, min(1.0, sampling_rate))
|
||
|
||
self._heap: List[PriorityMessage] = []
|
||
self._lock = asyncio.Lock()
|
||
self._not_empty = asyncio.Event()
|
||
self._unfinished_tasks = 0
|
||
self._finished = asyncio.Event()
|
||
self._finished.set()
|
||
|
||
# 统计
|
||
self._total_put = 0
|
||
self._total_dropped = 0
|
||
self._total_rejected = 0
|
||
|
||
def qsize(self) -> int:
|
||
"""返回队列大小"""
|
||
return len(self._heap)
|
||
|
||
def empty(self) -> bool:
|
||
"""队列是否为空"""
|
||
return len(self._heap) == 0
|
||
|
||
def full(self) -> bool:
|
||
"""队列是否已满"""
|
||
return len(self._heap) >= self.maxsize
|
||
|
||
async def put(
|
||
self,
|
||
msg_type: int,
|
||
data: Dict[str, Any],
|
||
priority: int = None,
|
||
) -> bool:
|
||
"""
|
||
添加消息到队列
|
||
|
||
Args:
|
||
msg_type: 消息类型
|
||
data: 消息数据
|
||
priority: 优先级(可选)
|
||
|
||
Returns:
|
||
是否成功添加
|
||
"""
|
||
async with self._lock:
|
||
self._total_put += 1
|
||
|
||
# 处理队列满的情况
|
||
if self.full():
|
||
if not self._handle_overflow():
|
||
self._total_rejected += 1
|
||
return False
|
||
|
||
msg = PriorityMessage(msg_type, data, priority)
|
||
heapq.heappush(self._heap, msg)
|
||
self._unfinished_tasks += 1
|
||
self._finished.clear()
|
||
self._not_empty.set()
|
||
return True
|
||
|
||
def _handle_overflow(self) -> bool:
|
||
"""
|
||
处理队列溢出
|
||
|
||
Returns:
|
||
True 表示成功腾出空间,False 表示拒绝
|
||
"""
|
||
if self.overflow_strategy == OverflowStrategy.REJECT:
|
||
logger.warning("队列已满,拒绝新消息")
|
||
return False
|
||
|
||
if self.overflow_strategy == OverflowStrategy.DROP_OLDEST:
|
||
# 找到最旧的消息(timestamp 最小)
|
||
if self._heap:
|
||
oldest_idx = 0
|
||
for i, msg in enumerate(self._heap):
|
||
if msg.timestamp < self._heap[oldest_idx].timestamp:
|
||
oldest_idx = i
|
||
self._heap.pop(oldest_idx)
|
||
heapq.heapify(self._heap)
|
||
self._total_dropped += 1
|
||
self._unfinished_tasks = max(0, self._unfinished_tasks - 1)
|
||
return True
|
||
|
||
elif self.overflow_strategy == OverflowStrategy.DROP_LOWEST:
|
||
# 找到优先级最低的消息(priority 值最大,因为是负数)
|
||
if self._heap:
|
||
lowest_idx = 0
|
||
for i, msg in enumerate(self._heap):
|
||
if msg.priority > self._heap[lowest_idx].priority:
|
||
lowest_idx = i
|
||
self._heap.pop(lowest_idx)
|
||
heapq.heapify(self._heap)
|
||
self._total_dropped += 1
|
||
self._unfinished_tasks = max(0, self._unfinished_tasks - 1)
|
||
return True
|
||
|
||
elif self.overflow_strategy == OverflowStrategy.SAMPLING:
|
||
# 按采样率决定是否接受
|
||
if random.random() < self.sampling_rate:
|
||
# 接受新消息,丢弃最旧的
|
||
if self._heap:
|
||
oldest_idx = 0
|
||
for i, msg in enumerate(self._heap):
|
||
if msg.timestamp < self._heap[oldest_idx].timestamp:
|
||
oldest_idx = i
|
||
self._heap.pop(oldest_idx)
|
||
heapq.heapify(self._heap)
|
||
self._total_dropped += 1
|
||
self._unfinished_tasks = max(0, self._unfinished_tasks - 1)
|
||
return True
|
||
else:
|
||
self._total_dropped += 1
|
||
return False
|
||
|
||
return False
|
||
|
||
async def get(self, timeout: float = None) -> Tuple[int, Dict[str, Any]]:
|
||
"""
|
||
获取优先级最高的消息
|
||
|
||
Args:
|
||
timeout: 超时时间(秒),None 表示无限等待
|
||
|
||
Returns:
|
||
(msg_type, data) 元组
|
||
|
||
Raises:
|
||
asyncio.TimeoutError: 超时
|
||
"""
|
||
start_time = time.time()
|
||
|
||
while True:
|
||
async with self._lock:
|
||
if self._heap:
|
||
msg = heapq.heappop(self._heap)
|
||
if not self._heap:
|
||
self._not_empty.clear()
|
||
return (msg.msg_type, msg.data)
|
||
|
||
# 计算剩余超时时间
|
||
if timeout is not None:
|
||
elapsed = time.time() - start_time
|
||
remaining = timeout - elapsed
|
||
if remaining <= 0:
|
||
raise asyncio.TimeoutError("Queue get timeout")
|
||
try:
|
||
await asyncio.wait_for(self._not_empty.wait(), timeout=remaining)
|
||
except asyncio.TimeoutError:
|
||
raise asyncio.TimeoutError("Queue get timeout")
|
||
else:
|
||
await self._not_empty.wait()
|
||
|
||
def get_nowait(self) -> Tuple[int, Dict[str, Any]]:
|
||
"""非阻塞获取消息"""
|
||
if not self._heap:
|
||
raise asyncio.QueueEmpty()
|
||
msg = heapq.heappop(self._heap)
|
||
if not self._heap:
|
||
self._not_empty.clear()
|
||
return (msg.msg_type, msg.data)
|
||
|
||
def task_done(self):
|
||
"""标记任务完成"""
|
||
self._unfinished_tasks = max(0, self._unfinished_tasks - 1)
|
||
if self._unfinished_tasks == 0:
|
||
self._finished.set()
|
||
|
||
async def join(self):
|
||
"""等待所有任务完成"""
|
||
await self._finished.wait()
|
||
|
||
def clear(self):
|
||
"""清空队列"""
|
||
self._heap.clear()
|
||
self._not_empty.clear()
|
||
self._unfinished_tasks = 0
|
||
self._finished.set()
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""获取队列统计信息"""
|
||
return {
|
||
"current_size": len(self._heap),
|
||
"max_size": self.maxsize,
|
||
"total_put": self._total_put,
|
||
"total_dropped": self._total_dropped,
|
||
"total_rejected": self._total_rejected,
|
||
"unfinished_tasks": self._unfinished_tasks,
|
||
"overflow_strategy": self.overflow_strategy.value,
|
||
"utilization": len(self._heap) / max(self.maxsize, 1),
|
||
}
|
||
|
||
@classmethod
|
||
def from_config(cls, queue_config: Dict[str, Any]) -> "PriorityMessageQueue":
|
||
"""
|
||
从配置创建队列
|
||
|
||
Args:
|
||
queue_config: Queue 配置节
|
||
|
||
Returns:
|
||
PriorityMessageQueue 实例
|
||
"""
|
||
return cls(
|
||
maxsize=queue_config.get("max_size", 1000),
|
||
overflow_strategy=queue_config.get("overflow_strategy", "drop_oldest"),
|
||
sampling_rate=queue_config.get("sampling_rate", 0.5),
|
||
)
|
||
|
||
|
||
# ==================== 导出 ====================
|
||
|
||
__all__ = [
|
||
'MessagePriority',
|
||
'OverflowStrategy',
|
||
'PriorityMessage',
|
||
'PriorityMessageQueue',
|
||
]
|