Files
abot/configuration.py

535 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import copy
import os
import re
import yaml
class Config(object):
"""全局配置加载器。
设计目标:
1. 继续兼容项目原有的 `config.yaml` 结构,避免一次性重构过大;
2. 支持 `${ENV_NAME}` / `${ENV_NAME:default}` 形式的环境变量注入;
3. 在启动阶段尽早发现缺项、弱配置和明文敏感信息,降低误配置风险;
4. 为后续后台脱敏展示、配置巡检等能力预留统一入口。
"""
# 环境变量占位符格式:
# 1. `${ABOT_DB_PASSWORD}` 表示必须从环境变量读取;
# 2. `${ABOT_DB_HOST:127.0.0.1}` 表示环境变量缺失时回退默认值;
# 3. 这里允许字母、数字和下划线,足够覆盖常见部署变量命名。
ENV_PATTERN = re.compile(r"\$\{([A-Za-z0-9_]+)(?::([^}]*))?\}")
# 敏感字段关键字:
# 1. 用于识别需要脱敏的配置项;
# 2. 同时用于扫描原始 YAML 中是否仍有明文敏感值;
# 3. 采用“关键字包含”而不是完全等值,兼容 `sender_password` / `api_key` 等不同命名。
SENSITIVE_KEYWORDS = {
"password",
"passwd",
"secret",
"token",
"api_key",
"apikey",
"access_key",
"private_key",
}
def __init__(self, config_path: str = None) -> None:
self.project_dir = os.path.dirname(os.path.abspath(__file__))
self.config_path = config_path or os.path.join(self.project_dir, "config.yaml")
# 启动阶段优先尝试自动加载项目根目录下的 `.env`
# 1. 用户只需要把 `.env` 放到线上目录,不必每次手动 export
# 2. 这里刻意保持“已有系统环境变量优先”,避免覆盖运维层显式注入的值;
# 3. 不额外引入第三方依赖,保持部署门槛尽量低。
self._load_local_env_file(os.path.join(self.project_dir, ".env"))
self.raw_config = {}
self.resolved_config = {}
self.unresolved_placeholders = []
self.validation_report = {"errors": [], "warnings": []}
self.reload()
def _load_config(self) -> dict:
"""从磁盘读取 YAML 配置。"""
with open(self.config_path, "r", encoding="utf-8") as fp:
yconfig = yaml.safe_load(fp) or {}
return yconfig
@staticmethod
def _strip_optional_quotes(value: str) -> str:
"""去掉 `.env` 中常见的首尾引号。"""
text = str(value or "").strip()
if len(text) >= 2 and text[0] == text[-1] and text[0] in {"'", '"'}:
return text[1:-1]
return text
def _load_local_env_file(self, env_path: str) -> None:
"""从本地 `.env` 文件加载环境变量。"""
if not os.path.exists(env_path):
return
try:
with open(env_path, "r", encoding="utf-8") as env_file:
for raw_line in env_file:
line = str(raw_line or "").strip()
if not line or line.startswith("#"):
continue
# 兼容 `export KEY=value` 写法,方便 Linux 用户直接复用 shell 风格文件。
if line.startswith("export "):
line = line[len("export "):].strip()
if "=" not in line:
continue
key, value = line.split("=", 1)
key = str(key or "").strip()
value = self._strip_optional_quotes(value)
if not key:
continue
# `.env` 只在“当前进程还没有这个变量”时兜底注入:
# 1. 显式传入的系统环境变量优先级更高;
# 2. 这样本地调试和线上运维都可以覆盖 `.env` 默认值;
# 3. 也避免启动时误把运维平台上的密钥覆盖掉。
if key not in os.environ:
os.environ[key] = value
except Exception:
# `.env` 自动加载属于增强能力,不应因为格式问题直接把启动打死。
# 真正的必填项缺失会在后续 validate 阶段给出明确错误。
return
def _resolve_env_placeholders_in_string(self, raw_value: str, path: str) -> str:
"""解析字符串中的环境变量占位符。"""
def _replace(match: re.Match) -> str:
env_name = str(match.group(1) or "").strip()
default_value = match.group(2)
env_value = os.environ.get(env_name)
# 优先使用环境变量的真实值;
# 如果环境变量不存在,但模板给了默认值,则回退默认值;
# 如果两者都没有,则记录为“启动期缺失”,由 validate() 输出致命错误。
if env_value not in (None, ""):
return str(env_value)
if default_value is not None:
return str(default_value)
self.unresolved_placeholders.append({
"path": path,
"env_name": env_name,
})
return ""
return self.ENV_PATTERN.sub(_replace, raw_value)
def _resolve_config_tree(self, node, path: str = "root"):
"""递归解析整棵配置树中的占位符。"""
if isinstance(node, dict):
return {
key: self._resolve_config_tree(value, f"{path}.{key}")
for key, value in node.items()
}
if isinstance(node, list):
return [
self._resolve_config_tree(value, f"{path}[{index}]")
for index, value in enumerate(node)
]
if isinstance(node, str):
return self._resolve_env_placeholders_in_string(node, path)
return node
@staticmethod
def _safe_int(value, default: int):
"""把 YAML / 环境变量中的数字字符串安全转成整数。"""
try:
if value in (None, ""):
return default
return int(value)
except (TypeError, ValueError):
return default
@staticmethod
def _safe_bool(value, default: bool):
"""把常见的字符串/数字配置安全转成布尔值。"""
if value is None:
return default
if isinstance(value, bool):
return value
text = str(value).strip().lower()
if text in {"1", "true", "yes", "y", "on"}:
return True
if text in {"0", "false", "no", "n", "off"}:
return False
return default
def _normalize_config(self, yconfig: dict) -> dict:
"""对解析后的配置做一次结构与类型归一化。"""
normalized = copy.deepcopy(yconfig or {})
# 数据库配置归一化:
# 1. 历史配置长期使用 `prot` 拼写;
# 2. `db.connection` 代码层已经统一读取 `port`
# 3. 因此这里同时回填 `port/prot`,确保新老配置都可运行。
db_config = dict(normalized.get("db_config", {}) or {})
db_port = db_config.get("port", db_config.get("prot", 3306))
db_config["port"] = self._safe_int(db_port, 3306)
db_config["prot"] = db_config["port"]
db_config["pool_size"] = self._safe_int(db_config.get("pool_size", 10), 10)
normalized["db_config"] = db_config
# Redis / 邮件 配置中不少值来自环境变量,解析后先统一转型,
# 这样后续业务代码就不需要到处防守“字符串数字”的情况。
redis_config = dict(normalized.get("redis_config", {}) or {})
redis_config["port"] = self._safe_int(redis_config.get("port", 6379), 6379)
redis_config["db"] = self._safe_int(redis_config.get("db", 0), 0)
redis_config["max_connections"] = self._safe_int(redis_config.get("max_connections", 30), 30)
normalized["redis_config"] = redis_config
email_config = dict(normalized.get("email_config", {}) or {})
email_config["smtp_port"] = self._safe_int(email_config.get("smtp_port", 465), 465)
normalized["email_config"] = email_config
plugin_hot_reload = dict(normalized.get("plugin_hot_reload", {}) or {})
# 插件热加载本质上是一个持续扫盘线程:
# 1. 本地开发时它很方便,但线上稳定运行时意义不大;
# 2. 用户当前明确希望先停掉这类自动扫盘行为,降低不必要的 IO 干扰;
# 3. 因此这里把 enabled / interval_seconds 做成标准化配置,便于后续按环境开关。
plugin_hot_reload["enabled"] = self._safe_bool(plugin_hot_reload.get("enabled", False), False)
plugin_hot_reload["interval_seconds"] = self._safe_int(plugin_hot_reload.get("interval_seconds", 600), 600)
normalized["plugin_hot_reload"] = plugin_hot_reload
# wechat_ipad 配置归一化:
# 1. 静态连接参数现在统一走 config.yaml + .env而不是要求用户维护独立 TOML
# 2. 登录后的 wxid / device_id / device_name 会落到 provider 自己目录下的本地状态文件;
# 3. legacy_config_path 仅用于兼容历史仓库中的 `wechat_ipad/config.toml`,迁移完成后可逐步淡出。
wechat_ipad_config = dict(normalized.get("wechat_ipad", {}) or {})
wechat_ipad_config["server_port"] = self._safe_int(wechat_ipad_config.get("server_port", 8059), 8059)
wechat_ipad_config["server_url"] = str(wechat_ipad_config.get("server_url", "") or "").strip()
wechat_ipad_config["server_ip"] = str(wechat_ipad_config.get("server_ip", "") or "").strip()
wechat_ipad_config["server_type"] = str(
wechat_ipad_config.get("server_type", "legacy_855") or "legacy_855"
).strip()
# 864 风格 server 的鉴权核心是固定 `key`
# 1. 它不是像 855 那样主要依赖运行时 `wxid`
# 2. 因此这里把 `server_key` 也纳入统一配置归一化,确保 `.env` 成为唯一静态维护入口;
# 3. 留空仍允许通过校验阶段给出明确提示,而不是在 provider 启动后才报模糊错误。
wechat_ipad_config["server_key"] = str(wechat_ipad_config.get("server_key", "") or "").strip()
wechat_ipad_config["login_qr_api"] = str(
wechat_ipad_config.get("login_qr_api", "new_x") or "new_x"
).strip()
wechat_ipad_config["login_way"] = str(
wechat_ipad_config.get("login_way", "mac") or "mac"
).strip()
wechat_ipad_config["wxid"] = str(wechat_ipad_config.get("wxid", "") or "").strip()
wechat_ipad_config["device_name"] = str(wechat_ipad_config.get("device_name", "") or "").strip()
wechat_ipad_config["device_id"] = str(wechat_ipad_config.get("device_id", "") or "").strip()
wechat_ipad_config["state_file"] = str(wechat_ipad_config.get("state_file", "") or "").strip()
wechat_ipad_config["legacy_config_path"] = str(
wechat_ipad_config.get("legacy_config_path", "wechat_ipad/config.toml") or "wechat_ipad/config.toml"
).strip()
normalized["wechat_ipad"] = wechat_ipad_config
return normalized
@classmethod
def _contains_placeholder(cls, value: str) -> bool:
"""判断原始字符串是否仍包含环境变量模板。"""
return bool(cls.ENV_PATTERN.search(str(value or "")))
@classmethod
def _is_sensitive_key(cls, key: str) -> bool:
lowered_key = str(key or "").strip().lower()
return any(keyword in lowered_key for keyword in cls.SENSITIVE_KEYWORDS)
def _append_issue(self, bucket: list, code: str, path: str, message: str) -> None:
"""统一追加配置问题,便于后续日志输出与后台展示。"""
bucket.append({
"code": code,
"path": path,
"message": message,
})
def _validate_required_sections(self, report: dict) -> None:
"""检查核心运行依赖是否完整。"""
db_config = self.mariadb or {}
redis_config = self.redis or {}
llm_config = self.llm or {}
llm_backends = dict(llm_config.get("backends", {}) or {})
default_backend = str(llm_config.get("default_backend", "") or "").strip()
required_db_fields = {
"host": "数据库 host",
"user": "数据库 user",
"password": "数据库 password",
"database": "数据库 database",
}
for field_name, display_name in required_db_fields.items():
if not str(db_config.get(field_name, "") or "").strip():
self._append_issue(
report["errors"],
"missing_db_field",
f"db_config.{field_name}",
f"{display_name} 未配置,机器人无法正常连接 MySQL。",
)
if not db_config.get("port"):
self._append_issue(
report["errors"],
"missing_db_port",
"db_config.port",
"数据库 port 未配置,机器人无法正常连接 MySQL。",
)
if not str(redis_config.get("host", "") or "").strip():
self._append_issue(
report["errors"],
"missing_redis_host",
"redis_config.host",
"Redis host 未配置,机器人无法正常连接 Redis。",
)
if not redis_config.get("port"):
self._append_issue(
report["errors"],
"missing_redis_port",
"redis_config.port",
"Redis port 未配置,机器人无法正常连接 Redis。",
)
if not llm_backends:
self._append_issue(
report["warnings"],
"missing_llm_backends",
"llm.backends",
"当前未配置任何 LLM backend依赖 AI 的插件将不可用。",
)
return
if not default_backend:
self._append_issue(
report["warnings"],
"missing_default_llm_backend",
"llm.default_backend",
"未配置 llm.default_backend建议指定默认 AI 路由。",
)
elif default_backend not in llm_backends:
self._append_issue(
report["errors"],
"invalid_default_llm_backend",
"llm.default_backend",
f"默认 backend `{default_backend}` 不存在于 llm.backends 中。",
)
def _validate_email_config(self, report: dict) -> None:
"""检查邮件告警配置是否处于“半配置”状态。"""
email_config = self.email or {}
sender_email = str(email_config.get("sender_email", "") or "").strip()
sender_password = str(email_config.get("sender_password", "") or "").strip()
alert_recipient = str(email_config.get("alert_recipient", "") or "").strip()
if sender_email and not sender_password:
self._append_issue(
report["warnings"],
"missing_email_password",
"email_config.sender_password",
"已配置 sender_email但缺少 sender_password邮件告警发送会失败。",
)
if alert_recipient and (not sender_email or not sender_password):
self._append_issue(
report["warnings"],
"email_alert_incomplete",
"email_config.alert_recipient",
"已配置告警接收人,但发件邮箱配置不完整,告警链路不可用。",
)
def _validate_wechat_ipad_config(self, report: dict) -> None:
"""检查 wechat_ipad 静态连接配置是否完整。"""
wechat_ipad_config = self.wechat_ipad or {}
server_url = str(wechat_ipad_config.get("server_url", "") or "").strip()
server_ip = str(wechat_ipad_config.get("server_ip", "") or "").strip()
server_port = wechat_ipad_config.get("server_port", 0)
server_type = str(wechat_ipad_config.get("server_type", "") or "").strip().lower()
server_key = str(wechat_ipad_config.get("server_key", "") or "").strip()
if not server_url:
self._append_issue(
report["errors"],
"missing_wechat_server_url",
"wechat_ipad.server_url",
"wechat_ipad server_url 未配置,机器人无法连接 wechat_ipad server。",
)
if not server_ip:
self._append_issue(
report["errors"],
"missing_wechat_server_ip",
"wechat_ipad.server_ip",
"wechat_ipad server_ip 未配置,机器人无法连接 wechat_ipad server。",
)
if not server_port:
self._append_issue(
report["errors"],
"missing_wechat_server_port",
"wechat_ipad.server_port",
"wechat_ipad server_port 未配置,机器人无法连接 wechat_ipad server。",
)
# 864 provider 明确依赖静态 `server_key`
# 1. 它用于服务端 license / 实例身份校验;
# 2. 即使扫码成功,也不能替代这份静态鉴权参数;
# 3. 因此这里在启动前直接报错,避免上线后才在登录页反复拿不到二维码。
if server_type in {"864", "server_864"} and not server_key:
self._append_issue(
report["errors"],
"missing_wechat_server_key",
"wechat_ipad.server_key",
"server_864 模式必须配置 wechat_ipad.server_key建议通过 .env 的 WECHAT_SERVER_KEY 注入)。",
)
def _validate_llm_config(self, report: dict) -> None:
"""检查 LLM 配置的完整性与路由一致性。"""
llm_config = self.llm or {}
backends = dict(llm_config.get("backends", {}) or {})
scenes = dict(llm_config.get("scenes", {}) or {})
for backend_name, backend_config in backends.items():
backend_config = backend_config or {}
provider = str(backend_config.get("provider", "") or "").strip()
if not provider:
self._append_issue(
report["warnings"],
"missing_llm_provider",
f"llm.backends.{backend_name}.provider",
f"LLM backend `{backend_name}` 未配置 provider。",
)
# 对接第三方 AI 服务时api_key 通常是最容易漏配的关键项;
# 这里把空值直接标成 warning既不会误伤“暂未启用的 backend”又能在启动期给出提醒。
api_key = str(backend_config.get("api_key", "") or "").strip()
if not api_key:
self._append_issue(
report["warnings"],
"missing_llm_api_key",
f"llm.backends.{backend_name}.api_key",
f"LLM backend `{backend_name}` 未配置 api_key相关 AI 能力将不可用。",
)
for scene_name, backend_name in scenes.items():
backend_name = str(backend_name or "").strip()
if backend_name and backend_name not in backends:
self._append_issue(
report["warnings"],
"invalid_llm_scene_backend",
f"llm.scenes.{scene_name}",
f"场景 `{scene_name}` 指向了不存在的 backend `{backend_name}`。",
)
def _validate_unresolved_placeholders(self, report: dict) -> None:
"""把缺失环境变量转换为启动期可读错误。"""
for unresolved_item in self.unresolved_placeholders:
self._append_issue(
report["errors"],
"missing_environment_variable",
unresolved_item.get("path", "root"),
f"环境变量 `{unresolved_item.get('env_name', '')}` 未提供,且未设置默认值。",
)
def _validate_plaintext_secrets(self, report: dict) -> None:
"""扫描原始 YAML 中是否仍保留明文敏感配置。"""
def _walk(node, path: str = "root") -> None:
if isinstance(node, dict):
for key, value in node.items():
next_path = f"{path}.{key}"
if isinstance(value, str) and self._is_sensitive_key(key):
stripped_value = value.strip()
if stripped_value and not self._contains_placeholder(stripped_value):
self._append_issue(
report["warnings"],
"plaintext_sensitive_value",
next_path,
"该敏感配置仍以明文形式写在 YAML 中,建议改为环境变量注入。",
)
_walk(value, next_path)
return
if isinstance(node, list):
for index, value in enumerate(node):
_walk(value, f"{path}[{index}]")
_walk(self.raw_config)
def validate(self) -> dict:
"""返回当前配置的校验报告。"""
report = {"errors": [], "warnings": []}
self._validate_unresolved_placeholders(report)
self._validate_required_sections(report)
self._validate_email_config(report)
self._validate_wechat_ipad_config(report)
self._validate_llm_config(report)
self._validate_plaintext_secrets(report)
return report
@staticmethod
def _mask_secret_value(value: str) -> str:
"""对敏感值做轻量脱敏,保留一点可辨识尾巴方便排查。"""
text = str(value or "")
if not text:
return ""
if len(text) <= 6:
return "*" * len(text)
return f"{text[:2]}{'*' * (len(text) - 4)}{text[-2:]}"
def _sanitize_config_tree(self, node, parent_key: str = ""):
"""递归生成适合日志/后台展示的脱敏配置快照。"""
if isinstance(node, dict):
return {
key: self._sanitize_config_tree(value, str(key))
for key, value in node.items()
}
if isinstance(node, list):
return [self._sanitize_config_tree(value, parent_key) for value in node]
if isinstance(node, str) and self._is_sensitive_key(parent_key):
return self._mask_secret_value(node)
return node
def get_validation_report(self) -> dict:
"""返回一份拷贝,避免外部误改内部状态。"""
return copy.deepcopy(self.validation_report)
def has_fatal_issues(self) -> bool:
"""是否存在阻止启动的致命配置错误。"""
return bool(self.validation_report.get("errors"))
def get_sanitized_snapshot(self) -> dict:
"""返回可安全打印/展示的脱敏配置快照。"""
return self._sanitize_config_tree(self.resolved_config)
def reload(self) -> None:
"""重新加载配置,并刷新公开属性与校验结果。"""
self.raw_config = self._load_config()
self.unresolved_placeholders = []
resolved_config = self._resolve_config_tree(copy.deepcopy(self.raw_config))
self.resolved_config = self._normalize_config(resolved_config)
# 为了兼容现有调用方,这里继续保留原有的顶层属性映射;
# 后续如果逐步引入更严格的配置对象,也可以先不动业务代码。
self.environment = str(self.resolved_config.get("environment", "development") or "development").strip()
self.plugin_dir = str(self.resolved_config.get("plugin_dir", "plugins") or "plugins").strip()
self.plugin_hot_reload = self.resolved_config.get("plugin_hot_reload", {})
self.mariadb = self.resolved_config.get("db_config", {})
self.redis = self.resolved_config.get("redis_config", {})
self.email = self.resolved_config.get("email_config", {})
self.wx_config = self.resolved_config.get("wx_config", {})
self.wechat_ipad = self.resolved_config.get("wechat_ipad", {})
self.llm = self.resolved_config.get("llm", {})
self.validation_report = self.validate()