fix: 解决Kronos模型预测连续性问题,通过价格转换保持数据连续性

This commit is contained in:
Charles 2025-08-26 17:02:24 +08:00
parent ca2502757f
commit b4b270954b

View File

@ -122,6 +122,31 @@ def load_data_file(file_path):
except Exception as e:
return None, f"加载文件失败: {str(e)}"
def convert_price_to_returns(df, price_cols=['open', 'high', 'low', 'close']):
"""将绝对价格转换为相对变化率解决Kronos模型的连续性问题"""
returns_df = df.copy()
# 计算相对变化率 (pct_change)
for col in price_cols:
if col in returns_df.columns:
returns_df[col] = returns_df[col].pct_change()
# 第一行会是NaN用0填充
returns_df = returns_df.fillna(0)
return returns_df
def convert_returns_to_price(returns_df, initial_prices, price_cols=['open', 'high', 'low', 'close']):
"""将相对变化率转换回绝对价格"""
price_df = returns_df.copy()
for col in price_cols:
if col in price_df.columns:
# 从初始价格开始,逐步计算绝对价格
price_df[col] = (1 + returns_df[col]) * initial_prices[col]
return price_df
def save_prediction_results(file_path, prediction_type, prediction_results, actual_data, input_data, prediction_params):
"""保存预测结果到文件"""
try:
@ -476,8 +501,21 @@ def predict():
if isinstance(y_timestamp, pd.DatetimeIndex):
y_timestamp = pd.Series(y_timestamp, name='timestamps')
pred_df = predictor.predict(
df=x_df,
# 解决Kronos模型连续性问题将绝对价格转换为相对变化率
original_x_df = x_df.copy()
initial_prices = {
'open': x_df['open'].iloc[0],
'high': x_df['high'].iloc[0],
'low': x_df['low'].iloc[0],
'close': x_df['close'].iloc[0]
}
# 转换为相对变化率
x_df_returns = convert_price_to_returns(x_df, ['open', 'high', 'low', 'close'])
# 使用转换后的数据进行预测
pred_df_returns = predictor.predict(
df=x_df_returns,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
@ -486,6 +524,15 @@ def predict():
sample_count=sample_count
)
# 将预测结果转换回绝对价格
pred_df = convert_returns_to_price(pred_df_returns, initial_prices, ['open', 'high', 'low', 'close'])
# 保持volume列不变如果存在
if 'volume' in pred_df.columns:
pred_df['volume'] = pred_df_returns['volume']
if 'amount' in pred_df.columns:
pred_df['amount'] = pred_df_returns['amount']
except Exception as e:
return jsonify({'error': f'Kronos模型预测失败: {str(e)}'}), 500
else: