Files
abot/utils/plugin_schedule_manager.py
2026-05-01 12:45:40 +08:00

346 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from loguru import logger
from db.plugin_schedule_db import PluginScheduleDBOperator
from utils.decorator.async_job import async_job
from utils.robot_cmd.robot_command import GroupBotManager, PermissionStatus
class PluginScheduleManager:
"""插件定时任务管理器(数据库驱动)。"""
def __init__(self, plugin_manager, plugin_schedule_db: PluginScheduleDBOperator):
self.plugin_manager = plugin_manager
self.db = plugin_schedule_db
self._schedule_job_map: Dict[int, str] = {}
self._compensation_tolerance_seconds = 120
def init_and_load(self):
self.db.init_tables()
self.reload_from_db()
def migrate_from_system_jobs(self, system_job_db) -> Dict[str, int]:
"""把历史系统任务配置迁移到插件任务表(幂等)。"""
# 迁移映射:旧 system_job_key -> (插件显示名, 插件动作 key)
migration_map = {
"news_baidu_report_auto": ("每日新闻", "baidu_news_daily_push"),
"epic_free_games": ("Epic播报", "weekly_free_games_push"),
"message_ranking_push": ("每日排行", "daily_message_ranking_push"),
"sehuatang_pdf_push": ("涩图推送", "daily_pdf_push"),
"xiuren_download": ("秀人图片", "resource_xiuren_download"),
"shenshi_r15_download": ("秀人图片", "resource_shenshi_r15_download"),
"update_image_cache": ("秀人图片", "resource_update_image_cache"),
}
migrated = 0
skipped = 0
failed = 0
for job_key, target in migration_map.items():
plugin_name, action_key = target
try:
sys_row = system_job_db.get_job(job_key)
if not sys_row:
skipped += 1
continue
schedule_row = self.db.get_schedule_by_plugin_action(plugin_name, action_key)
if not schedule_row:
skipped += 1
continue
# 通过 payload 标记是否已经迁移,避免每次启动覆盖用户后续修改。
payload = schedule_row.get("payload") or {}
if payload.get("_migrated_from_system_job") == job_key:
skipped += 1
continue
payload["_migrated_from_system_job"] = job_key
updates = {
"trigger_type": sys_row.get("trigger_type", schedule_row.get("trigger_type")),
"trigger_config": sys_row.get("trigger_config") or schedule_row.get("trigger_config") or {},
"enabled": bool(sys_row.get("enabled", 1)),
"payload": payload,
}
# 名称/描述尽量沿用用户在插件端的展示,但允许继承旧系统任务描述。
if sys_row.get("description"):
updates["description"] = sys_row.get("description")
if self.db.update_schedule(int(schedule_row["id"]), updates):
migrated += 1
else:
failed += 1
except Exception as e:
failed += 1
logger.error(f"系统任务迁移到插件任务失败: job_key={job_key}, error={e}")
return {"migrated": migrated, "skipped": skipped, "failed": failed}
def _get_plugin_actions(self) -> List[Dict[str, Any]]:
actions = []
for plugin in self.plugin_manager.plugins.values():
if not hasattr(plugin, "get_schedule_actions"):
continue
try:
plugin_actions = plugin.get_schedule_actions() or []
except Exception as e:
logger.error(f"读取插件 {plugin.name} 调度动作失败: {e}")
continue
for action in plugin_actions:
actions.append(
{
"plugin_name": plugin.name,
"action_key": action.get("action_key"),
"action_name": action.get("name", action.get("action_key", "")),
"description": action.get("description", ""),
"trigger_type": action.get("trigger_type", "at_times"),
"trigger_config": action.get("trigger_config", {"time_list": ["09:00"]}),
"target_scope": action.get("target_scope", "all_enabled_groups"),
"target_config": action.get("target_config", {}),
"payload": action.get("payload", {}),
"enabled": bool(action.get("default_enabled", False)),
}
)
return actions
def sync_defaults(self):
for item in self._get_plugin_actions():
if not item.get("plugin_name") or not item.get("action_key"):
continue
self.db.upsert_default_schedule(item)
def _resolve_targets(self, plugin, schedule_row: Dict[str, Any]) -> List[str]:
scope = str(schedule_row.get("target_scope") or "all_enabled_groups")
target_cfg = schedule_row.get("target_config") or {}
if scope == "single_group":
gid = str(target_cfg.get("group_id") or "").strip()
return [gid] if gid else []
if scope == "group_whitelist":
group_ids = target_cfg.get("group_ids") or []
return [str(x).strip() for x in group_ids if str(x).strip()]
# 默认:所有已启用群
all_groups = GroupBotManager.get_group_list()
if not getattr(plugin, "feature", None):
return all_groups
enabled_groups = []
for gid in all_groups:
if GroupBotManager.get_group_permission(gid, plugin.feature) == PermissionStatus.ENABLED:
enabled_groups.append(gid)
return enabled_groups
@staticmethod
def _latest_expected_run_before_now(trigger_type: str, trigger_config: Dict[str, Any], now: datetime) -> datetime | None:
cfg = trigger_config or {}
if trigger_type == "every_seconds":
seconds = int(cfg.get("seconds") or 0)
if seconds <= 0:
return None
return now - timedelta(seconds=seconds)
if trigger_type == "at_times":
time_list = cfg.get("time_list") or []
candidates = []
for text in time_list:
try:
tm = datetime.strptime(str(text), "%H:%M").time()
except Exception:
continue
dt = datetime.combine(now.date(), tm)
if dt > now:
dt -= timedelta(days=1)
candidates.append(dt)
return max(candidates) if candidates else None
if trigger_type in ("every_weekday_time", "every_week_time"):
try:
weekday = int(cfg.get("weekday"))
tm = datetime.strptime(str(cfg.get("time_str") or ""), "%H:%M").time()
except Exception:
return None
days_ago = (now.weekday() - weekday + 7) % 7
dt = datetime.combine((now - timedelta(days=days_ago)).date(), tm)
if dt > now:
dt -= timedelta(days=7)
return dt
if trigger_type == "every_month_last_day_time":
try:
tm = datetime.strptime(str(cfg.get("time_str") or ""), "%H:%M").time()
except Exception:
return None
if now.month == 12:
next_month = datetime(now.year + 1, 1, 1)
else:
next_month = datetime(now.year, now.month + 1, 1)
last_day = next_month - timedelta(days=1)
dt = datetime.combine(last_day.date(), tm)
if dt > now:
if now.month == 1:
prev_next_month = datetime(now.year, 1, 1)
else:
prev_next_month = datetime(now.year, now.month, 1)
prev_last_day = prev_next_month - timedelta(days=1)
dt = datetime.combine(prev_last_day.date(), tm)
return dt
return None
@staticmethod
def _run_coro_blocking(coro):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
return loop.create_task(coro)
def _should_compensate_once(self, schedule_id: int, trigger_type: str, trigger_config: Dict[str, Any]) -> bool:
expected_at = self._latest_expected_run_before_now(trigger_type, trigger_config, datetime.now())
if not expected_at:
return False
latest_log_at = self.db.get_latest_log_time(schedule_id)
if not latest_log_at:
return False
return latest_log_at < (expected_at - timedelta(seconds=self._compensation_tolerance_seconds))
async def _run_one_schedule(self, schedule_row: Dict[str, Any]) -> Dict[str, Any]:
schedule_id = int(schedule_row["id"])
action_key = schedule_row.get("action_key")
plugin_name = schedule_row.get("plugin_name")
_, plugin = self.plugin_manager.find_plugin_by_name(plugin_name)
if not plugin:
detail = {"error": f"未找到插件: {plugin_name}"}
self.db.create_log(schedule_id, "failed", detail["error"], detail)
return {"success": False, "summary": detail["error"], "detail": detail}
if not hasattr(plugin, "run_scheduled_action"):
detail = {"error": f"插件 {plugin.name} 未实现 run_scheduled_action"}
self.db.create_log(schedule_id, "failed", detail["error"], detail)
return {"success": False, "summary": detail["error"], "detail": detail}
targets = self._resolve_targets(plugin, schedule_row)
payload = schedule_row.get("payload") or {}
ctx = {
"schedule_id": schedule_id,
"triggered_at": datetime.now().isoformat(timespec="seconds"),
"target_scope": schedule_row.get("target_scope"),
"target_config": schedule_row.get("target_config") or {},
"target_groups": targets,
"payload": payload,
"bot": getattr(plugin, "bot", None),
}
try:
res = await plugin.run_scheduled_action(action_key, ctx)
if not isinstance(res, dict):
res = {"success": bool(res), "summary": "插件返回非 dict已兼容处理", "detail": {"result": str(res)}}
except Exception as e:
res = {"success": False, "summary": f"执行异常: {e}", "detail": {"error": str(e)}}
status = "success" if res.get("success") else "failed"
summary = str(res.get("summary") or ("执行成功" if status == "success" else "执行失败"))
detail = res.get("detail") or {}
detail["target_count"] = len(targets)
self.db.create_log(schedule_id, status, summary, detail)
return {"success": status == "success", "summary": summary, "detail": detail}
def reload_from_db(self):
self.sync_defaults()
# 清理旧注册,避免重复
for job_id in list(self._schedule_job_map.values()):
async_job.remove_job(job_id)
self._schedule_job_map = {}
rows = self.db.list_enabled_schedules()
for row in rows:
schedule_id = int(row["id"])
async def _runner(_row=row):
await self._run_one_schedule(_row)
job_id = async_job.register_callable(
func=_runner,
trigger_type=row.get("trigger_type", "at_times"),
trigger_config=row.get("trigger_config", {"time_list": ["09:00"]}),
job_name=f"[插件调度]{row.get('plugin_name')}:{row.get('action_name')}",
description=row.get("description", ""),
job_key=f"plugin_schedule:{schedule_id}",
)
self._schedule_job_map[schedule_id] = job_id
# 重启/重载补偿:如果最近一次应执行时间已过且日志未覆盖,补跑一次。
try:
trigger_type = row.get("trigger_type", "at_times")
trigger_config = row.get("trigger_config", {"time_list": ["09:00"]})
if self._should_compensate_once(schedule_id, trigger_type, trigger_config):
logger.warning(
f"插件调度触发漏执行补偿: schedule_id={schedule_id}, "
f"plugin={row.get('plugin_name')}, action={row.get('action_key')}"
)
self._run_coro_blocking(_runner())
except Exception as e:
logger.error(f"插件调度漏执行补偿失败: schedule_id={schedule_id}, error={e}")
def list_schedules_with_runtime(self) -> List[Dict[str, Any]]:
db_rows = self.db.list_schedules()
runtime_rows = async_job.get_jobs_snapshot()
runtime_by_key = {row.get("job_key"): row for row in runtime_rows if row.get("job_key")}
# 日志兜底:进程重启后内存态 last_run_at 会丢失,任务页需要从数据库最新日志恢复显示。
schedule_ids = [int(row.get("id")) for row in db_rows if row.get("id") is not None]
latest_log_by_schedule = self.db.get_latest_logs_map(schedule_ids)
data = []
for row in db_rows:
schedule_id = int(row["id"])
key = f"plugin_schedule:{schedule_id}"
runtime = runtime_by_key.get(key, {})
latest_log = latest_log_by_schedule.get(schedule_id) or {}
merged = dict(row)
merged["runtime_job_id"] = runtime.get("id")
merged["running"] = runtime.get("running", False)
merged["trigger_text"] = runtime.get("trigger_text", "")
merged["next_run_at"] = runtime.get("next_run_at")
# last_run_at 等字段优先取运行时;若缺失则用最新日志兜底,避免页面显示空白。
merged["last_run_at"] = runtime.get("last_run_at") or latest_log.get("triggered_at")
merged["last_status"] = runtime.get("last_status") or latest_log.get("status")
merged["last_error"] = runtime.get("last_error") or ""
if not merged["last_error"] and str(merged["last_status"]) == "failed":
merged["last_error"] = str(latest_log.get("summary") or "")
merged["last_duration_ms"] = runtime.get("last_duration_ms")
merged["run_count"] = runtime.get("run_count", 0)
merged["success_count"] = runtime.get("success_count", 0)
merged["fail_count"] = runtime.get("fail_count", 0)
data.append(merged)
return data
def trigger_now(self, schedule_id: int) -> (bool, str):
job_key = f"plugin_schedule:{int(schedule_id)}"
job_id = async_job.get_job_id_by_key(job_key)
if not job_id:
self.reload_from_db()
job_id = async_job.get_job_id_by_key(job_key)
if not job_id:
return False, "该调度未启用或未加载"
return async_job.trigger_job_now(job_id, operator="dashboard")
def update_schedule(self, schedule_id: int, updates: Dict[str, Any]) -> bool:
ok = self.db.update_schedule(int(schedule_id), updates)
if ok:
self.reload_from_db()
return ok
def get_logs(self, schedule_id: int, limit: int = 100) -> List[Dict[str, Any]]:
return self.db.get_logs(int(schedule_id), limit=limit)
def get_available_plugin_actions(self) -> List[Dict[str, Any]]:
return self._get_plugin_actions()