Files
WeChatHookBot/bot.py

703 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
WechatHookBot - 主入口
基于新版 HTTP Hook API 的微信机器人框架
特点:
- HTTP 回调接收消息
- HTTP API 发送消息
- 无需 DLL 注入DLL 放到微信目录自动加载)
- 优先级消息队列
- 自适应熔断器
- 配置热更新
- 性能监控
- 优雅关闭
"""
import asyncio
import signal
import sys
import time
from pathlib import Path
from loguru import logger
from WechatHook import WechatHookClient
from WechatHook.http_server import CallbackServer
from WechatHook.message_types import normalize_from_callback, get_internal_msg_type
from utils.hookbot import HookBot
from utils.config_manager import get_config, get_main_config_path, get_project_root
from utils.plugin_manager import PluginManager
from utils.decorators import scheduler
from utils.message_queue import PriorityMessageQueue, MessagePriority
from utils.bot_utils import (
PRIORITY_MESSAGE_TYPES,
AdaptiveCircuitBreaker,
ConfigWatcher,
PerformanceMonitor,
get_performance_monitor
)
from utils.operation_lock import OperationLock
class BotService:
"""机器人服务类"""
def __init__(self):
self.base_dir = get_project_root()
self.config_path = get_main_config_path()
self.client: WechatHookClient = None
self.callback_server: CallbackServer = None
self.hookbot: HookBot = None
self.plugin_manager: PluginManager = None
self.is_running = False
self.is_shutting_down = False
self.event_loop = None
# 消息队列和性能控制
self.message_queue: PriorityMessageQueue = None
self.queue_config = {}
self.concurrency_config = {}
self.consumer_tasks = []
self.processing_semaphore = None
# 自适应熔断器
self.circuit_breaker: AdaptiveCircuitBreaker = None
# 配置热更新
self.config_watcher: ConfigWatcher = None
# 性能监控
self.performance_monitor: PerformanceMonitor = None
# 配置
self.config = {}
# WebUI
self.webui_server = None
async def on_message_callback(self, message_type: str, data: dict):
"""
HTTP 回调消息处理
Args:
message_type: 消息类型 (private_message/group_message/moments_message/chatroom_member_add/chatroom_member_remove/chatroom_info_change/chatroom_member_nickname_change)
data: 原始消息数据
"""
if OperationLock.is_paused():
logger.debug(f"更新中忽略消息: type={message_type}")
return
if self.is_shutting_down:
logger.debug(f"关闭中忽略消息: type={message_type}")
return
# 跳过朋友圈消息
if message_type == "moments_message":
logger.debug("跳过朋友圈消息")
return
# 处理群事件event_type 类型的消息)
if message_type in ["chatroom_member_add", "chatroom_member_remove", "chatroom_info_change", "chatroom_member_nickname_change"]:
await self._handle_chatroom_event(message_type, data)
return
# 使用消息队列处理普通消息
if self.message_queue and self.event_loop:
try:
await self._enqueue_message(message_type, data)
except Exception as e:
logger.error(f"消息入队失败: {e}")
else:
logger.warning(f"消息队列未就绪: queue={self.message_queue is not None}, loop={self.event_loop is not None}")
async def _handle_chatroom_event(self, event_type: str, data: dict):
"""
处理群事件event_type 类型的消息)
Args:
event_type: 事件类型 (chatroom_member_add/chatroom_member_remove/chatroom_info_change/chatroom_member_nickname_change)
data: 原始事件数据
"""
try:
logger.info(f"[群事件] 收到事件: {event_type}")
# 提取事件数据
event_data = data.get("data", {})
room_wxid = event_data.get("roomid", "")
member_count = event_data.get("membercount", 0)
member_list_data = event_data.get("memberlist", {})
# 构造标准化的消息格式
normalized_msg = {
"MsgType": self._get_event_msg_type(event_type),
"RoomWxid": room_wxid,
"MemberCount": member_count,
"MemberList": [],
}
# 处理成员列表(可能是单个对象或数组)
if isinstance(member_list_data, dict):
# 单个成员
member_info = {
"wxid": member_list_data.get("userName", ""),
"nickname": member_list_data.get("nickName", ""),
"display_name": member_list_data.get("displayName", ""),
"avatar": member_list_data.get("bigHeadImgUrl", ""),
}
normalized_msg["MemberList"].append(member_info)
elif isinstance(member_list_data, list):
# 多个成员
for member in member_list_data:
member_info = {
"wxid": member.get("userName", ""),
"nickname": member.get("nickName", ""),
"display_name": member.get("displayName", ""),
"avatar": member.get("bigHeadImgUrl", ""),
}
normalized_msg["MemberList"].append(member_info)
logger.info(f"[群事件] 标准化消息: room={room_wxid}, members={len(normalized_msg['MemberList'])}")
# 直接触发事件(不经过消息队列)
from utils.event_manager import EventManager
await EventManager.emit(event_type, self.client, normalized_msg)
except Exception as e:
logger.error(f"处理群事件失败: {e}")
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
def _get_event_msg_type(self, event_type: str) -> int:
"""获取事件对应的消息类型码"""
from WechatHook.message_types import MessageType
event_map = {
"chatroom_member_add": MessageType.MT_CHATROOM_MEMBER_ADD,
"chatroom_member_remove": MessageType.MT_CHATROOM_MEMBER_REMOVE,
"chatroom_info_change": MessageType.MT_CHATROOM_INFO_CHANGE,
"chatroom_member_nickname_change": MessageType.MT_CHATROOM_MEMBER_NICKNAME_CHANGE,
}
return event_map.get(event_type, 11046)
async def _enqueue_message(self, message_type: str, data: dict):
"""将消息加入优先级队列"""
try:
# 记录收到消息
if self.performance_monitor:
self.performance_monitor.record_message_received()
# 获取内部消息类型
wechat_msg_type = str(data.get("msgType", "1"))
internal_type = get_internal_msg_type(wechat_msg_type, data)
priority = PRIORITY_MESSAGE_TYPES.get(internal_type, MessagePriority.NORMAL)
# 将消息放入优先级队列
# 存储 (message_type, data) 元组
accepted = await self.message_queue.put(
internal_type,
{"_callback_type": message_type, **data},
priority=priority
)
if not accepted:
if self.performance_monitor:
self.performance_monitor.record_message_dropped()
return
# 记录队列大小
if self.performance_monitor:
self.performance_monitor.record_queue_size(self.message_queue.qsize())
except Exception as e:
logger.error(f"消息入队异常: {e}")
async def _ensure_consumer_count(self, target_count: int):
"""按目标数量调整消费者协程。"""
target_count = max(int(target_count), 1)
current_count = len(self.consumer_tasks)
if target_count > current_count:
for consumer_id in range(current_count, target_count):
consumer_task = asyncio.create_task(self._message_consumer(consumer_id))
self.consumer_tasks.append(consumer_task)
logger.info(f"消息消费者数量已扩容到 {target_count}")
return
if target_count < current_count:
tasks_to_stop = self.consumer_tasks[target_count:]
self.consumer_tasks = self.consumer_tasks[:target_count]
for task in tasks_to_stop:
task.cancel()
await asyncio.gather(*tasks_to_stop, return_exceptions=True)
logger.info(f"消息消费者数量已缩容到 {target_count}")
async def _message_consumer(self, consumer_id: int):
"""消息消费者协程"""
logger.info(f"消息消费者 {consumer_id} 已启动")
while True:
if self.is_shutting_down and (not self.message_queue or self.message_queue.empty()):
break
try:
if OperationLock.is_paused():
await OperationLock.wait_if_paused()
continue
message_acquired = False
msg_type = None
data = None
try:
# 从队列获取消息
msg_type, data = await asyncio.wait_for(
self.message_queue.get(),
timeout=1.0
)
message_acquired = True
# 检查熔断器状态
if self.circuit_breaker and self.circuit_breaker.is_open():
logger.debug("熔断器开启,跳过消息处理")
self.circuit_breaker.record_rejection()
continue
# 标准化消息
callback_type = data.pop("_callback_type", "private_message")
normalized_msg = normalize_from_callback(callback_type, data)
# 从消息中提取群成员信息并缓存
if callback_type == "group_message" and self.client:
sender_profile = data.get("sender_profile") or {}
new_chatroom_data = sender_profile.get("newChatroomData") or {}
members = new_chatroom_data.get("chatRoomMember") or []
room_id = normalized_msg.get("RoomWxid", "")
if members and room_id:
self.client.update_chatroom_members_cache(room_id, members)
# 处理消息
timeout = self.concurrency_config.get("plugin_task_timeout_seconds", 720)
start_time = time.time()
try:
if self.processing_semaphore:
async with self.processing_semaphore:
await asyncio.wait_for(
self.hookbot.process_message(msg_type, normalized_msg),
timeout=timeout
)
else:
await asyncio.wait_for(
self.hookbot.process_message(msg_type, normalized_msg),
timeout=timeout
)
processing_time = time.time() - start_time
if self.circuit_breaker:
self.circuit_breaker.record_success()
if self.performance_monitor:
self.performance_monitor.record_message_processed(processing_time)
except asyncio.TimeoutError:
logger.warning(f"消息处理超时 (>{timeout}s): type={msg_type}")
if self.circuit_breaker:
self.circuit_breaker.record_failure()
if self.performance_monitor:
self.performance_monitor.record_message_failed()
except Exception as e:
logger.error(f"消息处理异常: {e}")
if self.circuit_breaker:
self.circuit_breaker.record_failure()
if self.performance_monitor:
self.performance_monitor.record_message_failed()
# 更新熔断器统计
if self.performance_monitor and self.circuit_breaker:
self.performance_monitor.update_circuit_breaker_stats(
self.circuit_breaker.get_stats()
)
# 消息间隔
message_interval = self.concurrency_config.get("message_interval_ms", 100)
if message_interval > 0:
await asyncio.sleep(message_interval / 1000.0)
except asyncio.TimeoutError:
if self.is_shutting_down and self.message_queue and self.message_queue.empty():
break
continue
finally:
if message_acquired and self.message_queue:
self.message_queue.task_done()
except asyncio.CancelledError:
logger.info(f"消费者 {consumer_id} 收到取消信号")
break
except Exception as e:
logger.error(f"消费者 {consumer_id} 异常: {e}")
import traceback
logger.error(traceback.format_exc())
await asyncio.sleep(0.1)
logger.info(f"消费者 {consumer_id} 已退出")
async def initialize(self):
"""初始化系统"""
logger.info("=" * 60)
logger.info("WechatHookBot 启动中... (HTTP 协议版本)")
logger.info("=" * 60)
self.event_loop = asyncio.get_event_loop()
# 读取配置
config_path = self.config_path
if not config_path.exists():
logger.error(f"配置文件不存在: {config_path}")
return False
self.config = get_config().get_all()
# 初始化性能配置
self.queue_config = self.config.get("Queue", {})
self.concurrency_config = self.config.get("Concurrency", {})
# 创建优先级消息队列
self.message_queue = PriorityMessageQueue.from_config(self.queue_config)
logger.info(
f"优先级消息队列已创建,容量: {self.message_queue.maxsize}, "
f"溢出策略: {self.message_queue.overflow_strategy.value}"
)
# 创建并发控制信号量
max_concurrency = self.concurrency_config.get("plugin_max_concurrency", 8)
self.processing_semaphore = asyncio.Semaphore(max_concurrency)
logger.info(f"并发控制已设置,最大并发: {max_concurrency}")
# 创建自适应熔断器
if self.concurrency_config.get("enable_circuit_breaker", True):
self.circuit_breaker = AdaptiveCircuitBreaker(
failure_threshold=self.concurrency_config.get("circuit_breaker_threshold", 10),
success_threshold=3,
initial_recovery_time=5.0,
max_recovery_time=300.0
)
logger.info("自适应熔断器已创建")
# 创建性能监控器
self.performance_monitor = get_performance_monitor()
logger.info("性能监控器已创建")
# 创建配置热更新监听器
self.config_watcher = ConfigWatcher(str(self.config_path), check_interval=5.0)
self.config_watcher.register_callback(self._on_config_update)
await self.config_watcher.start()
logger.info("配置热更新监听器已启动")
# 获取 HTTP 配置
http_config = self.config.get("HttpHook", {})
api_base_url = http_config.get("api-url", "http://127.0.0.1:8888")
callback_host = http_config.get("callback-host", "0.0.0.0")
callback_port = http_config.get("callback-port", 9999)
# 创建 HTTP 客户端
logger.info(f"连接 Hook API: {api_base_url}")
self.client = WechatHookClient(base_url=api_base_url)
# 创建 HookBot
self.hookbot = HookBot(self.client)
# 微信初始化(刷新好友列表、群列表缓存)
logger.info("执行微信初始化...")
if await self.client.wechat_init():
logger.success("微信初始化成功")
else:
logger.warning("微信初始化失败,部分功能可能受影响")
# 获取登录信息
logger.info("获取登录信息...")
login_info = await self.client.get_login_info()
if login_info and self.client.wxid:
logger.success(f"获取登录信息成功: wxid={self.client.wxid}, nickname={self.client.nickname}")
self.hookbot.update_profile(self.client.wxid, self.client.nickname)
else:
# 使用配置中的备用信息
bot_config = self.config.get("Bot", {})
fallback_wxid = bot_config.get("wxid", "unknown")
fallback_nickname = bot_config.get("nickname", "HookBot")
logger.warning(f"获取登录信息失败,使用配置中的备用信息: {fallback_wxid}")
self.hookbot.update_profile(fallback_wxid, fallback_nickname)
# 创建并启动回调服务器
logger.info(f"启动回调服务器: {callback_host}:{callback_port}")
self.callback_server = CallbackServer(host=callback_host, port=callback_port)
self.callback_server.add_message_handler(self.on_message_callback)
if not await self.callback_server.start():
logger.error("回调服务器启动失败")
return False
# 启动 WebUI
webui_config = self.config.get("WebUI", {})
if webui_config.get("enabled", False):
try:
from utils.webui import WebUIServer
webui_host = webui_config.get("host", "0.0.0.0")
webui_port = webui_config.get("port", 5001)
self.webui_server = WebUIServer(host=webui_host, port=webui_port, config_path=str(self.config_path))
await self.webui_server.start()
except Exception as e:
logger.error(f"WebUI 启动失败: {e}")
# 加载插件
logger.info("加载插件...")
self.plugin_manager = PluginManager()
self.plugin_manager.set_bot(self.client)
loaded_plugins = await self.plugin_manager.load_plugins(load_disabled=False)
logger.success(f"已加载插件: {loaded_plugins}")
# 启动消息消费者
consumer_count = self.queue_config.get("consumer_count", 1)
await self._ensure_consumer_count(consumer_count)
logger.success(f"已启动 {consumer_count} 个消息消费者")
# 启动定时任务
if scheduler.state == 0:
scheduler.start()
logger.success("定时任务已启动")
# 记录启动时间
self.start_time = int(time.time())
logger.success("=" * 60)
logger.success("WechatHookBot 启动成功!")
logger.success(f"回调地址: http://{callback_host}:{callback_port}")
logger.success("请确保 Hook 已配置正确的回调地址")
logger.success("=" * 60)
return True
async def _on_config_update(self, new_config: dict):
"""配置热更新回调"""
logger.info("正在应用新配置...")
self.config = new_config or self.config
old_queue = self.queue_config
self.queue_config = new_config.get("Queue", self.queue_config)
old_concurrency = self.concurrency_config
self.concurrency_config = new_config.get("Concurrency", self.concurrency_config)
if self.message_queue:
await self.message_queue.update_config(
maxsize=self.queue_config.get("max_size", self.message_queue.maxsize),
overflow_strategy=self.queue_config.get("overflow_strategy", self.message_queue.overflow_strategy.value),
sampling_rate=self.queue_config.get("sampling_rate", self.message_queue.sampling_rate),
)
logger.info(
f"消息队列配置已更新: max_size={self.message_queue.maxsize}, "
f"overflow={self.message_queue.overflow_strategy.value}"
)
enable_circuit_breaker = self.concurrency_config.get("enable_circuit_breaker", True)
if enable_circuit_breaker and not self.circuit_breaker:
self.circuit_breaker = AdaptiveCircuitBreaker(
failure_threshold=self.concurrency_config.get("circuit_breaker_threshold", 10),
success_threshold=3,
initial_recovery_time=5.0,
max_recovery_time=300.0
)
logger.info("已按新配置启用熔断器")
elif not enable_circuit_breaker and self.circuit_breaker:
self.circuit_breaker = None
logger.info("已按新配置禁用熔断器")
if self.circuit_breaker:
new_threshold = self.concurrency_config.get("circuit_breaker_threshold", 10)
if new_threshold != old_concurrency.get("circuit_breaker_threshold", 10):
self.circuit_breaker.failure_threshold = new_threshold
logger.info(f"熔断器阈值已更新: {new_threshold}")
new_max_concurrency = self.concurrency_config.get("plugin_max_concurrency", 8)
if new_max_concurrency != old_concurrency.get("plugin_max_concurrency", 8):
self.processing_semaphore = asyncio.Semaphore(new_max_concurrency)
logger.info(f"插件并发上限已更新: {new_max_concurrency}")
new_consumer_count = self.queue_config.get("consumer_count", len(self.consumer_tasks) or 1)
if new_consumer_count != len(self.consumer_tasks):
await self._ensure_consumer_count(new_consumer_count)
if self.queue_config.get("consumer_count") != old_queue.get("consumer_count"):
logger.info(f"消息消费者数量已更新: {new_consumer_count}")
logger.success("配置热更新完成")
async def run(self):
"""运行机器人"""
if not await self.initialize():
return
self.is_running = True
# 启动定期性能报告
async def periodic_stats():
while self.is_running:
await asyncio.sleep(300)
if self.performance_monitor and self.is_running:
self.performance_monitor.print_stats()
stats_task = asyncio.create_task(periodic_stats())
try:
logger.info("机器人正在运行,按 Ctrl+C 停止...")
while self.is_running:
await asyncio.sleep(1)
except KeyboardInterrupt:
logger.info("收到停止信号...")
finally:
stats_task.cancel()
await self.stop()
async def stop(self):
"""优雅关闭机器人"""
if self.is_shutting_down:
return
self.is_shutting_down = True
logger.info("=" * 60)
logger.info("正在优雅关闭机器人...")
logger.info("=" * 60)
# 1. 停止接收新消息
self.is_running = False
logger.info("[1/7] 停止接收新消息")
if self.callback_server:
logger.info("[1/7] 关闭消息入口...")
await self.callback_server.stop()
logger.info("[1/7] 消息入口已关闭")
# 2. 等待队列中的消息处理完成
if self.message_queue and not self.message_queue.empty():
queue_size = self.message_queue.qsize()
logger.info(f"[2/7] 等待队列中 {queue_size} 条消息处理完成...")
try:
await asyncio.wait_for(self.message_queue.join(), timeout=30)
logger.info("[2/7] 队列消息已全部处理完成")
except asyncio.TimeoutError:
logger.warning("[2/7] 队列消息未在 30 秒内处理完成,将在停止消费者后清空剩余消息")
else:
logger.info("[2/7] 队列为空,无需等待")
# 3. 停止消息消费者
if self.consumer_tasks:
logger.info(f"[3/7] 停止 {len(self.consumer_tasks)} 个消息消费者...")
for task in self.consumer_tasks:
task.cancel()
await asyncio.gather(*self.consumer_tasks, return_exceptions=True)
self.consumer_tasks.clear()
logger.info("[3/7] 消息消费者已停止")
else:
logger.info("[3/7] 无消费者需要停止")
# 4. 回调服务器已在前面关闭,这里仅补充日志
if self.callback_server:
logger.info("[4/7] 回调服务器已停止")
else:
logger.info("[4/7] 无回调服务器")
# 4.5 停止 WebUI
if self.webui_server:
await self.webui_server.stop()
# 5. 停止配置监听器
if self.config_watcher:
logger.info("[5/7] 停止配置监听器...")
await self.config_watcher.stop()
logger.info("[5/7] 配置监听器已停止")
else:
logger.info("[5/7] 无配置监听器")
# 6. 卸载插件
if self.plugin_manager:
logger.info("[6/7] 卸载插件...")
await self.plugin_manager.unload_plugins()
logger.info("[6/7] 插件已卸载")
else:
logger.info("[6/7] 无插件需要卸载")
# 7. 停止定时任务和关闭客户端
logger.info("[7/7] 清理资源...")
if self.message_queue and not self.message_queue.empty():
remaining = self.message_queue.qsize()
self.message_queue.clear()
logger.warning(f"[7/7] 已清空剩余队列消息: {remaining}")
if scheduler.running:
scheduler.shutdown()
if self.client:
await self.client.close()
# 输出最终性能报告
if self.performance_monitor:
logger.info("最终性能报告:")
self.performance_monitor.print_stats()
logger.success("=" * 60)
logger.success("机器人已优雅关闭")
logger.success("=" * 60)
async def main():
"""主函数"""
# 读取性能配置
project_root = get_project_root()
logs_dir = project_root / "logs"
logs_dir.mkdir(parents=True, exist_ok=True)
config_path = get_main_config_path()
if config_path.exists():
perf_config = get_config().get_section("Performance")
else:
perf_config = {}
# 配置日志
logger.remove()
console_enabled = perf_config.get("log_console_enabled", True)
logger.add(
sys.stdout,
colorize=perf_config.get("log_colorize", True),
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level}</level> | {message}",
level=perf_config.get("log_level_console", "INFO"),
filter=lambda record: console_enabled or "启动" in record["message"] or "初始化" in record["message"] or "成功" in record["message"] or "失败" in record["message"] or "错误" in record["message"]
)
logger.add(
str(logs_dir / "hookbot.log"),
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
encoding="utf-8",
rotation="5mb",
retention="1 week",
level=perf_config.get("log_level_file", "INFO")
)
# WebUI 日志 sink
try:
from utils.webui import loguru_sink
logger.add(
loguru_sink,
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
level=perf_config.get("log_level_console", "INFO"),
)
except Exception:
pass
# 创建并运行服务
service = BotService()
await service.run()
if __name__ == "__main__":
# 注意:新协议不再需要 32 位 Python
asyncio.run(main())