pendulum-transformer-server/chaos_pdl/pendulum_transformer.py
2025-09-24 11:35:48 +08:00

449 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
train_ts_transformer_multi.py
- 同时读取多个 CSV列名需一致若存在 frame_index 会按其排序并移除)
- Transformer Encoder 进行多步预测(默认输入 240 帧,预测 30 帧)
- 训练/验证集“均匀交错抽样”划分,可设置 val_gap 减少邻近泄漏
- 标准化仅用训练窗口拟合,避免验证泄漏
- 损失:按预测步 t 递减权重的加权 RMSE支持指数/线性衰减)
- 保存 best_model.pt 与 scaler.npz
"""
import argparse
import os
import math
import random
from typing import List, Tuple
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
def set_seed(seed: int = 42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
class WindowedTSDataset(Dataset):
def __init__(self, X_norm: np.ndarray, starts: np.ndarray, seq_len: int, horizon: int):
self.X = X_norm.astype(np.float32) # [T_total, F]
self.starts = starts.astype(np.int64) # 全局起点索引(已处理多文件偏移)
self.seq_len = seq_len
self.horizon = horizon
def __len__(self):
return len(self.starts)
def __getitem__(self, idx):
s = int(self.starts[idx])
x = self.X[s: s + self.seq_len] # [L, F]
y = self.X[s + self.seq_len: s + self.seq_len + self.horizon] # [H, F]
return torch.from_numpy(x), torch.from_numpy(y)
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 10000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe.unsqueeze(1)) # [max_len, 1, d_model]
def forward(self, x: torch.Tensor):
# x: [S, N, E]
x = x + self.pe[: x.size(0)]
return self.dropout(x)
class TimeSeriesTransformer(nn.Module):
def __init__(self, in_features: int, d_model: int, nhead: int, num_layers: int,
dim_feedforward: int, dropout: float, horizon: int):
super().__init__()
self.in_features = in_features
self.horizon = horizon
self.input_proj = nn.Linear(in_features, d_model)
self.pos_enc = PositionalEncoding(d_model, dropout=dropout)
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
dropout=dropout, batch_first=False, norm_first=True
)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
self.head = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, horizon * in_features),
)
def forward(self, x: torch.Tensor):
# x: [B, L, F]
B, L, F = x.shape
h = self.input_proj(x).transpose(0, 1) # [L, B, d]
h = self.pos_enc(h)
h = self.encoder(h) # [L, B, d]
h_last = h[-1] # [B, d]
y = self.head(h_last).view(B, self.horizon, self.in_features) # [B, H, F]
return y
def build_windows_indices(T: int, seq_len: int, horizon: int, stride: int) -> np.ndarray:
max_start = T - seq_len - horizon
if max_start < 0:
return np.array([], dtype=np.int64)
return np.arange(0, max_start + 1, stride, dtype=np.int64)
def interleaved_split(starts: np.ndarray, val_ratio: float, val_gap: int = 0) -> Tuple[np.ndarray, np.ndarray]:
assert 0.0 < val_ratio < 1.0
total = len(starts)
if total == 0:
return np.array([], dtype=np.int64), np.array([], dtype=np.int64)
every = max(int(round(1.0 / val_ratio)), 2)
# 安全裁剪 val_gap若 2*val_gap+1 >= every会把所有索引都覆盖掉
safe_gap = int(min(max((every - 2) // 2, 0), val_gap))
if safe_gap != val_gap:
print(f"[warn] val_gap={val_gap} 过大,按交错周期 every={every} 自动裁剪为 {safe_gap},以避免训练集为空。")
idx_all = np.arange(total, dtype=np.int64)
val_mask = (idx_all % every) == 0
val_idx = idx_all[val_mask]
train_mask = ~val_mask
if safe_gap > 0:
drop = set()
for v in val_idx:
lo = max(0, v - safe_gap)
hi = min(total - 1, v + safe_gap)
for t in range(lo, hi + 1):
drop.add(t)
if drop:
train_mask[list(drop)] = False
train_idx = idx_all[train_mask]
return starts[train_idx], starts[val_idx]
def fit_scaler_from_windows(X: np.ndarray, train_starts: np.ndarray, seq_len: int) -> Tuple[np.ndarray, np.ndarray]:
chunks = [X[s:s + seq_len] for s in train_starts]
if not chunks:
raise ValueError("训练窗口为空,无法拟合标准化。")
arr = np.concatenate(chunks, axis=0) # [N_train*L, F]
mean = arr.mean(axis=0)
std = arr.std(axis=0)
std[std < 1e-6] = 1e-6
return mean, std
class WeightedRMSELoss(nn.Module):
"""
对预测步 t 使用递减权重 w_t支持:
- 指数衰减: w_t = gamma^t
- 线性衰减: w_t = 1 - alpha * t, 直到 >= min_w
返回整体 RMSE 标量(可反传)
"""
def __init__(self, horizon: int, mode: str = "exp",
gamma: float = 0.97, alpha: float = 0.0, min_w: float = 0.05):
super().__init__()
self.horizon = horizon
self.mode = mode
self.gamma = gamma
self.alpha = alpha
self.min_w = min_w
self.register_buffer("w", self._make_weights())
def _make_weights(self) -> torch.Tensor:
if self.mode == "exp":
w = torch.tensor([self.gamma ** t for t in range(self.horizon)], dtype=torch.float32)
elif self.mode == "linear":
w = torch.tensor([max(1.0 - self.alpha * t, self.min_w) for t in range(self.horizon)], dtype=torch.float32)
else:
raise ValueError("mode 必须是 'exp''linear'")
# 归一化到均值为 1或 sum==H让数值尺度稳定
w = w * (self.horizon / (w.sum() + 1e-12))
return w # [H]
def forward(self, pred: torch.Tensor, target: torch.Tensor):
# pred,target: [B,H,F]
if pred.shape != target.shape:
raise ValueError("pred/target 形状不一致")
B, H, F = pred.shape
if H != self.horizon:
raise ValueError("horizon 不匹配")
err2 = (pred - target) ** 2 # [B,H,F]
w = self.w.view(1, H, 1) # [1,H,1]
w_err = err2 * w
mse = w_err.sum() / (w.sum() * B * F) # 加权 MSE
rmse = torch.sqrt(mse + 1e-12) # RMSE
return rmse
@torch.no_grad()
def evaluate_losses(model, loader, device, weighted_rmse: WeightedRMSELoss):
model.eval()
total_wrmse, total_mse = 0.0, 0.0
count = 0
mse_fn = nn.MSELoss()
for xb, yb in loader:
xb = xb.to(device)
yb = yb.to(device)
pred = model(xb)
total_wrmse += weighted_rmse(pred, yb).item() * xb.size(0)
total_mse += mse_fn(pred, yb).item() * xb.size(0)
count += xb.size(0)
return total_wrmse / max(1, count), total_mse / max(1, count)
@torch.no_grad()
def mae_in_original_scale(model, loader, device, mean: np.ndarray, std: np.ndarray):
model.eval()
total_mae, count = 0.0, 0
mean_t = torch.from_numpy(mean).to(device).view(1, 1, -1)
std_t = torch.from_numpy(std).to(device).view(1, 1, -1)
for xb, yb in loader:
xb = xb.to(device); yb = yb.to(device)
pred = model(xb)
pred_denorm = pred * std_t + mean_t
yb_denorm = yb * std_t + mean_t
mae = torch.abs(pred_denorm - yb_denorm).mean()
total_mae += mae.item() * xb.size(0)
count += xb.size(0)
return total_mae / max(1, count)
def read_and_align_csvs(paths: List[str]) -> pd.DataFrame:
"""
读取多个 CSV确保数值列集合与顺序一致若存在 frame_index 列按其排序并移除。
将各文件按行拼接,返回对齐后的整表(仅数值列)。
同时返回每个段落的起止行以备切片(通过 offsets
"""
numeric_cols_ref = None
dfs = []
for p in paths:
df = pd.read_csv(p)
if "frame_index" in df.columns:
df = df.sort_values("frame_index").drop(columns=["frame_index"])
df_num = df.select_dtypes(include=[np.number]).copy()
if df_num.shape[1] == 0:
raise ValueError(f"{p} 中没有数值列。")
cols = list(df_num.columns)
if numeric_cols_ref is None:
numeric_cols_ref = cols
else:
if cols != numeric_cols_ref:
raise ValueError(
f"列不一致:\n 参考列: {numeric_cols_ref}\n {p} 列: {cols}\n"
"请确保所有 CSV 的数值列名与顺序一致。"
)
dfs.append(df_num.reset_index(drop=True))
big = pd.concat(dfs, axis=0, ignore_index=True)
return big
def build_multi_file_starts(lengths: List[int], seq_len: int, horizon: int, stride: int) -> np.ndarray:
"""
对多个段(各自长度 T_i分别生成窗口起点再加上段偏移并合并。
确保窗口不跨段。
"""
starts_all = []
offset = 0
for T in lengths:
local = build_windows_indices(T, seq_len, horizon, stride)
if len(local) > 0:
starts_all.append(local + offset)
offset += T
if not starts_all:
return np.array([], dtype=np.int64)
return np.concatenate(starts_all, axis=0)
def main():
ap = argparse.ArgumentParser()
# === 关键默认:输入 240、预测 30 ===
ap.add_argument("--csv_paths", type=str, nargs="+", help="多个 CSV 路径(空格分隔)")
ap.add_argument("--csv_path", type=str, default="", help="兼容单文件旧参数;若提供也会被加入")
ap.add_argument("--seq_len", type=int, default=180, help="输入窗口长度 L")
ap.add_argument("--pred_horizon", type=int, default=45, help="预测步数 H输出 30 帧)")
ap.add_argument("--stride", type=int, default=1, help="窗口步长")
ap.add_argument("--val_ratio", type=float, default=0.2, help="验证集比例(交错采样)")
ap.add_argument("--val_gap", type=int, default=1, help="从训练中剔除与验证窗口相邻的窗口数量")
ap.add_argument("--batch_size", type=int, default=512)
ap.add_argument("--epochs", type=int, default=70)
ap.add_argument("--lr", type=float, default=3e-4)
ap.add_argument("--d_model", type=int, default=256)
ap.add_argument("--nhead", type=int, default=8)
ap.add_argument("--num_layers", type=int, default=4)
ap.add_argument("--dim_ff", type=int, default=1024)
ap.add_argument("--dropout", type=float, default=0.1)
ap.add_argument("--grad_clip", type=float, default=1.0)
ap.add_argument("--patience", type=int, default=20)
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--save_dir", type=str, default="./outputs/chaos")
# 加权 RMSE 配置
ap.add_argument("--horizon_weight_mode", type=str, default="exp", choices=["exp", "linear"],
help="预测步权重模式exp=指数衰减linear=线性衰减")
ap.add_argument("--horizon_gamma", type=float, default=0.99,
help="指数衰减系数 gamma (0<gamma<=1),越小靠后权重越低")
ap.add_argument("--horizon_alpha", type=float, default=0.02,
help="线性衰减步长 alpha仅当 mode=linear 生效)")
ap.add_argument("--horizon_min_w", type=float, default=0.05,
help="线性衰减的最小权重下限(避免为 0")
args = ap.parse_args()
os.makedirs(args.save_dir, exist_ok=True)
set_seed(args.seed)
device = torch.device("cuda")
print(f"Device: {device}")
# ====== 读取多个 CSV 并对齐列 ======
paths: List[str] = []
if args.csv_paths:
paths.extend(args.csv_paths)
if args.csv_path:
paths.append(args.csv_path)
if not paths:
raise ValueError("请通过 --csv_paths 指定至少一个 CSV。")
big_df = read_and_align_csvs(paths)
values = big_df.values.astype(np.float32) # [T_total, F]
T_total, F = values.shape
print(f"Loaded {len(paths)} files. Total T={T_total}, F={F}. Columns={list(big_df.columns)}")
# 段长度(用于保证窗口不跨文件)
lengths = []
for p in paths:
df = pd.read_csv(p)
if "frame_index" in df.columns:
df = df.sort_values("frame_index").drop(columns=["frame_index"])
lengths.append(len(df))
# ====== 生成全局窗口起点(不跨段) ======
starts_all = build_multi_file_starts(lengths, args.seq_len, args.pred_horizon, args.stride)
print(f"Windows total: {len(starts_all)}")
if len(starts_all) == 0:
raise ValueError("无可用窗口,请检查 seq_len/pred_horizon/stride 与数据长度。")
# ====== 均匀交错划分训练/验证 ======
train_starts, val_starts = interleaved_split(starts_all, args.val_ratio, val_gap=args.val_gap)
print(f"Split -> train: {len(train_starts)}, val: {len(val_starts)}")
# ====== 用训练窗口拟合标准化 ======
mean, std = fit_scaler_from_windows(values, train_starts, args.seq_len)
values_norm = (values - mean) / std
# ====== DataLoader ======
train_ds = WindowedTSDataset(values_norm, train_starts, args.seq_len, args.pred_horizon)
val_ds = WindowedTSDataset(values_norm, val_starts, args.seq_len, args.pred_horizon)
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False)
# ====== 模型/优化器/损失 ======
model = TimeSeriesTransformer(
in_features=F, d_model=args.d_model, nhead=args.nhead,
num_layers=args.num_layers, dim_feedforward=args.dim_ff,
dropout=args.dropout, horizon=args.pred_horizon
).to(device)
weighted_rmse = WeightedRMSELoss(
horizon=args.pred_horizon,
mode=args.horizon_weight_mode,
gamma=args.horizon_gamma,
alpha=args.horizon_alpha,
min_w=args.horizon_min_w
).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
# ====== 训练(以加权 RMSE 为主损失,早停监控验证加权 RMSE ======
best_val = float("inf")
best_path = os.path.join(args.save_dir, "best_model.pt")
no_improve = 0
for epoch in range(1, args.epochs + 1):
# 训练
model.train()
total_train_loss = 0.0
n_train = 0
for xb, yb in train_loader:
xb = xb.to(device); yb = yb.to(device)
optimizer.zero_grad(set_to_none=True)
pred = model(xb) # [B,H,F]
loss = weighted_rmse(pred, yb) # 加权 RMSE
loss.backward()
if args.grad_clip is not None:
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
total_train_loss += loss.item() * xb.size(0)
n_train += xb.size(0)
# 验证
val_wrmse, val_mse = evaluate_losses(model, val_loader, device, weighted_rmse)
train_wrmse = total_train_loss / max(1, n_train)
scheduler.step(val_wrmse)
print(f"Epoch {epoch:03d}/{args.epochs} | train wRMSE={train_wrmse:.6f} | "
f"val wRMSE={val_wrmse:.6f} | val MSE={val_mse:.6f}")
if val_wrmse < best_val - 1e-8:
best_val = val_wrmse
no_improve = 0
torch.save(
{
"model_state": model.state_dict(),
"config": {
"in_features": F, "d_model": args.d_model, "nhead": args.nhead,
"num_layers": args.num_layers, "dim_ff": args.dim_ff,
"dropout": args.dropout, "horizon": args.pred_horizon
},
"loss_cfg": {
"mode": args.horizon_weight_mode,
"gamma": args.horizon_gamma,
"alpha": args.horizon_alpha,
"min_w": args.horizon_min_w
}
},
best_path
)
else:
no_improve += 1
if no_improve >= args.patience:
print(f"Early stopping at epoch {epoch}. Best val wRMSE={best_val:.6f}")
break
# 保存标准化
np.savez(os.path.join(args.save_dir, "scaler.npz"),
mean=mean, std=std, columns=np.array(big_df.columns))
# 评估(最佳模型)与原尺度 MAE
ckpt = torch.load(best_path, map_location=device)
model.load_state_dict(ckpt["model_state"])
val_wrmse, val_mse = evaluate_losses(model, val_loader, device, weighted_rmse)
val_mae_orig = mae_in_original_scale(model, val_loader, device, mean, std)
print(f"[Best] Val wRMSE={val_wrmse:.6f} | Val MSE={val_mse:.6f} | Val MAE(original)={val_mae_orig:.6f}")
# 示例推理:取最后一个窗口打印 30 步预测(反标准化)
with torch.no_grad():
last_start = int(starts_all[-1])
x_last = values_norm[last_start: last_start + args.seq_len]
y_true = values[last_start + args.seq_len: last_start + args.seq_len + args.pred_horizon]
pred = model(torch.from_numpy(x_last).unsqueeze(0).to(device)).cpu().numpy()[0]
pred_denorm = pred * std + mean
print("\n=== 示例(最后一个窗口) ===")
for h in range(args.pred_horizon):
line = f"+{h+1:02d}: pred={np.round(pred_denorm[h], 6)}"
if h < y_true.shape[0]:
line += f" | true={np.round(y_true[h], 6)}"
print(line)
print(f"\n保存:\n- 模型: {best_path}\n- 标准化: {os.path.join(args.save_dir, 'scaler.npz')}")
if __name__ == "__main__":
main()