239 lines
9.4 KiB
Python
239 lines
9.4 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
启动一个仅 CUDA 的 HTTP 推理服务,输入/输出均为像素坐标。
|
||
内部将像素轨迹拟合圆并转换为单位 (sin,cos) 以适配训练空间,
|
||
并把预测结果再还原为像素坐标返回。
|
||
|
||
POST /predict
|
||
body:
|
||
{
|
||
"sequences": [ # 批量,每个序列为 [S_in, 2] 像素 [x,y]
|
||
[[x, y], [x, y], ...],
|
||
...
|
||
],
|
||
"steps": 30, # 预测步数 (>0)
|
||
"centers": [[cx,cy], ...], # 可选 [B,2],若已标定则可传入
|
||
"radii": [r, ...] # 可选 [B]
|
||
}
|
||
|
||
返回:
|
||
{
|
||
"pred_xy": [[[x,y], ...], ...], # [B, steps, 2] 像素
|
||
"device": "cuda",
|
||
"centers": [[cx,cy], ...], # 用于本次变换的参数
|
||
"radii": [r, ...]
|
||
}
|
||
|
||
运行示例:
|
||
uv run serve_pendulum.py --ckpt ckpt/pendulum_tf_base.pt --host 0.0.0.0 --port 8000
|
||
"""
|
||
import argparse
|
||
import math
|
||
import os
|
||
from typing import List, Optional
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
from fastapi import FastAPI, HTTPException
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel, Field
|
||
import uvicorn
|
||
# 兼容直接脚本运行与包内相对导入
|
||
try:
|
||
from .geom_utils import batch_pixels_to_unit_sincos, unit_sincos_to_pixels # type: ignore
|
||
except Exception: # noqa
|
||
from geom_utils import batch_pixels_to_unit_sincos, unit_sincos_to_pixels
|
||
|
||
# ----------------- 与训练一致的模块 -----------------
|
||
class PositionalEncoding(nn.Module):
|
||
def __init__(self, d_model: int, max_len: int = 4096):
|
||
super().__init__()
|
||
pe = torch.zeros(max_len, d_model)
|
||
pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
|
||
div = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))
|
||
pe[:, 0::2] = torch.sin(pos * div)
|
||
pe[:, 1::2] = torch.cos(pos * div)
|
||
self.register_buffer("pe", pe.unsqueeze(0)) # [1, L, D]
|
||
|
||
def forward(self, x):
|
||
return x + self.pe[:, : x.size(1), :]
|
||
|
||
class TimeSeriesTransformer(nn.Module):
|
||
def __init__(self, in_dim=2, d_model=256, nhead=4, enc_layers=3, dec_layers=3, d_ff=512, dropout=0.1, out_dim=2):
|
||
super().__init__()
|
||
self.in_proj = nn.Linear(in_dim, d_model)
|
||
self.out_proj = nn.Linear(d_model, out_dim)
|
||
self.tgt_proj = nn.Linear(out_dim, d_model)
|
||
self.pos_enc = PositionalEncoding(d_model)
|
||
self.tf = nn.Transformer(
|
||
d_model=d_model,
|
||
nhead=nhead,
|
||
num_encoder_layers=enc_layers,
|
||
num_decoder_layers=dec_layers,
|
||
dim_feedforward=d_ff,
|
||
dropout=dropout,
|
||
batch_first=True,
|
||
)
|
||
|
||
def forward(self, src, tgt_in, tgt_mask=None):
|
||
src = self.pos_enc(self.in_proj(src))
|
||
tgt = self.pos_enc(self.tgt_proj(tgt_in))
|
||
out = self.tf(src, tgt, tgt_mask=tgt_mask) # [B, S_out, D]
|
||
return self.out_proj(out)
|
||
|
||
# ----------------- 推理引擎(仅 CUDA) -----------------
|
||
class InferenceEngine:
|
||
def __init__(self, ckpt_path: str):
|
||
if not torch.cuda.is_available():
|
||
raise RuntimeError("CUDA 不可用:该服务仅支持 CUDA。")
|
||
self.device = torch.device("cuda")
|
||
|
||
if not os.path.exists(ckpt_path):
|
||
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
||
|
||
# 显式关闭 weights_only 的安全限制(信任本地 ckpt)
|
||
payload = torch.load(ckpt_path, map_location=self.device, weights_only=False)
|
||
|
||
default_cfg = dict(d_model=256, nhead=4, enc_layers=3, dec_layers=3, d_ff=512)
|
||
ckpt_cfg = payload.get("cfg", default_cfg)
|
||
self.net = TimeSeriesTransformer(in_dim=2, out_dim=2, **ckpt_cfg).to(self.device)
|
||
self.net.load_state_dict(payload["model"], strict=True)
|
||
self.net.eval()
|
||
torch.set_grad_enabled(False)
|
||
|
||
# 标准化参数
|
||
self.mu = payload["mu"] # numpy
|
||
self.std = payload["std"]
|
||
self.mu_t = torch.from_numpy(self.mu).to(self.device)
|
||
self.std_t = torch.from_numpy(self.std).to(self.device)
|
||
|
||
@staticmethod
|
||
def _unitize_t(t: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
|
||
n = torch.clamp(torch.linalg.norm(t, dim=-1, keepdim=True), min=eps)
|
||
return t / n
|
||
|
||
def _to_raw(self, z: torch.Tensor) -> torch.Tensor:
|
||
return z * self.std_t + self.mu_t
|
||
|
||
def _to_std(self, z: torch.Tensor) -> torch.Tensor:
|
||
return (z - self.mu_t) / self.std_t
|
||
|
||
def rollout(self, xb_raw: torch.Tensor, steps: int) -> torch.Tensor:
|
||
"""
|
||
xb_raw: [B, S_in, 2],raw 空间(即训练时用的 (sin, cos) 或等价的 [x,y] 单位向量)
|
||
返回:raw 空间预测 [B, steps, 2]
|
||
"""
|
||
xb_raw = self._unitize_t(xb_raw) # 规整到单位圆
|
||
xb_std = self._to_std(xb_raw) # -> 标准化
|
||
|
||
tgt_tokens = xb_std[:, -1:, :2].clone() # 起始 token
|
||
for _ in range(steps):
|
||
mask = nn.Transformer.generate_square_subsequent_mask(tgt_tokens.size(1)).to(self.device)
|
||
out = self.net(xb_std, tgt_tokens, tgt_mask=mask) # 标准化空间
|
||
next_std = out[:, -1:, :]
|
||
next_raw = self._to_raw(next_std) # -> raw
|
||
next_raw = self._unitize_t(next_raw) # 单位圆约束
|
||
next_std = self._to_std(next_raw) # -> 标准化
|
||
tgt_tokens = torch.cat([tgt_tokens, next_std], dim=1)
|
||
|
||
pred_std = tgt_tokens[:, 1:, :]
|
||
pred_raw = self._to_raw(pred_std)
|
||
pred_raw = self._unitize_t(pred_raw) # 返回前再规整
|
||
return pred_raw
|
||
|
||
# ----------------- HTTP 层(仅 /predict) -----------------
|
||
class PredictRequest(BaseModel):
|
||
sequences: List[List[List[float]]] = Field(
|
||
..., description="批量序列:每个为 [S_in,2] 的像素 [x,y] 轨迹(来自图像跟踪)"
|
||
)
|
||
steps: int = Field(..., gt=0, description="预测步数 (>0)")
|
||
# 可选:若客户端已标定可传入,以避免重复拟合
|
||
centers: Optional[List[List[float]]] = Field(
|
||
default=None, description="可选 [B,2],每条序列的圆心 (cx,cy)"
|
||
)
|
||
radii: Optional[List[float]] = Field(
|
||
default=None, description="可选 [B],每条序列的半径 r"
|
||
)
|
||
|
||
class PredictResponse(BaseModel):
|
||
pred_xy: List[List[List[float]]] # 未来像素坐标 [B, steps, 2]
|
||
device: str
|
||
centers: Optional[List[List[float]]] = None
|
||
radii: Optional[List[float]] = None
|
||
|
||
app = FastAPI(title="Pendulum Transformer Inference (CUDA Only)", version="1.0.0")
|
||
|
||
# CORS: 允许任意来源
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=False,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
ENGINE: InferenceEngine | None = None # 运行时注入
|
||
|
||
@app.post("/predict", response_model=PredictResponse)
|
||
def predict(req: PredictRequest):
|
||
if ENGINE is None:
|
||
raise HTTPException(500, "Engine not initialized")
|
||
|
||
try:
|
||
pix = np.array(req.sequences, dtype=np.float32) # [B, S_in, 2]
|
||
except Exception as e:
|
||
raise HTTPException(400, f"Invalid 'sequences': {e}")
|
||
|
||
if pix.ndim != 3 or pix.shape[-1] != 2:
|
||
raise HTTPException(400, f"'sequences' must be [B, S_in, 2], got {pix.shape}")
|
||
if req.steps <= 0:
|
||
raise HTTPException(400, "'steps' must be > 0")
|
||
if pix.shape[1] < 1:
|
||
raise HTTPException(400, "Each sequence must have S_in >= 1")
|
||
|
||
# 1) 像素 -> 单位 (sin,cos),并获取每条序列的 (cx,cy,r)
|
||
if req.centers is not None and req.radii is not None:
|
||
# 使用来自拍端的标定参数
|
||
centers = np.array(req.centers, dtype=np.float32) # [B,2]
|
||
radii = np.array(req.radii, dtype=np.float32) # [B]
|
||
if centers.shape[0] != pix.shape[0] or centers.shape[1] != 2 or radii.shape[0] != pix.shape[0]:
|
||
raise HTTPException(400, "centers 应为 [B,2] 且 radii 为 [B]")
|
||
radii = np.clip(radii, 1e-6, None)
|
||
# 根据训练时定义:sin=(x-cx)/r, cos=(cy-y)/r
|
||
sin = (pix[..., 0] - centers[:, 0:1]) / radii[:, None]
|
||
cos = (centers[:, 1:2] - pix[..., 1]) / radii[:, None]
|
||
sc = np.stack([sin, cos], axis=-1).astype(np.float32) # [B,S,2]
|
||
else:
|
||
sc, centers, radii = batch_pixels_to_unit_sincos(pix)
|
||
|
||
# 2) 送入引擎:期望输入为 raw (即单位向量的 sin,cos)
|
||
x = torch.from_numpy(sc).to(ENGINE.device)
|
||
pred_sc = ENGINE.rollout(x, steps=req.steps) # [B, steps, 2] 单位 (sin,cos)
|
||
pred_sc_np = pred_sc.detach().cpu().numpy()
|
||
|
||
# 3) 单位 (sin,cos) -> 像素
|
||
pred_pix = unit_sincos_to_pixels(pred_sc_np, centers, radii) # [B, steps, 2]
|
||
|
||
return PredictResponse(
|
||
pred_xy=pred_pix.tolist(), device="cuda",
|
||
centers=centers.tolist(), radii=radii.tolist()
|
||
)
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--ckpt", type=str, required=True, help="Path to checkpoint (.pt)")
|
||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||
parser.add_argument("--port", type=int, default=8000)
|
||
args = parser.parse_args()
|
||
|
||
global ENGINE
|
||
ENGINE = InferenceEngine(args.ckpt)
|
||
print(f"[serve] Loaded ckpt from {args.ckpt} on device=cuda")
|
||
|
||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|