248 lines
9.4 KiB
Python
248 lines
9.4 KiB
Python
from __future__ import annotations
|
||
import io
|
||
import os
|
||
import logging
|
||
from flask import Blueprint, current_app, jsonify, request, send_from_directory
|
||
from werkzeug.utils import secure_filename
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
from .analysis import (
|
||
load_data,
|
||
calculate_angle_from_pendulum_length,
|
||
fit_pendulum_model,
|
||
damped_oscillator,
|
||
calculate_physical_parameters,
|
||
identify_periods,
|
||
visualize_results,
|
||
visualize_physical_parameters,
|
||
calculate_large_angle_g,
|
||
savgol_filter,
|
||
fit_circle,
|
||
)
|
||
from .storage import LocalSessionStorage
|
||
|
||
bp = Blueprint("api", __name__, url_prefix="/api")
|
||
log = logging.getLogger(__name__)
|
||
|
||
|
||
def _json_error(message: str, status: int = 400):
|
||
return jsonify({"success": False, "error": message}), status
|
||
|
||
|
||
@bp.before_app_request
|
||
def _housekeeping():
|
||
# opportunistic TTL cleanup each request
|
||
LocalSessionStorage().cleanup_expired()
|
||
|
||
|
||
@bp.route("/health", methods=["GET"])
|
||
def health():
|
||
return jsonify({"status": "ok"})
|
||
|
||
|
||
@bp.route("/analyze", methods=["POST"])
|
||
def analyze():
|
||
if 'csv_file' not in request.files:
|
||
return _json_error('没有上传CSV文件', 400)
|
||
|
||
file = request.files['csv_file']
|
||
if not file.filename:
|
||
return _json_error('空文件名', 400)
|
||
|
||
# Basic content type & size checks are handled by Flask + MAX_CONTENT_LENGTH
|
||
filename = secure_filename(file.filename)
|
||
|
||
# Read into memory buffer; Pandas can ingest file-like
|
||
buf = io.BytesIO(file.read())
|
||
buf.seek(0)
|
||
data = load_data(buf)
|
||
if data is None:
|
||
return _json_error('CSV文件格式错误或无法读取', 400)
|
||
|
||
# Column normalization and pendulum length extraction
|
||
cols = set(data.columns)
|
||
pendulum_length = None
|
||
|
||
if {'time', 'x', 'y', 'pendulum_length'}.issubset(cols):
|
||
# 新格式:包含摆长信息
|
||
pendulum_length = data['pendulum_length'].iloc[0] # 获取摆长
|
||
log.info(f"从CSV中读取摆长: {pendulum_length:.4f} 米")
|
||
elif {'time', 'x/mm', 'y/mm'}.issubset(cols):
|
||
# 兼容旧格式:已经是毫米单位,通过拟合圆推断摆长
|
||
log.warning("使用旧数据格式,将通过拟合圆推断摆长")
|
||
data = data.rename(columns={'x/mm': 'x', 'y/mm': 'y'})
|
||
# 临时计算摆长用于兼容
|
||
x_temp = data['x'].values
|
||
y_temp = data['y'].values
|
||
_, _, r_mm = fit_circle(x_temp, y_temp)
|
||
pendulum_length = r_mm / 1000.0 # 转换为米
|
||
elif {'time', 'x', 'y'}.issubset(cols):
|
||
return _json_error('缺少摆长信息,请在前端输入摆长参数', 400)
|
||
else:
|
||
return _json_error('数据格式不兼容,需要time、x、y列以及pendulum_length信息', 400)
|
||
|
||
if pendulum_length is None or pendulum_length <= 0:
|
||
return _json_error('摆长参数无效', 400)
|
||
|
||
# Compute core analysis
|
||
t, theta, L, pivot = calculate_angle_from_pendulum_length(data, pendulum_length)
|
||
|
||
# 计算像素到毫米的转换系数(用于可视化)
|
||
x_pixels = data['x'].values
|
||
y_pixels = data['y'].values
|
||
_, _, r_pixels = fit_circle(x_pixels, y_pixels)
|
||
mm_per_pixel = (pendulum_length * 1000) / r_pixels
|
||
|
||
# Smooth safely
|
||
window_length = min(11, len(theta) - 1)
|
||
if window_length < 3:
|
||
return _json_error('数据点过少,无法进行分析', 400)
|
||
if window_length % 2 == 0:
|
||
window_length -= 1
|
||
theta_smooth = savgol_filter(theta, window_length=window_length, polyorder=2)
|
||
|
||
# Choose analysis range
|
||
if len(theta) <= 200:
|
||
start_idx = 10
|
||
end_idx = max(len(theta) - 10, start_idx + 5)
|
||
else:
|
||
start_idx = min(100, len(theta) // 10)
|
||
end_idx = min(20000, len(theta) - 10)
|
||
|
||
t_analyze = t[start_idx:end_idx]
|
||
theta_analyze = theta_smooth[start_idx:end_idx]
|
||
|
||
popt, pcov = fit_pendulum_model(t_analyze, theta_analyze)
|
||
if popt is None:
|
||
return _json_error('拟合模型失败', 400)
|
||
|
||
theta_fit = damped_oscillator(t_analyze, *popt)
|
||
|
||
# Allocate session & write images
|
||
storage = LocalSessionStorage()
|
||
session_id = storage.new_session()
|
||
|
||
# Figure 1
|
||
fig1_path = os.path.join(session_id, 'pendulum_fit_analysis.png')
|
||
visualize_results(t_analyze, theta[start_idx:end_idx], theta_analyze, theta_fit, popt, data, pivot, mm_per_pixel, output_path=os.path.join(storage.base_dir, fig1_path))
|
||
|
||
# Physical params
|
||
g, b_over_m, g_no_damping = calculate_physical_parameters(L, popt)
|
||
|
||
# Periods + corrected g
|
||
periods = identify_periods(t_analyze, theta_analyze)
|
||
theta0, gamma, omega, phi = popt
|
||
g_corrected = calculate_large_angle_g(L, periods, gamma)
|
||
|
||
# Figure 2
|
||
fig2_path = os.path.join(session_id, 'pendulum_parameters_summary.png')
|
||
visualize_physical_parameters(L, popt, g, b_over_m, g_no_damping, g_corrected, output_path=os.path.join(storage.base_dir, fig2_path))
|
||
|
||
# Figure 3 (period analysis)
|
||
if periods:
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
fig = plt.figure(figsize=(10, 6))
|
||
try:
|
||
amplitudes = [p['max_amplitude'] for p in periods]
|
||
durations = [p['duration'] for p in periods]
|
||
|
||
def filter_period_outliers(amplitudes, durations, threshold=0.008):
|
||
duration_mean = np.mean(durations)
|
||
out_a, out_d = [], []
|
||
for amp, dur in zip(amplitudes, durations):
|
||
deviation = abs(dur - duration_mean) / duration_mean
|
||
if deviation <= threshold:
|
||
out_a.append(amp); out_d.append(dur)
|
||
return out_a, out_d
|
||
|
||
plt.subplot(2, 1, 1)
|
||
f_amp, f_dur = filter_period_outliers(amplitudes, durations)
|
||
plt.plot(f_amp, f_dur, 'bo-')
|
||
plt.xlabel('最大摆角 (rad)'); plt.ylabel('周期 T (s)'); plt.title('周期与最大摆角的关系'); plt.grid(True)
|
||
theta_theory = np.linspace(0, max(amplitudes) * 1.1, 100)
|
||
T0 = 2 * np.pi * np.sqrt(L / (g_corrected or g))
|
||
T_angle_only = T0 * (1 + theta_theory**2/16 + 11*theta_theory**4/3072)
|
||
plt.plot(theta_theory, T_angle_only, 'r-', label='大摆角修正'); plt.legend()
|
||
|
||
plt.subplot(2, 1, 2)
|
||
g_values_angle = []
|
||
for p in periods:
|
||
theta_max = p['max_amplitude']
|
||
T = p['duration']
|
||
angle_correction = 1 + theta_max**2/16 + 11*theta_max**4/3072
|
||
g_period = (4 * np.pi**2 * L) / (T**2 * angle_correction**2)
|
||
g_values_angle.append(g_period)
|
||
|
||
g_values_combined = []
|
||
for p in periods:
|
||
theta_max = p['max_amplitude']
|
||
T = p['duration']
|
||
g_approx = (4 * np.pi**2 * L) / T**2
|
||
omega0_approx = np.sqrt(g_approx / L)
|
||
angle_correction = 1 + theta_max**2/16 + 11*theta_max**4/3072
|
||
damping_factor = 1 + gamma**2/(8 * omega0_approx**2)
|
||
combined_correction = angle_correction * damping_factor
|
||
g_period = (4 * np.pi**2 * L) / (T**2 / combined_correction**2)
|
||
g_values_combined.append(g_period)
|
||
|
||
def filter_outliers(amplitudes, g_values, threshold=0.01):
|
||
g_mean = np.mean(g_values)
|
||
out_a, out_g = [], []
|
||
for amp, gv in zip(amplitudes, g_values):
|
||
deviation = abs(gv - g_mean) / g_mean
|
||
if deviation <= threshold:
|
||
out_a.append(amp); out_g.append(gv)
|
||
return out_a, out_g
|
||
|
||
fa1, fg1 = filter_outliers(amplitudes, g_values_angle)
|
||
fa2, fg2 = filter_outliers(amplitudes, g_values_combined)
|
||
|
||
plt.plot(fa1, fg1, 'ro-', label='仅大摆角修正')
|
||
plt.plot(fa2, fg2, 'go-', label='大摆角+阻尼修正')
|
||
plt.axhline(y=9.8, color='k', linestyle='--', label='标准重力加速度 9.8 m/s²')
|
||
plt.xlabel('最大摆角 (rad)'); plt.ylabel('计算的重力加速度 (m/s²)'); plt.title('各个周期计算的重力加速度 (不同修正方法)')
|
||
plt.grid(True); plt.legend();
|
||
fig3_full = os.path.join(storage.base_dir, session_id, 'pendulum_period_analysis.png')
|
||
fig.tight_layout(); fig.savefig(fig3_full, dpi=300)
|
||
finally:
|
||
plt.close(fig)
|
||
|
||
# Response payload
|
||
analysis_data = {
|
||
'pendulum_length_m': round(float(L), 6),
|
||
'gravity_acceleration_m_s2': round(float(g_corrected or g), 6),
|
||
'damping_coefficient_s_inv': round(float(gamma), 6),
|
||
'period_s': round(float(2*np.pi/omega), 6),
|
||
'angle_max_rad': round(float(popt[0]), 6),
|
||
}
|
||
|
||
base = "/api/images"
|
||
images = [
|
||
f"{base}/{session_id}/pendulum_fit_analysis.png",
|
||
f"{base}/{session_id}/pendulum_parameters_summary.png",
|
||
f"{base}/{session_id}/pendulum_period_analysis.png",
|
||
]
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'session_id': session_id,
|
||
'images': images,
|
||
'analysis_data': analysis_data,
|
||
})
|
||
|
||
|
||
@bp.route('/images/<session_id>/<path:filename>', methods=['GET'])
|
||
def get_image(session_id: str, filename: str):
|
||
# serve safely from session dir
|
||
storage = LocalSessionStorage()
|
||
directory = storage._session_path(session_id)
|
||
return send_from_directory(directory, filename)
|
||
|
||
|
||
@bp.route('/sessions/<session_id>', methods=['DELETE'])
|
||
def delete_session(session_id: str):
|
||
storage = LocalSessionStorage()
|
||
ok = storage.delete_session(session_id)
|
||
return jsonify({'success': ok}) |