import importlib import inspect import os import sys import traceback from typing import Dict, Type, List, Union from loguru import logger # from WechatAPI import WechatAPIClient # 注释掉,WechatHookBot 不需要这个导入 from utils.singleton import Singleton from utils.config_manager import get_bot_config from utils.llm_tooling import register_plugin_tools, unregister_plugin_tools from .event_manager import EventManager from .plugin_base import PluginBase class PluginManager(metaclass=Singleton): def __init__(self): self.plugins: Dict[str, PluginBase] = {} self.plugin_classes: Dict[str, Type[PluginBase]] = {} self.plugin_info: Dict[str, dict] = {} # 新增:存储所有插件信息 self.bot = None # 使用统一配置管理器 bot_config = get_bot_config() self.excluded_plugins = bot_config.get("disabled-plugins", []) def set_bot(self, bot): """设置 bot 客户端(WechatHookClient)""" self.bot = bot async def load_plugin(self, plugin: Union[Type[PluginBase], str]) -> bool: if isinstance(plugin, str): return await self._load_plugin_name(plugin) elif isinstance(plugin, type) and issubclass(plugin, PluginBase): return await self._load_plugin_class(plugin) async def _load_plugin_class(self, plugin_class: Type[PluginBase], is_disabled: bool = False) -> bool: """加载单个插件,接受Type[PluginBase]""" try: plugin_name = plugin_class.__name__ # 防止重复加载插件 if plugin_name in self.plugins: return False # 安全获取插件目录名 directory = "unknown" try: module_name = plugin_class.__module__ if module_name.startswith("plugins."): directory = module_name.split('.')[1] else: logger.warning(f"非常规插件模块路径: {module_name}") except Exception as e: logger.error(f"获取插件目录失败: {e}") directory = "error" # 记录插件信息,即使插件被禁用也会记录 self.plugin_info[plugin_name] = { "name": plugin_name, "description": plugin_class.description, "author": plugin_class.author, "version": plugin_class.version, "directory": directory, "enabled": False, "class": plugin_class } # 如果插件被禁用则不加载 if is_disabled: return False # 创建插件实例 plugin = plugin_class() # 生命周期: on_load(可访问其他插件) await plugin.on_load(self) # 生命周期: async_init(加载配置、资源) await plugin.async_init() # 绑定事件处理器 EventManager.bind_instance(plugin) # 生命周期: on_enable(注册定时任务) await plugin.on_enable(self.bot) # 注册到插件管理器 self.plugins[plugin_name] = plugin self.plugin_classes[plugin_name] = plugin_class self.plugin_info[plugin_name]["enabled"] = True # 注册插件的 LLM 工具到全局注册中心 try: tools_config = self._get_tools_config() timeout_config = self._get_timeout_config() register_plugin_tools(plugin_name, plugin, tools_config, timeout_config) except Exception as e: logger.warning(f"注册插件 {plugin_name} 的工具时出错: {e}") logger.success(f"加载插件 {plugin_name} 成功") return True except: logger.error(f"加载插件时发生错误: {traceback.format_exc()}") return False async def _load_plugin_name(self, plugin_name: str) -> bool: """从plugins目录加载单个插件 Args: plugin_name: 插件类名称(不是文件名) Returns: bool: 是否成功加载插件 """ found = False for dirname in os.listdir("plugins"): try: if os.path.isdir(f"plugins/{dirname}") and os.path.exists(f"plugins/{dirname}/main.py"): module = importlib.import_module(f"plugins.{dirname}.main") importlib.reload(module) for name, obj in inspect.getmembers(module): if (inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase and obj.__name__ == plugin_name): found = True return await self._load_plugin_class(obj) except: logger.error(f"检查 {dirname} 时发生错误: {traceback.format_exc()}") continue if not found: logger.warning(f"未找到插件类 {plugin_name}") def _resolve_load_order(self, plugin_classes: List[Type[PluginBase]]) -> List[Type[PluginBase]]: """ 解析插件加载顺序(拓扑排序 + 优先级排序) Args: plugin_classes: 插件类列表 Returns: 按依赖关系和优先级排序后的插件类列表 """ # 构建插件名到类的映射 name_to_class = {cls.__name__: cls for cls in plugin_classes} # 构建依赖图 dependencies = {} for cls in plugin_classes: deps = getattr(cls, 'dependencies', []) dependencies[cls.__name__] = [d for d in deps if d in name_to_class] # 拓扑排序 sorted_names = [] visited = set() temp_visited = set() def visit(name: str): if name in temp_visited: # 检测到循环依赖 logger.warning(f"检测到循环依赖: {name}") return if name in visited: return temp_visited.add(name) # 先访问依赖 for dep in dependencies.get(name, []): visit(dep) temp_visited.remove(name) visited.add(name) sorted_names.append(name) # 按优先级排序后再进行拓扑排序 priority_sorted = sorted( plugin_classes, key=lambda cls: getattr(cls, 'load_priority', 50), reverse=True ) for cls in priority_sorted: if cls.__name__ not in visited: visit(cls.__name__) # 返回排序后的类列表 return [name_to_class[name] for name in sorted_names if name in name_to_class] async def load_plugins(self, load_disabled: bool = True) -> Union[List[str], bool]: """加载所有插件(按依赖顺序)""" loaded_plugins = [] # 第一步:收集所有插件类 all_plugin_classes = [] plugin_disabled_map = {} for dirname in os.listdir("plugins"): if os.path.isdir(f"plugins/{dirname}") and os.path.exists(f"plugins/{dirname}/main.py"): try: module = importlib.import_module(f"plugins.{dirname}.main") for name, obj in inspect.getmembers(module): if inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase: all_plugin_classes.append(obj) # 记录是否禁用 is_disabled = False if not load_disabled: is_disabled = obj.__name__ in self.excluded_plugins or dirname in self.excluded_plugins plugin_disabled_map[obj.__name__] = is_disabled except: logger.error(f"加载 {dirname} 时发生错误: {traceback.format_exc()}") # 第二步:按依赖顺序排序 sorted_classes = self._resolve_load_order(all_plugin_classes) logger.info(f"插件加载顺序: {[cls.__name__ for cls in sorted_classes]}") # 第三步:按顺序加载插件 for plugin_class in sorted_classes: plugin_name = plugin_class.__name__ is_disabled = plugin_disabled_map.get(plugin_name, False) # 检查依赖是否已加载 deps = getattr(plugin_class, 'dependencies', []) deps_satisfied = all(dep in self.plugins for dep in deps) if not deps_satisfied and not is_disabled: missing_deps = [dep for dep in deps if dep not in self.plugins] logger.warning(f"插件 {plugin_name} 的依赖未满足: {missing_deps},跳过加载") continue if await self._load_plugin_class(plugin_class, is_disabled=is_disabled): loaded_plugins.append(plugin_name) return loaded_plugins async def unload_plugin(self, plugin_name: str) -> bool: """卸载单个插件""" if plugin_name not in self.plugins: return False # 防止卸载 ManagePlugin if plugin_name == "ManagePlugin": logger.warning("ManagePlugin 不能被卸载") return False try: plugin = self.plugins[plugin_name] # 生命周期: on_disable(清理定时任务) await plugin.on_disable() # 解绑事件处理器 EventManager.unbind_instance(plugin) # 从全局注册中心注销插件的工具 try: unregister_plugin_tools(plugin_name) except Exception as e: logger.warning(f"注销插件 {plugin_name} 的工具时出错: {e}") # 生命周期: on_unload(释放资源) await plugin.on_unload() del self.plugins[plugin_name] del self.plugin_classes[plugin_name] if plugin_name in self.plugin_info.keys(): self.plugin_info[plugin_name]["enabled"] = False logger.success(f"卸载插件 {plugin_name} 成功") return True except: logger.error(f"卸载插件 {plugin_name} 时发生错误: {traceback.format_exc()}") return False async def unload_plugins(self) -> tuple[List[str], List[str]]: """卸载所有插件""" unloaded_plugins = [] failed_unloads = [] for plugin_name in list(self.plugins.keys()): if await self.unload_plugin(plugin_name): unloaded_plugins.append(plugin_name) else: failed_unloads.append(plugin_name) return unloaded_plugins, failed_unloads async def reload_plugin(self, plugin_name: str) -> bool: """重载单个插件(支持状态保存和恢复)""" if plugin_name not in self.plugin_classes: return False # 防止重载 ManagePlugin if plugin_name == "ManagePlugin": logger.warning("ManagePlugin 不能被重载") return False try: # 获取插件类所在的模块 plugin_class = self.plugin_classes[plugin_name] module_name = plugin_class.__module__ # 生命周期: on_reload(保存状态) saved_state = {} if plugin_name in self.plugins: try: saved_state = await self.plugins[plugin_name].on_reload() except Exception as e: logger.warning(f"保存插件 {plugin_name} 状态失败: {e}") # 卸载插件 if not await self.unload_plugin(plugin_name): return False # 重新导入模块 module = importlib.import_module(module_name) importlib.reload(module) # 从重新加载的模块中获取插件类 for name, obj in inspect.getmembers(module): if (inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase and obj.__name__ == plugin_name): # 加载新插件 if await self.load_plugin(obj): # 恢复状态 if saved_state and plugin_name in self.plugins: try: await self.plugins[plugin_name].restore_state(saved_state) except Exception as e: logger.warning(f"恢复插件 {plugin_name} 状态失败: {e}") return True return False except Exception as e: logger.error(f"重载插件 {plugin_name} 时发生错误: {e}") return False async def reload_plugins(self) -> List[str]: """重载所有插件 Returns: List[str]: 成功重载的插件名称列表 """ try: # 记录当前加载的插件名称,排除 ManagePlugin original_plugins = [name for name in self.plugins.keys() if name != "ManagePlugin"] # 卸载除 ManagePlugin 外的所有插件 for plugin_name in original_plugins: await self.unload_plugin(plugin_name) # 重新加载所有模块 for module_name in list(sys.modules.keys()): if module_name.startswith('plugins.') and not module_name.endswith('ManagePlugin'): del sys.modules[module_name] # 从目录重新加载插件 return await self.load_plugins() except: logger.error(f"重载所有插件时发生错误: {traceback.format_exc()}") return [] async def refresh_plugins(self): for dirname in os.listdir("plugins"): try: dirpath = f"plugins/{dirname}" if os.path.isdir(dirpath) and os.path.exists(f"{dirpath}/main.py"): # 验证目录名合法性 if not dirname.isidentifier(): logger.warning(f"跳过非法插件目录名: {dirname}") continue module = importlib.import_module(f"plugins.{dirname}.main") importlib.reload(module) for name, obj in inspect.getmembers(module): if inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase: if obj.__name__ not in self.plugin_info.keys(): self.plugin_info[obj.__name__] = { "name": obj.__name__, "description": obj.description, "author": obj.author, "version": obj.version, "directory": dirname, "enabled": False, "class": obj } except: logger.error(f"检查 {dirname} 时发生错误: {traceback.format_exc()}") continue def get_plugin_info(self, plugin_name: str = None) -> Union[dict, List[dict]]: """获取插件信息 Args: plugin_name: 插件名称,如果为None则返回所有插件信息 Returns: 如果指定插件名,返回单个插件信息字典;否则返回所有插件信息列表 """ if plugin_name: return self.plugin_info.get(plugin_name) return list(self.plugin_info.values()) def _get_tools_config(self) -> dict: """获取工具配置(用于工具注册)""" try: # 尝试从 AIChat 插件配置读取 import tomllib from pathlib import Path aichat_config_path = Path("plugins/AIChat/config.toml") if aichat_config_path.exists(): with open(aichat_config_path, "rb") as f: aichat_config = tomllib.load(f) return aichat_config.get("tools", {}) except Exception: pass return {"mode": "all", "whitelist": [], "blacklist": []} def _get_timeout_config(self) -> dict: """获取工具超时配置""" try: import tomllib from pathlib import Path aichat_config_path = Path("plugins/AIChat/config.toml") if aichat_config_path.exists(): with open(aichat_config_path, "rb") as f: aichat_config = tomllib.load(f) return aichat_config.get("tools", {}).get("timeout", {}) except Exception: pass return {"default": 60}