feat: 添加预测结果保存功能,用于分析预测质量问题

This commit is contained in:
Charles 2025-08-26 16:35:34 +08:00
parent 1f394cace3
commit 609235c077

View File

@ -8,6 +8,7 @@ from flask import Flask, render_template, request, jsonify
from flask_cors import CORS
import sys
import warnings
import datetime
warnings.filterwarnings('ignore')
# 添加项目根目录到路径
@ -121,6 +122,90 @@ def load_data_file(file_path):
except Exception as e:
return None, f"加载文件失败: {str(e)}"
def save_prediction_results(file_path, prediction_type, prediction_results, actual_data, input_data, prediction_params):
"""保存预测结果到文件"""
try:
# 创建预测结果目录
results_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'prediction_results')
os.makedirs(results_dir, exist_ok=True)
# 生成文件名
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'prediction_{timestamp}.json'
filepath = os.path.join(results_dir, filename)
# 准备保存的数据
save_data = {
'timestamp': datetime.datetime.now().isoformat(),
'file_path': file_path,
'prediction_type': prediction_type,
'prediction_params': prediction_params,
'input_data_summary': {
'rows': len(input_data),
'columns': list(input_data.columns),
'price_range': {
'open': {'min': float(input_data['open'].min()), 'max': float(input_data['open'].max())},
'high': {'min': float(input_data['high'].min()), 'max': float(input_data['high'].max())},
'low': {'min': float(input_data['low'].min()), 'max': float(input_data['low'].max())},
'close': {'min': float(input_data['close'].min()), 'max': float(input_data['close'].max())}
},
'last_values': {
'open': float(input_data['open'].iloc[-1]),
'high': float(input_data['high'].iloc[-1]),
'low': float(input_data['low'].iloc[-1]),
'close': float(input_data['close'].iloc[-1])
}
},
'prediction_results': prediction_results,
'actual_data': actual_data,
'analysis': {}
}
# 如果有实际数据,进行对比分析
if actual_data and len(actual_data) > 0:
# 计算连续性分析
if len(prediction_results) > 0 and len(actual_data) > 0:
last_pred = prediction_results[0] # 第一个预测点
first_actual = actual_data[0] # 第一个实际点
save_data['analysis']['continuity'] = {
'last_prediction': {
'open': last_pred['open'],
'high': last_pred['high'],
'low': last_pred['low'],
'close': last_pred['close']
},
'first_actual': {
'open': first_actual['open'],
'high': first_actual['high'],
'low': first_actual['low'],
'close': first_actual['close']
},
'gaps': {
'open_gap': abs(last_pred['open'] - first_actual['open']),
'high_gap': abs(last_pred['high'] - first_actual['high']),
'low_gap': abs(last_pred['low'] - first_actual['low']),
'close_gap': abs(last_pred['close'] - first_actual['close'])
},
'gap_percentages': {
'open_gap_pct': (abs(last_pred['open'] - first_actual['open']) / first_actual['open']) * 100,
'high_gap_pct': (abs(last_pred['high'] - first_actual['high']) / first_actual['high']) * 100,
'low_gap_pct': (abs(last_pred['low'] - first_actual['low']) / first_actual['low']) * 100,
'close_gap_pct': (abs(last_pred['close'] - first_actual['close']) / first_actual['close']) * 100
}
}
# 保存到文件
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(save_data, f, indent=2, ensure_ascii=False)
print(f"预测结果已保存到: {filepath}")
return filepath
except Exception as e:
print(f"保存预测结果失败: {e}")
return None
def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0):
"""创建预测图表"""
# 使用指定的历史数据起始位置而不是总是从df的开头开始
@ -505,6 +590,26 @@ def predict():
'amount': float(row['amount']) if 'amount' in row else 0
})
# 保存预测结果到文件
try:
save_prediction_results(
file_path=file_path,
prediction_type=prediction_type,
prediction_results=prediction_results,
actual_data=actual_data,
input_data=x_df,
prediction_params={
'lookback': lookback,
'pred_len': pred_len,
'temperature': temperature,
'top_p': top_p,
'sample_count': sample_count,
'start_date': start_date if start_date else 'latest'
}
)
except Exception as e:
print(f"保存预测结果失败: {e}")
return jsonify({
'success': True,
'prediction_type': prediction_type,