2025-09-24 11:35:48 +08:00

336 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
eval_model.py
- 加载训练好的 Transformer 模型best_model.pt和标准化参数scaler.npz
- 对单个 CSV 进行标准化、滑窗切片并评测wRMSE / MSE / MAE(original)
- 可选输出最后一个窗口的预测与真值(反标准化)到 CSV
使用示例:
uv run chaos_pdl/eval_model.py \
--csv_path chaos_pdl/data/achieve2/raw/IMG_1119_out_metrics.csv \
--model_path outputs/chaos/best_model.pt \
--scaler_path outputs/chaos/scaler.npz \
--seq_len 180 --stride 1 \
--save_dir outputs/chaos
"""
import argparse
import os
from typing import List
import numpy as np
import pandas as pd
import torch
# 复用训练脚本中的模块与评测函数
# 兼容两种导入方式:
try:
# 优先按包路径导入(若将项目安装为包或以工作区根为 sys.path
from chaos_pdl.pendulum_transformer import (
TimeSeriesTransformer,
WeightedRMSELoss,
WindowedTSDataset,
)
except ModuleNotFoundError:
# 退化为同目录导入(直接运行本脚本时更稳)
from pendulum_transformer import (
TimeSeriesTransformer,
WeightedRMSELoss,
WindowedTSDataset,
)
def build_windows_indices(T: int, seq_len: int, horizon: int, stride: int) -> np.ndarray:
"""与训练一致:生成起点索引,确保 s+seq_len+horizon 不越界。"""
max_start = T - seq_len - horizon
if max_start < 0:
return np.array([], dtype=np.int64)
return np.arange(0, max_start + 1, stride, dtype=np.int64)
@torch.no_grad()
def evaluate_losses(model, loader, device, weighted_rmse: WeightedRMSELoss):
model.eval()
total_wrmse, total_mse = 0.0, 0.0
count = 0
mse_fn = torch.nn.MSELoss()
for xb, yb in loader:
xb = xb.to(device)
yb = yb.to(device)
pred = model(xb)
total_wrmse += weighted_rmse(pred, yb).item() * xb.size(0)
total_mse += mse_fn(pred, yb).item() * xb.size(0)
count += xb.size(0)
return total_wrmse / max(1, count), total_mse / max(1, count)
@torch.no_grad()
def mae_in_original_scale(model, loader, device, mean: np.ndarray, std: np.ndarray):
model.eval()
total_mae, count = 0.0, 0
mean_t = torch.from_numpy(mean).to(device).view(1, 1, -1)
std_t = torch.from_numpy(std).to(device).view(1, 1, -1)
for xb, yb in loader:
xb = xb.to(device)
yb = yb.to(device)
pred = model(xb)
pred_denorm = pred * std_t + mean_t
yb_denorm = yb * std_t + mean_t
mae = torch.abs(pred_denorm - yb_denorm).mean()
total_mae += mae.item() * xb.size(0)
count += xb.size(0)
return total_mae / max(1, count)
def read_single_csv_align(path: str, expect_columns: List[str]) -> pd.DataFrame:
"""
读取单个 CSV
- 若含 frame_index则按其排序并移除
- 仅保留数值列;
- 校验列集合与 scaler 中记录的列一致,并按预期顺序重排。
"""
df = pd.read_csv(path)
if "frame_index" in df.columns:
df = df.sort_values("frame_index").drop(columns=["frame_index"]) # 对齐训练逻辑
df_num = df.select_dtypes(include=[np.number]).copy()
has = list(df_num.columns)
if set(has) != set(expect_columns):
raise ValueError(
f"CSV 数值列与训练时不一致\n 期望: {expect_columns}\n 实际: {has}\n"
"请确保列集合一致(名称与类型),且来源一致。"
)
# 按 scaler 的列顺序重排
df_num = df_num[expect_columns]
return df_num.reset_index(drop=True)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--csv_path", type=str, required=True, help="待评测的 CSV 文件路径")
ap.add_argument("--model_path", type=str, default="./outputs/chaos/best_model.pt")
ap.add_argument("--scaler_path", type=str, default="./outputs/chaos/scaler.npz")
ap.add_argument("--seq_len", type=int, default=180, help="历史窗口长度 L需与训练一致")
ap.add_argument("--stride", type=int, default=1, help="评测滑窗步长")
ap.add_argument("--batch_size", type=int, default=1024)
ap.add_argument("--save_dir", type=str, default="", help="若提供,则保存最后一个窗口的预测/真值 CSV")
ap.add_argument(
"--plot_path",
type=str,
default="",
help="每步RMSE(原尺度)曲线图的保存路径(.png。若未提供且有save_dir则保存到 save_dir/per_step_rmse.png",
)
ap.add_argument(
"--plot_heatmap",
action="store_true",
help="输出每步×每特征的RMSE热力图原尺度保存到 save_dir/per_step_feature_rmse.png 与 CSV",
)
ap.add_argument(
"--plot_feature",
type=str,
default="",
help="指定一个特征名,对最后一个窗口画出预测与真值曲线;默认取第一个列名",
)
args = ap.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# 读取 scaler包含列顺序
sc = np.load(args.scaler_path, allow_pickle=True)
mean = sc["mean"]
std = sc["std"]
columns = list(sc["columns"].tolist())
print(f"Scaler loaded. Columns={columns}")
# 读取并对齐 CSV
df = read_single_csv_align(args.csv_path, columns)
values = df.values.astype(np.float32)
T, F = values.shape
print(f"Data loaded. T={T}, F={F}")
# 加载模型
ckpt = torch.load(args.model_path, map_location=device)
cfg = ckpt["config"]
loss_cfg = ckpt.get("loss_cfg", {"mode": "exp", "gamma": 0.99, "alpha": 0.02, "min_w": 0.05})
horizon = int(cfg["horizon"]) # 评测时以 ckpt 中的 horizon 为准
model = TimeSeriesTransformer(
in_features=F,
d_model=cfg["d_model"],
nhead=cfg["nhead"],
num_layers=cfg["num_layers"],
dim_feedforward=cfg["dim_ff"],
dropout=cfg["dropout"],
horizon=horizon,
).to(device)
model.load_state_dict(ckpt["model_state"])
model.eval()
# 构造评测窗口
starts = build_windows_indices(T, args.seq_len, horizon, args.stride)
if len(starts) == 0:
raise ValueError(
f"数据不足T={T},需要至少 seq_len({args.seq_len})+horizon({horizon}) 帧。"
)
print(f"Eval windows: {len(starts)} (L={args.seq_len}, H={horizon}, stride={args.stride})")
# 标准化(使用训练好的 mean/std
values_norm = (values - mean) / std
# DataLoader
ds = WindowedTSDataset(values_norm, starts, args.seq_len, horizon)
loader = torch.utils.data.DataLoader(ds, batch_size=args.batch_size, shuffle=False)
# 构造损失
weighted_rmse = WeightedRMSELoss(
horizon=horizon,
mode=loss_cfg.get("mode", "exp"),
gamma=loss_cfg.get("gamma", 0.99),
alpha=loss_cfg.get("alpha", 0.02),
min_w=loss_cfg.get("min_w", 0.05),
).to(device)
# 评测
val_wrmse, val_mse = evaluate_losses(model, loader, device, weighted_rmse)
val_mae_orig = mae_in_original_scale(model, loader, device, mean, std)
print(
f"[Eval] wRMSE={val_wrmse:.6f} | MSE={val_mse:.6f} | MAE(original)={val_mae_orig:.6f}"
)
# 统计每个预测步 t 的原尺度 RMSE并可视化可选统计每步×每特征的RMSE热力图
try:
import matplotlib.pyplot as plt
model.eval()
H = horizon
# 累积每步的 MSE原尺度
step_sse = np.zeros(H, dtype=np.float64)
step_count = 0 # 累积的样本数B*F按 batch 聚合
# 若需要热力图:累积 [H,F] 的平方误差和,以及样本计数(按窗口计数)
per_feat_sse = None # shape [H,F]
window_count = 0
mean_t = torch.from_numpy(mean).to(device).view(1, 1, -1)
std_t = torch.from_numpy(std).to(device).view(1, 1, -1)
with torch.no_grad():
for xb, yb in loader:
xb = xb.to(device)
yb = yb.to(device)
pred = model(xb)
pred_den = pred * std_t + mean_t
yb_den = yb * std_t + mean_t
err = pred_den - yb_den # [B,H,F]
# 对 batch 和 feature 求均方,先累加平方和,最后再除以总数
sse_h = (err ** 2).sum(dim=(0, 2)).detach().cpu().numpy() # [H]
step_sse += sse_h
step_count += (xb.size(0) * xb.size(2))
if args.plot_heatmap:
sse_hf = (err ** 2).sum(dim=0).detach().cpu().numpy() # [H,F] 按 batch 合并
if per_feat_sse is None:
per_feat_sse = sse_hf.astype(np.float64)
else:
per_feat_sse += sse_hf
window_count += xb.size(0)
step_mse = step_sse / max(1, step_count)
step_rmse = np.sqrt(np.maximum(step_mse, 1e-12))
# 保存数据到 CSV若 save_dir 提供)
if args.save_dir:
os.makedirs(args.save_dir, exist_ok=True)
csv_path = os.path.join(args.save_dir, "per_step_rmse.csv")
pd.DataFrame({
"step": np.arange(1, H + 1),
"rmse": step_rmse,
"mse": step_mse,
}).to_csv(csv_path, index=False)
print(f"Saved: {csv_path}")
# 绘图路径
plot_path = args.plot_path
if not plot_path and args.save_dir:
plot_path = os.path.join(args.save_dir, "per_step_rmse.png")
if plot_path:
plt.figure(figsize=(8, 4.5))
plt.plot(np.arange(1, H + 1), step_rmse, marker="o", lw=1.5)
plt.xlabel("Prediction step (t)")
plt.ylabel("RMSE (original scale)")
plt.title("Per-step RMSE vs prediction step")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(plot_path, dpi=150)
print(f"Saved: {plot_path}")
# 输出热力图与CSV
if args.plot_heatmap and per_feat_sse is not None and args.save_dir:
os.makedirs(args.save_dir, exist_ok=True)
per_feat_mse = per_feat_sse / max(1, window_count) # [H,F] 每步每特征 MSE按窗口平均
per_feat_rmse = np.sqrt(np.maximum(per_feat_mse, 1e-12))
# CSV行为 step列为各特征
csv_hm = os.path.join(args.save_dir, "per_step_feature_rmse.csv")
pd.DataFrame(per_feat_rmse, columns=columns, index=np.arange(1, H + 1)).to_csv(csv_hm)
print(f"Saved: {csv_hm}")
# 热力图
plt.figure(figsize=(min(12, 1.5 + 0.35 * len(columns) + 6), 5))
im = plt.imshow(per_feat_rmse, aspect="auto", origin="lower", cmap="viridis")
plt.colorbar(im, label="RMSE")
plt.yticks(ticks=np.arange(0, H, max(1, H // 10)), labels=(np.arange(1, H + 1)[::max(1, H // 10)]))
plt.xticks(ticks=np.arange(len(columns)), labels=columns, rotation=45, ha="right")
plt.xlabel("Feature")
plt.ylabel("Prediction step (t)")
plt.title("Per-step per-feature RMSE (original scale)")
plt.tight_layout()
hm_path = os.path.join(args.save_dir, "per_step_feature_rmse.png")
plt.savefig(hm_path, dpi=150)
print(f"Saved: {hm_path}")
except Exception as e:
print(f"[warn] 每步RMSE绘图失败{e}")
# 可选:保存最后一个窗口的预测与真值(反标准化)
if args.save_dir:
os.makedirs(args.save_dir, exist_ok=True)
last_start = int(starts[-1])
x_last = values_norm[last_start : last_start + args.seq_len]
y_true = values[last_start + args.seq_len : last_start + args.seq_len + horizon]
with torch.no_grad():
pred = (
model(torch.from_numpy(x_last).unsqueeze(0).to(device))
.cpu()
.numpy()[0]
)
pred_denorm = pred * std + mean
# 保存为两份 CSVpred 和 true行是步长 t=1..H列为各特征名
pred_df = pd.DataFrame(pred_denorm, columns=columns)
true_df = pd.DataFrame(y_true, columns=columns)
pred_path = os.path.join(args.save_dir, "last_window_pred.csv")
true_path = os.path.join(args.save_dir, "last_window_true.csv")
pred_df.to_csv(pred_path, index=False)
true_df.to_csv(true_path, index=False)
print(f"Saved: {pred_path}\nSaved: {true_path}")
# 可选:对某个特征画最后一个窗口的预测-真值曲线
try:
import matplotlib.pyplot as plt
feat = args.plot_feature or columns[0]
if feat not in columns:
print(f"[warn] plot_feature='{feat}' 不在列中,使用默认 {columns[0]}")
feat = columns[0]
fi = columns.index(feat)
plt.figure(figsize=(7, 4))
plt.plot(np.arange(1, horizon + 1), pred_denorm[:, fi], label="pred", marker="o", lw=1.5)
plt.plot(np.arange(1, horizon + 1), y_true[:, fi], label="true", marker="o", lw=1.5)
plt.xlabel("Prediction step (t)")
plt.ylabel(feat)
plt.title(f"Last window prediction vs truth: {feat}")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
line_path = os.path.join(args.save_dir, f"last_window_{feat}_pred_vs_true.png")
plt.savefig(line_path, dpi=150)
print(f"Saved: {line_path}")
except Exception as e:
print(f"[warn] 最后窗口曲线绘图失败:{e}")
if __name__ == "__main__":
main()