feat:自动更换ip+流量监控

This commit is contained in:
2026-01-07 17:19:53 +08:00
commit 035da64084
27 changed files with 6182 additions and 0 deletions

1
services/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Services 包"""

104
services/aws_eip.py Normal file
View File

@@ -0,0 +1,104 @@
"""AWS EC2 Elastic IP 操作"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any
import boto3
from .aws_region import normalize_aws_region
@dataclass(frozen=True)
class ElasticIpInfo:
allocation_id: str
association_id: str | None
public_ip: str | None
_INSTANCE_ID_RE = re.compile(r"^i-[0-9a-f]{8,17}$", re.IGNORECASE)
def is_valid_instance_id(value: str | None) -> bool:
if not value:
return False
return bool(_INSTANCE_ID_RE.match(value.strip()))
def create_ec2_client(
*,
region: str,
aws_access_key: str | None,
aws_secret_key: str | None,
):
kwargs: dict[str, Any] = {"region_name": normalize_aws_region(region)}
if aws_access_key and aws_secret_key:
kwargs["aws_access_key_id"] = aws_access_key
kwargs["aws_secret_access_key"] = aws_secret_key
return boto3.client("ec2", **kwargs)
def get_instance_elastic_ip(ec2_client, *, instance_id: str) -> ElasticIpInfo | None:
resp = ec2_client.describe_addresses(
Filters=[{"Name": "instance-id", "Values": [instance_id]}]
)
addresses = resp.get("Addresses") or []
if not addresses:
return None
addr = addresses[0]
allocation_id = addr.get("AllocationId")
if not allocation_id:
return None
return ElasticIpInfo(
allocation_id=allocation_id,
association_id=addr.get("AssociationId"),
public_ip=addr.get("PublicIp"),
)
def disassociate_elastic_ip(ec2_client, *, association_id: str) -> None:
ec2_client.disassociate_address(AssociationId=association_id)
def release_elastic_ip(ec2_client, *, allocation_id: str) -> None:
ec2_client.release_address(AllocationId=allocation_id)
def allocate_elastic_ip(ec2_client) -> ElasticIpInfo:
resp = ec2_client.allocate_address(Domain="vpc")
return ElasticIpInfo(
allocation_id=resp["AllocationId"],
association_id=None,
public_ip=resp.get("PublicIp"),
)
def associate_elastic_ip(ec2_client, *, instance_id: str, allocation_id: str) -> str:
resp = ec2_client.associate_address(InstanceId=instance_id, AllocationId=allocation_id)
return resp.get("AssociationId") or ""
def rotate_elastic_ip(
ec2_client,
*,
instance_id: str,
release_old: bool,
) -> dict[str, str | None]:
current = get_instance_elastic_ip(ec2_client, instance_id=instance_id)
if current and current.association_id:
disassociate_elastic_ip(ec2_client, association_id=current.association_id)
if current and release_old:
release_elastic_ip(ec2_client, allocation_id=current.allocation_id)
new_eip = allocate_elastic_ip(ec2_client)
associate_elastic_ip(ec2_client, instance_id=instance_id, allocation_id=new_eip.allocation_id)
return {
"public_ip": new_eip.public_ip,
"allocation_id": new_eip.allocation_id,
}

24
services/aws_region.py Normal file
View File

@@ -0,0 +1,24 @@
"""AWS 区域验证"""
from __future__ import annotations
import re
_AWS_REGION_RE = re.compile(r"^[a-z]{2}(-gov)?-[a-z]+-\d+$")
_AWS_AZ_RE = re.compile(r"^([a-z]{2}(-gov)?-[a-z]+-\d+)[a-z]$")
def normalize_aws_region(value: str | None) -> str:
raw = (value or "").strip().lower()
if not raw:
return ""
m = _AWS_AZ_RE.match(raw)
if m:
return m.group(1)
return raw
def is_valid_aws_region(value: str | None) -> bool:
if not value:
return False
return bool(_AWS_REGION_RE.match(value.strip().lower()))

161
services/cloudflare_dns.py Normal file
View File

