import os import time import random import asyncio from typing import Optional, List, Dict from collections import deque import logging from db.connection import DBConnectionManager logger = logging.getLogger(__name__) class ImageCacheManager: IMAGE_KEY_PREFIX = "group:images:" LAST_UPDATE_TIME_KEY = "group:images:last_update_time" BATCH_SIZE = 500 def __init__(self, image_folder: str, cache_size: int = 5): self.image_folder = image_folder self.redis = DBConnectionManager.get_instance().get_redis_connection() self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} # 内存缓存相关 self.cache_size = cache_size self.image_bytes_cache = deque(maxlen=cache_size) # 使用deque作为LRU缓存 self.is_refilling = False # 防止重复填充缓存 def _get_last_update_time(self) -> float: ts = self.redis.get(self.LAST_UPDATE_TIME_KEY) if ts: try: return float(ts) except Exception as e: logger.warning(f"解析最后更新时间失败: {e}") return 0.0 def _set_last_update_time(self, ts: float) -> None: self.redis.set(self.LAST_UPDATE_TIME_KEY, ts) def should_update_index(self) -> bool: try: folder_mtime = os.path.getmtime(self.image_folder) last_ts = self._get_last_update_time() if folder_mtime <= last_ts: logger.info("图片目录未更新,无需重新索引") return False return True except Exception as e: logger.error(f"判断图片目录更新时间失败: {e}") return True # 出错则默认更新 def _scan_new_images(self, last_update_ts: float) -> List[str]: """扫描目录获取新增图片文件路径""" new_images = [] for root, _, files in os.walk(self.image_folder): for file in files: try: _, ext = os.path.splitext(file) if ext.lower() in self.image_extensions: full_path = os.path.join(root, file) # 只收录修改时间大于上次更新的文件 if os.path.getmtime(full_path) > last_update_ts: if os.access(full_path, os.R_OK): new_images.append(full_path) except Exception as e: logger.warning(f"处理文件时异常 {file}: {e}") return new_images def _redis_batch_write(self, keys_values: List[tuple]): pipeline = self.redis.pipeline() for key, value in keys_values: pipeline.sadd(key, value) pipeline.execute() async def update_image_cache(self): """异步更新Redis图片缓存,分批写入,避免一次写入压力过大""" if not self.should_update_index(): return last_update_ts = self._get_last_update_time() new_images = self._scan_new_images(last_update_ts) if not new_images: logger.info("无新增图片,无需更新缓存") # 也更新时间戳防止重复扫描 self._set_last_update_time(time.time()) return logger.info(f"新增图片数量: {len(new_images)}, 开始写入 Redis 分批") total = len(new_images) batch_size = self.BATCH_SIZE # Redis key 固定为 set,方便随机取成员 redis_key = self.IMAGE_KEY_PREFIX + "all" for i in range(0, total, batch_size): batch = new_images[i:i + batch_size] kvs = [(redis_key, path) for path in batch] try: self._redis_batch_write(kvs) logger.info(f"写入 Redis 批次 {i // batch_size + 1} 成功,数量: {len(batch)}") except Exception as e: logger.error(f"Redis 写入失败: {e}") # 这里可选择是否继续或退出,暂继续 # 更新最后更新时间戳 self._set_last_update_time(time.time()) def get_random_image(self) -> Optional[str]: redis_key = self.IMAGE_KEY_PREFIX + "all" try: img = self.redis.srandmember(redis_key) if img: # redis 返回字节,转字符串 return img.decode('utf-8') if isinstance(img, bytes) else img except Exception as e: logger.error(f"获取随机图片失败: {e}") return None def _load_image_bytes(self, image_path: str) -> Optional[bytes]: """从磁盘加载图片的bytes数据""" try: with open(image_path, 'rb') as f: return f.read() except Exception as e: logger.error(f"读取图片文件失败 {image_path}: {e}") return None def _refill_cache(self): """重新填充缓存""" if self.is_refilling: return self.is_refilling = True try: logger.info("开始重新填充图片缓存...") # 获取多个随机图片路径 redis_key = self.IMAGE_KEY_PREFIX + "all" image_paths = [] # 获取比缓存大小多一点的图片路径,以防有些文件读取失败 for _ in range(self.cache_size + 2): try: img = self.redis.srandmember(redis_key) if img: path = img.decode('utf-8') if isinstance(img, bytes) else img # 检查路径是否已经在缓存中 existing_paths = [item['path'] for item in self.image_bytes_cache] if path not in existing_paths: image_paths.append(path) except Exception as e: logger.error(f"获取随机图片路径失败: {e}") continue # 加载图片bytes并添加到缓存 loaded_count = 0 for path in image_paths: if loaded_count >= self.cache_size: break image_bytes = self._load_image_bytes(path) if image_bytes: self.image_bytes_cache.append({ 'path': path, 'bytes': image_bytes }) loaded_count += 1 logger.info(f"缓存填充完成,新增 {loaded_count} 张图片") except Exception as e: logger.error(f"填充缓存失败: {e}") finally: self.is_refilling = False def get_cached_image_bytes(self) -> Optional[Dict[str, any]]: """ 从缓存中获取图片bytes数据 返回格式: {'path': str, 'bytes': bytes} """ # 如果缓存为空,立即填充 if not self.image_bytes_cache: self._refill_cache() if not self.image_bytes_cache: return None # 从缓存中取出一个图片 cached_image = self.image_bytes_cache.popleft() # 如果缓存中只剩最后一个,异步填充缓存 if len(self.image_bytes_cache) <= 1: # 使用线程池异步填充,避免阻塞 import threading threading.Thread(target=self._refill_cache, daemon=True).start() return cached_image def get_cached_image_count(self) -> int: """获取当前缓存中的图片数量""" return len(self.image_bytes_cache)