2025-09-24 11:35:48 +08:00

133 lines
4.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
serve_api.py
- FastAPI 推理服务:加载 best_model.pt 与 scaler.npz
- /predict接收最近 L=180 帧(按训练列顺序),返回未来 H 帧预测(原尺度)
- /health健康检查/meta返回模型与列信息
运行示例:
uv run -m uvicorn chaos_pdl.serve_api:app --host 0.0.0.0 --port 8000 --workers 1
请求示例:
POST /predict
{
"frames": [[...F个数值...], ..., 共180行]
}
可选:若你以对象形式传参(带列名),也支持:
{
"frame_objects": [
{"delta_time": 0.02, "green_angle_sin": 0.1, ...},
... 共180个对象 ...
]
}
"""
from typing import List, Optional, Dict, Any
import os
import numpy as np
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
try:
from chaos_pdl.pendulum_transformer import TimeSeriesTransformer
except ModuleNotFoundError:
from pendulum_transformer import TimeSeriesTransformer
class PredictRequest(BaseModel):
frames: Optional[List[List[float]]] = Field(
default=None, description="最近L帧形状 [L,F],按 scaler.npz 的列顺序提供"
)
frame_objects: Optional[List[Dict[str, float]]] = Field(
default=None, description="最近L帧形状 [L],每项是 {列名: 数值} 对象"
)
class PredictResponse(BaseModel):
horizon: int
columns: List[str]
predictions: List[List[float]] # [H,F] 原尺度
def load_artifacts(model_path: str, scaler_path: str, device: torch.device):
sc = np.load(scaler_path, allow_pickle=True)
mean = sc["mean"].astype(np.float32)
std = sc["std"].astype(np.float32)
columns = list(sc["columns"].tolist())
ckpt = torch.load(model_path, map_location=device)
cfg = ckpt["config"]
H = int(cfg["horizon"])
F = len(columns)
model = TimeSeriesTransformer(
in_features=F,
d_model=cfg["d_model"],
nhead=cfg["nhead"],
num_layers=cfg["num_layers"],
dim_feedforward=cfg["dim_ff"],
dropout=cfg["dropout"],
horizon=H,
).to(device)
model.load_state_dict(ckpt["model_state"])
model.eval()
return model, mean, std, columns, H
MODEL_PATH = os.environ.get("MODEL_PATH", "./outputs/chaos/best_model.pt")
SCALER_PATH = os.environ.get("SCALER_PATH", "./outputs/chaos/scaler.npz")
SEQ_LEN = int(os.environ.get("SEQ_LEN", 180))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, mean, std, columns, H = load_artifacts(MODEL_PATH, SCALER_PATH, device)
app = FastAPI(title="Pendulum Transformer Inference API", version="0.1.0")
@app.get("/health")
def health():
return {"status": "ok", "device": str(device), "horizon": H, "seq_len": SEQ_LEN}
@app.get("/meta")
def meta():
return {"columns": columns, "horizon": H, "seq_len": SEQ_LEN}
def frames_from_objects(objs: List[Dict[str, Any]], cols: List[str]) -> np.ndarray:
try:
arr = np.array([[float(o[c]) for c in cols] for o in objs], dtype=np.float32)
except Exception as e:
raise HTTPException(status_code=400, detail=f"frame_objects 解析失败: {e}")
return arr
@app.post("/predict", response_model=PredictResponse)
def predict(req: PredictRequest):
# 取输入
if (req.frames is None) == (req.frame_objects is None):
raise HTTPException(status_code=400, detail="需要提供 frames 或 frame_objects 二选一")
if req.frames is not None:
x = np.asarray(req.frames, dtype=np.float32)
if x.ndim != 2:
raise HTTPException(status_code=400, detail="frames 必须是二维数组 [L,F]")
if x.shape[1] != len(columns):
raise HTTPException(status_code=400, detail=f"列数不匹配期望F={len(columns)}")
else:
x = frames_from_objects(req.frame_objects, columns)
if x.shape[0] != SEQ_LEN:
raise HTTPException(status_code=400, detail=f"帧数不匹配期望L={SEQ_LEN}")
# 标准化
x_norm = (x - mean) / std
xb = torch.from_numpy(x_norm).unsqueeze(0).to(device) # [1,L,F]
with torch.no_grad():
pred = model(xb).cpu().numpy()[0] # [H,F] 标准化空间
pred_denorm = pred * std + mean
return PredictResponse(horizon=H, columns=columns, predictions=pred_denorm.tolist())