feat:初版
This commit is contained in:
443
bot.py
Normal file
443
bot.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
WechatHookBot - 主入口
|
||||
|
||||
基于个微大客户版 Hook API 的微信机器人框架
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
from WechatHook import NoveLoader, WechatHookClient
|
||||
from WechatHook.callbacks import (
|
||||
add_callback_handler,
|
||||
wechat_connect_callback,
|
||||
wechat_recv_callback,
|
||||
wechat_close_callback,
|
||||
CONNECT_CALLBACK,
|
||||
RECV_CALLBACK,
|
||||
CLOSE_CALLBACK
|
||||
)
|
||||
from utils.hookbot import HookBot
|
||||
from utils.plugin_manager import PluginManager
|
||||
from utils.decorators import scheduler
|
||||
# from database import KeyvalDB, MessageDB # 不需要数据库
|
||||
|
||||
|
||||
class BotService:
|
||||
"""机器人服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self.loader = None
|
||||
self.client = None
|
||||
self.hookbot = None
|
||||
self.plugin_manager = None
|
||||
self.process_id = None # 微信进程 ID
|
||||
self.socket_client_id = None # Socket 客户端 ID
|
||||
self.is_running = False
|
||||
self.event_loop = None # 事件循环引用
|
||||
|
||||
# 消息队列和性能控制
|
||||
self.message_queue = None
|
||||
self.queue_config = {}
|
||||
self.concurrency_config = {}
|
||||
self.consumer_tasks = []
|
||||
self.processing_semaphore = None
|
||||
self.circuit_breaker_failures = 0
|
||||
self.circuit_breaker_open = False
|
||||
self.circuit_breaker_last_failure = 0
|
||||
|
||||
@CONNECT_CALLBACK(in_class=True)
|
||||
def on_connect(self, client_id):
|
||||
"""连接回调"""
|
||||
logger.success(f"微信客户端已连接: {client_id}")
|
||||
self.socket_client_id = client_id
|
||||
|
||||
@RECV_CALLBACK(in_class=True)
|
||||
def on_receive(self, client_id, msg_type, data):
|
||||
"""接收消息回调"""
|
||||
# 减少日志输出,只记录关键消息类型
|
||||
if msg_type == 11025: # 登录信息
|
||||
logger.success(f"获取到登录信息: wxid={data.get('wxid', 'unknown')}, nickname={data.get('nickname', 'unknown')}")
|
||||
if self.hookbot:
|
||||
self.hookbot.update_profile(data.get('wxid', 'unknown'), data.get('nickname', 'unknown'))
|
||||
|
||||
# 初始化 CDN(必须在登录后执行,才能使用协议 API)
|
||||
if self.client and self.event_loop:
|
||||
logger.info("正在初始化 CDN...")
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.client.cdn_init(),
|
||||
self.event_loop
|
||||
)
|
||||
return
|
||||
|
||||
# 使用消息队列处理其他消息
|
||||
if self.message_queue and self.event_loop:
|
||||
try:
|
||||
# 快速入队,不阻塞回调
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._enqueue_message(msg_type, data),
|
||||
self.event_loop
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"消息入队失败: {e}")
|
||||
|
||||
async def _enqueue_message(self, msg_type, data):
|
||||
"""将消息加入队列"""
|
||||
try:
|
||||
# 检查队列是否已满
|
||||
if self.message_queue.qsize() >= self.queue_config.get("max_size", 1000):
|
||||
overflow_strategy = self.queue_config.get("overflow_strategy", "drop_oldest")
|
||||
|
||||
if overflow_strategy == "drop_oldest":
|
||||
# 丢弃最旧的消息
|
||||
try:
|
||||
self.message_queue.get_nowait()
|
||||
logger.warning("队列已满,丢弃最旧消息")
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
elif overflow_strategy == "sampling":
|
||||
# 采样处理,随机丢弃
|
||||
import random
|
||||
if random.random() < 0.5: # 50% 概率丢弃
|
||||
logger.debug("队列压力大,采样丢弃消息")
|
||||
return
|
||||
else: # degrade
|
||||
logger.warning("队列已满,降级处理")
|
||||
return
|
||||
|
||||
# 将消息放入队列
|
||||
await self.message_queue.put((msg_type, data))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"消息入队异常: {e}")
|
||||
|
||||
async def _message_consumer(self, consumer_id: int):
|
||||
"""消息消费者协程"""
|
||||
logger.info(f"消息消费者 {consumer_id} 已启动")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
# 从队列获取消息,设置超时避免无限等待
|
||||
msg_type, data = await asyncio.wait_for(
|
||||
self.message_queue.get(),
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
# 检查熔断器状态
|
||||
if self._check_circuit_breaker():
|
||||
logger.debug("熔断器开启,跳过消息处理")
|
||||
continue
|
||||
|
||||
# 创建并发任务,不等待完成
|
||||
timeout = self.concurrency_config.get("plugin_task_timeout_seconds", 5)
|
||||
|
||||
# 使用信号量控制并发数量
|
||||
async def process_with_semaphore():
|
||||
async with self.processing_semaphore:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.hookbot.process_message(msg_type, data),
|
||||
timeout=timeout
|
||||
)
|
||||
self._reset_circuit_breaker()
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"消息处理超时 (>{timeout}s): type={msg_type}")
|
||||
self._record_circuit_breaker_failure()
|
||||
except Exception as e:
|
||||
logger.error(f"消息处理异常: {e}")
|
||||
self._record_circuit_breaker_failure()
|
||||
|
||||
# 创建任务但不等待,实现真正并发
|
||||
asyncio.create_task(process_with_semaphore())
|
||||
|
||||
# 标记任务完成
|
||||
self.message_queue.task_done()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 队列为空,继续等待
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"消费者 {consumer_id} 异常: {e}")
|
||||
await asyncio.sleep(0.1) # 短暂休息避免忙等
|
||||
|
||||
def _check_circuit_breaker(self) -> bool:
|
||||
"""检查熔断器状态"""
|
||||
if not self.concurrency_config.get("enable_circuit_breaker", True):
|
||||
return False
|
||||
|
||||
if self.circuit_breaker_open:
|
||||
# 检查是否可以尝试恢复
|
||||
import time
|
||||
if time.time() - self.circuit_breaker_last_failure > 30: # 30秒后尝试恢复
|
||||
self.circuit_breaker_open = False
|
||||
self.circuit_breaker_failures = 0
|
||||
logger.info("熔断器尝试恢复")
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
def _record_circuit_breaker_failure(self):
|
||||
"""记录熔断器失败"""
|
||||
if not self.concurrency_config.get("enable_circuit_breaker", True):
|
||||
return
|
||||
|
||||
self.circuit_breaker_failures += 1
|
||||
threshold = self.concurrency_config.get("circuit_breaker_threshold", 5)
|
||||
|
||||
if self.circuit_breaker_failures >= threshold:
|
||||
import time
|
||||
self.circuit_breaker_open = True
|
||||
self.circuit_breaker_last_failure = time.time()
|
||||
logger.warning(f"熔断器开启,连续失败 {self.circuit_breaker_failures} 次")
|
||||
|
||||
def _reset_circuit_breaker(self):
|
||||
"""重置熔断器"""
|
||||
if self.circuit_breaker_failures > 0:
|
||||
self.circuit_breaker_failures = 0
|
||||
|
||||
@CLOSE_CALLBACK(in_class=True)
|
||||
def on_close(self, client_id):
|
||||
"""断开连接回调"""
|
||||
logger.warning(f"微信客户端已断开: {client_id}")
|
||||
|
||||
async def _wait_for_socket(self, timeout_seconds: int = 15) -> bool:
|
||||
"""等待 socket 客户端连接"""
|
||||
elapsed = 0
|
||||
while elapsed < timeout_seconds:
|
||||
if self.socket_client_id:
|
||||
return True
|
||||
await asyncio.sleep(1)
|
||||
elapsed += 1
|
||||
logger.info(f"等待微信客户端连接中... ({elapsed}/{timeout_seconds}s)")
|
||||
return False
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化系统"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("WechatHookBot 启动中...")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 保存事件循环引用
|
||||
self.event_loop = asyncio.get_event_loop()
|
||||
|
||||
# 读取配置
|
||||
config_path = Path("main_config.toml")
|
||||
if not config_path.exists():
|
||||
logger.error("配置文件不存在: main_config.toml")
|
||||
return False
|
||||
|
||||
with open(config_path, "rb") as f:
|
||||
config = tomllib.load(f)
|
||||
|
||||
# 初始化性能配置
|
||||
self.queue_config = config.get("Queue", {})
|
||||
self.concurrency_config = config.get("Concurrency", {})
|
||||
|
||||
# 创建消息队列
|
||||
queue_size = self.queue_config.get("max_size", 1000)
|
||||
self.message_queue = asyncio.Queue(maxsize=queue_size)
|
||||
logger.info(f"消息队列已创建,容量: {queue_size}")
|
||||
|
||||
# 创建并发控制信号量
|
||||
max_concurrency = self.concurrency_config.get("plugin_max_concurrency", 8)
|
||||
self.processing_semaphore = asyncio.Semaphore(max_concurrency)
|
||||
logger.info(f"并发控制已设置,最大并发: {max_concurrency}")
|
||||
|
||||
# 不需要数据库(简化版本)
|
||||
|
||||
# 获取 DLL 路径
|
||||
hook_config = config.get("WechatHook", {})
|
||||
loader_dll = hook_config.get("loader-dll", "libs/Loader.dll")
|
||||
helper_dll = hook_config.get("helper-dll", "libs/Helper.dll")
|
||||
|
||||
# 创建共享内存(必须在创建 Loader 之前)
|
||||
from WechatHook.loader import create_shared_memory
|
||||
logger.info("创建共享内存...")
|
||||
self.shared_memory_handle, self.shared_memory_address = create_shared_memory()
|
||||
|
||||
# 注册回调(必须在创建 Loader 之前)
|
||||
add_callback_handler(self)
|
||||
|
||||
# 创建 Loader
|
||||
logger.info("加载 Loader.dll...")
|
||||
try:
|
||||
self.loader = NoveLoader(loader_dll)
|
||||
except Exception as e:
|
||||
logger.error(f"加载 Loader.dll 失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
version = self.loader.GetUserWeChatVersion()
|
||||
logger.info(f"检测到本机微信版本: {version}")
|
||||
except Exception as e:
|
||||
logger.warning(f"无法获取微信版本信息: {e}")
|
||||
|
||||
# 注入微信
|
||||
logger.info("注入微信...")
|
||||
self.process_id = self.loader.InjectWeChat(helper_dll)
|
||||
if not self.process_id:
|
||||
logger.error("注入微信失败")
|
||||
return False
|
||||
|
||||
# 等待 socket 客户端回调
|
||||
if not await self._wait_for_socket(timeout_seconds=20):
|
||||
logger.error("Socket 客户端未连接,请检查微信是否正在运行")
|
||||
return False
|
||||
|
||||
# 额外等待 0.5s 确保稳定
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
self.client = WechatHookClient(self.loader, self.socket_client_id)
|
||||
|
||||
# 创建 HookBot
|
||||
self.hookbot = HookBot(self.client)
|
||||
|
||||
# 获取登录信息
|
||||
logger.info("获取登录信息...")
|
||||
await self.client.get_login_info()
|
||||
await asyncio.sleep(2) # 增加等待时间确保回调执行
|
||||
|
||||
# 检查是否已通过回调获取到登录信息
|
||||
if not self.hookbot.wxid:
|
||||
logger.warning("未能通过回调获取登录信息,使用占位符")
|
||||
self.hookbot.update_profile("unknown", "HookBot")
|
||||
|
||||
# 初始化 CDN(必须在登录后执行,才能使用协议 API)
|
||||
logger.info("正在初始化 CDN...")
|
||||
await self.client.cdn_init()
|
||||
await asyncio.sleep(0.5) # 等待 CDN 初始化完成
|
||||
|
||||
# 加载插件
|
||||
logger.info("加载插件...")
|
||||
self.plugin_manager = PluginManager()
|
||||
self.plugin_manager.set_bot(self.client)
|
||||
loaded_plugins = await self.plugin_manager.load_plugins(load_disabled=False)
|
||||
logger.success(f"已加载插件: {loaded_plugins}")
|
||||
|
||||
# 启动消息消费者
|
||||
consumer_count = self.queue_config.get("consumer_count", 1)
|
||||
for i in range(consumer_count):
|
||||
consumer_task = asyncio.create_task(self._message_consumer(i))
|
||||
self.consumer_tasks.append(consumer_task)
|
||||
logger.success(f"已启动 {consumer_count} 个消息消费者")
|
||||
|
||||
# 启动定时任务
|
||||
if scheduler.state == 0:
|
||||
scheduler.start()
|
||||
logger.success("定时任务已启动")
|
||||
|
||||
# 记录启动时间
|
||||
import time
|
||||
self.start_time = int(time.time())
|
||||
logger.info(f"启动时间: {self.start_time}")
|
||||
|
||||
logger.success("=" * 60)
|
||||
logger.success("WechatHookBot 启动成功!")
|
||||
logger.success("=" * 60)
|
||||
|
||||
return True
|
||||
|
||||
async def run(self):
|
||||
"""运行机器人"""
|
||||
if not await self.initialize():
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
|
||||
try:
|
||||
logger.info("机器人正在运行,按 Ctrl+C 停止...")
|
||||
while self.is_running:
|
||||
await asyncio.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("收到停止信号...")
|
||||
finally:
|
||||
await self.stop()
|
||||
|
||||
async def stop(self):
|
||||
"""停止机器人"""
|
||||
logger.info("正在停止机器人...")
|
||||
self.is_running = False
|
||||
|
||||
# 停止消息消费者
|
||||
if self.consumer_tasks:
|
||||
logger.info("正在停止消息消费者...")
|
||||
for task in self.consumer_tasks:
|
||||
task.cancel()
|
||||
|
||||
# 等待所有消费者任务完成
|
||||
if self.consumer_tasks:
|
||||
await asyncio.gather(*self.consumer_tasks, return_exceptions=True)
|
||||
self.consumer_tasks.clear()
|
||||
logger.info("消息消费者已停止")
|
||||
|
||||
# 清空消息队列
|
||||
if self.message_queue:
|
||||
while not self.message_queue.empty():
|
||||
try:
|
||||
self.message_queue.get_nowait()
|
||||
self.message_queue.task_done()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
logger.info("消息队列已清空")
|
||||
|
||||
# 停止定时任务
|
||||
if scheduler.running:
|
||||
scheduler.shutdown()
|
||||
|
||||
# 销毁微信连接
|
||||
if self.loader:
|
||||
self.loader.DestroyWeChat()
|
||||
|
||||
logger.success("机器人已停止")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
# 读取性能配置
|
||||
config_path = Path("main_config.toml")
|
||||
if config_path.exists():
|
||||
with open(config_path, "rb") as f:
|
||||
config = tomllib.load(f)
|
||||
perf_config = config.get("Performance", {})
|
||||
else:
|
||||
perf_config = {}
|
||||
|
||||
# 配置日志
|
||||
logger.remove()
|
||||
|
||||
# 控制台日志(启动阶段始终启用,稳定后可配置禁用)
|
||||
console_enabled = perf_config.get("log_console_enabled", True)
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
colorize=perf_config.get("log_colorize", True),
|
||||
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level}</level> | {message}",
|
||||
level=perf_config.get("log_level_console", "INFO"),
|
||||
filter=lambda record: console_enabled or "启动" in record["message"] or "初始化" in record["message"] or "成功" in record["message"] or "失败" in record["message"] or "错误" in record["message"]
|
||||
)
|
||||
|
||||
# 文件日志(始终启用)
|
||||
logger.add(
|
||||
"logs/hookbot.log",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
|
||||
encoding="utf-8",
|
||||
rotation="5mb", # 减小文件大小
|
||||
retention="1 week", # 缩短保留时间
|
||||
level=perf_config.get("log_level_file", "INFO")
|
||||
)
|
||||
|
||||
# 创建并运行服务
|
||||
service = BotService()
|
||||
await service.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 检查 Python 版本
|
||||
if sys.maxsize > 2**32:
|
||||
logger.error("请使用 32位 Python 运行此程序!")
|
||||
sys.exit(1)
|
||||
|
||||
# 运行
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user