@@ -0,0 +1,161 @@
"""Cloudflare DNS API 封装"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import requests
API_BASE = "https://api.cloudflare.com/client/v4"
@dataclass(frozen=True)
class CloudflareAuth:
auth_type: str # api_token | global_key
api_token: str | None = None
email: str | None = None
api_key: str | None = None
def headers(self) -> dict[str, str]:
if self.auth_type == "api_token":
if not self.api_token:
raise ValueError("Cloudflare API Token 不能为空")
return {"Authorization": f"Bearer {self.api_token}"}
if self.auth_type == "global_key":
if not self.email or not self.api_key:
raise ValueError("Cloudflare Email / Global API Key 不能为空")
return {"X-Auth-Email": self.email, "X-Auth-Key": self.api_key}
raise ValueError("cloudflare_auth_type 只能是 api_token 或 global_key")
def _check_response(resp: requests.Response) -> dict[str, Any]:
try:
data = resp.json()
except Exception:
resp.raise_for_status()
raise
if not resp.ok or not data.get("success"):
errors = data.get("errors") or []
message = errors[0].get("message") if errors else f"Cloudflare API 请求失败: {resp.status_code}"
raise RuntimeError(message)
return data
def _request(
method: str,
path: str,
*,
auth: CloudflareAuth,
params: dict[str, Any] | None = None,
json: dict[str, Any] | None = None,
timeout_seconds: int = 15,
) -> dict[str, Any]:
url = f"{API_BASE}{path}"
headers = {"Content-Type": "application/json", **auth.headers()}
resp = requests.request(
method,
url,
headers=headers,
params=params,
json=json,
timeout=timeout_seconds,
)
return _check_response(resp)
def find_a_record(
*,
zone_id: str,
record_name: str,
auth: CloudflareAuth,
) -> dict[str, Any] | None:
data = _request(
"GET",
f"/zones/{zone_id}/dns_records",
auth=auth,
params={"type": "A", "name": record_name},
)
result = data.get("result") or []
return result[0] if result else None
def update_a_record(
*,
zone_id: str,
record_id: str,
record_name: str,
ip: str,
proxied: bool,
auth: CloudflareAuth,
) -> dict[str, Any]:
data = _request(
"PUT",
f"/zones/{zone_id}/dns_records/{record_id}",
auth=auth,
json={"type": "A", "name": record_name, "content": ip, "proxied": proxied},
)
return data["result"]
def create_a_record(
*,
zone_id: str,
record_name: str,
ip: str,
proxied: bool,
auth: CloudflareAuth,
) -> dict[str, Any]:
data = _request(
"POST",
f"/zones/{zone_id}/dns_records",
auth=auth,
json={"type": "A", "name": record_name, "content": ip, "proxied": proxied},
)
return data["result"]
def upsert_a_record(
*,
zone_id: str,
record_name: str,
ip: str,
proxied: bool,
record_id: str | None,
auth: CloudflareAuth,
) -> dict[str, Any]:
if record_id:
try:
return update_a_record(
zone_id=zone_id,
record_id=record_id,
record_name=record_name,
ip=ip,
proxied=proxied,
auth=auth,
)
except Exception:
pass
record = find_a_record(zone_id=zone_id, record_name=record_name, auth=auth)
if record:
return update_a_record(
zone_id=zone_id,
record_id=record["id"],
record_name=record_name,
ip=ip,
proxied=proxied,
auth=auth,
)
return create_a_record(
zone_id=zone_id,
record_name=record_name,
ip=ip,
proxied=proxied,
auth=auth,
)

197
services/email_service.py Normal file
View File

@@ -0,0 +1,197 @@
"""邮件发送服务"""
from __future__ import annotations
import logging
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from typing import Optional
logger = logging.getLogger(__name__)
def send_email(
smtp_host: str,
smtp_port: int,
smtp_user: str,
smtp_password: str,
use_tls: bool,
to_email: str,
subject: str,
body: str,
html_body: Optional[str] = None,
) -> dict:
"""
发送邮件
Args:
smtp_host: SMTP 服务器地址
smtp_port: SMTP 端口
smtp_user: SMTP 用户名
smtp_password: SMTP 密码
use_tls: 是否使用 TLS
to_email: 收件人邮箱
subject: 邮件主题
body: 纯文本内容
html_body: HTML 内容(可选)
Returns:
{"ok": True/False, "message": "..."}
"""
try:
# 创建邮件
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["From"] = smtp_user
msg["To"] = to_email
# 添加纯文本内容
part1 = MIMEText(body, "plain", "utf-8")
msg.attach(part1)
# 添加 HTML 内容
if html_body:
part2 = MIMEText(html_body, "html", "utf-8")
msg.attach(part2)
# 连接 SMTP 服务器
if use_tls:
server = smtplib.SMTP(smtp_host, smtp_port, timeout=30)
server.starttls()
else:
server = smtplib.SMTP_SSL(smtp_host, smtp_port, timeout=30)
server.login(smtp_user, smtp_password)
server.sendmail(smtp_user, [to_email], msg.as_string())
server.quit()
logger.info("Email sent successfully to %s: %s", to_email, subject)
return {"ok": True, "message": "邮件发送成功"}
except smtplib.SMTPAuthenticationError as e:
logger.error("SMTP authentication failed: %s", e)
return {"ok": False, "message": "SMTP 认证失败,请检查用户名和密码"}
except smtplib.SMTPConnectError as e:
logger.error("SMTP connection failed: %s", e)
return {"ok": False, "message": "无法连接到 SMTP 服务器"}
except Exception as e:
logger.exception("Failed to send email")
return {"ok": False, "message": str(e)}
def send_traffic_alert_email(
smtp_host: str,
smtp_port: int,
smtp_user: str,
smtp_password: str,
use_tls: bool,
to_email: str,
machine_name: str,
aws_service: str,
current_traffic_gb: float,
limit_gb: float,
traffic_type: str, # "total" or "upload"
) -> dict:
"""
发送流量预警邮件
Args:
machine_name: 机器名称
aws_service: 服务类型ec2/lightsail
current_traffic_gb: 当前流量GB
limit_gb: 限制流量GB
traffic_type: 流量类型total=总流量, upload=上传流量)
"""
if traffic_type == "total":
traffic_desc = "总流量(上传+下载)"
else:
traffic_desc = "上传流量"
subject = f"[ProxyAuto] 流量预警 - {machine_name}"
body = f"""
流量预警通知
机器名称: {machine_name}
服务类型: {aws_service.upper()}
预警类型: {traffic_desc}超限
当前{traffic_desc}: {current_traffic_gb:.2f} GB
设定限制: {limit_gb:.2f} GB
超出: {current_traffic_gb - limit_gb:.2f} GB
系统已自动暂停该机器的 IP 自动更换任务。
---
ProxyAuto Pro 自动通知
"""
html_body = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; line-height: 1.6; color: #333; }}
.container {{ max-width: 600px; margin: 0 auto; padding: 20px; }}
.header {{ background: linear-gradient(135deg, #1F6BFF, #3D8BFF); color: white; padding: 20px; border-radius: 8px 8px 0 0; }}
.content {{ background: #f9f9f9; padding: 20px; border: 1px solid #e0e0e0; border-top: none; border-radius: 0 0 8px 8px; }}
.alert {{ background: #FEE2E2; border-left: 4px solid #EF4444; padding: 15px; margin: 15px 0; border-radius: 4px; }}
.info-row {{ display: flex; justify-content: space-between; padding: 10px 0; border-bottom: 1px solid #e0e0e0; }}
.info-label {{ color: #666; }}
.info-value {{ font-weight: 600; }}
.footer {{ text-align: center; padding: 15px; color: #999; font-size: 12px; }}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h2 style="margin: 0;">流量预警通知</h2>
</div>
<div class="content">
<div class="alert">
<strong>警告:</strong>{traffic_desc}已超出设定限制!
</div>
<div class="info-row">
<span class="info-label">机器名称</span>
<span class="info-value">{machine_name}</span>
</div>
<div class="info-row">
<span class="info-label">服务类型</span>
<span class="info-value">{aws_service.upper()}</span>
</div>
<div class="info-row">
<span class="info-label">当前{traffic_desc}</span>
<span class="info-value" style="color: #EF4444;">{current_traffic_gb:.2f} GB</span>
</div>
<div class="info-row">
<span class="info-label">设定限制</span>
<span class="info-value">{limit_gb:.2f} GB</span>
</div>
<div class="info-row">
<span class="info-label">超出流量</span>
<span class="info-value" style="color: #EF4444;">{current_traffic_gb - limit_gb:.2f} GB</span>
</div>
<p style="margin-top: 20px; color: #666;">
系统已自动暂停该机器的 IP 自动更换任务。请登录控制台查看详情。
</p>
</div>
<div class="footer">
ProxyAuto Pro 自动通知
</div>
</div>
</body>
</html>
"""
return send_email(
smtp_host=smtp_host,
smtp_port=smtp_port,
smtp_user=smtp_user,
smtp_password=smtp_password,
use_tls=use_tls,
to_email=to_email,
subject=subject,
body=body,
html_body=html_body,
)

