From f0414e0dff50e35c2e7130141efe30081c850833 Mon Sep 17 00:00:00 2001 From: liuwei Date: Thu, 16 Apr 2026 13:54:56 +0800 Subject: [PATCH] feat(plugin): support auto bot injection and file-based hot reload --- base/plugin_common/plugin_manager.py | 164 ++++++++++++++++++++++++ plugins/message_summary/main.py | 3 + robot.py | 1 + utils/decorator/async_job.py | 180 ++++++++++++++++++++++++--- 4 files changed, 330 insertions(+), 18 deletions(-) diff --git a/base/plugin_common/plugin_manager.py b/base/plugin_common/plugin_manager.py index 16ad3dc..d614b1d 100644 --- a/base/plugin_common/plugin_manager.py +++ b/base/plugin_common/plugin_manager.py @@ -2,6 +2,8 @@ import importlib import inspect import os import sys +import threading +import time from typing import Dict, List, Any, Optional, Tuple from loguru import logger @@ -11,6 +13,7 @@ from base.plugin_common.message_plugin_interface import MessagePluginInterface from base.plugin_common.scheduled_plugin_interface import ScheduledPluginInterface from base.plugin_common.plugin_registry import PluginRegistry from base.plugin_common.event_system import EventSystem, EventType +from utils.decorator.async_job import async_job from wechat_ipad import WechatAPIClient @@ -53,6 +56,14 @@ class PluginManager: self.plugin_modules = {} # 插件模块字典,键为module_name self.module_to_display = {} # 模块名到显示名的映射 self.system_context = {} # 系统上下文 + self.current_bot: Optional[WechatAPIClient] = None + + # 热加载相关 + self._watcher_thread: Optional[threading.Thread] = None + self._watcher_stop_event = threading.Event() + self._watcher_interval = 2.0 + self._module_file_state: Dict[str, Dict[str, float]] = {} + self._watcher_lock = threading.RLock() self.LOG = logger # 确保插件目录存在 @@ -71,6 +82,139 @@ class PluginManager: context: 系统上下文 """ self.system_context = context + bot = context.get("bot") + if bot is not None: + self.current_bot = bot + + def _build_module_file_state(self, module_name: str) -> Optional[Dict[str, float]]: + """ + 构建模块的文件状态快照,用于检测文件变更 + """ + plugin_folder = os.path.join(self.plugin_dir, module_name) + file_state: Dict[str, float] = {} + + if os.path.isdir(plugin_folder): + for root, _, files in os.walk(plugin_folder): + for filename in files: + if not (filename.endswith(".py") or filename.endswith(".toml")): + continue + file_path = os.path.join(root, filename) + try: + file_state[os.path.abspath(file_path)] = os.path.getmtime(file_path) + except OSError: + continue + return file_state if file_state else None + + single_file = os.path.join(self.plugin_dir, f"{module_name}.py") + if os.path.exists(single_file): + try: + file_state[os.path.abspath(single_file)] = os.path.getmtime(single_file) + except OSError: + return None + return file_state + + return None + + def _refresh_module_file_state(self, module_name: str): + state = self._build_module_file_state(module_name) + if state is None: + self._module_file_state.pop(module_name, None) + else: + self._module_file_state[module_name] = state + + def _inject_bot_to_plugin(self, plugin: PluginInterface): + bot = self.current_bot or self.system_context.get("bot") + if not bot: + return + if hasattr(plugin, "set_bot"): + try: + plugin.set_bot(bot) + except Exception as e: + self.LOG.error(f"自动注入 bot 到插件 {plugin.name} 失败: {e}") + + def start_hot_reload_watcher(self, interval_seconds: float = 2.0): + """ + 启动插件目录变更监听线程(轮询) + """ + with self._watcher_lock: + if self._watcher_thread and self._watcher_thread.is_alive(): + self.LOG.debug("PluginManager:热加载监听线程已运行,跳过重复启动") + return + + self._watcher_interval = max(float(interval_seconds), 0.5) + self._watcher_stop_event.clear() + + # 初始化快照 + for module_name in self.discover_plugins(): + self._refresh_module_file_state(module_name) + + self._watcher_thread = threading.Thread( + target=self._hot_reload_watch_loop, + name="plugin-hot-reload-watcher", + daemon=True, + ) + self._watcher_thread.start() + self.LOG.info(f"PluginManager:插件热加载监听已启动,轮询间隔 {self._watcher_interval}s") + + def stop_hot_reload_watcher(self): + """ + 停止插件目录变更监听线程 + """ + with self._watcher_lock: + if not self._watcher_thread: + return + self._watcher_stop_event.set() + thread = self._watcher_thread + self._watcher_thread = None + + if thread.is_alive(): + thread.join(timeout=2.0) + self.LOG.info("PluginManager:插件热加载监听已停止") + + def _hot_reload_watch_loop(self): + while not self._watcher_stop_event.is_set(): + try: + discovered = set(self.discover_plugins()) + loaded_modules = set(self.module_to_display.keys()) + + # 1. 新增插件 -> 自动加载并启动 + new_modules = discovered - loaded_modules + for module_name in new_modules: + plugin = self.load_plugin(module_name) + if plugin: + self.start_plugin(plugin.name) + self.LOG.info(f"PluginManager:检测到新增插件 {module_name},已自动加载") + self._refresh_module_file_state(module_name) + + # 2. 已删除插件 -> 自动卸载 + removed_modules = loaded_modules - discovered + for module_name in removed_modules: + if self.unload_plugin(module_name): + self.LOG.info(f"PluginManager:检测到插件 {module_name} 已删除,已自动卸载") + self._module_file_state.pop(module_name, None) + + # 3. 文件变更 -> 自动重载 + for module_name in list(self.module_to_display.keys()): + new_state = self._build_module_file_state(module_name) + old_state = self._module_file_state.get(module_name) + if new_state is None: + continue + if old_state is None: + self._module_file_state[module_name] = new_state + continue + + if new_state != old_state: + reloaded = self.reload_plugin(module_name) + if reloaded: + self.LOG.info(f"PluginManager:检测到插件 {module_name} 文件变更,已自动重载") + self._module_file_state[module_name] = new_state + else: + self.LOG.warning(f"PluginManager:插件 {module_name} 自动重载失败") + + except Exception as e: + self.LOG.error(f"PluginManager:热加载监听异常: {e}", exc_info=True) + + time.sleep(self._watcher_interval) def discover_plugins(self) -> List[str]: """ @@ -207,6 +351,7 @@ class PluginManager: if module_name not in self.module_to_display: self.module_to_display[module_name] = display_name self.LOG.debug(f"PluginManager:添加缺失的模块映射 {module_name} -> {display_name}") + self._inject_bot_to_plugin(plugin) return plugin except Exception as e: self.LOG.warning(f"获取插件 {display_name} 的模块名时出错: {e}") @@ -258,12 +403,15 @@ class PluginManager: # 加载插件配置 if not plugin.load_config(): self.LOG.error(f"PluginManager:插件模块 {module_name} 加载配置失败") + async_job.remove_jobs_by_owner(plugin) return None # 初始化插件 if not plugin.initialize(self.system_context): self.LOG.error(f"PluginManager:插件模块 {module_name} 初始化失败") + async_job.remove_jobs_by_owner(plugin) return None + self._inject_bot_to_plugin(plugin) # 注册插件 PluginRegistry().register(plugin) @@ -276,6 +424,7 @@ class PluginManager: # 添加模块名到显示名的映射 self.module_to_display[module_name] = display_name + self._refresh_module_file_state(module_name) # self.LOG.info(f"PluginManager:添加模块映射 {module_name} -> {display_name}") # 发布插件加载事件 @@ -299,18 +448,22 @@ class PluginManager: # 加载插件配置 if not plugin.load_config(): self.LOG.error(f"PluginManager:插件模块 {module_name} 加载配置失败") + async_job.remove_jobs_by_owner(plugin) return None # 修改检查enable状态的代码:遍历所有配置节点 for section in plugin._config.values(): if isinstance(section, dict) and not section.get("enable", True): self.LOG.debug(f"PluginManager:插件 {module_name} 已禁用,跳过加载") + async_job.remove_jobs_by_owner(plugin) return None # 初始化插件 if not plugin.initialize(self.system_context): self.LOG.error(f"PluginManager:插件模块 {module_name} 初始化失败") + async_job.remove_jobs_by_owner(plugin) return None + self._inject_bot_to_plugin(plugin) # 注册插件 PluginRegistry().register(plugin) @@ -323,6 +476,7 @@ class PluginManager: # 添加模块名到显示名的映射 self.module_to_display[module_name] = display_name + self._refresh_module_file_state(module_name) # self.LOG.info(f"PluginManager:添加模块映射 {module_name} -> {display_name}") # 发布插件加载事件 @@ -331,6 +485,9 @@ class PluginManager: return plugin except Exception as e: + plugin_obj = locals().get("plugin") + if plugin_obj is not None: + async_job.remove_jobs_by_owner(plugin_obj) self.LOG.exception(f"PluginManager:加载插件模块 {module_name} 失败: {e}", exc_info=True) return None @@ -358,6 +515,10 @@ class PluginManager: return False plugin.status = PluginStatus.STOPPED # 确保状态更新 + removed_jobs = async_job.remove_jobs_by_owner(plugin) + if removed_jobs: + self.LOG.debug(f"PluginManager:已移除插件 {display_name} 的定时任务 {removed_jobs} 个") + # 清理插件资源 if not plugin.cleanup(): self.LOG.debug(f"PluginManager:清理插件 {display_name} 资源失败") @@ -374,6 +535,7 @@ class PluginManager: if module_name and module_name in self.module_to_display: del self.module_to_display[module_name] self.LOG.debug(f"PluginManager:清理模块映射 {module_name} -> {display_name}") + self._module_file_state.pop(module_name, None) # 移除插件实例 del self.plugins[display_name] @@ -498,6 +660,7 @@ class PluginManager: 是否全部成功卸载 """ success = True + self.stop_hot_reload_watcher() # 创建插件名称的副本,因为在卸载过程中会修改self.plugins字典 display_names = list(self.plugins.keys()) @@ -574,6 +737,7 @@ class PluginManager: return None, None def inject_bot(self, bot: WechatAPIClient): + self.current_bot = bot self.system_context["bot"] = bot for name, plugin in self.plugins.items(): # self.LOG.debug(f"plugin name{name}, plugin: {plugin}") diff --git a/plugins/message_summary/main.py b/plugins/message_summary/main.py index e1807f3..50da87b 100644 --- a/plugins/message_summary/main.py +++ b/plugins/message_summary/main.py @@ -548,6 +548,9 @@ class MessageSummaryPlugin(MessagePluginInterface): async def daily_summary_job(self): """定时任务:每天早上9点总结昨天的聊天信息""" try: + if not self.bot: + self.LOG.warning("每日聊天总结任务跳过:bot 尚未注入") + return self.LOG.info("开始执行每日聊天总结任务") # 计算昨天的时间范围 diff --git a/robot.py b/robot.py index af77ed9..807ced1 100644 --- a/robot.py +++ b/robot.py @@ -98,6 +98,7 @@ class Robot: self.plugin_manager = PluginManager(plugin_dir=getattr(self.config, "plugin_dir", "plugins")) self.plugin_manager.set_system_context(self.system_context) self.plugins = self.plugin_manager.load_all_plugins() + self.plugin_manager.start_hot_reload_watcher(interval_seconds=2.0) # 加载插件 self.LOG.debug("插件系统初始化完成") diff --git a/utils/decorator/async_job.py b/utils/decorator/async_job.py index fab315f..a41a09f 100644 --- a/utils/decorator/async_job.py +++ b/utils/decorator/async_job.py @@ -1,20 +1,115 @@ import asyncio +import threading from datetime import datetime, timedelta -from typing import Callable, Awaitable, List, Dict +from typing import Callable, Awaitable, List, Dict, Optional, Any class AsyncJob: def __init__(self): - self.tasks: List[Callable[[], Awaitable]] = [] + self._jobs: Dict[str, Dict[str, Any]] = {} + self._running_tasks: Dict[str, asyncio.Task] = {} + self._running = False + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._stop_event: Optional[asyncio.Event] = None + self._lock = threading.RLock() + self._job_seq = 0 + + def _next_job_id(self) -> str: + with self._lock: + self._job_seq += 1 + return f"job-{self._job_seq}" + + @staticmethod + def _infer_owner(func: Callable) -> Optional[Any]: + owner = getattr(func, "__self__", None) + if owner is not None: + return owner + + closure = getattr(func, "__closure__", None) or [] + for cell in closure: + try: + value = cell.cell_contents + except ValueError: + continue + if all(hasattr(value, attr) for attr in ("initialize", "start", "stop")): + return value + return None + + def _register(self, func: Callable, wrapper: Callable[[], Awaitable], trigger: str): + owner = self._infer_owner(func) + job_id = self._next_job_id() + with self._lock: + self._jobs[job_id] = { + "func": func, + "wrapper": wrapper, + "trigger": trigger, + "owner_id": id(owner) if owner is not None else None, + "owner_name": owner.__class__.__name__ if owner is not None else None, + } + if self._running and self._loop and self._loop.is_running(): + self._loop.call_soon_threadsafe(self._start_job_in_loop, job_id) + + def _start_job_in_loop(self, job_id: str): + job = self._jobs.get(job_id) + if not job or job_id in self._running_tasks: + return + + async def runner(): + try: + await job["wrapper"]() + except asyncio.CancelledError: + raise + except Exception as e: + print(f"[AsyncJob] 任务异常退出: {job_id}, trigger={job.get('trigger')}, error={e}") + + task = asyncio.create_task(runner(), name=f"async_job:{job_id}") + self._running_tasks[job_id] = task + task.add_done_callback(lambda _task, _id=job_id: self._running_tasks.pop(_id, None)) + + def _cancel_job_in_loop(self, job_id: str): + task = self._running_tasks.pop(job_id, None) + if task: + task.cancel() + + def remove_job(self, job_id: str) -> bool: + with self._lock: + existed = job_id in self._jobs or job_id in self._running_tasks + self._jobs.pop(job_id, None) + loop = self._loop + running = self._running + + if running and loop and loop.is_running(): + loop.call_soon_threadsafe(self._cancel_job_in_loop, job_id) + else: + task = self._running_tasks.pop(job_id, None) + if task: + task.cancel() + return existed + + def remove_jobs_by_owner(self, owner: Any) -> int: + owner_id = id(owner) + with self._lock: + job_ids = [job_id for job_id, meta in self._jobs.items() if meta.get("owner_id") == owner_id] + + removed = 0 + for job_id in job_ids: + if self.remove_job(job_id): + removed += 1 + return removed def every_seconds(self, seconds: int): def decorator(func: Callable): async def wrapper(): while True: - await func() + try: + await func() + except asyncio.CancelledError: + raise + except Exception as e: + print(f"[AsyncJob] every_seconds 任务执行异常: {e}") await asyncio.sleep(seconds) - self.tasks.append(wrapper) + self._register(func, wrapper, f"every_seconds({seconds})") return func return decorator @@ -27,18 +122,29 @@ class AsyncJob: def at_times(self, time_list: List[str]): def decorator(func: Callable): + parsed_times = [datetime.strptime(t, "%H:%M").time() for t in time_list] + async def wrapper(): while True: now = datetime.now() - for t in time_list: - target = datetime.strptime(t, "%H:%M").replace(year=now.year, month=now.month, day=now.day) - if target < now: + targets = [] + for t in parsed_times: + target = datetime.combine(now.date(), t) + if target <= now: target += timedelta(days=1) - wait_seconds = (target - now).total_seconds() - await asyncio.sleep(wait_seconds) - await func() + targets.append(target) - self.tasks.append(wrapper) + next_target = min(targets) + wait_seconds = (next_target - now).total_seconds() + await asyncio.sleep(max(wait_seconds, 0)) + try: + await func() + except asyncio.CancelledError: + raise + except Exception as e: + print(f"[AsyncJob] at_times 任务执行异常: {e}") + + self._register(func, wrapper, f"at_times({time_list})") return func return decorator @@ -64,9 +170,14 @@ class AsyncJob: sleep_secs = (target_dt - now).total_seconds() await asyncio.sleep(sleep_secs) - await func() + try: + await func() + except asyncio.CancelledError: + raise + except Exception as e: + print(f"[AsyncJob] every_weekday_time 任务执行异常: {e}") - self.tasks.append(wrapper) + self._register(func, wrapper, f"every_weekday_time({weekday}, {time_str})") return func return decorator @@ -91,9 +202,14 @@ class AsyncJob: sleep_secs = (target_dt - now).total_seconds() await asyncio.sleep(sleep_secs) - await func() + try: + await func() + except asyncio.CancelledError: + raise + except Exception as e: + print(f"[AsyncJob] every_week_time 任务执行异常: {e}") - self.tasks.append(wrapper) + self._register(func, wrapper, f"every_week_time({weekday}, {time_str})") return func return decorator @@ -128,15 +244,43 @@ class AsyncJob: sleep_secs = (target_dt - now).total_seconds() await asyncio.sleep(sleep_secs) - await func() + try: + await func() + except asyncio.CancelledError: + raise + except Exception as e: + print(f"[AsyncJob] every_month_last_day_time 任务执行异常: {e}") - self.tasks.append(wrapper) + self._register(func, wrapper, f"every_month_last_day_time({time_str})") return func return decorator async def run_all(self): - await asyncio.gather(*(task() for task in self.tasks)) + with self._lock: + self._running = True + self._loop = asyncio.get_running_loop() + self._stop_event = asyncio.Event() + job_ids = list(self._jobs.keys()) + + for job_id in job_ids: + self._start_job_in_loop(job_id) + + await self._stop_event.wait() + + def stop_all(self): + with self._lock: + self._running = False + loop = self._loop + self._jobs.clear() + self._job_seq = 0 + stop_event = self._stop_event + + if loop and loop.is_running(): + for job_id in list(self._running_tasks.keys()): + loop.call_soon_threadsafe(self._cancel_job_in_loop, job_id) + if stop_event: + loop.call_soon_threadsafe(stop_event.set) # 全局唯一 job 管理器