feat(plugin): support auto bot injection and file-based hot reload

This commit is contained in:
liuwei
2026-04-16 13:54:56 +08:00
parent 041a3f30d8
commit f0414e0dff
4 changed files with 330 additions and 18 deletions

View File

@@ -2,6 +2,8 @@ import importlib
import inspect import inspect
import os import os
import sys import sys
import threading
import time
from typing import Dict, List, Any, Optional, Tuple from typing import Dict, List, Any, Optional, Tuple
from loguru import logger from loguru import logger
@@ -11,6 +13,7 @@ from base.plugin_common.message_plugin_interface import MessagePluginInterface
from base.plugin_common.scheduled_plugin_interface import ScheduledPluginInterface from base.plugin_common.scheduled_plugin_interface import ScheduledPluginInterface
from base.plugin_common.plugin_registry import PluginRegistry from base.plugin_common.plugin_registry import PluginRegistry
from base.plugin_common.event_system import EventSystem, EventType from base.plugin_common.event_system import EventSystem, EventType
from utils.decorator.async_job import async_job
from wechat_ipad import WechatAPIClient from wechat_ipad import WechatAPIClient
@@ -53,6 +56,14 @@ class PluginManager:
self.plugin_modules = {} # 插件模块字典键为module_name self.plugin_modules = {} # 插件模块字典键为module_name
self.module_to_display = {} # 模块名到显示名的映射 self.module_to_display = {} # 模块名到显示名的映射
self.system_context = {} # 系统上下文 self.system_context = {} # 系统上下文
self.current_bot: Optional[WechatAPIClient] = None
# 热加载相关
self._watcher_thread: Optional[threading.Thread] = None
self._watcher_stop_event = threading.Event()
self._watcher_interval = 2.0
self._module_file_state: Dict[str, Dict[str, float]] = {}
self._watcher_lock = threading.RLock()
self.LOG = logger self.LOG = logger
# 确保插件目录存在 # 确保插件目录存在
@@ -71,6 +82,139 @@ class PluginManager:
context: 系统上下文 context: 系统上下文
""" """
self.system_context = context self.system_context = context
bot = context.get("bot")
if bot is not None:
self.current_bot = bot
def _build_module_file_state(self, module_name: str) -> Optional[Dict[str, float]]:
"""
构建模块的文件状态快照,用于检测文件变更
"""
plugin_folder = os.path.join(self.plugin_dir, module_name)
file_state: Dict[str, float] = {}
if os.path.isdir(plugin_folder):
for root, _, files in os.walk(plugin_folder):
for filename in files:
if not (filename.endswith(".py") or filename.endswith(".toml")):
continue
file_path = os.path.join(root, filename)
try:
file_state[os.path.abspath(file_path)] = os.path.getmtime(file_path)
except OSError:
continue
return file_state if file_state else None
single_file = os.path.join(self.plugin_dir, f"{module_name}.py")
if os.path.exists(single_file):
try:
file_state[os.path.abspath(single_file)] = os.path.getmtime(single_file)
except OSError:
return None
return file_state
return None
def _refresh_module_file_state(self, module_name: str):
state = self._build_module_file_state(module_name)
if state is None:
self._module_file_state.pop(module_name, None)
else:
self._module_file_state[module_name] = state
def _inject_bot_to_plugin(self, plugin: PluginInterface):
bot = self.current_bot or self.system_context.get("bot")
if not bot:
return
if hasattr(plugin, "set_bot"):
try:
plugin.set_bot(bot)
except Exception as e:
self.LOG.error(f"自动注入 bot 到插件 {plugin.name} 失败: {e}")
def start_hot_reload_watcher(self, interval_seconds: float = 2.0):
"""
启动插件目录变更监听线程(轮询)
"""
with self._watcher_lock:
if self._watcher_thread and self._watcher_thread.is_alive():
self.LOG.debug("PluginManager热加载监听线程已运行跳过重复启动")
return
self._watcher_interval = max(float(interval_seconds), 0.5)
self._watcher_stop_event.clear()
# 初始化快照
for module_name in self.discover_plugins():
self._refresh_module_file_state(module_name)
self._watcher_thread = threading.Thread(
target=self._hot_reload_watch_loop,
name="plugin-hot-reload-watcher",
daemon=True,
)
self._watcher_thread.start()
self.LOG.info(f"PluginManager插件热加载监听已启动轮询间隔 {self._watcher_interval}s")
def stop_hot_reload_watcher(self):
"""
停止插件目录变更监听线程
"""
with self._watcher_lock:
if not self._watcher_thread:
return
self._watcher_stop_event.set()
thread = self._watcher_thread
self._watcher_thread = None
if thread.is_alive():
thread.join(timeout=2.0)
self.LOG.info("PluginManager插件热加载监听已停止")
def _hot_reload_watch_loop(self):
while not self._watcher_stop_event.is_set():
try:
discovered = set(self.discover_plugins())
loaded_modules = set(self.module_to_display.keys())
# 1. 新增插件 -> 自动加载并启动
new_modules = discovered - loaded_modules
for module_name in new_modules:
plugin = self.load_plugin(module_name)
if plugin:
self.start_plugin(plugin.name)
self.LOG.info(f"PluginManager检测到新增插件 {module_name},已自动加载")
self._refresh_module_file_state(module_name)
# 2. 已删除插件 -> 自动卸载
removed_modules = loaded_modules - discovered
for module_name in removed_modules:
if self.unload_plugin(module_name):
self.LOG.info(f"PluginManager检测到插件 {module_name} 已删除,已自动卸载")
self._module_file_state.pop(module_name, None)
# 3. 文件变更 -> 自动重载
for module_name in list(self.module_to_display.keys()):
new_state = self._build_module_file_state(module_name)
old_state = self._module_file_state.get(module_name)
if new_state is None:
continue
if old_state is None:
self._module_file_state[module_name] = new_state
continue
if new_state != old_state:
reloaded = self.reload_plugin(module_name)
if reloaded:
self.LOG.info(f"PluginManager检测到插件 {module_name} 文件变更,已自动重载")
self._module_file_state[module_name] = new_state
else:
self.LOG.warning(f"PluginManager插件 {module_name} 自动重载失败")
except Exception as e:
self.LOG.error(f"PluginManager热加载监听异常: {e}", exc_info=True)
time.sleep(self._watcher_interval)
def discover_plugins(self) -> List[str]: def discover_plugins(self) -> List[str]:
""" """
@@ -207,6 +351,7 @@ class PluginManager:
if module_name not in self.module_to_display: if module_name not in self.module_to_display:
self.module_to_display[module_name] = display_name self.module_to_display[module_name] = display_name
self.LOG.debug(f"PluginManager添加缺失的模块映射 {module_name} -> {display_name}") self.LOG.debug(f"PluginManager添加缺失的模块映射 {module_name} -> {display_name}")
self._inject_bot_to_plugin(plugin)
return plugin return plugin
except Exception as e: except Exception as e:
self.LOG.warning(f"获取插件 {display_name} 的模块名时出错: {e}") self.LOG.warning(f"获取插件 {display_name} 的模块名时出错: {e}")
@@ -258,12 +403,15 @@ class PluginManager:
# 加载插件配置 # 加载插件配置
if not plugin.load_config(): if not plugin.load_config():
self.LOG.error(f"PluginManager插件模块 {module_name} 加载配置失败") self.LOG.error(f"PluginManager插件模块 {module_name} 加载配置失败")
async_job.remove_jobs_by_owner(plugin)
return None return None
# 初始化插件 # 初始化插件
if not plugin.initialize(self.system_context): if not plugin.initialize(self.system_context):
self.LOG.error(f"PluginManager插件模块 {module_name} 初始化失败") self.LOG.error(f"PluginManager插件模块 {module_name} 初始化失败")
async_job.remove_jobs_by_owner(plugin)
return None return None
self._inject_bot_to_plugin(plugin)
# 注册插件 # 注册插件
PluginRegistry().register(plugin) PluginRegistry().register(plugin)
@@ -276,6 +424,7 @@ class PluginManager:
# 添加模块名到显示名的映射 # 添加模块名到显示名的映射
self.module_to_display[module_name] = display_name self.module_to_display[module_name] = display_name
self._refresh_module_file_state(module_name)
# self.LOG.info(f"PluginManager添加模块映射 {module_name} -> {display_name}") # self.LOG.info(f"PluginManager添加模块映射 {module_name} -> {display_name}")
# 发布插件加载事件 # 发布插件加载事件
@@ -299,18 +448,22 @@ class PluginManager:
# 加载插件配置 # 加载插件配置
if not plugin.load_config(): if not plugin.load_config():
self.LOG.error(f"PluginManager插件模块 {module_name} 加载配置失败") self.LOG.error(f"PluginManager插件模块 {module_name} 加载配置失败")
async_job.remove_jobs_by_owner(plugin)
return None return None
# 修改检查enable状态的代码遍历所有配置节点 # 修改检查enable状态的代码遍历所有配置节点
for section in plugin._config.values(): for section in plugin._config.values():
if isinstance(section, dict) and not section.get("enable", True): if isinstance(section, dict) and not section.get("enable", True):
self.LOG.debug(f"PluginManager插件 {module_name} 已禁用,跳过加载") self.LOG.debug(f"PluginManager插件 {module_name} 已禁用,跳过加载")
async_job.remove_jobs_by_owner(plugin)
return None return None
# 初始化插件 # 初始化插件
if not plugin.initialize(self.system_context): if not plugin.initialize(self.system_context):
self.LOG.error(f"PluginManager插件模块 {module_name} 初始化失败") self.LOG.error(f"PluginManager插件模块 {module_name} 初始化失败")
async_job.remove_jobs_by_owner(plugin)
return None return None
self._inject_bot_to_plugin(plugin)
# 注册插件 # 注册插件
PluginRegistry().register(plugin) PluginRegistry().register(plugin)
@@ -323,6 +476,7 @@ class PluginManager:
# 添加模块名到显示名的映射 # 添加模块名到显示名的映射
self.module_to_display[module_name] = display_name self.module_to_display[module_name] = display_name
self._refresh_module_file_state(module_name)
# self.LOG.info(f"PluginManager添加模块映射 {module_name} -> {display_name}") # self.LOG.info(f"PluginManager添加模块映射 {module_name} -> {display_name}")
# 发布插件加载事件 # 发布插件加载事件
@@ -331,6 +485,9 @@ class PluginManager:
return plugin return plugin
except Exception as e: except Exception as e:
plugin_obj = locals().get("plugin")
if plugin_obj is not None:
async_job.remove_jobs_by_owner(plugin_obj)
self.LOG.exception(f"PluginManager加载插件模块 {module_name} 失败: {e}", exc_info=True) self.LOG.exception(f"PluginManager加载插件模块 {module_name} 失败: {e}", exc_info=True)
return None return None
@@ -358,6 +515,10 @@ class PluginManager:
return False return False
plugin.status = PluginStatus.STOPPED # 确保状态更新 plugin.status = PluginStatus.STOPPED # 确保状态更新
removed_jobs = async_job.remove_jobs_by_owner(plugin)
if removed_jobs:
self.LOG.debug(f"PluginManager已移除插件 {display_name} 的定时任务 {removed_jobs}")
# 清理插件资源 # 清理插件资源
if not plugin.cleanup(): if not plugin.cleanup():
self.LOG.debug(f"PluginManager清理插件 {display_name} 资源失败") self.LOG.debug(f"PluginManager清理插件 {display_name} 资源失败")
@@ -374,6 +535,7 @@ class PluginManager:
if module_name and module_name in self.module_to_display: if module_name and module_name in self.module_to_display:
del self.module_to_display[module_name] del self.module_to_display[module_name]
self.LOG.debug(f"PluginManager清理模块映射 {module_name} -> {display_name}") self.LOG.debug(f"PluginManager清理模块映射 {module_name} -> {display_name}")
self._module_file_state.pop(module_name, None)
# 移除插件实例 # 移除插件实例
del self.plugins[display_name] del self.plugins[display_name]
@@ -498,6 +660,7 @@ class PluginManager:
是否全部成功卸载 是否全部成功卸载
""" """
success = True success = True
self.stop_hot_reload_watcher()
# 创建插件名称的副本因为在卸载过程中会修改self.plugins字典 # 创建插件名称的副本因为在卸载过程中会修改self.plugins字典
display_names = list(self.plugins.keys()) display_names = list(self.plugins.keys())
@@ -574,6 +737,7 @@ class PluginManager:
return None, None return None, None
def inject_bot(self, bot: WechatAPIClient): def inject_bot(self, bot: WechatAPIClient):
self.current_bot = bot
self.system_context["bot"] = bot self.system_context["bot"] = bot
for name, plugin in self.plugins.items(): for name, plugin in self.plugins.items():
# self.LOG.debug(f"plugin name{name}, plugin: {plugin}") # self.LOG.debug(f"plugin name{name}, plugin: {plugin}")

