Files
ProxyAuto/database.py

181 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""数据库模型和配置管理"""
from __future__ import annotations
import os
import hashlib
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, Float, create_engine
from sqlalchemy.orm import Session, declarative_base, relationship, sessionmaker
# 加载 .env 配置文件
APP_DIR = Path(__file__).resolve().parent
load_dotenv(APP_DIR / ".env")
# 路径配置
DATA_DIR = APP_DIR / "data"
LOG_DIR = APP_DIR / "logs"
DATA_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR.mkdir(parents=True, exist_ok=True)
# 数据库配置
# MySQL 连接格式: mysql+pymysql://用户名:密码@主机:端口/数据库名?charset=utf8mb4
# 在 .env 文件中配置 DATABASE_URL
DEFAULT_MYSQL_URL = "mysql+pymysql://proxyauto:proxyauto@localhost:3306/proxyauto?charset=utf8mb4"
DATABASE_URL = os.environ.get("DATABASE_URL", DEFAULT_MYSQL_URL)
engine = create_engine(
DATABASE_URL,
echo=False,
pool_pre_ping=True, # 自动重连
pool_recycle=3600, # 连接池回收时间
)
SessionLocal = sessionmaker(bind=engine)
Base = declarative_base()
def utcnow() -> datetime:
return datetime.now(timezone.utc)
def hash_password(password: str) -> str:
"""简单的密码哈希"""
return hashlib.sha256(password.encode()).hexdigest()
class User(Base):
"""用户表"""
__tablename__ = "users"
id = Column(Integer, primary_key=True)
username = Column(String(64), nullable=False, unique=True)
password_hash = Column(String(128), nullable=False)
is_admin = Column(Boolean, nullable=False, default=False)
created_at = Column(DateTime(timezone=True), nullable=False, default=utcnow)
updated_at = Column(DateTime(timezone=True), nullable=False, default=utcnow, onupdate=utcnow)
def check_password(self, password: str) -> bool:
return self.password_hash == hash_password(password)
class ProxyMachine(Base):
"""代理机器表 - 每台机器可独立配置域名和更换时间"""
__tablename__ = "proxy_machines"
id = Column(Integer, primary_key=True)
name = Column(String(128), nullable=False, unique=True)
aws_service = Column(String(32), nullable=False, default="ec2") # ec2 | lightsail
aws_region = Column(String(64), nullable=False, default="us-east-1")
aws_instance_id = Column(String(64), nullable=False)
note = Column(String(255), nullable=True)
enabled = Column(Boolean, nullable=False, default=True)
# 域名配置 - 每台机器独立绑定
cf_zone_id = Column(String(64), nullable=True)
cf_record_name = Column(String(255), nullable=True)
cf_record_id = Column(String(64), nullable=True)
cf_proxied = Column(Boolean, nullable=False, default=False)
# 独立的更换时间配置
change_interval_seconds = Column(Integer, nullable=False, default=3600)
auto_enabled = Column(Boolean, nullable=False, default=False)
# 当前状态
current_ip = Column(String(64), nullable=True)
last_run_at = Column(DateTime(timezone=True), nullable=True)
last_success = Column(Boolean, nullable=True)
last_message = Column(Text, nullable=True)
# 流量预警配置
# Lightsail: 总流量预警(上传+下载),单位 GB
# EC2: 上传流量预警,单位 GB
traffic_alert_enabled = Column(Boolean, nullable=False, default=False)
traffic_alert_limit_gb = Column(Float, nullable=True) # 流量限制GB
traffic_alert_triggered = Column(Boolean, nullable=False, default=False) # 是否已触发预警
traffic_last_check_at = Column(DateTime(timezone=True), nullable=True) # 上次检查时间
created_at = Column(DateTime(timezone=True), nullable=False, default=utcnow)
updated_at = Column(DateTime(timezone=True), nullable=False, default=utcnow, onupdate=utcnow)
class Config(Base):
"""全局配置表 - AWS/Cloudflare 凭证"""
__tablename__ = "configs"
id = Column(Integer, primary_key=True)
# AWS 凭证(全局共用)
aws_access_key = Column(String(255), nullable=True)
aws_secret_key = Column(String(255), nullable=True)
# Cloudflare 凭证(全局共用)
cloudflare_auth_type = Column(String(32), nullable=False, default="api_token") # api_token | global_key
cf_api_token = Column(String(255), nullable=True)
cf_email = Column(String(255), nullable=True)
cf_api_key = Column(String(255), nullable=True)
# 全局设置
release_old_eip = Column(Boolean, nullable=False, default=True)
# 邮件通知配置
smtp_host = Column(String(255), nullable=True)
smtp_port = Column(Integer, nullable=True, default=587)
smtp_user = Column(String(255), nullable=True)
smtp_password = Column(String(255), nullable=True)
smtp_use_tls = Column(Boolean, nullable=False, default=True)
alert_email = Column(String(255), nullable=True) # 接收预警的邮箱
created_at = Column(DateTime(timezone=True), nullable=False, default=utcnow)
updated_at = Column(DateTime(timezone=True), nullable=False, default=utcnow, onupdate=utcnow)
# 创建表
Base.metadata.create_all(engine)
def get_session() -> Session:
return SessionLocal()
def ensure_singleton_config(session: Session) -> Config:
config = session.query(Config).order_by(Config.id.asc()).first()
if config:
return config
config = Config(
cloudflare_auth_type="api_token",
release_old_eip=True,
)
session.add(config)
session.commit()
return config
def ensure_admin_user(session: Session) -> User:
"""确保管理员用户存在"""
admin = session.query(User).filter_by(username="admin").first()
if admin:
return admin
admin = User(
username="admin",
password_hash=hash_password("80012029Lz@"),
is_admin=True,
)
session.add(admin)
session.commit()
return admin
def mask_secret(value: Optional[str], keep_start: int = 3, keep_end: int = 3) -> str:
if not value:
return ""
text = value.strip()
if len(text) <= keep_start + keep_end + 2:
return "••••••"
return f"{text[:keep_start]}••••••{text[-keep_end:]}"