283 lines
12 KiB
Python
283 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
||
from __future__ import annotations
|
||
|
||
import inspect
|
||
import asyncio
|
||
from datetime import datetime, timedelta
|
||
from typing import Any, Awaitable, Callable, Dict, List
|
||
|
||
from loguru import logger
|
||
|
||
from db.system_job_db import SystemJobDBOperator
|
||
from utils.decorator.async_job import async_job
|
||
|
||
|
||
def get_system_job_definitions(robot) -> List[Dict[str, Any]]:
|
||
"""系统任务定义(业务函数映射)。
|
||
|
||
说明:这里只维护“任务 key 与业务函数”的绑定关系;
|
||
调度时间、启停状态全部从数据库 t_system_jobs 读取。
|
||
"""
|
||
return [
|
||
{
|
||
"job_key": "message_count_to_db",
|
||
"name": "消息计数入库",
|
||
"description": "每天 02:30 将 Redis 消息计数写入 SQLite",
|
||
"trigger_type": "at_times",
|
||
"trigger_config": {"time_list": ["02:30"]},
|
||
"handler": robot.message_count_to_db,
|
||
},
|
||
{
|
||
"job_key": "login_check",
|
||
"name": "登录状态巡检",
|
||
"description": "每天 14:43 执行登录二次校验",
|
||
"trigger_type": "at_times",
|
||
"trigger_config": {"time_list": ["14:43"]},
|
||
"handler": _build_login_check_handler(robot),
|
||
},
|
||
{
|
||
"job_key": "process_pending_images",
|
||
"name": "待下载图片补偿处理",
|
||
"description": "每 5 分钟处理一次待下载图片/表情,避免数据库锁竞争",
|
||
"trigger_type": "every_seconds",
|
||
"trigger_config": {"seconds": 300},
|
||
"handler": _build_process_pending_images_handler(robot),
|
||
},
|
||
{
|
||
"job_key": "contact_avatar_cache_sync",
|
||
"name": "联系人头像缓存同步",
|
||
"description": "每小时扫描一次联系人头像差异并增量下载,避免启动阶段批量拉头像",
|
||
"trigger_type": "every_seconds",
|
||
"trigger_config": {"seconds": 3600},
|
||
"handler": _build_contact_avatar_cache_sync_handler(robot),
|
||
},
|
||
]
|
||
|
||
def _build_process_pending_images_handler(robot) -> Callable[[], Awaitable[None]]:
|
||
async def _handler():
|
||
if hasattr(robot, "message_storage") and robot.message_storage:
|
||
await robot.message_storage.process_pending_images(minutes_ago=10, batch_size=20)
|
||
|
||
return _handler
|
||
|
||
|
||
def _build_contact_avatar_cache_sync_handler(robot) -> Callable[[], Awaitable[None]]:
|
||
async def _handler():
|
||
contact_manager = getattr(robot, "contact_manager", None)
|
||
if not contact_manager:
|
||
return
|
||
# 头像缓存同步内部包含 requests 下载和本地文件写入:
|
||
# 1. 这些都是阻塞式 IO,不适合直接卡在调度器事件循环里执行;
|
||
# 2. 这里统一切到线程池,保证系统其他异步任务的调度不受影响;
|
||
# 3. ContactManager 内部还会做“运行中跳过”保护,避免重复重入。
|
||
await asyncio.to_thread(contact_manager.run_scheduled_avatar_cache_sync, "system_job_hourly")
|
||
|
||
return _handler
|
||
|
||
|
||
def _build_login_check_handler(robot) -> Callable[[], Awaitable[bool]]:
|
||
async def _handler() -> bool:
|
||
ipad_bot = getattr(robot, "ipad_bot", None)
|
||
if not ipad_bot:
|
||
logger.info("系统任务 login_check 已跳过:wechat provider 尚未初始化")
|
||
return False
|
||
|
||
login_health_check = getattr(ipad_bot, "run_login_health_check", None)
|
||
if not callable(login_health_check):
|
||
logger.info("系统任务 login_check 已跳过:当前 provider 未暴露登录巡检能力")
|
||
return False
|
||
|
||
# 系统任务层只认 provider 统一入口:
|
||
# 1. 这样“是否要做二次登录校验”由具体 provider 自己决定;
|
||
# 2. Robot 不再残留 server 版本差异相关的旧方法;
|
||
# 3. 后续新增 provider 时,只需实现自己的巡检逻辑即可接入现有调度链。
|
||
result = login_health_check()
|
||
if inspect.isawaitable(result):
|
||
result = await result
|
||
return bool(result)
|
||
|
||
return _handler
|
||
|
||
|
||
class SystemJobLoader:
|
||
"""系统任务加载器:从数据库读取调度配置并注册到 async_job。"""
|
||
|
||
def __init__(self, robot, system_job_db: SystemJobDBOperator):
|
||
self.robot = robot
|
||
self.db = system_job_db
|
||
self._job_defs = {item["job_key"]: item for item in get_system_job_definitions(robot)}
|
||
self._registered_job_ids: List[str] = []
|
||
# 防止系统时钟误差导致“刚执行完又被判定漏跑”,给一个小容差窗口。
|
||
self._compensation_tolerance_seconds = 120
|
||
|
||
@staticmethod
|
||
def _latest_expected_run_before_now(trigger_type: str, trigger_config: Dict[str, Any], now: datetime) -> datetime | None:
|
||
"""根据调度配置计算“当前时刻之前最近一次应执行时间”。
|
||
|
||
注意:这里只用于漏执行补偿判定,不用于替代 async_job 的正式调度。
|
||
"""
|
||
cfg = trigger_config or {}
|
||
if trigger_type == "every_seconds":
|
||
seconds = int(cfg.get("seconds") or 0)
|
||
if seconds <= 0:
|
||
return None
|
||
return now - timedelta(seconds=seconds)
|
||
|
||
if trigger_type == "at_times":
|
||
time_list = cfg.get("time_list") or []
|
||
candidates = []
|
||
for text in time_list:
|
||
try:
|
||
tm = datetime.strptime(str(text), "%H:%M").time()
|
||
except Exception:
|
||
continue
|
||
dt = datetime.combine(now.date(), tm)
|
||
if dt > now:
|
||
dt -= timedelta(days=1)
|
||
candidates.append(dt)
|
||
return max(candidates) if candidates else None
|
||
|
||
if trigger_type in ("every_weekday_time", "every_week_time"):
|
||
try:
|
||
weekday = int(cfg.get("weekday"))
|
||
tm = datetime.strptime(str(cfg.get("time_str") or ""), "%H:%M").time()
|
||
except Exception:
|
||
return None
|
||
days_ago = (now.weekday() - weekday + 7) % 7
|
||
dt = datetime.combine((now - timedelta(days=days_ago)).date(), tm)
|
||
if dt > now:
|
||
dt -= timedelta(days=7)
|
||
return dt
|
||
|
||
if trigger_type == "every_month_last_day_time":
|
||
try:
|
||
tm = datetime.strptime(str(cfg.get("time_str") or ""), "%H:%M").time()
|
||
except Exception:
|
||
return None
|
||
# 先算本月最后一天
|
||
if now.month == 12:
|
||
next_month = datetime(now.year + 1, 1, 1)
|
||
else:
|
||
next_month = datetime(now.year, now.month + 1, 1)
|
||
last_day = next_month - timedelta(days=1)
|
||
dt = datetime.combine(last_day.date(), tm)
|
||
if dt > now:
|
||
# 回退到上个月最后一天
|
||
if now.month == 1:
|
||
prev_next_month = datetime(now.year, 1, 1)
|
||
else:
|
||
prev_next_month = datetime(now.year, now.month, 1)
|
||
prev_last_day = prev_next_month - timedelta(days=1)
|
||
dt = datetime.combine(prev_last_day.date(), tm)
|
||
return dt
|
||
return None
|
||
|
||
@staticmethod
|
||
def _run_coro_blocking(coro):
|
||
"""在同步上下文执行协程:无事件循环则阻塞执行,有事件循环则丢给当前循环。"""
|
||
try:
|
||
loop = asyncio.get_running_loop()
|
||
except RuntimeError:
|
||
return asyncio.run(coro)
|
||
return loop.create_task(coro)
|
||
|
||
def _should_compensate_once(self, job_key: str, trigger_type: str, trigger_config: Dict[str, Any]) -> bool:
|
||
expected_at = self._latest_expected_run_before_now(trigger_type, trigger_config, datetime.now())
|
||
if not expected_at:
|
||
return False
|
||
latest_log_at = self.db.get_latest_log_time(job_key)
|
||
if not latest_log_at:
|
||
# 没有历史执行日志时不做补偿,避免首次上线就批量触发一次。
|
||
return False
|
||
return latest_log_at < (expected_at - timedelta(seconds=self._compensation_tolerance_seconds))
|
||
|
||
def init_and_load(self, *, run_startup_compensation: bool = False):
|
||
self.db.init_tables()
|
||
self._seed_defaults()
|
||
self.reload_from_db(run_startup_compensation=run_startup_compensation)
|
||
|
||
def _seed_defaults(self):
|
||
for item in self._job_defs.values():
|
||
existed = self.db.get_job(item["job_key"])
|
||
if existed:
|
||
continue
|
||
self.db.upsert_job(
|
||
{
|
||
"job_key": item["job_key"],
|
||
"name": item["name"],
|
||
"description": item.get("description", ""),
|
||
"trigger_type": item["trigger_type"],
|
||
"trigger_config": item["trigger_config"],
|
||
"enabled": True,
|
||
}
|
||
)
|
||
|
||
def reload_from_db(self, *, run_startup_compensation: bool = True):
|
||
# 每次重载前先补齐默认任务,避免误删后无法恢复
|
||
self._seed_defaults()
|
||
|
||
# 先移除当前注册任务,避免重复调度
|
||
for job_id in self._registered_job_ids:
|
||
async_job.remove_job(job_id)
|
||
self._registered_job_ids = []
|
||
|
||
jobs = self.db.list_jobs()
|
||
for row in jobs:
|
||
job_key = row.get("job_key")
|
||
if not row.get("enabled", 1):
|
||
continue
|
||
definition = self._job_defs.get(job_key)
|
||
if not definition:
|
||
logger.warning(f"系统任务 {job_key} 在代码中无处理器,已跳过注册")
|
||
continue
|
||
|
||
raw_handler = definition["handler"]
|
||
|
||
async def _wrapped_handler(_handler=raw_handler, _job_key=job_key):
|
||
"""系统任务执行包装器:执行业务并持久化日志。"""
|
||
started_at = datetime.now()
|
||
try:
|
||
result = _handler()
|
||
# 兼容同步/异步 handler 两种写法。
|
||
if inspect.isawaitable(result):
|
||
await result
|
||
duration_ms = int((datetime.now() - started_at).total_seconds() * 1000)
|
||
self.db.create_job_log(
|
||
_job_key,
|
||
"success",
|
||
"执行成功",
|
||
detail={"job_key": _job_key},
|
||
duration_ms=duration_ms,
|
||
)
|
||
except Exception as e:
|
||
duration_ms = int((datetime.now() - started_at).total_seconds() * 1000)
|
||
# 失败日志写库后继续抛出,让 async_job 运行态状态也能正确标记为 failed。
|
||
self.db.create_job_log(
|
||
_job_key,
|
||
"failed",
|
||
f"执行失败: {e}",
|
||
detail={"job_key": _job_key, "error": str(e)},
|
||
duration_ms=duration_ms,
|
||
)
|
||
raise
|
||
|
||
job_id = async_job.register_callable(
|
||
func=_wrapped_handler,
|
||
trigger_type=row.get("trigger_type", definition["trigger_type"]),
|
||
trigger_config=row.get("trigger_config", definition["trigger_config"]),
|
||
job_name=row.get("name") or definition["name"],
|
||
description=row.get("description") or definition.get("description", ""),
|
||
job_key=job_key,
|
||
)
|
||
self._registered_job_ids.append(job_id)
|
||
|
||
# 重载后执行一次漏执行补偿:若最近一次“应执行时间”晚于最新日志,则补跑一次。
|
||
try:
|
||
trigger_type = row.get("trigger_type", definition["trigger_type"])
|
||
trigger_config = row.get("trigger_config", definition["trigger_config"])
|
||
if run_startup_compensation and self._should_compensate_once(job_key, trigger_type, trigger_config):
|
||
logger.warning(f"系统任务触发漏执行补偿: job_key={job_key}")
|
||
self._run_coro_blocking(_wrapped_handler())
|
||
except Exception as e:
|
||
logger.error(f"系统任务漏执行补偿失败: job_key={job_key}, error={e}")
|