Files
WechatHookBot/utils/plugin_manager.py

441 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import importlib
import inspect
import os
import sys
import traceback
from typing import Dict, Type, List, Union
from loguru import logger
# from WechatAPI import WechatAPIClient # 注释掉WechatHookBot 不需要这个导入
from utils.singleton import Singleton
from utils.config_manager import get_bot_config
from utils.llm_tooling import register_plugin_tools, unregister_plugin_tools
from .event_manager import EventManager
from .plugin_base import PluginBase
class PluginManager(metaclass=Singleton):
def __init__(self):
self.plugins: Dict[str, PluginBase] = {}
self.plugin_classes: Dict[str, Type[PluginBase]] = {}
self.plugin_info: Dict[str, dict] = {} # 新增:存储所有插件信息
self.bot = None
# 使用统一配置管理器
bot_config = get_bot_config()
self.excluded_plugins = bot_config.get("disabled-plugins", [])
def set_bot(self, bot):
"""设置 bot 客户端WechatHookClient"""
self.bot = bot
async def load_plugin(self, plugin: Union[Type[PluginBase], str]) -> bool:
if isinstance(plugin, str):
return await self._load_plugin_name(plugin)
elif isinstance(plugin, type) and issubclass(plugin, PluginBase):
return await self._load_plugin_class(plugin)
async def _load_plugin_class(self, plugin_class: Type[PluginBase],
is_disabled: bool = False) -> bool:
"""加载单个插件接受Type[PluginBase]"""
try:
plugin_name = plugin_class.__name__
# 防止重复加载插件
if plugin_name in self.plugins:
return False
# 安全获取插件目录名
directory = "unknown"
try:
module_name = plugin_class.__module__
if module_name.startswith("plugins."):
directory = module_name.split('.')[1]
else:
logger.warning(f"非常规插件模块路径: {module_name}")
except Exception as e:
logger.error(f"获取插件目录失败: {e}")
directory = "error"
# 记录插件信息,即使插件被禁用也会记录
self.plugin_info[plugin_name] = {
"name": plugin_name,
"description": plugin_class.description,
"author": plugin_class.author,
"version": plugin_class.version,
"directory": directory,
"enabled": False,
"class": plugin_class
}
# 如果插件被禁用则不加载
if is_disabled:
return False
# 创建插件实例
plugin = plugin_class()
# 生命周期: on_load可访问其他插件
await plugin.on_load(self)
# 生命周期: async_init加载配置、资源
await plugin.async_init()
# 绑定事件处理器
EventManager.bind_instance(plugin)
# 生命周期: on_enable注册定时任务
await plugin.on_enable(self.bot)
# 注册到插件管理器
self.plugins[plugin_name] = plugin
self.plugin_classes[plugin_name] = plugin_class
self.plugin_info[plugin_name]["enabled"] = True
# 注册插件的 LLM 工具到全局注册中心
try:
tools_config = self._get_tools_config()
timeout_config = self._get_timeout_config()
register_plugin_tools(plugin_name, plugin, tools_config, timeout_config)
except Exception as e:
logger.warning(f"注册插件 {plugin_name} 的工具时出错: {e}")
logger.success(f"加载插件 {plugin_name} 成功")
return True
except:
logger.error(f"加载插件时发生错误: {traceback.format_exc()}")
return False
async def _load_plugin_name(self, plugin_name: str) -> bool:
"""从plugins目录加载单个插件
Args:
plugin_name: 插件类名称(不是文件名)
Returns:
bool: 是否成功加载插件
"""
found = False
for dirname in os.listdir("plugins"):
try:
if os.path.isdir(f"plugins/{dirname}") and os.path.exists(f"plugins/{dirname}/main.py"):
module = importlib.import_module(f"plugins.{dirname}.main")
importlib.reload(module)
for name, obj in inspect.getmembers(module):
if (inspect.isclass(obj) and
issubclass(obj, PluginBase) and
obj != PluginBase and
obj.__name__ == plugin_name):
found = True
return await self._load_plugin_class(obj)
except:
logger.error(f"检查 {dirname} 时发生错误: {traceback.format_exc()}")
continue
if not found:
logger.warning(f"未找到插件类 {plugin_name}")
def _resolve_load_order(self, plugin_classes: List[Type[PluginBase]]) -> List[Type[PluginBase]]:
"""
解析插件加载顺序(拓扑排序 + 优先级排序)
Args:
plugin_classes: 插件类列表
Returns:
按依赖关系和优先级排序后的插件类列表
"""
# 构建插件名到类的映射
name_to_class = {cls.__name__: cls for cls in plugin_classes}
# 构建依赖图
dependencies = {}
for cls in plugin_classes:
deps = getattr(cls, 'dependencies', [])
dependencies[cls.__name__] = [d for d in deps if d in name_to_class]
# 拓扑排序
sorted_names = []
visited = set()
temp_visited = set()
def visit(name: str):
if name in temp_visited:
# 检测到循环依赖
logger.warning(f"检测到循环依赖: {name}")
return
if name in visited:
return
temp_visited.add(name)
# 先访问依赖
for dep in dependencies.get(name, []):
visit(dep)
temp_visited.remove(name)
visited.add(name)
sorted_names.append(name)
# 按优先级排序后再进行拓扑排序
priority_sorted = sorted(
plugin_classes,
key=lambda cls: getattr(cls, 'load_priority', 50),
reverse=True
)
for cls in priority_sorted:
if cls.__name__ not in visited:
visit(cls.__name__)
# 返回排序后的类列表
return [name_to_class[name] for name in sorted_names if name in name_to_class]
async def load_plugins(self, load_disabled: bool = True) -> Union[List[str], bool]:
"""加载所有插件(按依赖顺序)"""
loaded_plugins = []
# 第一步:收集所有插件类
all_plugin_classes = []
plugin_disabled_map = {}
for dirname in os.listdir("plugins"):
if os.path.isdir(f"plugins/{dirname}") and os.path.exists(f"plugins/{dirname}/main.py"):
try:
module = importlib.import_module(f"plugins.{dirname}.main")
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase:
all_plugin_classes.append(obj)
# 记录是否禁用
is_disabled = False
if not load_disabled:
is_disabled = obj.__name__ in self.excluded_plugins or dirname in self.excluded_plugins
plugin_disabled_map[obj.__name__] = is_disabled
except:
logger.error(f"加载 {dirname} 时发生错误: {traceback.format_exc()}")
# 第二步:按依赖顺序排序
sorted_classes = self._resolve_load_order(all_plugin_classes)
logger.info(f"插件加载顺序: {[cls.__name__ for cls in sorted_classes]}")
# 第三步:按顺序加载插件
for plugin_class in sorted_classes:
plugin_name = plugin_class.__name__
is_disabled = plugin_disabled_map.get(plugin_name, False)
# 检查依赖是否已加载
deps = getattr(plugin_class, 'dependencies', [])
deps_satisfied = all(dep in self.plugins for dep in deps)
if not deps_satisfied and not is_disabled:
missing_deps = [dep for dep in deps if dep not in self.plugins]
logger.warning(f"插件 {plugin_name} 的依赖未满足: {missing_deps},跳过加载")
continue
if await self._load_plugin_class(plugin_class, is_disabled=is_disabled):
loaded_plugins.append(plugin_name)
return loaded_plugins
async def unload_plugin(self, plugin_name: str) -> bool:
"""卸载单个插件"""
if plugin_name not in self.plugins:
return False
# 防止卸载 ManagePlugin
if plugin_name == "ManagePlugin":
logger.warning("ManagePlugin 不能被卸载")
return False
try:
plugin = self.plugins[plugin_name]
# 生命周期: on_disable清理定时任务
await plugin.on_disable()
# 解绑事件处理器
EventManager.unbind_instance(plugin)
# 从全局注册中心注销插件的工具
try:
unregister_plugin_tools(plugin_name)
except Exception as e:
logger.warning(f"注销插件 {plugin_name} 的工具时出错: {e}")
# 生命周期: on_unload释放资源
await plugin.on_unload()
del self.plugins[plugin_name]
del self.plugin_classes[plugin_name]
if plugin_name in self.plugin_info.keys():
self.plugin_info[plugin_name]["enabled"] = False
logger.success(f"卸载插件 {plugin_name} 成功")
return True
except:
logger.error(f"卸载插件 {plugin_name} 时发生错误: {traceback.format_exc()}")
return False
async def unload_plugins(self) -> tuple[List[str], List[str]]:
"""卸载所有插件"""
unloaded_plugins = []
failed_unloads = []
for plugin_name in list(self.plugins.keys()):
if await self.unload_plugin(plugin_name):
unloaded_plugins.append(plugin_name)
else:
failed_unloads.append(plugin_name)
return unloaded_plugins, failed_unloads
async def reload_plugin(self, plugin_name: str) -> bool:
"""重载单个插件(支持状态保存和恢复)"""
if plugin_name not in self.plugin_classes:
return False
# 防止重载 ManagePlugin
if plugin_name == "ManagePlugin":
logger.warning("ManagePlugin 不能被重载")
return False
try:
# 获取插件类所在的模块
plugin_class = self.plugin_classes[plugin_name]
module_name = plugin_class.__module__
# 生命周期: on_reload保存状态
saved_state = {}
if plugin_name in self.plugins:
try:
saved_state = await self.plugins[plugin_name].on_reload()
except Exception as e:
logger.warning(f"保存插件 {plugin_name} 状态失败: {e}")
# 卸载插件
if not await self.unload_plugin(plugin_name):
return False
# 重新导入模块
module = importlib.import_module(module_name)
importlib.reload(module)
# 从重新加载的模块中获取插件类
for name, obj in inspect.getmembers(module):
if (inspect.isclass(obj) and
issubclass(obj, PluginBase) and
obj != PluginBase and
obj.__name__ == plugin_name):
# 加载新插件
if await self.load_plugin(obj):
# 恢复状态
if saved_state and plugin_name in self.plugins:
try:
await self.plugins[plugin_name].restore_state(saved_state)
except Exception as e:
logger.warning(f"恢复插件 {plugin_name} 状态失败: {e}")
return True
return False
except Exception as e:
logger.error(f"重载插件 {plugin_name} 时发生错误: {e}")
return False
async def reload_plugins(self) -> List[str]:
"""重载所有插件
Returns:
List[str]: 成功重载的插件名称列表
"""
try:
# 记录当前加载的插件名称,排除 ManagePlugin
original_plugins = [name for name in self.plugins.keys() if name != "ManagePlugin"]
# 卸载除 ManagePlugin 外的所有插件
for plugin_name in original_plugins:
await self.unload_plugin(plugin_name)
# 重新加载所有模块
for module_name in list(sys.modules.keys()):
if module_name.startswith('plugins.') and not module_name.endswith('ManagePlugin'):
del sys.modules[module_name]
# 从目录重新加载插件
return await self.load_plugins()
except:
logger.error(f"重载所有插件时发生错误: {traceback.format_exc()}")
return []
async def refresh_plugins(self):
for dirname in os.listdir("plugins"):
try:
dirpath = f"plugins/{dirname}"
if os.path.isdir(dirpath) and os.path.exists(f"{dirpath}/main.py"):
# 验证目录名合法性
if not dirname.isidentifier():
logger.warning(f"跳过非法插件目录名: {dirname}")
continue
module = importlib.import_module(f"plugins.{dirname}.main")
importlib.reload(module)
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase:
if obj.__name__ not in self.plugin_info.keys():
self.plugin_info[obj.__name__] = {
"name": obj.__name__,
"description": obj.description,
"author": obj.author,
"version": obj.version,
"directory": dirname,
"enabled": False,
"class": obj
}
except:
logger.error(f"检查 {dirname} 时发生错误: {traceback.format_exc()}")
continue
def get_plugin_info(self, plugin_name: str = None) -> Union[dict, List[dict]]:
"""获取插件信息
Args:
plugin_name: 插件名称如果为None则返回所有插件信息
Returns:
如果指定插件名,返回单个插件信息字典;否则返回所有插件信息列表
"""
if plugin_name:
return self.plugin_info.get(plugin_name)
return list(self.plugin_info.values())
def _get_tools_config(self) -> dict:
"""获取工具配置(用于工具注册)"""
try:
# 尝试从 AIChat 插件配置读取
import tomllib
from pathlib import Path
aichat_config_path = Path("plugins/AIChat/config.toml")
if aichat_config_path.exists():
with open(aichat_config_path, "rb") as f:
aichat_config = tomllib.load(f)
return aichat_config.get("tools", {})
except Exception:
pass
return {"mode": "all", "whitelist": [], "blacklist": []}
def _get_timeout_config(self) -> dict:
"""获取工具超时配置"""
try:
import tomllib
from pathlib import Path
aichat_config_path = Path("plugins/AIChat/config.toml")
if aichat_config_path.exists():
with open(aichat_config_path, "rb") as f:
aichat_config = tomllib.load(f)
return aichat_config.get("tools", {}).get("timeout", {})
except Exception:
pass
return {"default": 60}