146 lines
3.8 KiB
Python
146 lines
3.8 KiB
Python
"""
|
||
消息去重器模块
|
||
|
||
防止同一条消息被重复处理(某些环境下回调会重复触发)
|
||
"""
|
||
|
||
import asyncio
|
||
import time
|
||
from typing import Any, Dict, Optional
|
||
|
||
from loguru import logger
|
||
|
||
|
||
class MessageDeduplicator:
|
||
"""
|
||
消息去重器
|
||
|
||
使用基于时间的滑动窗口实现去重:
|
||
- 记录最近处理的消息 ID
|
||
- 在 TTL 时间内重复的消息会被过滤
|
||
- 自动清理过期记录,限制内存占用
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
ttl_seconds: float = 30.0,
|
||
max_size: int = 5000,
|
||
):
|
||
"""
|
||
初始化去重器
|
||
|
||
Args:
|
||
ttl_seconds: 消息 ID 的有效期(秒),0 表示禁用去重
|
||
max_size: 最大缓存条目数,防止内存泄漏
|
||
"""
|
||
self.ttl_seconds = max(float(ttl_seconds), 0.0)
|
||
self.max_size = max(int(max_size), 0)
|
||
self._cache: Dict[str, float] = {} # key -> timestamp
|
||
self._lock = asyncio.Lock()
|
||
|
||
@staticmethod
|
||
def extract_msg_id(data: Dict[str, Any]) -> str:
|
||
"""
|
||
从原始消息数据中提取消息 ID
|
||
|
||
Args:
|
||
data: 原始消息数据
|
||
|
||
Returns:
|
||
消息 ID 字符串,提取失败返回空字符串
|
||
"""
|
||
for key in ("msgid", "msg_id", "MsgId", "id"):
|
||
value = data.get(key)
|
||
if value:
|
||
return str(value)
|
||
return ""
|
||
|
||
async def is_duplicate(self, data: Dict[str, Any]) -> bool:
|
||
"""
|
||
检查消息是否重复
|
||
|
||
Args:
|
||
data: 原始消息数据
|
||
|
||
Returns:
|
||
True 表示是重复消息,False 表示是新消息
|
||
"""
|
||
if self.ttl_seconds <= 0:
|
||
return False
|
||
|
||
msg_id = self.extract_msg_id(data)
|
||
if not msg_id:
|
||
# 没有消息 ID 时不做去重,避免误判
|
||
return False
|
||
|
||
key = f"msgid:{msg_id}"
|
||
now = time.time()
|
||
|
||
async with self._lock:
|
||
# 检查是否存在且未过期
|
||
last_seen = self._cache.get(key)
|
||
if last_seen is not None and (now - last_seen) < self.ttl_seconds:
|
||
return True
|
||
|
||
# 记录新消息
|
||
self._cache.pop(key, None) # 确保插入到末尾(保持顺序)
|
||
self._cache[key] = now
|
||
|
||
# 清理过期条目
|
||
self._cleanup_expired(now)
|
||
|
||
# 限制大小
|
||
self._limit_size()
|
||
|
||
return False
|
||
|
||
def _cleanup_expired(self, now: float):
|
||
"""清理过期条目(需在锁内调用)"""
|
||
cutoff = now - self.ttl_seconds
|
||
while self._cache:
|
||
first_key = next(iter(self._cache))
|
||
if self._cache[first_key] >= cutoff:
|
||
break
|
||
self._cache.pop(first_key, None)
|
||
|
||
def _limit_size(self):
|
||
"""限制缓存大小(需在锁内调用)"""
|
||
if self.max_size <= 0:
|
||
return
|
||
while len(self._cache) > self.max_size:
|
||
first_key = next(iter(self._cache))
|
||
self._cache.pop(first_key, None)
|
||
|
||
def clear(self):
|
||
"""清空缓存"""
|
||
self._cache.clear()
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""获取统计信息"""
|
||
return {
|
||
"cached_count": len(self._cache),
|
||
"ttl_seconds": self.ttl_seconds,
|
||
"max_size": self.max_size,
|
||
}
|
||
|
||
@classmethod
|
||
def from_config(cls, perf_config: Dict[str, Any]) -> "MessageDeduplicator":
|
||
"""
|
||
从配置创建去重器
|
||
|
||
Args:
|
||
perf_config: Performance 配置节
|
||
|
||
Returns:
|
||
MessageDeduplicator 实例
|
||
"""
|
||
return cls(
|
||
ttl_seconds=perf_config.get("dedup_ttl_seconds", 30),
|
||
max_size=perf_config.get("dedup_max_size", 5000),
|
||
)
|
||
|
||
|
||
# ==================== 导出 ====================
|
||
|
||
__all__ = ['MessageDeduplicator']
|