346 lines
16 KiB
Python
346 lines
16 KiB
Python
# -*- coding: utf-8 -*-
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
from datetime import datetime, timedelta
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from loguru import logger
|
||
|
||
from db.plugin_schedule_db import PluginScheduleDBOperator
|
||
from utils.decorator.async_job import async_job
|
||
from utils.robot_cmd.robot_command import GroupBotManager, PermissionStatus
|
||
|
||
|
||
class PluginScheduleManager:
|
||
"""插件定时任务管理器(数据库驱动)。"""
|
||
|
||
def __init__(self, plugin_manager, plugin_schedule_db: PluginScheduleDBOperator):
|
||
self.plugin_manager = plugin_manager
|
||
self.db = plugin_schedule_db
|
||
self._schedule_job_map: Dict[int, str] = {}
|
||
self._compensation_tolerance_seconds = 120
|
||
|
||
def init_and_load(self, *, run_startup_compensation: bool = False):
|
||
self.db.init_tables()
|
||
self.reload_from_db(run_startup_compensation=run_startup_compensation)
|
||
|
||
def migrate_from_system_jobs(self, system_job_db) -> Dict[str, int]:
|
||
"""把历史系统任务配置迁移到插件任务表(幂等)。"""
|
||
# 迁移映射:旧 system_job_key -> (插件显示名, 插件动作 key)
|
||
migration_map = {
|
||
"news_baidu_report_auto": ("每日新闻", "baidu_news_daily_push"),
|
||
"epic_free_games": ("Epic播报", "weekly_free_games_push"),
|
||
"message_ranking_push": ("每日排行", "daily_message_ranking_push"),
|
||
"sehuatang_pdf_push": ("涩图推送", "daily_pdf_push"),
|
||
"xiuren_download": ("秀人图片", "resource_xiuren_download"),
|
||
"shenshi_r15_download": ("秀人图片", "resource_shenshi_r15_download"),
|
||
"update_image_cache": ("秀人图片", "resource_update_image_cache"),
|
||
}
|
||
|
||
migrated = 0
|
||
skipped = 0
|
||
failed = 0
|
||
for job_key, target in migration_map.items():
|
||
plugin_name, action_key = target
|
||
try:
|
||
sys_row = system_job_db.get_job(job_key)
|
||
if not sys_row:
|
||
skipped += 1
|
||
continue
|
||
schedule_row = self.db.get_schedule_by_plugin_action(plugin_name, action_key)
|
||
if not schedule_row:
|
||
skipped += 1
|
||
continue
|
||
|
||
# 通过 payload 标记是否已经迁移,避免每次启动覆盖用户后续修改。
|
||
payload = schedule_row.get("payload") or {}
|
||
if payload.get("_migrated_from_system_job") == job_key:
|
||
skipped += 1
|
||
continue
|
||
|
||
payload["_migrated_from_system_job"] = job_key
|
||
updates = {
|
||
"trigger_type": sys_row.get("trigger_type", schedule_row.get("trigger_type")),
|
||
"trigger_config": sys_row.get("trigger_config") or schedule_row.get("trigger_config") or {},
|
||
"enabled": bool(sys_row.get("enabled", 1)),
|
||
"payload": payload,
|
||
}
|
||
# 名称/描述尽量沿用用户在插件端的展示,但允许继承旧系统任务描述。
|
||
if sys_row.get("description"):
|
||
updates["description"] = sys_row.get("description")
|
||
if self.db.update_schedule(int(schedule_row["id"]), updates):
|
||
migrated += 1
|
||
else:
|
||
failed += 1
|
||
except Exception as e:
|
||
failed += 1
|
||
logger.error(f"系统任务迁移到插件任务失败: job_key={job_key}, error={e}")
|
||
|
||
return {"migrated": migrated, "skipped": skipped, "failed": failed}
|
||
|
||
def _get_plugin_actions(self) -> List[Dict[str, Any]]:
|
||
actions = []
|
||
for plugin in self.plugin_manager.plugins.values():
|
||
if not hasattr(plugin, "get_schedule_actions"):
|
||
continue
|
||
try:
|
||
plugin_actions = plugin.get_schedule_actions() or []
|
||
except Exception as e:
|
||
logger.error(f"读取插件 {plugin.name} 调度动作失败: {e}")
|
||
continue
|
||
|
||
for action in plugin_actions:
|
||
actions.append(
|
||
{
|
||
"plugin_name": plugin.name,
|
||
"action_key": action.get("action_key"),
|
||
"action_name": action.get("name", action.get("action_key", "")),
|
||
"description": action.get("description", ""),
|
||
"trigger_type": action.get("trigger_type", "at_times"),
|
||
"trigger_config": action.get("trigger_config", {"time_list": ["09:00"]}),
|
||
"target_scope": action.get("target_scope", "all_enabled_groups"),
|
||
"target_config": action.get("target_config", {}),
|
||
"payload": action.get("payload", {}),
|
||
"enabled": bool(action.get("default_enabled", False)),
|
||
}
|
||
)
|
||
return actions
|
||
|
||
def sync_defaults(self):
|
||
for item in self._get_plugin_actions():
|
||
if not item.get("plugin_name") or not item.get("action_key"):
|
||
continue
|
||
self.db.upsert_default_schedule(item)
|
||
|
||
def _resolve_targets(self, plugin, schedule_row: Dict[str, Any]) -> List[str]:
|
||
scope = str(schedule_row.get("target_scope") or "all_enabled_groups")
|
||
target_cfg = schedule_row.get("target_config") or {}
|
||
|
||
if scope == "single_group":
|
||
gid = str(target_cfg.get("group_id") or "").strip()
|
||
return [gid] if gid else []
|
||
|
||
if scope == "group_whitelist":
|
||
group_ids = target_cfg.get("group_ids") or []
|
||
return [str(x).strip() for x in group_ids if str(x).strip()]
|
||
|
||
# 默认:所有已启用群
|
||
all_groups = GroupBotManager.get_group_list()
|
||
if not getattr(plugin, "feature", None):
|
||
return all_groups
|
||
|
||
enabled_groups = []
|
||
for gid in all_groups:
|
||
if GroupBotManager.get_group_permission(gid, plugin.feature) == PermissionStatus.ENABLED:
|
||
enabled_groups.append(gid)
|
||
return enabled_groups
|
||
|
||
@staticmethod
|
||
def _latest_expected_run_before_now(trigger_type: str, trigger_config: Dict[str, Any], now: datetime) -> datetime | None:
|
||
cfg = trigger_config or {}
|
||
if trigger_type == "every_seconds":
|
||
seconds = int(cfg.get("seconds") or 0)
|
||
if seconds <= 0:
|
||
return None
|
||
return now - timedelta(seconds=seconds)
|
||
|
||
if trigger_type == "at_times":
|
||
time_list = cfg.get("time_list") or []
|
||
candidates = []
|
||
for text in time_list:
|
||
try:
|
||
tm = datetime.strptime(str(text), "%H:%M").time()
|
||
except Exception:
|
||
continue
|
||
dt = datetime.combine(now.date(), tm)
|
||
if dt > now:
|
||
dt -= timedelta(days=1)
|
||
candidates.append(dt)
|
||
return max(candidates) if candidates else None
|
||
|
||
if trigger_type in ("every_weekday_time", "every_week_time"):
|
||
try:
|
||
weekday = int(cfg.get("weekday"))
|
||
tm = datetime.strptime(str(cfg.get("time_str") or ""), "%H:%M").time()
|
||
except Exception:
|
||
return None
|
||
days_ago = (now.weekday() - weekday + 7) % 7
|
||
dt = datetime.combine((now - timedelta(days=days_ago)).date(), tm)
|
||
if dt > now:
|
||
dt -= timedelta(days=7)
|
||
return dt
|
||
|
||
if trigger_type == "every_month_last_day_time":
|
||
try:
|
||
tm = datetime.strptime(str(cfg.get("time_str") or ""), "%H:%M").time()
|
||
except Exception:
|
||
return None
|
||
if now.month == 12:
|
||
next_month = datetime(now.year + 1, 1, 1)
|
||
else:
|
||
next_month = datetime(now.year, now.month + 1, 1)
|
||
last_day = next_month - timedelta(days=1)
|
||
dt = datetime.combine(last_day.date(), tm)
|
||
if dt > now:
|
||
if now.month == 1:
|
||
prev_next_month = datetime(now.year, 1, 1)
|
||
else:
|
||
prev_next_month = datetime(now.year, now.month, 1)
|
||
prev_last_day = prev_next_month - timedelta(days=1)
|
||
dt = datetime.combine(prev_last_day.date(), tm)
|
||
return dt
|
||
return None
|
||
|
||
@staticmethod
|
||
def _run_coro_blocking(coro):
|
||
try:
|
||
loop = asyncio.get_running_loop()
|
||
except RuntimeError:
|
||
return asyncio.run(coro)
|
||
return loop.create_task(coro)
|
||
|
||
def _should_compensate_once(self, schedule_id: int, trigger_type: str, trigger_config: Dict[str, Any]) -> bool:
|
||
expected_at = self._latest_expected_run_before_now(trigger_type, trigger_config, datetime.now())
|
||
if not expected_at:
|
||
return False
|
||
latest_log_at = self.db.get_latest_log_time(schedule_id)
|
||
if not latest_log_at:
|
||
return False
|
||
return latest_log_at < (expected_at - timedelta(seconds=self._compensation_tolerance_seconds))
|
||
|
||
async def _run_one_schedule(self, schedule_row: Dict[str, Any]) -> Dict[str, Any]:
|
||
schedule_id = int(schedule_row["id"])
|
||
action_key = schedule_row.get("action_key")
|
||
plugin_name = schedule_row.get("plugin_name")
|
||
|
||
_, plugin = self.plugin_manager.find_plugin_by_name(plugin_name)
|
||
if not plugin:
|
||
detail = {"error": f"未找到插件: {plugin_name}"}
|
||
self.db.create_log(schedule_id, "failed", detail["error"], detail)
|
||
return {"success": False, "summary": detail["error"], "detail": detail}
|
||
|
||
if not hasattr(plugin, "run_scheduled_action"):
|
||
detail = {"error": f"插件 {plugin.name} 未实现 run_scheduled_action"}
|
||
self.db.create_log(schedule_id, "failed", detail["error"], detail)
|
||
return {"success": False, "summary": detail["error"], "detail": detail}
|
||
|
||
targets = self._resolve_targets(plugin, schedule_row)
|
||
payload = schedule_row.get("payload") or {}
|
||
|
||
ctx = {
|
||
"schedule_id": schedule_id,
|
||
"triggered_at": datetime.now().isoformat(timespec="seconds"),
|
||
"target_scope": schedule_row.get("target_scope"),
|
||
"target_config": schedule_row.get("target_config") or {},
|
||
"target_groups": targets,
|
||
"payload": payload,
|
||
"bot": getattr(plugin, "bot", None),
|
||
}
|
||
|
||
try:
|
||
res = await plugin.run_scheduled_action(action_key, ctx)
|
||
if not isinstance(res, dict):
|
||
res = {"success": bool(res), "summary": "插件返回非 dict,已兼容处理", "detail": {"result": str(res)}}
|
||
except Exception as e:
|
||
res = {"success": False, "summary": f"执行异常: {e}", "detail": {"error": str(e)}}
|
||
|
||
status = "success" if res.get("success") else "failed"
|
||
summary = str(res.get("summary") or ("执行成功" if status == "success" else "执行失败"))
|
||
detail = res.get("detail") or {}
|
||
detail["target_count"] = len(targets)
|
||
self.db.create_log(schedule_id, status, summary, detail)
|
||
return {"success": status == "success", "summary": summary, "detail": detail}
|
||
|
||
def reload_from_db(self, *, run_startup_compensation: bool = True):
|
||
self.sync_defaults()
|
||
|
||
# 清理旧注册,避免重复
|
||
for job_id in list(self._schedule_job_map.values()):
|
||
async_job.remove_job(job_id)
|
||
self._schedule_job_map = {}
|
||
|
||
rows = self.db.list_enabled_schedules()
|
||
for row in rows:
|
||
schedule_id = int(row["id"])
|
||
|
||
async def _runner(_row=row):
|
||
await self._run_one_schedule(_row)
|
||
|
||
job_id = async_job.register_callable(
|
||
func=_runner,
|
||
trigger_type=row.get("trigger_type", "at_times"),
|
||
trigger_config=row.get("trigger_config", {"time_list": ["09:00"]}),
|
||
job_name=f"[插件调度]{row.get('plugin_name')}:{row.get('action_name')}",
|
||
description=row.get("description", ""),
|
||
job_key=f"plugin_schedule:{schedule_id}",
|
||
)
|
||
self._schedule_job_map[schedule_id] = job_id
|
||
|
||
# 重启/重载补偿:如果最近一次应执行时间已过且日志未覆盖,补跑一次。
|
||
try:
|
||
trigger_type = row.get("trigger_type", "at_times")
|
||
trigger_config = row.get("trigger_config", {"time_list": ["09:00"]})
|
||
if run_startup_compensation and self._should_compensate_once(schedule_id, trigger_type, trigger_config):
|
||
logger.warning(
|
||
f"插件调度触发漏执行补偿: schedule_id={schedule_id}, "
|
||
f"plugin={row.get('plugin_name')}, action={row.get('action_key')}"
|
||
)
|
||
self._run_coro_blocking(_runner())
|
||
except Exception as e:
|
||
logger.error(f"插件调度漏执行补偿失败: schedule_id={schedule_id}, error={e}")
|
||
|
||
def list_schedules_with_runtime(self) -> List[Dict[str, Any]]:
|
||
db_rows = self.db.list_schedules()
|
||
runtime_rows = async_job.get_jobs_snapshot()
|
||
runtime_by_key = {row.get("job_key"): row for row in runtime_rows if row.get("job_key")}
|
||
# 日志兜底:进程重启后内存态 last_run_at 会丢失,任务页需要从数据库最新日志恢复显示。
|
||
schedule_ids = [int(row.get("id")) for row in db_rows if row.get("id") is not None]
|
||
latest_log_by_schedule = self.db.get_latest_logs_map(schedule_ids)
|
||
|
||
data = []
|
||
for row in db_rows:
|
||
schedule_id = int(row["id"])
|
||
key = f"plugin_schedule:{schedule_id}"
|
||
runtime = runtime_by_key.get(key, {})
|
||
latest_log = latest_log_by_schedule.get(schedule_id) or {}
|
||
merged = dict(row)
|
||
merged["runtime_job_id"] = runtime.get("id")
|
||
merged["running"] = runtime.get("running", False)
|
||
merged["trigger_text"] = runtime.get("trigger_text", "")
|
||
merged["next_run_at"] = runtime.get("next_run_at")
|
||
# last_run_at 等字段优先取运行时;若缺失则用最新日志兜底,避免页面显示空白。
|
||
merged["last_run_at"] = runtime.get("last_run_at") or latest_log.get("triggered_at")
|
||
merged["last_status"] = runtime.get("last_status") or latest_log.get("status")
|
||
merged["last_error"] = runtime.get("last_error") or ""
|
||
if not merged["last_error"] and str(merged["last_status"]) == "failed":
|
||
merged["last_error"] = str(latest_log.get("summary") or "")
|
||
merged["last_duration_ms"] = runtime.get("last_duration_ms")
|
||
merged["run_count"] = runtime.get("run_count", 0)
|
||
merged["success_count"] = runtime.get("success_count", 0)
|
||
merged["fail_count"] = runtime.get("fail_count", 0)
|
||
data.append(merged)
|
||
return data
|
||
|
||
def trigger_now(self, schedule_id: int) -> (bool, str):
|
||
job_key = f"plugin_schedule:{int(schedule_id)}"
|
||
job_id = async_job.get_job_id_by_key(job_key)
|
||
if not job_id:
|
||
self.reload_from_db()
|
||
job_id = async_job.get_job_id_by_key(job_key)
|
||
if not job_id:
|
||
return False, "该调度未启用或未加载"
|
||
return async_job.trigger_job_now(job_id, operator="dashboard")
|
||
|
||
def update_schedule(self, schedule_id: int, updates: Dict[str, Any]) -> bool:
|
||
ok = self.db.update_schedule(int(schedule_id), updates)
|
||
if ok:
|
||
self.reload_from_db()
|
||
return ok
|
||
|
||
def get_logs(self, schedule_id: int, limit: int = 100) -> List[Dict[str, Any]]:
|
||
return self.db.get_logs(int(schedule_id), limit=limit)
|
||
|
||
def get_available_plugin_actions(self) -> List[Dict[str, Any]]:
|
||
return self._get_plugin_actions()
|