449 lines
18 KiB
Python
449 lines
18 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
train_ts_transformer_multi.py
|
||
- 同时读取多个 CSV(列名需一致;若存在 frame_index 会按其排序并移除)
|
||
- Transformer Encoder 进行多步预测(默认输入 240 帧,预测 30 帧)
|
||
- 训练/验证集“均匀交错抽样”划分,可设置 val_gap 减少邻近泄漏
|
||
- 标准化仅用训练窗口拟合,避免验证泄漏
|
||
- 损失:按预测步 t 递减权重的加权 RMSE(支持指数/线性衰减)
|
||
- 保存 best_model.pt 与 scaler.npz
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
import math
|
||
import random
|
||
from typing import List, Tuple
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
from torch import nn
|
||
from torch.utils.data import Dataset, DataLoader
|
||
|
||
|
||
def set_seed(seed: int = 42):
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
torch.cuda.manual_seed_all(seed)
|
||
|
||
|
||
class WindowedTSDataset(Dataset):
|
||
def __init__(self, X_norm: np.ndarray, starts: np.ndarray, seq_len: int, horizon: int):
|
||
self.X = X_norm.astype(np.float32) # [T_total, F]
|
||
self.starts = starts.astype(np.int64) # 全局起点索引(已处理多文件偏移)
|
||
self.seq_len = seq_len
|
||
self.horizon = horizon
|
||
|
||
def __len__(self):
|
||
return len(self.starts)
|
||
|
||
def __getitem__(self, idx):
|
||
s = int(self.starts[idx])
|
||
x = self.X[s: s + self.seq_len] # [L, F]
|
||
y = self.X[s + self.seq_len: s + self.seq_len + self.horizon] # [H, F]
|
||
return torch.from_numpy(x), torch.from_numpy(y)
|
||
|
||
|
||
class PositionalEncoding(nn.Module):
|
||
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 10000):
|
||
super().__init__()
|
||
self.dropout = nn.Dropout(p=dropout)
|
||
pe = torch.zeros(max_len, d_model)
|
||
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
|
||
div_term = torch.exp(
|
||
torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model)
|
||
)
|
||
pe[:, 0::2] = torch.sin(position * div_term)
|
||
pe[:, 1::2] = torch.cos(position * div_term)
|
||
self.register_buffer("pe", pe.unsqueeze(1)) # [max_len, 1, d_model]
|
||
|
||
def forward(self, x: torch.Tensor):
|
||
# x: [S, N, E]
|
||
x = x + self.pe[: x.size(0)]
|
||
return self.dropout(x)
|
||
|
||
|
||
class TimeSeriesTransformer(nn.Module):
|
||
def __init__(self, in_features: int, d_model: int, nhead: int, num_layers: int,
|
||
dim_feedforward: int, dropout: float, horizon: int):
|
||
super().__init__()
|
||
self.in_features = in_features
|
||
self.horizon = horizon
|
||
self.input_proj = nn.Linear(in_features, d_model)
|
||
self.pos_enc = PositionalEncoding(d_model, dropout=dropout)
|
||
enc_layer = nn.TransformerEncoderLayer(
|
||
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
|
||
dropout=dropout, batch_first=False, norm_first=True
|
||
)
|
||
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
|
||
self.head = nn.Sequential(
|
||
nn.LayerNorm(d_model),
|
||
nn.Linear(d_model, horizon * in_features),
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor):
|
||
# x: [B, L, F]
|
||
B, L, F = x.shape
|
||
h = self.input_proj(x).transpose(0, 1) # [L, B, d]
|
||
h = self.pos_enc(h)
|
||
h = self.encoder(h) # [L, B, d]
|
||
h_last = h[-1] # [B, d]
|
||
y = self.head(h_last).view(B, self.horizon, self.in_features) # [B, H, F]
|
||
return y
|
||
|
||
|
||
def build_windows_indices(T: int, seq_len: int, horizon: int, stride: int) -> np.ndarray:
|
||
max_start = T - seq_len - horizon
|
||
if max_start < 0:
|
||
return np.array([], dtype=np.int64)
|
||
return np.arange(0, max_start + 1, stride, dtype=np.int64)
|
||
|
||
|
||
def interleaved_split(starts: np.ndarray, val_ratio: float, val_gap: int = 0) -> Tuple[np.ndarray, np.ndarray]:
|
||
assert 0.0 < val_ratio < 1.0
|
||
total = len(starts)
|
||
if total == 0:
|
||
return np.array([], dtype=np.int64), np.array([], dtype=np.int64)
|
||
every = max(int(round(1.0 / val_ratio)), 2)
|
||
# 安全裁剪 val_gap:若 2*val_gap+1 >= every,会把所有索引都覆盖掉
|
||
safe_gap = int(min(max((every - 2) // 2, 0), val_gap))
|
||
if safe_gap != val_gap:
|
||
print(f"[warn] val_gap={val_gap} 过大,按交错周期 every={every} 自动裁剪为 {safe_gap},以避免训练集为空。")
|
||
idx_all = np.arange(total, dtype=np.int64)
|
||
val_mask = (idx_all % every) == 0
|
||
val_idx = idx_all[val_mask]
|
||
train_mask = ~val_mask
|
||
if safe_gap > 0:
|
||
drop = set()
|
||
for v in val_idx:
|
||
lo = max(0, v - safe_gap)
|
||
hi = min(total - 1, v + safe_gap)
|
||
for t in range(lo, hi + 1):
|
||
drop.add(t)
|
||
if drop:
|
||
train_mask[list(drop)] = False
|
||
train_idx = idx_all[train_mask]
|
||
return starts[train_idx], starts[val_idx]
|
||
|
||
|
||
def fit_scaler_from_windows(X: np.ndarray, train_starts: np.ndarray, seq_len: int) -> Tuple[np.ndarray, np.ndarray]:
|
||
chunks = [X[s:s + seq_len] for s in train_starts]
|
||
if not chunks:
|
||
raise ValueError("训练窗口为空,无法拟合标准化。")
|
||
arr = np.concatenate(chunks, axis=0) # [N_train*L, F]
|
||
mean = arr.mean(axis=0)
|
||
std = arr.std(axis=0)
|
||
std[std < 1e-6] = 1e-6
|
||
return mean, std
|
||
|
||
|
||
class WeightedRMSELoss(nn.Module):
|
||
"""
|
||
对预测步 t 使用递减权重 w_t;支持:
|
||
- 指数衰减: w_t = gamma^t
|
||
- 线性衰减: w_t = 1 - alpha * t, 直到 >= min_w
|
||
返回整体 RMSE 标量(可反传)
|
||
"""
|
||
def __init__(self, horizon: int, mode: str = "exp",
|
||
gamma: float = 0.97, alpha: float = 0.0, min_w: float = 0.05):
|
||
super().__init__()
|
||
self.horizon = horizon
|
||
self.mode = mode
|
||
self.gamma = gamma
|
||
self.alpha = alpha
|
||
self.min_w = min_w
|
||
self.register_buffer("w", self._make_weights())
|
||
|
||
def _make_weights(self) -> torch.Tensor:
|
||
if self.mode == "exp":
|
||
w = torch.tensor([self.gamma ** t for t in range(self.horizon)], dtype=torch.float32)
|
||
elif self.mode == "linear":
|
||
w = torch.tensor([max(1.0 - self.alpha * t, self.min_w) for t in range(self.horizon)], dtype=torch.float32)
|
||
else:
|
||
raise ValueError("mode 必须是 'exp' 或 'linear'")
|
||
# 归一化到均值为 1(或 sum==H),让数值尺度稳定
|
||
w = w * (self.horizon / (w.sum() + 1e-12))
|
||
return w # [H]
|
||
|
||
def forward(self, pred: torch.Tensor, target: torch.Tensor):
|
||
# pred,target: [B,H,F]
|
||
if pred.shape != target.shape:
|
||
raise ValueError("pred/target 形状不一致")
|
||
B, H, F = pred.shape
|
||
if H != self.horizon:
|
||
raise ValueError("horizon 不匹配")
|
||
err2 = (pred - target) ** 2 # [B,H,F]
|
||
w = self.w.view(1, H, 1) # [1,H,1]
|
||
w_err = err2 * w
|
||
mse = w_err.sum() / (w.sum() * B * F) # 加权 MSE
|
||
rmse = torch.sqrt(mse + 1e-12) # RMSE
|
||
return rmse
|
||
|
||
|
||
@torch.no_grad()
|
||
def evaluate_losses(model, loader, device, weighted_rmse: WeightedRMSELoss):
|
||
model.eval()
|
||
total_wrmse, total_mse = 0.0, 0.0
|
||
count = 0
|
||
mse_fn = nn.MSELoss()
|
||
for xb, yb in loader:
|
||
xb = xb.to(device)
|
||
yb = yb.to(device)
|
||
pred = model(xb)
|
||
total_wrmse += weighted_rmse(pred, yb).item() * xb.size(0)
|
||
total_mse += mse_fn(pred, yb).item() * xb.size(0)
|
||
count += xb.size(0)
|
||
return total_wrmse / max(1, count), total_mse / max(1, count)
|
||
|
||
|
||
@torch.no_grad()
|
||
def mae_in_original_scale(model, loader, device, mean: np.ndarray, std: np.ndarray):
|
||
model.eval()
|
||
total_mae, count = 0.0, 0
|
||
mean_t = torch.from_numpy(mean).to(device).view(1, 1, -1)
|
||
std_t = torch.from_numpy(std).to(device).view(1, 1, -1)
|
||
for xb, yb in loader:
|
||
xb = xb.to(device); yb = yb.to(device)
|
||
pred = model(xb)
|
||
pred_denorm = pred * std_t + mean_t
|
||
yb_denorm = yb * std_t + mean_t
|
||
mae = torch.abs(pred_denorm - yb_denorm).mean()
|
||
total_mae += mae.item() * xb.size(0)
|
||
count += xb.size(0)
|
||
return total_mae / max(1, count)
|
||
|
||
|
||
def read_and_align_csvs(paths: List[str]) -> pd.DataFrame:
|
||
"""
|
||
读取多个 CSV,确保数值列集合与顺序一致;若存在 frame_index 列按其排序并移除。
|
||
将各文件按行拼接,返回对齐后的整表(仅数值列)。
|
||
同时返回每个段落的起止行以备切片(通过 offsets)。
|
||
"""
|
||
numeric_cols_ref = None
|
||
dfs = []
|
||
for p in paths:
|
||
df = pd.read_csv(p)
|
||
if "frame_index" in df.columns:
|
||
df = df.sort_values("frame_index").drop(columns=["frame_index"])
|
||
df_num = df.select_dtypes(include=[np.number]).copy()
|
||
if df_num.shape[1] == 0:
|
||
raise ValueError(f"{p} 中没有数值列。")
|
||
cols = list(df_num.columns)
|
||
if numeric_cols_ref is None:
|
||
numeric_cols_ref = cols
|
||
else:
|
||
if cols != numeric_cols_ref:
|
||
raise ValueError(
|
||
f"列不一致:\n 参考列: {numeric_cols_ref}\n {p} 列: {cols}\n"
|
||
"请确保所有 CSV 的数值列名与顺序一致。"
|
||
)
|
||
dfs.append(df_num.reset_index(drop=True))
|
||
big = pd.concat(dfs, axis=0, ignore_index=True)
|
||
return big
|
||
|
||
|
||
def build_multi_file_starts(lengths: List[int], seq_len: int, horizon: int, stride: int) -> np.ndarray:
|
||
"""
|
||
对多个段(各自长度 T_i)分别生成窗口起点,再加上段偏移并合并。
|
||
确保窗口不跨段。
|
||
"""
|
||
starts_all = []
|
||
offset = 0
|
||
for T in lengths:
|
||
local = build_windows_indices(T, seq_len, horizon, stride)
|
||
if len(local) > 0:
|
||
starts_all.append(local + offset)
|
||
offset += T
|
||
if not starts_all:
|
||
return np.array([], dtype=np.int64)
|
||
return np.concatenate(starts_all, axis=0)
|
||
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser()
|
||
# === 关键默认:输入 240、预测 30 ===
|
||
ap.add_argument("--csv_paths", type=str, nargs="+", help="多个 CSV 路径(空格分隔)")
|
||
ap.add_argument("--csv_path", type=str, default="", help="兼容单文件旧参数;若提供也会被加入")
|
||
ap.add_argument("--seq_len", type=int, default=180, help="输入窗口长度 L")
|
||
ap.add_argument("--pred_horizon", type=int, default=45, help="预测步数 H(输出 30 帧)")
|
||
ap.add_argument("--stride", type=int, default=1, help="窗口步长")
|
||
ap.add_argument("--val_ratio", type=float, default=0.2, help="验证集比例(交错采样)")
|
||
ap.add_argument("--val_gap", type=int, default=1, help="从训练中剔除与验证窗口相邻的窗口数量")
|
||
ap.add_argument("--batch_size", type=int, default=512)
|
||
ap.add_argument("--epochs", type=int, default=70)
|
||
ap.add_argument("--lr", type=float, default=3e-4)
|
||
ap.add_argument("--d_model", type=int, default=256)
|
||
ap.add_argument("--nhead", type=int, default=8)
|
||
ap.add_argument("--num_layers", type=int, default=4)
|
||
ap.add_argument("--dim_ff", type=int, default=1024)
|
||
ap.add_argument("--dropout", type=float, default=0.1)
|
||
ap.add_argument("--grad_clip", type=float, default=1.0)
|
||
ap.add_argument("--patience", type=int, default=20)
|
||
ap.add_argument("--seed", type=int, default=42)
|
||
ap.add_argument("--save_dir", type=str, default="./outputs/chaos")
|
||
|
||
# 加权 RMSE 配置
|
||
ap.add_argument("--horizon_weight_mode", type=str, default="exp", choices=["exp", "linear"],
|
||
help="预测步权重模式:exp=指数衰减;linear=线性衰减")
|
||
ap.add_argument("--horizon_gamma", type=float, default=0.99,
|
||
help="指数衰减系数 gamma (0<gamma<=1),越小靠后权重越低")
|
||
ap.add_argument("--horizon_alpha", type=float, default=0.02,
|
||
help="线性衰减步长 alpha(仅当 mode=linear 生效)")
|
||
ap.add_argument("--horizon_min_w", type=float, default=0.05,
|
||
help="线性衰减的最小权重下限(避免为 0)")
|
||
|
||
args = ap.parse_args()
|
||
os.makedirs(args.save_dir, exist_ok=True)
|
||
set_seed(args.seed)
|
||
device = torch.device("cuda")
|
||
print(f"Device: {device}")
|
||
|
||
# ====== 读取多个 CSV 并对齐列 ======
|
||
paths: List[str] = []
|
||
if args.csv_paths:
|
||
paths.extend(args.csv_paths)
|
||
if args.csv_path:
|
||
paths.append(args.csv_path)
|
||
if not paths:
|
||
raise ValueError("请通过 --csv_paths 指定至少一个 CSV。")
|
||
|
||
big_df = read_and_align_csvs(paths)
|
||
values = big_df.values.astype(np.float32) # [T_total, F]
|
||
T_total, F = values.shape
|
||
print(f"Loaded {len(paths)} files. Total T={T_total}, F={F}. Columns={list(big_df.columns)}")
|
||
|
||
# 段长度(用于保证窗口不跨文件)
|
||
lengths = []
|
||
for p in paths:
|
||
df = pd.read_csv(p)
|
||
if "frame_index" in df.columns:
|
||
df = df.sort_values("frame_index").drop(columns=["frame_index"])
|
||
lengths.append(len(df))
|
||
|
||
# ====== 生成全局窗口起点(不跨段) ======
|
||
starts_all = build_multi_file_starts(lengths, args.seq_len, args.pred_horizon, args.stride)
|
||
print(f"Windows total: {len(starts_all)}")
|
||
if len(starts_all) == 0:
|
||
raise ValueError("无可用窗口,请检查 seq_len/pred_horizon/stride 与数据长度。")
|
||
|
||
# ====== 均匀交错划分训练/验证 ======
|
||
train_starts, val_starts = interleaved_split(starts_all, args.val_ratio, val_gap=args.val_gap)
|
||
print(f"Split -> train: {len(train_starts)}, val: {len(val_starts)}")
|
||
|
||
# ====== 用训练窗口拟合标准化 ======
|
||
mean, std = fit_scaler_from_windows(values, train_starts, args.seq_len)
|
||
values_norm = (values - mean) / std
|
||
|
||
# ====== DataLoader ======
|
||
train_ds = WindowedTSDataset(values_norm, train_starts, args.seq_len, args.pred_horizon)
|
||
val_ds = WindowedTSDataset(values_norm, val_starts, args.seq_len, args.pred_horizon)
|
||
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True)
|
||
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False)
|
||
|
||
# ====== 模型/优化器/损失 ======
|
||
model = TimeSeriesTransformer(
|
||
in_features=F, d_model=args.d_model, nhead=args.nhead,
|
||
num_layers=args.num_layers, dim_feedforward=args.dim_ff,
|
||
dropout=args.dropout, horizon=args.pred_horizon
|
||
).to(device)
|
||
|
||
weighted_rmse = WeightedRMSELoss(
|
||
horizon=args.pred_horizon,
|
||
mode=args.horizon_weight_mode,
|
||
gamma=args.horizon_gamma,
|
||
alpha=args.horizon_alpha,
|
||
min_w=args.horizon_min_w
|
||
).to(device)
|
||
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
|
||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
|
||
|
||
# ====== 训练(以加权 RMSE 为主损失,早停监控验证加权 RMSE) ======
|
||
best_val = float("inf")
|
||
best_path = os.path.join(args.save_dir, "best_model.pt")
|
||
no_improve = 0
|
||
|
||
for epoch in range(1, args.epochs + 1):
|
||
# 训练
|
||
model.train()
|
||
total_train_loss = 0.0
|
||
n_train = 0
|
||
for xb, yb in train_loader:
|
||
xb = xb.to(device); yb = yb.to(device)
|
||
optimizer.zero_grad(set_to_none=True)
|
||
pred = model(xb) # [B,H,F]
|
||
loss = weighted_rmse(pred, yb) # 加权 RMSE
|
||
loss.backward()
|
||
if args.grad_clip is not None:
|
||
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||
optimizer.step()
|
||
total_train_loss += loss.item() * xb.size(0)
|
||
n_train += xb.size(0)
|
||
|
||
# 验证
|
||
val_wrmse, val_mse = evaluate_losses(model, val_loader, device, weighted_rmse)
|
||
train_wrmse = total_train_loss / max(1, n_train)
|
||
scheduler.step(val_wrmse)
|
||
|
||
print(f"Epoch {epoch:03d}/{args.epochs} | train wRMSE={train_wrmse:.6f} | "
|
||
f"val wRMSE={val_wrmse:.6f} | val MSE={val_mse:.6f}")
|
||
|
||
if val_wrmse < best_val - 1e-8:
|
||
best_val = val_wrmse
|
||
no_improve = 0
|
||
torch.save(
|
||
{
|
||
"model_state": model.state_dict(),
|
||
"config": {
|
||
"in_features": F, "d_model": args.d_model, "nhead": args.nhead,
|
||
"num_layers": args.num_layers, "dim_ff": args.dim_ff,
|
||
"dropout": args.dropout, "horizon": args.pred_horizon
|
||
},
|
||
"loss_cfg": {
|
||
"mode": args.horizon_weight_mode,
|
||
"gamma": args.horizon_gamma,
|
||
"alpha": args.horizon_alpha,
|
||
"min_w": args.horizon_min_w
|
||
}
|
||
},
|
||
best_path
|
||
)
|
||
else:
|
||
no_improve += 1
|
||
if no_improve >= args.patience:
|
||
print(f"Early stopping at epoch {epoch}. Best val wRMSE={best_val:.6f}")
|
||
break
|
||
|
||
# 保存标准化
|
||
np.savez(os.path.join(args.save_dir, "scaler.npz"),
|
||
mean=mean, std=std, columns=np.array(big_df.columns))
|
||
|
||
# 评估(最佳模型)与原尺度 MAE
|
||
ckpt = torch.load(best_path, map_location=device)
|
||
model.load_state_dict(ckpt["model_state"])
|
||
val_wrmse, val_mse = evaluate_losses(model, val_loader, device, weighted_rmse)
|
||
val_mae_orig = mae_in_original_scale(model, val_loader, device, mean, std)
|
||
print(f"[Best] Val wRMSE={val_wrmse:.6f} | Val MSE={val_mse:.6f} | Val MAE(original)={val_mae_orig:.6f}")
|
||
|
||
# 示例推理:取最后一个窗口打印 30 步预测(反标准化)
|
||
with torch.no_grad():
|
||
last_start = int(starts_all[-1])
|
||
x_last = values_norm[last_start: last_start + args.seq_len]
|
||
y_true = values[last_start + args.seq_len: last_start + args.seq_len + args.pred_horizon]
|
||
pred = model(torch.from_numpy(x_last).unsqueeze(0).to(device)).cpu().numpy()[0]
|
||
pred_denorm = pred * std + mean
|
||
print("\n=== 示例(最后一个窗口) ===")
|
||
for h in range(args.pred_horizon):
|
||
line = f"+{h+1:02d}: pred={np.round(pred_denorm[h], 6)}"
|
||
if h < y_true.shape[0]:
|
||
line += f" | true={np.round(y_true[h], 6)}"
|
||
print(line)
|
||
|
||
print(f"\n保存:\n- 模型: {best_path}\n- 标准化: {os.path.join(args.save_dir, 'scaler.npz')}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|