feat:初版
This commit is contained in:
20
utils/__init__.py
Normal file
20
utils/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Utils - 工具类模块
|
||||
|
||||
包含插件系统、事件管理、装饰器等核心工具
|
||||
"""
|
||||
|
||||
from .plugin_base import PluginBase
|
||||
from .plugin_manager import PluginManager
|
||||
from .event_manager import EventManager
|
||||
from .decorators import *
|
||||
from .singleton import Singleton
|
||||
from .hookbot import HookBot
|
||||
|
||||
__all__ = [
|
||||
'PluginBase',
|
||||
'PluginManager',
|
||||
'EventManager',
|
||||
'Singleton',
|
||||
'HookBot',
|
||||
]
|
||||
257
utils/decorators.py
Normal file
257
utils/decorators.py
Normal file
@@ -0,0 +1,257 @@
|
||||
from functools import wraps
|
||||
from typing import Callable, Union
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
|
||||
def schedule(
|
||||
trigger: Union[str, CronTrigger, IntervalTrigger],
|
||||
**trigger_args
|
||||
) -> Callable:
|
||||
"""
|
||||
定时任务装饰器
|
||||
|
||||
例子:
|
||||
|
||||
- @schedule('interval', seconds=30)
|
||||
- @schedule('cron', hour=8, minute=30, second=30)
|
||||
- @schedule('date', run_date='2024-01-01 00:00:00')
|
||||
"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
job_id = f"{func.__module__}.{func.__qualname__}"
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
setattr(wrapper, '_is_scheduled', True)
|
||||
setattr(wrapper, '_schedule_trigger', trigger)
|
||||
setattr(wrapper, '_schedule_args', trigger_args)
|
||||
setattr(wrapper, '_job_id', job_id)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def add_job_safe(scheduler: AsyncIOScheduler, job_id: str, func: Callable, bot,
|
||||
trigger: Union[str, CronTrigger, IntervalTrigger], **trigger_args):
|
||||
"""添加函数到定时任务中,如果存在则先删除现有的任务"""
|
||||
try:
|
||||
scheduler.remove_job(job_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 读取调度器配置
|
||||
try:
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
config_path = Path("main_config.toml")
|
||||
if config_path.exists():
|
||||
with open(config_path, "rb") as f:
|
||||
config = tomllib.load(f)
|
||||
scheduler_config = config.get("Scheduler", {})
|
||||
else:
|
||||
scheduler_config = {}
|
||||
except:
|
||||
scheduler_config = {}
|
||||
|
||||
# 应用调度器配置
|
||||
job_kwargs = {
|
||||
"coalesce": scheduler_config.get("coalesce", True),
|
||||
"max_instances": scheduler_config.get("max_instances", 1),
|
||||
"misfire_grace_time": scheduler_config.get("misfire_grace_time", 30)
|
||||
}
|
||||
job_kwargs.update(trigger_args)
|
||||
|
||||
scheduler.add_job(func, trigger, args=[bot], id=job_id, **job_kwargs)
|
||||
|
||||
|
||||
def remove_job_safe(scheduler: AsyncIOScheduler, job_id: str):
|
||||
"""从定时任务中移除任务"""
|
||||
try:
|
||||
scheduler.remove_job(job_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def on_text_message(priority=50):
|
||||
"""文本消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority): # 无参数调用时
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'text_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
# 有参数调用时
|
||||
setattr(func, '_event_type', 'text_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
|
||||
|
||||
def on_image_message(priority=50):
|
||||
"""图片消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'image_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'image_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
|
||||
|
||||
def on_voice_message(priority=50):
|
||||
"""语音消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'voice_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'voice_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
|
||||
|
||||
def on_emoji_message(priority=50):
|
||||
"""表情消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'emoji_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'emoji_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
|
||||
|
||||
def on_file_message(priority=50):
|
||||
"""文件消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'file_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'file_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
|
||||
|
||||
def on_quote_message(priority=50):
|
||||
"""引用消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'quote_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'quote_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
|
||||
|
||||
def on_video_message(priority=50):
|
||||
"""视频消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'video_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'video_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
|
||||
|
||||
def on_pat_message(priority=50):
|
||||
"""拍一拍消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'pat_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'pat_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
|
||||
|
||||
def on_at_message(priority=50):
|
||||
"""被@消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'at_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'at_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
|
||||
|
||||
def on_system_message(priority=50):
|
||||
"""其他消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'system_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'other_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
|
||||
|
||||
def on_other_message(priority=50):
|
||||
"""其他消息装饰器"""
|
||||
|
||||
def decorator(func):
|
||||
if callable(priority):
|
||||
f = priority
|
||||
setattr(f, '_event_type', 'other_message')
|
||||
setattr(f, '_priority', 50)
|
||||
return f
|
||||
setattr(func, '_event_type', 'other_message')
|
||||
setattr(func, '_priority', min(max(priority, 0), 99))
|
||||
return func
|
||||
|
||||
return decorator if not callable(priority) else decorator(priority)
|
||||
71
utils/event_manager.py
Normal file
71
utils/event_manager.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import copy
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
|
||||
class EventManager:
|
||||
_handlers: Dict[str, List[tuple[Callable, object, int]]] = {}
|
||||
|
||||
@classmethod
|
||||
def bind_instance(cls, instance: object):
|
||||
"""将实例绑定到对应的事件处理函数"""
|
||||
from loguru import logger
|
||||
registered_count = 0
|
||||
for method_name in dir(instance):
|
||||
method = getattr(instance, method_name)
|
||||
if hasattr(method, '_event_type'):
|
||||
event_type = getattr(method, '_event_type')
|
||||
priority = getattr(method, '_priority', 50)
|
||||
|
||||
if event_type not in cls._handlers:
|
||||
cls._handlers[event_type] = []
|
||||
cls._handlers[event_type].append((method, instance, priority))
|
||||
# 按优先级排序,优先级高的在前
|
||||
cls._handlers[event_type].sort(key=lambda x: x[2], reverse=True)
|
||||
registered_count += 1
|
||||
logger.debug(f"[EventManager] 注册事件处理器: {instance.__class__.__name__}.{method_name} -> {event_type} (优先级={priority})")
|
||||
|
||||
if registered_count > 0:
|
||||
logger.success(f"[EventManager] {instance.__class__.__name__} 注册了 {registered_count} 个事件处理器")
|
||||
|
||||
@classmethod
|
||||
async def emit(cls, event_type: str, *args, **kwargs) -> None:
|
||||
"""触发事件"""
|
||||
from loguru import logger
|
||||
|
||||
if event_type not in cls._handlers:
|
||||
logger.debug(f"[EventManager] 事件 {event_type} 没有注册的处理器")
|
||||
return
|
||||
|
||||
logger.debug(f"[EventManager] 触发事件: {event_type}, 处理器数量: {len(cls._handlers[event_type])}")
|
||||
|
||||
api_client, message = args
|
||||
for handler, instance, priority in cls._handlers[event_type]:
|
||||
try:
|
||||
logger.debug(f"[EventManager] 调用处理器: {instance.__class__.__name__}.{handler.__name__}")
|
||||
# 不再深拷贝message,让所有处理器共享同一个消息对象
|
||||
# 这样AutoReply设置的标记可以传递给AIChat
|
||||
handler_args = (api_client, message)
|
||||
new_kwargs = kwargs # kwargs也不需要深拷贝
|
||||
|
||||
result = await handler(*handler_args, **new_kwargs)
|
||||
|
||||
if isinstance(result, bool):
|
||||
# True 继续执行 False 停止执行
|
||||
if not result:
|
||||
break
|
||||
else:
|
||||
continue # 我也不知道你返回了个啥玩意,反正继续执行就是了
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"处理器 {handler.__name__} 执行失败: {e}")
|
||||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||||
|
||||
@classmethod
|
||||
def unbind_instance(cls, instance: object):
|
||||
"""解绑实例的所有事件处理函数"""
|
||||
for event_type in cls._handlers:
|
||||
cls._handlers[event_type] = [
|
||||
(handler, inst, priority)
|
||||
for handler, inst, priority in cls._handlers[event_type]
|
||||
if inst is not instance
|
||||
]
|
||||
198
utils/hookbot.py
Normal file
198
utils/hookbot.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
HookBot - 机器人核心类
|
||||
|
||||
处理消息路由和事件分发
|
||||
"""
|
||||
|
||||
import tomllib
|
||||
from typing import Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from WechatHook import WechatHookClient, MESSAGE_TYPE_MAP, normalize_message
|
||||
from utils.event_manager import EventManager
|
||||
|
||||
|
||||
class HookBot:
|
||||
"""
|
||||
HookBot 核心类
|
||||
|
||||
负责消息处理、路由和事件分发
|
||||
"""
|
||||
|
||||
def __init__(self, client: WechatHookClient):
|
||||
"""
|
||||
初始化 HookBot
|
||||
|
||||
Args:
|
||||
client: WechatHookClient 实例
|
||||
"""
|
||||
self.client = client
|
||||
self.wxid = None
|
||||
self.nickname = None
|
||||
|
||||
# 读取配置
|
||||
with open("main_config.toml", "rb") as f:
|
||||
main_config = tomllib.load(f)
|
||||
|
||||
bot_config = main_config.get("Bot", {})
|
||||
preset_wxid = bot_config.get("wxid") or bot_config.get("bot_wxid")
|
||||
preset_nickname = bot_config.get("nickname") or bot_config.get("bot_nickname")
|
||||
|
||||
if preset_wxid:
|
||||
self.wxid = preset_wxid
|
||||
logger.info(f"使用配置中的机器人 wxid: {self.wxid}")
|
||||
if preset_nickname:
|
||||
self.nickname = preset_nickname
|
||||
logger.info(f"使用配置中的机器人昵称: {self.nickname}")
|
||||
self.ignore_mode = bot_config.get("ignore-mode", "None")
|
||||
self.whitelist = bot_config.get("whitelist", [])
|
||||
self.blacklist = bot_config.get("blacklist", [])
|
||||
|
||||
# 性能配置
|
||||
perf_config = main_config.get("Performance", {})
|
||||
self.log_sampling_rate = perf_config.get("log_sampling_rate", 1.0)
|
||||
|
||||
# 消息计数和统计
|
||||
self.message_count = 0
|
||||
self.filtered_count = 0
|
||||
self.processed_count = 0
|
||||
|
||||
logger.info("HookBot 初始化完成")
|
||||
|
||||
def update_profile(self, wxid: str, nickname: str):
|
||||
"""
|
||||
更新机器人信息
|
||||
|
||||
Args:
|
||||
wxid: 机器人 wxid
|
||||
nickname: 机器人昵称
|
||||
"""
|
||||
self.wxid = wxid
|
||||
self.nickname = nickname
|
||||
logger.info(f"机器人信息: wxid={wxid}, nickname={nickname}")
|
||||
|
||||
async def process_message(self, msg_type: int, data: dict):
|
||||
"""
|
||||
处理接收到的消息
|
||||
|
||||
Args:
|
||||
msg_type: 消息类型
|
||||
data: 消息数据
|
||||
"""
|
||||
# 过滤 API 响应消息
|
||||
if msg_type in [11174, 11230]:
|
||||
return
|
||||
|
||||
# 消息计数
|
||||
self.message_count += 1
|
||||
|
||||
# 日志采样 - 只记录部分消息以减少日志量
|
||||
should_log = self._should_log_message(msg_type)
|
||||
|
||||
if should_log:
|
||||
logger.debug(f"处理消息: type={msg_type}")
|
||||
|
||||
# 重要事件始终记录
|
||||
if msg_type in [11098, 11099, 11058]: # 群成员变动、系统消息
|
||||
logger.info(f"重要事件: type={msg_type}")
|
||||
|
||||
# 获取事件类型
|
||||
event_type = MESSAGE_TYPE_MAP.get(msg_type)
|
||||
|
||||
if should_log and event_type:
|
||||
logger.info(f"[HookBot] 消息类型映射: {msg_type} -> {event_type}")
|
||||
|
||||
if not event_type:
|
||||
# 记录未知消息类型的详细信息,帮助调试
|
||||
content_preview = str(data.get('raw_msg', data.get('msg', '')))[:200]
|
||||
logger.warning(f"未映射的消息类型: {msg_type}, wx_type: {data.get('wx_type')}, 内容预览: {content_preview}")
|
||||
return
|
||||
|
||||
# 格式转换
|
||||
try:
|
||||
message = normalize_message(msg_type, data)
|
||||
except Exception as e:
|
||||
logger.error(f"格式转换失败: {e}")
|
||||
return
|
||||
|
||||
# 过滤消息
|
||||
if not self._check_filter(message):
|
||||
self.filtered_count += 1
|
||||
if should_log:
|
||||
logger.debug(f"消息被过滤: {message.get('FromWxid')}")
|
||||
return
|
||||
|
||||
self.processed_count += 1
|
||||
|
||||
# 采样记录处理的消息
|
||||
if should_log:
|
||||
content = message.get('Content', '')
|
||||
if len(content) > 50:
|
||||
content = content[:50] + "..."
|
||||
logger.info(f"处理消息: type={event_type}, from={message.get('FromWxid')}, content={content}")
|
||||
|
||||
# 触发事件
|
||||
try:
|
||||
await EventManager.emit(event_type, self.client, message)
|
||||
except Exception as e:
|
||||
logger.error(f"事件处理失败: {e}")
|
||||
|
||||
def _should_log_message(self, msg_type: int) -> bool:
|
||||
"""判断是否应该记录此消息的日志"""
|
||||
# 重要消息类型始终记录
|
||||
important_types = {
|
||||
11058, 11098, 11099, 11025, # 系统消息、群成员变动、登录信息
|
||||
11051, 11047, 11052, 11055 # 视频、图片、表情、文件消息
|
||||
}
|
||||
if msg_type in important_types:
|
||||
return True
|
||||
|
||||
# 其他消息按采样率记录
|
||||
import random
|
||||
return random.random() < self.log_sampling_rate
|
||||
|
||||
def _check_filter(self, message: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
检查消息是否通过过滤
|
||||
|
||||
Args:
|
||||
message: 消息字典
|
||||
|
||||
Returns:
|
||||
是否通过过滤
|
||||
"""
|
||||
from_wxid = message.get("FromWxid", "")
|
||||
sender_wxid = message.get("SenderWxid", "")
|
||||
msg_type = message.get("MsgType", 0)
|
||||
|
||||
# 系统消息(type=11058)不过滤,因为包含重要的群聊事件信息
|
||||
if msg_type == 11058:
|
||||
return True
|
||||
|
||||
# 过滤机器人自己发送的消息,避免无限循环
|
||||
if self.wxid and (from_wxid == self.wxid or sender_wxid == self.wxid):
|
||||
return False
|
||||
|
||||
# None 模式:处理所有消息
|
||||
if self.ignore_mode == "None":
|
||||
return True
|
||||
|
||||
# Whitelist 模式:仅处理白名单
|
||||
if self.ignore_mode == "Whitelist":
|
||||
return from_wxid in self.whitelist or sender_wxid in self.whitelist
|
||||
|
||||
# Blacklist 模式:屏蔽黑名单
|
||||
if self.ignore_mode == "Blacklist":
|
||||
return from_wxid not in self.blacklist and sender_wxid not in self.blacklist
|
||||
|
||||
return True
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""获取消息处理统计信息"""
|
||||
return {
|
||||
"total_messages": self.message_count,
|
||||
"filtered_messages": self.filtered_count,
|
||||
"processed_messages": self.processed_count,
|
||||
"filter_rate": self.filtered_count / max(self.message_count, 1),
|
||||
"process_rate": self.processed_count / max(self.message_count, 1)
|
||||
}
|
||||
88
utils/message_hook.py
Normal file
88
utils/message_hook.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
消息发送钩子工具
|
||||
|
||||
用于自动记录机器人发送的消息到 MessageLogger
|
||||
"""
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
async def log_bot_message(to_wxid: str, content: str, msg_type: str = "text", media_url: str = ""):
|
||||
"""
|
||||
记录机器人发送的消息到 MessageLogger
|
||||
|
||||
Args:
|
||||
to_wxid: 接收者微信ID
|
||||
content: 消息内容
|
||||
msg_type: 消息类型 (text/image/video/file等)
|
||||
media_url: 媒体文件URL (可选)
|
||||
"""
|
||||
try:
|
||||
logger.info(f"message_hook: 开始记录机器人消息")
|
||||
|
||||
# 动态导入避免循环依赖
|
||||
from plugins.MessageLogger.main import MessageLogger
|
||||
logger.info(f"message_hook: MessageLogger 导入成功")
|
||||
|
||||
# 获取 MessageLogger 实例
|
||||
message_logger = MessageLogger.get_instance()
|
||||
logger.info(f"message_hook: MessageLogger 实例: {message_logger}")
|
||||
|
||||
if message_logger:
|
||||
logger.info(f"message_hook: 调用 save_bot_message")
|
||||
await message_logger.save_bot_message(to_wxid, content, msg_type, media_url)
|
||||
logger.info(f"message_hook: save_bot_message 调用完成")
|
||||
else:
|
||||
logger.warning("MessageLogger 实例未找到,跳过消息记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录机器人消息失败: {e}")
|
||||
import traceback
|
||||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def create_message_hook(original_method):
|
||||
"""
|
||||
创建消息发送钩子装饰器
|
||||
|
||||
Args:
|
||||
original_method: 原始的发送消息方法
|
||||
|
||||
Returns:
|
||||
包装后的方法
|
||||
"""
|
||||
async def wrapper(self, to_wxid: str, content: str, *args, **kwargs):
|
||||
# 调用原始方法
|
||||
result = await original_method(self, to_wxid, content, *args, **kwargs)
|
||||
|
||||
# 记录消息
|
||||
await log_bot_message(to_wxid, content, "text")
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def create_file_message_hook(original_method, msg_type: str):
|
||||
"""
|
||||
创建文件消息发送钩子装饰器
|
||||
|
||||
Args:
|
||||
original_method: 原始的发送文件方法
|
||||
msg_type: 消息类型
|
||||
|
||||
Returns:
|
||||
包装后的方法
|
||||
"""
|
||||
async def wrapper(self, to_wxid: str, file_path: str, *args, **kwargs):
|
||||
# 调用原始方法
|
||||
result = await original_method(self, to_wxid, file_path, *args, **kwargs)
|
||||
|
||||
# 记录消息
|
||||
import os
|
||||
filename = os.path.basename(file_path)
|
||||
await log_bot_message(to_wxid, f"[{msg_type}] {filename}", msg_type, file_path)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
47
utils/plugin_base.py
Normal file
47
utils/plugin_base.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from abc import ABC
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .decorators import scheduler, add_job_safe, remove_job_safe
|
||||
|
||||
|
||||
class PluginBase(ABC):
|
||||
"""插件基类"""
|
||||
|
||||
# 插件元数据
|
||||
description: str = "暂无描述"
|
||||
author: str = "未知"
|
||||
version: str = "1.0.0"
|
||||
|
||||
def __init__(self):
|
||||
self.enabled = False
|
||||
self._scheduled_jobs = set()
|
||||
|
||||
async def on_enable(self, bot=None):
|
||||
"""插件启用时调用"""
|
||||
|
||||
# 定时任务
|
||||
for method_name in dir(self):
|
||||
method = getattr(self, method_name)
|
||||
if hasattr(method, '_is_scheduled'):
|
||||
job_id = getattr(method, '_job_id')
|
||||
trigger = getattr(method, '_schedule_trigger')
|
||||
trigger_args = getattr(method, '_schedule_args')
|
||||
|
||||
add_job_safe(scheduler, job_id, method, bot, trigger, **trigger_args)
|
||||
self._scheduled_jobs.add(job_id)
|
||||
if self._scheduled_jobs:
|
||||
logger.success("插件 {} 已加载定时任务: {}", self.__class__.__name__, self._scheduled_jobs)
|
||||
|
||||
async def on_disable(self):
|
||||
"""插件禁用时调用"""
|
||||
|
||||
# 移除定时任务
|
||||
for job_id in self._scheduled_jobs:
|
||||
remove_job_safe(scheduler, job_id)
|
||||
logger.info("已卸载定时任务: {}", self._scheduled_jobs)
|
||||
self._scheduled_jobs.clear()
|
||||
|
||||
async def async_init(self):
|
||||
"""插件异步初始化"""
|
||||
return
|
||||
278
utils/plugin_manager.py
Normal file
278
utils/plugin_manager.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import tomllib
|
||||
import traceback
|
||||
from typing import Dict, Type, List, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# from WechatAPI import WechatAPIClient # 注释掉,WechatHookBot 不需要这个导入
|
||||
from utils.singleton import Singleton
|
||||
from .event_manager import EventManager
|
||||
from .plugin_base import PluginBase
|
||||
|
||||
|
||||
class PluginManager(metaclass=Singleton):
|
||||
def __init__(self):
|
||||
self.plugins: Dict[str, PluginBase] = {}
|
||||
self.plugin_classes: Dict[str, Type[PluginBase]] = {}
|
||||
self.plugin_info: Dict[str, dict] = {} # 新增:存储所有插件信息
|
||||
|
||||
self.bot = None
|
||||
|
||||
with open("main_config.toml", "rb") as f:
|
||||
main_config = tomllib.load(f)
|
||||
|
||||
self.excluded_plugins = main_config.get("Bot", {}).get("disabled-plugins", [])
|
||||
|
||||
def set_bot(self, bot):
|
||||
"""设置 bot 客户端(WechatHookClient)"""
|
||||
self.bot = bot
|
||||
|
||||
async def load_plugin(self, plugin: Union[Type[PluginBase], str]) -> bool:
|
||||
if isinstance(plugin, str):
|
||||
return await self._load_plugin_name(plugin)
|
||||
elif isinstance(plugin, type) and issubclass(plugin, PluginBase):
|
||||
return await self._load_plugin_class(plugin)
|
||||
|
||||
async def _load_plugin_class(self, plugin_class: Type[PluginBase],
|
||||
is_disabled: bool = False) -> bool:
|
||||
"""加载单个插件,接受Type[PluginBase]"""
|
||||
try:
|
||||
plugin_name = plugin_class.__name__
|
||||
|
||||
# 防止重复加载插件
|
||||
if plugin_name in self.plugins:
|
||||
return False
|
||||
|
||||
# 安全获取插件目录名
|
||||
directory = "unknown"
|
||||
try:
|
||||
module_name = plugin_class.__module__
|
||||
if module_name.startswith("plugins."):
|
||||
directory = module_name.split('.')[1]
|
||||
else:
|
||||
logger.warning(f"非常规插件模块路径: {module_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件目录失败: {e}")
|
||||
directory = "error"
|
||||
|
||||
# 记录插件信息,即使插件被禁用也会记录
|
||||
self.plugin_info[plugin_name] = {
|
||||
"name": plugin_name,
|
||||
"description": plugin_class.description,
|
||||
"author": plugin_class.author,
|
||||
"version": plugin_class.version,
|
||||
"directory": directory,
|
||||
"enabled": False,
|
||||
"class": plugin_class
|
||||
}
|
||||
|
||||
# 如果插件被禁用则不加载
|
||||
if is_disabled:
|
||||
return False
|
||||
|
||||
plugin = plugin_class()
|
||||
EventManager.bind_instance(plugin)
|
||||
await plugin.on_enable(self.bot)
|
||||
await plugin.async_init()
|
||||
self.plugins[plugin_name] = plugin
|
||||
self.plugin_classes[plugin_name] = plugin_class
|
||||
self.plugin_info[plugin_name]["enabled"] = True
|
||||
logger.success(f"加载插件 {plugin_name} 成功")
|
||||
return True
|
||||
except:
|
||||
logger.error(f"加载插件时发生错误: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
async def _load_plugin_name(self, plugin_name: str) -> bool:
|
||||
"""从plugins目录加载单个插件
|
||||
|
||||
Args:
|
||||
plugin_name: 插件类名称(不是文件名)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功加载插件
|
||||
"""
|
||||
found = False
|
||||
for dirname in os.listdir("plugins"):
|
||||
try:
|
||||
if os.path.isdir(f"plugins/{dirname}") and os.path.exists(f"plugins/{dirname}/main.py"):
|
||||
module = importlib.import_module(f"plugins.{dirname}.main")
|
||||
importlib.reload(module)
|
||||
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if (inspect.isclass(obj) and
|
||||
issubclass(obj, PluginBase) and
|
||||
obj != PluginBase and
|
||||
obj.__name__ == plugin_name):
|
||||
found = True
|
||||
return await self._load_plugin_class(obj)
|
||||
except:
|
||||
logger.error(f"检查 {dirname} 时发生错误: {traceback.format_exc()}")
|
||||
continue
|
||||
|
||||
if not found:
|
||||
logger.warning(f"未找到插件类 {plugin_name}")
|
||||
|
||||
async def load_plugins(self, load_disabled: bool = True) -> Union[List[str], bool]:
|
||||
loaded_plugins = []
|
||||
|
||||
for dirname in os.listdir("plugins"):
|
||||
if os.path.isdir(f"plugins/{dirname}") and os.path.exists(f"plugins/{dirname}/main.py"):
|
||||
try:
|
||||
module = importlib.import_module(f"plugins.{dirname}.main")
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase:
|
||||
is_disabled = False
|
||||
if not load_disabled:
|
||||
is_disabled = obj.__name__ in self.excluded_plugins or dirname in self.excluded_plugins
|
||||
|
||||
if await self._load_plugin_class(obj, is_disabled=is_disabled):
|
||||
loaded_plugins.append(obj.__name__)
|
||||
except:
|
||||
logger.error(f"加载 {dirname} 时发生错误: {traceback.format_exc()}")
|
||||
|
||||
return loaded_plugins
|
||||
|
||||
async def unload_plugin(self, plugin_name: str) -> bool:
|
||||
"""卸载单个插件"""
|
||||
if plugin_name not in self.plugins:
|
||||
return False
|
||||
|
||||
# 防止卸载 ManagePlugin
|
||||
if plugin_name == "ManagePlugin":
|
||||
logger.warning("ManagePlugin 不能被卸载")
|
||||
return False
|
||||
|
||||
try:
|
||||
plugin = self.plugins[plugin_name]
|
||||
await plugin.on_disable()
|
||||
EventManager.unbind_instance(plugin)
|
||||
del self.plugins[plugin_name]
|
||||
del self.plugin_classes[plugin_name]
|
||||
if plugin_name in self.plugin_info.keys():
|
||||
self.plugin_info[plugin_name]["enabled"] = False
|
||||
logger.success(f"卸载插件 {plugin_name} 成功")
|
||||
return True
|
||||
except:
|
||||
logger.error(f"卸载插件 {plugin_name} 时发生错误: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
async def unload_plugins(self) -> tuple[List[str], List[str]]:
|
||||
"""卸载所有插件"""
|
||||
unloaded_plugins = []
|
||||
failed_unloads = []
|
||||
for plugin_name in list(self.plugins.keys()):
|
||||
if await self.unload_plugin(plugin_name):
|
||||
unloaded_plugins.append(plugin_name)
|
||||
else:
|
||||
failed_unloads.append(plugin_name)
|
||||
return unloaded_plugins, failed_unloads
|
||||
|
||||
async def reload_plugin(self, plugin_name: str) -> bool:
|
||||
"""重载单个插件"""
|
||||
if plugin_name not in self.plugin_classes:
|
||||
return False
|
||||
|
||||
# 防止重载 ManagePlugin
|
||||
if plugin_name == "ManagePlugin":
|
||||
logger.warning("ManagePlugin 不能被重载")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 获取插件类所在的模块
|
||||
plugin_class = self.plugin_classes[plugin_name]
|
||||
module_name = plugin_class.__module__
|
||||
|
||||
# 先卸载插件
|
||||
if not await self.unload_plugin(plugin_name):
|
||||
return False
|
||||
|
||||
# 重新导入模块
|
||||
module = importlib.import_module(module_name)
|
||||
importlib.reload(module)
|
||||
|
||||
# 从重新加载的模块中获取插件类
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if (inspect.isclass(obj) and
|
||||
issubclass(obj, PluginBase) and
|
||||
obj != PluginBase and
|
||||
obj.__name__ == plugin_name):
|
||||
# 使用新的插件类而不是旧的
|
||||
return await self.load_plugin(obj)
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"重载插件 {plugin_name} 时发生错误: {e}")
|
||||
return False
|
||||
|
||||
async def reload_plugins(self) -> List[str]:
|
||||
"""重载所有插件
|
||||
|
||||
Returns:
|
||||
List[str]: 成功重载的插件名称列表
|
||||
"""
|
||||
try:
|
||||
# 记录当前加载的插件名称,排除 ManagePlugin
|
||||
original_plugins = [name for name in self.plugins.keys() if name != "ManagePlugin"]
|
||||
|
||||
# 卸载除 ManagePlugin 外的所有插件
|
||||
for plugin_name in original_plugins:
|
||||
await self.unload_plugin(plugin_name)
|
||||
|
||||
# 重新加载所有模块
|
||||
for module_name in list(sys.modules.keys()):
|
||||
if module_name.startswith('plugins.') and not module_name.endswith('ManagePlugin'):
|
||||
del sys.modules[module_name]
|
||||
|
||||
# 从目录重新加载插件
|
||||
return await self.load_plugins()
|
||||
|
||||
except:
|
||||
logger.error(f"重载所有插件时发生错误: {traceback.format_exc()}")
|
||||
return []
|
||||
|
||||
async def refresh_plugins(self):
|
||||
for dirname in os.listdir("plugins"):
|
||||
try:
|
||||
dirpath = f"plugins/{dirname}"
|
||||
if os.path.isdir(dirpath) and os.path.exists(f"{dirpath}/main.py"):
|
||||
# 验证目录名合法性
|
||||
if not dirname.isidentifier():
|
||||
logger.warning(f"跳过非法插件目录名: {dirname}")
|
||||
continue
|
||||
|
||||
module = importlib.import_module(f"plugins.{dirname}.main")
|
||||
importlib.reload(module)
|
||||
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, PluginBase) and obj != PluginBase:
|
||||
if obj.__name__ not in self.plugin_info.keys():
|
||||
self.plugin_info[obj.__name__] = {
|
||||
"name": obj.__name__,
|
||||
"description": obj.description,
|
||||
"author": obj.author,
|
||||
"version": obj.version,
|
||||
"directory": dirname,
|
||||
"enabled": False,
|
||||
"class": obj
|
||||
}
|
||||
except:
|
||||
logger.error(f"检查 {dirname} 时发生错误: {traceback.format_exc()}")
|
||||
continue
|
||||
|
||||
def get_plugin_info(self, plugin_name: str = None) -> Union[dict, List[dict]]:
|
||||
"""获取插件信息
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称,如果为None则返回所有插件信息
|
||||
|
||||
Returns:
|
||||
如果指定插件名,返回单个插件信息字典;否则返回所有插件信息列表
|
||||
"""
|
||||
if plugin_name:
|
||||
return self.plugin_info.get(plugin_name)
|
||||
return list(self.plugin_info.values())
|
||||
18
utils/singleton.py
Normal file
18
utils/singleton.py
Normal file
@@ -0,0 +1,18 @@
|
||||
class Singleton(type):
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super().__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
@classmethod
|
||||
def reset_instance(mcs, cls):
|
||||
"""重置指定类的单例实例"""
|
||||
if cls in mcs._instances:
|
||||
del mcs._instances[cls]
|
||||
|
||||
@classmethod
|
||||
def reset_all(mcs):
|
||||
"""重置所有单例实例"""
|
||||
mcs._instances.clear()
|
||||
Reference in New Issue
Block a user