970 lines
42 KiB
Python
970 lines
42 KiB
Python
import os
|
||
import time
|
||
import json
|
||
import logging
|
||
import shutil
|
||
import datetime
|
||
from typing import Dict, Any, List, Tuple, Optional
|
||
import threading
|
||
import traceback
|
||
|
||
from wcferry import Wcf
|
||
|
||
from db.connection import DBConnectionManager
|
||
from plugin_common.plugin_interface import PluginStatus
|
||
|
||
try:
|
||
import numpy as np
|
||
import cv2
|
||
from deepface import DeepFace
|
||
from sklearn.cluster import DBSCAN
|
||
except ImportError as e:
|
||
raise ImportError(f"缺少必要的依赖库: {e}。请安装 requirements.txt 中的依赖: pip install -r requirements.txt")
|
||
|
||
from plugin_common.message_plugin_interface import MessagePluginInterface
|
||
from utils.robot_cmd.robot_command import Feature, PermissionStatus, GroupBotManager
|
||
from utils.decorator.plugin_decorators import plugin_stats_decorator
|
||
from db.kid_photo_redis import KidPhotoRedisDB
|
||
|
||
|
||
class FaceAnalyzer:
|
||
"""人脸分析器,负责检测和分析人脸"""
|
||
|
||
def __init__(self, kid_age_threshold=14):
|
||
# 保留参数但不再使用
|
||
self.kid_age_threshold = kid_age_threshold
|
||
self.logger = logging.getLogger("Plugin.KidPhotoExtractor.FaceAnalyzer")
|
||
|
||
def detect_faces(self, image_path):
|
||
"""检测图片中的所有人脸"""
|
||
try:
|
||
# 检查文件是否存在
|
||
if not os.path.exists(image_path):
|
||
self.logger.error(f"图片文件不存在: {image_path}")
|
||
return []
|
||
|
||
# 检查文件是否可读
|
||
try:
|
||
img = cv2.imread(image_path)
|
||
if img is None:
|
||
self.logger.error(f"无法读取图片: {image_path}")
|
||
return []
|
||
|
||
# 打印图片信息以便调试
|
||
self.logger.info(f"图片尺寸: {img.shape}, 类型: {img.dtype}")
|
||
except Exception as e:
|
||
self.logger.error(f"读取图片失败: {image_path}, 错误: {e}")
|
||
return []
|
||
|
||
# 使用更精确的人脸检测参数,移除不支持的threshold参数
|
||
faces = DeepFace.extract_faces(
|
||
img_path=image_path,
|
||
enforce_detection=False,
|
||
detector_backend='retinaface', # 使用更精确的RetinaFace检测器
|
||
align=True
|
||
)
|
||
|
||
# 过滤掉可能的误检
|
||
valid_faces = []
|
||
for face in faces:
|
||
# 检查置信度,过滤掉低置信度的检测结果
|
||
confidence = face.get('confidence', 0)
|
||
if confidence < 0.8: # 提高置信度阈值到0.8
|
||
self.logger.info(f"过滤掉低置信度人脸: {confidence}")
|
||
continue
|
||
|
||
# 检查人脸区域的大小,过滤掉太小的区域(可能是图标)
|
||
if 'facial_area' in face:
|
||
area = face['facial_area']
|
||
face_width = area['w']
|
||
face_height = area['h']
|
||
|
||
# 过滤条件:人脸必须足够大(通常图标较小)
|
||
min_face_size = 60 # 增加最小人脸尺寸(像素)
|
||
if face_width > min_face_size and face_height > min_face_size:
|
||
# 检查人脸宽高比,过滤掉不合理的比例
|
||
aspect_ratio = face_width / face_height
|
||
if 0.7 <= aspect_ratio <= 1.5: # 缩小人脸宽高比范围,更接近真实人脸
|
||
# 计算人脸区域占图片的比例,过滤掉太小的区域
|
||
img_height, img_width = img.shape[:2]
|
||
face_area_ratio = (face_width * face_height) / (img_width * img_height)
|
||
if 0.01 <= face_area_ratio <= 0.9: # 人脸区域应该在合理范围内
|
||
valid_faces.append(face)
|
||
else:
|
||
self.logger.info(f"过滤掉不合理区域比例的人脸: {face_area_ratio}")
|
||
else:
|
||
self.logger.info(f"过滤掉不合理宽高比的人脸: {aspect_ratio}")
|
||
else:
|
||
self.logger.info(f"过滤掉过小的人脸: {face_width}x{face_height}")
|
||
|
||
# 记录检测到的人脸数量
|
||
self.logger.info(f"在图片 {image_path} 中检测到 {len(faces)} 个人脸,有效人脸 {len(valid_faces)} 个")
|
||
|
||
return valid_faces
|
||
except Exception as e:
|
||
self.logger.error(f"人脸检测失败: {image_path}, 错误: {e}")
|
||
self.logger.error(traceback.format_exc()) # 打印完整的错误堆栈
|
||
return []
|
||
|
||
def analyze_face(self, image_path, face_area=None):
|
||
"""分析人脸,获取特征向量"""
|
||
temp_path = None
|
||
try:
|
||
# 检查文件是否存在
|
||
if not os.path.exists(image_path):
|
||
self.logger.error(f"图片文件不存在: {image_path}")
|
||
return None, None
|
||
|
||
# 检查文件是否可读
|
||
try:
|
||
img = cv2.imread(image_path)
|
||
if img is None:
|
||
self.logger.error(f"无法读取图片: {image_path}")
|
||
return None, None
|
||
|
||
# 如果指定了人脸区域,裁剪图片
|
||
if face_area:
|
||
x, y, w, h = face_area['x'], face_area['y'], face_area['w'], face_area['h']
|
||
img = img[y:y+h, x:x+w]
|
||
# 保存临时裁剪图片
|
||
temp_path = f"{image_path}.temp.jpg"
|
||
cv2.imwrite(temp_path, img)
|
||
|
||
# 使用人脸检测来验证这是否是一张人脸,而不是使用verify
|
||
try:
|
||
# 使用DeepFace的detect_face函数检测是否包含人脸
|
||
detection = DeepFace.extract_faces(
|
||
img_path=temp_path,
|
||
enforce_detection=True, # 强制检测
|
||
detector_backend='retinaface'
|
||
)
|
||
|
||
# 如果没有检测到人脸,则返回None
|
||
if not detection or len(detection) == 0:
|
||
self.logger.info(f"二次验证失败,可能不是真实人脸: {image_path}")
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
return None, None
|
||
except Exception as e:
|
||
# 如果强制检测失败,说明可能不是人脸
|
||
self.logger.warning(f"人脸二次验证失败: {e}")
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
return None, None
|
||
|
||
image_path = temp_path
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"读取图片失败: {image_path}, 错误: {e}")
|
||
return None, None
|
||
|
||
# 提取人脸特征向量用于后续比对
|
||
embedding_result = DeepFace.represent(
|
||
img_path=image_path,
|
||
model_name='ArcFace', # 使用ArcFace模型,对不同年龄段人脸效果更好
|
||
enforce_detection=True, # 改为True,确保是人脸
|
||
detector_backend='retinaface' # 使用更精确的检测器
|
||
)
|
||
|
||
# 处理embedding结果,确保它是一个数值数组
|
||
embedding = None
|
||
if isinstance(embedding_result, list) and len(embedding_result) > 0:
|
||
if isinstance(embedding_result[0], dict) and 'embedding' in embedding_result[0]:
|
||
embedding = embedding_result[0]['embedding']
|
||
else:
|
||
embedding = embedding_result[0]
|
||
elif isinstance(embedding_result, dict) and 'embedding' in embedding_result:
|
||
embedding = embedding_result['embedding']
|
||
else:
|
||
embedding = embedding_result
|
||
|
||
# 确保embedding是数值列表
|
||
if embedding is not None:
|
||
try:
|
||
# 尝试转换为浮点数列表
|
||
embedding = [float(x) for x in embedding]
|
||
except (TypeError, ValueError):
|
||
self.logger.error(f"无法将嵌入向量转换为浮点数列表: {image_path}")
|
||
return None, temp_path
|
||
|
||
self.logger.info(f"成功提取人脸特征向量: {image_path}")
|
||
|
||
# 不再进行年龄判断,直接返回特征向量和临时文件路径
|
||
return {
|
||
'embedding': embedding,
|
||
'is_kid': True # 默认所有人脸都处理
|
||
}, temp_path
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"人脸分析失败: {image_path}, 错误: {e}")
|
||
self.logger.error(traceback.format_exc()) # 打印完整的错误堆栈
|
||
return None, temp_path
|
||
|
||
def is_kid(self, face_info):
|
||
"""判断是否为小朋友 - 现在总是返回True"""
|
||
if not face_info:
|
||
return False
|
||
return True # 所有人脸都视为需要处理
|
||
|
||
|
||
class FaceGrouper:
|
||
"""人脸分组器,负责对人脸进行聚类分组"""
|
||
|
||
def __init__(self, eps=0.4, min_samples=4): # 减小eps值,增加min_samples值
|
||
self.eps = eps # DBSCAN的邻域半径
|
||
self.min_samples = min_samples # 形成核心点所需的最小样本数
|
||
self.logger = logging.getLogger("Plugin.KidPhotoExtractor.FaceGrouper")
|
||
|
||
def cluster_faces(self, face_embeddings):
|
||
"""对人脸特征向量进行聚类"""
|
||
if not face_embeddings:
|
||
return []
|
||
|
||
if len(face_embeddings) < 2:
|
||
# 如果只有一个人脸,直接返回
|
||
return [0] * len(face_embeddings)
|
||
|
||
try:
|
||
# 将特征向量转换为numpy数组,确保是浮点数类型
|
||
# 首先提取实际的嵌入向量数据
|
||
processed_embeddings = []
|
||
for emb in face_embeddings:
|
||
# DeepFace.represent() 可能返回字典或列表,需要提取实际的向量
|
||
if isinstance(emb, dict) and 'embedding' in emb:
|
||
processed_embeddings.append(emb['embedding'])
|
||
elif isinstance(emb, list) and len(emb) > 0:
|
||
# 如果是列表,取第一个元素
|
||
if isinstance(emb[0], dict) and 'embedding' in emb[0]:
|
||
processed_embeddings.append(emb[0]['embedding'])
|
||
else:
|
||
processed_embeddings.append(emb)
|
||
else:
|
||
processed_embeddings.append(emb)
|
||
|
||
# 转换为numpy数组并确保是浮点数类型
|
||
embeddings_array = np.array(processed_embeddings, dtype=np.float64)
|
||
|
||
# 安全地检查无效值
|
||
try:
|
||
has_nan = np.isnan(embeddings_array).any()
|
||
has_inf = np.isinf(embeddings_array).any()
|
||
if has_nan or has_inf:
|
||
self.logger.error("特征向量包含无效值(NaN或Inf)")
|
||
# 清理无效值
|
||
embeddings_array = np.nan_to_num(embeddings_array)
|
||
except TypeError:
|
||
# 如果仍然无法检查NaN/Inf,记录警告并继续
|
||
self.logger.warning("无法检查特征向量中的无效值,将直接进行聚类")
|
||
|
||
# 使用DBSCAN进行聚类
|
||
clustering = DBSCAN(eps=self.eps, min_samples=self.min_samples, metric='euclidean').fit(embeddings_array)
|
||
|
||
# 获取聚类标签
|
||
labels = clustering.labels_
|
||
|
||
# 处理噪声点(标签为-1的点)
|
||
# 将噪声点分配到最近的聚类
|
||
noise_indices = np.where(labels == -1)[0]
|
||
if len(noise_indices) > 0 and len(set(labels) - {-1}) > 0:
|
||
for idx in noise_indices:
|
||
# 计算该点到所有非噪声点的距离
|
||
distances = []
|
||
for cluster_id in set(labels) - {-1}:
|
||
cluster_points = embeddings_array[labels == cluster_id]
|
||
if len(cluster_points) > 0:
|
||
# 计算到该聚类所有点的平均距离
|
||
dist = np.mean([np.linalg.norm(embeddings_array[idx] - point) for point in cluster_points])
|
||
distances.append((cluster_id, dist))
|
||
|
||
# 分配到最近的聚类
|
||
if distances:
|
||
nearest_cluster = min(distances, key=lambda x: x[1])[0]
|
||
labels[idx] = nearest_cluster
|
||
|
||
return labels.tolist()
|
||
except MemoryError as e:
|
||
self.logger.error(f"聚类过程内存不足: {e}")
|
||
return [0] * len(face_embeddings) # 失败时,将所有人脸分到同一组
|
||
except Exception as e:
|
||
self.logger.error(f"人脸聚类失败: {e}")
|
||
self.logger.error(traceback.format_exc())
|
||
return [0] * len(face_embeddings) # 失败时,将所有人脸分到同一组
|
||
|
||
|
||
class PhotoClassifier:
|
||
"""照片分类器,负责创建分类目录并复制照片"""
|
||
|
||
def __init__(self):
|
||
self.logger = logging.getLogger("Plugin.KidPhotoExtractor.PhotoClassifier")
|
||
|
||
def create_kid_folder(self, base_dir, kid_id):
|
||
"""创建小朋友的文件夹"""
|
||
kid_folder = os.path.join(base_dir, f"kid_{kid_id}")
|
||
os.makedirs(kid_folder, exist_ok=True)
|
||
return kid_folder
|
||
|
||
def copy_photo(self, src_path, dest_folder, new_name=None):
|
||
"""复制照片到目标文件夹"""
|
||
try:
|
||
if not os.path.exists(src_path):
|
||
self.logger.error(f"源文件不存在: {src_path}")
|
||
return False
|
||
|
||
if new_name:
|
||
dest_path = os.path.join(dest_folder, new_name)
|
||
else:
|
||
dest_path = os.path.join(dest_folder, os.path.basename(src_path))
|
||
|
||
# 如果目标文件已存在,添加时间戳避免重名
|
||
if os.path.exists(dest_path):
|
||
name, ext = os.path.splitext(os.path.basename(src_path))
|
||
timestamp = int(time.time())
|
||
dest_path = os.path.join(dest_folder, f"{name}_{timestamp}{ext}")
|
||
|
||
shutil.copy2(src_path, dest_path)
|
||
return True
|
||
except Exception as e:
|
||
self.logger.error(f"复制照片失败: {e}")
|
||
return False
|
||
|
||
def save_analysis_report(self, output_dir, report_data):
|
||
"""保存分析报告"""
|
||
try:
|
||
report_path = os.path.join(output_dir, "analysis_report.json")
|
||
with open(report_path, 'w', encoding='utf-8') as f:
|
||
json.dump(report_data, f, ensure_ascii=False, indent=2)
|
||
return report_path
|
||
except Exception as e:
|
||
self.logger.error(f"保存分析报告失败: {e}")
|
||
return None
|
||
|
||
|
||
class KidPhotoExtractorPlugin(MessagePluginInterface):
|
||
"""小朋友照片提取插件"""
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return "小朋友照片提取"
|
||
|
||
@property
|
||
def version(self) -> str:
|
||
return "0.0.1"
|
||
|
||
@property
|
||
def description(self) -> str:
|
||
return "未完成-提供小朋友照片提取和分类功能,基于人脸识别技术"
|
||
|
||
@property
|
||
def author(self) -> str:
|
||
return "Trae AI"
|
||
|
||
@property
|
||
def command_prefix(self) -> Optional[str]:
|
||
return "#" # 使用#作为命令前缀
|
||
|
||
@property
|
||
def commands(self) -> List[str]:
|
||
return self._commands
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self._commands = []
|
||
self.face_analyzer = None
|
||
self.face_grouper = None
|
||
self.photo_classifier = None
|
||
self.analysis_tasks = {} # 存储分析任务状态
|
||
self.db_manager = None
|
||
self.kid_photo_db = None
|
||
|
||
def initialize(self, context: Dict[str, Any]) -> bool:
|
||
"""初始化插件"""
|
||
self.LOG = logging.getLogger(f"Plugin.{self.name}")
|
||
self.LOG.info(f"正在初始化 {self.name} 插件...")
|
||
|
||
# 保存上下文对象
|
||
self.wcf = context.get("wcf")
|
||
self.event_system = context.get("event_system")
|
||
self.message_util = context.get("message_util")
|
||
self.db_manager = DBConnectionManager.get_instance()
|
||
|
||
# 初始化数据库
|
||
if self.db_manager:
|
||
self.kid_photo_db = KidPhotoRedisDB(self.db_manager)
|
||
else:
|
||
self.LOG.warning("数据库管理器未提供,将无法使用Redis功能")
|
||
|
||
# 初始化配置
|
||
self._commands = self._config.get("KidPhotoExtractor", {}).get("command",
|
||
["开始分析照片", "查看照片分析", "清理照片分析",
|
||
"照片分析时间"])
|
||
self.command_format = self._config.get("KidPhotoExtractor", {}).get("command-format",
|
||
"使用 #开始分析照片 [目录路径] 开始分析")
|
||
self.enable = self._config.get("KidPhotoExtractor", {}).get("enable", True)
|
||
# 只在插件启用时初始化数据库和其他组件 TODO,现在过于复杂,暂时不启用
|
||
if not self.enable:
|
||
self.LOG.info(f"[{self.name}] 插件已禁用,跳过组件初始化")
|
||
return True
|
||
# 初始化组件
|
||
self.face_analyzer = FaceAnalyzer()
|
||
self.face_grouper = FaceGrouper()
|
||
self.photo_classifier = PhotoClassifier()
|
||
|
||
self.LOG.info(f"[{self.name}] 插件初始化完成,指令:{self._commands}")
|
||
return True
|
||
|
||
def start(self) -> bool:
|
||
"""启动插件"""
|
||
self.LOG.info(f"[{self.name}] 插件已启动")
|
||
self.status = PluginStatus.RUNNING
|
||
return True
|
||
|
||
def stop(self) -> bool:
|
||
"""停止插件"""
|
||
self.LOG.info(f"[{self.name}] 插件已停止")
|
||
self.status = PluginStatus.STOPPED
|
||
return True
|
||
|
||
def can_process(self, message: Dict[str, Any]) -> bool:
|
||
"""检查是否可以处理该消息"""
|
||
if not self.enable:
|
||
return False
|
||
|
||
content = str(message.get("content", "")).strip()
|
||
|
||
# 检查是否以命令前缀开头
|
||
if not content.startswith(self.command_prefix):
|
||
return False
|
||
|
||
# 去掉前缀后检查命令
|
||
command_text = content[len(self.command_prefix):].strip()
|
||
command = command_text.split(" ")[0]
|
||
|
||
return command in self._commands
|
||
|
||
@plugin_stats_decorator(plugin_name="小朋友照片提取")
|
||
def process_message(self, message: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
|
||
"""处理消息"""
|
||
content = str(message.get("content", "")).strip()
|
||
self.LOG.info(f"插件执行: {self.name}:{content}")
|
||
|
||
# 去掉前缀
|
||
command_text = content[len(self.command_prefix):].strip()
|
||
command_parts = command_text.split(" ")
|
||
command = command_parts[0]
|
||
|
||
sender = message.get("sender")
|
||
roomid = message.get("roomid", "")
|
||
wcf: Wcf = message.get("wcf")
|
||
gbm: GroupBotManager = message.get("gbm")
|
||
|
||
# 检查权限
|
||
if roomid and gbm.get_group_permission(roomid, Feature.KID_PHOTO_EXTRACT) == PermissionStatus.DISABLED:
|
||
return False, "没有权限"
|
||
|
||
# 根据命令分发处理
|
||
if command == "开始分析照片":
|
||
return self._handle_start_analysis(command_parts, wcf, sender, roomid, gbm)
|
||
elif command == "查看照片分析":
|
||
return self._handle_view_analysis(wcf, sender, roomid, gbm)
|
||
elif command == "清理照片分析":
|
||
return self._handle_clean_analysis(wcf, sender, roomid, gbm)
|
||
elif command == "照片分析时间":
|
||
return self._handle_analysis_time(wcf, sender, roomid, gbm)
|
||
else:
|
||
wcf.send_text(f"❌未知命令!\n{self.command_format}",
|
||
(roomid if roomid else sender), sender)
|
||
return True, "未知命令"
|
||
|
||
# 在 _handle_start_analysis 方法中,增加输入验证和错误处理
|
||
def _handle_start_analysis(self, command_parts, wcf, sender, roomid, gbm):
|
||
"""处理开始分析命令"""
|
||
target = roomid if roomid else sender
|
||
|
||
# 检查是否已有分析任务在进行
|
||
group_key = roomid or sender
|
||
if group_key in self.analysis_tasks and self.analysis_tasks[group_key].get("running", False):
|
||
wcf.send_text("⚠️已有分析任务正在进行,请等待完成后再试", target, sender)
|
||
return True, "任务已在进行"
|
||
|
||
# 判断是否为全量分析
|
||
is_full = False
|
||
if len(command_parts) > 1 and command_parts[1].lower() == "全量":
|
||
is_full = True
|
||
command_parts.pop(1) # 移除"全量"参数
|
||
|
||
# 获取目录路径
|
||
source_dir = None
|
||
if len(command_parts) > 1:
|
||
source_dir = " ".join(command_parts[1:])
|
||
# 验证路径安全性
|
||
if not self._is_safe_path(source_dir):
|
||
wcf.send_text("⚠️指定的路径不安全或包含非法字符", target, sender)
|
||
return True, "路径不安全"
|
||
else:
|
||
# 使用默认目录
|
||
if roomid:
|
||
# 群聊默认目录 - 使用与message_to_db.py相同的图片存储结构
|
||
image_base_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "static",
|
||
"images")
|
||
source_dir = os.path.join(image_base_dir, roomid)
|
||
else:
|
||
# 暂不支持私聊
|
||
wcf.send_text("⚠️当前版本仅支持群聊图片分析", target, sender)
|
||
return True, "不支持私聊"
|
||
|
||
# 检查目录是否存在
|
||
if not os.path.exists(source_dir):
|
||
wcf.send_text(f"❌目录不存在: {source_dir}", target, sender)
|
||
return True, "目录不存在"
|
||
|
||
# 检查目录是否有图片文件
|
||
has_images = False
|
||
for root, _, files in os.walk(source_dir):
|
||
for file in files:
|
||
if file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
|
||
has_images = True
|
||
break
|
||
if has_images:
|
||
break
|
||
|
||
if not has_images:
|
||
wcf.send_text(f"❌目录中没有图片文件: {source_dir}", target, sender)
|
||
return True, "没有图片文件"
|
||
|
||
# 创建输出目录
|
||
output_dir = os.path.join(os.path.dirname(source_dir), f"kid_photos_{roomid}")
|
||
try:
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
except Exception as e:
|
||
wcf.send_text(f"❌创建输出目录失败: {str(e)}", target, sender)
|
||
return True, "创建目录失败"
|
||
|
||
# 启动分析任务
|
||
analysis_type = "全量" if is_full else "增量"
|
||
wcf.send_text(f"✅开始{analysis_type}分析照片,源目录: {source_dir}\n分析结果将保存到: {output_dir}", target,
|
||
sender)
|
||
|
||
# 记录任务状态
|
||
self.analysis_tasks[group_key] = {
|
||
"running": True,
|
||
"start_time": time.time(),
|
||
"source_dir": source_dir,
|
||
"output_dir": output_dir,
|
||
"is_full": is_full
|
||
}
|
||
|
||
# 在后台线程中执行分析
|
||
thread = threading.Thread(
|
||
target=self._run_analysis_task,
|
||
args=(group_key, source_dir, output_dir, wcf, target, sender)
|
||
)
|
||
thread.daemon = True
|
||
thread.start()
|
||
|
||
return True, f"开始{analysis_type}分析"
|
||
|
||
def _is_safe_path(self, path):
|
||
"""检查路径是否安全"""
|
||
# 检查路径是否包含可疑字符
|
||
suspicious_chars = ['..', '~', '`', '$', '|', ';', '&', '*', '>', '<', '"', "'"]
|
||
for char in suspicious_chars:
|
||
if char in path:
|
||
return False
|
||
|
||
# 检查路径是否为绝对路径
|
||
if os.path.isabs(path):
|
||
# 检查是否在允许的目录范围内
|
||
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||
normalized_path = os.path.normpath(path)
|
||
if not normalized_path.startswith(base_dir):
|
||
return False
|
||
|
||
return True
|
||
|
||
def _run_analysis_task(self, group_key, source_dir, output_dir, wcf, target, sender):
|
||
"""在后台运行分析任务"""
|
||
start_time = time.time()
|
||
temp_files = [] # 用于跟踪所有创建的临时文件
|
||
|
||
try:
|
||
is_full = self.analysis_tasks[group_key].get("is_full", False)
|
||
self.LOG.info(f"开始{'全量' if is_full else '增量'}分析任务: {source_dir}")
|
||
wcf.send_text("🔍正在分析照片,请稍候...", target, sender)
|
||
|
||
# 分析结果
|
||
result = {
|
||
"start_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||
"total_images": 0,
|
||
"processed_images": 0,
|
||
"total_faces": 0,
|
||
"face_groups": 0, # 改名,不再是kid_faces
|
||
"groups": 0, # 改名,不再是kid_groups
|
||
"persons": {}, # 改名,不再是kids
|
||
"is_full": is_full
|
||
}
|
||
|
||
# 获取所有图片文件
|
||
image_files = []
|
||
for root, _, files in os.walk(source_dir):
|
||
for file in files:
|
||
if file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
|
||
image_files.append(os.path.join(root, file))
|
||
|
||
result["total_images"] = len(image_files)
|
||
|
||
if result["total_images"] == 0:
|
||
wcf.send_text("⚠️未找到任何图片文件", target, sender)
|
||
self.analysis_tasks[group_key]["running"] = False
|
||
return
|
||
|
||
# 如果是增量分析,获取已处理的照片和最后处理时间
|
||
processed_photos = set()
|
||
last_process_time = None
|
||
|
||
if not is_full and self.kid_photo_db:
|
||
processed_photos = self.kid_photo_db.get_processed_photos(group_key)
|
||
last_process_time = self.kid_photo_db.get_last_process_time(group_key)
|
||
|
||
if last_process_time:
|
||
wcf.send_text(
|
||
f"📊上次处理时间: {datetime.datetime.fromtimestamp(last_process_time).strftime('%Y-%m-%d %H:%M:%S')}\n已处理照片数: {len(processed_photos)}",
|
||
target, sender)
|
||
else:
|
||
wcf.send_text("⚠️未找到上次处理记录,将执行首次完整分析", target, sender)
|
||
|
||
# 筛选需要处理的图片
|
||
if not is_full and processed_photos:
|
||
# 只处理未处理过的文件或上次处理后修改的文件
|
||
filtered_image_files = []
|
||
for img_path in image_files:
|
||
# 如果文件不在已处理列表中,或者文件的修改时间晚于上次处理时间
|
||
if img_path not in processed_photos or (
|
||
last_process_time and os.path.getmtime(img_path) > last_process_time):
|
||
filtered_image_files.append(img_path)
|
||
|
||
image_files = filtered_image_files
|
||
wcf.send_text(f"📊本次需要处理的新增/修改照片数: {len(image_files)}", target, sender)
|
||
|
||
if len(image_files) == 0:
|
||
wcf.send_text("✅没有新增或修改的照片,无需分析", target, sender)
|
||
self.analysis_tasks[group_key]["running"] = False
|
||
return
|
||
|
||
# 进度更新
|
||
last_progress_time = time.time()
|
||
progress_interval = 10 # 每5秒更新一次进度
|
||
|
||
# 处理每张图片
|
||
all_faces = [] # 存储所有的人脸特征,不再区分小朋友
|
||
face_images = [] # 存储对应的图片路径
|
||
face_regions = [] # 存储人脸区域
|
||
|
||
# 记录本次处理的照片
|
||
newly_processed_photos = []
|
||
|
||
# 批量处理,避免内存溢出
|
||
batch_size = 50 # 每批处理的图片数量
|
||
total_batches = (len(image_files) + batch_size - 1) // batch_size
|
||
|
||
for batch_idx in range(total_batches):
|
||
start_idx = batch_idx * batch_size
|
||
end_idx = min((batch_idx + 1) * batch_size, len(image_files))
|
||
batch_images = image_files[start_idx:end_idx]
|
||
|
||
for i, image_path in enumerate(batch_images):
|
||
overall_idx = start_idx + i
|
||
try:
|
||
# 检测图片中的人脸
|
||
faces = self.face_analyzer.detect_faces(image_path)
|
||
|
||
for face in faces:
|
||
# 分析人脸
|
||
face_region = face.get('facial_area', None)
|
||
face_info, temp_file = self.face_analyzer.analyze_face(image_path, face_region)
|
||
|
||
# 如果创建了临时文件,添加到跟踪列表
|
||
if temp_file:
|
||
temp_files.append(temp_file)
|
||
|
||
if face_info:
|
||
# 保存人脸特征
|
||
all_faces.append(face_info['embedding'])
|
||
face_images.append(image_path)
|
||
face_regions.append(face_region)
|
||
result["face_groups"] += 1 # 更新计数器名称
|
||
|
||
result["total_faces"] += 1
|
||
|
||
result["processed_images"] += 1
|
||
newly_processed_photos.append(image_path)
|
||
|
||
# 更新进度
|
||
current_time = time.time()
|
||
if current_time - last_progress_time > progress_interval:
|
||
progress = (overall_idx + 1) / len(image_files) * 100
|
||
wcf.send_text(f"📊分析进度: {progress:.1f}% ({overall_idx + 1}/{len(image_files)})", target, sender)
|
||
last_progress_time = current_time
|
||
|
||
except Exception as e:
|
||
self.LOG.error(f"处理图片失败: {image_path}, 错误: {e}")
|
||
continue
|
||
|
||
# 每批处理完成后保存进度,避免全部失败
|
||
if newly_processed_photos and self.kid_photo_db:
|
||
self.kid_photo_db.save_processed_photos(group_key, newly_processed_photos)
|
||
newly_processed_photos = [] # 清空已保存的记录
|
||
|
||
# 强制垃圾回收,释放内存
|
||
import gc
|
||
gc.collect()
|
||
|
||
# 保存最后一批已处理的照片记录
|
||
if newly_processed_photos and self.kid_photo_db:
|
||
self.kid_photo_db.save_processed_photos(group_key, newly_processed_photos)
|
||
# 更新最后处理时间
|
||
self.kid_photo_db.save_last_process_time(group_key)
|
||
|
||
# 如果没有找到人脸
|
||
if not all_faces:
|
||
wcf.send_text("⚠️未检测到任何人脸", target, sender)
|
||
self.analysis_tasks[group_key]["running"] = False
|
||
|
||
# 保存分析结果
|
||
result["end_time"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
result["duration"] = time.time() - start_time
|
||
self._save_analysis_result(group_key, result)
|
||
|
||
return
|
||
|
||
# 对人脸进行聚类分组
|
||
wcf.send_text("🧩正在对人脸进行分组...", target, sender)
|
||
cluster_labels = self.face_grouper.cluster_faces(all_faces)
|
||
|
||
# 统计每个分组的人脸数量
|
||
face_groups = {}
|
||
for i, label in enumerate(cluster_labels):
|
||
if label not in face_groups:
|
||
face_groups[label] = []
|
||
face_groups[label].append((face_images[i], face_regions[i]))
|
||
|
||
result["groups"] = len(face_groups)
|
||
|
||
# 为每个人创建文件夹并复制照片
|
||
wcf.send_text(f"📁正在创建分类文件夹,共有{len(face_groups)}个人...", target, sender)
|
||
|
||
for person_id, faces in face_groups.items():
|
||
# 创建人物文件夹
|
||
person_folder = os.path.join(output_dir, f"person_{person_id}")
|
||
os.makedirs(person_folder, exist_ok=True)
|
||
|
||
# 复制照片 - 修改为只复制原始照片,不复制人脸区域
|
||
copied_photos = []
|
||
processed_paths = set() # 用于跟踪已处理的照片路径,避免重复复制
|
||
|
||
for image_path, _ in faces:
|
||
# 跳过临时文件
|
||
if ".temp." in image_path:
|
||
continue
|
||
|
||
# 避免重复复制同一张照片
|
||
if image_path in processed_paths:
|
||
continue
|
||
|
||
processed_paths.add(image_path)
|
||
|
||
if self.photo_classifier.copy_photo(image_path, person_folder):
|
||
copied_photos.append(os.path.basename(image_path))
|
||
# 保存照片映射关系
|
||
if self.kid_photo_db:
|
||
self.kid_photo_db.save_photo_mapping(group_key, f"person_{person_id}", image_path)
|
||
|
||
# 记录结果
|
||
result["persons"][f"person_{person_id}"] = {
|
||
"photo_count": len(copied_photos),
|
||
"photos": copied_photos
|
||
}
|
||
|
||
# 保存分析报告
|
||
report_path = self.photo_classifier.save_analysis_report(output_dir, result)
|
||
|
||
# 完成分析
|
||
result["end_time"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
result["duration"] = time.time() - start_time
|
||
|
||
# 保存分析结果
|
||
self._save_analysis_result(group_key, result)
|
||
|
||
# 发送分析结果
|
||
summary = self._generate_analysis_summary(result, output_dir)
|
||
wcf.send_text(summary, target, sender)
|
||
|
||
except MemoryError as e:
|
||
self.LOG.error(f"分析任务内存不足: {e}")
|
||
wcf.send_text(f"❌分析过程中内存不足,请减少照片数量或分批处理", target, sender)
|
||
except Exception as e:
|
||
self.LOG.error(f"分析任务出错: {e}")
|
||
self.LOG.error(traceback.format_exc())
|
||
wcf.send_text(f"❌分析过程中出错: {str(e)}", target, sender)
|
||
finally:
|
||
# 清理所有临时文件
|
||
self._cleanup_temp_files(temp_files)
|
||
|
||
# 标记任务完成
|
||
self.analysis_tasks[group_key]["running"] = False
|
||
|
||
# 强制垃圾回收
|
||
import gc
|
||
gc.collect()
|
||
|
||
def _cleanup_temp_files(self, temp_files):
|
||
"""清理所有临时文件"""
|
||
for temp_file in temp_files:
|
||
try:
|
||
if os.path.exists(temp_file):
|
||
os.remove(temp_file)
|
||
self.LOG.info(f"已删除临时文件: {temp_file}")
|
||
except Exception as e:
|
||
self.LOG.error(f"删除临时文件失败: {temp_file}, 错误: {e}")
|
||
|
||
def _save_analysis_result(self, group_key, result):
|
||
"""保存分析结果到Redis"""
|
||
try:
|
||
if self.kid_photo_db:
|
||
# 保存最后分析时间和分析结果
|
||
self.kid_photo_db.save_last_analysis_time(group_key)
|
||
self.kid_photo_db.save_analysis_result(group_key, result)
|
||
self.LOG.info(f"已保存分析结果: {group_key}")
|
||
else:
|
||
self.LOG.error("数据库未初始化")
|
||
except Exception as e:
|
||
self.LOG.error(f"保存分析结果失败: {e}")
|
||
|
||
def _get_last_analysis_result(self, group_key):
|
||
"""获取最后一次分析结果"""
|
||
try:
|
||
if self.kid_photo_db:
|
||
return self.kid_photo_db.get_last_analysis_result(group_key)
|
||
else:
|
||
self.LOG.error("数据库未初始化")
|
||
return None
|
||
except Exception as e:
|
||
self.LOG.error(f"获取分析结果失败: {e}")
|
||
return None
|
||
|
||
def _get_last_analysis_time(self, group_key):
|
||
"""获取最后一次分析时间"""
|
||
try:
|
||
if self.kid_photo_db:
|
||
return self.kid_photo_db.get_last_analysis_time(group_key)
|
||
else:
|
||
self.LOG.error("数据库未初始化")
|
||
return None
|
||
except Exception as e:
|
||
self.LOG.error(f"获取分析时间失败: {e}")
|
||
return None
|
||
|
||
def _handle_clean_analysis(self, wcf, sender, roomid, gbm):
|
||
"""处理清理分析数据命令"""
|
||
target = roomid if roomid else sender
|
||
group_key = roomid or sender
|
||
|
||
# 检查是否有分析任务在进行
|
||
if group_key in self.analysis_tasks and self.analysis_tasks[group_key].get("running", False):
|
||
wcf.send_text("⚠️当前有分析任务正在进行,无法清理数据", target, sender)
|
||
return True, "任务进行中"
|
||
|
||
# 清理数据
|
||
if self.kid_photo_db:
|
||
if self.kid_photo_db.clear_analysis_data(group_key):
|
||
wcf.send_text("✅已清理所有照片分析数据", target, sender)
|
||
return True, "清理成功"
|
||
else:
|
||
wcf.send_text("❌清理数据失败", target, sender)
|
||
return True, "清理失败"
|
||
else:
|
||
wcf.send_text("⚠️数据库未初始化,无法清理数据", target, sender)
|
||
return True, "数据库未初始化"
|
||
|
||
def _handle_view_analysis(self, wcf, sender, roomid, gbm):
|
||
"""处理查看分析结果命令"""
|
||
target = roomid if roomid else sender
|
||
group_key = roomid or sender
|
||
|
||
# 获取最近一次分析结果
|
||
result = self._get_last_analysis_result(group_key)
|
||
|
||
if not result:
|
||
wcf.send_text("⚠️未找到分析结果,请先执行照片分析", target, sender)
|
||
return True, "无分析结果"
|
||
|
||
# 生成分析摘要
|
||
output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||
"static", "images", f"kid_photos_{roomid}")
|
||
summary = self._generate_analysis_summary(result, output_dir)
|
||
wcf.send_text(summary, target, sender)
|
||
|
||
return True, "查看分析结果"
|
||
|
||
def _handle_analysis_time(self, wcf, sender, roomid, gbm):
|
||
"""处理查询分析时间命令"""
|
||
target = roomid if roomid else sender
|
||
group_key = roomid or sender
|
||
|
||
# 获取最后分析时间
|
||
last_time = self._get_last_analysis_time(group_key)
|
||
last_process_time = None
|
||
|
||
if self.kid_photo_db:
|
||
last_process_time = self.kid_photo_db.get_last_process_time(group_key)
|
||
|
||
if not last_time and not last_process_time:
|
||
wcf.send_text("⚠️未找到分析记录,请先执行照片分析", target, sender)
|
||
return True, "无分析记录"
|
||
|
||
# 生成时间信息
|
||
time_info = "📊照片分析时间信息:\n"
|
||
|
||
if last_time:
|
||
time_info += f"最后分析时间: {datetime.datetime.fromtimestamp(last_time).strftime('%Y-%m-%d %H:%M:%S')}\n"
|
||
|
||
if last_process_time:
|
||
time_info += f"最后处理时间: {datetime.datetime.fromtimestamp(last_process_time).strftime('%Y-%m-%d %H:%M:%S')}"
|
||
|
||
wcf.send_text(time_info, target, sender)
|
||
return True, "查询分析时间"
|
||
|
||
def _generate_analysis_summary(self, result, output_dir):
|
||
"""生成分析结果摘要"""
|
||
summary = "📊人脸照片分析结果:\n\n" # 更改标题
|
||
|
||
# 基本信息
|
||
summary += f"📷 总照片数: {result.get('total_images', 0)}\n"
|
||
summary += f"👤 处理照片数: {result.get('processed_images', 0)}\n"
|
||
summary += f"😊 检测到的人脸: {result.get('total_faces', 0)}\n"
|
||
summary += f"👪 人脸分组: {result.get('groups', 0)}\n\n" # 更新字段名
|
||
|
||
# 人物分组信息
|
||
persons = result.get('persons', {}) # 更新字段名
|
||
if persons:
|
||
summary += "🧒 人物照片统计:\n" # 更新描述
|
||
for person_id, person_info in persons.items(): # 更新变量名
|
||
summary += f" - {person_id}: {person_info.get('photo_count', 0)}张照片\n"
|
||
|
||
# 分析时间
|
||
start_time = result.get('start_time', '')
|
||
end_time = result.get('end_time', '')
|
||
duration = result.get('duration', 0)
|
||
|
||
summary += f"\n⏱️ 开始时间: {start_time}\n"
|
||
summary += f"⏱️ 结束时间: {end_time}\n"
|
||
summary += f"⏱️ 耗时: {duration:.2f}秒\n"
|
||
|
||
# 输出目录
|
||
summary += f"\n📁 照片已保存到: {output_dir}\n"
|
||
|
||
# 分析类型
|
||
is_full = result.get('is_full', False)
|
||
summary += f"📝 分析类型: {'全量' if is_full else '增量'}"
|
||
|
||
return summary
|