204 lines
7.6 KiB
Python
204 lines
7.6 KiB
Python
import os
|
||
import time
|
||
import random
|
||
import asyncio
|
||
from typing import Optional, List, Dict
|
||
from collections import deque
|
||
|
||
import logging
|
||
|
||
from db.connection import DBConnectionManager
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ImageCacheManager:
|
||
IMAGE_KEY_PREFIX = "group:images:"
|
||
LAST_UPDATE_TIME_KEY = "group:images:last_update_time"
|
||
BATCH_SIZE = 500
|
||
|
||
def __init__(self, image_folder: str, cache_size: int = 5):
|
||
self.image_folder = image_folder
|
||
self.redis = DBConnectionManager.get_instance().get_redis_connection()
|
||
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
|
||
|
||
# 内存缓存相关
|
||
self.cache_size = cache_size
|
||
self.image_bytes_cache = deque(maxlen=cache_size) # 使用deque作为LRU缓存
|
||
self.is_refilling = False # 防止重复填充缓存
|
||
|
||
def _get_last_update_time(self) -> float:
|
||
ts = self.redis.get(self.LAST_UPDATE_TIME_KEY)
|
||
if ts:
|
||
try:
|
||
return float(ts)
|
||
except Exception as e:
|
||
logger.warning(f"解析最后更新时间失败: {e}")
|
||
return 0.0
|
||
|
||
def _set_last_update_time(self, ts: float) -> None:
|
||
self.redis.set(self.LAST_UPDATE_TIME_KEY, ts)
|
||
|
||
def should_update_index(self) -> bool:
|
||
try:
|
||
folder_mtime = os.path.getmtime(self.image_folder)
|
||
last_ts = self._get_last_update_time()
|
||
if folder_mtime <= last_ts:
|
||
logger.info("图片目录未更新,无需重新索引")
|
||
return False
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"判断图片目录更新时间失败: {e}")
|
||
return True # 出错则默认更新
|
||
|
||
def _scan_new_images(self, last_update_ts: float) -> List[str]:
|
||
"""扫描目录获取新增图片文件路径"""
|
||
new_images = []
|
||
for root, _, files in os.walk(self.image_folder):
|
||
for file in files:
|
||
try:
|
||
_, ext = os.path.splitext(file)
|
||
if ext.lower() in self.image_extensions:
|
||
full_path = os.path.join(root, file)
|
||
# 只收录修改时间大于上次更新的文件
|
||
if os.path.getmtime(full_path) > last_update_ts:
|
||
if os.access(full_path, os.R_OK):
|
||
new_images.append(full_path)
|
||
except Exception as e:
|
||
logger.warning(f"处理文件时异常 {file}: {e}")
|
||
return new_images
|
||
|
||
def _redis_batch_write(self, keys_values: List[tuple]):
|
||
pipeline = self.redis.pipeline()
|
||
for key, value in keys_values:
|
||
pipeline.sadd(key, value)
|
||
pipeline.execute()
|
||
|
||
async def update_image_cache(self):
|
||
"""异步更新Redis图片缓存,分批写入,避免一次写入压力过大"""
|
||
if not self.should_update_index():
|
||
return
|
||
|
||
last_update_ts = self._get_last_update_time()
|
||
new_images = self._scan_new_images(last_update_ts)
|
||
|
||
if not new_images:
|
||
logger.info("无新增图片,无需更新缓存")
|
||
# 也更新时间戳防止重复扫描
|
||
self._set_last_update_time(time.time())
|
||
return
|
||
|
||
logger.info(f"新增图片数量: {len(new_images)}, 开始写入 Redis 分批")
|
||
|
||
total = len(new_images)
|
||
batch_size = self.BATCH_SIZE
|
||
# Redis key 固定为 set,方便随机取成员
|
||
redis_key = self.IMAGE_KEY_PREFIX + "all"
|
||
|
||
for i in range(0, total, batch_size):
|
||
batch = new_images[i:i + batch_size]
|
||
kvs = [(redis_key, path) for path in batch]
|
||
try:
|
||
self._redis_batch_write(kvs)
|
||
logger.info(f"写入 Redis 批次 {i // batch_size + 1} 成功,数量: {len(batch)}")
|
||
except Exception as e:
|
||
logger.error(f"Redis 写入失败: {e}")
|
||
# 这里可选择是否继续或退出,暂继续
|
||
|
||
# 更新最后更新时间戳
|
||
self._set_last_update_time(time.time())
|
||
|
||
def get_random_image(self) -> Optional[str]:
|
||
redis_key = self.IMAGE_KEY_PREFIX + "all"
|
||
try:
|
||
img = self.redis.srandmember(redis_key)
|
||
if img:
|
||
# redis 返回字节,转字符串
|
||
return img.decode('utf-8') if isinstance(img, bytes) else img
|
||
except Exception as e:
|
||
logger.error(f"获取随机图片失败: {e}")
|
||
return None
|
||
|
||
def _load_image_bytes(self, image_path: str) -> Optional[bytes]:
|
||
"""从磁盘加载图片的bytes数据"""
|
||
try:
|
||
with open(image_path, 'rb') as f:
|
||
return f.read()
|
||
except Exception as e:
|
||
logger.error(f"读取图片文件失败 {image_path}: {e}")
|
||
return None
|
||
|
||
def _refill_cache(self):
|
||
"""重新填充缓存"""
|
||
if self.is_refilling:
|
||
return
|
||
|
||
self.is_refilling = True
|
||
try:
|
||
logger.info("开始重新填充图片缓存...")
|
||
|
||
# 获取多个随机图片路径
|
||
redis_key = self.IMAGE_KEY_PREFIX + "all"
|
||
image_paths = []
|
||
|
||
# 获取比缓存大小多一点的图片路径,以防有些文件读取失败
|
||
for _ in range(self.cache_size + 2):
|
||
try:
|
||
img = self.redis.srandmember(redis_key)
|
||
if img:
|
||
path = img.decode('utf-8') if isinstance(img, bytes) else img
|
||
# 检查路径是否已经在缓存中
|
||
existing_paths = [item['path'] for item in self.image_bytes_cache]
|
||
if path not in existing_paths:
|
||
image_paths.append(path)
|
||
except Exception as e:
|
||
logger.error(f"获取随机图片路径失败: {e}")
|
||
continue
|
||
|
||
# 加载图片bytes并添加到缓存
|
||
loaded_count = 0
|
||
for path in image_paths:
|
||
if loaded_count >= self.cache_size:
|
||
break
|
||
|
||
image_bytes = self._load_image_bytes(path)
|
||
if image_bytes:
|
||
self.image_bytes_cache.append({
|
||
'path': path,
|
||
'bytes': image_bytes
|
||
})
|
||
loaded_count += 1
|
||
|
||
logger.info(f"缓存填充完成,新增 {loaded_count} 张图片")
|
||
|
||
except Exception as e:
|
||
logger.error(f"填充缓存失败: {e}")
|
||
finally:
|
||
self.is_refilling = False
|
||
|
||
def get_cached_image_bytes(self) -> Optional[Dict[str, any]]:
|
||
"""
|
||
从缓存中获取图片bytes数据
|
||
返回格式: {'path': str, 'bytes': bytes}
|
||
"""
|
||
# 如果缓存为空,立即填充
|
||
if not self.image_bytes_cache:
|
||
self._refill_cache()
|
||
if not self.image_bytes_cache:
|
||
return None
|
||
|
||
# 从缓存中取出一个图片
|
||
cached_image = self.image_bytes_cache.popleft()
|
||
|
||
# 如果缓存中只剩最后一个,异步填充缓存
|
||
if len(self.image_bytes_cache) <= 1:
|
||
# 使用线程池异步填充,避免阻塞
|
||
import threading
|
||
threading.Thread(target=self._refill_cache, daemon=True).start()
|
||
|
||
return cached_image
|
||
|
||
def get_cached_image_count(self) -> int:
|
||
"""获取当前缓存中的图片数量"""
|
||
return len(self.image_bytes_cache)
|