import importlib import inspect import os import sys import tomllib import traceback from typing import Dict, Type, List, Union from loguru import logger # from WechatAPI import WechatAPIClient # 注释掉,WechatHookBot 不需要这个导入 from utils.singleton import Singleton from .event_manager import EventManager from .plugin_base import PluginBase class PluginManager(metaclass=Singleton): def __init__(self): self.plugins: Dict[str, PluginBase] = {} self.plugin_classes: Dict[str, Type[PluginBase]] = {} self.plugin_info: Dict[str, dict] = {} # 新增:存储所有插件信息 self.bot = None with open("main_config.toml", "rb") as f: main_config = tomllib.load(f) self.excluded_plugins = main_config.get("Bot", {}).get("disabled-plugins", []) def set_bot(self, bot): """设置 bot 客户端(WechatHookClient)""" self.bot = bot async def load_plugin(self, plugin: Union[Type[PluginBase], str]) -> bool: if isinstance(plugin, str): return await self._load_plugin_name(plugin) elif isinstance(plugin, type) and issubclass(plugin, PluginBase): return await self._load_plugin_class(plugin) async def _load_plugin_class(self, plugin_class: Type[PluginBase], is_disabled: bool = False) -> bool: """加载单个插件,接受Type[PluginBase]""" try: plugin_name = plugin_class.__name__ # 防止重复加载插件 if plugin_name in self.plugins: return False # 安全获取插件目录名 directory = "unknown" try: module_name = plugin_class.__module__ if module_name.startswith("plugins."): directory = module_name.split('.')[1] else: logger.warning(f"非常规插件模块路径: {module_name}") except Exception as e: logger.error(f"获取插件目录失败: {e}") directory = "error" # 记录插件信息,即使插件被禁用也会记录 self.plugin_info[plugin_name] = { "name": plugin_name, "description": plugin_class.description, "author": plugin_class.author, "version": plugin_class.version, "directory": directory, "enabled": False, "class": plugin_class } # 如果插件被禁用则不加载 if is_disabled: return False plugin = plugin_class() EventManager.bind_instance(plugin) await plugin.on_enable(self.bot) await plugin.async_init() self.plugins[plugin_name] = plugin self.plugin_classes[plugin_name] = plugin_class self.plugin_info[plugin_name]["enabled"] = True logger.success(f"加载插件 {plugin_name} 成功") return True except: logger.error(f"加载插件时发生错误: {traceback.format_exc()}") return False async def _load_plugin_name(self, plugin_name: str) -> bool: """从plugins目录加载单个插件 Args: plugin_name: 插件类名称(不是文件名) Returns: bool: 是否成功加载插件 """ found = False for dirname in os.listdir("plugins"): try: if os.path.isdir(f"plugins/{dirname}") and os.path.exists(f"plugins/{dirname}/main.py"): module = importlib.import_module(f"plugins.{dirname}.main") importlib.reload(module) for name, obj in inspect.getmembers(module): if (inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase and obj.__name__ == plugin_name): found = True return await self._load_plugin_class(obj) except: logger.error(f"检查 {dirname} 时发生错误: {traceback.format_exc()}") continue if not found: logger.warning(f"未找到插件类 {plugin_name}") def _resolve_load_order(self, plugin_classes: List[Type[PluginBase]]) -> List[Type[PluginBase]]: """ 解析插件加载顺序(拓扑排序 + 优先级排序) Args: plugin_classes: 插件类列表 Returns: 按依赖关系和优先级排序后的插件类列表 """ # 构建插件名到类的映射 name_to_class = {cls.__name__: cls for cls in plugin_classes} # 构建依赖图 dependencies = {} for cls in plugin_classes: deps = getattr(cls, 'dependencies', []) dependencies[cls.__name__] = [d for d in deps if d in name_to_class] # 拓扑排序 sorted_names = [] visited = set() temp_visited = set() def visit(name: str): if name in temp_visited: # 检测到循环依赖 logger.warning(f"检测到循环依赖: {name}") return if name in visited: return temp_visited.add(name) # 先访问依赖 for dep in dependencies.get(name, []): visit(dep) temp_visited.remove(name) visited.add(name) sorted_names.append(name) # 按优先级排序后再进行拓扑排序 priority_sorted = sorted( plugin_classes, key=lambda cls: getattr(cls, 'load_priority', 50), reverse=True ) for cls in priority_sorted: if cls.__name__ not in visited: visit(cls.__name__) # 返回排序后的类列表 return [name_to_class[name] for name in sorted_names if name in name_to_class] async def load_plugins(self, load_disabled: bool = True) -> Union[List[str], bool]: """加载所有插件(按依赖顺序)""" loaded_plugins = [] # 第一步:收集所有插件类 all_plugin_classes = [] plugin_disabled_map = {} for dirname in os.listdir("plugins"): if os.path.isdir(f"plugins/{dirname}") and os.path.exists(f"plugins/{dirname}/main.py"): try: module = importlib.import_module(f"plugins.{dirname}.main") for name, obj in inspect.getmembers(module): if inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase: all_plugin_classes.append(obj) # 记录是否禁用 is_disabled = False if not load_disabled: is_disabled = obj.__name__ in self.excluded_plugins or dirname in self.excluded_plugins plugin_disabled_map[obj.__name__] = is_disabled except: logger.error(f"加载 {dirname} 时发生错误: {traceback.format_exc()}") # 第二步:按依赖顺序排序 sorted_classes = self._resolve_load_order(all_plugin_classes) logger.info(f"插件加载顺序: {[cls.__name__ for cls in sorted_classes]}") # 第三步:按顺序加载插件 for plugin_class in sorted_classes: plugin_name = plugin_class.__name__ is_disabled = plugin_disabled_map.get(plugin_name, False) # 检查依赖是否已加载 deps = getattr(plugin_class, 'dependencies', []) deps_satisfied = all(dep in self.plugins for dep in deps) if not deps_satisfied and not is_disabled: missing_deps = [dep for dep in deps if dep not in self.plugins] logger.warning(f"插件 {plugin_name} 的依赖未满足: {missing_deps},跳过加载") continue if await self._load_plugin_class(plugin_class, is_disabled=is_disabled): loaded_plugins.append(plugin_name) return loaded_plugins async def unload_plugin(self, plugin_name: str) -> bool: """卸载单个插件""" if plugin_name not in self.plugins: return False # 防止卸载 ManagePlugin if plugin_name == "ManagePlugin": logger.warning("ManagePlugin 不能被卸载") return False try: plugin = self.plugins[plugin_name] await plugin.on_disable() EventManager.unbind_instance(plugin) del self.plugins[plugin_name] del self.plugin_classes[plugin_name] if plugin_name in self.plugin_info.keys(): self.plugin_info[plugin_name]["enabled"] = False logger.success(f"卸载插件 {plugin_name} 成功") return True except: logger.error(f"卸载插件 {plugin_name} 时发生错误: {traceback.format_exc()}") return False async def unload_plugins(self) -> tuple[List[str], List[str]]: """卸载所有插件""" unloaded_plugins = [] failed_unloads = [] for plugin_name in list(self.plugins.keys()): if await self.unload_plugin(plugin_name): unloaded_plugins.append(plugin_name) else: failed_unloads.append(plugin_name) return unloaded_plugins, failed_unloads async def reload_plugin(self, plugin_name: str) -> bool: """重载单个插件""" if plugin_name not in self.plugin_classes: return False # 防止重载 ManagePlugin if plugin_name == "ManagePlugin": logger.warning("ManagePlugin 不能被重载") return False try: # 获取插件类所在的模块 plugin_class = self.plugin_classes[plugin_name] module_name = plugin_class.__module__ # 先卸载插件 if not await self.unload_plugin(plugin_name): return False # 重新导入模块 module = importlib.import_module(module_name) importlib.reload(module) # 从重新加载的模块中获取插件类 for name, obj in inspect.getmembers(module): if (inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase and obj.__name__ == plugin_name): # 使用新的插件类而不是旧的 return await self.load_plugin(obj) return False except Exception as e: logger.error(f"重载插件 {plugin_name} 时发生错误: {e}") return False async def reload_plugins(self) -> List[str]: """重载所有插件 Returns: List[str]: 成功重载的插件名称列表 """ try: # 记录当前加载的插件名称,排除 ManagePlugin original_plugins = [name for name in self.plugins.keys() if name != "ManagePlugin"] # 卸载除 ManagePlugin 外的所有插件 for plugin_name in original_plugins: await self.unload_plugin(plugin_name) # 重新加载所有模块 for module_name in list(sys.modules.keys()): if module_name.startswith('plugins.') and not module_name.endswith('ManagePlugin'): del sys.modules[module_name] # 从目录重新加载插件 return await self.load_plugins() except: logger.error(f"重载所有插件时发生错误: {traceback.format_exc()}") return [] async def refresh_plugins(self): for dirname in os.listdir("plugins"): try: dirpath = f"plugins/{dirname}" if os.path.isdir(dirpath) and os.path.exists(f"{dirpath}/main.py"): # 验证目录名合法性 if not dirname.isidentifier(): logger.warning(f"跳过非法插件目录名: {dirname}") continue module = importlib.import_module(f"plugins.{dirname}.main") importlib.reload(module) for name, obj in inspect.getmembers(module): if inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase: if obj.__name__ not in self.plugin_info.keys(): self.plugin_info[obj.__name__] = { "name": obj.__name__, "description": obj.description, "author": obj.author, "version": obj.version, "directory": dirname, "enabled": False, "class": obj } except: logger.error(f"检查 {dirname} 时发生错误: {traceback.format_exc()}") continue def get_plugin_info(self, plugin_name: str = None) -> Union[dict, List[dict]]: """获取插件信息 Args: plugin_name: 插件名称,如果为None则返回所有插件信息 Returns: 如果指定插件名,返回单个插件信息字典;否则返回所有插件信息列表 """ if plugin_name: return self.plugin_info.get(plugin_name) return list(self.plugin_info.values())