View File

@@ -548,6 +548,9 @@ class MessageSummaryPlugin(MessagePluginInterface):
async def daily_summary_job(self): async def daily_summary_job(self):
"""定时任务每天早上9点总结昨天的聊天信息""" """定时任务每天早上9点总结昨天的聊天信息"""
try: try:
if not self.bot:
self.LOG.warning("每日聊天总结任务跳过bot 尚未注入")
return
self.LOG.info("开始执行每日聊天总结任务") self.LOG.info("开始执行每日聊天总结任务")
# 计算昨天的时间范围 # 计算昨天的时间范围

View File

@@ -98,6 +98,7 @@ class Robot:
self.plugin_manager = PluginManager(plugin_dir=getattr(self.config, "plugin_dir", "plugins")) self.plugin_manager = PluginManager(plugin_dir=getattr(self.config, "plugin_dir", "plugins"))
self.plugin_manager.set_system_context(self.system_context) self.plugin_manager.set_system_context(self.system_context)
self.plugins = self.plugin_manager.load_all_plugins() self.plugins = self.plugin_manager.load_all_plugins()
self.plugin_manager.start_hot_reload_watcher(interval_seconds=2.0)
# 加载插件 # 加载插件
self.LOG.debug("插件系统初始化完成") self.LOG.debug("插件系统初始化完成")

