Files
WechatHookBot/utils/image_processor.py
2026-01-08 09:49:01 +08:00

766 lines
28 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
图片/视频处理模块
提供媒体文件的下载、编码和描述生成:
- 图片下载与 base64 编码
- 表情包下载与编码
- 视频下载与编码
- AI 图片/视频描述生成
使用示例:
from utils.image_processor import ImageProcessor, MediaConfig
config = MediaConfig(
api_url="https://api.openai.com/v1/chat/completions",
api_key="sk-xxx",
model="gpt-4-vision-preview",
)
processor = ImageProcessor(config)
# 下载图片
image_base64 = await processor.download_image(bot, cdnurl, aeskey)
# 生成描述
description = await processor.generate_description(image_base64, "描述这张图片")
"""
from __future__ import annotations
import asyncio
import base64
import io
import json
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional, TYPE_CHECKING
import aiohttp
from loguru import logger
# 图片处理支持
try:
from PIL import Image
PIL_AVAILABLE = True
except ImportError:
PIL_AVAILABLE = False
logger.warning("[ImageProcessor] Pillow 未安装GIF 转换功能不可用")
# 可选代理支持
try:
from aiohttp_socks import ProxyConnector
PROXY_SUPPORT = True
except ImportError:
PROXY_SUPPORT = False
if TYPE_CHECKING:
pass # bot 类型提示
@dataclass
class MediaConfig:
"""媒体处理配置"""
# API 配置
api_url: str = "https://api.openai.com/v1/chat/completions"
api_key: str = ""
model: str = "gpt-4-vision-preview"
timeout: int = 120
max_tokens: int = 1000
retries: int = 2
# 代理配置
proxy_enabled: bool = False
proxy_type: str = "socks5"
proxy_host: str = "127.0.0.1"
proxy_port: int = 7890
proxy_username: str = ""
proxy_password: str = ""
# 视频专用配置
video_api_url: str = ""
video_model: str = ""
video_max_size_mb: int = 20
video_timeout: int = 360
video_max_tokens: int = 8192
# 临时目录
temp_dir: Optional[Path] = None
@classmethod
def from_dict(cls, config: Dict[str, Any]) -> "MediaConfig":
"""从配置字典创建"""
api_config = config.get("api", {})
proxy_config = config.get("proxy", {})
image_desc_config = config.get("image_description", {})
video_config = config.get("video_recognition", {})
return cls(
api_url=api_config.get("url", "https://api.openai.com/v1/chat/completions"),
api_key=api_config.get("api_key", ""),
model=image_desc_config.get("model", api_config.get("model", "gpt-4-vision-preview")),
timeout=api_config.get("timeout", 120),
max_tokens=image_desc_config.get("max_tokens", 1000),
retries=image_desc_config.get("retries", 2),
proxy_enabled=proxy_config.get("enabled", False),
proxy_type=proxy_config.get("type", "socks5"),
proxy_host=proxy_config.get("host", "127.0.0.1"),
proxy_port=proxy_config.get("port", 7890),
proxy_username=proxy_config.get("username", ""),
proxy_password=proxy_config.get("password", ""),
video_api_url=video_config.get("api_url", ""),
video_model=video_config.get("model", ""),
video_max_size_mb=video_config.get("max_size_mb", 20),
video_timeout=video_config.get("timeout", 360),
video_max_tokens=video_config.get("max_tokens", 8192),
)
@dataclass
class MediaResult:
"""媒体处理结果"""
success: bool = False
data: str = "" # base64 数据
description: str = ""
error: Optional[str] = None
media_type: str = "image" # image, emoji, video
class ImageProcessor:
"""
图片/视频处理器
提供统一的媒体处理接口:
- 下载和编码
- AI 描述生成
- 缓存支持
"""
def __init__(self, config: MediaConfig, temp_dir: Optional[Path] = None):
self.config = config
self.temp_dir = temp_dir or config.temp_dir or Path("temp")
self.temp_dir.mkdir(exist_ok=True)
def _get_proxy_connector(self) -> Optional[Any]:
"""获取代理连接器"""
if not self.config.proxy_enabled or not PROXY_SUPPORT:
return None
proxy_type = self.config.proxy_type.upper()
if self.config.proxy_username and self.config.proxy_password:
proxy_url = (
f"{proxy_type}://{self.config.proxy_username}:"
f"{self.config.proxy_password}@"
f"{self.config.proxy_host}:{self.config.proxy_port}"
)
else:
proxy_url = f"{proxy_type}://{self.config.proxy_host}:{self.config.proxy_port}"
try:
return ProxyConnector.from_url(proxy_url)
except Exception as e:
logger.warning(f"[ImageProcessor] 代理配置失败: {e}")
return None
async def download_image(
self,
bot,
cdnurl: str,
aeskey: str,
use_cache: bool = True,
) -> str:
"""
下载图片并转换为 base64
Args:
bot: WechatHookClient 实例(用于 CDN 下载)
cdnurl: CDN URL
aeskey: AES 密钥
use_cache: 是否使用缓存
Returns:
base64 编码的图片数据(带 data URI 前缀)
"""
try:
# 1. 优先从 Redis 缓存获取
if use_cache:
from utils.redis_cache import RedisCache, get_cache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
if media_key:
cached_data = redis_cache.get_cached_media(media_key, "image")
if cached_data:
logger.debug(f"[ImageProcessor] 图片缓存命中: {media_key[:20]}...")
return cached_data
# 2. 缓存未命中,下载图片
logger.debug(f"[ImageProcessor] 开始下载图片...")
filename = f"temp_{uuid.uuid4().hex[:8]}.jpg"
save_path = str((self.temp_dir / filename).resolve())
# 尝试下载中图,失败则下载原图
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=2)
if not success:
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=1)
if not success:
logger.error("[ImageProcessor] CDN 下载失败")
return ""
# 等待文件写入完成
import os
for _ in range(20): # 最多等待10秒
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
break
await asyncio.sleep(0.5)
if not os.path.exists(save_path):
logger.error("[ImageProcessor] 图片文件未生成")
return ""
with open(save_path, "rb") as f:
image_data = base64.b64encode(f.read()).decode()
base64_result = f"data:image/jpeg;base64,{image_data}"
# 3. 缓存到 Redis
if use_cache:
try:
from utils.redis_cache import RedisCache, get_cache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
if media_key:
redis_cache.cache_media(media_key, base64_result, "image", ttl=300)
logger.debug(f"[ImageProcessor] 图片已缓存: {media_key[:20]}...")
except Exception as e:
logger.debug(f"[ImageProcessor] 缓存图片失败: {e}")
# 清理临时文件
try:
Path(save_path).unlink()
except Exception:
pass
return base64_result
except Exception as e:
logger.error(f"[ImageProcessor] 下载图片失败: {e}")
return ""
async def download_emoji(
self,
cdn_url: str,
max_retries: int = 3,
use_cache: bool = True,
) -> str:
"""
下载表情包并转换为 base64
Args:
cdn_url: CDN URL
max_retries: 最大重试次数
use_cache: 是否使用缓存
Returns:
base64 编码的表情包数据(带 data URI 前缀)
"""
# 替换 HTML 实体
cdn_url = cdn_url.replace("&", "&")
# 1. 优先从 Redis 缓存获取
media_key = None
if use_cache:
try:
from utils.redis_cache import RedisCache, get_cache
redis_cache = get_cache()
media_key = RedisCache.generate_media_key(cdnurl=cdn_url)
if redis_cache and redis_cache.enabled and media_key:
cached_data = redis_cache.get_cached_media(media_key, "emoji")
if cached_data:
logger.debug(f"[ImageProcessor] 表情包缓存命中: {media_key[:20]}...")
return cached_data
except Exception:
pass
# 2. 缓存未命中,下载表情包
logger.debug(f"[ImageProcessor] 开始下载表情包...")
last_error = None
connector = self._get_proxy_connector()
for attempt in range(max_retries):
try:
timeout = aiohttp.ClientTimeout(total=30 + attempt * 15)
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.get(cdn_url) as response:
if response.status == 200:
content = await response.read()
if len(content) == 0:
logger.warning(f"[ImageProcessor] 表情包内容为空,重试 {attempt + 1}/{max_retries}")
continue
image_data = base64.b64encode(content).decode()
base64_result = f"data:image/gif;base64,{image_data}"
logger.debug(f"[ImageProcessor] 表情包下载成功,大小: {len(content)} 字节")
# 3. 缓存到 Redis
if use_cache and media_key:
try:
from utils.redis_cache import get_cache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
redis_cache.cache_media(media_key, base64_result, "emoji", ttl=300)
logger.debug(f"[ImageProcessor] 表情包已缓存: {media_key[:20]}...")
except Exception:
pass
return base64_result
else:
logger.warning(f"[ImageProcessor] 表情包下载失败,状态码: {response.status}")
except asyncio.TimeoutError:
last_error = "请求超时"
logger.warning(f"[ImageProcessor] 表情包下载超时,重试 {attempt + 1}/{max_retries}")
except aiohttp.ClientError as e:
last_error = str(e)
logger.warning(f"[ImageProcessor] 表情包下载网络错误: {e}")
except Exception as e:
last_error = str(e)
logger.warning(f"[ImageProcessor] 表情包下载异常: {e}")
if attempt < max_retries - 1:
await asyncio.sleep(1 * (attempt + 1))
logger.error(f"[ImageProcessor] 表情包下载失败,已重试 {max_retries} 次: {last_error}")
return ""
async def download_video(
self,
bot,
cdnurl: str,
aeskey: str,
use_cache: bool = True,
) -> str:
"""
下载视频并转换为 base64
Args:
bot: WechatHookClient 实例
cdnurl: CDN URL
aeskey: AES 密钥
use_cache: 是否使用缓存
Returns:
base64 编码的视频数据(带 data URI 前缀)
"""
try:
# 从缓存获取
media_key = None
if use_cache:
try:
from utils.redis_cache import RedisCache, get_cache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
media_key = RedisCache.generate_media_key(cdnurl, aeskey)
if media_key:
cached_data = redis_cache.get_cached_media(media_key, "video")
if cached_data:
logger.debug(f"[ImageProcessor] 视频缓存命中: {media_key[:20]}...")
return cached_data
except Exception:
pass
# 下载视频
logger.info(f"[ImageProcessor] 开始下载视频...")
filename = f"video_{uuid.uuid4().hex[:8]}.mp4"
save_path = str((self.temp_dir / filename).resolve())
# file_type=4 表示视频
success = await bot.cdn_download(cdnurl, aeskey, save_path, file_type=4)
if not success:
logger.error("[ImageProcessor] 视频 CDN 下载失败")
return ""
# 等待文件写入完成
import os
for _ in range(30):
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
break
await asyncio.sleep(0.5)
if not os.path.exists(save_path):
logger.error("[ImageProcessor] 视频文件未生成")
return ""
file_size = os.path.getsize(save_path)
logger.info(f"[ImageProcessor] 视频下载完成,大小: {file_size / 1024 / 1024:.2f} MB")
# 检查文件大小限制
max_size_mb = self.config.video_max_size_mb
if file_size > max_size_mb * 1024 * 1024:
logger.warning(f"[ImageProcessor] 视频文件过大: {file_size / 1024 / 1024:.2f} MB > {max_size_mb} MB")
try:
Path(save_path).unlink()
except Exception:
pass
return ""
# 读取并编码
with open(save_path, "rb") as f:
video_data = base64.b64encode(f.read()).decode()
video_base64 = f"data:video/mp4;base64,{video_data}"
# 缓存到 Redis
if use_cache and media_key:
try:
from utils.redis_cache import get_cache
redis_cache = get_cache()
if redis_cache and redis_cache.enabled:
redis_cache.cache_media(media_key, video_base64, "video", ttl=600)
logger.debug(f"[ImageProcessor] 视频已缓存: {media_key[:20]}...")
except Exception:
pass
# 清理临时文件
try:
Path(save_path).unlink()
except Exception:
pass
return video_base64
except Exception as e:
logger.error(f"[ImageProcessor] 下载视频失败: {e}")
import traceback
logger.error(traceback.format_exc())
return ""
def _convert_gif_to_png(self, image_base64: str) -> str:
"""
将 GIF 图片转换为 PNG提取第一帧
Args:
image_base64: GIF 图片的 base64 数据(带 data URI 前缀)
Returns:
PNG 图片的 base64 数据(带 data URI 前缀),失败返回原数据
"""
if not PIL_AVAILABLE:
logger.warning("[ImageProcessor] Pillow 未安装,无法转换 GIF")
return image_base64
try:
# 提取 base64 数据部分
if "," in image_base64:
base64_data = image_base64.split(",", 1)[1]
else:
base64_data = image_base64
# 解码 base64
gif_bytes = base64.b64decode(base64_data)
# 使用 Pillow 打开 GIF 并提取第一帧
img = Image.open(io.BytesIO(gif_bytes))
# 转换为 RGB 模式(去除透明通道)
if img.mode in ('RGBA', 'LA', 'P'):
# 创建白色背景
background = Image.new('RGB', img.size, (255, 255, 255))
if img.mode == 'P':
img = img.convert('RGBA')
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None)
img = background
elif img.mode != 'RGB':
img = img.convert('RGB')
# 保存为 PNG
output = io.BytesIO()
img.save(output, format='PNG', optimize=True)
png_bytes = output.getvalue()
# 编码为 base64
png_base64 = base64.b64encode(png_bytes).decode()
result = f"data:image/png;base64,{png_base64}"
logger.debug(f"[ImageProcessor] GIF 已转换为 PNG原大小: {len(gif_bytes)} 字节,新大小: {len(png_bytes)} 字节")
return result
except Exception as e:
logger.error(f"[ImageProcessor] GIF 转换失败: {e}")
return image_base64
async def generate_description(
self,
image_base64: str,
prompt: str = "请用一句话简洁地描述这张图片的主要内容。",
model: Optional[str] = None,
) -> str:
"""
使用 AI 生成图片描述
Args:
image_base64: 图片的 base64 数据
prompt: 描述提示词
model: 使用的模型(默认使用配置中的模型)
Returns:
图片描述文本,失败返回空字符串
"""
# 检测并转换 GIF 格式(大多数视觉 API 不支持 GIF
if image_base64.startswith("data:image/gif"):
logger.debug("[ImageProcessor] 检测到 GIF 格式,转换为 PNG...")
image_base64 = self._convert_gif_to_png(image_base64)
description_model = model or self.config.model
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": image_base64}}
]
}
]
payload = {
"model": description_model,
"messages": messages,
"max_tokens": self.config.max_tokens,
"stream": True
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.config.api_key}"
}
max_retries = self.config.retries
last_error = None
for attempt in range(max_retries + 1):
try:
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
connector = self._get_proxy_connector()
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.post(
self.config.api_url,
json=payload,
headers=headers
) as resp:
if resp.status != 200:
error_text = await resp.text()
raise Exception(f"API 返回错误: {resp.status}, {error_text[:200]}")
# 流式接收响应
description = ""
async for line in resp.content:
line = line.decode('utf-8').strip()
if not line or line == "data: [DONE]":
continue
if line.startswith("data: "):
try:
data = json.loads(line[6:])
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
description += content
except Exception:
pass
logger.debug(f"[ImageProcessor] 图片描述生成成功: {description[:50]}...")
return description.strip()
except asyncio.CancelledError:
raise
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
last_error = str(e)
if attempt < max_retries:
logger.warning(f"[ImageProcessor] 图片描述网络错误: {e},重试 {attempt + 1}/{max_retries}")
await asyncio.sleep(1 * (attempt + 1))
continue
except Exception as e:
last_error = str(e)
if attempt < max_retries:
logger.warning(f"[ImageProcessor] 图片描述生成异常: {e},重试 {attempt + 1}/{max_retries}")
await asyncio.sleep(1 * (attempt + 1))
continue
logger.error(f"[ImageProcessor] 生成图片描述失败,已重试 {max_retries + 1} 次: {last_error}")
return ""
async def analyze_video(
self,
video_base64: str,
prompt: Optional[str] = None,
) -> str:
"""
使用 AI 分析视频内容
Args:
video_base64: 视频的 base64 数据
prompt: 分析提示词
Returns:
视频分析描述,失败返回空字符串
"""
if not self.config.video_api_url or not self.config.video_model:
logger.error("[ImageProcessor] 视频分析配置不完整")
return ""
# 去除 data:video/mp4;base64, 前缀(如果有)
if video_base64.startswith("data:"):
video_base64 = video_base64.split(",", 1)[1]
default_prompt = """请详细分析这个视频的内容,包括:
1. 视频的主要场景和环境
2. 出现的人物/物体及其动作
3. 视频中的文字、对话或声音(如果有)
4. 视频的整体主题或要表达的内容
5. 任何值得注意的细节
请用客观、详细的方式描述,不要加入主观评价。"""
analyze_prompt = prompt or default_prompt
full_url = f"{self.config.video_api_url}/{self.config.video_model}:generateContent"
payload = {
"contents": [
{
"parts": [
{"text": analyze_prompt},
{
"inline_data": {
"mime_type": "video/mp4",
"data": video_base64
}
}
]
}
],
"generationConfig": {
"maxOutputTokens": self.config.video_max_tokens
}
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.config.api_key}"
}
timeout = aiohttp.ClientTimeout(total=self.config.video_timeout)
max_retries = 2
retry_delay = 5
for attempt in range(max_retries + 1):
try:
logger.info(f"[ImageProcessor] 开始分析视频...{f' (重试 {attempt}/{max_retries})' if attempt > 0 else ''}")
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(full_url, json=payload, headers=headers) as resp:
if resp.status in [502, 503, 504]:
logger.warning(f"[ImageProcessor] 视频 API 临时错误: {resp.status}")
if attempt < max_retries:
await asyncio.sleep(retry_delay)
continue
return ""
if resp.status != 200:
error_text = await resp.text()
logger.error(f"[ImageProcessor] 视频 API 错误: {resp.status}, {error_text[:300]}")
return ""
# 检查响应类型是否为 JSON
content_type = resp.headers.get('Content-Type', '')
if 'application/json' not in content_type:
error_text = await resp.text()
logger.error(f"[ImageProcessor] 视频 API 返回非 JSON 响应: Content-Type={content_type}, Body={error_text[:500]}")
return ""
result = await resp.json()
# 检查安全过滤
if "promptFeedback" in result:
feedback = result["promptFeedback"]
if feedback.get("blockReason"):
logger.warning(f"[ImageProcessor] 视频内容被过滤: {feedback.get('blockReason')}")
return ""
# 提取文本
if "candidates" in result and result["candidates"]:
for candidate in result["candidates"]:
if candidate.get("finishReason") == "SAFETY":
logger.warning("[ImageProcessor] 视频响应被安全过滤")
return ""
content = candidate.get("content", {})
for part in content.get("parts", []):
if "text" in part:
text = part["text"]
logger.info(f"[ImageProcessor] 视频分析完成,长度: {len(text)}")
return text
logger.error(f"[ImageProcessor] 视频分析无有效响应")
return ""
except asyncio.TimeoutError:
logger.warning(f"[ImageProcessor] 视频分析超时{f', 将重试...' if attempt < max_retries else ''}")
if attempt < max_retries:
await asyncio.sleep(retry_delay)
continue
return ""
except Exception as e:
logger.error(f"[ImageProcessor] 视频分析失败: {e}")
import traceback
logger.error(traceback.format_exc())
return ""
return ""
# ==================== 便捷函数 ====================
_default_processor: Optional[ImageProcessor] = None
def get_image_processor(config: Optional[MediaConfig] = None) -> ImageProcessor:
"""获取默认图片处理器"""
global _default_processor
if config:
_default_processor = ImageProcessor(config)
if _default_processor is None:
raise ValueError("ImageProcessor 未初始化,请先传入配置")
return _default_processor
def init_image_processor(config_dict: Dict[str, Any], temp_dir: Optional[Path] = None) -> ImageProcessor:
"""从配置字典初始化图片处理器"""
config = MediaConfig.from_dict(config_dict)
if temp_dir:
config.temp_dir = temp_dir
processor = ImageProcessor(config, temp_dir)
global _default_processor
_default_processor = processor
return processor
# ==================== 导出 ====================
__all__ = [
'MediaConfig',
'MediaResult',
'ImageProcessor',
'get_image_processor',
'init_image_processor',
]