2025-09-24 11:35:48 +08:00

148 lines
5.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
bench_latency.py
- 加载训练好的 Transformer 模型best_model.pt与 scaler.npz用于确定列与F
- 构造单次前向输入(从指定 CSV 取最后 L 行并标准化,或随机生成)
- 多次重复前向,统计耗时(均值/最佳/p50/p90/p99支持 CUDA 同步与 autocast
用法示例:
uv run chaos_pdl/bench_latency.py \
--model_path outputs/chaos/best_model.pt \
--scaler_path outputs/chaos/scaler.npz \
--csv_path chaos_pdl/data/metrics-20250920-000328_angles.csv \
--seq_len 180 --iters 200 --warmup 50 --dtype fp32
如导入问题,可用模块方式:
uv run -m chaos_pdl.bench_latency ...
"""
import argparse
import time
from typing import List
import numpy as np
import pandas as pd
import torch
try:
from chaos_pdl.pendulum_transformer import TimeSeriesTransformer
except ModuleNotFoundError:
from pendulum_transformer import TimeSeriesTransformer
def read_last_window(csv_path: str, expect_columns: List[str], seq_len: int):
df = pd.read_csv(csv_path)
if "frame_index" in df.columns:
df = df.sort_values("frame_index").drop(columns=["frame_index"]) # 与训练一致
df_num = df.select_dtypes(include=[np.number]).copy()
has = list(df_num.columns)
if set(has) != set(expect_columns):
raise ValueError(
f"CSV 数值列与训练不一致\n 期望: {expect_columns}\n 实际: {has}\n"
)
df_num = df_num[expect_columns].reset_index(drop=True)
if len(df_num) < seq_len:
raise ValueError(f"数据行不足:需要至少 {seq_len} 行,实际 {len(df_num)}")
arr = df_num.values.astype(np.float32)
return arr[-seq_len:, :]
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model_path", type=str, default="./outputs/chaos/best_model.pt")
ap.add_argument("--scaler_path", type=str, default="./outputs/chaos/scaler.npz")
ap.add_argument("--csv_path", type=str, default="", help="用于构造输入的 CSV可选")
ap.add_argument("--seq_len", type=int, default=180)
ap.add_argument("--iters", type=int, default=200)
ap.add_argument("--warmup", type=int, default=50)
ap.add_argument("--batch", type=int, default=1)
ap.add_argument("--dtype", type=str, default="fp32", choices=["fp32", "bf16", "fp16"])
ap.add_argument("--use_autocast", action="store_true", help="使用 autocast 进行混合精度GPU 推荐)")
args = ap.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
print(f"Device: cuda | {torch.cuda.get_device_name(0)}")
else:
print("Device: cpu")
sc = np.load(args.scaler_path, allow_pickle=True)
mean = sc["mean"].astype(np.float32)
std = sc["std"].astype(np.float32)
columns = list(sc["columns"].tolist())
F = len(columns)
# 加载模型
ckpt = torch.load(args.model_path, map_location=device)
cfg = ckpt["config"]
horizon = int(cfg["horizon"])
model = TimeSeriesTransformer(
in_features=F,
d_model=cfg["d_model"],
nhead=cfg["nhead"],
num_layers=cfg["num_layers"],
dim_feedforward=cfg["dim_ff"],
dropout=cfg["dropout"],
horizon=horizon,
).to(device)
model.load_state_dict(ckpt["model_state"])
model.eval()
# 构造输入(标准化后)
if args.csv_path:
arr = read_last_window(args.csv_path, columns, args.seq_len)
else:
# 随机输入(标准化空间)
arr = np.random.randn(args.seq_len, F).astype(np.float32)
# 标准化
arr_norm = (arr - mean) / std
xb_np = np.repeat(arr_norm[None, :, :], args.batch, axis=0) # [B,L,F]
xb = torch.from_numpy(xb_np).to(device)
# dtype/autocast
dtype_map = {"fp32": torch.float32, "bf16": torch.bfloat16, "fp16": torch.float16}
amp_dtype = dtype_map[args.dtype]
use_amp = args.use_autocast and (device.type == "cuda") and (amp_dtype != torch.float32)
# 预热
with torch.no_grad():
for _ in range(max(0, args.warmup)):
if use_amp:
with torch.autocast(device_type="cuda", dtype=amp_dtype):
_ = model(xb)
else:
_ = model(xb)
if device.type == "cuda":
torch.cuda.synchronize()
# 正式测时
times = []
with torch.no_grad():
for _ in range(args.iters):
if device.type == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
if use_amp:
with torch.autocast(device_type="cuda", dtype=amp_dtype):
_ = model(xb)
else:
_ = model(xb)
if device.type == "cuda":
torch.cuda.synchronize()
t1 = time.perf_counter()
times.append((t1 - t0) * 1000.0) # ms
times = np.array(times, dtype=np.float64)
def pct(p):
return float(np.percentile(times, p))
print("Latency (ms) for single forward:")
print(f" batch={args.batch}, L={args.seq_len}, F={F}, H={horizon}, dtype={args.dtype}, autocast={use_amp}")
print(f" mean={times.mean():.3f} | std={times.std():.3f} | min={times.min():.3f}")
print(f" p50={pct(50):.3f} | p90={pct(90):.3f} | p99={pct(99):.3f}")
if __name__ == "__main__":
main()