221 lines
7.8 KiB
Python
221 lines
7.8 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
批量处理 CSV:
|
||
- 从 CSV(支持多个路径或通配符)读取中国语系 GB2312 编码数据;
|
||
- 对黄色/红色圆心序列分别做最小二乘圆拟合;
|
||
- 计算角度的 sin/cos 以及按拟合圆归一化后的坐标(pos_sin/pos_cos);
|
||
- 导出 UTF-8 CSV 到指定输出目录;
|
||
- 可选保存拟合示意图 PNG。
|
||
|
||
用法示例:
|
||
1) 处理一个/多个文件:
|
||
python circle_fit_batch.py data/3.csv data/4.csv -o outputs
|
||
|
||
2) 使用通配符批量处理:
|
||
python circle_fit_batch.py "data/*.csv" -o outputs
|
||
|
||
3) 指定不同编码或禁用绘图:
|
||
python circle_fit_batch.py "data/*.csv" -o outputs --encoding gb2312 --no-plot
|
||
|
||
4) 指定中文字体文件用于图例(可选):
|
||
python circle_fit_batch.py "data/*.csv" -o outputs --font "/path/to/PingFang Regular.ttf"
|
||
"""
|
||
import argparse
|
||
import glob
|
||
import math
|
||
import os
|
||
import sys
|
||
from typing import Tuple, List
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
# 后端设为无界面,以便在服务器/批处理环境下保存图像
|
||
import matplotlib
|
||
matplotlib.use("Agg")
|
||
import matplotlib.pyplot as plt
|
||
from matplotlib import font_manager
|
||
|
||
REQUIRED_COLS = [
|
||
"帧序号",
|
||
"黄色与水平线夹角(度)",
|
||
"红色与水平线夹角(度)",
|
||
"黄色外接圆圆心X",
|
||
"黄色外接圆圆心Y",
|
||
"红色相关外接圆圆心X",
|
||
"红色相关外接圆圆心Y",
|
||
]
|
||
|
||
|
||
def fit_circle_lstsq(x: np.ndarray, y: np.ndarray) -> Tuple[float, float, float]:
|
||
"""最小二乘拟合圆:解 x^2 + y^2 = 2 a x + 2 b y + c
|
||
返回 (cx, cy, r)
|
||
"""
|
||
x = np.asarray(x, dtype=np.float64)
|
||
y = np.asarray(y, dtype=np.float64)
|
||
A = np.c_[2 * x, 2 * y, np.ones_like(x)]
|
||
b = x * x + y * y
|
||
sol, *_ = np.linalg.lstsq(A, b, rcond=None)
|
||
a, b_, c = sol
|
||
cx, cy = a, b_
|
||
r = math.sqrt(max(cx * cx + cy * cy + c, 0.0))
|
||
return float(cx), float(cy), float(r)
|
||
|
||
|
||
def deg_to_cos_sin(deg: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""角度(度) -> (cos, sin)"""
|
||
rad = np.radians(deg.astype(np.float64))
|
||
return np.cos(rad), np.sin(rad)
|
||
|
||
|
||
def expand_inputs(inputs: List[str]) -> List[str]:
|
||
files: List[str] = []
|
||
for pat in inputs:
|
||
matched = glob.glob(pat)
|
||
if matched:
|
||
files.extend(matched)
|
||
else:
|
||
# 若未匹配到且参数本身是一个存在的文件,也加入
|
||
if os.path.isfile(pat):
|
||
files.append(pat)
|
||
# 去重并排序,便于稳定输出
|
||
files = sorted(set(files))
|
||
return files
|
||
|
||
|
||
def ensure_output_dir(out_dir: str) -> None:
|
||
if not os.path.isdir(out_dir):
|
||
os.makedirs(out_dir, exist_ok=True)
|
||
|
||
|
||
def process_one_csv(csv_path: str, out_dir: str, encoding: str, plot: bool, font_path: str = None, overwrite: bool = False) -> None:
|
||
base = os.path.splitext(os.path.basename(csv_path))[0]
|
||
out_csv = os.path.join(out_dir, f"{base}_angles.csv")
|
||
out_png = os.path.join(out_dir, f"{base}_circle_fit.png")
|
||
|
||
if not overwrite and os.path.exists(out_csv):
|
||
print(f"[SKIP] 输出已存在(使用 --overwrite 覆盖):{out_csv}")
|
||
return
|
||
|
||
try:
|
||
df = pd.read_csv(csv_path, encoding=encoding)
|
||
except Exception as e:
|
||
print(f"[ERROR] 读取失败:{csv_path} ({e})")
|
||
return
|
||
|
||
missing = [c for c in REQUIRED_COLS if c not in df.columns]
|
||
if missing:
|
||
print(f"[ERROR] 列缺失:{csv_path} 缺少 {missing}")
|
||
return
|
||
|
||
# 提取列
|
||
yellow_x = df["黄色外接圆圆心X"].astype(float).to_numpy()
|
||
yellow_y = df["黄色外接圆圆心Y"].astype(float).to_numpy()
|
||
red_x = df["红色相关外接圆圆心X"].astype(float).to_numpy()
|
||
red_y = df["红色相关外接圆圆心Y"].astype(float).to_numpy()
|
||
|
||
# 圆拟合
|
||
ycx, ycy, yr = fit_circle_lstsq(yellow_x, yellow_y)
|
||
rcx, rcy, rr = fit_circle_lstsq(red_x, red_y)
|
||
|
||
if yr <= 0 or rr <= 0 or not np.isfinite([yr, rr]).all():
|
||
print(f"[ERROR] 拟合半径非法(<=0 或 非数),文件:{csv_path}")
|
||
return
|
||
|
||
# 角度 -> cos/sin
|
||
yellow_angle_deg = pd.to_numeric(df["黄色与水平线夹角(度)"], errors="coerce").to_numpy()
|
||
red_angle_deg = pd.to_numeric(df["红色与水平线夹角(度)"], errors="coerce").to_numpy()
|
||
yellow_angle_cos, yellow_angle_sin = deg_to_cos_sin(yellow_angle_deg)
|
||
red_angle_cos, red_angle_sin = deg_to_cos_sin(red_angle_deg)
|
||
|
||
# 位置按拟合圆归一化(cosx=沿X方向的归一化偏移;cosy=沿Y方向的归一化偏移)
|
||
yellow_cosx = (yellow_x - ycx) / yr
|
||
yellow_cosy = (yellow_y - ycy) / yr
|
||
red_cosx = (red_x - rcx) / rr
|
||
red_cosy = (red_y - rcy) / rr
|
||
|
||
out_df = pd.DataFrame({
|
||
"frame_index": df["帧序号"],
|
||
"yellow_angle_sin": yellow_angle_sin,
|
||
"yellow_angle_cos": yellow_angle_cos,
|
||
"red_angle_sin": red_angle_sin,
|
||
"red_angle_cos": red_angle_cos,
|
||
"yellow_pos_sin": yellow_cosy,
|
||
"yellow_pos_cos": yellow_cosx,
|
||
"red_pos_sin": red_cosy,
|
||
"red_pos_cos": red_cosx,
|
||
})
|
||
|
||
try:
|
||
out_df.to_csv(out_csv, index=False, encoding="utf-8")
|
||
print(f"[OK] 导出:{out_csv} | 黄圆 (cx,cy,r)=({ycx:.6g},{ycy:.6g},{yr:.6g}) 红圆 (cx,cy,r)=({rcx:.6g},{rcy:.6g},{rr:.6g})")
|
||
except Exception as e:
|
||
print(f"[ERROR] 写出 CSV 失败:{out_csv} ({e})")
|
||
|
||
if plot:
|
||
# 字体(可选)
|
||
if font_path:
|
||
try:
|
||
font_manager.fontManager.addfont(font_path)
|
||
# 若传入中文字体,则优先用它显示中文
|
||
font_name = os.path.splitext(os.path.basename(font_path))[0]
|
||
plt.rcParams["font.sans-serif"] = [font_name]
|
||
except Exception as e:
|
||
print(f"[WARN] 字体加载失败:{font_path} ({e})")
|
||
|
||
fig, ax = plt.subplots()
|
||
ax.set_aspect("equal", "box")
|
||
ax.scatter(yellow_x, yellow_y, s=1, label="黄色")
|
||
ax.scatter(red_x, red_y, s=1, label="红色")
|
||
|
||
circle1 = plt.Circle((ycx, ycy), yr, color="orange", fill=False, label="拟合黄色圆")
|
||
circle2 = plt.Circle((rcx, rcy), rr, color="red", fill=False, label="拟合红色圆")
|
||
ax.add_artist(circle1)
|
||
ax.add_artist(circle2)
|
||
ax.legend()
|
||
|
||
try:
|
||
fig.savefig(out_png, dpi=150, bbox_inches="tight")
|
||
print(f"[OK] 拟合图:{out_png}")
|
||
finally:
|
||
plt.close(fig)
|
||
|
||
|
||
def parse_args():
|
||
p = argparse.ArgumentParser(description="批量对 CSV 进行圆拟合与角度/位置归一化导出")
|
||
p.add_argument("inputs", nargs="+", help="输入 CSV 路径或通配符(可多个),例如 data/3.csv 或 \"data/*.csv\"")
|
||
p.add_argument("-o", "--output-dir", required=True, help="输出目录,将在其中生成 *_angles.csv 与 *_circle_fit.png")
|
||
p.add_argument("--encoding", default="gb2312", help="输入 CSV 编码(默认 gb2312)")
|
||
p.add_argument("--no-plot", action="store_true", help="不保存拟合示意图 PNG")
|
||
p.add_argument("--font", default=None, help="可选:中文字体文件路径(用于图例/标签显示中文)")
|
||
p.add_argument("--overwrite", action="store_true", help="若输出文件存在则覆盖")
|
||
return p.parse_args()
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
files = expand_inputs(args.inputs)
|
||
if not files:
|
||
print("[ERROR] 未找到任何输入文件。请检查路径/通配符。", file=sys.stderr)
|
||
sys.exit(2)
|
||
|
||
ensure_output_dir(args.output_dir)
|
||
|
||
print(f"共 {len(files)} 个文件,将输出到:{args.output_dir}")
|
||
for fp in files:
|
||
print(f"==> 处理:{fp}")
|
||
process_one_csv(
|
||
csv_path=fp,
|
||
out_dir=args.output_dir,
|
||
encoding=args.encoding,
|
||
plot=(not args.no_plot),
|
||
font_path=args.font,
|
||
overwrite=args.overwrite,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|