ball-tracking-cv/app/routes.py
2025-08-10 10:01:43 +08:00

248 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import io
import os
import logging
from flask import Blueprint, current_app, jsonify, request, send_from_directory
from werkzeug.utils import secure_filename
import numpy as np
import pandas as pd
from .analysis import (
load_data,
calculate_angle_from_pendulum_length,
fit_pendulum_model,
damped_oscillator,
calculate_physical_parameters,
identify_periods,
visualize_results,
visualize_physical_parameters,
calculate_large_angle_g,
savgol_filter,
fit_circle,
)
from .storage import LocalSessionStorage
bp = Blueprint("api", __name__, url_prefix="/api")
log = logging.getLogger(__name__)
def _json_error(message: str, status: int = 400):
return jsonify({"success": False, "error": message}), status
@bp.before_app_request
def _housekeeping():
# opportunistic TTL cleanup each request
LocalSessionStorage().cleanup_expired()
@bp.route("/health", methods=["GET"])
def health():
return jsonify({"status": "ok"})
@bp.route("/analyze", methods=["POST"])
def analyze():
if 'csv_file' not in request.files:
return _json_error('没有上传CSV文件', 400)
file = request.files['csv_file']
if not file.filename:
return _json_error('空文件名', 400)
# Basic content type & size checks are handled by Flask + MAX_CONTENT_LENGTH
filename = secure_filename(file.filename)
# Read into memory buffer; Pandas can ingest file-like
buf = io.BytesIO(file.read())
buf.seek(0)
data = load_data(buf)
if data is None:
return _json_error('CSV文件格式错误或无法读取', 400)
# Column normalization and pendulum length extraction
cols = set(data.columns)
pendulum_length = None
if {'time', 'x', 'y', 'pendulum_length'}.issubset(cols):
# 新格式:包含摆长信息
pendulum_length = data['pendulum_length'].iloc[0] # 获取摆长
log.info(f"从CSV中读取摆长: {pendulum_length:.4f}")
elif {'time', 'x/mm', 'y/mm'}.issubset(cols):
# 兼容旧格式:已经是毫米单位,通过拟合圆推断摆长
log.warning("使用旧数据格式,将通过拟合圆推断摆长")
data = data.rename(columns={'x/mm': 'x', 'y/mm': 'y'})
# 临时计算摆长用于兼容
x_temp = data['x'].values
y_temp = data['y'].values
_, _, r_mm = fit_circle(x_temp, y_temp)
pendulum_length = r_mm / 1000.0 # 转换为米
elif {'time', 'x', 'y'}.issubset(cols):
return _json_error('缺少摆长信息,请在前端输入摆长参数', 400)
else:
return _json_error('数据格式不兼容需要time、x、y列以及pendulum_length信息', 400)
if pendulum_length is None or pendulum_length <= 0:
return _json_error('摆长参数无效', 400)
# Compute core analysis
t, theta, L, pivot = calculate_angle_from_pendulum_length(data, pendulum_length)
# 计算像素到毫米的转换系数(用于可视化)
x_pixels = data['x'].values
y_pixels = data['y'].values
_, _, r_pixels = fit_circle(x_pixels, y_pixels)
mm_per_pixel = (pendulum_length * 1000) / r_pixels
# Smooth safely
window_length = min(11, len(theta) - 1)
if window_length < 3:
return _json_error('数据点过少,无法进行分析', 400)
if window_length % 2 == 0:
window_length -= 1
theta_smooth = savgol_filter(theta, window_length=window_length, polyorder=2)
# Choose analysis range
if len(theta) <= 200:
start_idx = 10
end_idx = max(len(theta) - 10, start_idx + 5)
else:
start_idx = min(100, len(theta) // 10)
end_idx = min(20000, len(theta) - 10)
t_analyze = t[start_idx:end_idx]
theta_analyze = theta_smooth[start_idx:end_idx]
popt, pcov = fit_pendulum_model(t_analyze, theta_analyze)
if popt is None:
return _json_error('拟合模型失败', 400)
theta_fit = damped_oscillator(t_analyze, *popt)
# Allocate session & write images
storage = LocalSessionStorage()
session_id = storage.new_session()
# Figure 1
fig1_path = os.path.join(session_id, 'pendulum_fit_analysis.png')
visualize_results(t_analyze, theta[start_idx:end_idx], theta_analyze, theta_fit, popt, data, pivot, mm_per_pixel, output_path=os.path.join(storage.base_dir, fig1_path))
# Physical params
g, b_over_m, g_no_damping = calculate_physical_parameters(L, popt)
# Periods + corrected g
periods = identify_periods(t_analyze, theta_analyze)
theta0, gamma, omega, phi = popt
g_corrected = calculate_large_angle_g(L, periods, gamma)
# Figure 2
fig2_path = os.path.join(session_id, 'pendulum_parameters_summary.png')
visualize_physical_parameters(L, popt, g, b_over_m, g_no_damping, g_corrected, output_path=os.path.join(storage.base_dir, fig2_path))
# Figure 3 (period analysis)
if periods:
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure(figsize=(10, 6))
try:
amplitudes = [p['max_amplitude'] for p in periods]
durations = [p['duration'] for p in periods]
def filter_period_outliers(amplitudes, durations, threshold=0.008):
duration_mean = np.mean(durations)
out_a, out_d = [], []
for amp, dur in zip(amplitudes, durations):
deviation = abs(dur - duration_mean) / duration_mean
if deviation <= threshold:
out_a.append(amp); out_d.append(dur)
return out_a, out_d
plt.subplot(2, 1, 1)
f_amp, f_dur = filter_period_outliers(amplitudes, durations)
plt.plot(f_amp, f_dur, 'bo-')
plt.xlabel('最大摆角 (rad)'); plt.ylabel('周期 T (s)'); plt.title('周期与最大摆角的关系'); plt.grid(True)
theta_theory = np.linspace(0, max(amplitudes) * 1.1, 100)
T0 = 2 * np.pi * np.sqrt(L / (g_corrected or g))
T_angle_only = T0 * (1 + theta_theory**2/16 + 11*theta_theory**4/3072)
plt.plot(theta_theory, T_angle_only, 'r-', label='大摆角修正'); plt.legend()
plt.subplot(2, 1, 2)
g_values_angle = []
for p in periods:
theta_max = p['max_amplitude']
T = p['duration']
angle_correction = 1 + theta_max**2/16 + 11*theta_max**4/3072
g_period = (4 * np.pi**2 * L) / (T**2 * angle_correction**2)
g_values_angle.append(g_period)
g_values_combined = []
for p in periods:
theta_max = p['max_amplitude']
T = p['duration']
g_approx = (4 * np.pi**2 * L) / T**2
omega0_approx = np.sqrt(g_approx / L)
angle_correction = 1 + theta_max**2/16 + 11*theta_max**4/3072
damping_factor = 1 + gamma**2/(8 * omega0_approx**2)
combined_correction = angle_correction * damping_factor
g_period = (4 * np.pi**2 * L) / (T**2 / combined_correction**2)
g_values_combined.append(g_period)
def filter_outliers(amplitudes, g_values, threshold=0.01):
g_mean = np.mean(g_values)
out_a, out_g = [], []
for amp, gv in zip(amplitudes, g_values):
deviation = abs(gv - g_mean) / g_mean
if deviation <= threshold:
out_a.append(amp); out_g.append(gv)
return out_a, out_g
fa1, fg1 = filter_outliers(amplitudes, g_values_angle)
fa2, fg2 = filter_outliers(amplitudes, g_values_combined)
plt.plot(fa1, fg1, 'ro-', label='仅大摆角修正')
plt.plot(fa2, fg2, 'go-', label='大摆角+阻尼修正')
plt.axhline(y=9.8, color='k', linestyle='--', label='标准重力加速度 9.8 m/s²')
plt.xlabel('最大摆角 (rad)'); plt.ylabel('计算的重力加速度 (m/s²)'); plt.title('各个周期计算的重力加速度 (不同修正方法)')
plt.grid(True); plt.legend();
fig3_full = os.path.join(storage.base_dir, session_id, 'pendulum_period_analysis.png')
fig.tight_layout(); fig.savefig(fig3_full, dpi=300)
finally:
plt.close(fig)
# Response payload
analysis_data = {
'pendulum_length_m': round(float(L), 6),
'gravity_acceleration_m_s2': round(float(g_corrected or g), 6),
'damping_coefficient_s_inv': round(float(gamma), 6),
'period_s': round(float(2*np.pi/omega), 6),
'angle_max_rad': round(float(popt[0]), 6),
}
base = "/api/images"
images = [
f"{base}/{session_id}/pendulum_fit_analysis.png",
f"{base}/{session_id}/pendulum_parameters_summary.png",
f"{base}/{session_id}/pendulum_period_analysis.png",
]
return jsonify({
'success': True,
'session_id': session_id,
'images': images,
'analysis_data': analysis_data,
})
@bp.route('/images/<session_id>/<path:filename>', methods=['GET'])
def get_image(session_id: str, filename: str):
# serve safely from session dir
storage = LocalSessionStorage()
directory = storage._session_path(session_id)
return send_from_directory(directory, filename)
@bp.route('/sessions/<session_id>', methods=['DELETE'])
def delete_session(session_id: str):
storage = LocalSessionStorage()
ok = storage.delete_session(session_id)
return jsonify({'success': ok})