152
services/ip_change.py Normal file
View File

@@ -0,0 +1,152 @@
"""IP 更换核心逻辑"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from database import Config, ProxyMachine, get_session, ensure_singleton_config
from services.aws_eip import create_ec2_client, is_valid_instance_id, rotate_elastic_ip
from services.lightsail_static_ip import (
create_lightsail_client,
is_valid_lightsail_instance_name,
rotate_lightsail_static_ip,
)
from services.cloudflare_dns import CloudflareAuth, upsert_a_record
logger = logging.getLogger(__name__)
def _get_cf_auth(config: Config) -> CloudflareAuth:
auth_type = (config.cloudflare_auth_type or "").strip() or "api_token"
if auth_type == "api_token":
return CloudflareAuth(auth_type="api_token", api_token=config.cf_api_token)
return CloudflareAuth(
auth_type="global_key",
email=config.cf_email,
api_key=config.cf_api_key,
)
def change_ip_for_machine(machine: ProxyMachine, config: Config) -> dict[str, str]:
"""为指定机器更换 IP 并更新 DNS"""
if not machine.enabled:
raise ValueError(f"机器 {machine.name} 已被禁用")
aws_service = (machine.aws_service or "ec2").strip().lower()
# 执行 IP 轮换
if aws_service == "lightsail":
if not is_valid_lightsail_instance_name(machine.aws_instance_id):
raise ValueError(
f"Lightsail 实例名格式不正确:{machine.aws_instance_id}"
)
lightsail = create_lightsail_client(
region=machine.aws_region,
aws_access_key=config.aws_access_key,
aws_secret_key=config.aws_secret_key,
)
logger.info(
"Rotating Lightsail Static IP for instance %s (%s)",
machine.aws_instance_id,
machine.name,
)
aws_result = rotate_lightsail_static_ip(
lightsail,
instance_name=machine.aws_instance_id,
release_old=bool(config.release_old_eip),
)
else:
if not is_valid_instance_id(machine.aws_instance_id):
raise ValueError(
f"EC2 Instance ID 格式不正确:{machine.aws_instance_id}(应类似 i-xxxxxxxxxxxxxxxxx"
)
ec2 = create_ec2_client(
region=machine.aws_region,
aws_access_key=config.aws_access_key,
aws_secret_key=config.aws_secret_key,
)
logger.info(
"Rotating Elastic IP for instance %s (%s)",
machine.aws_instance_id,
machine.name,
)
aws_result = rotate_elastic_ip(
ec2,
instance_id=machine.aws_instance_id,
release_old=bool(config.release_old_eip),
)
public_ip = aws_result.get("public_ip")
if not public_ip:
raise RuntimeError("AWS 未返回新的 Public IP")
message = f"IP 已更换为 {public_ip}"
# 如果机器配置了域名,更新 DNS
if machine.cf_zone_id and machine.cf_record_name:
logger.info("Updating Cloudflare A record %s -> %s", machine.cf_record_name, public_ip)
record = upsert_a_record(
zone_id=machine.cf_zone_id,
record_name=machine.cf_record_name,
ip=public_ip,
proxied=bool(machine.cf_proxied),
record_id=machine.cf_record_id,
auth=_get_cf_auth(config),
)
record_id = record.get("id")
if record_id:
machine.cf_record_id = record_id
message = f"已更新 {machine.cf_record_name} -> {public_ip}"
return {
"public_ip": public_ip,
"message": message,
}
def run_ip_change_for_machine(machine_id: int) -> dict:
"""为指定机器执行一次 IP 更换"""
session = get_session()
try:
machine = session.query(ProxyMachine).filter_by(id=machine_id).first()
if not machine:
return {"ok": False, "message": "机器不存在"}
config = ensure_singleton_config(session)
started_at = datetime.now(timezone.utc)
try:
result = change_ip_for_machine(machine, config)
machine.last_run_at = started_at
machine.last_success = True
machine.current_ip = result.get("public_ip")
machine.last_message = result.get("message") or "OK"
session.add(machine)
session.commit()
logger.info("IP change success for %s: %s", machine.name, machine.current_ip)
return {"ok": True, "machine_name": machine.name, **result}
except Exception as exc:
logger.exception("IP change failed for %s", machine.name)
machine.last_run_at = started_at
machine.last_success = False
machine.last_message = str(exc)
session.add(machine)
session.commit()
return {"ok": False, "machine_name": machine.name, "message": str(exc)}
finally:
session.close()
# 兼容旧接口
def run_ip_change() -> dict:
"""执行一次 IP 更换(兼容旧接口)"""
session = get_session()
try:
# 获取第一台启用的机器
machine = session.query(ProxyMachine).filter_by(enabled=True).first()
if not machine:
return {"ok": False, "message": "没有可用的机器"}
return run_ip_change_for_machine(machine.id)
finally:
session.close()

View File

@@ -0,0 +1,132 @@
"""AWS Lightsail Static IP 操作"""
from __future__ import annotations
import re
import time
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any
import boto3
from botocore.exceptions import ClientError
from .aws_region import normalize_aws_region
@dataclass(frozen=True)
class LightsailStaticIp:
name: str
ip_address: str | None
attached_to: str | None
is_attached: bool
_INSTANCE_NAME_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9_.-]{0,127}$")
def is_valid_lightsail_instance_name(value: str | None) -> bool:
if not value:
return False
return bool(_INSTANCE_NAME_RE.match(value.strip()))
def create_lightsail_client(
*,
region: str,
aws_access_key: str | None,
aws_secret_key: str | None,
):
kwargs: dict[str, Any] = {"region_name": normalize_aws_region(region)}
if aws_access_key and aws_secret_key:
kwargs["aws_access_key_id"] = aws_access_key
kwargs["aws_secret_access_key"] = aws_secret_key
return boto3.client("lightsail", **kwargs)
def _list_static_ips(client) -> list[LightsailStaticIp]:
resp = client.get_static_ips()
items = resp.get("staticIps") or []
result: list[LightsailStaticIp] = []
for item in items:
result.append(
LightsailStaticIp(
name=item.get("name") or "",
ip_address=item.get("ipAddress"),
attached_to=item.get("attachedTo"),
is_attached=bool(item.get("isAttached")),
)
)
return [ip for ip in result if ip.name]
def _get_attached_static_ip(client, *, instance_name: str) -> LightsailStaticIp | None:
for ip in _list_static_ips(client):
if ip.is_attached and ip.attached_to == instance_name:
return ip
return None
def _generate_static_ip_name(instance_name: str) -> str:
safe = re.sub(r"[^A-Za-z0-9-]+", "-", instance_name).strip("-").lower() or "instance"
safe = safe[:24]
ts = datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S")
return f"proxyauto-{safe}-{ts}"
def rotate_lightsail_static_ip(
client,
*,
instance_name: str,
release_old: bool,
) -> dict[str, str]:
current = _get_attached_static_ip(client, instance_name=instance_name)
if current:
client.detach_static_ip(staticIpName=current.name)
if release_old:
for attempt in range(8):
try:
client.release_static_ip(staticIpName=current.name)
break
except ClientError as exc:
if attempt == 7:
raise
code = exc.response.get("Error", {}).get("Code")
if code in {"OperationFailureException", "InvalidInputException"}:
time.sleep(1)
continue
raise
new_name = _generate_static_ip_name(instance_name)
for attempt in range(5):
try:
client.allocate_static_ip(staticIpName=new_name)
break
except ClientError as exc:
if attempt == 4:
raise
message = (exc.response.get("Error", {}).get("Message") or "").lower()
if "already exists" in message or "alreadyexist" in message:
new_name = f"{new_name}-{attempt + 1}"
continue
raise
client.attach_static_ip(staticIpName=new_name, instanceName=instance_name)
public_ip: str | None = None
for _ in range(20):
try:
resp = client.get_static_ip(staticIpName=new_name)
static_ip = resp.get("staticIp") or {}
public_ip = static_ip.get("ipAddress")
if public_ip:
break
except ClientError:
pass
time.sleep(1)
if not public_ip:
raise RuntimeError("Lightsail 未返回新的 Static IP 地址,请稍后重试")
return {"public_ip": public_ip, "static_ip_name": new_name}

124
services/scheduler.py Normal file
View File

@@ -0,0 +1,124 @@
"""后台定时任务调度器 - 支持每台机器独立调度"""
from __future__ import annotations
import logging
import threading
from datetime import datetime, timezone
from typing import Any
from zoneinfo import ZoneInfo
from apscheduler.schedulers.background import BackgroundScheduler
logger = logging.getLogger(__name__)
_scheduler: BackgroundScheduler | None = None
_lock = threading.Lock()
# 使用上海时区
SHANGHAI_TZ = ZoneInfo("Asia/Shanghai")
def _get_job_id(machine_id: int) -> str:
return f"ip_change_machine_{machine_id}"
def get_scheduler() -> BackgroundScheduler:
global _scheduler
with _lock:
if _scheduler is None:
_scheduler = BackgroundScheduler(timezone=SHANGHAI_TZ)
_scheduler.start()
return _scheduler
def get_scheduler_status() -> dict[str, Any]:
"""获取整体调度器状态"""
scheduler = get_scheduler()
jobs = scheduler.get_jobs()
return {
"running": len(jobs) > 0,
"job_count": len(jobs),
}
def get_machine_scheduler_status(machine_id: int) -> dict[str, Any]:
"""获取指定机器的调度状态"""
scheduler = get_scheduler()
job_id = _get_job_id(machine_id)
job = scheduler.get_job(job_id)
next_run_time = None
if job and job.next_run_time:
# 确保转换为上海时区
next_run_time = job.next_run_time.astimezone(SHANGHAI_TZ).isoformat()
return {
"running": bool(job),
"next_run_time": next_run_time,
"job_id": job_id,
}
def _job_func(machine_id: int) -> None:
"""定时任务执行函数"""
from services.ip_change import run_ip_change_for_machine
run_ip_change_for_machine(machine_id)
def start_machine_auto(machine_id: int, interval_seconds: int) -> None:
"""启动指定机器的自动任务"""
scheduler = get_scheduler()
job_id = _get_job_id(machine_id)
interval = max(10, interval_seconds)
scheduler.add_job(
_job_func,
"interval",
args=[machine_id],
seconds=interval,
id=job_id,
replace_existing=True,
max_instances=1,
coalesce=True,
misfire_grace_time=30,
)
logger.info("Machine %d auto job scheduled: every %ss", machine_id, interval)
def stop_machine_auto(machine_id: int) -> None:
"""停止指定机器的自动任务"""
scheduler = get_scheduler()
job_id = _get_job_id(machine_id)
job = scheduler.get_job(job_id)
if job:
scheduler.remove_job(job_id)
logger.info("Machine %d auto job stopped", machine_id)
def update_all_schedulers() -> None:
"""根据数据库配置更新所有机器的调度器状态"""
from database import get_session, ProxyMachine
session = get_session()
try:
machines = session.query(ProxyMachine).filter_by(auto_enabled=True).all()
for machine in machines:
start_machine_auto(machine.id, machine.change_interval_seconds)
finally:
session.close()
# 兼容旧接口
def start_auto(interval_seconds: int) -> None:
"""兼容旧接口"""
pass
def stop_auto() -> None:
"""兼容旧接口"""
pass
def update_scheduler_from_config() -> None:
"""兼容旧接口"""
update_all_schedulers()

193
services/traffic_alert.py Normal file
View File

@@ -0,0 +1,193 @@
"""流量预警检查服务"""
from __future__ import annotations
import logging
from datetime import datetime
from zoneinfo import ZoneInfo
from database import Config, ProxyMachine, get_session, ensure_singleton_config
from services.traffic_monitor import get_current_month_traffic
from services.email_service import send_traffic_alert_email
from services.scheduler import stop_machine_auto
logger = logging.getLogger(__name__)
SHANGHAI_TZ = ZoneInfo("Asia/Shanghai")
# 字节转GB
BYTES_PER_GB = 1024 * 1024 * 1024
def check_machine_traffic_alert(machine: ProxyMachine, config: Config) -> dict:
"""
检查单台机器的流量预警
Returns:
{"triggered": bool, "message": str}
"""
if not machine.traffic_alert_enabled or not machine.traffic_alert_limit_gb:
return {"triggered": False, "message": "未启用流量预警"}
if not config.aws_access_key or not config.aws_secret_key:
return {"triggered": False, "message": "AWS 凭证未配置"}
try:
# 获取当月流量
traffic_data = get_current_month_traffic(
aws_service=machine.aws_service,
region=machine.aws_region,
instance_id=machine.aws_instance_id,
aws_access_key=config.aws_access_key,
aws_secret_key=config.aws_secret_key,
)
if not traffic_data.get("ok"):
return {"triggered": False, "message": f"获取流量失败: {traffic_data.get('message')}"}
# 根据服务类型判断预警条件
if machine.aws_service == "lightsail":
# Lightsail: 总流量预警
current_bytes = traffic_data["total"]
traffic_type = "total"
else:
# EC2: 上传流量预警
current_bytes = traffic_data["network_out"]
traffic_type = "upload"
current_gb = current_bytes / BYTES_PER_GB
limit_gb = machine.traffic_alert_limit_gb
if current_gb >= limit_gb:
return {
"triggered": True,
"message": f"流量超限: {current_gb:.2f} GB / {limit_gb:.2f} GB",
"current_gb": current_gb,
"limit_gb": limit_gb,
"traffic_type": traffic_type,
}
return {
"triggered": False,
"message": f"流量正常: {current_gb:.2f} GB / {limit_gb:.2f} GB",
"current_gb": current_gb,
"limit_gb": limit_gb,
}
except Exception as e:
logger.exception("Failed to check traffic for %s", machine.name)
return {"triggered": False, "message": f"检查失败: {str(e)}"}
def handle_traffic_alert(machine: ProxyMachine, config: Config, alert_result: dict) -> None:
"""
处理流量预警:暂停机器 + 发送邮件
"""
session = get_session()
try:
# 重新获取机器对象(确保在当前 session 中)
db_machine = session.query(ProxyMachine).filter_by(id=machine.id).first()
if not db_machine:
return
# 标记已触发预警
db_machine.traffic_alert_triggered = True
db_machine.traffic_last_check_at = datetime.now(SHANGHAI_TZ)
# 暂停自动任务
if db_machine.auto_enabled:
db_machine.auto_enabled = False
stop_machine_auto(db_machine.id)
logger.warning("Machine %s auto job stopped due to traffic alert", db_machine.name)
session.commit()
# 发送邮件通知
if config.smtp_host and config.alert_email:
send_traffic_alert_email(
smtp_host=config.smtp_host,
smtp_port=config.smtp_port or 587,
smtp_user=config.smtp_user,
smtp_password=config.smtp_password,
use_tls=config.smtp_use_tls,
to_email=config.alert_email,
machine_name=db_machine.name,
aws_service=db_machine.aws_service,
current_traffic_gb=alert_result["current_gb"],
limit_gb=alert_result["limit_gb"],
traffic_type=alert_result["traffic_type"],
)
else:
logger.warning("Email not sent: SMTP or alert email not configured")
finally:
session.close()
def check_all_traffic_alerts() -> dict:
"""
检查所有机器的流量预警
Returns:
{"checked": int, "triggered": int, "results": [...]}
"""
session = get_session()
try:
config = ensure_singleton_config(session)
# 获取所有启用了流量预警且未触发的机器
machines = session.query(ProxyMachine).filter(
ProxyMachine.traffic_alert_enabled == True,
ProxyMachine.traffic_alert_triggered == False,
ProxyMachine.enabled == True,
).all()
results = []
triggered_count = 0
for machine in machines:
result = check_machine_traffic_alert(machine, config)
results.append({
"machine_id": machine.id,
"machine_name": machine.name,
**result,
})
if result.get("triggered"):
triggered_count += 1
handle_traffic_alert(machine, config, result)
logger.warning("Traffic alert triggered for %s: %s", machine.name, result["message"])
# 更新检查时间
machine.traffic_last_check_at = datetime.now(SHANGHAI_TZ)
session.commit()
return {
"ok": True,
"checked": len(machines),
"triggered": triggered_count,
"results": results,
}
except Exception as e:
logger.exception("Failed to check traffic alerts")
return {"ok": False, "message": str(e), "checked": 0, "triggered": 0, "results": []}
finally:
session.close()
def reset_machine_alert(machine_id: int) -> dict:
"""
重置机器的预警状态(手动解除预警)
"""
session = get_session()
try:
machine = session.query(ProxyMachine).filter_by(id=machine_id).first()
if not machine:
return {"ok": False, "message": "机器不存在"}
machine.traffic_alert_triggered = False
session.commit()
return {"ok": True, "message": f"已重置 {machine.name} 的预警状态"}
finally:
session.close()

338
services/traffic_monitor.py Normal file
View File

@@ -0,0 +1,338 @@
"""流量监控服务 - 获取 EC2/Lightsail 流量数据"""
from __future__ import annotations
import logging
from datetime import datetime, timedelta
from typing import Any
from zoneinfo import ZoneInfo
import boto3
logger = logging.getLogger(__name__)
SHANGHAI_TZ = ZoneInfo("Asia/Shanghai")
def create_cloudwatch_client(region: str, aws_access_key: str, aws_secret_key: str):
"""创建 CloudWatch 客户端"""
return boto3.client(
"cloudwatch",
region_name=region,
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
)
def get_ec2_traffic(
region: str,
instance_id: str,
aws_access_key: str,
aws_secret_key: str,
start_time: datetime,
end_time: datetime,
period: int = 3600,
) -> dict[str, Any]:
"""
获取 EC2 实例的流量数据
Args:
region: AWS 区域
instance_id: EC2 实例 ID
aws_access_key: AWS Access Key
aws_secret_key: AWS Secret Key
start_time: 开始时间
end_time: 结束时间
period: 数据点间隔默认1小时
Returns:
{
"network_in": 下载流量(字节),
"network_out": 上传流量(字节),
"data_points": 详细数据点列表
}
"""
try:
cloudwatch = create_cloudwatch_client(region, aws_access_key, aws_secret_key)
# 获取 NetworkIn下载
network_in_response = cloudwatch.get_metric_statistics(
Namespace="AWS/EC2",
MetricName="NetworkIn",
Dimensions=[{"Name": "InstanceId", "Value": instance_id}],
StartTime=start_time,
EndTime=end_time,
Period=period,
Statistics=["Sum"],
)
# 获取 NetworkOut上传
network_out_response = cloudwatch.get_metric_statistics(
Namespace="AWS/EC2",
MetricName="NetworkOut",
Dimensions=[{"Name": "InstanceId", "Value": instance_id}],
StartTime=start_time,
EndTime=end_time,
Period=period,
Statistics=["Sum"],
)
# 计算总流量
network_in_total = sum(dp["Sum"] for dp in network_in_response.get("Datapoints", []))
network_out_total = sum(dp["Sum"] for dp in network_out_response.get("Datapoints", []))
# 合并数据点用于图表
data_points = []
in_points = {dp["Timestamp"]: dp["Sum"] for dp in network_in_response.get("Datapoints", [])}
out_points = {dp["Timestamp"]: dp["Sum"] for dp in network_out_response.get("Datapoints", [])}
all_timestamps = sorted(set(in_points.keys()) | set(out_points.keys()))
for ts in all_timestamps:
data_points.append({
"timestamp": ts.astimezone(SHANGHAI_TZ).isoformat(),
"network_in": in_points.get(ts, 0),
"network_out": out_points.get(ts, 0),
})
return {
"ok": True,
"network_in": network_in_total,
"network_out": network_out_total,
"total": network_in_total + network_out_total,
"data_points": data_points,
}
except Exception as e:
logger.exception("Failed to get EC2 traffic for %s", instance_id)
return {"ok": False, "message": str(e), "network_in": 0, "network_out": 0, "total": 0, "data_points": []}
def get_lightsail_traffic(
region: str,
instance_name: str,
aws_access_key: str,
aws_secret_key: str,
start_time: datetime,
end_time: datetime,
period: int = 3600,
) -> dict[str, Any]:
"""
获取 Lightsail 实例的流量数据
Args:
region: AWS 区域
instance_name: Lightsail 实例名称
aws_access_key: AWS Access Key
aws_secret_key: AWS Secret Key
start_time: 开始时间
end_time: 结束时间
period: 数据点间隔默认1小时
Returns:
{
"network_in": 下载流量(字节),
"network_out": 上传流量(字节),
"data_points": 详细数据点列表
}
"""
try:
lightsail = boto3.client(
"lightsail",
region_name=region,
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
)
# 获取 NetworkIn下载
network_in_response = lightsail.get_instance_metric_data(
instanceName=instance_name,
metricName="NetworkIn",
period=period,
startTime=start_time,
endTime=end_time,
unit="Bytes",
statistics=["Sum"],
)
# 获取 NetworkOut上传
network_out_response = lightsail.get_instance_metric_data(
instanceName=instance_name,
metricName="NetworkOut",
period=period,
startTime=start_time,
endTime=end_time,
unit="Bytes",
statistics=["Sum"],
)
# 计算总流量
in_data = network_in_response.get("metricData", [])
out_data = network_out_response.get("metricData", [])
network_in_total = sum(dp.get("sum", 0) for dp in in_data)
network_out_total = sum(dp.get("sum", 0) for dp in out_data)
# 合并数据点
data_points = []
in_points = {dp["timestamp"]: dp.get("sum", 0) for dp in in_data}
out_points = {dp["timestamp"]: dp.get("sum", 0) for dp in out_data}
all_timestamps = sorted(set(in_points.keys()) | set(out_points.keys()))
for ts in all_timestamps:
data_points.append({
"timestamp": ts.astimezone(SHANGHAI_TZ).isoformat(),
"network_in": in_points.get(ts, 0),
"network_out": out_points.get(ts, 0),
})
return {
"ok": True,
"network_in": network_in_total,
"network_out": network_out_total,
"total": network_in_total + network_out_total,
"data_points": data_points,
}
except Exception as e:
logger.exception("Failed to get Lightsail traffic for %s", instance_name)
return {"ok": False, "message": str(e), "network_in": 0, "network_out": 0, "total": 0, "data_points": []}
def get_machine_traffic(
aws_service: str,
region: str,
instance_id: str,
aws_access_key: str,
aws_secret_key: str,
start_time: datetime,
end_time: datetime,
period: int = 3600,
) -> dict[str, Any]:
"""
统一接口:根据服务类型获取流量数据
"""
if aws_service == "lightsail":
return get_lightsail_traffic(
region=region,
instance_name=instance_id,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
start_time=start_time,
end_time=end_time,
period=period,
)
else:
return get_ec2_traffic(
region=region,
instance_id=instance_id,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
start_time=start_time,
end_time=end_time,
period=period,
)
def get_current_month_traffic(
aws_service: str,
region: str,
instance_id: str,
aws_access_key: str,
aws_secret_key: str,
) -> dict[str, Any]:
"""获取当月流量数据"""
now = datetime.now(SHANGHAI_TZ)
start_of_month = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
return get_machine_traffic(
aws_service=aws_service,
region=region,
instance_id=instance_id,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
start_time=start_of_month,
end_time=now,
period=3600, # 每小时一个数据点
)
def get_current_day_traffic(
aws_service: str,
region: str,
instance_id: str,
aws_access_key: str,
aws_secret_key: str,
) -> dict[str, Any]:
"""获取当日流量数据"""
now = datetime.now(SHANGHAI_TZ)
start_of_day = now.replace(hour=0, minute=0, second=0, microsecond=0)
return get_machine_traffic(
aws_service=aws_service,
region=region,
instance_id=instance_id,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
start_time=start_of_day,
end_time=now,
period=300, # 每5分钟一个数据点
)
def get_all_time_traffic(
aws_service: str,
region: str,
instance_id: str,
aws_access_key: str,
aws_secret_key: str,
created_at: datetime = None,
) -> dict[str, Any]:
"""
获取建站至今的总流量数据
注意: CloudWatch 数据保留期限有限:
- 小于60秒的数据点保留3小时
- 60秒(1分钟)的数据点保留15天
- 300秒(5分钟)的数据点保留63天
- 3600秒(1小时)的数据点保留455天(约15个月)
因此这里只能获取最近约15个月的数据
"""
now = datetime.now(SHANGHAI_TZ)
# 如果提供了创建时间使用它否则使用15个月前
if created_at:
# 确保时区正确
if created_at.tzinfo is None:
start_time = created_at.replace(tzinfo=SHANGHAI_TZ)
else:
start_time = created_at.astimezone(SHANGHAI_TZ)
else:
# CloudWatch 最多保留约15个月的小时级数据
start_time = now - timedelta(days=455)
return get_machine_traffic(
aws_service=aws_service,
region=region,
instance_id=instance_id,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
start_time=start_time,
end_time=now,
period=86400, # 每天一个数据点
)
def format_bytes(bytes_value: float) -> str:
"""格式化字节数为可读格式"""
if bytes_value < 0:
return "0 B"
units = ["B", "KB", "MB", "GB", "TB"]
unit_index = 0
value = float(bytes_value)
while value >= 1024 and unit_index < len(units) - 1:
value /= 1024
unit_index += 1
if unit_index == 0:
return f"{int(value)} {units[unit_index]}"
return f"{value:.2f} {units[unit_index]}"