242 lines
8.7 KiB
Python
242 lines
8.7 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
批量处理 CSV:
|
||
- 从 CSV(支持多个路径或通配符)读取中国语系 GB2312 编码数据;
|
||
- 对黄色/红色圆心序列分别做最小二乘圆拟合;
|
||
- 计算角度的 sin/cos 以及按拟合圆归一化后的坐标(pos_sin/pos_cos);
|
||
- 导出 UTF-8 CSV 到指定输出目录;
|
||
- 可选保存拟合示意图 PNG。
|
||
|
||
用法示例:
|
||
1) 处理一个/多个文件:
|
||
python circle_fit_batch.py data/3.csv data/4.csv -o outputs
|
||
|
||
2) 使用通配符批量处理:
|
||
python circle_fit_batch.py "data/*.csv" -o outputs
|
||
|
||
3) 指定不同编码或禁用绘图:
|
||
python circle_fit_batch.py "data/*.csv" -o outputs --encoding gb2312 --no-plot
|
||
|
||
4) 指定中文字体文件用于图例(可选):
|
||
python circle_fit_batch.py "data/*.csv" -o outputs --font "/path/to/PingFang Regular.ttf"
|
||
"""
|
||
import argparse
|
||
import glob
|
||
import math
|
||
import os
|
||
import sys
|
||
from typing import Tuple, List
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
# 后端设为无界面,以便在服务器/批处理环境下保存图像
|
||
import matplotlib
|
||
matplotlib.use("Agg")
|
||
import matplotlib.pyplot as plt
|
||
from matplotlib import font_manager
|
||
|
||
REQUIRED_COLS = [
|
||
"frameIndex",
|
||
"time",
|
||
"greenRectRotation",
|
||
"blueRectRotation",
|
||
"greenRectCenterX",
|
||
"greenRectCenterY",
|
||
"blueRectCenterX",
|
||
"blueRectCenterY",
|
||
]
|
||
|
||
|
||
def fit_circle_lstsq(x: np.ndarray, y: np.ndarray) -> Tuple[float, float, float]:
|
||
"""最小二乘拟合圆:解 x^2 + y^2 = 2 a x + 2 b y + c
|
||
返回 (cx, cy, r)
|
||
"""
|
||
x = np.asarray(x, dtype=np.float64)
|
||
y = np.asarray(y, dtype=np.float64)
|
||
A = np.c_[2 * x, 2 * y, np.ones_like(x)]
|
||
b = x * x + y * y
|
||
sol, *_ = np.linalg.lstsq(A, b, rcond=None)
|
||
a, b_, c = sol
|
||
cx, cy = a, b_
|
||
r = math.sqrt(max(cx * cx + cy * cy + c, 0.0))
|
||
return float(cx), float(cy), float(r)
|
||
|
||
|
||
def deg_to_cos_sin(deg: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""角度(度) -> (cos, sin)"""
|
||
rad = np.radians(deg.astype(np.float64))
|
||
return np.cos(rad), np.sin(rad)
|
||
|
||
|
||
def expand_inputs(inputs: List[str]) -> List[str]:
|
||
files: List[str] = []
|
||
for pat in inputs:
|
||
matched = glob.glob(pat)
|
||
if matched:
|
||
files.extend(matched)
|
||
else:
|
||
# 若未匹配到且参数本身是一个存在的文件,也加入
|
||
if os.path.isfile(pat):
|
||
files.append(pat)
|
||
# 去重并排序,便于稳定输出
|
||
files = sorted(set(files))
|
||
return files
|
||
|
||
def ensure_output_dir(out_dir: str) -> None:
|
||
if not os.path.isdir(out_dir):
|
||
os.makedirs(out_dir, exist_ok=True)
|
||
|
||
|
||
def process_one_csv(csv_path: str, out_dir: str, encoding: str, plot: bool, font_path: str = None, overwrite: bool = False) -> None:
|
||
base = os.path.splitext(os.path.basename(csv_path))[0]
|
||
out_csv = os.path.join(out_dir, f"{base}_angles.csv")
|
||
out_png = os.path.join(out_dir, f"{base}_circle_fit.png")
|
||
|
||
if not overwrite and os.path.exists(out_csv):
|
||
print(f"[SKIP] 输出已存在(使用 --overwrite 覆盖):{out_csv}")
|
||
return
|
||
|
||
try:
|
||
df = pd.read_csv(csv_path, encoding=encoding)
|
||
except Exception as e:
|
||
print(f"[ERROR] 读取失败:{csv_path} ({e})")
|
||
return
|
||
|
||
missing = [c for c in REQUIRED_COLS if c not in df.columns]
|
||
if missing:
|
||
print(f"[ERROR] 列缺失:{csv_path} 缺少 {missing}")
|
||
return
|
||
|
||
# 将必要列转为数值,无法解析的置为 NaN
|
||
for c in REQUIRED_COLS:
|
||
df[c] = pd.to_numeric(df[c], errors="coerce")
|
||
|
||
# 逐帧过滤:若任一必要列缺失/非数,则剔除该帧
|
||
valid_mask = df[REQUIRED_COLS].notna().all(axis=1)
|
||
dropped = int(len(df) - valid_mask.sum())
|
||
df_valid = df.loc[valid_mask].reset_index(drop=True)
|
||
if dropped > 0:
|
||
print(f"[INFO] 已过滤 {dropped} 帧(必要列缺失),保留 {len(df_valid)} 帧:{csv_path}")
|
||
|
||
# 至少需要 2 帧来计算 deltaTime
|
||
if len(df_valid) < 2:
|
||
print(f"[ERROR] 有效帧不足2帧,无法计算deltaTime:{csv_path}")
|
||
return
|
||
|
||
# 基于过滤后的时间序列计算 deltaTime;输出与其它列对齐(从第2帧开始)
|
||
time = df_valid["time"].to_numpy(dtype=float)
|
||
delta_time = np.diff(time)
|
||
|
||
# 从第2帧开始提取所有数据(与 delta_time 对齐)
|
||
green_x = df_valid["greenRectCenterX"].to_numpy(dtype=float)[1:]
|
||
green_y = df_valid["greenRectCenterY"].to_numpy(dtype=float)[1:]
|
||
red_x = df_valid["blueRectCenterX"].to_numpy(dtype=float)[1:]
|
||
red_y = df_valid["blueRectCenterY"].to_numpy(dtype=float)[1:]
|
||
|
||
# 圆拟合
|
||
ycx, ycy, yr = fit_circle_lstsq(green_x, green_y)
|
||
rcx, rcy, rr = fit_circle_lstsq(red_x, red_y)
|
||
|
||
if yr <= 0 or rr <= 0 or not np.isfinite([yr, rr]).all():
|
||
print(f"[ERROR] 拟合半径非法(<=0 或 非数),文件:{csv_path}")
|
||
return
|
||
|
||
# 角度 -> cos/sin (从第2帧开始)
|
||
green_angle_deg = df_valid["greenRectRotation"].to_numpy(dtype=float)[1:]
|
||
red_angle_deg = df_valid["blueRectRotation"].to_numpy(dtype=float)[1:]
|
||
green_angle_cos, green_angle_sin = deg_to_cos_sin(green_angle_deg)
|
||
red_angle_cos, red_angle_sin = deg_to_cos_sin(red_angle_deg)
|
||
|
||
# 位置按拟合圆归一化(cosx=沿X方向的归一化偏移;cosy=沿Y方向的归一化偏移)
|
||
green_cosx = (green_x - ycx) / yr
|
||
green_cosy = (green_y - ycy) / yr
|
||
red_cosx = (red_x - rcx) / rr
|
||
red_cosy = (red_y - rcy) / rr
|
||
|
||
out_df = pd.DataFrame({
|
||
"frame_index": df_valid["frameIndex"].to_numpy()[1:], # 从第2帧开始
|
||
"delta_time": delta_time,
|
||
"green_angle_sin": green_angle_sin,
|
||
"green_angle_cos": green_angle_cos,
|
||
"red_angle_sin": red_angle_sin,
|
||
"red_angle_cos": red_angle_cos,
|
||
"green_pos_sin": green_cosy,
|
||
"green_pos_cos": green_cosx,
|
||
"red_pos_sin": red_cosy,
|
||
"red_pos_cos": red_cosx,
|
||
})
|
||
|
||
try:
|
||
out_df.to_csv(out_csv, index=False, encoding="utf-8")
|
||
print(f"[OK] 导出:{out_csv} | 黄圆 (cx,cy,r)=({ycx:.6g},{ycy:.6g},{yr:.6g}) 红圆 (cx,cy,r)=({rcx:.6g},{rcy:.6g},{rr:.6g})")
|
||
except Exception as e:
|
||
print(f"[ERROR] 写出 CSV 失败:{out_csv} ({e})")
|
||
|
||
if plot:
|
||
# 字体(可选)
|
||
if font_path:
|
||
try:
|
||
font_manager.fontManager.addfont(font_path)
|
||
# 若传入中文字体,则优先用它显示中文
|
||
font_name = os.path.splitext(os.path.basename(font_path))[0]
|
||
plt.rcParams["font.sans-serif"] = [font_name]
|
||
except Exception as e:
|
||
print(f"[WARN] 字体加载失败:{font_path} ({e})")
|
||
|
||
fig, ax = plt.subplots()
|
||
ax.set_aspect("equal", "box")
|
||
ax.scatter(green_x, green_y, s=1, label="黄色")
|
||
ax.scatter(red_x, red_y, s=1, label="红色")
|
||
|
||
circle1 = plt.Circle((ycx, ycy), yr, color="orange", fill=False, label="拟合黄色圆")
|
||
circle2 = plt.Circle((rcx, rcy), rr, color="red", fill=False, label="拟合红色圆")
|
||
ax.add_artist(circle1)
|
||
ax.add_artist(circle2)
|
||
ax.legend()
|
||
|
||
try:
|
||
fig.savefig(out_png, dpi=150, bbox_inches="tight")
|
||
print(f"[OK] 拟合图:{out_png}")
|
||
finally:
|
||
plt.close(fig)
|
||
|
||
|
||
def parse_args():
|
||
p = argparse.ArgumentParser(description="批量对 CSV 进行圆拟合与角度/位置归一化导出")
|
||
p.add_argument("inputs", nargs="+", help="输入 CSV 路径或通配符(可多个),例如 data/3.csv 或 \"data/*.csv\"")
|
||
p.add_argument("-o", "--output-dir", required=True, help="输出目录,将在其中生成 *_angles.csv 与 *_circle_fit.png")
|
||
p.add_argument("--encoding", default="utf-8", help="输入 CSV 编码(默认 utf-8)")
|
||
p.add_argument("--no-plot", action="store_true", help="不保存拟合示意图 PNG")
|
||
p.add_argument("--font", default=None, help="可选:中文字体文件路径(用于图例/标签显示中文)")
|
||
p.add_argument("--overwrite", action="store_true", help="若输出文件存在则覆盖")
|
||
return p.parse_args()
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
files = expand_inputs(args.inputs)
|
||
if not files:
|
||
print("[ERROR] 未找到任何输入文件。请检查路径/通配符。", file=sys.stderr)
|
||
sys.exit(2)
|
||
|
||
ensure_output_dir(args.output_dir)
|
||
|
||
print(f"共 {len(files)} 个文件,将输出到:{args.output_dir}")
|
||
for fp in files:
|
||
print(f"==> 处理:{fp}")
|
||
process_one_csv(
|
||
csv_path=fp,
|
||
out_dir=args.output_dir,
|
||
encoding=args.encoding,
|
||
plot=(not args.no_plot),
|
||
font_path=args.font,
|
||
overwrite=args.overwrite,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|