Files
abot/plugins/xiuren_image/images_cache.py

204 lines
7.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import time
import random
import asyncio
from typing import Optional, List, Dict
from collections import deque
import logging
from db.connection import DBConnectionManager
logger = logging.getLogger(__name__)
class ImageCacheManager:
IMAGE_KEY_PREFIX = "group:images:"
LAST_UPDATE_TIME_KEY = "group:images:last_update_time"
BATCH_SIZE = 500
def __init__(self, image_folder: str, cache_size: int = 5):
self.image_folder = image_folder
self.redis = DBConnectionManager.get_instance().get_redis_connection()
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
# 内存缓存相关
self.cache_size = cache_size
self.image_bytes_cache = deque(maxlen=cache_size) # 使用deque作为LRU缓存
self.is_refilling = False # 防止重复填充缓存
def _get_last_update_time(self) -> float:
ts = self.redis.get(self.LAST_UPDATE_TIME_KEY)
if ts:
try:
return float(ts)
except Exception as e:
logger.warning(f"解析最后更新时间失败: {e}")
return 0.0
def _set_last_update_time(self, ts: float) -> None:
self.redis.set(self.LAST_UPDATE_TIME_KEY, ts)
def should_update_index(self) -> bool:
try:
folder_mtime = os.path.getmtime(self.image_folder)
last_ts = self._get_last_update_time()
if folder_mtime <= last_ts:
logger.info("图片目录未更新,无需重新索引")
return False
return True
except Exception as e:
logger.error(f"判断图片目录更新时间失败: {e}")
return True # 出错则默认更新
def _scan_new_images(self, last_update_ts: float) -> List[str]:
"""扫描目录获取新增图片文件路径"""
new_images = []
for root, _, files in os.walk(self.image_folder):
for file in files:
try:
_, ext = os.path.splitext(file)
if ext.lower() in self.image_extensions:
full_path = os.path.join(root, file)
# 只收录修改时间大于上次更新的文件
if os.path.getmtime(full_path) > last_update_ts:
if os.access(full_path, os.R_OK):
new_images.append(full_path)
except Exception as e:
logger.warning(f"处理文件时异常 {file}: {e}")
return new_images
def _redis_batch_write(self, keys_values: List[tuple]):
pipeline = self.redis.pipeline()
for key, value in keys_values:
pipeline.sadd(key, value)
pipeline.execute()
async def update_image_cache(self):
"""异步更新Redis图片缓存分批写入避免一次写入压力过大"""
if not self.should_update_index():
return
last_update_ts = self._get_last_update_time()
new_images = self._scan_new_images(last_update_ts)
if not new_images:
logger.info("无新增图片,无需更新缓存")
# 也更新时间戳防止重复扫描
self._set_last_update_time(time.time())
return
logger.info(f"新增图片数量: {len(new_images)}, 开始写入 Redis 分批")
total = len(new_images)
batch_size = self.BATCH_SIZE
# Redis key 固定为 set方便随机取成员
redis_key = self.IMAGE_KEY_PREFIX + "all"
for i in range(0, total, batch_size):
batch = new_images[i:i + batch_size]
kvs = [(redis_key, path) for path in batch]
try:
self._redis_batch_write(kvs)
logger.info(f"写入 Redis 批次 {i // batch_size + 1} 成功,数量: {len(batch)}")
except Exception as e:
logger.error(f"Redis 写入失败: {e}")
# 这里可选择是否继续或退出,暂继续
# 更新最后更新时间戳
self._set_last_update_time(time.time())
def get_random_image(self) -> Optional[str]:
redis_key = self.IMAGE_KEY_PREFIX + "all"
try:
img = self.redis.srandmember(redis_key)
if img:
# redis 返回字节,转字符串
return img.decode('utf-8') if isinstance(img, bytes) else img
except Exception as e:
logger.error(f"获取随机图片失败: {e}")
return None
def _load_image_bytes(self, image_path: str) -> Optional[bytes]:
"""从磁盘加载图片的bytes数据"""
try:
with open(image_path, 'rb') as f:
return f.read()
except Exception as e:
logger.error(f"读取图片文件失败 {image_path}: {e}")
return None
def _refill_cache(self):
"""重新填充缓存"""
if self.is_refilling:
return
self.is_refilling = True
try:
logger.info("开始重新填充图片缓存...")
# 获取多个随机图片路径
redis_key = self.IMAGE_KEY_PREFIX + "all"
image_paths = []
# 获取比缓存大小多一点的图片路径,以防有些文件读取失败
for _ in range(self.cache_size + 2):
try:
img = self.redis.srandmember(redis_key)
if img:
path = img.decode('utf-8') if isinstance(img, bytes) else img
# 检查路径是否已经在缓存中
existing_paths = [item['path'] for item in self.image_bytes_cache]
if path not in existing_paths:
image_paths.append(path)
except Exception as e:
logger.error(f"获取随机图片路径失败: {e}")
continue
# 加载图片bytes并添加到缓存
loaded_count = 0
for path in image_paths:
if loaded_count >= self.cache_size:
break
image_bytes = self._load_image_bytes(path)
if image_bytes:
self.image_bytes_cache.append({
'path': path,
'bytes': image_bytes
})
loaded_count += 1
logger.info(f"缓存填充完成,新增 {loaded_count} 张图片")
except Exception as e:
logger.error(f"填充缓存失败: {e}")
finally:
self.is_refilling = False
def get_cached_image_bytes(self) -> Optional[Dict[str, any]]:
"""
从缓存中获取图片bytes数据
返回格式: {'path': str, 'bytes': bytes}
"""
# 如果缓存为空,立即填充
if not self.image_bytes_cache:
self._refill_cache()
if not self.image_bytes_cache:
return None
# 从缓存中取出一个图片
cached_image = self.image_bytes_cache.popleft()
# 如果缓存中只剩最后一个,异步填充缓存
if len(self.image_bytes_cache) <= 1:
# 使用线程池异步填充,避免阻塞
import threading
threading.Thread(target=self._refill_cache, daemon=True).start()
return cached_image
def get_cached_image_count(self) -> int:
"""获取当前缓存中的图片数量"""
return len(self.image_bytes_cache)