181 lines
6.1 KiB
Python
181 lines
6.1 KiB
Python
"""数据库模型和配置管理"""
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
import hashlib
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
from dotenv import load_dotenv
|
||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, Float, create_engine
|
||
from sqlalchemy.orm import Session, declarative_base, relationship, sessionmaker
|
||
|
||
# 加载 .env 配置文件
|
||
APP_DIR = Path(__file__).resolve().parent
|
||
load_dotenv(APP_DIR / ".env")
|
||
|
||
# 路径配置
|
||
DATA_DIR = APP_DIR / "data"
|
||
LOG_DIR = APP_DIR / "logs"
|
||
|
||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 数据库配置
|
||
# MySQL 连接格式: mysql+pymysql://用户名:密码@主机:端口/数据库名?charset=utf8mb4
|
||
# 在 .env 文件中配置 DATABASE_URL
|
||
DEFAULT_MYSQL_URL = "mysql+pymysql://proxyauto:proxyauto@localhost:3306/proxyauto?charset=utf8mb4"
|
||
DATABASE_URL = os.environ.get("DATABASE_URL", DEFAULT_MYSQL_URL)
|
||
|
||
engine = create_engine(
|
||
DATABASE_URL,
|
||
echo=False,
|
||
pool_pre_ping=True, # 自动重连
|
||
pool_recycle=3600, # 连接池回收时间
|
||
)
|
||
SessionLocal = sessionmaker(bind=engine)
|
||
Base = declarative_base()
|
||
|
||
|
||
def utcnow() -> datetime:
|
||
return datetime.now(timezone.utc)
|
||
|
||
|
||
def hash_password(password: str) -> str:
|
||
"""简单的密码哈希"""
|
||
return hashlib.sha256(password.encode()).hexdigest()
|
||
|
||
|
||
class User(Base):
|
||
"""用户表"""
|
||
__tablename__ = "users"
|
||
|
||
id = Column(Integer, primary_key=True)
|
||
username = Column(String(64), nullable=False, unique=True)
|
||
password_hash = Column(String(128), nullable=False)
|
||
is_admin = Column(Boolean, nullable=False, default=False)
|
||
created_at = Column(DateTime(timezone=True), nullable=False, default=utcnow)
|
||
updated_at = Column(DateTime(timezone=True), nullable=False, default=utcnow, onupdate=utcnow)
|
||
|
||
def check_password(self, password: str) -> bool:
|
||
return self.password_hash == hash_password(password)
|
||
|
||
|
||
class ProxyMachine(Base):
|
||
"""代理机器表 - 每台机器可独立配置域名和更换时间"""
|
||
__tablename__ = "proxy_machines"
|
||
|
||
id = Column(Integer, primary_key=True)
|
||
name = Column(String(128), nullable=False, unique=True)
|
||
aws_service = Column(String(32), nullable=False, default="ec2") # ec2 | lightsail
|
||
aws_region = Column(String(64), nullable=False, default="us-east-1")
|
||
aws_instance_id = Column(String(64), nullable=False)
|
||
note = Column(String(255), nullable=True)
|
||
enabled = Column(Boolean, nullable=False, default=True)
|
||
|
||
# 域名配置 - 每台机器独立绑定
|
||
cf_zone_id = Column(String(64), nullable=True)
|
||
cf_record_name = Column(String(255), nullable=True)
|
||
cf_record_id = Column(String(64), nullable=True)
|
||
cf_proxied = Column(Boolean, nullable=False, default=False)
|
||
|
||
# 独立的更换时间配置
|
||
change_interval_seconds = Column(Integer, nullable=False, default=3600)
|
||
auto_enabled = Column(Boolean, nullable=False, default=False)
|
||
|
||
# 当前状态
|
||
current_ip = Column(String(64), nullable=True)
|
||
last_run_at = Column(DateTime(timezone=True), nullable=True)
|
||
last_success = Column(Boolean, nullable=True)
|
||
last_message = Column(Text, nullable=True)
|
||
|
||
# 流量预警配置
|
||
# Lightsail: 总流量预警(上传+下载),单位 GB
|
||
# EC2: 上传流量预警,单位 GB
|
||
traffic_alert_enabled = Column(Boolean, nullable=False, default=False)
|
||
traffic_alert_limit_gb = Column(Float, nullable=True) # 流量限制(GB)
|
||
traffic_alert_triggered = Column(Boolean, nullable=False, default=False) # 是否已触发预警
|
||
traffic_last_check_at = Column(DateTime(timezone=True), nullable=True) # 上次检查时间
|
||
|
||
created_at = Column(DateTime(timezone=True), nullable=False, default=utcnow)
|
||
updated_at = Column(DateTime(timezone=True), nullable=False, default=utcnow, onupdate=utcnow)
|
||
|
||
|
||
class Config(Base):
|
||
"""全局配置表 - AWS/Cloudflare 凭证"""
|
||
__tablename__ = "configs"
|
||
|
||
id = Column(Integer, primary_key=True)
|
||
|
||
# AWS 凭证(全局共用)
|
||
aws_access_key = Column(String(255), nullable=True)
|
||
aws_secret_key = Column(String(255), nullable=True)
|
||
|
||
# Cloudflare 凭证(全局共用)
|
||
cloudflare_auth_type = Column(String(32), nullable=False, default="api_token") # api_token | global_key
|
||
cf_api_token = Column(String(255), nullable=True)
|
||
cf_email = Column(String(255), nullable=True)
|
||
cf_api_key = Column(String(255), nullable=True)
|
||
|
||
# 全局设置
|
||
release_old_eip = Column(Boolean, nullable=False, default=True)
|
||
|
||
# 邮件通知配置
|
||
smtp_host = Column(String(255), nullable=True)
|
||
smtp_port = Column(Integer, nullable=True, default=587)
|
||
smtp_user = Column(String(255), nullable=True)
|
||
smtp_password = Column(String(255), nullable=True)
|
||
smtp_use_tls = Column(Boolean, nullable=False, default=True)
|
||
alert_email = Column(String(255), nullable=True) # 接收预警的邮箱
|
||
|
||
created_at = Column(DateTime(timezone=True), nullable=False, default=utcnow)
|
||
updated_at = Column(DateTime(timezone=True), nullable=False, default=utcnow, onupdate=utcnow)
|
||
|
||
|
||
# 创建表
|
||
Base.metadata.create_all(engine)
|
||
|
||
|
||
def get_session() -> Session:
|
||
return SessionLocal()
|
||
|
||
|
||
def ensure_singleton_config(session: Session) -> Config:
|
||
config = session.query(Config).order_by(Config.id.asc()).first()
|
||
if config:
|
||
return config
|
||
|
||
config = Config(
|
||
cloudflare_auth_type="api_token",
|
||
release_old_eip=True,
|
||
)
|
||
session.add(config)
|
||
session.commit()
|
||
return config
|
||
|
||
|
||
def ensure_admin_user(session: Session) -> User:
|
||
"""确保管理员用户存在"""
|
||
admin = session.query(User).filter_by(username="admin").first()
|
||
if admin:
|
||
return admin
|
||
|
||
admin = User(
|
||
username="admin",
|
||
password_hash=hash_password("80012029Lz@"),
|
||
is_admin=True,
|
||
)
|
||
session.add(admin)
|
||
session.commit()
|
||
return admin
|
||
|
||
|
||
def mask_secret(value: Optional[str], keep_start: int = 3, keep_end: int = 3) -> str:
|
||
if not value:
|
||
return ""
|
||
text = value.strip()
|
||
if len(text) <= keep_start + keep_end + 2:
|
||
return "••••••"
|
||
return f"{text[:keep_start]}••••••{text[-keep_end:]}"
|