250 lines
10 KiB
Python
250 lines
10 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
make_demo_video.py
|
||
- 读取单个 CSV + 模型与标准化,进行滚动 1-step 预测,生成 10 秒预测示意视频
|
||
- 画面:展示绿色/红色摆的真实位置(实线)与上一步预测到当前帧的预测位置(虚线)
|
||
|
||
使用示例:
|
||
uv run -m chaos_pdl.make_demo_video \
|
||
--csv_path chaos_pdl/data/metrics-20250920-000328_angles.csv \
|
||
--model_path outputs/chaos/best_model.pt \
|
||
--scaler_path outputs/chaos/scaler.npz \
|
||
--seq_len 180 \
|
||
--duration 10 \
|
||
--fps 30 \
|
||
--out_path outputs/chaos/demo.mp4
|
||
|
||
若系统未安装 ffmpeg,会自动回退为 GIF:outputs/chaos/demo.gif
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
from typing import List, Tuple
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
import matplotlib.pyplot as plt
|
||
from matplotlib.animation import FuncAnimation, FFMpegWriter, PillowWriter
|
||
|
||
try:
|
||
from chaos_pdl.pendulum_transformer import TimeSeriesTransformer, WindowedTSDataset
|
||
except ModuleNotFoundError:
|
||
from pendulum_transformer import TimeSeriesTransformer, WindowedTSDataset
|
||
|
||
|
||
def read_align_csv(path: str, expect_columns: List[str]) -> pd.DataFrame:
|
||
df = pd.read_csv(path)
|
||
if "frame_index" in df.columns:
|
||
df = df.sort_values("frame_index").drop(columns=["frame_index"]) # 对齐训练逻辑
|
||
df_num = df.select_dtypes(include=[np.number]).copy()
|
||
has = list(df_num.columns)
|
||
if set(has) != set(expect_columns):
|
||
raise ValueError(
|
||
f"CSV 数值列与训练不一致\n 期望: {expect_columns}\n 实际: {has}\n"
|
||
)
|
||
return df_num[expect_columns].reset_index(drop=True)
|
||
|
||
|
||
def build_one_step_pred_sequence(values_norm: np.ndarray, model: TimeSeriesTransformer, device: torch.device,
|
||
seq_len: int, frame_idx_list: np.ndarray, batch_size: int = 512) -> np.ndarray:
|
||
"""
|
||
对于每个目标帧 t,使用窗口 [t-L, t)(即以 t-1 为最后一帧)做前向,取 1-step 预测(pred[:, 0, :])。
|
||
返回形状 [N, F] 的反标准化前的预测(在外部再去反标准化)。
|
||
这里直接返回标准化空间预测(保留与训练一致的数值),由外层负责反标准化。
|
||
"""
|
||
model.eval()
|
||
starts = frame_idx_list - seq_len # 每个 t 对应窗口起点(最后一帧是 t-1)
|
||
assert np.all(starts >= 0)
|
||
# 构造批量窗口
|
||
preds_list = []
|
||
with torch.no_grad():
|
||
for i in range(0, len(starts), batch_size):
|
||
bs = starts[i:i + batch_size]
|
||
# 堆叠为 [B,L,F]
|
||
batch = np.stack([values_norm[s:s + seq_len] for s in bs], axis=0).astype(np.float32)
|
||
xb = torch.from_numpy(batch).to(device)
|
||
out = model(xb) # [B,H,F]
|
||
preds_list.append(out[:, 0, :].cpu().numpy()) # 取 t+1 的第一步
|
||
return np.concatenate(preds_list, axis=0) # [N,F]
|
||
|
||
|
||
def angle_from_sin_cos(sin_v: float, cos_v: float) -> float:
|
||
return float(np.arctan2(sin_v, cos_v))
|
||
|
||
|
||
def endpoint_from_angle(theta: float, length: float = 1.0) -> Tuple[float, float]:
|
||
# 以 (0,0) 为锚点,x=sin(theta), y=-cos(theta)(让角度0指向正下)
|
||
return (length * np.sin(theta), -length * np.cos(theta))
|
||
|
||
|
||
def draw_frame(ax, true_row: np.ndarray, pred_row: np.ndarray, col_idx: dict, title: str):
|
||
ax.clear()
|
||
ax.set_aspect('equal', 'box')
|
||
ax.set_xlim(-1.2, 1.2)
|
||
ax.set_ylim(-1.2, 1.2)
|
||
ax.grid(True, alpha=0.2)
|
||
|
||
# 真实角度
|
||
g_theta = angle_from_sin_cos(true_row[col_idx['green_angle_sin']], true_row[col_idx['green_angle_cos']])
|
||
r_theta = angle_from_sin_cos(true_row[col_idx['red_angle_sin']], true_row[col_idx['red_angle_cos']])
|
||
gx, gy = endpoint_from_angle(g_theta)
|
||
rx, ry = endpoint_from_angle(r_theta)
|
||
|
||
# 预测角度(上一步预测到当前帧)
|
||
pg_theta = angle_from_sin_cos(pred_row[col_idx['green_angle_sin']], pred_row[col_idx['green_angle_cos']])
|
||
pr_theta = angle_from_sin_cos(pred_row[col_idx['red_angle_sin']], pred_row[col_idx['red_angle_cos']])
|
||
pgx, pgy = endpoint_from_angle(pg_theta)
|
||
prx, pry = endpoint_from_angle(pr_theta)
|
||
|
||
# 真实(实线)
|
||
ax.plot([0, gx], [0, gy], color='tab:green', lw=3, label='green-true')
|
||
ax.plot([0, rx], [0, ry], color='tab:red', lw=3, label='red-true')
|
||
ax.scatter([gx, rx], [gy, ry], color=['tab:green', 'tab:red'], s=40)
|
||
|
||
# 预测(虚线半透明)
|
||
ax.plot([0, pgx], [0, pgy], color='tab:green', lw=2, ls='--', alpha=0.7, label='green-pred')
|
||
ax.plot([0, prx], [0, pry], color='tab:red', lw=2, ls='--', alpha=0.7, label='red-pred')
|
||
ax.scatter([pgx, prx], [pgy, pry], color=['tab:green', 'tab:red'], s=30, alpha=0.7)
|
||
|
||
ax.legend(loc='upper right')
|
||
ax.set_title(title)
|
||
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser()
|
||
ap.add_argument('--csv_path', type=str, required=True, help='输入 CSV(需要包含训练使用的列)')
|
||
ap.add_argument('--model_path', type=str, default='outputs/chaos/best_model.pt')
|
||
ap.add_argument('--scaler_path', type=str, default='outputs/chaos/scaler.npz')
|
||
ap.add_argument('--seq_len', type=int, default=180)
|
||
ap.add_argument('--duration', type=int, default=10, help='视频秒数')
|
||
ap.add_argument('--fps', type=int, default=30)
|
||
ap.add_argument('--out_path', type=str, default='outputs/chaos/demo.mp4')
|
||
ap.add_argument('--start_from_end', action='store_true', help='从序列末尾倒推所需时长;不设则从 seq_len 处开始')
|
||
ap.add_argument('--batch_size', type=int, default=512)
|
||
ap.add_argument('--orient', type=str, default='down', choices=['down', 'right'],
|
||
help='坐标映射:down -> (x=sin(theta), y=-cos(theta)); right -> (x=cos(theta), y=sin(theta))')
|
||
ap.add_argument('--debug_stats', action='store_true', help='打印窗口内若干关键列的最小/最大值,排查坐标恒正等问题')
|
||
|
||
args = ap.parse_args()
|
||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
||
sc = np.load(args.scaler_path, allow_pickle=True)
|
||
mean = sc['mean'].astype(np.float32)
|
||
std = sc['std'].astype(np.float32)
|
||
columns = list(sc['columns'].tolist())
|
||
col_idx = {c: i for i, c in enumerate(columns)}
|
||
|
||
# 加载模型
|
||
ckpt = torch.load(args.model_path, map_location=device)
|
||
cfg = ckpt['config']
|
||
H = int(cfg['horizon'])
|
||
model = TimeSeriesTransformer(
|
||
in_features=len(columns),
|
||
d_model=cfg['d_model'],
|
||
nhead=cfg['nhead'],
|
||
num_layers=cfg['num_layers'],
|
||
dim_feedforward=cfg['dim_ff'],
|
||
dropout=cfg['dropout'],
|
||
horizon=H,
|
||
).to(device)
|
||
model.load_state_dict(ckpt['model_state'])
|
||
model.eval()
|
||
|
||
# 数据
|
||
df = read_align_csv(args.csv_path, columns)
|
||
values = df.values.astype(np.float32)
|
||
T = len(values)
|
||
if T < args.seq_len + 2:
|
||
raise ValueError(f'数据太短:需要至少 {args.seq_len + 2} 行,当前 {T}')
|
||
|
||
# 选择时间范围
|
||
N_frames = args.duration * args.fps
|
||
if args.start_from_end: # 从末尾倒推
|
||
end_t = T - 1
|
||
start_t = max(args.seq_len, end_t - N_frames + 1)
|
||
else:
|
||
start_t = args.seq_len
|
||
end_t = min(T - 1, start_t + N_frames - 1)
|
||
frame_ts = np.arange(start_t, end_t + 1, dtype=np.int64)
|
||
N_frames = len(frame_ts)
|
||
|
||
# 标准化
|
||
values_norm = (values - mean) / std
|
||
|
||
# 为每个 t 进行 1-step 预测:窗口是 [t-L, t),目标是 t
|
||
preds_norm = build_one_step_pred_sequence(values_norm, model, device, args.seq_len, frame_ts, args.batch_size)
|
||
preds = preds_norm * std + mean # 反标准化
|
||
|
||
# 真实对应行
|
||
truths = values[frame_ts]
|
||
|
||
# 可选调试:打印该时段关键列范围
|
||
if args.debug_stats:
|
||
def rng(name):
|
||
i = col_idx[name]
|
||
return float(truths[:, i].min()), float(truths[:, i].max())
|
||
keys = [
|
||
'green_angle_sin','green_angle_cos','red_angle_sin','red_angle_cos'
|
||
]
|
||
avail = [k for k in keys if k in col_idx]
|
||
for k in avail:
|
||
lo, hi = rng(k)
|
||
print(f"[debug] {k}: min={lo:.4f}, max={hi:.4f}")
|
||
|
||
# 准备动画
|
||
fig, ax = plt.subplots(figsize=(5, 5))
|
||
|
||
def update(i):
|
||
t = frame_ts[i]
|
||
title = f't={int(t)} (frame {i+1}/{N_frames})'
|
||
# 根据 orient 选择坐标映射
|
||
if args.orient == 'down':
|
||
draw_frame(ax, truths[i], preds[i], col_idx, title)
|
||
else: # 'right'
|
||
def angle_from_sc(row, which):
|
||
return float(np.arctan2(row[col_idx[f'{which}_angle_sin']], row[col_idx[f'{which}_angle_cos']]))
|
||
def xy_from_theta(theta):
|
||
return (np.cos(theta), np.sin(theta))
|
||
ax.clear()
|
||
ax.set_aspect('equal', 'box')
|
||
ax.set_xlim(-1.2, 1.2)
|
||
ax.set_ylim(-1.2, 1.2)
|
||
ax.grid(True, alpha=0.2)
|
||
# 真实
|
||
g_t = angle_from_sc(truths[i], 'green'); r_t = angle_from_sc(truths[i], 'red')
|
||
gx, gy = xy_from_theta(g_t); rx, ry = xy_from_theta(r_t)
|
||
# 预测
|
||
pg_t = angle_from_sc(preds[i], 'green'); pr_t = angle_from_sc(preds[i], 'red')
|
||
pgx, pgy = xy_from_theta(pg_t); prx, pry = xy_from_theta(pr_t)
|
||
ax.plot([0, gx],[0, gy], color='tab:green', lw=3, label='green-true')
|
||
ax.plot([0, rx],[0, ry], color='tab:red', lw=3, label='red-true')
|
||
ax.scatter([gx, rx],[gy, ry], color=['tab:green','tab:red'], s=40)
|
||
ax.plot([0, pgx],[0, pgy], color='tab:green', lw=2, ls='--', alpha=0.7, label='green-pred')
|
||
ax.plot([0, prx],[0, pry], color='tab:red', lw=2, ls='--', alpha=0.7, label='red-pred')
|
||
ax.scatter([pgx, prx],[pgy, pry], color=['tab:green','tab:red'], s=30, alpha=0.7)
|
||
ax.legend(loc='upper right')
|
||
ax.set_title(title)
|
||
return []
|
||
|
||
anim = FuncAnimation(fig, update, frames=N_frames, interval=1000.0 / args.fps, blit=False)
|
||
|
||
# 存储
|
||
os.makedirs(os.path.dirname(args.out_path), exist_ok=True)
|
||
try:
|
||
writer = FFMpegWriter(fps=args.fps, bitrate=1800)
|
||
anim.save(args.out_path, writer=writer)
|
||
print(f'Saved video: {args.out_path}')
|
||
except Exception as e:
|
||
# 回退为 GIF
|
||
gif_path = os.path.splitext(args.out_path)[0] + '.gif'
|
||
print(f'[warn] 写入 mp4 失败(可能缺少 ffmpeg):{e}\n回退为 GIF -> {gif_path}')
|
||
writer = PillowWriter(fps=args.fps)
|
||
anim.save(gif_path, writer=writer)
|
||
print(f'Saved GIF: {gif_path}')
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|