2025-09-24 11:35:48 +08:00

239 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
启动一个仅 CUDA 的 HTTP 推理服务,输入/输出均为像素坐标。
内部将像素轨迹拟合圆并转换为单位 (sin,cos) 以适配训练空间,
并把预测结果再还原为像素坐标返回。
POST /predict
body:
{
"sequences": [ # 批量,每个序列为 [S_in, 2] 像素 [x,y]
[[x, y], [x, y], ...],
...
],
"steps": 30, # 预测步数 (>0)
"centers": [[cx,cy], ...], # 可选 [B,2],若已标定则可传入
"radii": [r, ...] # 可选 [B]
}
返回:
{
"pred_xy": [[[x,y], ...], ...], # [B, steps, 2] 像素
"device": "cuda",
"centers": [[cx,cy], ...], # 用于本次变换的参数
"radii": [r, ...]
}
运行示例:
uv run serve_pendulum.py --ckpt ckpt/pendulum_tf_base.pt --host 0.0.0.0 --port 8000
"""
import argparse
import math
import os
from typing import List, Optional
import numpy as np
import torch
import torch.nn as nn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
# 兼容直接脚本运行与包内相对导入
try:
from .geom_utils import batch_pixels_to_unit_sincos, unit_sincos_to_pixels # type: ignore
except Exception: # noqa
from geom_utils import batch_pixels_to_unit_sincos, unit_sincos_to_pixels
# ----------------- 与训练一致的模块 -----------------
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 4096):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer("pe", pe.unsqueeze(0)) # [1, L, D]
def forward(self, x):
return x + self.pe[:, : x.size(1), :]
class TimeSeriesTransformer(nn.Module):
def __init__(self, in_dim=2, d_model=256, nhead=4, enc_layers=3, dec_layers=3, d_ff=512, dropout=0.1, out_dim=2):
super().__init__()
self.in_proj = nn.Linear(in_dim, d_model)
self.out_proj = nn.Linear(d_model, out_dim)
self.tgt_proj = nn.Linear(out_dim, d_model)
self.pos_enc = PositionalEncoding(d_model)
self.tf = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=enc_layers,
num_decoder_layers=dec_layers,
dim_feedforward=d_ff,
dropout=dropout,
batch_first=True,
)
def forward(self, src, tgt_in, tgt_mask=None):
src = self.pos_enc(self.in_proj(src))
tgt = self.pos_enc(self.tgt_proj(tgt_in))
out = self.tf(src, tgt, tgt_mask=tgt_mask) # [B, S_out, D]
return self.out_proj(out)
# ----------------- 推理引擎(仅 CUDA -----------------
class InferenceEngine:
def __init__(self, ckpt_path: str):
if not torch.cuda.is_available():
raise RuntimeError("CUDA 不可用:该服务仅支持 CUDA。")
self.device = torch.device("cuda")
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
# 显式关闭 weights_only 的安全限制(信任本地 ckpt
payload = torch.load(ckpt_path, map_location=self.device, weights_only=False)
default_cfg = dict(d_model=256, nhead=4, enc_layers=3, dec_layers=3, d_ff=512)
ckpt_cfg = payload.get("cfg", default_cfg)
self.net = TimeSeriesTransformer(in_dim=2, out_dim=2, **ckpt_cfg).to(self.device)
self.net.load_state_dict(payload["model"], strict=True)
self.net.eval()
torch.set_grad_enabled(False)
# 标准化参数
self.mu = payload["mu"] # numpy
self.std = payload["std"]
self.mu_t = torch.from_numpy(self.mu).to(self.device)
self.std_t = torch.from_numpy(self.std).to(self.device)
@staticmethod
def _unitize_t(t: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
n = torch.clamp(torch.linalg.norm(t, dim=-1, keepdim=True), min=eps)
return t / n
def _to_raw(self, z: torch.Tensor) -> torch.Tensor:
return z * self.std_t + self.mu_t
def _to_std(self, z: torch.Tensor) -> torch.Tensor:
return (z - self.mu_t) / self.std_t
def rollout(self, xb_raw: torch.Tensor, steps: int) -> torch.Tensor:
"""
xb_raw: [B, S_in, 2]raw 空间(即训练时用的 (sin, cos) 或等价的 [x,y] 单位向量)
返回raw 空间预测 [B, steps, 2]
"""
xb_raw = self._unitize_t(xb_raw) # 规整到单位圆
xb_std = self._to_std(xb_raw) # -> 标准化
tgt_tokens = xb_std[:, -1:, :2].clone() # 起始 token
for _ in range(steps):
mask = nn.Transformer.generate_square_subsequent_mask(tgt_tokens.size(1)).to(self.device)
out = self.net(xb_std, tgt_tokens, tgt_mask=mask) # 标准化空间
next_std = out[:, -1:, :]
next_raw = self._to_raw(next_std) # -> raw
next_raw = self._unitize_t(next_raw) # 单位圆约束
next_std = self._to_std(next_raw) # -> 标准化
tgt_tokens = torch.cat([tgt_tokens, next_std], dim=1)
pred_std = tgt_tokens[:, 1:, :]
pred_raw = self._to_raw(pred_std)
pred_raw = self._unitize_t(pred_raw) # 返回前再规整
return pred_raw
# ----------------- HTTP 层(仅 /predict -----------------
class PredictRequest(BaseModel):
sequences: List[List[List[float]]] = Field(
..., description="批量序列:每个为 [S_in,2] 的像素 [x,y] 轨迹(来自图像跟踪)"
)
steps: int = Field(..., gt=0, description="预测步数 (>0)")
# 可选:若客户端已标定可传入,以避免重复拟合
centers: Optional[List[List[float]]] = Field(
default=None, description="可选 [B,2],每条序列的圆心 (cx,cy)"
)
radii: Optional[List[float]] = Field(
default=None, description="可选 [B],每条序列的半径 r"
)
class PredictResponse(BaseModel):
pred_xy: List[List[List[float]]] # 未来像素坐标 [B, steps, 2]
device: str
centers: Optional[List[List[float]]] = None
radii: Optional[List[float]] = None
app = FastAPI(title="Pendulum Transformer Inference (CUDA Only)", version="1.0.0")
# CORS: 允许任意来源
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
ENGINE: InferenceEngine | None = None # 运行时注入
@app.post("/predict", response_model=PredictResponse)
def predict(req: PredictRequest):
if ENGINE is None:
raise HTTPException(500, "Engine not initialized")
try:
pix = np.array(req.sequences, dtype=np.float32) # [B, S_in, 2]
except Exception as e:
raise HTTPException(400, f"Invalid 'sequences': {e}")
if pix.ndim != 3 or pix.shape[-1] != 2:
raise HTTPException(400, f"'sequences' must be [B, S_in, 2], got {pix.shape}")
if req.steps <= 0:
raise HTTPException(400, "'steps' must be > 0")
if pix.shape[1] < 1:
raise HTTPException(400, "Each sequence must have S_in >= 1")
# 1) 像素 -> 单位 (sin,cos),并获取每条序列的 (cx,cy,r)
if req.centers is not None and req.radii is not None:
# 使用来自拍端的标定参数
centers = np.array(req.centers, dtype=np.float32) # [B,2]
radii = np.array(req.radii, dtype=np.float32) # [B]
if centers.shape[0] != pix.shape[0] or centers.shape[1] != 2 or radii.shape[0] != pix.shape[0]:
raise HTTPException(400, "centers 应为 [B,2] 且 radii 为 [B]")
radii = np.clip(radii, 1e-6, None)
# 根据训练时定义sin=(x-cx)/r, cos=(cy-y)/r
sin = (pix[..., 0] - centers[:, 0:1]) / radii[:, None]
cos = (centers[:, 1:2] - pix[..., 1]) / radii[:, None]
sc = np.stack([sin, cos], axis=-1).astype(np.float32) # [B,S,2]
else:
sc, centers, radii = batch_pixels_to_unit_sincos(pix)
# 2) 送入引擎:期望输入为 raw (即单位向量的 sin,cos)
x = torch.from_numpy(sc).to(ENGINE.device)
pred_sc = ENGINE.rollout(x, steps=req.steps) # [B, steps, 2] 单位 (sin,cos)
pred_sc_np = pred_sc.detach().cpu().numpy()
# 3) 单位 (sin,cos) -> 像素
pred_pix = unit_sincos_to_pixels(pred_sc_np, centers, radii) # [B, steps, 2]
return PredictResponse(
pred_xy=pred_pix.tolist(), device="cuda",
centers=centers.tolist(), radii=radii.tolist()
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", type=str, required=True, help="Path to checkpoint (.pt)")
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
global ENGINE
ENGINE = InferenceEngine(args.ckpt)
print(f"[serve] Loaded ckpt from {args.ckpt} on device=cuda")
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
if __name__ == "__main__":
main()