2025-09-24 11:35:48 +08:00

117 lines
3.6 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
几何与坐标变换工具:
- 基于最小二乘的圆拟合
- 像素坐标 <-> 单位圆 (sin, cos) 变换
与训练脚本的定义保持一致:
theta = atan2(x - cx, cy - y)
sin = sin(theta) = (x - cx) / r
cos = cos(theta) = (cy - y) / r = -(y - cy) / r
因此,反变换:
x = cx + r * sin
y = cy - r * cos
"""
from __future__ import annotations
import math
import numpy as np
from typing import Tuple
def fit_circle_lstsq(x: np.ndarray, y: np.ndarray) -> Tuple[float, float, float]:
"""最小二乘拟合圆:解 x^2 + y^2 = 2 a x + 2 b y + c
返回 (cx, cy, r)
"""
x = np.asarray(x, dtype=np.float64)
y = np.asarray(y, dtype=np.float64)
A = np.c_[2 * x, 2 * y, np.ones_like(x)]
b = x * x + y * y
sol, *_ = np.linalg.lstsq(A, b, rcond=None)
a, b_, c = sol
cx, cy = a, b_
r = math.sqrt(max(cx * cx + cy * cy + c, 0.0))
return float(cx), float(cy), float(r)
def _estimate_center_radius(seq_xy: np.ndarray) -> Tuple[float, float, float]:
"""对单条序列估计 (cx, cy, r)。
- 当样本点 >=3 时使用最小二乘拟合
- 否则使用均值点作为中心、均距为半径;若半径过小,则回退为 1.0
"""
seq_xy = np.asarray(seq_xy, dtype=np.float64)
S = seq_xy.shape[0]
if S >= 3:
cx, cy, r = fit_circle_lstsq(seq_xy[:, 0], seq_xy[:, 1])
else:
cx, cy = np.mean(seq_xy, axis=0)
d = np.linalg.norm(seq_xy - np.array([cx, cy], dtype=np.float64), axis=1)
r = float(np.median(d) if d.size else 1.0)
r = float(max(r, 1e-6))
return float(cx), float(cy), r
def pixels_to_unit_sincos(seq_xy: np.ndarray) -> Tuple[np.ndarray, Tuple[float, float, float]]:
"""将单条像素轨迹 [S,2] 转为单位 (sin,cos) [S,2],并返回 (cx,cy,r)。"""
seq_xy = np.asarray(seq_xy, dtype=np.float64)
cx, cy, r = _estimate_center_radius(seq_xy)
x = seq_xy[:, 0]
y = seq_xy[:, 1]
# 与训练一致的角度定义
theta = np.arctan2(x - cx, cy - y)
sc = np.stack([np.sin(theta), np.cos(theta)], axis=-1).astype(np.float32)
return sc, (float(cx), float(cy), float(r))
def batch_pixels_to_unit_sincos(batch_xy: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""批量像素 -> 单位 (sin,cos)。
参数:
batch_xy: [B,S,2] 像素坐标
返回:
sc: [B,S,2] 单位 (sin,cos)
centers: [B,2] (cx,cy)
radii: [B] r
"""
batch_xy = np.asarray(batch_xy, dtype=np.float64)
B, S, _ = batch_xy.shape
sc = np.empty((B, S, 2), dtype=np.float32)
centers = np.empty((B, 2), dtype=np.float32)
radii = np.empty((B,), dtype=np.float32)
for i in range(B):
seq = batch_xy[i]
sc_i, (cx, cy, r) = pixels_to_unit_sincos(seq)
sc[i] = sc_i
centers[i] = (cx, cy)
radii[i] = r
return sc, centers, radii
def unit_sincos_to_pixels(sc: np.ndarray, centers: np.ndarray, radii: np.ndarray) -> np.ndarray:
"""单位 (sin,cos) -> 像素坐标。
支持广播:
sc: [B,S,2]
centers: [B,2]
radii: [B] or [B,1]
返回:
pix: [B,S,2]
"""
sc = np.asarray(sc, dtype=np.float64)
centers = np.asarray(centers, dtype=np.float64)
radii = np.asarray(radii, dtype=np.float64).reshape(-1, 1, 1) # [B,1,1]
sin = sc[..., 0]
cos = sc[..., 1]
cx = centers[:, 0].reshape(-1, 1)
cy = centers[:, 1].reshape(-1, 1)
x = cx + radii[:, 0, 0:1] * sin # 广播 [B,1] * [B,S]
y = cy - radii[:, 0, 0:1] * cos
pix = np.stack([x, y], axis=-1)
return pix.astype(np.float32)