View File

@@ -1,20 +1,115 @@
import asyncio import asyncio
import threading
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Callable, Awaitable, List, Dict from typing import Callable, Awaitable, List, Dict, Optional, Any
class AsyncJob: class AsyncJob:
def __init__(self): def __init__(self):
self.tasks: List[Callable[[], Awaitable]] = [] self._jobs: Dict[str, Dict[str, Any]] = {}
self._running_tasks: Dict[str, asyncio.Task] = {}
self._running = False
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._stop_event: Optional[asyncio.Event] = None
self._lock = threading.RLock()
self._job_seq = 0
def _next_job_id(self) -> str:
with self._lock:
self._job_seq += 1
return f"job-{self._job_seq}"
@staticmethod
def _infer_owner(func: Callable) -> Optional[Any]:
owner = getattr(func, "__self__", None)
if owner is not None:
return owner
closure = getattr(func, "__closure__", None) or []
for cell in closure:
try:
value = cell.cell_contents
except ValueError:
continue
if all(hasattr(value, attr) for attr in ("initialize", "start", "stop")):
return value
return None
def _register(self, func: Callable, wrapper: Callable[[], Awaitable], trigger: str):
owner = self._infer_owner(func)
job_id = self._next_job_id()
with self._lock:
self._jobs[job_id] = {
"func": func,
"wrapper": wrapper,
"trigger": trigger,
"owner_id": id(owner) if owner is not None else None,
"owner_name": owner.__class__.__name__ if owner is not None else None,
}
if self._running and self._loop and self._loop.is_running():
self._loop.call_soon_threadsafe(self._start_job_in_loop, job_id)
def _start_job_in_loop(self, job_id: str):
job = self._jobs.get(job_id)
if not job or job_id in self._running_tasks:
return
async def runner():
try:
await job["wrapper"]()
except asyncio.CancelledError:
raise
except Exception as e:
print(f"[AsyncJob] 任务异常退出: {job_id}, trigger={job.get('trigger')}, error={e}")
task = asyncio.create_task(runner(), name=f"async_job:{job_id}")
self._running_tasks[job_id] = task
task.add_done_callback(lambda _task, _id=job_id: self._running_tasks.pop(_id, None))
def _cancel_job_in_loop(self, job_id: str):
task = self._running_tasks.pop(job_id, None)
if task:
task.cancel()
def remove_job(self, job_id: str) -> bool:
with self._lock:
existed = job_id in self._jobs or job_id in self._running_tasks
self._jobs.pop(job_id, None)
loop = self._loop
running = self._running
if running and loop and loop.is_running():
loop.call_soon_threadsafe(self._cancel_job_in_loop, job_id)
else:
task = self._running_tasks.pop(job_id, None)
if task:
task.cancel()
return existed
def remove_jobs_by_owner(self, owner: Any) -> int:
owner_id = id(owner)
with self._lock:
job_ids = [job_id for job_id, meta in self._jobs.items() if meta.get("owner_id") == owner_id]
removed = 0
for job_id in job_ids:
if self.remove_job(job_id):
removed += 1
return removed
def every_seconds(self, seconds: int): def every_seconds(self, seconds: int):
def decorator(func: Callable): def decorator(func: Callable):
async def wrapper(): async def wrapper():
while True: while True:
try:
await func() await func()
except asyncio.CancelledError:
raise
except Exception as e:
print(f"[AsyncJob] every_seconds 任务执行异常: {e}")
await asyncio.sleep(seconds) await asyncio.sleep(seconds)
self.tasks.append(wrapper) self._register(func, wrapper, f"every_seconds({seconds})")
return func return func
return decorator return decorator
@@ -27,18 +122,29 @@ class AsyncJob:
def at_times(self, time_list: List[str]): def at_times(self, time_list: List[str]):
def decorator(func: Callable): def decorator(func: Callable):
parsed_times = [datetime.strptime(t, "%H:%M").time() for t in time_list]
async def wrapper(): async def wrapper():
while True: while True:
now = datetime.now() now = datetime.now()
for t in time_list: targets = []
target = datetime.strptime(t, "%H:%M").replace(year=now.year, month=now.month, day=now.day) for t in parsed_times:
if target < now: target = datetime.combine(now.date(), t)
if target <= now:
target += timedelta(days=1) target += timedelta(days=1)
wait_seconds = (target - now).total_seconds() targets.append(target)
await asyncio.sleep(wait_seconds)
await func()
self.tasks.append(wrapper) next_target = min(targets)
wait_seconds = (next_target - now).total_seconds()
await asyncio.sleep(max(wait_seconds, 0))
try:
await func()
except asyncio.CancelledError:
raise
except Exception as e:
print(f"[AsyncJob] at_times 任务执行异常: {e}")
self._register(func, wrapper, f"at_times({time_list})")
return func return func
return decorator return decorator
@@ -64,9 +170,14 @@ class AsyncJob:
sleep_secs = (target_dt - now).total_seconds() sleep_secs = (target_dt - now).total_seconds()
await asyncio.sleep(sleep_secs) await asyncio.sleep(sleep_secs)
try:
await func() await func()
except asyncio.CancelledError:
raise
except Exception as e:
print(f"[AsyncJob] every_weekday_time 任务执行异常: {e}")
self.tasks.append(wrapper) self._register(func, wrapper, f"every_weekday_time({weekday}, {time_str})")
return func return func
return decorator return decorator
@@ -91,9 +202,14 @@ class AsyncJob:
sleep_secs = (target_dt - now).total_seconds() sleep_secs = (target_dt - now).total_seconds()
await asyncio.sleep(sleep_secs) await asyncio.sleep(sleep_secs)
try:
await func() await func()
except asyncio.CancelledError:
raise
except Exception as e:
print(f"[AsyncJob] every_week_time 任务执行异常: {e}")
self.tasks.append(wrapper) self._register(func, wrapper, f"every_week_time({weekday}, {time_str})")
return func return func
return decorator return decorator
@@ -128,15 +244,43 @@ class AsyncJob:
sleep_secs = (target_dt - now).total_seconds() sleep_secs = (target_dt - now).total_seconds()
await asyncio.sleep(sleep_secs) await asyncio.sleep(sleep_secs)
try:
await func() await func()
except asyncio.CancelledError:
raise
except Exception as e:
print(f"[AsyncJob] every_month_last_day_time 任务执行异常: {e}")
self.tasks.append(wrapper) self._register(func, wrapper, f"every_month_last_day_time({time_str})")
return func return func
return decorator return decorator
async def run_all(self): async def run_all(self):
await asyncio.gather(*(task() for task in self.tasks)) with self._lock:
self._running = True
self._loop = asyncio.get_running_loop()
self._stop_event = asyncio.Event()
job_ids = list(self._jobs.keys())
for job_id in job_ids:
self._start_job_in_loop(job_id)
await self._stop_event.wait()
def stop_all(self):
with self._lock:
self._running = False
loop = self._loop
self._jobs.clear()
self._job_seq = 0
stop_event = self._stop_event
if loop and loop.is_running():
for job_id in list(self._running_tasks.keys()):
loop.call_soon_threadsafe(self._cancel_job_in_loop, job_id)
if stop_event:
loop.call_soon_threadsafe(stop_event.set)
# 全局唯一 job 管理器 # 全局唯一 job 管理器