2025-09-24 11:35:48 +08:00

250 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
make_demo_video.py
- 读取单个 CSV + 模型与标准化,进行滚动 1-step 预测,生成 10 秒预测示意视频
- 画面:展示绿色/红色摆的真实位置(实线)与上一步预测到当前帧的预测位置(虚线)
使用示例:
uv run -m chaos_pdl.make_demo_video \
--csv_path chaos_pdl/data/metrics-20250920-000328_angles.csv \
--model_path outputs/chaos/best_model.pt \
--scaler_path outputs/chaos/scaler.npz \
--seq_len 180 \
--duration 10 \
--fps 30 \
--out_path outputs/chaos/demo.mp4
若系统未安装 ffmpeg会自动回退为 GIFoutputs/chaos/demo.gif
"""
import argparse
import os
from typing import List, Tuple
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, FFMpegWriter, PillowWriter
try:
from chaos_pdl.pendulum_transformer import TimeSeriesTransformer, WindowedTSDataset
except ModuleNotFoundError:
from pendulum_transformer import TimeSeriesTransformer, WindowedTSDataset
def read_align_csv(path: str, expect_columns: List[str]) -> pd.DataFrame:
df = pd.read_csv(path)
if "frame_index" in df.columns:
df = df.sort_values("frame_index").drop(columns=["frame_index"]) # 对齐训练逻辑
df_num = df.select_dtypes(include=[np.number]).copy()
has = list(df_num.columns)
if set(has) != set(expect_columns):
raise ValueError(
f"CSV 数值列与训练不一致\n 期望: {expect_columns}\n 实际: {has}\n"
)
return df_num[expect_columns].reset_index(drop=True)
def build_one_step_pred_sequence(values_norm: np.ndarray, model: TimeSeriesTransformer, device: torch.device,
seq_len: int, frame_idx_list: np.ndarray, batch_size: int = 512) -> np.ndarray:
"""
对于每个目标帧 t使用窗口 [t-L, t)(即以 t-1 为最后一帧)做前向,取 1-step 预测pred[:, 0, :])。
返回形状 [N, F] 的反标准化前的预测(在外部再去反标准化)。
这里直接返回标准化空间预测(保留与训练一致的数值),由外层负责反标准化。
"""
model.eval()
starts = frame_idx_list - seq_len # 每个 t 对应窗口起点(最后一帧是 t-1
assert np.all(starts >= 0)
# 构造批量窗口
preds_list = []
with torch.no_grad():
for i in range(0, len(starts), batch_size):
bs = starts[i:i + batch_size]
# 堆叠为 [B,L,F]
batch = np.stack([values_norm[s:s + seq_len] for s in bs], axis=0).astype(np.float32)
xb = torch.from_numpy(batch).to(device)
out = model(xb) # [B,H,F]
preds_list.append(out[:, 0, :].cpu().numpy()) # 取 t+1 的第一步
return np.concatenate(preds_list, axis=0) # [N,F]
def angle_from_sin_cos(sin_v: float, cos_v: float) -> float:
return float(np.arctan2(sin_v, cos_v))
def endpoint_from_angle(theta: float, length: float = 1.0) -> Tuple[float, float]:
# 以 (0,0) 为锚点x=sin(theta), y=-cos(theta)让角度0指向正下
return (length * np.sin(theta), -length * np.cos(theta))
def draw_frame(ax, true_row: np.ndarray, pred_row: np.ndarray, col_idx: dict, title: str):
ax.clear()
ax.set_aspect('equal', 'box')
ax.set_xlim(-1.2, 1.2)
ax.set_ylim(-1.2, 1.2)
ax.grid(True, alpha=0.2)
# 真实角度
g_theta = angle_from_sin_cos(true_row[col_idx['green_angle_sin']], true_row[col_idx['green_angle_cos']])
r_theta = angle_from_sin_cos(true_row[col_idx['red_angle_sin']], true_row[col_idx['red_angle_cos']])
gx, gy = endpoint_from_angle(g_theta)
rx, ry = endpoint_from_angle(r_theta)
# 预测角度(上一步预测到当前帧)
pg_theta = angle_from_sin_cos(pred_row[col_idx['green_angle_sin']], pred_row[col_idx['green_angle_cos']])
pr_theta = angle_from_sin_cos(pred_row[col_idx['red_angle_sin']], pred_row[col_idx['red_angle_cos']])
pgx, pgy = endpoint_from_angle(pg_theta)
prx, pry = endpoint_from_angle(pr_theta)
# 真实(实线)
ax.plot([0, gx], [0, gy], color='tab:green', lw=3, label='green-true')
ax.plot([0, rx], [0, ry], color='tab:red', lw=3, label='red-true')
ax.scatter([gx, rx], [gy, ry], color=['tab:green', 'tab:red'], s=40)
# 预测(虚线半透明)
ax.plot([0, pgx], [0, pgy], color='tab:green', lw=2, ls='--', alpha=0.7, label='green-pred')
ax.plot([0, prx], [0, pry], color='tab:red', lw=2, ls='--', alpha=0.7, label='red-pred')
ax.scatter([pgx, prx], [pgy, pry], color=['tab:green', 'tab:red'], s=30, alpha=0.7)
ax.legend(loc='upper right')
ax.set_title(title)
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--csv_path', type=str, required=True, help='输入 CSV需要包含训练使用的列')
ap.add_argument('--model_path', type=str, default='outputs/chaos/best_model.pt')
ap.add_argument('--scaler_path', type=str, default='outputs/chaos/scaler.npz')
ap.add_argument('--seq_len', type=int, default=180)
ap.add_argument('--duration', type=int, default=10, help='视频秒数')
ap.add_argument('--fps', type=int, default=30)
ap.add_argument('--out_path', type=str, default='outputs/chaos/demo.mp4')
ap.add_argument('--start_from_end', action='store_true', help='从序列末尾倒推所需时长;不设则从 seq_len 处开始')
ap.add_argument('--batch_size', type=int, default=512)
ap.add_argument('--orient', type=str, default='down', choices=['down', 'right'],
help='坐标映射down -> (x=sin(theta), y=-cos(theta)); right -> (x=cos(theta), y=sin(theta))')
ap.add_argument('--debug_stats', action='store_true', help='打印窗口内若干关键列的最小/最大值,排查坐标恒正等问题')
args = ap.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sc = np.load(args.scaler_path, allow_pickle=True)
mean = sc['mean'].astype(np.float32)
std = sc['std'].astype(np.float32)
columns = list(sc['columns'].tolist())
col_idx = {c: i for i, c in enumerate(columns)}
# 加载模型
ckpt = torch.load(args.model_path, map_location=device)
cfg = ckpt['config']
H = int(cfg['horizon'])
model = TimeSeriesTransformer(
in_features=len(columns),
d_model=cfg['d_model'],
nhead=cfg['nhead'],
num_layers=cfg['num_layers'],
dim_feedforward=cfg['dim_ff'],
dropout=cfg['dropout'],
horizon=H,
).to(device)
model.load_state_dict(ckpt['model_state'])
model.eval()
# 数据
df = read_align_csv(args.csv_path, columns)
values = df.values.astype(np.float32)
T = len(values)
if T < args.seq_len + 2:
raise ValueError(f'数据太短:需要至少 {args.seq_len + 2} 行,当前 {T}')
# 选择时间范围
N_frames = args.duration * args.fps
if args.start_from_end: # 从末尾倒推
end_t = T - 1
start_t = max(args.seq_len, end_t - N_frames + 1)
else:
start_t = args.seq_len
end_t = min(T - 1, start_t + N_frames - 1)
frame_ts = np.arange(start_t, end_t + 1, dtype=np.int64)
N_frames = len(frame_ts)
# 标准化
values_norm = (values - mean) / std
# 为每个 t 进行 1-step 预测:窗口是 [t-L, t),目标是 t
preds_norm = build_one_step_pred_sequence(values_norm, model, device, args.seq_len, frame_ts, args.batch_size)
preds = preds_norm * std + mean # 反标准化
# 真实对应行
truths = values[frame_ts]
# 可选调试:打印该时段关键列范围
if args.debug_stats:
def rng(name):
i = col_idx[name]
return float(truths[:, i].min()), float(truths[:, i].max())
keys = [
'green_angle_sin','green_angle_cos','red_angle_sin','red_angle_cos'
]
avail = [k for k in keys if k in col_idx]
for k in avail:
lo, hi = rng(k)
print(f"[debug] {k}: min={lo:.4f}, max={hi:.4f}")
# 准备动画
fig, ax = plt.subplots(figsize=(5, 5))
def update(i):
t = frame_ts[i]
title = f't={int(t)} (frame {i+1}/{N_frames})'
# 根据 orient 选择坐标映射
if args.orient == 'down':
draw_frame(ax, truths[i], preds[i], col_idx, title)
else: # 'right'
def angle_from_sc(row, which):
return float(np.arctan2(row[col_idx[f'{which}_angle_sin']], row[col_idx[f'{which}_angle_cos']]))
def xy_from_theta(theta):
return (np.cos(theta), np.sin(theta))
ax.clear()
ax.set_aspect('equal', 'box')
ax.set_xlim(-1.2, 1.2)
ax.set_ylim(-1.2, 1.2)
ax.grid(True, alpha=0.2)
# 真实
g_t = angle_from_sc(truths[i], 'green'); r_t = angle_from_sc(truths[i], 'red')
gx, gy = xy_from_theta(g_t); rx, ry = xy_from_theta(r_t)
# 预测
pg_t = angle_from_sc(preds[i], 'green'); pr_t = angle_from_sc(preds[i], 'red')
pgx, pgy = xy_from_theta(pg_t); prx, pry = xy_from_theta(pr_t)
ax.plot([0, gx],[0, gy], color='tab:green', lw=3, label='green-true')
ax.plot([0, rx],[0, ry], color='tab:red', lw=3, label='red-true')
ax.scatter([gx, rx],[gy, ry], color=['tab:green','tab:red'], s=40)
ax.plot([0, pgx],[0, pgy], color='tab:green', lw=2, ls='--', alpha=0.7, label='green-pred')
ax.plot([0, prx],[0, pry], color='tab:red', lw=2, ls='--', alpha=0.7, label='red-pred')
ax.scatter([pgx, prx],[pgy, pry], color=['tab:green','tab:red'], s=30, alpha=0.7)
ax.legend(loc='upper right')
ax.set_title(title)
return []
anim = FuncAnimation(fig, update, frames=N_frames, interval=1000.0 / args.fps, blit=False)
# 存储
os.makedirs(os.path.dirname(args.out_path), exist_ok=True)
try:
writer = FFMpegWriter(fps=args.fps, bitrate=1800)
anim.save(args.out_path, writer=writer)
print(f'Saved video: {args.out_path}')
except Exception as e:
# 回退为 GIF
gif_path = os.path.splitext(args.out_path)[0] + '.gif'
print(f'[warn] 写入 mp4 失败(可能缺少 ffmpeg{e}\n回退为 GIF -> {gif_path}')
writer = PillowWriter(fps=args.fps)
anim.save(gif_path, writer=writer)
print(f'Saved GIF: {gif_path}')
if __name__ == '__main__':
main()