Files
WechatHookBot/bot.py
2025-12-03 15:48:44 +08:00

444 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
WechatHookBot - 主入口
基于个微大客户版 Hook API 的微信机器人框架
"""
import asyncio
import sys
import tomllib
from pathlib import Path
from loguru import logger
from WechatHook import NoveLoader, WechatHookClient
from WechatHook.callbacks import (
add_callback_handler,
wechat_connect_callback,
wechat_recv_callback,
wechat_close_callback,
CONNECT_CALLBACK,
RECV_CALLBACK,
CLOSE_CALLBACK
)
from utils.hookbot import HookBot
from utils.plugin_manager import PluginManager
from utils.decorators import scheduler
# from database import KeyvalDB, MessageDB # 不需要数据库
class BotService:
"""机器人服务类"""
def __init__(self):
self.loader = None
self.client = None
self.hookbot = None
self.plugin_manager = None
self.process_id = None # 微信进程 ID
self.socket_client_id = None # Socket 客户端 ID
self.is_running = False
self.event_loop = None # 事件循环引用
# 消息队列和性能控制
self.message_queue = None
self.queue_config = {}
self.concurrency_config = {}
self.consumer_tasks = []
self.processing_semaphore = None
self.circuit_breaker_failures = 0
self.circuit_breaker_open = False
self.circuit_breaker_last_failure = 0
@CONNECT_CALLBACK(in_class=True)
def on_connect(self, client_id):
"""连接回调"""
logger.success(f"微信客户端已连接: {client_id}")
self.socket_client_id = client_id
@RECV_CALLBACK(in_class=True)
def on_receive(self, client_id, msg_type, data):
"""接收消息回调"""
# 减少日志输出,只记录关键消息类型
if msg_type == 11025: # 登录信息
logger.success(f"获取到登录信息: wxid={data.get('wxid', 'unknown')}, nickname={data.get('nickname', 'unknown')}")
if self.hookbot:
self.hookbot.update_profile(data.get('wxid', 'unknown'), data.get('nickname', 'unknown'))
# 初始化 CDN必须在登录后执行才能使用协议 API
if self.client and self.event_loop:
logger.info("正在初始化 CDN...")
asyncio.run_coroutine_threadsafe(
self.client.cdn_init(),
self.event_loop
)
return
# 使用消息队列处理其他消息
if self.message_queue and self.event_loop:
try:
# 快速入队,不阻塞回调
asyncio.run_coroutine_threadsafe(
self._enqueue_message(msg_type, data),
self.event_loop
)
except Exception as e:
logger.error(f"消息入队失败: {e}")
async def _enqueue_message(self, msg_type, data):
"""将消息加入队列"""
try:
# 检查队列是否已满
if self.message_queue.qsize() >= self.queue_config.get("max_size", 1000):
overflow_strategy = self.queue_config.get("overflow_strategy", "drop_oldest")
if overflow_strategy == "drop_oldest":
# 丢弃最旧的消息
try:
self.message_queue.get_nowait()
logger.warning("队列已满,丢弃最旧消息")
except asyncio.QueueEmpty:
pass
elif overflow_strategy == "sampling":
# 采样处理,随机丢弃
import random
if random.random() < 0.5: # 50% 概率丢弃
logger.debug("队列压力大,采样丢弃消息")
return
else: # degrade
logger.warning("队列已满,降级处理")
return
# 将消息放入队列
await self.message_queue.put((msg_type, data))
except Exception as e:
logger.error(f"消息入队异常: {e}")
async def _message_consumer(self, consumer_id: int):
"""消息消费者协程"""
logger.info(f"消息消费者 {consumer_id} 已启动")
while self.is_running:
try:
# 从队列获取消息,设置超时避免无限等待
msg_type, data = await asyncio.wait_for(
self.message_queue.get(),
timeout=1.0
)
# 检查熔断器状态
if self._check_circuit_breaker():
logger.debug("熔断器开启,跳过消息处理")
continue
# 创建并发任务,不等待完成
timeout = self.concurrency_config.get("plugin_task_timeout_seconds", 5)
# 使用信号量控制并发数量
async def process_with_semaphore():
async with self.processing_semaphore:
try:
await asyncio.wait_for(
self.hookbot.process_message(msg_type, data),
timeout=timeout
)
self._reset_circuit_breaker()
except asyncio.TimeoutError:
logger.warning(f"消息处理超时 (>{timeout}s): type={msg_type}")
self._record_circuit_breaker_failure()
except Exception as e:
logger.error(f"消息处理异常: {e}")
self._record_circuit_breaker_failure()
# 创建任务但不等待,实现真正并发
asyncio.create_task(process_with_semaphore())
# 标记任务完成
self.message_queue.task_done()
except asyncio.TimeoutError:
# 队列为空,继续等待
continue
except Exception as e:
logger.error(f"消费者 {consumer_id} 异常: {e}")
await asyncio.sleep(0.1) # 短暂休息避免忙等
def _check_circuit_breaker(self) -> bool:
"""检查熔断器状态"""
if not self.concurrency_config.get("enable_circuit_breaker", True):
return False
if self.circuit_breaker_open:
# 检查是否可以尝试恢复
import time
if time.time() - self.circuit_breaker_last_failure > 30: # 30秒后尝试恢复
self.circuit_breaker_open = False
self.circuit_breaker_failures = 0
logger.info("熔断器尝试恢复")
return False
return True
return False
def _record_circuit_breaker_failure(self):
"""记录熔断器失败"""
if not self.concurrency_config.get("enable_circuit_breaker", True):
return
self.circuit_breaker_failures += 1
threshold = self.concurrency_config.get("circuit_breaker_threshold", 5)
if self.circuit_breaker_failures >= threshold:
import time
self.circuit_breaker_open = True
self.circuit_breaker_last_failure = time.time()
logger.warning(f"熔断器开启,连续失败 {self.circuit_breaker_failures}")
def _reset_circuit_breaker(self):
"""重置熔断器"""
if self.circuit_breaker_failures > 0:
self.circuit_breaker_failures = 0
@CLOSE_CALLBACK(in_class=True)
def on_close(self, client_id):
"""断开连接回调"""
logger.warning(f"微信客户端已断开: {client_id}")
async def _wait_for_socket(self, timeout_seconds: int = 15) -> bool:
"""等待 socket 客户端连接"""
elapsed = 0
while elapsed < timeout_seconds:
if self.socket_client_id:
return True
await asyncio.sleep(1)
elapsed += 1
logger.info(f"等待微信客户端连接中... ({elapsed}/{timeout_seconds}s)")
return False
async def initialize(self):
"""初始化系统"""
logger.info("=" * 60)
logger.info("WechatHookBot 启动中...")
logger.info("=" * 60)
# 保存事件循环引用
self.event_loop = asyncio.get_event_loop()
# 读取配置
config_path = Path("main_config.toml")
if not config_path.exists():
logger.error("配置文件不存在: main_config.toml")
return False
with open(config_path, "rb") as f:
config = tomllib.load(f)
# 初始化性能配置
self.queue_config = config.get("Queue", {})
self.concurrency_config = config.get("Concurrency", {})
# 创建消息队列
queue_size = self.queue_config.get("max_size", 1000)
self.message_queue = asyncio.Queue(maxsize=queue_size)
logger.info(f"消息队列已创建,容量: {queue_size}")
# 创建并发控制信号量
max_concurrency = self.concurrency_config.get("plugin_max_concurrency", 8)
self.processing_semaphore = asyncio.Semaphore(max_concurrency)
logger.info(f"并发控制已设置,最大并发: {max_concurrency}")
# 不需要数据库(简化版本)
# 获取 DLL 路径
hook_config = config.get("WechatHook", {})
loader_dll = hook_config.get("loader-dll", "libs/Loader.dll")
helper_dll = hook_config.get("helper-dll", "libs/Helper.dll")
# 创建共享内存(必须在创建 Loader 之前)
from WechatHook.loader import create_shared_memory
logger.info("创建共享内存...")
self.shared_memory_handle, self.shared_memory_address = create_shared_memory()
# 注册回调(必须在创建 Loader 之前)
add_callback_handler(self)
# 创建 Loader
logger.info("加载 Loader.dll...")
try:
self.loader = NoveLoader(loader_dll)
except Exception as e:
logger.error(f"加载 Loader.dll 失败: {e}")
return False
try:
version = self.loader.GetUserWeChatVersion()
logger.info(f"检测到本机微信版本: {version}")
except Exception as e:
logger.warning(f"无法获取微信版本信息: {e}")
# 注入微信
logger.info("注入微信...")
self.process_id = self.loader.InjectWeChat(helper_dll)
if not self.process_id:
logger.error("注入微信失败")
return False
# 等待 socket 客户端回调
if not await self._wait_for_socket(timeout_seconds=20):
logger.error("Socket 客户端未连接,请检查微信是否正在运行")
return False
# 额外等待 0.5s 确保稳定
await asyncio.sleep(0.5)
self.client = WechatHookClient(self.loader, self.socket_client_id)
# 创建 HookBot
self.hookbot = HookBot(self.client)
# 获取登录信息
logger.info("获取登录信息...")
await self.client.get_login_info()
await asyncio.sleep(2) # 增加等待时间确保回调执行
# 检查是否已通过回调获取到登录信息
if not self.hookbot.wxid:
logger.warning("未能通过回调获取登录信息,使用占位符")
self.hookbot.update_profile("unknown", "HookBot")
# 初始化 CDN必须在登录后执行才能使用协议 API
logger.info("正在初始化 CDN...")
await self.client.cdn_init()
await asyncio.sleep(0.5) # 等待 CDN 初始化完成
# 加载插件
logger.info("加载插件...")
self.plugin_manager = PluginManager()
self.plugin_manager.set_bot(self.client)
loaded_plugins = await self.plugin_manager.load_plugins(load_disabled=False)
logger.success(f"已加载插件: {loaded_plugins}")
# 启动消息消费者
consumer_count = self.queue_config.get("consumer_count", 1)
for i in range(consumer_count):
consumer_task = asyncio.create_task(self._message_consumer(i))
self.consumer_tasks.append(consumer_task)
logger.success(f"已启动 {consumer_count} 个消息消费者")
# 启动定时任务
if scheduler.state == 0:
scheduler.start()
logger.success("定时任务已启动")
# 记录启动时间
import time
self.start_time = int(time.time())
logger.info(f"启动时间: {self.start_time}")
logger.success("=" * 60)
logger.success("WechatHookBot 启动成功!")
logger.success("=" * 60)
return True
async def run(self):
"""运行机器人"""
if not await self.initialize():
return
self.is_running = True
try:
logger.info("机器人正在运行,按 Ctrl+C 停止...")
while self.is_running:
await asyncio.sleep(1)
except KeyboardInterrupt:
logger.info("收到停止信号...")
finally:
await self.stop()
async def stop(self):
"""停止机器人"""
logger.info("正在停止机器人...")
self.is_running = False
# 停止消息消费者
if self.consumer_tasks:
logger.info("正在停止消息消费者...")
for task in self.consumer_tasks:
task.cancel()
# 等待所有消费者任务完成
if self.consumer_tasks:
await asyncio.gather(*self.consumer_tasks, return_exceptions=True)
self.consumer_tasks.clear()
logger.info("消息消费者已停止")
# 清空消息队列
if self.message_queue:
while not self.message_queue.empty():
try:
self.message_queue.get_nowait()
self.message_queue.task_done()
except asyncio.QueueEmpty:
break
logger.info("消息队列已清空")
# 停止定时任务
if scheduler.running:
scheduler.shutdown()
# 销毁微信连接
if self.loader:
self.loader.DestroyWeChat()
logger.success("机器人已停止")
async def main():
"""主函数"""
# 读取性能配置
config_path = Path("main_config.toml")
if config_path.exists():
with open(config_path, "rb") as f:
config = tomllib.load(f)
perf_config = config.get("Performance", {})
else:
perf_config = {}
# 配置日志
logger.remove()
# 控制台日志(启动阶段始终启用,稳定后可配置禁用)
console_enabled = perf_config.get("log_console_enabled", True)
logger.add(
sys.stdout,
colorize=perf_config.get("log_colorize", True),
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level}</level> | {message}",
level=perf_config.get("log_level_console", "INFO"),
filter=lambda record: console_enabled or "启动" in record["message"] or "初始化" in record["message"] or "成功" in record["message"] or "失败" in record["message"] or "错误" in record["message"]
)
# 文件日志(始终启用)
logger.add(
"logs/hookbot.log",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
encoding="utf-8",
rotation="5mb", # 减小文件大小
retention="1 week", # 缩短保留时间
level=perf_config.get("log_level_file", "INFO")
)
# 创建并运行服务
service = BotService()
await service.run()
if __name__ == "__main__":
# 检查 Python 版本
if sys.maxsize > 2**32:
logger.error("请使用 32位 Python 运行此程序!")
sys.exit(1)
# 运行
asyncio.run(main())