709 lines
29 KiB
Python
709 lines
29 KiB
Python
import os
|
||
import pandas as pd
|
||
import numpy as np
|
||
import json
|
||
import plotly.graph_objects as go
|
||
import plotly.utils
|
||
from flask import Flask, render_template, request, jsonify
|
||
from flask_cors import CORS
|
||
import sys
|
||
import warnings
|
||
import datetime
|
||
warnings.filterwarnings('ignore')
|
||
|
||
# 添加项目根目录到路径
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
try:
|
||
from model import Kronos, KronosTokenizer, KronosPredictor
|
||
MODEL_AVAILABLE = True
|
||
except ImportError:
|
||
MODEL_AVAILABLE = False
|
||
print("警告: Kronos模型无法导入,将使用模拟数据进行演示")
|
||
|
||
app = Flask(__name__)
|
||
CORS(app)
|
||
|
||
# 全局变量存储模型
|
||
tokenizer = None
|
||
model = None
|
||
predictor = None
|
||
|
||
# 可用的模型配置
|
||
AVAILABLE_MODELS = {
|
||
'kronos-mini': {
|
||
'name': 'Kronos-mini',
|
||
'model_id': 'NeoQuasar/Kronos-mini',
|
||
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-2k',
|
||
'context_length': 2048,
|
||
'params': '4.1M',
|
||
'description': '轻量级模型,适合快速预测'
|
||
},
|
||
'kronos-small': {
|
||
'name': 'Kronos-small',
|
||
'model_id': 'NeoQuasar/Kronos-small',
|
||
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base',
|
||
'context_length': 512,
|
||
'params': '24.7M',
|
||
'description': '小型模型,平衡性能和速度'
|
||
},
|
||
'kronos-base': {
|
||
'name': 'Kronos-base',
|
||
'model_id': 'NeoQuasar/Kronos-base',
|
||
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base',
|
||
'context_length': 512,
|
||
'params': '102.3M',
|
||
'description': '基础模型,提供更好的预测质量'
|
||
}
|
||
}
|
||
|
||
def load_data_files():
|
||
"""扫描data目录并返回可用的数据文件"""
|
||
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data')
|
||
data_files = []
|
||
|
||
if os.path.exists(data_dir):
|
||
for file in os.listdir(data_dir):
|
||
if file.endswith(('.csv', '.feather')):
|
||
file_path = os.path.join(data_dir, file)
|
||
file_size = os.path.getsize(file_path)
|
||
data_files.append({
|
||
'name': file,
|
||
'path': file_path,
|
||
'size': f"{file_size / 1024:.1f} KB" if file_size < 1024*1024 else f"{file_size / (1024*1024):.1f} MB"
|
||
})
|
||
|
||
return data_files
|
||
|
||
def load_data_file(file_path):
|
||
"""加载数据文件"""
|
||
try:
|
||
if file_path.endswith('.csv'):
|
||
df = pd.read_csv(file_path)
|
||
elif file_path.endswith('.feather'):
|
||
df = pd.read_feather(file_path)
|
||
else:
|
||
return None, "不支持的文件格式"
|
||
|
||
# 检查必要的列
|
||
required_cols = ['open', 'high', 'low', 'close']
|
||
if not all(col in df.columns for col in required_cols):
|
||
return None, f"缺少必要的列: {required_cols}"
|
||
|
||
# 处理时间戳列
|
||
if 'timestamps' in df.columns:
|
||
df['timestamps'] = pd.to_datetime(df['timestamps'])
|
||
elif 'timestamp' in df.columns:
|
||
df['timestamps'] = pd.to_datetime(df['timestamp'])
|
||
elif 'date' in df.columns:
|
||
# 如果列名是'date',将其重命名为'timestamps'
|
||
df['timestamps'] = pd.to_datetime(df['date'])
|
||
else:
|
||
# 如果没有时间戳列,创建一个
|
||
df['timestamps'] = pd.date_range(start='2024-01-01', periods=len(df), freq='1H')
|
||
|
||
# 确保数值列是数值类型
|
||
for col in ['open', 'high', 'low', 'close']:
|
||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||
|
||
# 处理volume列(可选)
|
||
if 'volume' in df.columns:
|
||
df['volume'] = pd.to_numeric(df['volume'], errors='coerce')
|
||
|
||
# 处理amount列(可选,但不用于预测)
|
||
if 'amount' in df.columns:
|
||
df['amount'] = pd.to_numeric(df['amount'], errors='coerce')
|
||
|
||
# 删除包含NaN的行
|
||
df = df.dropna()
|
||
|
||
return df, None
|
||
|
||
except Exception as e:
|
||
return None, f"加载文件失败: {str(e)}"
|
||
|
||
def save_prediction_results(file_path, prediction_type, prediction_results, actual_data, input_data, prediction_params):
|
||
"""保存预测结果到文件"""
|
||
try:
|
||
# 创建预测结果目录
|
||
results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prediction_results')
|
||
os.makedirs(results_dir, exist_ok=True)
|
||
|
||
# 生成文件名
|
||
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
filename = f'prediction_{timestamp}.json'
|
||
filepath = os.path.join(results_dir, filename)
|
||
|
||
# 准备保存的数据
|
||
save_data = {
|
||
'timestamp': datetime.datetime.now().isoformat(),
|
||
'file_path': file_path,
|
||
'prediction_type': prediction_type,
|
||
'prediction_params': prediction_params,
|
||
'input_data_summary': {
|
||
'rows': len(input_data),
|
||
'columns': list(input_data.columns),
|
||
'price_range': {
|
||
'open': {'min': float(input_data['open'].min()), 'max': float(input_data['open'].max())},
|
||
'high': {'min': float(input_data['high'].min()), 'max': float(input_data['high'].max())},
|
||
'low': {'min': float(input_data['low'].min()), 'max': float(input_data['low'].max())},
|
||
'close': {'min': float(input_data['close'].min()), 'max': float(input_data['close'].max())}
|
||
},
|
||
'last_values': {
|
||
'open': float(input_data['open'].iloc[-1]),
|
||
'high': float(input_data['high'].iloc[-1]),
|
||
'low': float(input_data['low'].iloc[-1]),
|
||
'close': float(input_data['close'].iloc[-1])
|
||
}
|
||
},
|
||
'prediction_results': prediction_results,
|
||
'actual_data': actual_data,
|
||
'analysis': {}
|
||
}
|
||
|
||
# 如果有实际数据,进行对比分析
|
||
if actual_data and len(actual_data) > 0:
|
||
# 计算连续性分析
|
||
if len(prediction_results) > 0 and len(actual_data) > 0:
|
||
last_pred = prediction_results[0] # 第一个预测点
|
||
first_actual = actual_data[0] # 第一个实际点
|
||
|
||
save_data['analysis']['continuity'] = {
|
||
'last_prediction': {
|
||
'open': last_pred['open'],
|
||
'high': last_pred['high'],
|
||
'low': last_pred['low'],
|
||
'close': last_pred['close']
|
||
},
|
||
'first_actual': {
|
||
'open': first_actual['open'],
|
||
'high': first_actual['high'],
|
||
'low': first_actual['low'],
|
||
'close': first_actual['close']
|
||
},
|
||
'gaps': {
|
||
'open_gap': abs(last_pred['open'] - first_actual['open']),
|
||
'high_gap': abs(last_pred['high'] - first_actual['high']),
|
||
'low_gap': abs(last_pred['low'] - first_actual['low']),
|
||
'close_gap': abs(last_pred['close'] - first_actual['close'])
|
||
},
|
||
'gap_percentages': {
|
||
'open_gap_pct': (abs(last_pred['open'] - first_actual['open']) / first_actual['open']) * 100,
|
||
'high_gap_pct': (abs(last_pred['high'] - first_actual['high']) / first_actual['high']) * 100,
|
||
'low_gap_pct': (abs(last_pred['low'] - first_actual['low']) / first_actual['low']) * 100,
|
||
'close_gap_pct': (abs(last_pred['close'] - first_actual['close']) / first_actual['close']) * 100
|
||
}
|
||
}
|
||
|
||
# 保存到文件
|
||
with open(filepath, 'w', encoding='utf-8') as f:
|
||
json.dump(save_data, f, indent=2, ensure_ascii=False)
|
||
|
||
print(f"预测结果已保存到: {filepath}")
|
||
return filepath
|
||
|
||
except Exception as e:
|
||
print(f"保存预测结果失败: {e}")
|
||
return None
|
||
|
||
def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0):
|
||
"""创建预测图表"""
|
||
# 使用指定的历史数据起始位置,而不是总是从df的开头开始
|
||
if historical_start_idx + lookback + pred_len <= len(df):
|
||
# 显示指定位置开始的lookback个历史点 + pred_len个预测点
|
||
historical_df = df.iloc[historical_start_idx:historical_start_idx+lookback]
|
||
prediction_range = range(historical_start_idx+lookback, historical_start_idx+lookback+pred_len)
|
||
else:
|
||
# 如果数据不够,调整到可用的最大范围
|
||
available_lookback = min(lookback, len(df) - historical_start_idx)
|
||
available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback))
|
||
historical_df = df.iloc[historical_start_idx:historical_start_idx+available_lookback]
|
||
prediction_range = range(historical_start_idx+available_lookback, historical_start_idx+available_lookback+available_pred_len)
|
||
|
||
# 创建图表
|
||
fig = go.Figure()
|
||
|
||
# 添加历史数据(K线图)
|
||
fig.add_trace(go.Candlestick(
|
||
x=historical_df['timestamps'] if 'timestamps' in historical_df.columns else historical_df.index,
|
||
open=historical_df['open'],
|
||
high=historical_df['high'],
|
||
low=historical_df['low'],
|
||
close=historical_df['close'],
|
||
name='历史数据 (400个数据点)',
|
||
increasing_line_color='#26A69A',
|
||
decreasing_line_color='#EF5350'
|
||
))
|
||
|
||
# 添加预测数据(K线图)
|
||
if pred_df is not None and len(pred_df) > 0:
|
||
# 计算预测数据的时间戳 - 确保与历史数据连续
|
||
if 'timestamps' in df.columns and len(historical_df) > 0:
|
||
# 从历史数据的最后一个时间点开始,按相同的时间间隔创建预测时间戳
|
||
last_timestamp = historical_df['timestamps'].iloc[-1]
|
||
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
|
||
|
||
pred_timestamps = pd.date_range(
|
||
start=last_timestamp + time_diff,
|
||
periods=len(pred_df),
|
||
freq=time_diff
|
||
)
|
||
else:
|
||
# 如果没有时间戳,使用索引
|
||
pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df))
|
||
|
||
fig.add_trace(go.Candlestick(
|
||
x=pred_timestamps,
|
||
open=pred_df['open'],
|
||
high=pred_df['high'],
|
||
low=pred_df['low'],
|
||
close=pred_df['close'],
|
||
name='预测数据 (120个数据点)',
|
||
increasing_line_color='#66BB6A',
|
||
decreasing_line_color='#FF7043'
|
||
))
|
||
|
||
# 添加实际数据用于对比(如果存在)
|
||
if actual_df is not None and len(actual_df) > 0:
|
||
# 实际数据应该与预测数据在同一个时间段
|
||
if 'timestamps' in df.columns:
|
||
# 实际数据应该使用与预测数据相同的时间戳,确保时间对齐
|
||
if 'pred_timestamps' in locals():
|
||
actual_timestamps = pred_timestamps
|
||
else:
|
||
# 如果没有预测时间戳,从历史数据最后一个时间点开始计算
|
||
if len(historical_df) > 0:
|
||
last_timestamp = historical_df['timestamps'].iloc[-1]
|
||
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
|
||
actual_timestamps = pd.date_range(
|
||
start=last_timestamp + time_diff,
|
||
periods=len(actual_df),
|
||
freq=time_diff
|
||
)
|
||
else:
|
||
actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
|
||
else:
|
||
actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
|
||
|
||
fig.add_trace(go.Candlestick(
|
||
x=actual_timestamps,
|
||
open=actual_df['open'],
|
||
high=actual_df['high'],
|
||
low=actual_df['low'],
|
||
close=actual_df['close'],
|
||
name='实际数据 (120个数据点)',
|
||
increasing_line_color='#FF9800',
|
||
decreasing_line_color='#F44336'
|
||
))
|
||
|
||
# 更新布局
|
||
fig.update_layout(
|
||
title='Kronos 金融预测结果 - 400个历史点 + 120个预测点 vs 120个实际点',
|
||
xaxis_title='时间',
|
||
yaxis_title='价格',
|
||
template='plotly_white',
|
||
height=600,
|
||
showlegend=True
|
||
)
|
||
|
||
# 确保x轴时间连续
|
||
if 'timestamps' in historical_df.columns:
|
||
# 获取所有时间戳并排序
|
||
all_timestamps = []
|
||
if len(historical_df) > 0:
|
||
all_timestamps.extend(historical_df['timestamps'])
|
||
if 'pred_timestamps' in locals():
|
||
all_timestamps.extend(pred_timestamps)
|
||
if 'actual_timestamps' in locals():
|
||
all_timestamps.extend(actual_timestamps)
|
||
|
||
if all_timestamps:
|
||
all_timestamps = sorted(all_timestamps)
|
||
fig.update_xaxes(
|
||
range=[all_timestamps[0], all_timestamps[-1]],
|
||
rangeslider_visible=False,
|
||
type='date'
|
||
)
|
||
|
||
return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
|
||
|
||
@app.route('/')
|
||
def index():
|
||
"""主页"""
|
||
return render_template('index.html')
|
||
|
||
@app.route('/api/data-files')
|
||
def get_data_files():
|
||
"""获取可用的数据文件列表"""
|
||
data_files = load_data_files()
|
||
return jsonify(data_files)
|
||
|
||
@app.route('/api/load-data', methods=['POST'])
|
||
def load_data():
|
||
"""加载数据文件"""
|
||
try:
|
||
data = request.get_json()
|
||
file_path = data.get('file_path')
|
||
|
||
if not file_path:
|
||
return jsonify({'error': '文件路径不能为空'}), 400
|
||
|
||
df, error = load_data_file(file_path)
|
||
if error:
|
||
return jsonify({'error': error}), 400
|
||
|
||
# 检测数据的时间频率
|
||
def detect_timeframe(df):
|
||
if len(df) < 2:
|
||
return "未知"
|
||
|
||
time_diffs = []
|
||
for i in range(1, min(10, len(df))): # 检查前10个时间差
|
||
diff = df['timestamps'].iloc[i] - df['timestamps'].iloc[i-1]
|
||
time_diffs.append(diff)
|
||
|
||
if not time_diffs:
|
||
return "未知"
|
||
|
||
# 计算平均时间差
|
||
avg_diff = sum(time_diffs, pd.Timedelta(0)) / len(time_diffs)
|
||
|
||
# 转换为可读格式
|
||
if avg_diff < pd.Timedelta(minutes=1):
|
||
return f"{avg_diff.total_seconds():.0f}秒"
|
||
elif avg_diff < pd.Timedelta(hours=1):
|
||
return f"{avg_diff.total_seconds() / 60:.0f}分钟"
|
||
elif avg_diff < pd.Timedelta(days=1):
|
||
return f"{avg_diff.total_seconds() / 3600:.0f}小时"
|
||
else:
|
||
return f"{avg_diff.days}天"
|
||
|
||
# 返回数据信息
|
||
data_info = {
|
||
'rows': len(df),
|
||
'columns': list(df.columns),
|
||
'start_date': df['timestamps'].min().isoformat() if 'timestamps' in df.columns else 'N/A',
|
||
'end_date': df['timestamps'].max().isoformat() if 'timestamps' in df.columns else 'N/A',
|
||
'price_range': {
|
||
'min': float(df[['open', 'high', 'low', 'close']].min().min()),
|
||
'max': float(df[['open', 'high', 'low', 'close']].max().max())
|
||
},
|
||
'prediction_columns': ['open', 'high', 'low', 'close'] + (['volume'] if 'volume' in df.columns else []),
|
||
'timeframe': detect_timeframe(df)
|
||
}
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data_info': data_info,
|
||
'message': f'成功加载数据,共 {len(df)} 行'
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': f'加载数据失败: {str(e)}'}), 500
|
||
|
||
@app.route('/api/predict', methods=['POST'])
|
||
def predict():
|
||
"""进行预测"""
|
||
try:
|
||
data = request.get_json()
|
||
file_path = data.get('file_path')
|
||
lookback = int(data.get('lookback', 400))
|
||
pred_len = int(data.get('pred_len', 120))
|
||
|
||
# 获取预测质量参数
|
||
temperature = float(data.get('temperature', 1.0))
|
||
top_p = float(data.get('top_p', 0.9))
|
||
sample_count = int(data.get('sample_count', 1))
|
||
|
||
if not file_path:
|
||
return jsonify({'error': '文件路径不能为空'}), 400
|
||
|
||
# 加载数据
|
||
df, error = load_data_file(file_path)
|
||
if error:
|
||
return jsonify({'error': error}), 400
|
||
|
||
if len(df) < lookback:
|
||
return jsonify({'error': f'数据长度不足,需要至少 {lookback} 行数据'}), 400
|
||
|
||
# 进行预测
|
||
if MODEL_AVAILABLE and predictor is not None:
|
||
try:
|
||
# 使用真实的Kronos模型
|
||
# 只使用必要的列:OHLCV,不包含amount
|
||
required_cols = ['open', 'high', 'low', 'close']
|
||
if 'volume' in df.columns:
|
||
required_cols.append('volume')
|
||
|
||
# 处理时间段选择
|
||
start_date = data.get('start_date')
|
||
|
||
if start_date:
|
||
# 自定义时间段 - 修复逻辑:使用选择的窗口内的数据
|
||
start_dt = pd.to_datetime(start_date)
|
||
|
||
# 找到开始时间之后的数据
|
||
mask = df['timestamps'] >= start_dt
|
||
time_range_df = df[mask]
|
||
|
||
# 确保有足够的数据:lookback + pred_len
|
||
if len(time_range_df) < lookback + pred_len:
|
||
return jsonify({'error': f'从开始时间 {start_dt.strftime("%Y-%m-%d %H:%M")} 开始的数据不足,需要至少 {lookback + pred_len} 个数据点,当前只有 {len(time_range_df)} 个'}), 400
|
||
|
||
# 使用选择的窗口内的前lookback个数据点进行预测
|
||
x_df = time_range_df.iloc[:lookback][required_cols]
|
||
x_timestamp = time_range_df.iloc[:lookback]['timestamps']
|
||
|
||
# 使用选择的窗口内的后pred_len个数据点作为实际值
|
||
y_timestamp = time_range_df.iloc[lookback:lookback+pred_len]['timestamps']
|
||
|
||
# 计算实际的时间段长度
|
||
start_timestamp = time_range_df['timestamps'].iloc[0]
|
||
end_timestamp = time_range_df['timestamps'].iloc[lookback+pred_len-1]
|
||
time_span = end_timestamp - start_timestamp
|
||
|
||
prediction_type = f"Kronos模型预测 (选择的窗口内:前{lookback}个数据点预测,后{pred_len}个数据点对比,时间跨度: {time_span})"
|
||
else:
|
||
# 使用最新数据
|
||
x_df = df.iloc[:lookback][required_cols]
|
||
x_timestamp = df.iloc[:lookback]['timestamps']
|
||
y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps']
|
||
prediction_type = "Kronos模型预测 (最新数据)"
|
||
|
||
# 确保时间戳是Series格式,不是DatetimeIndex,避免Kronos模型的.dt属性错误
|
||
if isinstance(x_timestamp, pd.DatetimeIndex):
|
||
x_timestamp = pd.Series(x_timestamp, name='timestamps')
|
||
if isinstance(y_timestamp, pd.DatetimeIndex):
|
||
y_timestamp = pd.Series(y_timestamp, name='timestamps')
|
||
|
||
pred_df = predictor.predict(
|
||
df=x_df,
|
||
x_timestamp=x_timestamp,
|
||
y_timestamp=y_timestamp,
|
||
pred_len=pred_len,
|
||
T=temperature,
|
||
top_p=top_p,
|
||
sample_count=sample_count
|
||
)
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': f'Kronos模型预测失败: {str(e)}'}), 500
|
||
else:
|
||
return jsonify({'error': 'Kronos模型未加载,请先加载模型'}), 400
|
||
|
||
# 准备实际数据用于对比(如果存在)
|
||
actual_data = []
|
||
actual_df = None
|
||
|
||
if start_date: # 自定义时间段
|
||
# 修复逻辑:使用选择的窗口内的数据
|
||
# 预测使用的是选择的窗口内的前400个数据点
|
||
# 实际数据应该是选择的窗口内的后120个数据点
|
||
start_dt = pd.to_datetime(start_date)
|
||
|
||
# 找到从start_date开始的数据
|
||
mask = df['timestamps'] >= start_dt
|
||
time_range_df = df[mask]
|
||
|
||
if len(time_range_df) >= lookback + pred_len:
|
||
# 获取选择的窗口内的后120个数据点作为实际值
|
||
actual_df = time_range_df.iloc[lookback:lookback+pred_len]
|
||
|
||
for i, (_, row) in enumerate(actual_df.iterrows()):
|
||
actual_data.append({
|
||
'timestamp': row['timestamps'].isoformat(),
|
||
'open': float(row['open']),
|
||
'high': float(row['high']),
|
||
'low': float(row['low']),
|
||
'close': float(row['close']),
|
||
'volume': float(row['volume']) if 'volume' in row else 0,
|
||
'amount': float(row['amount']) if 'amount' in row else 0
|
||
})
|
||
else: # 最新数据
|
||
# 预测使用的是前400个数据点
|
||
# 实际数据应该是400个数据点之后的120个数据点
|
||
if len(df) >= lookback + pred_len:
|
||
actual_df = df.iloc[lookback:lookback+pred_len]
|
||
for i, (_, row) in enumerate(actual_df.iterrows()):
|
||
actual_data.append({
|
||
'timestamp': row['timestamps'].isoformat(),
|
||
'open': float(row['open']),
|
||
'high': float(row['high']),
|
||
'low': float(row['low']),
|
||
'close': float(row['close']),
|
||
'volume': float(row['volume']) if 'volume' in row else 0,
|
||
'amount': float(row['amount']) if 'amount' in row else 0
|
||
})
|
||
|
||
# 创建图表 - 传递历史数据的起始位置
|
||
if start_date:
|
||
# 自定义时间段:找到历史数据在原始df中的起始位置
|
||
start_dt = pd.to_datetime(start_date)
|
||
mask = df['timestamps'] >= start_dt
|
||
historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0
|
||
else:
|
||
# 最新数据:从开头开始
|
||
historical_start_idx = 0
|
||
|
||
chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx)
|
||
|
||
# 准备预测结果数据 - 修复时间戳计算逻辑
|
||
if 'timestamps' in df.columns:
|
||
if start_date:
|
||
# 自定义时间段:使用选择的窗口数据计算时间戳
|
||
start_dt = pd.to_datetime(start_date)
|
||
mask = df['timestamps'] >= start_dt
|
||
time_range_df = df[mask]
|
||
|
||
if len(time_range_df) >= lookback:
|
||
# 从选择的窗口的最后一个时间点开始计算预测时间戳
|
||
last_timestamp = time_range_df['timestamps'].iloc[lookback-1]
|
||
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
|
||
future_timestamps = pd.date_range(
|
||
start=last_timestamp + time_diff,
|
||
periods=pred_len,
|
||
freq=time_diff
|
||
)
|
||
else:
|
||
future_timestamps = []
|
||
else:
|
||
# 最新数据:从整个数据文件的最后时间点开始计算
|
||
last_timestamp = df['timestamps'].iloc[-1]
|
||
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
|
||
future_timestamps = pd.date_range(
|
||
start=last_timestamp + time_diff,
|
||
periods=pred_len,
|
||
freq=time_diff
|
||
)
|
||
else:
|
||
future_timestamps = range(len(df), len(df) + pred_len)
|
||
|
||
prediction_results = []
|
||
for i, (_, row) in enumerate(pred_df.iterrows()):
|
||
prediction_results.append({
|
||
'timestamp': future_timestamps[i].isoformat() if i < len(future_timestamps) else f"T{i}",
|
||
'open': float(row['open']),
|
||
'high': float(row['high']),
|
||
'low': float(row['low']),
|
||
'close': float(row['close']),
|
||
'volume': float(row['volume']) if 'volume' in row else 0,
|
||
'amount': float(row['amount']) if 'amount' in row else 0
|
||
})
|
||
|
||
# 保存预测结果到文件
|
||
try:
|
||
save_prediction_results(
|
||
file_path=file_path,
|
||
prediction_type=prediction_type,
|
||
prediction_results=prediction_results,
|
||
actual_data=actual_data,
|
||
input_data=x_df,
|
||
prediction_params={
|
||
'lookback': lookback,
|
||
'pred_len': pred_len,
|
||
'temperature': temperature,
|
||
'top_p': top_p,
|
||
'sample_count': sample_count,
|
||
'start_date': start_date if start_date else 'latest'
|
||
}
|
||
)
|
||
except Exception as e:
|
||
print(f"保存预测结果失败: {e}")
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'prediction_type': prediction_type,
|
||
'chart': chart_json,
|
||
'prediction_results': prediction_results,
|
||
'actual_data': actual_data,
|
||
'has_comparison': len(actual_data) > 0,
|
||
'message': f'预测完成,生成了 {pred_len} 个预测点' + (f',包含 {len(actual_data)} 个实际数据点用于对比' if len(actual_data) > 0 else '')
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': f'预测失败: {str(e)}'}), 500
|
||
|
||
@app.route('/api/load-model', methods=['POST'])
|
||
def load_model():
|
||
"""加载Kronos模型"""
|
||
global tokenizer, model, predictor
|
||
|
||
try:
|
||
if not MODEL_AVAILABLE:
|
||
return jsonify({'error': 'Kronos模型库不可用'}), 400
|
||
|
||
data = request.get_json()
|
||
model_key = data.get('model_key', 'kronos-small')
|
||
device = data.get('device', 'cpu')
|
||
|
||
if model_key not in AVAILABLE_MODELS:
|
||
return jsonify({'error': f'不支持的模型: {model_key}'}), 400
|
||
|
||
model_config = AVAILABLE_MODELS[model_key]
|
||
|
||
# 加载tokenizer和模型
|
||
tokenizer = KronosTokenizer.from_pretrained(model_config['tokenizer_id'])
|
||
model = Kronos.from_pretrained(model_config['model_id'])
|
||
|
||
# 创建predictor
|
||
predictor = KronosPredictor(model, tokenizer, device=device, max_context=model_config['context_length'])
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'message': f'模型加载成功: {model_config["name"]} ({model_config["params"]}) on {device}',
|
||
'model_info': {
|
||
'name': model_config['name'],
|
||
'params': model_config['params'],
|
||
'context_length': model_config['context_length'],
|
||
'description': model_config['description']
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': f'模型加载失败: {str(e)}'}), 500
|
||
|
||
@app.route('/api/available-models')
|
||
def get_available_models():
|
||
"""获取可用的模型列表"""
|
||
return jsonify({
|
||
'models': AVAILABLE_MODELS,
|
||
'model_available': MODEL_AVAILABLE
|
||
})
|
||
|
||
@app.route('/api/model-status')
|
||
def get_model_status():
|
||
"""获取模型状态"""
|
||
if MODEL_AVAILABLE:
|
||
if predictor is not None:
|
||
return jsonify({
|
||
'available': True,
|
||
'loaded': True,
|
||
'message': 'Kronos模型已加载并可用',
|
||
'current_model': {
|
||
'name': predictor.model.__class__.__name__,
|
||
'device': str(next(predictor.model.parameters()).device)
|
||
}
|
||
})
|
||
else:
|
||
return jsonify({
|
||
'available': True,
|
||
'loaded': False,
|
||
'message': 'Kronos模型可用但未加载'
|
||
})
|
||
else:
|
||
return jsonify({
|
||
'available': False,
|
||
'loaded': False,
|
||
'message': 'Kronos模型库不可用,请安装相关依赖'
|
||
})
|
||
|
||
if __name__ == '__main__':
|
||
print("启动Kronos Web UI...")
|
||
print(f"模型可用性: {MODEL_AVAILABLE}")
|
||
if MODEL_AVAILABLE:
|
||
print("提示: 可以通过 /api/load-model 接口加载Kronos模型")
|
||
else:
|
||
print("提示: 将使用模拟数据进行演示")
|
||
|
||
app.run(debug=True, host='0.0.0.0', port=7070)
|