133 lines
4.2 KiB
Python
133 lines
4.2 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
serve_api.py
|
||
- FastAPI 推理服务:加载 best_model.pt 与 scaler.npz
|
||
- /predict:接收最近 L=180 帧(按训练列顺序),返回未来 H 帧预测(原尺度)
|
||
- /health:健康检查;/meta:返回模型与列信息
|
||
|
||
运行示例:
|
||
uv run -m uvicorn chaos_pdl.serve_api:app --host 0.0.0.0 --port 8000 --workers 1
|
||
|
||
请求示例:
|
||
POST /predict
|
||
{
|
||
"frames": [[...F个数值...], ..., 共180行]
|
||
}
|
||
|
||
可选:若你以对象形式传参(带列名),也支持:
|
||
{
|
||
"frame_objects": [
|
||
{"delta_time": 0.02, "green_angle_sin": 0.1, ...},
|
||
... 共180个对象 ...
|
||
]
|
||
}
|
||
"""
|
||
|
||
from typing import List, Optional, Dict, Any
|
||
import os
|
||
import numpy as np
|
||
import torch
|
||
from fastapi import FastAPI, HTTPException
|
||
from pydantic import BaseModel, Field
|
||
|
||
try:
|
||
from chaos_pdl.pendulum_transformer import TimeSeriesTransformer
|
||
except ModuleNotFoundError:
|
||
from pendulum_transformer import TimeSeriesTransformer
|
||
|
||
|
||
class PredictRequest(BaseModel):
|
||
frames: Optional[List[List[float]]] = Field(
|
||
default=None, description="最近L帧,形状 [L,F],按 scaler.npz 的列顺序提供"
|
||
)
|
||
frame_objects: Optional[List[Dict[str, float]]] = Field(
|
||
default=None, description="最近L帧,形状 [L],每项是 {列名: 数值} 对象"
|
||
)
|
||
|
||
|
||
class PredictResponse(BaseModel):
|
||
horizon: int
|
||
columns: List[str]
|
||
predictions: List[List[float]] # [H,F] 原尺度
|
||
|
||
|
||
def load_artifacts(model_path: str, scaler_path: str, device: torch.device):
|
||
sc = np.load(scaler_path, allow_pickle=True)
|
||
mean = sc["mean"].astype(np.float32)
|
||
std = sc["std"].astype(np.float32)
|
||
columns = list(sc["columns"].tolist())
|
||
|
||
ckpt = torch.load(model_path, map_location=device)
|
||
cfg = ckpt["config"]
|
||
H = int(cfg["horizon"])
|
||
F = len(columns)
|
||
|
||
model = TimeSeriesTransformer(
|
||
in_features=F,
|
||
d_model=cfg["d_model"],
|
||
nhead=cfg["nhead"],
|
||
num_layers=cfg["num_layers"],
|
||
dim_feedforward=cfg["dim_ff"],
|
||
dropout=cfg["dropout"],
|
||
horizon=H,
|
||
).to(device)
|
||
model.load_state_dict(ckpt["model_state"])
|
||
model.eval()
|
||
return model, mean, std, columns, H
|
||
|
||
|
||
MODEL_PATH = os.environ.get("MODEL_PATH", "./outputs/chaos/best_model.pt")
|
||
SCALER_PATH = os.environ.get("SCALER_PATH", "./outputs/chaos/scaler.npz")
|
||
SEQ_LEN = int(os.environ.get("SEQ_LEN", 180))
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
model, mean, std, columns, H = load_artifacts(MODEL_PATH, SCALER_PATH, device)
|
||
|
||
app = FastAPI(title="Pendulum Transformer Inference API", version="0.1.0")
|
||
|
||
|
||
@app.get("/health")
|
||
def health():
|
||
return {"status": "ok", "device": str(device), "horizon": H, "seq_len": SEQ_LEN}
|
||
|
||
|
||
@app.get("/meta")
|
||
def meta():
|
||
return {"columns": columns, "horizon": H, "seq_len": SEQ_LEN}
|
||
|
||
|
||
def frames_from_objects(objs: List[Dict[str, Any]], cols: List[str]) -> np.ndarray:
|
||
try:
|
||
arr = np.array([[float(o[c]) for c in cols] for o in objs], dtype=np.float32)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"frame_objects 解析失败: {e}")
|
||
return arr
|
||
|
||
|
||
@app.post("/predict", response_model=PredictResponse)
|
||
def predict(req: PredictRequest):
|
||
# 取输入
|
||
if (req.frames is None) == (req.frame_objects is None):
|
||
raise HTTPException(status_code=400, detail="需要提供 frames 或 frame_objects 二选一")
|
||
|
||
if req.frames is not None:
|
||
x = np.asarray(req.frames, dtype=np.float32)
|
||
if x.ndim != 2:
|
||
raise HTTPException(status_code=400, detail="frames 必须是二维数组 [L,F]")
|
||
if x.shape[1] != len(columns):
|
||
raise HTTPException(status_code=400, detail=f"列数不匹配,期望F={len(columns)}")
|
||
else:
|
||
x = frames_from_objects(req.frame_objects, columns)
|
||
|
||
if x.shape[0] != SEQ_LEN:
|
||
raise HTTPException(status_code=400, detail=f"帧数不匹配,期望L={SEQ_LEN}")
|
||
|
||
# 标准化
|
||
x_norm = (x - mean) / std
|
||
xb = torch.from_numpy(x_norm).unsqueeze(0).to(device) # [1,L,F]
|
||
|
||
with torch.no_grad():
|
||
pred = model(xb).cpu().numpy()[0] # [H,F] 标准化空间
|
||
pred_denorm = pred * std + mean
|
||
return PredictResponse(horizon=H, columns=columns, predictions=pred_denorm.tolist())
|