init
2
.gitignore
vendored
@ -73,3 +73,5 @@ venv.bak/
|
|||||||
*.temp
|
*.temp
|
||||||
temp/
|
temp/
|
||||||
tmp/
|
tmp/
|
||||||
|
|
||||||
|
figures
|
||||||
@ -1,72 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import sys
|
|
||||||
sys.path.append("../")
|
|
||||||
from model import Kronos, KronosTokenizer, KronosPredictor
|
|
||||||
|
|
||||||
|
|
||||||
def plot_prediction(kline_df, pred_df):
|
|
||||||
pred_df.index = kline_df.index[-pred_df.shape[0]:]
|
|
||||||
sr_close = kline_df['close']
|
|
||||||
sr_pred_close = pred_df['close']
|
|
||||||
sr_close.name = 'Ground Truth'
|
|
||||||
sr_pred_close.name = "Prediction"
|
|
||||||
|
|
||||||
sr_volume = kline_df['volume']
|
|
||||||
sr_pred_volume = pred_df['volume']
|
|
||||||
sr_volume.name = 'Ground Truth'
|
|
||||||
sr_pred_volume.name = "Prediction"
|
|
||||||
|
|
||||||
close_df = pd.concat([sr_close, sr_pred_close], axis=1)
|
|
||||||
volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1)
|
|
||||||
|
|
||||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
|
|
||||||
|
|
||||||
ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
|
|
||||||
ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
|
|
||||||
ax1.set_ylabel('Close Price', fontsize=14)
|
|
||||||
ax1.legend(loc='lower left', fontsize=12)
|
|
||||||
ax1.grid(True)
|
|
||||||
|
|
||||||
ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
|
|
||||||
ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
|
|
||||||
ax2.set_ylabel('Volume', fontsize=14)
|
|
||||||
ax2.legend(loc='upper left', fontsize=12)
|
|
||||||
ax2.grid(True)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
# 1. Load Model and Tokenizer
|
|
||||||
tokenizer = KronosTokenizer.from_pretrained('/home/csc/huggingface/Kronos-Tokenizer-base/')
|
|
||||||
model = Kronos.from_pretrained("/home/csc/huggingface/Kronos-base/")
|
|
||||||
|
|
||||||
# 2. Instantiate Predictor
|
|
||||||
predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512)
|
|
||||||
|
|
||||||
# 3. Prepare Data
|
|
||||||
df = pd.read_csv("./data/XSHG_5min_600977.csv")
|
|
||||||
df['timestamps'] = pd.to_datetime(df['timestamps'])
|
|
||||||
|
|
||||||
lookback = 400
|
|
||||||
pred_len = 120
|
|
||||||
|
|
||||||
dfs = []
|
|
||||||
xtsp = []
|
|
||||||
ytsp = []
|
|
||||||
for i in range(5):
|
|
||||||
idf = df.loc[(i*400):(i*400+lookback-1), ['open', 'high', 'low', 'close', 'volume', 'amount']]
|
|
||||||
i_x_timestamp = df.loc[(i*400):(i*400+lookback-1), 'timestamps']
|
|
||||||
i_y_timestamp = df.loc[(i*400+lookback):(i*400+lookback+pred_len-1), 'timestamps']
|
|
||||||
|
|
||||||
dfs.append(idf)
|
|
||||||
xtsp.append(i_x_timestamp)
|
|
||||||
ytsp.append(i_y_timestamp)
|
|
||||||
|
|
||||||
pred_df = predictor.predict_batch(
|
|
||||||
df_list=dfs,
|
|
||||||
x_timestamp_list=xtsp,
|
|
||||||
y_timestamp_list=ytsp,
|
|
||||||
pred_len=pred_len,
|
|
||||||
)
|
|
||||||
@ -1,80 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import sys
|
|
||||||
sys.path.append("../")
|
|
||||||
from model import Kronos, KronosTokenizer, KronosPredictor
|
|
||||||
|
|
||||||
|
|
||||||
def plot_prediction(kline_df, pred_df):
|
|
||||||
pred_df.index = kline_df.index[-pred_df.shape[0]:]
|
|
||||||
sr_close = kline_df['close']
|
|
||||||
sr_pred_close = pred_df['close']
|
|
||||||
sr_close.name = 'Ground Truth'
|
|
||||||
sr_pred_close.name = "Prediction"
|
|
||||||
|
|
||||||
sr_volume = kline_df['volume']
|
|
||||||
sr_pred_volume = pred_df['volume']
|
|
||||||
sr_volume.name = 'Ground Truth'
|
|
||||||
sr_pred_volume.name = "Prediction"
|
|
||||||
|
|
||||||
close_df = pd.concat([sr_close, sr_pred_close], axis=1)
|
|
||||||
volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1)
|
|
||||||
|
|
||||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
|
|
||||||
|
|
||||||
ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
|
|
||||||
ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
|
|
||||||
ax1.set_ylabel('Close Price', fontsize=14)
|
|
||||||
ax1.legend(loc='lower left', fontsize=12)
|
|
||||||
ax1.grid(True)
|
|
||||||
|
|
||||||
ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
|
|
||||||
ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
|
|
||||||
ax2.set_ylabel('Volume', fontsize=14)
|
|
||||||
ax2.legend(loc='upper left', fontsize=12)
|
|
||||||
ax2.grid(True)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
# 1. Load Model and Tokenizer
|
|
||||||
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
|
|
||||||
model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
|
|
||||||
|
|
||||||
# 2. Instantiate Predictor
|
|
||||||
predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512)
|
|
||||||
|
|
||||||
# 3. Prepare Data
|
|
||||||
df = pd.read_csv("./data/XSHG_5min_600977.csv")
|
|
||||||
df['timestamps'] = pd.to_datetime(df['timestamps'])
|
|
||||||
|
|
||||||
lookback = 400
|
|
||||||
pred_len = 120
|
|
||||||
|
|
||||||
x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close', 'volume', 'amount']]
|
|
||||||
x_timestamp = df.loc[:lookback-1, 'timestamps']
|
|
||||||
y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps']
|
|
||||||
|
|
||||||
# 4. Make Prediction
|
|
||||||
pred_df = predictor.predict(
|
|
||||||
df=x_df,
|
|
||||||
x_timestamp=x_timestamp,
|
|
||||||
y_timestamp=y_timestamp,
|
|
||||||
pred_len=pred_len,
|
|
||||||
T=1.0,
|
|
||||||
top_p=0.9,
|
|
||||||
sample_count=1,
|
|
||||||
verbose=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. Visualize Results
|
|
||||||
print("Forecasted Data Head:")
|
|
||||||
print(pred_df.head())
|
|
||||||
|
|
||||||
# Combine historical and forecasted data for plotting
|
|
||||||
kline_df = df.loc[:lookback+pred_len-1]
|
|
||||||
|
|
||||||
# visualize
|
|
||||||
plot_prediction(kline_df, pred_df)
|
|
||||||
|
|
||||||
@ -1,68 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import sys
|
|
||||||
sys.path.append("../")
|
|
||||||
from model import Kronos, KronosTokenizer, KronosPredictor
|
|
||||||
|
|
||||||
|
|
||||||
def plot_prediction(kline_df, pred_df):
|
|
||||||
pred_df.index = kline_df.index[-pred_df.shape[0]:]
|
|
||||||
sr_close = kline_df['close']
|
|
||||||
sr_pred_close = pred_df['close']
|
|
||||||
sr_close.name = 'Ground Truth'
|
|
||||||
sr_pred_close.name = "Prediction"
|
|
||||||
|
|
||||||
close_df = pd.concat([sr_close, sr_pred_close], axis=1)
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
|
|
||||||
|
|
||||||
ax.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
|
|
||||||
ax.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
|
|
||||||
ax.set_ylabel('Close Price', fontsize=14)
|
|
||||||
ax.legend(loc='lower left', fontsize=12)
|
|
||||||
ax.grid(True)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
# 1. Load Model and Tokenizer
|
|
||||||
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
|
|
||||||
model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
|
|
||||||
|
|
||||||
# 2. Instantiate Predictor
|
|
||||||
predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512)
|
|
||||||
|
|
||||||
# 3. Prepare Data
|
|
||||||
df = pd.read_csv("./data/XSHG_5min_600977.csv")
|
|
||||||
df['timestamps'] = pd.to_datetime(df['timestamps'])
|
|
||||||
|
|
||||||
lookback = 400
|
|
||||||
pred_len = 120
|
|
||||||
|
|
||||||
x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close']]
|
|
||||||
x_timestamp = df.loc[:lookback-1, 'timestamps']
|
|
||||||
y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps']
|
|
||||||
|
|
||||||
# 4. Make Prediction
|
|
||||||
pred_df = predictor.predict(
|
|
||||||
df=x_df,
|
|
||||||
x_timestamp=x_timestamp,
|
|
||||||
y_timestamp=y_timestamp,
|
|
||||||
pred_len=pred_len,
|
|
||||||
T=1.0,
|
|
||||||
top_p=0.9,
|
|
||||||
sample_count=1,
|
|
||||||
verbose=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. Visualize Results
|
|
||||||
print("Forecasted Data Head:")
|
|
||||||
print(pred_df.head())
|
|
||||||
|
|
||||||
# Combine historical and forecasted data for plotting
|
|
||||||
kline_df = df.loc[:lookback+pred_len-1]
|
|
||||||
|
|
||||||
# visualize
|
|
||||||
plot_prediction(kline_df, pred_df)
|
|
||||||
|
|
||||||
|
Before Width: | Height: | Size: 488 KiB |
BIN
figures/logo.png
|
Before Width: | Height: | Size: 851 KiB |
|
Before Width: | Height: | Size: 1.2 MiB |
|
Before Width: | Height: | Size: 189 KiB |
@ -1,131 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""
|
|
||||||
Configuration class for the entire project.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# =================================================================
|
|
||||||
# Data & Feature Parameters
|
|
||||||
# =================================================================
|
|
||||||
# TODO: Update this path to your Qlib data directory.
|
|
||||||
self.qlib_data_path = "~/.qlib/qlib_data/cn_data"
|
|
||||||
self.instrument = 'csi300'
|
|
||||||
|
|
||||||
# Overall time range for data loading from Qlib.
|
|
||||||
self.dataset_begin_time = "2011-01-01"
|
|
||||||
self.dataset_end_time = '2025-06-05'
|
|
||||||
|
|
||||||
# Sliding window parameters for creating samples.
|
|
||||||
self.lookback_window = 90 # Number of past time steps for input.
|
|
||||||
self.predict_window = 10 # Number of future time steps for prediction.
|
|
||||||
self.max_context = 512 # Maximum context length for the model.
|
|
||||||
|
|
||||||
# Features to be used from the raw data.
|
|
||||||
self.feature_list = ['open', 'high', 'low', 'close', 'vol', 'amt']
|
|
||||||
# Time-based features to be generated.
|
|
||||||
self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month']
|
|
||||||
|
|
||||||
# =================================================================
|
|
||||||
# Dataset Splitting & Paths
|
|
||||||
# =================================================================
|
|
||||||
# Note: The validation/test set starts earlier than the training/validation set ends
|
|
||||||
# to account for the `lookback_window`.
|
|
||||||
self.train_time_range = ["2011-01-01", "2022-12-31"]
|
|
||||||
self.val_time_range = ["2022-09-01", "2024-06-30"]
|
|
||||||
self.test_time_range = ["2024-04-01", "2025-06-05"]
|
|
||||||
self.backtest_time_range = ["2024-07-01", "2025-06-05"]
|
|
||||||
|
|
||||||
# TODO: Directory to save the processed, pickled datasets.
|
|
||||||
self.dataset_path = "./data/processed_datasets"
|
|
||||||
|
|
||||||
# =================================================================
|
|
||||||
# Training Hyperparameters
|
|
||||||
# =================================================================
|
|
||||||
self.clip = 5.0 # Clipping value for normalized data to prevent outliers.
|
|
||||||
|
|
||||||
self.epochs = 30
|
|
||||||
self.log_interval = 100 # Log training status every N batches.
|
|
||||||
self.batch_size = 50 # Batch size per GPU.
|
|
||||||
|
|
||||||
# Number of samples to draw for one "epoch" of training/validation.
|
|
||||||
# This is useful for large datasets where a true epoch is too long.
|
|
||||||
self.n_train_iter = 2000 * self.batch_size
|
|
||||||
self.n_val_iter = 400 * self.batch_size
|
|
||||||
|
|
||||||
# Learning rates for different model components.
|
|
||||||
self.tokenizer_learning_rate = 2e-4
|
|
||||||
self.predictor_learning_rate = 4e-5
|
|
||||||
|
|
||||||
# Gradient accumulation to simulate a larger batch size.
|
|
||||||
self.accumulation_steps = 1
|
|
||||||
|
|
||||||
# AdamW optimizer parameters.
|
|
||||||
self.adam_beta1 = 0.9
|
|
||||||
self.adam_beta2 = 0.95
|
|
||||||
self.adam_weight_decay = 0.1
|
|
||||||
|
|
||||||
# Miscellaneous
|
|
||||||
self.seed = 100 # Global random seed for reproducibility.
|
|
||||||
|
|
||||||
# =================================================================
|
|
||||||
# Experiment Logging & Saving
|
|
||||||
# =================================================================
|
|
||||||
self.use_comet = True # Set to False if you don't want to use Comet ML
|
|
||||||
self.comet_config = {
|
|
||||||
# It is highly recommended to load secrets from environment variables
|
|
||||||
# for security purposes. Example: os.getenv("COMET_API_KEY")
|
|
||||||
"api_key": "YOUR_COMET_API_KEY",
|
|
||||||
"project_name": "Kronos-Finetune-Demo",
|
|
||||||
"workspace": "your_comet_workspace" # TODO: Change to your Comet ML workspace name
|
|
||||||
}
|
|
||||||
self.comet_tag = 'finetune_demo'
|
|
||||||
self.comet_name = 'finetune_demo'
|
|
||||||
|
|
||||||
# Base directory for saving model checkpoints and results.
|
|
||||||
# Using a general 'outputs' directory is a common practice.
|
|
||||||
self.save_path = "./outputs/models"
|
|
||||||
self.tokenizer_save_folder_name = 'finetune_tokenizer_demo'
|
|
||||||
self.predictor_save_folder_name = 'finetune_predictor_demo'
|
|
||||||
self.backtest_save_folder_name = 'finetune_backtest_demo'
|
|
||||||
|
|
||||||
# Path for backtesting results.
|
|
||||||
self.backtest_result_path = "./outputs/backtest_results"
|
|
||||||
|
|
||||||
# =================================================================
|
|
||||||
# Model & Checkpoint Paths
|
|
||||||
# =================================================================
|
|
||||||
# TODO: Update these paths to your pretrained model locations.
|
|
||||||
# These can be local paths or Hugging Face Hub model identifiers.
|
|
||||||
self.pretrained_tokenizer_path = "path/to/your/Kronos-Tokenizer-base"
|
|
||||||
self.pretrained_predictor_path = "path/to/your/Kronos-small"
|
|
||||||
|
|
||||||
# Paths to the fine-tuned models, derived from the save_path.
|
|
||||||
# These will be generated automatically during training.
|
|
||||||
self.finetuned_tokenizer_path = f"{self.save_path}/{self.tokenizer_save_folder_name}/checkpoints/best_model"
|
|
||||||
self.finetuned_predictor_path = f"{self.save_path}/{self.predictor_save_folder_name}/checkpoints/best_model"
|
|
||||||
|
|
||||||
# =================================================================
|
|
||||||
# Backtesting Parameters
|
|
||||||
# =================================================================
|
|
||||||
self.backtest_n_symbol_hold = 50 # Number of symbols to hold in the portfolio.
|
|
||||||
self.backtest_n_symbol_drop = 5 # Number of symbols to drop from the pool.
|
|
||||||
self.backtest_hold_thresh = 5 # Minimum holding period for a stock.
|
|
||||||
self.inference_T = 0.6
|
|
||||||
self.inference_top_p = 0.9
|
|
||||||
self.inference_top_k = 0
|
|
||||||
self.inference_sample_count = 5
|
|
||||||
self.backtest_batch_size = 1000
|
|
||||||
self.backtest_benchmark = self._set_benchmark(self.instrument)
|
|
||||||
|
|
||||||
def _set_benchmark(self, instrument):
|
|
||||||
dt_benchmark = {
|
|
||||||
'csi800': "SH000906",
|
|
||||||
'csi1000': "SH000852",
|
|
||||||
'csi300': "SH000300",
|
|
||||||
}
|
|
||||||
if instrument in dt_benchmark:
|
|
||||||
return dt_benchmark[instrument]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Benchmark not defined for instrument: {instrument}")
|
|
||||||
@ -1,145 +0,0 @@
|
|||||||
import pickle
|
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from config import Config
|
|
||||||
|
|
||||||
|
|
||||||
class QlibDataset(Dataset):
|
|
||||||
"""
|
|
||||||
A PyTorch Dataset for handling Qlib financial time series data.
|
|
||||||
|
|
||||||
This dataset pre-computes all possible start indices for sliding windows
|
|
||||||
and then randomly samples from them during training/validation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data_type (str): The type of dataset to load, either 'train' or 'val'.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If `data_type` is not 'train' or 'val'.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, data_type: str = 'train'):
|
|
||||||
self.config = Config()
|
|
||||||
if data_type not in ['train', 'val']:
|
|
||||||
raise ValueError("data_type must be 'train' or 'val'")
|
|
||||||
self.data_type = data_type
|
|
||||||
|
|
||||||
# Use a dedicated random number generator for sampling to avoid
|
|
||||||
# interfering with other random processes (e.g., in model initialization).
|
|
||||||
self.py_rng = random.Random(self.config.seed)
|
|
||||||
|
|
||||||
# Set paths and number of samples based on the data type.
|
|
||||||
if data_type == 'train':
|
|
||||||
self.data_path = f"{self.config.dataset_path}/train_data.pkl"
|
|
||||||
self.n_samples = self.config.n_train_iter
|
|
||||||
else:
|
|
||||||
self.data_path = f"{self.config.dataset_path}/val_data.pkl"
|
|
||||||
self.n_samples = self.config.n_val_iter
|
|
||||||
|
|
||||||
with open(self.data_path, 'rb') as f:
|
|
||||||
self.data = pickle.load(f)
|
|
||||||
|
|
||||||
self.window = self.config.lookback_window + self.config.predict_window + 1
|
|
||||||
|
|
||||||
self.symbols = list(self.data.keys())
|
|
||||||
self.feature_list = self.config.feature_list
|
|
||||||
self.time_feature_list = self.config.time_feature_list
|
|
||||||
|
|
||||||
# Pre-compute all possible (symbol, start_index) pairs.
|
|
||||||
self.indices = []
|
|
||||||
print(f"[{data_type.upper()}] Pre-computing sample indices...")
|
|
||||||
for symbol in self.symbols:
|
|
||||||
df = self.data[symbol].reset_index()
|
|
||||||
series_len = len(df)
|
|
||||||
num_samples = series_len - self.window + 1
|
|
||||||
|
|
||||||
if num_samples > 0:
|
|
||||||
# Generate time features and store them directly in the dataframe.
|
|
||||||
df['minute'] = df['datetime'].dt.minute
|
|
||||||
df['hour'] = df['datetime'].dt.hour
|
|
||||||
df['weekday'] = df['datetime'].dt.weekday
|
|
||||||
df['day'] = df['datetime'].dt.day
|
|
||||||
df['month'] = df['datetime'].dt.month
|
|
||||||
# Keep only necessary columns to save memory.
|
|
||||||
self.data[symbol] = df[self.feature_list + self.time_feature_list]
|
|
||||||
|
|
||||||
# Add all valid starting indices for this symbol to the global list.
|
|
||||||
for i in range(num_samples):
|
|
||||||
self.indices.append((symbol, i))
|
|
||||||
|
|
||||||
# The effective dataset size is the minimum of the configured iterations
|
|
||||||
# and the total number of available samples.
|
|
||||||
self.n_samples = min(self.n_samples, len(self.indices))
|
|
||||||
print(f"[{data_type.upper()}] Found {len(self.indices)} possible samples. Using {self.n_samples} per epoch.")
|
|
||||||
|
|
||||||
def set_epoch_seed(self, epoch: int):
|
|
||||||
"""
|
|
||||||
Sets a new seed for the random sampler for each epoch. This is crucial
|
|
||||||
for reproducibility in distributed training.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
epoch (int): The current epoch number.
|
|
||||||
"""
|
|
||||||
epoch_seed = self.config.seed + epoch
|
|
||||||
self.py_rng.seed(epoch_seed)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
"""Returns the number of samples per epoch."""
|
|
||||||
return self.n_samples
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Retrieves a random sample from the dataset.
|
|
||||||
|
|
||||||
Note: The `idx` argument is ignored. Instead, a random index is drawn
|
|
||||||
from the pre-computed `self.indices` list using `self.py_rng`. This
|
|
||||||
ensures random sampling over the entire dataset for each call.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
idx (int): Ignored.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
|
||||||
- x_tensor (torch.Tensor): The normalized feature tensor.
|
|
||||||
- x_stamp_tensor (torch.Tensor): The time feature tensor.
|
|
||||||
"""
|
|
||||||
# Select a random sample from the entire pool of indices.
|
|
||||||
random_idx = self.py_rng.randint(0, len(self.indices) - 1)
|
|
||||||
symbol, start_idx = self.indices[random_idx]
|
|
||||||
|
|
||||||
# Extract the sliding window from the dataframe.
|
|
||||||
df = self.data[symbol]
|
|
||||||
end_idx = start_idx + self.window
|
|
||||||
win_df = df.iloc[start_idx:end_idx]
|
|
||||||
|
|
||||||
# Separate main features and time features.
|
|
||||||
x = win_df[self.feature_list].values.astype(np.float32)
|
|
||||||
x_stamp = win_df[self.time_feature_list].values.astype(np.float32)
|
|
||||||
|
|
||||||
# Perform instance-level normalization.
|
|
||||||
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
|
|
||||||
x = (x - x_mean) / (x_std + 1e-5)
|
|
||||||
x = np.clip(x, -self.config.clip, self.config.clip)
|
|
||||||
|
|
||||||
# Convert to PyTorch tensors.
|
|
||||||
x_tensor = torch.from_numpy(x)
|
|
||||||
x_stamp_tensor = torch.from_numpy(x_stamp)
|
|
||||||
|
|
||||||
return x_tensor, x_stamp_tensor
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# Example usage and verification.
|
|
||||||
print("Creating training dataset instance...")
|
|
||||||
train_dataset = QlibDataset(data_type='train')
|
|
||||||
|
|
||||||
print(f"Dataset length: {len(train_dataset)}")
|
|
||||||
|
|
||||||
if len(train_dataset) > 0:
|
|
||||||
try_x, try_x_stamp = train_dataset[100] # Index 100 is ignored.
|
|
||||||
print(f"Sample feature shape: {try_x.shape}")
|
|
||||||
print(f"Sample time feature shape: {try_x_stamp.shape}")
|
|
||||||
else:
|
|
||||||
print("Dataset is empty.")
|
|
||||||
@ -1,130 +0,0 @@
|
|||||||
import os
|
|
||||||
import pickle
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import qlib
|
|
||||||
from qlib.config import REG_CN
|
|
||||||
from qlib.data import D
|
|
||||||
from qlib.data.dataset.loader import QlibDataLoader
|
|
||||||
from tqdm import trange
|
|
||||||
|
|
||||||
from config import Config
|
|
||||||
|
|
||||||
|
|
||||||
class QlibDataPreprocessor:
|
|
||||||
"""
|
|
||||||
A class to handle the loading, processing, and splitting of Qlib financial data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initializes the preprocessor with configuration and data fields."""
|
|
||||||
self.config = Config()
|
|
||||||
self.data_fields = ['open', 'close', 'high', 'low', 'volume', 'vwap']
|
|
||||||
self.data = {} # A dictionary to store processed data for each symbol.
|
|
||||||
|
|
||||||
def initialize_qlib(self):
|
|
||||||
"""Initializes the Qlib environment."""
|
|
||||||
print("Initializing Qlib...")
|
|
||||||
qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN)
|
|
||||||
|
|
||||||
def load_qlib_data(self):
|
|
||||||
"""
|
|
||||||
Loads raw data from Qlib, processes it symbol by symbol, and stores
|
|
||||||
it in the `self.data` attribute.
|
|
||||||
"""
|
|
||||||
print("Loading and processing data from Qlib...")
|
|
||||||
data_fields_qlib = ['$' + f for f in self.data_fields]
|
|
||||||
cal: np.ndarray = D.calendar()
|
|
||||||
|
|
||||||
# Determine the actual start and end times to load, including buffer for lookback and predict windows.
|
|
||||||
start_index = cal.searchsorted(pd.Timestamp(self.config.dataset_begin_time))
|
|
||||||
end_index = cal.searchsorted(pd.Timestamp(self.config.dataset_end_time))
|
|
||||||
|
|
||||||
# Check if start_index lookbackw_window will cause negative index
|
|
||||||
adjusted_start_index = max(start_index - self.config.lookback_window, 0)
|
|
||||||
real_start_time = cal[adjusted_start_index]
|
|
||||||
|
|
||||||
# Check if end_index exceeds the range of the array
|
|
||||||
if end_index >= len(cal):
|
|
||||||
end_index = len(cal) - 1
|
|
||||||
elif cal[end_index] != pd.Timestamp(self.config.dataset_end_time):
|
|
||||||
end_index -= 1
|
|
||||||
|
|
||||||
# Check if end_index+predictw_window will exceed the range of the array
|
|
||||||
adjusted_end_index = min(end_index + self.config.predict_window, len(cal) - 1)
|
|
||||||
real_end_time = cal[adjusted_end_index]
|
|
||||||
|
|
||||||
# Load data using Qlib's data loader.
|
|
||||||
data_df = QlibDataLoader(config=data_fields_qlib).load(
|
|
||||||
self.config.instrument, real_start_time, real_end_time
|
|
||||||
)
|
|
||||||
data_df = data_df.stack().unstack(level=1) # Reshape for easier access.
|
|
||||||
|
|
||||||
symbol_list = list(data_df.columns)
|
|
||||||
for i in trange(len(symbol_list), desc="Processing Symbols"):
|
|
||||||
symbol = symbol_list[i]
|
|
||||||
symbol_df = data_df[symbol]
|
|
||||||
|
|
||||||
# Pivot the table to have features as columns and datetime as index.
|
|
||||||
symbol_df = symbol_df.reset_index().rename(columns={'level_1': 'field'})
|
|
||||||
symbol_df = pd.pivot(symbol_df, index='datetime', columns='field', values=symbol)
|
|
||||||
symbol_df = symbol_df.rename(columns={f'${field}': field for field in self.data_fields})
|
|
||||||
|
|
||||||
# Calculate amount and select final features.
|
|
||||||
symbol_df['vol'] = symbol_df['volume']
|
|
||||||
symbol_df['amt'] = (symbol_df['open'] + symbol_df['high'] + symbol_df['low'] + symbol_df['close']) / 4 * symbol_df['vol']
|
|
||||||
symbol_df = symbol_df[self.config.feature_list]
|
|
||||||
|
|
||||||
# Filter out symbols with insufficient data.
|
|
||||||
symbol_df = symbol_df.dropna()
|
|
||||||
if len(symbol_df) < self.config.lookback_window + self.config.predict_window + 1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.data[symbol] = symbol_df
|
|
||||||
|
|
||||||
def prepare_dataset(self):
|
|
||||||
"""
|
|
||||||
Splits the loaded data into train, validation, and test sets and saves them to disk.
|
|
||||||
"""
|
|
||||||
print("Splitting data into train, validation, and test sets...")
|
|
||||||
train_data, val_data, test_data = {}, {}, {}
|
|
||||||
|
|
||||||
symbol_list = list(self.data.keys())
|
|
||||||
for i in trange(len(symbol_list), desc="Preparing Datasets"):
|
|
||||||
symbol = symbol_list[i]
|
|
||||||
symbol_df = self.data[symbol]
|
|
||||||
|
|
||||||
# Define time ranges from config.
|
|
||||||
train_start, train_end = self.config.train_time_range
|
|
||||||
val_start, val_end = self.config.val_time_range
|
|
||||||
test_start, test_end = self.config.test_time_range
|
|
||||||
|
|
||||||
# Create boolean masks for each dataset split.
|
|
||||||
train_mask = (symbol_df.index >= train_start) & (symbol_df.index <= train_end)
|
|
||||||
val_mask = (symbol_df.index >= val_start) & (symbol_df.index <= val_end)
|
|
||||||
test_mask = (symbol_df.index >= test_start) & (symbol_df.index <= test_end)
|
|
||||||
|
|
||||||
# Apply masks to create the final datasets.
|
|
||||||
train_data[symbol] = symbol_df[train_mask]
|
|
||||||
val_data[symbol] = symbol_df[val_mask]
|
|
||||||
test_data[symbol] = symbol_df[test_mask]
|
|
||||||
|
|
||||||
# Save the datasets using pickle.
|
|
||||||
os.makedirs(self.config.dataset_path, exist_ok=True)
|
|
||||||
with open(f"{self.config.dataset_path}/train_data.pkl", 'wb') as f:
|
|
||||||
pickle.dump(train_data, f)
|
|
||||||
with open(f"{self.config.dataset_path}/val_data.pkl", 'wb') as f:
|
|
||||||
pickle.dump(val_data, f)
|
|
||||||
with open(f"{self.config.dataset_path}/test_data.pkl", 'wb') as f:
|
|
||||||
pickle.dump(test_data, f)
|
|
||||||
|
|
||||||
print("Datasets prepared and saved successfully.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# This block allows the script to be run directly to perform data preprocessing.
|
|
||||||
preprocessor = QlibDataPreprocessor()
|
|
||||||
preprocessor.initialize_qlib()
|
|
||||||
preprocessor.load_qlib_data()
|
|
||||||
preprocessor.prepare_dataset()
|
|
||||||
|
|
||||||
@ -1,362 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
import pickle
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
from tqdm import trange, tqdm
|
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
|
|
||||||
import qlib
|
|
||||||
from qlib.config import REG_CN
|
|
||||||
from qlib.backtest import backtest, executor, CommonInfrastructure
|
|
||||||
from qlib.contrib.evaluate import risk_analysis
|
|
||||||
from qlib.contrib.strategy import TopkDropoutStrategy
|
|
||||||
from qlib.utils import flatten_dict
|
|
||||||
from qlib.utils.time import Freq
|
|
||||||
|
|
||||||
# Ensure project root is in the Python path
|
|
||||||
sys.path.append("../")
|
|
||||||
from config import Config
|
|
||||||
from model.kronos import Kronos, KronosTokenizer, auto_regressive_inference
|
|
||||||
|
|
||||||
|
|
||||||
# =================================================================================
|
|
||||||
# 1. Data Loading and Processing for Inference
|
|
||||||
# =================================================================================
|
|
||||||
|
|
||||||
class QlibTestDataset(Dataset):
|
|
||||||
"""
|
|
||||||
PyTorch Dataset for handling Qlib test data, specifically for inference.
|
|
||||||
|
|
||||||
This dataset iterates through all possible sliding windows sequentially. It also
|
|
||||||
yields metadata like symbol and timestamp, which are crucial for mapping
|
|
||||||
predictions back to the original time series.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, data: dict, config: Config):
|
|
||||||
self.data = data
|
|
||||||
self.config = config
|
|
||||||
self.window_size = config.lookback_window + config.predict_window
|
|
||||||
self.symbols = list(self.data.keys())
|
|
||||||
self.feature_list = config.feature_list
|
|
||||||
self.time_feature_list = config.time_feature_list
|
|
||||||
self.indices = []
|
|
||||||
|
|
||||||
print("Preprocessing and building indices for test dataset...")
|
|
||||||
for symbol in self.symbols:
|
|
||||||
df = self.data[symbol].reset_index()
|
|
||||||
# Generate time features on-the-fly
|
|
||||||
df['minute'] = df['datetime'].dt.minute
|
|
||||||
df['hour'] = df['datetime'].dt.hour
|
|
||||||
df['weekday'] = df['datetime'].dt.weekday
|
|
||||||
df['day'] = df['datetime'].dt.day
|
|
||||||
df['month'] = df['datetime'].dt.month
|
|
||||||
self.data[symbol] = df # Store preprocessed dataframe
|
|
||||||
|
|
||||||
num_samples = len(df) - self.window_size + 1
|
|
||||||
if num_samples > 0:
|
|
||||||
for i in range(num_samples):
|
|
||||||
timestamp = df.iloc[i + self.config.lookback_window - 1]['datetime']
|
|
||||||
self.indices.append((symbol, i, timestamp))
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.indices)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
|
||||||
symbol, start_idx, timestamp = self.indices[idx]
|
|
||||||
df = self.data[symbol]
|
|
||||||
|
|
||||||
context_end = start_idx + self.config.lookback_window
|
|
||||||
predict_end = context_end + self.config.predict_window
|
|
||||||
|
|
||||||
context_df = df.iloc[start_idx:context_end]
|
|
||||||
predict_df = df.iloc[context_end:predict_end]
|
|
||||||
|
|
||||||
x = context_df[self.feature_list].values.astype(np.float32)
|
|
||||||
x_stamp = context_df[self.time_feature_list].values.astype(np.float32)
|
|
||||||
y_stamp = predict_df[self.time_feature_list].values.astype(np.float32)
|
|
||||||
|
|
||||||
# Instance-level normalization, consistent with training
|
|
||||||
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
|
|
||||||
x = (x - x_mean) / (x_std + 1e-5)
|
|
||||||
x = np.clip(x, -self.config.clip, self.config.clip)
|
|
||||||
|
|
||||||
return torch.from_numpy(x), torch.from_numpy(x_stamp), torch.from_numpy(y_stamp), symbol, timestamp
|
|
||||||
|
|
||||||
|
|
||||||
# =================================================================================
|
|
||||||
# 2. Backtesting Logic
|
|
||||||
# =================================================================================
|
|
||||||
|
|
||||||
class QlibBacktest:
|
|
||||||
"""
|
|
||||||
A wrapper class for conducting backtesting experiments using Qlib.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Config):
|
|
||||||
self.config = config
|
|
||||||
self.initialize_qlib()
|
|
||||||
|
|
||||||
def initialize_qlib(self):
|
|
||||||
"""Initializes the Qlib environment."""
|
|
||||||
print("Initializing Qlib for backtesting...")
|
|
||||||
qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN)
|
|
||||||
|
|
||||||
def run_single_backtest(self, signal_series: pd.Series) -> pd.DataFrame:
|
|
||||||
"""
|
|
||||||
Runs a single backtest for a given prediction signal.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
signal_series (pd.Series): A pandas Series with a MultiIndex
|
|
||||||
(instrument, datetime) and prediction scores.
|
|
||||||
Returns:
|
|
||||||
pd.DataFrame: A DataFrame containing the performance report.
|
|
||||||
"""
|
|
||||||
strategy = TopkDropoutStrategy(
|
|
||||||
topk=self.config.backtest_n_symbol_hold,
|
|
||||||
n_drop=self.config.backtest_n_symbol_drop,
|
|
||||||
hold_thresh=self.config.backtest_hold_thresh,
|
|
||||||
signal=signal_series,
|
|
||||||
)
|
|
||||||
executor_config = {
|
|
||||||
"time_per_step": "day",
|
|
||||||
"generate_portfolio_metrics": True,
|
|
||||||
"delay_execution": True,
|
|
||||||
}
|
|
||||||
backtest_config = {
|
|
||||||
"start_time": self.config.backtest_time_range[0],
|
|
||||||
"end_time": self.config.backtest_time_range[1],
|
|
||||||
"account": 100_000_000,
|
|
||||||
"benchmark": self.config.backtest_benchmark,
|
|
||||||
"exchange_kwargs": {
|
|
||||||
"freq": "day", "limit_threshold": 0.095, "deal_price": "open",
|
|
||||||
"open_cost": 0.001, "close_cost": 0.0015, "min_cost": 5,
|
|
||||||
},
|
|
||||||
"executor": executor.SimulatorExecutor(**executor_config),
|
|
||||||
}
|
|
||||||
|
|
||||||
portfolio_metric_dict, _ = backtest(strategy=strategy, **backtest_config)
|
|
||||||
analysis_freq = "{0}{1}".format(*Freq.parse("day"))
|
|
||||||
report, _ = portfolio_metric_dict.get(analysis_freq)
|
|
||||||
|
|
||||||
# --- Analysis and Reporting ---
|
|
||||||
analysis = {
|
|
||||||
"excess_return_without_cost": risk_analysis(report["return"] - report["bench"], freq=analysis_freq),
|
|
||||||
"excess_return_with_cost": risk_analysis(report["return"] - report["bench"] - report["cost"], freq=analysis_freq),
|
|
||||||
}
|
|
||||||
print("\n--- Backtest Analysis ---")
|
|
||||||
print("Benchmark Return:", risk_analysis(report["bench"], freq=analysis_freq), sep='\n')
|
|
||||||
print("\nExcess Return (w/o cost):", analysis["excess_return_without_cost"], sep='\n')
|
|
||||||
print("\nExcess Return (w/ cost):", analysis["excess_return_with_cost"], sep='\n')
|
|
||||||
|
|
||||||
report_df = pd.DataFrame({
|
|
||||||
"cum_bench": report["bench"].cumsum(),
|
|
||||||
"cum_return_w_cost": (report["return"] - report["cost"]).cumsum(),
|
|
||||||
"cum_ex_return_w_cost": (report["return"] - report["bench"] - report["cost"]).cumsum(),
|
|
||||||
})
|
|
||||||
return report_df
|
|
||||||
|
|
||||||
def run_and_plot_results(self, signals: dict[str, pd.DataFrame]):
|
|
||||||
"""
|
|
||||||
Runs backtests for multiple signals and plots the cumulative return curves.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
signals (dict[str, pd.DataFrame]): A dictionary where keys are signal names
|
|
||||||
and values are prediction DataFrames.
|
|
||||||
"""
|
|
||||||
return_df, ex_return_df, bench_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
|
|
||||||
|
|
||||||
for signal_name, pred_df in signals.items():
|
|
||||||
print(f"\nBacktesting signal: {signal_name}...")
|
|
||||||
pred_series = pred_df.stack()
|
|
||||||
pred_series.index.names = ['datetime', 'instrument']
|
|
||||||
pred_series = pred_series.swaplevel().sort_index()
|
|
||||||
report_df = self.run_single_backtest(pred_series)
|
|
||||||
|
|
||||||
return_df[signal_name] = report_df['cum_return_w_cost']
|
|
||||||
ex_return_df[signal_name] = report_df['cum_ex_return_w_cost']
|
|
||||||
if 'return' not in bench_df:
|
|
||||||
bench_df['return'] = report_df['cum_bench']
|
|
||||||
|
|
||||||
# Plotting results
|
|
||||||
fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
|
|
||||||
return_df.plot(ax=axes[0], title='Cumulative Return with Cost', grid=True)
|
|
||||||
axes[0].plot(bench_df['return'], label=self.config.instrument.upper(), color='black', linestyle='--')
|
|
||||||
axes[0].legend()
|
|
||||||
axes[0].set_ylabel("Cumulative Return")
|
|
||||||
|
|
||||||
ex_return_df.plot(ax=axes[1], title='Cumulative Excess Return with Cost', grid=True)
|
|
||||||
axes[1].legend()
|
|
||||||
axes[1].set_xlabel("Date")
|
|
||||||
axes[1].set_ylabel("Cumulative Excess Return")
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig("../figures/backtest_result_example.png", dpi=200)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
# =================================================================================
|
|
||||||
# 3. Inference Logic
|
|
||||||
# =================================================================================
|
|
||||||
|
|
||||||
def load_models(config: dict) -> tuple[KronosTokenizer, Kronos]:
|
|
||||||
"""Loads the fine-tuned tokenizer and predictor model."""
|
|
||||||
device = torch.device(config['device'])
|
|
||||||
print(f"Loading models onto device: {device}...")
|
|
||||||
tokenizer = KronosTokenizer.from_pretrained(config['tokenizer_path']).to(device).eval()
|
|
||||||
model = Kronos.from_pretrained(config['model_path']).to(device).eval()
|
|
||||||
return tokenizer, model
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn_for_inference(batch):
|
|
||||||
"""
|
|
||||||
Custom collate function to handle batches containing Tensors, strings, and Timestamps.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch (list): A list of samples, where each sample is the tuple returned by
|
|
||||||
QlibTestDataset.__getitem__.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A single tuple containing the batched data.
|
|
||||||
"""
|
|
||||||
# Unzip the list of samples into separate lists for each data type
|
|
||||||
x, x_stamp, y_stamp, symbols, timestamps = zip(*batch)
|
|
||||||
|
|
||||||
# Stack the tensors to create a batch
|
|
||||||
x_batch = torch.stack(x, dim=0)
|
|
||||||
x_stamp_batch = torch.stack(x_stamp, dim=0)
|
|
||||||
y_stamp_batch = torch.stack(y_stamp, dim=0)
|
|
||||||
|
|
||||||
# Return the strings and timestamps as lists
|
|
||||||
return x_batch, x_stamp_batch, y_stamp_batch, list(symbols), list(timestamps)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_predictions(config: dict, test_data: dict) -> dict[str, pd.DataFrame]:
|
|
||||||
"""
|
|
||||||
Runs inference on the test dataset to generate prediction signals.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict): A dictionary containing inference parameters.
|
|
||||||
test_data (dict): The raw test data loaded from a pickle file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary where keys are signal types (e.g., 'mean', 'last') and
|
|
||||||
values are DataFrames of predictions (datetime index, symbol columns).
|
|
||||||
"""
|
|
||||||
tokenizer, model = load_models(config)
|
|
||||||
device = torch.device(config['device'])
|
|
||||||
|
|
||||||
# Use the Dataset and DataLoader for efficient batching and processing
|
|
||||||
dataset = QlibTestDataset(data=test_data, config=Config())
|
|
||||||
loader = DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=config['batch_size'] // config['sample_count'],
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=os.cpu_count() // 2,
|
|
||||||
collate_fn=collate_fn_for_inference
|
|
||||||
)
|
|
||||||
|
|
||||||
results = defaultdict(list)
|
|
||||||
with torch.no_grad():
|
|
||||||
for x, x_stamp, y_stamp, symbols, timestamps in tqdm(loader, desc="Inference"):
|
|
||||||
preds = auto_regressive_inference(
|
|
||||||
tokenizer, model, x.to(device), x_stamp.to(device), y_stamp.to(device),
|
|
||||||
max_context=config['max_context'], pred_len=config['pred_len'], clip=config['clip'],
|
|
||||||
T=config['T'], top_k=config['top_k'], top_p=config['top_p'], sample_count=config['sample_count']
|
|
||||||
)
|
|
||||||
# You can try commenting on this line to keep the history data
|
|
||||||
preds = preds[:, -config['pred_len']:, :]
|
|
||||||
|
|
||||||
# The 'close' price is at index 3 in `feature_list`
|
|
||||||
last_day_close = x[:, -1, 3].numpy()
|
|
||||||
signals = {
|
|
||||||
'last': preds[:, -1, 3] - last_day_close,
|
|
||||||
'mean': np.mean(preds[:, :, 3], axis=1) - last_day_close,
|
|
||||||
'max': np.max(preds[:, :, 3], axis=1) - last_day_close,
|
|
||||||
'min': np.min(preds[:, :, 3], axis=1) - last_day_close,
|
|
||||||
}
|
|
||||||
|
|
||||||
for i in range(len(symbols)):
|
|
||||||
for sig_type, sig_values in signals.items():
|
|
||||||
results[sig_type].append((timestamps[i], symbols[i], sig_values[i]))
|
|
||||||
|
|
||||||
print("Post-processing predictions into DataFrames...")
|
|
||||||
prediction_dfs = {}
|
|
||||||
for sig_type, records in results.items():
|
|
||||||
df = pd.DataFrame(records, columns=['datetime', 'instrument', 'score'])
|
|
||||||
pivot_df = df.pivot_table(index='datetime', columns='instrument', values='score')
|
|
||||||
prediction_dfs[sig_type] = pivot_df.sort_index()
|
|
||||||
|
|
||||||
return prediction_dfs
|
|
||||||
|
|
||||||
|
|
||||||
# =================================================================================
|
|
||||||
# 4. Main Execution
|
|
||||||
# =================================================================================
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main function to set up config, run inference, and execute backtesting."""
|
|
||||||
parser = argparse.ArgumentParser(description="Run Kronos Inference and Backtesting")
|
|
||||||
parser.add_argument("--device", type=str, default="cuda:1", help="Device for inference (e.g., 'cuda:0', 'cpu')")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# --- 1. Configuration Setup ---
|
|
||||||
base_config = Config()
|
|
||||||
|
|
||||||
# Create a dedicated dictionary for this run's configuration
|
|
||||||
run_config = {
|
|
||||||
'device': args.device,
|
|
||||||
'data_path': base_config.dataset_path,
|
|
||||||
'result_save_path': base_config.backtest_result_path,
|
|
||||||
'result_name': base_config.backtest_save_folder_name,
|
|
||||||
'tokenizer_path': base_config.finetuned_tokenizer_path,
|
|
||||||
'model_path': base_config.finetuned_predictor_path,
|
|
||||||
'max_context': base_config.max_context,
|
|
||||||
'pred_len': base_config.predict_window,
|
|
||||||
'clip': base_config.clip,
|
|
||||||
'T': base_config.inference_T,
|
|
||||||
'top_k': base_config.inference_top_k,
|
|
||||||
'top_p': base_config.inference_top_p,
|
|
||||||
'sample_count': base_config.inference_sample_count,
|
|
||||||
'batch_size': base_config.backtest_batch_size,
|
|
||||||
}
|
|
||||||
|
|
||||||
print("--- Running with Configuration ---")
|
|
||||||
for key, val in run_config.items():
|
|
||||||
print(f"{key:>20}: {val}")
|
|
||||||
print("-" * 35)
|
|
||||||
|
|
||||||
# --- 2. Load Data ---
|
|
||||||
test_data_path = os.path.join(run_config['data_path'], "test_data.pkl")
|
|
||||||
print(f"Loading test data from {test_data_path}...")
|
|
||||||
with open(test_data_path, 'rb') as f:
|
|
||||||
test_data = pickle.load(f)
|
|
||||||
print(test_data)
|
|
||||||
# --- 3. Generate Predictions ---
|
|
||||||
model_preds = generate_predictions(run_config, test_data)
|
|
||||||
|
|
||||||
# --- 4. Save Predictions ---
|
|
||||||
save_dir = os.path.join(run_config['result_save_path'], run_config['result_name'])
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
predictions_file = os.path.join(save_dir, "predictions.pkl")
|
|
||||||
print(f"Saving prediction signals to {predictions_file}...")
|
|
||||||
with open(predictions_file, 'wb') as f:
|
|
||||||
pickle.dump(model_preds, f)
|
|
||||||
|
|
||||||
# --- 5. Run Backtesting ---
|
|
||||||
with open(predictions_file, 'rb') as f:
|
|
||||||
model_preds = pickle.load(f)
|
|
||||||
|
|
||||||
backtester = QlibBacktest(base_config)
|
|
||||||
backtester.run_and_plot_results(model_preds)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|
||||||
|
|
||||||
@ -1,244 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from time import gmtime, strftime
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
|
|
||||||
import comet_ml
|
|
||||||
|
|
||||||
# Ensure project root is in path
|
|
||||||
sys.path.append('../')
|
|
||||||
from config import Config
|
|
||||||
from dataset import QlibDataset
|
|
||||||
from model.kronos import KronosTokenizer, Kronos
|
|
||||||
# Import shared utilities
|
|
||||||
from utils.training_utils import (
|
|
||||||
setup_ddp,
|
|
||||||
cleanup_ddp,
|
|
||||||
set_seed,
|
|
||||||
get_model_size,
|
|
||||||
format_time
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_dataloaders(config: dict, rank: int, world_size: int):
|
|
||||||
"""
|
|
||||||
Creates and returns distributed dataloaders for training and validation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict): A dictionary of configuration parameters.
|
|
||||||
rank (int): The global rank of the current process.
|
|
||||||
world_size (int): The total number of processes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (train_loader, val_loader, train_dataset, valid_dataset).
|
|
||||||
"""
|
|
||||||
print(f"[Rank {rank}] Creating distributed dataloaders...")
|
|
||||||
train_dataset = QlibDataset('train')
|
|
||||||
valid_dataset = QlibDataset('val')
|
|
||||||
print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}")
|
|
||||||
|
|
||||||
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
|
|
||||||
val_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, rank=rank, shuffle=False)
|
|
||||||
|
|
||||||
train_loader = DataLoader(
|
|
||||||
train_dataset, batch_size=config['batch_size'], sampler=train_sampler,
|
|
||||||
num_workers=config.get('num_workers', 2), pin_memory=True, drop_last=True
|
|
||||||
)
|
|
||||||
val_loader = DataLoader(
|
|
||||||
valid_dataset, batch_size=config['batch_size'], sampler=val_sampler,
|
|
||||||
num_workers=config.get('num_workers', 2), pin_memory=True, drop_last=False
|
|
||||||
)
|
|
||||||
return train_loader, val_loader, train_dataset, valid_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def train_model(model, tokenizer, device, config, save_dir, logger, rank, world_size):
|
|
||||||
"""
|
|
||||||
The main training and validation loop for the predictor.
|
|
||||||
"""
|
|
||||||
start_time = time.time()
|
|
||||||
if rank == 0:
|
|
||||||
effective_bs = config['batch_size'] * world_size
|
|
||||||
print(f"Effective BATCHSIZE per GPU: {config['batch_size']}, Total: {effective_bs}")
|
|
||||||
|
|
||||||
train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size)
|
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(
|
|
||||||
model.parameters(),
|
|
||||||
lr=config['predictor_learning_rate'],
|
|
||||||
betas=(config['adam_beta1'], config['adam_beta2']),
|
|
||||||
weight_decay=config['adam_weight_decay']
|
|
||||||
)
|
|
||||||
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
|
||||||
optimizer, max_lr=config['predictor_learning_rate'],
|
|
||||||
steps_per_epoch=len(train_loader), epochs=config['epochs'],
|
|
||||||
pct_start=0.03, div_factor=10
|
|
||||||
)
|
|
||||||
|
|
||||||
best_val_loss = float('inf')
|
|
||||||
dt_result = {}
|
|
||||||
batch_idx_global = 0
|
|
||||||
|
|
||||||
for epoch_idx in range(config['epochs']):
|
|
||||||
epoch_start_time = time.time()
|
|
||||||
model.train()
|
|
||||||
train_loader.sampler.set_epoch(epoch_idx)
|
|
||||||
|
|
||||||
train_dataset.set_epoch_seed(epoch_idx * 10000 + rank)
|
|
||||||
valid_dataset.set_epoch_seed(0)
|
|
||||||
|
|
||||||
for i, (batch_x, batch_x_stamp) in enumerate(train_loader):
|
|
||||||
batch_x = batch_x.squeeze(0).to(device, non_blocking=True)
|
|
||||||
batch_x_stamp = batch_x_stamp.squeeze(0).to(device, non_blocking=True)
|
|
||||||
|
|
||||||
# Tokenize input data on-the-fly
|
|
||||||
with torch.no_grad():
|
|
||||||
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
|
|
||||||
|
|
||||||
# Prepare inputs and targets for the language model
|
|
||||||
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
|
|
||||||
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
|
|
||||||
|
|
||||||
# Forward pass and loss calculation
|
|
||||||
logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
|
|
||||||
loss, s1_loss, s2_loss = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
|
|
||||||
|
|
||||||
# Backward pass and optimization
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0)
|
|
||||||
optimizer.step()
|
|
||||||
scheduler.step()
|
|
||||||
|
|
||||||
# Logging (Master Process Only)
|
|
||||||
if rank == 0 and (batch_idx_global + 1) % config['log_interval'] == 0:
|
|
||||||
lr = optimizer.param_groups[0]['lr']
|
|
||||||
print(
|
|
||||||
f"[Rank {rank}, Epoch {epoch_idx + 1}/{config['epochs']}, Step {i + 1}/{len(train_loader)}] "
|
|
||||||
f"LR {lr:.6f}, Loss: {loss.item():.4f}"
|
|
||||||
)
|
|
||||||
if rank == 0 and logger:
|
|
||||||
lr = optimizer.param_groups[0]['lr']
|
|
||||||
logger.log_metric('train_predictor_loss_batch', loss.item(), step=batch_idx_global)
|
|
||||||
logger.log_metric('train_S1_loss_each_batch', s1_loss.item(), step=batch_idx_global)
|
|
||||||
logger.log_metric('train_S2_loss_each_batch', s2_loss.item(), step=batch_idx_global)
|
|
||||||
logger.log_metric('predictor_learning_rate', lr, step=batch_idx_global)
|
|
||||||
|
|
||||||
batch_idx_global += 1
|
|
||||||
|
|
||||||
# --- Validation Loop ---
|
|
||||||
model.eval()
|
|
||||||
tot_val_loss_sum_rank = 0.0
|
|
||||||
val_batches_processed_rank = 0
|
|
||||||
with torch.no_grad():
|
|
||||||
for batch_x, batch_x_stamp in val_loader:
|
|
||||||
batch_x = batch_x.squeeze(0).to(device, non_blocking=True)
|
|
||||||
batch_x_stamp = batch_x_stamp.squeeze(0).to(device, non_blocking=True)
|
|
||||||
|
|
||||||
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
|
|
||||||
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
|
|
||||||
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
|
|
||||||
|
|
||||||
logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
|
|
||||||
val_loss, _, _ = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
|
|
||||||
|
|
||||||
tot_val_loss_sum_rank += val_loss.item()
|
|
||||||
val_batches_processed_rank += 1
|
|
||||||
|
|
||||||
# Reduce validation metrics
|
|
||||||
val_loss_sum_tensor = torch.tensor(tot_val_loss_sum_rank, device=device)
|
|
||||||
val_batches_tensor = torch.tensor(val_batches_processed_rank, device=device)
|
|
||||||
dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM)
|
|
||||||
dist.all_reduce(val_batches_tensor, op=dist.ReduceOp.SUM)
|
|
||||||
|
|
||||||
avg_val_loss = val_loss_sum_tensor.item() / val_batches_tensor.item() if val_batches_tensor.item() > 0 else 0
|
|
||||||
|
|
||||||
# --- End of Epoch Summary & Checkpointing (Master Process Only) ---
|
|
||||||
if rank == 0:
|
|
||||||
print(f"\n--- Epoch {epoch_idx + 1}/{config['epochs']} Summary ---")
|
|
||||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
|
||||||
print(f"Time This Epoch: {format_time(time.time() - epoch_start_time)}")
|
|
||||||
print(f"Total Time Elapsed: {format_time(time.time() - start_time)}\n")
|
|
||||||
if logger:
|
|
||||||
logger.log_metric('val_predictor_loss_epoch', avg_val_loss, epoch=epoch_idx)
|
|
||||||
|
|
||||||
if avg_val_loss < best_val_loss:
|
|
||||||
best_val_loss = avg_val_loss
|
|
||||||
save_path = f"{save_dir}/checkpoints/best_model"
|
|
||||||
model.module.save_pretrained(save_path)
|
|
||||||
print(f"Best model saved to {save_path} (Val Loss: {best_val_loss:.4f})")
|
|
||||||
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
dt_result['best_val_loss'] = best_val_loss
|
|
||||||
return dt_result
|
|
||||||
|
|
||||||
|
|
||||||
def main(config: dict):
|
|
||||||
"""Main function to orchestrate the DDP training process."""
|
|
||||||
rank, world_size, local_rank = setup_ddp()
|
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
|
||||||
set_seed(config['seed'], rank)
|
|
||||||
|
|
||||||
save_dir = os.path.join(config['save_path'], config['predictor_save_folder_name'])
|
|
||||||
|
|
||||||
# Logger and summary setup (master process only)
|
|
||||||
comet_logger, master_summary = None, {}
|
|
||||||
if rank == 0:
|
|
||||||
os.makedirs(os.path.join(save_dir, 'checkpoints'), exist_ok=True)
|
|
||||||
master_summary = {
|
|
||||||
'start_time': strftime("%Y-%m-%dT%H-%M-%S", gmtime()),
|
|
||||||
'save_directory': save_dir,
|
|
||||||
'world_size': world_size,
|
|
||||||
}
|
|
||||||
if config['use_comet']:
|
|
||||||
comet_logger = comet_ml.Experiment(
|
|
||||||
api_key=config['comet_config']['api_key'],
|
|
||||||
project_name=config['comet_config']['project_name'],
|
|
||||||
workspace=config['comet_config']['workspace'],
|
|
||||||
)
|
|
||||||
comet_logger.add_tag(config['comet_tag'])
|
|
||||||
comet_logger.set_name(config['comet_name'])
|
|
||||||
comet_logger.log_parameters(config)
|
|
||||||
print("Comet Logger Initialized.")
|
|
||||||
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
# Model Initialization
|
|
||||||
tokenizer = KronosTokenizer.from_pretrained(config['finetuned_tokenizer_path'])
|
|
||||||
tokenizer.eval().to(device)
|
|
||||||
|
|
||||||
model = Kronos.from_pretrained(config['pretrained_predictor_path'])
|
|
||||||
model.to(device)
|
|
||||||
model = DDP(model, device_ids=[local_rank], find_unused_parameters=False)
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
print(f"Predictor Model Size: {get_model_size(model.module)}")
|
|
||||||
|
|
||||||
# Start Training
|
|
||||||
dt_result = train_model(
|
|
||||||
model, tokenizer, device, config, save_dir, comet_logger, rank, world_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
master_summary['final_result'] = dt_result
|
|
||||||
with open(os.path.join(save_dir, 'summary.json'), 'w') as f:
|
|
||||||
json.dump(master_summary, f, indent=4)
|
|
||||||
print('Training finished. Summary file saved.')
|
|
||||||
if comet_logger: comet_logger.end()
|
|
||||||
|
|
||||||
cleanup_ddp()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_predictor.py
|
|
||||||
if "WORLD_SIZE" not in os.environ:
|
|
||||||
raise RuntimeError("This script must be launched with `torchrun`.")
|
|
||||||
|
|
||||||
config_instance = Config()
|
|
||||||
main(config_instance.__dict__)
|
|
||||||
@ -1,281 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from time import gmtime, strftime
|
|
||||||
import argparse
|
|
||||||
import datetime
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
|
|
||||||
import comet_ml
|
|
||||||
|
|
||||||
# Ensure project root is in path
|
|
||||||
sys.path.append("../")
|
|
||||||
from config import Config
|
|
||||||
from dataset import QlibDataset
|
|
||||||
from model.kronos import KronosTokenizer
|
|
||||||
# Import shared utilities
|
|
||||||
from utils.training_utils import (
|
|
||||||
setup_ddp,
|
|
||||||
cleanup_ddp,
|
|
||||||
set_seed,
|
|
||||||
get_model_size,
|
|
||||||
format_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_dataloaders(config: dict, rank: int, world_size: int):
|
|
||||||
"""
|
|
||||||
Creates and returns distributed dataloaders for training and validation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict): A dictionary of configuration parameters.
|
|
||||||
rank (int): The global rank of the current process.
|
|
||||||
world_size (int): The total number of processes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: A tuple containing (train_loader, val_loader, train_dataset, valid_dataset).
|
|
||||||
"""
|
|
||||||
print(f"[Rank {rank}] Creating distributed dataloaders...")
|
|
||||||
train_dataset = QlibDataset('train')
|
|
||||||
valid_dataset = QlibDataset('val')
|
|
||||||
print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}")
|
|
||||||
|
|
||||||
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
|
|
||||||
val_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, rank=rank, shuffle=False)
|
|
||||||
|
|
||||||
train_loader = DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
batch_size=config['batch_size'],
|
|
||||||
sampler=train_sampler,
|
|
||||||
shuffle=False, # Shuffle is handled by the sampler
|
|
||||||
num_workers=config.get('num_workers', 2),
|
|
||||||
pin_memory=True,
|
|
||||||
drop_last=True
|
|
||||||
)
|
|
||||||
val_loader = DataLoader(
|
|
||||||
valid_dataset,
|
|
||||||
batch_size=config['batch_size'],
|
|
||||||
sampler=val_sampler,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=config.get('num_workers', 2),
|
|
||||||
pin_memory=True,
|
|
||||||
drop_last=False
|
|
||||||
)
|
|
||||||
print(f"[Rank {rank}] Dataloaders created. Train steps/epoch: {len(train_loader)}, Val steps: {len(val_loader)}")
|
|
||||||
return train_loader, val_loader, train_dataset, valid_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def train_model(model, device, config, save_dir, logger, rank, world_size):
|
|
||||||
"""
|
|
||||||
The main training and validation loop for the tokenizer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (DDP): The DDP-wrapped model to train.
|
|
||||||
device (torch.device): The device for the current process.
|
|
||||||
config (dict): Configuration dictionary.
|
|
||||||
save_dir (str): Directory to save checkpoints.
|
|
||||||
logger (comet_ml.Experiment): Comet logger instance.
|
|
||||||
rank (int): Global rank of the process.
|
|
||||||
world_size (int): Total number of processes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: A tuple containing the trained model and a dictionary of results.
|
|
||||||
"""
|
|
||||||
start_time = time.time()
|
|
||||||
if rank == 0:
|
|
||||||
effective_bs = config['batch_size'] * world_size * config['accumulation_steps']
|
|
||||||
print(f"[Rank {rank}] BATCHSIZE (per GPU): {config['batch_size']}")
|
|
||||||
print(f"[Rank {rank}] Effective total batch size: {effective_bs}")
|
|
||||||
|
|
||||||
train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size)
|
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(
|
|
||||||
model.parameters(),
|
|
||||||
lr=config['tokenizer_learning_rate'],
|
|
||||||
weight_decay=config['adam_weight_decay']
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
|
||||||
optimizer=optimizer,
|
|
||||||
max_lr=config['tokenizer_learning_rate'],
|
|
||||||
steps_per_epoch=len(train_loader),
|
|
||||||
epochs=config['epochs'],
|
|
||||||
pct_start=0.03,
|
|
||||||
div_factor=10
|
|
||||||
)
|
|
||||||
|
|
||||||
best_val_loss = float('inf')
|
|
||||||
dt_result = {}
|
|
||||||
batch_idx_global_train = 0
|
|
||||||
|
|
||||||
for epoch_idx in range(config['epochs']):
|
|
||||||
epoch_start_time = time.time()
|
|
||||||
model.train()
|
|
||||||
train_loader.sampler.set_epoch(epoch_idx)
|
|
||||||
|
|
||||||
# Set dataset seeds for reproducible sampling
|
|
||||||
train_dataset.set_epoch_seed(epoch_idx * 10000 + rank)
|
|
||||||
valid_dataset.set_epoch_seed(0) # Keep validation sampling consistent
|
|
||||||
|
|
||||||
for i, (ori_batch_x, _) in enumerate(train_loader):
|
|
||||||
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
|
|
||||||
|
|
||||||
# --- Gradient Accumulation Loop ---
|
|
||||||
current_batch_total_loss = 0.0
|
|
||||||
for j in range(config['accumulation_steps']):
|
|
||||||
start_idx = j * (ori_batch_x.shape[0] // config['accumulation_steps'])
|
|
||||||
end_idx = (j + 1) * (ori_batch_x.shape[0] // config['accumulation_steps'])
|
|
||||||
batch_x = ori_batch_x[start_idx:end_idx]
|
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
zs, bsq_loss, _, _ = model(batch_x)
|
|
||||||
z_pre, z = zs
|
|
||||||
|
|
||||||
# Loss calculation
|
|
||||||
recon_loss_pre = F.mse_loss(z_pre, batch_x)
|
|
||||||
recon_loss_all = F.mse_loss(z, batch_x)
|
|
||||||
recon_loss = recon_loss_pre + recon_loss_all
|
|
||||||
loss = (recon_loss + bsq_loss) / 2 # Assuming w_1=w_2=1
|
|
||||||
|
|
||||||
loss_scaled = loss / config['accumulation_steps']
|
|
||||||
current_batch_total_loss += loss.item()
|
|
||||||
loss_scaled.backward()
|
|
||||||
|
|
||||||
# --- Optimizer Step after Accumulation ---
|
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
|
|
||||||
optimizer.step()
|
|
||||||
scheduler.step()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# --- Logging (Master Process Only) ---
|
|
||||||
if rank == 0 and (batch_idx_global_train + 1) % config['log_interval'] == 0:
|
|
||||||
avg_loss = current_batch_total_loss / config['accumulation_steps']
|
|
||||||
print(
|
|
||||||
f"[Rank {rank}, Epoch {epoch_idx + 1}/{config['epochs']}, Step {i + 1}/{len(train_loader)}] "
|
|
||||||
f"LR {optimizer.param_groups[0]['lr']:.6f}, Loss: {avg_loss:.4f}"
|
|
||||||
)
|
|
||||||
if rank == 0 and logger:
|
|
||||||
avg_loss = current_batch_total_loss / config['accumulation_steps']
|
|
||||||
logger.log_metric('train_tokenizer_loss_batch', avg_loss, step=batch_idx_global_train)
|
|
||||||
logger.log_metric(f'train_vqvae_vq_loss_each_batch', bsq_loss.item(), step=batch_idx_global_train)
|
|
||||||
logger.log_metric(f'train_recon_loss_pre_each_batch', recon_loss_pre.item(), step=batch_idx_global_train)
|
|
||||||
logger.log_metric(f'train_recon_loss_each_batch', recon_loss_all.item(), step=batch_idx_global_train)
|
|
||||||
logger.log_metric('tokenizer_learning_rate', optimizer.param_groups[0]["lr"], step=batch_idx_global_train)
|
|
||||||
|
|
||||||
batch_idx_global_train += 1
|
|
||||||
|
|
||||||
# --- Validation Loop ---
|
|
||||||
model.eval()
|
|
||||||
tot_val_loss_sum_rank = 0.0
|
|
||||||
val_sample_count_rank = 0
|
|
||||||
with torch.no_grad():
|
|
||||||
for ori_batch_x, _ in val_loader:
|
|
||||||
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
|
|
||||||
zs, _, _, _ = model(ori_batch_x)
|
|
||||||
_, z = zs
|
|
||||||
val_loss_item = F.mse_loss(z, ori_batch_x)
|
|
||||||
|
|
||||||
tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0)
|
|
||||||
val_sample_count_rank += ori_batch_x.size(0)
|
|
||||||
|
|
||||||
# Reduce validation losses from all processes
|
|
||||||
val_loss_sum_tensor = torch.tensor(tot_val_loss_sum_rank, device=device)
|
|
||||||
val_count_tensor = torch.tensor(val_sample_count_rank, device=device)
|
|
||||||
dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM)
|
|
||||||
dist.all_reduce(val_count_tensor, op=dist.ReduceOp.SUM)
|
|
||||||
|
|
||||||
avg_val_loss = val_loss_sum_tensor.item() / val_count_tensor.item() if val_count_tensor.item() > 0 else 0
|
|
||||||
|
|
||||||
# --- End of Epoch Summary & Checkpointing (Master Process Only) ---
|
|
||||||
if rank == 0:
|
|
||||||
print(f"\n--- Epoch {epoch_idx + 1}/{config['epochs']} Summary ---")
|
|
||||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
|
||||||
print(f"Time This Epoch: {format_time(time.time() - epoch_start_time)}")
|
|
||||||
print(f"Total Time Elapsed: {format_time(time.time() - start_time)}\n")
|
|
||||||
if logger:
|
|
||||||
logger.log_metric('val_tokenizer_loss_epoch', avg_val_loss, epoch=epoch_idx)
|
|
||||||
|
|
||||||
if avg_val_loss < best_val_loss:
|
|
||||||
best_val_loss = avg_val_loss
|
|
||||||
save_path = f"{save_dir}/checkpoints/best_model"
|
|
||||||
model.module.save_pretrained(save_path)
|
|
||||||
print(f"Best model saved to {save_path} (Val Loss: {best_val_loss:.4f})")
|
|
||||||
if logger:
|
|
||||||
logger.log_model("best_model", save_path)
|
|
||||||
|
|
||||||
dist.barrier() # Ensure all processes finish the epoch before starting the next one.
|
|
||||||
|
|
||||||
dt_result['best_val_loss'] = best_val_loss
|
|
||||||
return model, dt_result
|
|
||||||
|
|
||||||
|
|
||||||
def main(config: dict):
|
|
||||||
"""
|
|
||||||
Main function to orchestrate the DDP training process.
|
|
||||||
"""
|
|
||||||
rank, world_size, local_rank = setup_ddp()
|
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
|
||||||
set_seed(config['seed'], rank)
|
|
||||||
|
|
||||||
save_dir = os.path.join(config['save_path'], config['tokenizer_save_folder_name'])
|
|
||||||
|
|
||||||
# Logger and summary setup (master process only)
|
|
||||||
comet_logger, master_summary = None, {}
|
|
||||||
if rank == 0:
|
|
||||||
os.makedirs(os.path.join(save_dir, 'checkpoints'), exist_ok=True)
|
|
||||||
master_summary = {
|
|
||||||
'start_time': strftime("%Y-%m-%dT%H-%M-%S", gmtime()),
|
|
||||||
'save_directory': save_dir,
|
|
||||||
'world_size': world_size,
|
|
||||||
}
|
|
||||||
if config['use_comet']:
|
|
||||||
comet_logger = comet_ml.Experiment(
|
|
||||||
api_key=config['comet_config']['api_key'],
|
|
||||||
project_name=config['comet_config']['project_name'],
|
|
||||||
workspace=config['comet_config']['workspace'],
|
|
||||||
)
|
|
||||||
comet_logger.add_tag(config['comet_tag'])
|
|
||||||
comet_logger.set_name(config['comet_name'])
|
|
||||||
comet_logger.log_parameters(config)
|
|
||||||
print("Comet Logger Initialized.")
|
|
||||||
|
|
||||||
dist.barrier() # Ensure save directory is created before proceeding
|
|
||||||
|
|
||||||
# Model Initialization
|
|
||||||
model = KronosTokenizer.from_pretrained(config['pretrained_tokenizer_path'])
|
|
||||||
model.to(device)
|
|
||||||
model = DDP(model, device_ids=[local_rank], find_unused_parameters=False)
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
print(f"Model Size: {get_model_size(model.module)}")
|
|
||||||
|
|
||||||
# Start Training
|
|
||||||
_, dt_result = train_model(
|
|
||||||
model, device, config, save_dir, comet_logger, rank, world_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Finalize and save summary (master process only)
|
|
||||||
if rank == 0:
|
|
||||||
master_summary['final_result'] = dt_result
|
|
||||||
with open(os.path.join(save_dir, 'summary.json'), 'w') as f:
|
|
||||||
json.dump(master_summary, f, indent=4)
|
|
||||||
print('Training finished. Summary file saved.')
|
|
||||||
if comet_logger:
|
|
||||||
comet_logger.end()
|
|
||||||
|
|
||||||
cleanup_ddp()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_tokenizer.py
|
|
||||||
if "WORLD_SIZE" not in os.environ:
|
|
||||||
raise RuntimeError("This script must be launched with `torchrun`.")
|
|
||||||
|
|
||||||
config_instance = Config()
|
|
||||||
main(config_instance.__dict__)
|
|
||||||
@ -1,118 +0,0 @@
|
|||||||
import os
|
|
||||||
import random
|
|
||||||
import datetime
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
|
|
||||||
def setup_ddp():
|
|
||||||
"""
|
|
||||||
Initializes the distributed data parallel environment.
|
|
||||||
|
|
||||||
This function relies on environment variables set by `torchrun` or a similar
|
|
||||||
launcher. It initializes the process group and sets the CUDA device for the
|
|
||||||
current process.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: A tuple containing (rank, world_size, local_rank).
|
|
||||||
"""
|
|
||||||
if not dist.is_available():
|
|
||||||
raise RuntimeError("torch.distributed is not available.")
|
|
||||||
|
|
||||||
dist.init_process_group(backend="nccl")
|
|
||||||
rank = int(os.environ["RANK"])
|
|
||||||
world_size = int(os.environ["WORLD_SIZE"])
|
|
||||||
local_rank = int(os.environ["LOCAL_RANK"])
|
|
||||||
torch.cuda.set_device(local_rank)
|
|
||||||
print(
|
|
||||||
f"[DDP Setup] Global Rank: {rank}/{world_size}, "
|
|
||||||
f"Local Rank (GPU): {local_rank} on device {torch.cuda.current_device()}"
|
|
||||||
)
|
|
||||||
return rank, world_size, local_rank
|
|
||||||
|
|
||||||
|
|
||||||
def cleanup_ddp():
|
|
||||||
"""Cleans up the distributed process group."""
|
|
||||||
if dist.is_initialized():
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed: int, rank: int = 0):
|
|
||||||
"""
|
|
||||||
Sets the random seed for reproducibility across all relevant libraries.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed (int): The base seed value.
|
|
||||||
rank (int): The process rank, used to ensure different processes have
|
|
||||||
different seeds, which can be important for data loading.
|
|
||||||
"""
|
|
||||||
actual_seed = seed + rank
|
|
||||||
random.seed(actual_seed)
|
|
||||||
np.random.seed(actual_seed)
|
|
||||||
torch.manual_seed(actual_seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(actual_seed)
|
|
||||||
# The two lines below can impact performance, so they are often
|
|
||||||
# reserved for final experiments where reproducibility is critical.
|
|
||||||
torch.backends.cudnn.deterministic = True
|
|
||||||
torch.backends.cudnn.benchmark = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_size(model: torch.nn.Module) -> str:
|
|
||||||
"""
|
|
||||||
Calculates the number of trainable parameters in a PyTorch model and returns
|
|
||||||
it as a human-readable string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (torch.nn.Module): The PyTorch model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A string representing the model size (e.g., "175.0B", "7.1M", "50.5K").
|
|
||||||
"""
|
|
||||||
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
|
|
||||||
if total_params >= 1e9:
|
|
||||||
return f"{total_params / 1e9:.1f}B" # Billions
|
|
||||||
elif total_params >= 1e6:
|
|
||||||
return f"{total_params / 1e6:.1f}M" # Millions
|
|
||||||
else:
|
|
||||||
return f"{total_params / 1e3:.1f}K" # Thousands
|
|
||||||
|
|
||||||
|
|
||||||
def reduce_tensor(tensor: torch.Tensor, world_size: int, op=dist.ReduceOp.SUM) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Reduces a tensor's value across all processes in a distributed setup.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (torch.Tensor): The tensor to be reduced.
|
|
||||||
world_size (int): The total number of processes.
|
|
||||||
op (dist.ReduceOp, optional): The reduction operation (SUM, AVG, etc.).
|
|
||||||
Defaults to dist.ReduceOp.SUM.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The reduced tensor, which will be identical on all processes.
|
|
||||||
"""
|
|
||||||
rt = tensor.clone()
|
|
||||||
dist.all_reduce(rt, op=op)
|
|
||||||
# Note: `dist.ReduceOp.AVG` is available in newer torch versions.
|
|
||||||
# For compatibility, manual division is sometimes used after a SUM.
|
|
||||||
if op == dist.ReduceOp.AVG:
|
|
||||||
rt /= world_size
|
|
||||||
return rt
|
|
||||||
|
|
||||||
|
|
||||||
def format_time(seconds: float) -> str:
|
|
||||||
"""
|
|
||||||
Formats a duration in seconds into a human-readable H:M:S string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seconds (float): The total seconds.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The formatted time string (e.g., "0:15:32").
|
|
||||||
"""
|
|
||||||
return str(datetime.timedelta(seconds=int(seconds)))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1,120 +0,0 @@
|
|||||||
# Kronos Fine-tuning on Custom CSV Datasets
|
|
||||||
|
|
||||||
This module provides a comprehensive pipeline for fine-tuning Kronos models on your own CSV-formatted financial data. It supports both sequential training (tokenizer followed by predictor) and individual component training, with full distributed training capabilities.
|
|
||||||
|
|
||||||
|
|
||||||
## 1. Data Preparation
|
|
||||||
|
|
||||||
### Required Data Format
|
|
||||||
|
|
||||||
Your CSV file must contain the following columns:
|
|
||||||
- `timestamps`: DateTime stamps for each data point
|
|
||||||
- `open`: Opening price
|
|
||||||
- `high`: Highest price
|
|
||||||
- `low`: Lowest price
|
|
||||||
- `close`: Closing price
|
|
||||||
- `volume`: Trading volume
|
|
||||||
- `amount`: Trading amount
|
|
||||||
|
|
||||||
(volume and amount can be 0 if not available)
|
|
||||||
|
|
||||||
### Sample Data Format
|
|
||||||
|
|
||||||
| timestamps | open | close | high | low | volume | amount |
|
|
||||||
|------------|------|-------|------|-----|--------|--------|
|
|
||||||
| 2019/11/26 9:35 | 182.45215 | 184.45215 | 184.95215 | 182.45215 | 15136000 | 0 |
|
|
||||||
| 2019/11/26 9:40 | 184.35215 | 183.85215 | 184.55215 | 183.45215 | 4433300 | 0 |
|
|
||||||
| 2019/11/26 9:45 | 183.85215 | 183.35215 | 183.95215 | 182.95215 | 3070900 | 0 |
|
|
||||||
|
|
||||||
> **Reference**: Check `data/HK_ali_09988_kline_5min_all.csv` for a complete example of the proper data format.
|
|
||||||
|
|
||||||
|
|
||||||
## 2. Config Preparation
|
|
||||||
|
|
||||||
|
|
||||||
Please edit the correct data path & pretrained model path and set your training parameters.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# Data configuration
|
|
||||||
data:
|
|
||||||
data_path: "/path/to/your/data.csv"
|
|
||||||
lookback_window: 512 # Historical data points to use
|
|
||||||
predict_window: 48 # Future points to predict
|
|
||||||
max_context: 512 # Maximum context length
|
|
||||||
|
|
||||||
...
|
|
||||||
|
|
||||||
```
|
|
||||||
There are some other settings here, please see `configs/config_ali09988_candle-5min.yaml` for more comments.
|
|
||||||
|
|
||||||
## 3. Training
|
|
||||||
|
|
||||||
### Method 1: Sequential Training (Recommended)
|
|
||||||
|
|
||||||
The `train_sequential.py` script handles the complete training pipeline automatically:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Complete training (tokenizer + predictor)
|
|
||||||
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml
|
|
||||||
|
|
||||||
# Skip existing models
|
|
||||||
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing
|
|
||||||
|
|
||||||
# Only train tokenizer
|
|
||||||
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel
|
|
||||||
|
|
||||||
# Only train predictor
|
|
||||||
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer
|
|
||||||
```
|
|
||||||
|
|
||||||
### Method 2: Individual Component Training
|
|
||||||
|
|
||||||
Train each component separately for more control:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Step 1: Train tokenizer
|
|
||||||
python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml
|
|
||||||
|
|
||||||
# Step 2: Train predictor (requires fine-tuned tokenizer)
|
|
||||||
python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
### DDP Training
|
|
||||||
|
|
||||||
For faster training on multiple GPUs:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Set communication backend (nccl for NVIDIA GPUs, gloo for CPU/mixed)
|
|
||||||
DIST_BACKEND=nccl \
|
|
||||||
torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
## 4. Training Results
|
|
||||||
|
|
||||||
The training process generates several outputs:
|
|
||||||
|
|
||||||
### Model Checkpoints
|
|
||||||
- **Tokenizer**: Saved to `{base_save_path}/{exp_name}/tokenizer/best_model/`
|
|
||||||
- **Predictor**: Saved to `{base_save_path}/{exp_name}/basemodel/best_model/`
|
|
||||||
|
|
||||||
### Training Logs
|
|
||||||
- **Console output**: Real-time training progress and metrics
|
|
||||||
- **Log files**: Detailed logs saved to `{base_save_path}/logs/`
|
|
||||||
- **Validation tracking**: Best models are saved based on validation loss
|
|
||||||
|
|
||||||
## 5. Prediction Vis
|
|
||||||
|
|
||||||
The following images show example training results on alibaba (HK stock) data:
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1,118 +0,0 @@
|
|||||||
# Kronos微调-支持自定义CSV数据集
|
|
||||||
|
|
||||||
这是一个在自定义的CSV格式数据上微调Kronos模型的完整流程。包含顺序训练(先训练tokenizer再训练predictor)和单独模块训练,同时支持分布式训练。
|
|
||||||
|
|
||||||
|
|
||||||
## 1. 准备数据
|
|
||||||
|
|
||||||
### 数据格式
|
|
||||||
|
|
||||||
CSV文件必须包含以下列:
|
|
||||||
- `timestamps`: 每个数据点的时间戳
|
|
||||||
- `open`: 开盘价
|
|
||||||
- `high`: 最高价
|
|
||||||
- `low`: 最低价
|
|
||||||
- `close`: 收盘价
|
|
||||||
- `volume`: 交易量
|
|
||||||
- `amount`: 交易金额
|
|
||||||
|
|
||||||
(volume和amount可以全0如果没有这部分的数据)
|
|
||||||
|
|
||||||
### 示例数据格式
|
|
||||||
|
|
||||||
| timestamps | open | close | high | low | volume | amount |
|
|
||||||
|------------|------|-------|------|-----|--------|--------|
|
|
||||||
| 2019/11/26 9:35 | 182.45215 | 184.45215 | 184.95215 | 182.45215 | 15136000 | 0 |
|
|
||||||
| 2019/11/26 9:40 | 184.35215 | 183.85215 | 184.55215 | 183.45215 | 4433300 | 0 |
|
|
||||||
| 2019/11/26 9:45 | 183.85215 | 183.35215 | 183.95215 | 182.95215 | 3070900 | 0 |
|
|
||||||
|
|
||||||
> **标准数据样例**: `data/HK_ali_09988_kline_5min_all.csv`
|
|
||||||
|
|
||||||
## 2. 准备config文件
|
|
||||||
|
|
||||||
data_path及预训练模型路径需要修改,训练参数可以自己调节
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# 数据配置
|
|
||||||
data:
|
|
||||||
data_path: "/path/to/your/data.csv"
|
|
||||||
lookback_window: 512 # 要使用的历史数据点
|
|
||||||
predict_window: 48 # 要预测的未来点数
|
|
||||||
max_context: 512 # 最大上下文长度
|
|
||||||
|
|
||||||
...
|
|
||||||
|
|
||||||
```
|
|
||||||
这里还有其他一些设置, `configs/config_ali09988_candle-5min.yaml` 有更详细的注释。
|
|
||||||
|
|
||||||
## 3. 训练
|
|
||||||
|
|
||||||
### 方法1: 直接顺序训练
|
|
||||||
|
|
||||||
`train_sequential.py` 脚本自动处理完整的训练流程:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 完整训练(tokenizer + predictor)
|
|
||||||
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml
|
|
||||||
|
|
||||||
# 跳过已存在的模型
|
|
||||||
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing
|
|
||||||
|
|
||||||
# 只训练tokenizer
|
|
||||||
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel
|
|
||||||
|
|
||||||
# 只训练predictor
|
|
||||||
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer
|
|
||||||
```
|
|
||||||
|
|
||||||
### 方法2: 单独组件训练
|
|
||||||
|
|
||||||
可以单独训练每个组件:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 步骤1: 训练tokenizer
|
|
||||||
python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml
|
|
||||||
|
|
||||||
# 步骤2: 训练predictor(需要微调后的tokenizer)
|
|
||||||
python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
### DDP训练
|
|
||||||
|
|
||||||
如果有多卡,可以开启ddp加速训练:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 设置通信后端(NVIDIA GPU用nccl,CPU/混合用gloo)
|
|
||||||
DIST_BACKEND=nccl \
|
|
||||||
torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
## 4. 训练结果
|
|
||||||
|
|
||||||
训练过程生成以下输出:
|
|
||||||
|
|
||||||
### 模型检查点
|
|
||||||
- **Tokenizer**: 保存到 `{base_save_path}/{exp_name}/tokenizer/best_model/`
|
|
||||||
- **Predictor**: 保存到 `{base_save_path}/{exp_name}/basemodel/best_model/`
|
|
||||||
|
|
||||||
### 训练日志
|
|
||||||
- **控制台输出**: 实时训练进度和指标
|
|
||||||
- **日志文件**: 详细日志保存到 `{base_save_path}/logs/`
|
|
||||||
- **验证跟踪**: 基于验证损失保存最佳模型
|
|
||||||
|
|
||||||
## 5. 预测可视化
|
|
||||||
|
|
||||||
以下图像显示了kronos在阿里巴巴股票数据上微调后的示例训练结果:
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1,267 +0,0 @@
|
|||||||
import os
|
|
||||||
import yaml
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigLoader:
|
|
||||||
|
|
||||||
def __init__(self, config_path: str):
|
|
||||||
|
|
||||||
self.config_path = config_path
|
|
||||||
self.config = self._load_config()
|
|
||||||
|
|
||||||
def _load_config(self) -> Dict[str, Any]:
|
|
||||||
|
|
||||||
if not os.path.exists(self.config_path):
|
|
||||||
raise FileNotFoundError(f"config file not found: {self.config_path}")
|
|
||||||
|
|
||||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
|
||||||
config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
config = self._resolve_dynamic_paths(config)
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
def _resolve_dynamic_paths(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
|
|
||||||
exp_name = config.get('model_paths', {}).get('exp_name', '')
|
|
||||||
if not exp_name:
|
|
||||||
return config
|
|
||||||
|
|
||||||
base_path = config.get('model_paths', {}).get('base_path', '')
|
|
||||||
path_templates = {
|
|
||||||
'base_save_path': f"{base_path}/{exp_name}",
|
|
||||||
'finetuned_tokenizer': f"{base_path}/{exp_name}/tokenizer/best_model"
|
|
||||||
}
|
|
||||||
|
|
||||||
if 'model_paths' in config:
|
|
||||||
for key, template in path_templates.items():
|
|
||||||
if key in config['model_paths']:
|
|
||||||
# only use template when the original value is empty string
|
|
||||||
current_value = config['model_paths'][key]
|
|
||||||
if current_value == "" or current_value is None:
|
|
||||||
config['model_paths'][key] = template
|
|
||||||
else:
|
|
||||||
# if the original value is not empty, use template to replace the {exp_name} placeholder
|
|
||||||
if isinstance(current_value, str) and '{exp_name}' in current_value:
|
|
||||||
config['model_paths'][key] = current_value.format(exp_name=exp_name)
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
def get(self, key: str, default=None):
|
|
||||||
|
|
||||||
keys = key.split('.')
|
|
||||||
value = self.config
|
|
||||||
|
|
||||||
try:
|
|
||||||
for k in keys:
|
|
||||||
value = value[k]
|
|
||||||
return value
|
|
||||||
except (KeyError, TypeError):
|
|
||||||
return default
|
|
||||||
|
|
||||||
def get_data_config(self) -> Dict[str, Any]:
|
|
||||||
return self.config.get('data', {})
|
|
||||||
|
|
||||||
def get_training_config(self) -> Dict[str, Any]:
|
|
||||||
return self.config.get('training', {})
|
|
||||||
|
|
||||||
def get_model_paths(self) -> Dict[str, str]:
|
|
||||||
return self.config.get('model_paths', {})
|
|
||||||
|
|
||||||
def get_experiment_config(self) -> Dict[str, Any]:
|
|
||||||
return self.config.get('experiment', {})
|
|
||||||
|
|
||||||
def get_device_config(self) -> Dict[str, Any]:
|
|
||||||
return self.config.get('device', {})
|
|
||||||
|
|
||||||
def get_distributed_config(self) -> Dict[str, Any]:
|
|
||||||
return self.config.get('distributed', {})
|
|
||||||
|
|
||||||
def update_config(self, updates: Dict[str, Any]):
|
|
||||||
|
|
||||||
def update_nested_dict(d, u):
|
|
||||||
for k, v in u.items():
|
|
||||||
if isinstance(v, dict):
|
|
||||||
d[k] = update_nested_dict(d.get(k, {}), v)
|
|
||||||
else:
|
|
||||||
d[k] = v
|
|
||||||
return d
|
|
||||||
|
|
||||||
self.config = update_nested_dict(self.config, updates)
|
|
||||||
|
|
||||||
def save_config(self, save_path: str = None):
|
|
||||||
|
|
||||||
if save_path is None:
|
|
||||||
save_path = self.config_path
|
|
||||||
|
|
||||||
with open(save_path, 'w', encoding='utf-8') as f:
|
|
||||||
yaml.dump(self.config, f, default_flow_style=False, allow_unicode=True, indent=2)
|
|
||||||
|
|
||||||
def print_config(self):
|
|
||||||
print("=" * 50)
|
|
||||||
print("Current configuration:")
|
|
||||||
print("=" * 50)
|
|
||||||
yaml.dump(self.config, default_flow_style=False, allow_unicode=True, indent=2)
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomFinetuneConfig:
|
|
||||||
|
|
||||||
def __init__(self, config_path: str = None):
|
|
||||||
|
|
||||||
if config_path is None:
|
|
||||||
config_path = os.path.join(os.path.dirname(__file__), 'config.yaml')
|
|
||||||
|
|
||||||
self.loader = ConfigLoader(config_path)
|
|
||||||
self._load_all_configs()
|
|
||||||
|
|
||||||
def _load_all_configs(self):
|
|
||||||
|
|
||||||
data_config = self.loader.get_data_config()
|
|
||||||
self.data_path = data_config.get('data_path')
|
|
||||||
self.lookback_window = data_config.get('lookback_window', 512)
|
|
||||||
self.predict_window = data_config.get('predict_window', 48)
|
|
||||||
self.max_context = data_config.get('max_context', 512)
|
|
||||||
self.clip = data_config.get('clip', 5.0)
|
|
||||||
self.train_ratio = data_config.get('train_ratio', 0.9)
|
|
||||||
self.val_ratio = data_config.get('val_ratio', 0.1)
|
|
||||||
self.test_ratio = data_config.get('test_ratio', 0.0)
|
|
||||||
|
|
||||||
# training configuration
|
|
||||||
training_config = self.loader.get_training_config()
|
|
||||||
# support training epochs of tokenizer and basemodel separately
|
|
||||||
self.tokenizer_epochs = training_config.get('tokenizer_epochs', 30)
|
|
||||||
self.basemodel_epochs = training_config.get('basemodel_epochs', 30)
|
|
||||||
|
|
||||||
if 'epochs' in training_config and 'tokenizer_epochs' not in training_config:
|
|
||||||
self.tokenizer_epochs = training_config.get('epochs', 30)
|
|
||||||
if 'epochs' in training_config and 'basemodel_epochs' not in training_config:
|
|
||||||
self.basemodel_epochs = training_config.get('epochs', 30)
|
|
||||||
|
|
||||||
self.batch_size = training_config.get('batch_size', 160)
|
|
||||||
self.log_interval = training_config.get('log_interval', 50)
|
|
||||||
self.num_workers = training_config.get('num_workers', 6)
|
|
||||||
self.seed = training_config.get('seed', 100)
|
|
||||||
self.tokenizer_learning_rate = training_config.get('tokenizer_learning_rate', 2e-4)
|
|
||||||
self.predictor_learning_rate = training_config.get('predictor_learning_rate', 4e-5)
|
|
||||||
self.adam_beta1 = training_config.get('adam_beta1', 0.9)
|
|
||||||
self.adam_beta2 = training_config.get('adam_beta2', 0.95)
|
|
||||||
self.adam_weight_decay = training_config.get('adam_weight_decay', 0.1)
|
|
||||||
self.accumulation_steps = training_config.get('accumulation_steps', 1)
|
|
||||||
|
|
||||||
model_paths = self.loader.get_model_paths()
|
|
||||||
self.exp_name = model_paths.get('exp_name', 'default_experiment')
|
|
||||||
self.pretrained_tokenizer_path = model_paths.get('pretrained_tokenizer')
|
|
||||||
self.pretrained_predictor_path = model_paths.get('pretrained_predictor')
|
|
||||||
self.base_save_path = model_paths.get('base_save_path')
|
|
||||||
self.tokenizer_save_name = model_paths.get('tokenizer_save_name', 'tokenizer')
|
|
||||||
self.basemodel_save_name = model_paths.get('basemodel_save_name', 'basemodel')
|
|
||||||
self.finetuned_tokenizer_path = model_paths.get('finetuned_tokenizer')
|
|
||||||
|
|
||||||
experiment_config = self.loader.get_experiment_config()
|
|
||||||
self.experiment_name = experiment_config.get('name', 'kronos_custom_finetune')
|
|
||||||
self.experiment_description = experiment_config.get('description', '')
|
|
||||||
self.use_comet = experiment_config.get('use_comet', False)
|
|
||||||
self.train_tokenizer = experiment_config.get('train_tokenizer', True)
|
|
||||||
self.train_basemodel = experiment_config.get('train_basemodel', True)
|
|
||||||
self.skip_existing = experiment_config.get('skip_existing', False)
|
|
||||||
|
|
||||||
unified_pretrained = experiment_config.get('pre_trained', None)
|
|
||||||
self.pre_trained_tokenizer = experiment_config.get('pre_trained_tokenizer', unified_pretrained if unified_pretrained is not None else True)
|
|
||||||
self.pre_trained_predictor = experiment_config.get('pre_trained_predictor', unified_pretrained if unified_pretrained is not None else True)
|
|
||||||
|
|
||||||
device_config = self.loader.get_device_config()
|
|
||||||
self.use_cuda = device_config.get('use_cuda', True)
|
|
||||||
self.device_id = device_config.get('device_id', 0)
|
|
||||||
|
|
||||||
distributed_config = self.loader.get_distributed_config()
|
|
||||||
self.use_ddp = distributed_config.get('use_ddp', False)
|
|
||||||
self.ddp_backend = distributed_config.get('backend', 'nccl')
|
|
||||||
|
|
||||||
self._compute_full_paths()
|
|
||||||
|
|
||||||
def _compute_full_paths(self):
|
|
||||||
|
|
||||||
self.tokenizer_save_path = os.path.join(self.base_save_path, self.tokenizer_save_name)
|
|
||||||
self.tokenizer_best_model_path = os.path.join(self.tokenizer_save_path, 'best_model')
|
|
||||||
|
|
||||||
self.basemodel_save_path = os.path.join(self.base_save_path, self.basemodel_save_name)
|
|
||||||
self.basemodel_best_model_path = os.path.join(self.basemodel_save_path, 'best_model')
|
|
||||||
|
|
||||||
def get_tokenizer_config(self):
|
|
||||||
|
|
||||||
return {
|
|
||||||
'data_path': self.data_path,
|
|
||||||
'lookback_window': self.lookback_window,
|
|
||||||
'predict_window': self.predict_window,
|
|
||||||
'max_context': self.max_context,
|
|
||||||
'clip': self.clip,
|
|
||||||
'train_ratio': self.train_ratio,
|
|
||||||
'val_ratio': self.val_ratio,
|
|
||||||
'test_ratio': self.test_ratio,
|
|
||||||
'epochs': self.tokenizer_epochs,
|
|
||||||
'batch_size': self.batch_size,
|
|
||||||
'log_interval': self.log_interval,
|
|
||||||
'num_workers': self.num_workers,
|
|
||||||
'seed': self.seed,
|
|
||||||
'learning_rate': self.tokenizer_learning_rate,
|
|
||||||
'adam_beta1': self.adam_beta1,
|
|
||||||
'adam_beta2': self.adam_beta2,
|
|
||||||
'adam_weight_decay': self.adam_weight_decay,
|
|
||||||
'accumulation_steps': self.accumulation_steps,
|
|
||||||
'pretrained_model_path': self.pretrained_tokenizer_path,
|
|
||||||
'save_path': self.tokenizer_save_path,
|
|
||||||
'use_comet': self.use_comet
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_basemodel_config(self):
|
|
||||||
|
|
||||||
return {
|
|
||||||
'data_path': self.data_path,
|
|
||||||
'lookback_window': self.lookback_window,
|
|
||||||
'predict_window': self.predict_window,
|
|
||||||
'max_context': self.max_context,
|
|
||||||
'clip': self.clip,
|
|
||||||
'train_ratio': self.train_ratio,
|
|
||||||
'val_ratio': self.val_ratio,
|
|
||||||
'test_ratio': self.test_ratio,
|
|
||||||
'epochs': self.basemodel_epochs,
|
|
||||||
'batch_size': self.batch_size,
|
|
||||||
'log_interval': self.log_interval,
|
|
||||||
'num_workers': self.num_workers,
|
|
||||||
'seed': self.seed,
|
|
||||||
'predictor_learning_rate': self.predictor_learning_rate,
|
|
||||||
'tokenizer_learning_rate': self.tokenizer_learning_rate,
|
|
||||||
'adam_beta1': self.adam_beta1,
|
|
||||||
'adam_beta2': self.adam_beta2,
|
|
||||||
'adam_weight_decay': self.adam_weight_decay,
|
|
||||||
'pretrained_tokenizer_path': self.finetuned_tokenizer_path,
|
|
||||||
'pretrained_predictor_path': self.pretrained_predictor_path,
|
|
||||||
'save_path': self.basemodel_save_path,
|
|
||||||
'use_comet': self.use_comet
|
|
||||||
}
|
|
||||||
|
|
||||||
def print_config_summary(self):
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("Kronos finetuning configuration summary")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"Experiment name: {self.exp_name}")
|
|
||||||
print(f"Data path: {self.data_path}")
|
|
||||||
print(f"Lookback window: {self.lookback_window}")
|
|
||||||
print(f"Predict window: {self.predict_window}")
|
|
||||||
print(f"Tokenizer training epochs: {self.tokenizer_epochs}")
|
|
||||||
print(f"Basemodel training epochs: {self.basemodel_epochs}")
|
|
||||||
print(f"Batch size: {self.batch_size}")
|
|
||||||
print(f"Tokenizer learning rate: {self.tokenizer_learning_rate}")
|
|
||||||
print(f"Predictor learning rate: {self.predictor_learning_rate}")
|
|
||||||
print(f"Train tokenizer: {self.train_tokenizer}")
|
|
||||||
print(f"Train basemodel: {self.train_basemodel}")
|
|
||||||
print(f"Skip existing: {self.skip_existing}")
|
|
||||||
print(f"Use pre-trained tokenizer: {self.pre_trained_tokenizer}")
|
|
||||||
print(f"Use pre-trained predictor: {self.pre_trained_predictor}")
|
|
||||||
print(f"Base save path: {self.base_save_path}")
|
|
||||||
print(f"Tokenizer save path: {self.tokenizer_save_path}")
|
|
||||||
print(f"Basemodel save path: {self.basemodel_save_path}")
|
|
||||||
print("=" * 60)
|
|
||||||
@ -1,72 +0,0 @@
|
|||||||
#This is a template config for custom finetuning kronos on csv data
|
|
||||||
#这是一份模板config,用于kronos的csv自定义数据微调
|
|
||||||
|
|
||||||
data:
|
|
||||||
data_path: "/xxxx/Kronos/finetune_csv/data/HK_ali_09988_kline_5min_all.csv"
|
|
||||||
lookback_window: 512
|
|
||||||
predict_window: 48
|
|
||||||
max_context: 512
|
|
||||||
clip: 5.0
|
|
||||||
# dataset split ratio
|
|
||||||
train_ratio: 0.9
|
|
||||||
val_ratio: 0.1
|
|
||||||
test_ratio: 0.0
|
|
||||||
|
|
||||||
training:
|
|
||||||
# control the training epochs of tokenizer and basemodel
|
|
||||||
tokenizer_epochs: 30
|
|
||||||
basemodel_epochs: 20
|
|
||||||
batch_size: 32
|
|
||||||
log_interval: 50
|
|
||||||
num_workers: 6
|
|
||||||
seed: 42
|
|
||||||
|
|
||||||
tokenizer_learning_rate: 0.0002
|
|
||||||
predictor_learning_rate: 0.000001
|
|
||||||
|
|
||||||
adam_beta1: 0.9
|
|
||||||
adam_beta2: 0.95
|
|
||||||
adam_weight_decay: 0.1
|
|
||||||
|
|
||||||
# gradient accumulation steps for tokenizer training
|
|
||||||
accumulation_steps: 1
|
|
||||||
|
|
||||||
# model path configuration
|
|
||||||
model_paths:
|
|
||||||
# pretrained model path
|
|
||||||
pretrained_tokenizer: "/xxx/Kronos/pretrained/Kronos-Tokenizer-base"
|
|
||||||
pretrained_predictor: "/xxx/Kronos/pretrained/Kronos-base"
|
|
||||||
|
|
||||||
# experiment name - other paths will be generated based on this
|
|
||||||
exp_name: "HK_ali_09988_kline_5min_all"
|
|
||||||
base_path: "/xxx/Kronos/finetune_csv/finetuned/"
|
|
||||||
|
|
||||||
# the following paths will be generated based on exp_name, no need to modify manually
|
|
||||||
# way 1: leave empty string, the system will generate the full path
|
|
||||||
base_save_path: "" # /xxxx/Kronos/finetune_csv/finetuned/{exp_name}
|
|
||||||
finetuned_tokenizer: "" # /xxxx/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model
|
|
||||||
|
|
||||||
# way 2: use template string, {exp_name} will be replaced with the actual experiment name
|
|
||||||
# base_save_path: "/xxxx/Kronos/finetune_csv/finetuned/{exp_name}"
|
|
||||||
# finetuned_tokenizer: "/xxxx/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model"
|
|
||||||
|
|
||||||
tokenizer_save_name: "tokenizer"
|
|
||||||
basemodel_save_name: "basemodel"
|
|
||||||
|
|
||||||
experiment:
|
|
||||||
name: "kronos_custom_finetune"
|
|
||||||
description: "Custom finetune for HK stock data"
|
|
||||||
use_comet: false
|
|
||||||
|
|
||||||
# control the training phase
|
|
||||||
train_tokenizer: true
|
|
||||||
train_basemodel: true
|
|
||||||
|
|
||||||
# if true, skip the existing model training
|
|
||||||
skip_existing: false
|
|
||||||
|
|
||||||
# device configuration
|
|
||||||
device:
|
|
||||||
use_cuda: true
|
|
||||||
device_id: 0
|
|
||||||
|
|
||||||
|
Before Width: | Height: | Size: 474 KiB |
|
Before Width: | Height: | Size: 473 KiB |
|
Before Width: | Height: | Size: 331 KiB |
|
Before Width: | Height: | Size: 449 KiB |
|
Before Width: | Height: | Size: 530 KiB |
@ -1,468 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import pickle
|
|
||||||
import random
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
from time import gmtime, strftime
|
|
||||||
import logging
|
|
||||||
from logging.handlers import RotatingFileHandler
|
|
||||||
import datetime
|
|
||||||
import torch.distributed as dist
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
|
|
||||||
sys.path.append('../')
|
|
||||||
from model import Kronos, KronosTokenizer, KronosPredictor
|
|
||||||
from config_loader import CustomFinetuneConfig
|
|
||||||
|
|
||||||
|
|
||||||
class CustomKlineDataset(Dataset):
|
|
||||||
|
|
||||||
def __init__(self, data_path, data_type='train', lookback_window=90, predict_window=10,
|
|
||||||
clip=5.0, seed=100, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
|
|
||||||
self.data_path = data_path
|
|
||||||
self.data_type = data_type
|
|
||||||
self.lookback_window = lookback_window
|
|
||||||
self.predict_window = predict_window
|
|
||||||
self.window = lookback_window + predict_window + 1
|
|
||||||
self.clip = clip
|
|
||||||
self.seed = seed
|
|
||||||
self.train_ratio = train_ratio
|
|
||||||
self.val_ratio = val_ratio
|
|
||||||
self.test_ratio = test_ratio
|
|
||||||
|
|
||||||
self.feature_list = ['open', 'high', 'low', 'close', 'volume', 'amount']
|
|
||||||
self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month']
|
|
||||||
|
|
||||||
self.py_rng = random.Random(seed)
|
|
||||||
|
|
||||||
self._load_and_preprocess_data()
|
|
||||||
self._split_data_by_time()
|
|
||||||
|
|
||||||
self.n_samples = len(self.data) - self.window + 1
|
|
||||||
|
|
||||||
print(f"[{data_type.upper()}] Data length: {len(self.data)}, Available samples: {self.n_samples}")
|
|
||||||
|
|
||||||
def _load_and_preprocess_data(self):
|
|
||||||
df = pd.read_csv(self.data_path)
|
|
||||||
|
|
||||||
df['timestamps'] = pd.to_datetime(df['timestamps'])
|
|
||||||
df = df.sort_values('timestamps').reset_index(drop=True)
|
|
||||||
|
|
||||||
self.timestamps = df['timestamps'].copy()
|
|
||||||
|
|
||||||
df['minute'] = df['timestamps'].dt.minute
|
|
||||||
df['hour'] = df['timestamps'].dt.hour
|
|
||||||
df['weekday'] = df['timestamps'].dt.weekday
|
|
||||||
df['day'] = df['timestamps'].dt.day
|
|
||||||
df['month'] = df['timestamps'].dt.month
|
|
||||||
|
|
||||||
self.data = df[self.feature_list + self.time_feature_list].copy()
|
|
||||||
|
|
||||||
if self.data.isnull().any().any():
|
|
||||||
print("Warning: Missing values found in data, performing forward fill")
|
|
||||||
self.data = self.data.fillna(method='ffill')
|
|
||||||
|
|
||||||
print(f"Original data time range: {self.timestamps.min()} to {self.timestamps.max()}")
|
|
||||||
print(f"Original data total length: {len(df)} records")
|
|
||||||
|
|
||||||
def _split_data_by_time(self):
|
|
||||||
total_length = len(self.data)
|
|
||||||
|
|
||||||
train_end = int(total_length * self.train_ratio)
|
|
||||||
val_end = int(total_length * (self.train_ratio + self.val_ratio))
|
|
||||||
|
|
||||||
if self.data_type == 'train':
|
|
||||||
self.data = self.data.iloc[:train_end].copy()
|
|
||||||
self.timestamps = self.timestamps.iloc[:train_end].copy()
|
|
||||||
print(f"[{self.data_type.upper()}] Training set: first {train_end} time points ({self.train_ratio})")
|
|
||||||
print(f"[{self.data_type.upper()}] Training set time range: {self.timestamps.min()} to {self.timestamps.max()}")
|
|
||||||
elif self.data_type == 'val':
|
|
||||||
self.data = self.data.iloc[train_end:val_end].copy()
|
|
||||||
self.timestamps = self.timestamps.iloc[train_end:val_end].copy()
|
|
||||||
print(f"[{self.data_type.upper()}] Validation set: time points {train_end+1} to {val_end} ({self.val_ratio})")
|
|
||||||
print(f"[{self.data_type.upper()}] Validation set time range: {self.timestamps.min()} to {self.timestamps.max()}")
|
|
||||||
elif self.data_type == 'test':
|
|
||||||
self.data = self.data.iloc[val_end:].copy()
|
|
||||||
self.timestamps = self.timestamps.iloc[val_end:].copy()
|
|
||||||
print(f"[{self.data_type.upper()}] Test set: after time point {val_end+1}")
|
|
||||||
print(f"[{self.data_type.upper()}] Test set time range: {self.timestamps.min()} to {self.timestamps.max()}")
|
|
||||||
|
|
||||||
print(f"[{self.data_type.upper()}] Data length after split: {len(self.data)} records")
|
|
||||||
|
|
||||||
def set_epoch_seed(self, epoch):
|
|
||||||
epoch_seed = self.seed + epoch
|
|
||||||
self.py_rng.seed(epoch_seed)
|
|
||||||
self.current_epoch = epoch
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.n_samples
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
max_start = len(self.data) - self.window
|
|
||||||
if max_start <= 0:
|
|
||||||
raise ValueError("Data length insufficient to create samples")
|
|
||||||
|
|
||||||
if self.data_type == 'train':
|
|
||||||
epoch = getattr(self, 'current_epoch', 0)
|
|
||||||
start_idx = (idx * 9973 + (epoch + 1) * 104729) % (max_start + 1)
|
|
||||||
else:
|
|
||||||
start_idx = idx % (max_start + 1)
|
|
||||||
|
|
||||||
end_idx = start_idx + self.window
|
|
||||||
|
|
||||||
window_data = self.data.iloc[start_idx:end_idx]
|
|
||||||
|
|
||||||
x = window_data[self.feature_list].values.astype(np.float32)
|
|
||||||
x_stamp = window_data[self.time_feature_list].values.astype(np.float32)
|
|
||||||
|
|
||||||
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
|
|
||||||
x = (x - x_mean) / (x_std + 1e-5)
|
|
||||||
x = np.clip(x, -self.clip, self.clip)
|
|
||||||
|
|
||||||
x_tensor = torch.from_numpy(x)
|
|
||||||
x_stamp_tensor = torch.from_numpy(x_stamp)
|
|
||||||
|
|
||||||
return x_tensor, x_stamp_tensor
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger:
|
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
|
||||||
|
|
||||||
logger = logging.getLogger(f"basemodel_training_rank_{rank}")
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
if logger.handlers:
|
|
||||||
return logger
|
|
||||||
|
|
||||||
log_file = os.path.join(log_dir, f"basemodel_training_rank_{rank}.log")
|
|
||||||
file_handler = RotatingFileHandler(
|
|
||||||
log_file,
|
|
||||||
maxBytes=10*1024*1024,
|
|
||||||
backupCount=5,
|
|
||||||
encoding='utf-8'
|
|
||||||
)
|
|
||||||
file_handler.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
console_handler = None
|
|
||||||
if rank == 0:
|
|
||||||
console_handler = logging.StreamHandler()
|
|
||||||
console_handler.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
|
||||||
)
|
|
||||||
file_handler.setFormatter(formatter)
|
|
||||||
if console_handler is not None:
|
|
||||||
console_handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
if console_handler is not None:
|
|
||||||
logger.addHandler(console_handler)
|
|
||||||
|
|
||||||
logger.info(f"=== Basemodel Training Started ===")
|
|
||||||
logger.info(f"Experiment Name: {exp_name}")
|
|
||||||
logger.info(f"Log Directory: {log_dir}")
|
|
||||||
logger.info(f"Rank: {rank}")
|
|
||||||
logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
||||||
|
|
||||||
return logger
|
|
||||||
|
|
||||||
|
|
||||||
def create_dataloaders(config):
|
|
||||||
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
|
|
||||||
print("Creating data loaders...")
|
|
||||||
|
|
||||||
train_dataset = CustomKlineDataset(
|
|
||||||
data_path=config.data_path,
|
|
||||||
data_type='train',
|
|
||||||
lookback_window=config.lookback_window,
|
|
||||||
predict_window=config.predict_window,
|
|
||||||
clip=config.clip,
|
|
||||||
seed=config.seed,
|
|
||||||
train_ratio=config.train_ratio,
|
|
||||||
val_ratio=config.val_ratio,
|
|
||||||
test_ratio=config.test_ratio
|
|
||||||
)
|
|
||||||
|
|
||||||
val_dataset = CustomKlineDataset(
|
|
||||||
data_path=config.data_path,
|
|
||||||
data_type='val',
|
|
||||||
lookback_window=config.lookback_window,
|
|
||||||
predict_window=config.predict_window,
|
|
||||||
clip=config.clip,
|
|
||||||
seed=config.seed + 1,
|
|
||||||
train_ratio=config.train_ratio,
|
|
||||||
val_ratio=config.val_ratio,
|
|
||||||
test_ratio=config.test_ratio
|
|
||||||
)
|
|
||||||
|
|
||||||
use_ddp = dist.is_available() and dist.is_initialized()
|
|
||||||
train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None
|
|
||||||
val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None
|
|
||||||
|
|
||||||
train_loader = DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
batch_size=config.batch_size,
|
|
||||||
shuffle=(train_sampler is None),
|
|
||||||
num_workers=config.num_workers,
|
|
||||||
pin_memory=True,
|
|
||||||
drop_last=True,
|
|
||||||
sampler=train_sampler
|
|
||||||
)
|
|
||||||
|
|
||||||
val_loader = DataLoader(
|
|
||||||
val_dataset,
|
|
||||||
batch_size=config.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=config.num_workers,
|
|
||||||
pin_memory=True,
|
|
||||||
drop_last=False,
|
|
||||||
sampler=val_sampler
|
|
||||||
)
|
|
||||||
|
|
||||||
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
|
|
||||||
print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
|
|
||||||
|
|
||||||
return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler
|
|
||||||
|
|
||||||
|
|
||||||
def train_model(model, tokenizer, device, config, save_dir, logger):
|
|
||||||
logger.info("Starting training...")
|
|
||||||
use_ddp = dist.is_available() and dist.is_initialized()
|
|
||||||
rank = dist.get_rank() if use_ddp else 0
|
|
||||||
world_size = dist.get_world_size() if use_ddp else 1
|
|
||||||
|
|
||||||
train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config)
|
|
||||||
optimizer = torch.optim.AdamW(
|
|
||||||
model.parameters(),
|
|
||||||
lr=config.predictor_learning_rate,
|
|
||||||
betas=(config.adam_beta1, config.adam_beta2),
|
|
||||||
weight_decay=config.adam_weight_decay
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
|
||||||
optimizer,
|
|
||||||
max_lr=config.predictor_learning_rate,
|
|
||||||
steps_per_epoch=len(train_loader),
|
|
||||||
epochs=config.basemodel_epochs,
|
|
||||||
pct_start=0.03,
|
|
||||||
div_factor=10
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_ddp:
|
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
||||||
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
|
|
||||||
|
|
||||||
best_val_loss = float('inf')
|
|
||||||
batch_idx_global = 0
|
|
||||||
|
|
||||||
for epoch in range(config.basemodel_epochs):
|
|
||||||
epoch_start_time = time.time()
|
|
||||||
model.train()
|
|
||||||
|
|
||||||
train_dataset.set_epoch_seed(epoch * 10000)
|
|
||||||
val_dataset.set_epoch_seed(0)
|
|
||||||
if train_sampler is not None:
|
|
||||||
train_sampler.set_epoch(epoch)
|
|
||||||
|
|
||||||
epoch_train_loss = 0.0
|
|
||||||
train_batches = 0
|
|
||||||
|
|
||||||
for batch_idx, (batch_x, batch_x_stamp) in enumerate(train_loader):
|
|
||||||
batch_x = batch_x.to(device, non_blocking=True)
|
|
||||||
batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
|
|
||||||
|
|
||||||
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
|
|
||||||
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
|
|
||||||
|
|
||||||
logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
|
|
||||||
loss, s1_loss, s2_loss = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=3.0)
|
|
||||||
optimizer.step()
|
|
||||||
scheduler.step()
|
|
||||||
|
|
||||||
epoch_train_loss += loss.item()
|
|
||||||
train_batches += 1
|
|
||||||
|
|
||||||
if (batch_idx_global + 1) % config.log_interval == 0:
|
|
||||||
lr = optimizer.param_groups[0]['lr']
|
|
||||||
log_msg = (f"[Epoch {epoch+1}/{config.basemodel_epochs}, Step {batch_idx+1}/{len(train_loader)}] "
|
|
||||||
f"LR: {lr:.6f}, Loss: {loss.item():.4f}")
|
|
||||||
logger.info(log_msg)
|
|
||||||
if rank == 0:
|
|
||||||
print(log_msg)
|
|
||||||
|
|
||||||
batch_idx_global += 1
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
val_loss = 0.0
|
|
||||||
val_batches = 0
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for batch_x, batch_x_stamp in val_loader:
|
|
||||||
batch_x = batch_x.to(device, non_blocking=True)
|
|
||||||
batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)
|
|
||||||
|
|
||||||
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
|
|
||||||
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
|
|
||||||
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
|
|
||||||
|
|
||||||
logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
|
|
||||||
loss, _, _ = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
|
|
||||||
|
|
||||||
val_loss += loss.item()
|
|
||||||
val_batches += 1
|
|
||||||
|
|
||||||
if use_ddp:
|
|
||||||
tensor_sum = torch.tensor([epoch_train_loss, train_batches, val_loss, val_batches], dtype=torch.float64, device=device)
|
|
||||||
dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM)
|
|
||||||
epoch_train_loss_all = tensor_sum[0].item()
|
|
||||||
train_batches_all = int(tensor_sum[1].item())
|
|
||||||
val_loss_all = tensor_sum[2].item()
|
|
||||||
val_batches_all = int(tensor_sum[3].item())
|
|
||||||
avg_train_loss = (epoch_train_loss_all / train_batches_all) if train_batches_all > 0 else 0.0
|
|
||||||
avg_val_loss = (val_loss_all / val_batches_all) if val_batches_all > 0 else 0.0
|
|
||||||
else:
|
|
||||||
avg_train_loss = epoch_train_loss / train_batches if train_batches > 0 else 0
|
|
||||||
avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
|
|
||||||
|
|
||||||
epoch_time = time.time() - epoch_start_time
|
|
||||||
epoch_summary = (f"\n--- Epoch {epoch+1}/{config.basemodel_epochs} Summary ---\n"
|
|
||||||
f"Training Loss: {avg_train_loss:.4f}\n"
|
|
||||||
f"Validation Loss: {avg_val_loss:.4f}\n"
|
|
||||||
f"Epoch Time: {epoch_time:.2f} seconds\n")
|
|
||||||
logger.info(epoch_summary)
|
|
||||||
if rank == 0:
|
|
||||||
print(epoch_summary)
|
|
||||||
|
|
||||||
if avg_val_loss < best_val_loss:
|
|
||||||
best_val_loss = avg_val_loss
|
|
||||||
if rank == 0:
|
|
||||||
model_save_path = os.path.join(save_dir, "best_model")
|
|
||||||
os.makedirs(model_save_path, exist_ok=True)
|
|
||||||
(model.module if use_ddp else model).save_pretrained(model_save_path)
|
|
||||||
save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})"
|
|
||||||
logger.info(save_msg)
|
|
||||||
print(save_msg)
|
|
||||||
|
|
||||||
return best_val_loss
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Kronos Basemodel Fine-tuning Training')
|
|
||||||
parser.add_argument('--config', type=str, default='config.yaml',
|
|
||||||
help='Configuration file path (default: config.yaml)')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
print(f"Using device: {device}")
|
|
||||||
|
|
||||||
config = CustomFinetuneConfig(args.config)
|
|
||||||
|
|
||||||
os.makedirs(config.basemodel_save_path, exist_ok=True)
|
|
||||||
|
|
||||||
log_dir = os.path.join(config.base_save_path, "logs")
|
|
||||||
logger = setup_logging(config.exp_name, log_dir, 0)
|
|
||||||
|
|
||||||
torch.manual_seed(config.seed)
|
|
||||||
np.random.seed(config.seed)
|
|
||||||
random.seed(config.seed)
|
|
||||||
|
|
||||||
logger.info("Loading pretrained model or random initialization...")
|
|
||||||
print("Loading pretrained model or random initialization...")
|
|
||||||
if getattr(config, 'pre_trained_tokenizer', True):
|
|
||||||
tokenizer = KronosTokenizer.from_pretrained(config.finetuned_tokenizer_path)
|
|
||||||
else:
|
|
||||||
import json, os
|
|
||||||
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for training")
|
|
||||||
cfg_path_tok = os.path.join(config.pretrained_tokenizer_path if hasattr(config, 'pretrained_tokenizer_path') else config.finetuned_tokenizer_path, 'config.json')
|
|
||||||
with open(cfg_path_tok, 'r') as f:
|
|
||||||
arch_t = json.load(f)
|
|
||||||
tokenizer = KronosTokenizer(
|
|
||||||
d_in=arch_t.get('d_in', 6),
|
|
||||||
d_model=arch_t.get('d_model', 256),
|
|
||||||
n_heads=arch_t.get('n_heads', 4),
|
|
||||||
ff_dim=arch_t.get('ff_dim', 512),
|
|
||||||
n_enc_layers=arch_t.get('n_enc_layers', 4),
|
|
||||||
n_dec_layers=arch_t.get('n_dec_layers', 4),
|
|
||||||
ffn_dropout_p=arch_t.get('ffn_dropout_p', 0.0),
|
|
||||||
attn_dropout_p=arch_t.get('attn_dropout_p', 0.0),
|
|
||||||
resid_dropout_p=arch_t.get('resid_dropout_p', 0.0),
|
|
||||||
s1_bits=arch_t.get('s1_bits', 10),
|
|
||||||
s2_bits=arch_t.get('s2_bits', 10),
|
|
||||||
beta=arch_t.get('beta', 0.05),
|
|
||||||
gamma0=arch_t.get('gamma0', 1.0),
|
|
||||||
gamma=arch_t.get('gamma', 1.1),
|
|
||||||
zeta=arch_t.get('zeta', 0.05),
|
|
||||||
group_size=arch_t.get('group_size', 4)
|
|
||||||
)
|
|
||||||
|
|
||||||
if getattr(config, 'pre_trained_predictor', True):
|
|
||||||
model = Kronos.from_pretrained(config.pretrained_predictor_path)
|
|
||||||
else:
|
|
||||||
import json, os
|
|
||||||
print("pre_trained_predictor=False, randomly initializing Predictor architecture for training")
|
|
||||||
cfg_path = os.path.join(config.pretrained_predictor_path, 'config.json')
|
|
||||||
with open(cfg_path, 'r') as f:
|
|
||||||
arch = json.load(f)
|
|
||||||
model = Kronos(
|
|
||||||
s1_bits=arch.get('s1_bits', 10),
|
|
||||||
s2_bits=arch.get('s2_bits', 10),
|
|
||||||
n_layers=arch.get('n_layers', 12),
|
|
||||||
d_model=arch.get('d_model', 832),
|
|
||||||
n_heads=arch.get('n_heads', 16),
|
|
||||||
ff_dim=arch.get('ff_dim', 2048),
|
|
||||||
ffn_dropout_p=arch.get('ffn_dropout_p', 0.2),
|
|
||||||
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
|
||||||
resid_dropout_p=arch.get('resid_dropout_p', 0.2),
|
|
||||||
token_dropout_p=arch.get('token_dropout_p', 0.0),
|
|
||||||
learn_te=arch.get('learn_te', True)
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = tokenizer.to(device)
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
model_size = sum(p.numel() for p in model.parameters())
|
|
||||||
logger.info(f"Model parameters: {model_size:,}")
|
|
||||||
print(f"Model parameters: {model_size:,}")
|
|
||||||
|
|
||||||
logger.info("=== Training Configuration ===")
|
|
||||||
logger.info(f"Data path: {config.data_path}")
|
|
||||||
logger.info(f"Lookback window: {config.lookback_window}")
|
|
||||||
logger.info(f"Predict window: {config.predict_window}")
|
|
||||||
logger.info(f"Batch size: {config.batch_size}")
|
|
||||||
logger.info(f"Learning rate: {config.predictor_learning_rate}")
|
|
||||||
logger.info(f"Training epochs: {config.basemodel_epochs}")
|
|
||||||
logger.info(f"Device: {device}")
|
|
||||||
logger.info(f"Tokenizer path: {config.finetuned_tokenizer_path}")
|
|
||||||
logger.info(f"Pretrained model path: {config.pretrained_predictor_path}")
|
|
||||||
|
|
||||||
logger.info("Starting fine-tuning training...")
|
|
||||||
print("Starting fine-tuning training...")
|
|
||||||
best_val_loss = train_model(model, tokenizer, device, config, config.basemodel_save_path, logger)
|
|
||||||
|
|
||||||
final_msg = f"Training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.basemodel_save_path}"
|
|
||||||
logger.info(final_msg)
|
|
||||||
print(final_msg)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@ -1,359 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
from time import gmtime, strftime
|
|
||||||
import datetime
|
|
||||||
import logging
|
|
||||||
from logging.handlers import RotatingFileHandler
|
|
||||||
import torch.distributed as dist
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
|
|
||||||
sys.path.append("../")
|
|
||||||
from model import KronosTokenizer
|
|
||||||
from finetune_base_model import CustomKlineDataset
|
|
||||||
from config_loader import CustomFinetuneConfig
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed: int, rank: int = 0):
|
|
||||||
actual_seed = seed
|
|
||||||
random.seed(actual_seed)
|
|
||||||
np.random.seed(actual_seed)
|
|
||||||
torch.manual_seed(actual_seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(actual_seed)
|
|
||||||
torch.backends.cudnn.deterministic = True
|
|
||||||
torch.backends.cudnn.benchmark = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_size(model: torch.nn.Module) -> str:
|
|
||||||
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
if total_params >= 1e9:
|
|
||||||
return f"{total_params / 1e9:.1f}B"
|
|
||||||
elif total_params >= 1e6:
|
|
||||||
return f"{total_params / 1e6:.1f}M"
|
|
||||||
else:
|
|
||||||
return f"{total_params / 1e3:.1f}K"
|
|
||||||
|
|
||||||
|
|
||||||
def format_time(seconds: float) -> str:
|
|
||||||
return str(datetime.timedelta(seconds=int(seconds)))
|
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger:
|
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
|
||||||
|
|
||||||
logger = logging.getLogger(f"tokenizer_training_rank_{rank}")
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
if logger.handlers:
|
|
||||||
return logger
|
|
||||||
|
|
||||||
log_file = os.path.join(log_dir, f"tokenizer_training_rank_{rank}.log")
|
|
||||||
file_handler = RotatingFileHandler(
|
|
||||||
log_file,
|
|
||||||
maxBytes=10*1024*1024,
|
|
||||||
backupCount=5,
|
|
||||||
encoding='utf-8'
|
|
||||||
)
|
|
||||||
file_handler.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
console_handler = None
|
|
||||||
if rank == 0:
|
|
||||||
console_handler = logging.StreamHandler()
|
|
||||||
console_handler.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
|
||||||
)
|
|
||||||
file_handler.setFormatter(formatter)
|
|
||||||
if console_handler is not None:
|
|
||||||
console_handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
if console_handler is not None:
|
|
||||||
logger.addHandler(console_handler)
|
|
||||||
|
|
||||||
logger.info(f"=== Tokenizer Training Started ===")
|
|
||||||
logger.info(f"Experiment Name: {exp_name}")
|
|
||||||
logger.info(f"Log Directory: {log_dir}")
|
|
||||||
logger.info(f"Rank: {rank}")
|
|
||||||
logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
||||||
|
|
||||||
return logger
|
|
||||||
|
|
||||||
|
|
||||||
def create_dataloaders(config):
|
|
||||||
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
|
|
||||||
print("Creating tokenizer training data loaders...")
|
|
||||||
|
|
||||||
train_dataset = CustomKlineDataset(
|
|
||||||
data_path=config.data_path,
|
|
||||||
data_type="train",
|
|
||||||
lookback_window=config.lookback_window,
|
|
||||||
predict_window=config.predict_window,
|
|
||||||
clip=config.clip,
|
|
||||||
seed=config.seed,
|
|
||||||
train_ratio=config.train_ratio,
|
|
||||||
val_ratio=config.val_ratio,
|
|
||||||
test_ratio=config.test_ratio
|
|
||||||
)
|
|
||||||
|
|
||||||
val_dataset = CustomKlineDataset(
|
|
||||||
data_path=config.data_path,
|
|
||||||
data_type="val",
|
|
||||||
lookback_window=config.lookback_window,
|
|
||||||
predict_window=config.predict_window,
|
|
||||||
clip=config.clip,
|
|
||||||
seed=config.seed + 1,
|
|
||||||
train_ratio=config.train_ratio,
|
|
||||||
val_ratio=config.val_ratio,
|
|
||||||
test_ratio=config.test_ratio
|
|
||||||
)
|
|
||||||
|
|
||||||
use_ddp = dist.is_available() and dist.is_initialized()
|
|
||||||
train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None
|
|
||||||
val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None
|
|
||||||
|
|
||||||
train_loader = DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
batch_size=config.batch_size,
|
|
||||||
shuffle=(train_sampler is None),
|
|
||||||
num_workers=config.num_workers,
|
|
||||||
pin_memory=True,
|
|
||||||
drop_last=True,
|
|
||||||
sampler=train_sampler
|
|
||||||
)
|
|
||||||
|
|
||||||
val_loader = DataLoader(
|
|
||||||
val_dataset,
|
|
||||||
batch_size=config.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=config.num_workers,
|
|
||||||
pin_memory=True,
|
|
||||||
drop_last=False,
|
|
||||||
sampler=val_sampler
|
|
||||||
)
|
|
||||||
|
|
||||||
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
|
|
||||||
print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
|
|
||||||
|
|
||||||
return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler
|
|
||||||
|
|
||||||
|
|
||||||
def train_tokenizer(model, device, config, save_dir, logger):
|
|
||||||
logger.info("Starting tokenizer training...")
|
|
||||||
use_ddp = dist.is_available() and dist.is_initialized()
|
|
||||||
rank = dist.get_rank() if use_ddp else 0
|
|
||||||
world_size = dist.get_world_size() if use_ddp else 1
|
|
||||||
|
|
||||||
train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config)
|
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(
|
|
||||||
model.parameters(),
|
|
||||||
lr=config.tokenizer_learning_rate,
|
|
||||||
weight_decay=config.adam_weight_decay
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
|
||||||
optimizer,
|
|
||||||
max_lr=config.tokenizer_learning_rate,
|
|
||||||
steps_per_epoch=len(train_loader),
|
|
||||||
epochs=config.tokenizer_epochs,
|
|
||||||
pct_start=0.03,
|
|
||||||
div_factor=10
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_ddp:
|
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
||||||
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
|
|
||||||
|
|
||||||
best_val_loss = float("inf")
|
|
||||||
batch_idx_global = 0
|
|
||||||
|
|
||||||
accumulation_steps = getattr(config, 'accumulation_steps', 1)
|
|
||||||
|
|
||||||
for epoch in range(config.tokenizer_epochs):
|
|
||||||
epoch_start_time = time.time()
|
|
||||||
model.train()
|
|
||||||
|
|
||||||
train_dataset.set_epoch_seed(epoch * 10000)
|
|
||||||
val_dataset.set_epoch_seed(0)
|
|
||||||
if train_sampler is not None:
|
|
||||||
train_sampler.set_epoch(epoch)
|
|
||||||
|
|
||||||
for batch_idx, (ori_batch_x, _) in enumerate(train_loader):
|
|
||||||
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
|
|
||||||
|
|
||||||
current_batch_total_loss = 0.0
|
|
||||||
for j in range(accumulation_steps):
|
|
||||||
start_idx = j * (ori_batch_x.shape[0] // accumulation_steps)
|
|
||||||
end_idx = (j + 1) * (ori_batch_x.shape[0] // accumulation_steps)
|
|
||||||
batch_x = ori_batch_x[start_idx:end_idx]
|
|
||||||
|
|
||||||
zs, bsq_loss, _, _ = (model.module if use_ddp else model)(batch_x)
|
|
||||||
z_pre, z = zs
|
|
||||||
|
|
||||||
recon_loss_pre = F.mse_loss(z_pre, batch_x)
|
|
||||||
recon_loss_all = F.mse_loss(z, batch_x)
|
|
||||||
recon_loss = recon_loss_pre + recon_loss_all
|
|
||||||
loss = (recon_loss + bsq_loss) / 2
|
|
||||||
|
|
||||||
loss_scaled = loss / accumulation_steps
|
|
||||||
current_batch_total_loss += loss.item()
|
|
||||||
loss_scaled.backward()
|
|
||||||
|
|
||||||
torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=2.0)
|
|
||||||
optimizer.step()
|
|
||||||
scheduler.step()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
if (batch_idx_global + 1) % config.log_interval == 0:
|
|
||||||
avg_loss = current_batch_total_loss / accumulation_steps
|
|
||||||
lr = optimizer.param_groups[0]["lr"]
|
|
||||||
log_msg = (f"[Epoch {epoch+1}/{config.tokenizer_epochs}, Step {batch_idx+1}/{len(train_loader)}] "
|
|
||||||
f"LR: {lr:.6f}, Loss: {avg_loss:.4f}")
|
|
||||||
logger.info(log_msg)
|
|
||||||
if rank == 0:
|
|
||||||
print(log_msg)
|
|
||||||
|
|
||||||
detail_msg = (f" - VQ Loss: {bsq_loss.item():.4f}\n"
|
|
||||||
f" - Recon Loss Pre: {recon_loss_pre.item():.4f}\n"
|
|
||||||
f" - Recon Loss All: {recon_loss_all.item():.4f}")
|
|
||||||
logger.info(detail_msg)
|
|
||||||
if rank == 0:
|
|
||||||
print(detail_msg)
|
|
||||||
|
|
||||||
batch_idx_global += 1
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
tot_val_loss_sum_rank = 0.0
|
|
||||||
val_sample_count_rank = 0
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for ori_batch_x, _ in val_loader:
|
|
||||||
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
|
|
||||||
zs, _, _, _ = (model.module if use_ddp else model)(ori_batch_x)
|
|
||||||
_, z = zs
|
|
||||||
val_loss_item = F.mse_loss(z, ori_batch_x)
|
|
||||||
|
|
||||||
tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0)
|
|
||||||
val_sample_count_rank += ori_batch_x.size(0)
|
|
||||||
|
|
||||||
if use_ddp:
|
|
||||||
tensor_sum = torch.tensor([tot_val_loss_sum_rank, val_sample_count_rank], dtype=torch.float64, device=device)
|
|
||||||
dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM)
|
|
||||||
tot_val_loss_all = tensor_sum[0].item()
|
|
||||||
val_count_all = int(tensor_sum[1].item())
|
|
||||||
avg_val_loss = (tot_val_loss_all / val_count_all) if val_count_all > 0 else 0.0
|
|
||||||
else:
|
|
||||||
avg_val_loss = tot_val_loss_sum_rank / val_sample_count_rank if val_sample_count_rank > 0 else 0
|
|
||||||
|
|
||||||
epoch_time = time.time() - epoch_start_time
|
|
||||||
epoch_summary = (f"\n--- Epoch {epoch+1}/{config.tokenizer_epochs} Summary ---\n"
|
|
||||||
f"Validation Loss: {avg_val_loss:.4f}\n"
|
|
||||||
f"Epoch Time: {format_time(epoch_time)}\n"
|
|
||||||
f"Total Training Time: {format_time(time.time() - epoch_start_time)}\n")
|
|
||||||
logger.info(epoch_summary)
|
|
||||||
if rank == 0:
|
|
||||||
print(epoch_summary)
|
|
||||||
|
|
||||||
if avg_val_loss < best_val_loss:
|
|
||||||
best_val_loss = avg_val_loss
|
|
||||||
if rank == 0:
|
|
||||||
model_save_path = os.path.join(save_dir, "best_model")
|
|
||||||
os.makedirs(model_save_path, exist_ok=True)
|
|
||||||
(model.module if use_ddp else model).save_pretrained(model_save_path)
|
|
||||||
save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})"
|
|
||||||
logger.info(save_msg)
|
|
||||||
print(save_msg)
|
|
||||||
|
|
||||||
return best_val_loss
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Kronos Tokenizer Fine-tuning Training')
|
|
||||||
parser.add_argument('--config', type=str, default='config.yaml',
|
|
||||||
help='Configuration file path (default: config.yaml)')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
config = CustomFinetuneConfig(args.config)
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
print(f"Using device: {device}")
|
|
||||||
|
|
||||||
config = CustomFinetuneConfig(args.config)
|
|
||||||
|
|
||||||
os.makedirs(config.tokenizer_save_path, exist_ok=True)
|
|
||||||
|
|
||||||
log_dir = os.path.join(config.base_save_path, "logs")
|
|
||||||
logger = setup_logging(config.exp_name, log_dir, 0)
|
|
||||||
|
|
||||||
set_seed(config.seed)
|
|
||||||
|
|
||||||
# 加载预训练tokenizer
|
|
||||||
if getattr(config, 'pre_trained_tokenizer', True):
|
|
||||||
logger.info("Loading pretrained tokenizer...")
|
|
||||||
print("Loading pretrained tokenizer...")
|
|
||||||
tokenizer = KronosTokenizer.from_pretrained(config.pretrained_tokenizer_path)
|
|
||||||
else:
|
|
||||||
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture")
|
|
||||||
import json, os
|
|
||||||
cfg_path = os.path.join(config.pretrained_tokenizer_path, 'config.json')
|
|
||||||
with open(cfg_path, 'r') as f:
|
|
||||||
arch = json.load(f)
|
|
||||||
tokenizer = KronosTokenizer(
|
|
||||||
d_in=arch.get('d_in', 6),
|
|
||||||
d_model=arch.get('d_model', 256),
|
|
||||||
n_heads=arch.get('n_heads', 4),
|
|
||||||
ff_dim=arch.get('ff_dim', 512),
|
|
||||||
n_enc_layers=arch.get('n_enc_layers', 4),
|
|
||||||
n_dec_layers=arch.get('n_dec_layers', 4),
|
|
||||||
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
|
|
||||||
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
|
||||||
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
|
|
||||||
s1_bits=arch.get('s1_bits', 10),
|
|
||||||
s2_bits=arch.get('s2_bits', 10),
|
|
||||||
beta=arch.get('beta', 0.05),
|
|
||||||
gamma0=arch.get('gamma0', 1.0),
|
|
||||||
gamma=arch.get('gamma', 1.1),
|
|
||||||
zeta=arch.get('zeta', 0.05),
|
|
||||||
group_size=arch.get('group_size', 4)
|
|
||||||
)
|
|
||||||
tokenizer = tokenizer.to(device)
|
|
||||||
|
|
||||||
model_size = get_model_size(tokenizer)
|
|
||||||
logger.info(f"Tokenizer parameters: {model_size}")
|
|
||||||
print(f"Tokenizer parameters: {model_size}")
|
|
||||||
|
|
||||||
logger.info("=== Training Configuration ===")
|
|
||||||
logger.info(f"Data path: {config.data_path}")
|
|
||||||
logger.info(f"Lookback window: {config.lookback_window}")
|
|
||||||
logger.info(f"Predict window: {config.predict_window}")
|
|
||||||
logger.info(f"Batch size: {config.batch_size}")
|
|
||||||
logger.info(f"Learning rate: {config.tokenizer_learning_rate}")
|
|
||||||
logger.info(f"Training epochs: {config.tokenizer_epochs}")
|
|
||||||
logger.info(f"Device: {device}")
|
|
||||||
logger.info(f"Distributed training: False")
|
|
||||||
|
|
||||||
logger.info("Starting tokenizer fine-tuning training...")
|
|
||||||
print("Starting tokenizer fine-tuning training...")
|
|
||||||
best_val_loss = train_tokenizer(tokenizer, device, config, config.tokenizer_save_path, logger)
|
|
||||||
|
|
||||||
final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.tokenizer_save_path}"
|
|
||||||
logger.info(final_msg)
|
|
||||||
print(final_msg)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
||||||
@ -1,361 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import argparse
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
sys.path.append('../')
|
|
||||||
from model import Kronos, KronosTokenizer, KronosPredictor
|
|
||||||
|
|
||||||
from config_loader import CustomFinetuneConfig
|
|
||||||
from finetune_tokenizer import train_tokenizer, set_seed, setup_logging as setup_tokenizer_logging
|
|
||||||
from finetune_base_model import train_model, create_dataloaders, setup_logging as setup_basemodel_logging
|
|
||||||
|
|
||||||
|
|
||||||
class SequentialTrainer:
|
|
||||||
|
|
||||||
def __init__(self, config_path: str = None):
|
|
||||||
self.config = CustomFinetuneConfig(config_path)
|
|
||||||
self.rank = int(os.environ.get("RANK", "0"))
|
|
||||||
self.world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
|
||||||
self.local_rank = int(os.environ.get("LOCAL_RANK", str(self.config.device_id if hasattr(self.config, 'device_id') else 0)))
|
|
||||||
self.device = self._setup_device()
|
|
||||||
|
|
||||||
self.config.print_config_summary()
|
|
||||||
|
|
||||||
def _setup_device(self):
|
|
||||||
if self.config.use_cuda and torch.cuda.is_available():
|
|
||||||
torch.cuda.set_device(self.local_rank)
|
|
||||||
device = torch.device(f"cuda:{self.local_rank}")
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
|
|
||||||
if self.rank == 0:
|
|
||||||
print(f"Using device: {device} (rank={self.rank}, world_size={self.world_size}, local_rank={self.local_rank})")
|
|
||||||
return device
|
|
||||||
|
|
||||||
def _setup_distributed(self):
|
|
||||||
if self.world_size > 1 and torch.cuda.is_available():
|
|
||||||
backend = os.environ.get("DIST_BACKEND", "nccl").lower()
|
|
||||||
if not dist.is_initialized():
|
|
||||||
dist.init_process_group(backend=backend)
|
|
||||||
if self.rank == 0:
|
|
||||||
print(f"Distributed training initialized: backend={backend}, world_size={self.world_size}")
|
|
||||||
else:
|
|
||||||
if self.rank == 0:
|
|
||||||
print("Distributed training not enabled, using single GPU/CPU training")
|
|
||||||
|
|
||||||
def _check_existing_models(self):
|
|
||||||
tokenizer_exists = os.path.exists(self.config.tokenizer_best_model_path)
|
|
||||||
basemodel_exists = os.path.exists(self.config.basemodel_best_model_path)
|
|
||||||
|
|
||||||
print(f"Tokenizer model exists: {tokenizer_exists}")
|
|
||||||
print(f"Basemodel model exists: {basemodel_exists}")
|
|
||||||
|
|
||||||
return tokenizer_exists, basemodel_exists
|
|
||||||
|
|
||||||
def _create_directories(self):
|
|
||||||
os.makedirs(self.config.tokenizer_save_path, exist_ok=True)
|
|
||||||
os.makedirs(self.config.basemodel_save_path, exist_ok=True)
|
|
||||||
print(f"Created directory: {self.config.tokenizer_save_path}")
|
|
||||||
print(f"Created directory: {self.config.basemodel_save_path}")
|
|
||||||
|
|
||||||
def train_tokenizer_phase(self):
|
|
||||||
print("\n" + "="*60)
|
|
||||||
print("Starting Tokenizer Fine-tuning Phase")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
tokenizer_exists, _ = self._check_existing_models()
|
|
||||||
if tokenizer_exists and self.config.skip_existing:
|
|
||||||
print("Tokenizer model already exists, skipping training")
|
|
||||||
return True
|
|
||||||
|
|
||||||
log_dir = os.path.join(self.config.base_save_path, "logs")
|
|
||||||
logger = setup_tokenizer_logging(self.config.exp_name, log_dir, self.rank)
|
|
||||||
|
|
||||||
set_seed(self.config.seed)
|
|
||||||
|
|
||||||
if getattr(self.config, 'pre_trained_tokenizer', True):
|
|
||||||
logger.info("Loading pretrained tokenizer...")
|
|
||||||
if self.rank == 0:
|
|
||||||
print("Loading pretrained tokenizer...")
|
|
||||||
tokenizer = KronosTokenizer.from_pretrained(self.config.pretrained_tokenizer_path)
|
|
||||||
else:
|
|
||||||
if self.rank == 0:
|
|
||||||
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture")
|
|
||||||
import json
|
|
||||||
cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json')
|
|
||||||
with open(cfg_path, 'r') as f:
|
|
||||||
arch = json.load(f)
|
|
||||||
tokenizer = KronosTokenizer(
|
|
||||||
d_in=arch.get('d_in', 6),
|
|
||||||
d_model=arch.get('d_model', 256),
|
|
||||||
n_heads=arch.get('n_heads', 4),
|
|
||||||
ff_dim=arch.get('ff_dim', 512),
|
|
||||||
n_enc_layers=arch.get('n_enc_layers', 4),
|
|
||||||
n_dec_layers=arch.get('n_dec_layers', 4),
|
|
||||||
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
|
|
||||||
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
|
||||||
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
|
|
||||||
s1_bits=arch.get('s1_bits', 10),
|
|
||||||
s2_bits=arch.get('s2_bits', 10),
|
|
||||||
beta=arch.get('beta', 0.05),
|
|
||||||
gamma0=arch.get('gamma0', 1.0),
|
|
||||||
gamma=arch.get('gamma', 1.1),
|
|
||||||
zeta=arch.get('zeta', 0.05),
|
|
||||||
group_size=arch.get('group_size', 4)
|
|
||||||
)
|
|
||||||
tokenizer = tokenizer.to(self.device)
|
|
||||||
|
|
||||||
model_size = sum(p.numel() for p in tokenizer.parameters())
|
|
||||||
logger.info(f"Tokenizer parameters: {model_size:,}")
|
|
||||||
if self.rank == 0:
|
|
||||||
print(f"Tokenizer parameters: {model_size:,}")
|
|
||||||
|
|
||||||
logger.info("=== Training Configuration ===")
|
|
||||||
logger.info(f"Data path: {self.config.data_path}")
|
|
||||||
logger.info(f"Lookback window: {self.config.lookback_window}")
|
|
||||||
logger.info(f"Predict window: {self.config.predict_window}")
|
|
||||||
logger.info(f"Batch size: {self.config.batch_size}")
|
|
||||||
logger.info(f"Learning rate: {self.config.tokenizer_learning_rate}")
|
|
||||||
logger.info(f"Training epochs: {self.config.tokenizer_epochs}")
|
|
||||||
logger.info(f"Device: {self.device}")
|
|
||||||
logger.info(f"Distributed training: False")
|
|
||||||
|
|
||||||
logger.info("Starting tokenizer fine-tuning training...")
|
|
||||||
if self.rank == 0:
|
|
||||||
print("Starting tokenizer fine-tuning training...")
|
|
||||||
start_time = time.time()
|
|
||||||
best_val_loss = train_tokenizer(
|
|
||||||
tokenizer,
|
|
||||||
self.device,
|
|
||||||
self.config,
|
|
||||||
self.config.tokenizer_save_path,
|
|
||||||
logger,
|
|
||||||
)
|
|
||||||
training_time = time.time() - start_time
|
|
||||||
|
|
||||||
final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.tokenizer_save_path}"
|
|
||||||
logger.info(final_msg)
|
|
||||||
if self.rank == 0:
|
|
||||||
print(f"\n{final_msg}")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def train_basemodel_phase(self):
|
|
||||||
print("\n" + "="*60)
|
|
||||||
print("Starting Basemodel Fine-tuning Phase")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
if getattr(self.config, 'pre_trained_tokenizer', True):
|
|
||||||
if not os.path.exists(self.config.finetuned_tokenizer_path):
|
|
||||||
raise FileNotFoundError(f"Fine-tuned tokenizer does not exist: {self.config.finetuned_tokenizer_path}")
|
|
||||||
|
|
||||||
_, basemodel_exists = self._check_existing_models()
|
|
||||||
if basemodel_exists and self.config.skip_existing:
|
|
||||||
print("Basemodel model already exists, skipping training")
|
|
||||||
return True
|
|
||||||
|
|
||||||
log_dir = os.path.join(self.config.base_save_path, "logs")
|
|
||||||
logger = setup_basemodel_logging(self.config.exp_name, log_dir, self.rank)
|
|
||||||
|
|
||||||
set_seed(self.config.seed)
|
|
||||||
|
|
||||||
if getattr(self.config, 'pre_trained_tokenizer', True):
|
|
||||||
logger.info("Loading fine-tuned tokenizer...")
|
|
||||||
if self.rank == 0:
|
|
||||||
print("Loading fine-tuned tokenizer...")
|
|
||||||
tokenizer = KronosTokenizer.from_pretrained(self.config.finetuned_tokenizer_path)
|
|
||||||
else:
|
|
||||||
if self.rank == 0:
|
|
||||||
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for Predictor training")
|
|
||||||
import json
|
|
||||||
cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json')
|
|
||||||
with open(cfg_path, 'r') as f:
|
|
||||||
arch = json.load(f)
|
|
||||||
tokenizer = KronosTokenizer(
|
|
||||||
d_in=arch.get('d_in', 6),
|
|
||||||
d_model=arch.get('d_model', 256),
|
|
||||||
n_heads=arch.get('n_heads', 4),
|
|
||||||
ff_dim=arch.get('ff_dim', 512),
|
|
||||||
n_enc_layers=arch.get('n_enc_layers', 4),
|
|
||||||
n_dec_layers=arch.get('n_dec_layers', 4),
|
|
||||||
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
|
|
||||||
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
|
||||||
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
|
|
||||||
s1_bits=arch.get('s1_bits', 10),
|
|
||||||
s2_bits=arch.get('s2_bits', 10),
|
|
||||||
beta=arch.get('beta', 0.05),
|
|
||||||
gamma0=arch.get('gamma0', 1.0),
|
|
||||||
gamma=arch.get('gamma', 1.1),
|
|
||||||
zeta=arch.get('zeta', 0.05),
|
|
||||||
group_size=arch.get('group_size', 4)
|
|
||||||
)
|
|
||||||
tokenizer = tokenizer.to(self.device)
|
|
||||||
|
|
||||||
if getattr(self.config, 'pre_trained_predictor', True):
|
|
||||||
logger.info("Loading pretrained predictor...")
|
|
||||||
if self.rank == 0:
|
|
||||||
print("Loading pretrained predictor...")
|
|
||||||
model = Kronos.from_pretrained(self.config.pretrained_predictor_path)
|
|
||||||
else:
|
|
||||||
if self.rank == 0:
|
|
||||||
print("pre_trained_predictor=False, randomly initializing Predictor architecture")
|
|
||||||
import json
|
|
||||||
cfg_path = os.path.join(self.config.pretrained_predictor_path, 'config.json')
|
|
||||||
with open(cfg_path, 'r') as f:
|
|
||||||
arch = json.load(f)
|
|
||||||
print("model_config: ", arch)
|
|
||||||
model = Kronos(
|
|
||||||
s1_bits=arch.get('s1_bits', 10),
|
|
||||||
s2_bits=arch.get('s2_bits', 10),
|
|
||||||
n_layers=arch.get('n_layers', 12),
|
|
||||||
d_model=arch.get('d_model', 832),
|
|
||||||
n_heads=arch.get('n_heads', 16),
|
|
||||||
ff_dim=arch.get('ff_dim', 2048),
|
|
||||||
ffn_dropout_p=arch.get('ffn_dropout_p', 0.2),
|
|
||||||
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
|
||||||
resid_dropout_p=arch.get('resid_dropout_p', 0.2),
|
|
||||||
token_dropout_p=arch.get('token_dropout_p', 0.0),
|
|
||||||
learn_te=arch.get('learn_te', True)
|
|
||||||
)
|
|
||||||
model = model.to(self.device)
|
|
||||||
|
|
||||||
model_size = sum(p.numel() for p in model.parameters())
|
|
||||||
logger.info(f"Model parameters: {model_size:,}")
|
|
||||||
if self.rank == 0:
|
|
||||||
print(f"Model parameters: {model_size:,}")
|
|
||||||
|
|
||||||
logger.info("=== Training Configuration ===")
|
|
||||||
logger.info(f"Data path: {self.config.data_path}")
|
|
||||||
logger.info(f"Lookback window: {self.config.lookback_window}")
|
|
||||||
logger.info(f"Predict window: {self.config.predict_window}")
|
|
||||||
logger.info(f"Batch size: {self.config.batch_size}")
|
|
||||||
logger.info(f"Learning rate: {self.config.predictor_learning_rate}")
|
|
||||||
logger.info(f"Training epochs: {self.config.basemodel_epochs}")
|
|
||||||
logger.info(f"Device: {self.device}")
|
|
||||||
logger.info(f"Tokenizer path: {self.config.finetuned_tokenizer_path}")
|
|
||||||
logger.info(f"Pretrained model path: {self.config.pretrained_predictor_path}")
|
|
||||||
|
|
||||||
logger.info("Starting fine-tuning training...")
|
|
||||||
if self.rank == 0:
|
|
||||||
print("Starting fine-tuning training...")
|
|
||||||
start_time = time.time()
|
|
||||||
best_val_loss = train_model(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
self.device,
|
|
||||||
self.config,
|
|
||||||
self.config.basemodel_save_path,
|
|
||||||
logger,
|
|
||||||
)
|
|
||||||
training_time = time.time() - start_time
|
|
||||||
|
|
||||||
final_msg = f"Basemodel training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.basemodel_save_path}"
|
|
||||||
logger.info(final_msg)
|
|
||||||
if self.rank == 0:
|
|
||||||
print(f"\n{final_msg}")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def run_training(self):
|
|
||||||
if self.rank == 0:
|
|
||||||
print("Starting Kronos model sequential fine-tuning training")
|
|
||||||
print(f"Experiment name: {self.config.experiment_name}")
|
|
||||||
print(f"Experiment description: {self.config.experiment_description}")
|
|
||||||
|
|
||||||
self._setup_distributed()
|
|
||||||
|
|
||||||
self._create_directories()
|
|
||||||
|
|
||||||
tokenizer_exists, basemodel_exists = self._check_existing_models()
|
|
||||||
|
|
||||||
total_start_time = time.time()
|
|
||||||
|
|
||||||
try:
|
|
||||||
if self.config.train_tokenizer:
|
|
||||||
success = self.train_tokenizer_phase()
|
|
||||||
if not success:
|
|
||||||
print("Tokenizer training failed, terminating training")
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
print("Skipping Tokenizer training phase")
|
|
||||||
|
|
||||||
if self.config.train_basemodel:
|
|
||||||
success = self.train_basemodel_phase()
|
|
||||||
if not success:
|
|
||||||
print("Basemodel training failed, terminating training")
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
print("Skipping Basemodel training phase")
|
|
||||||
|
|
||||||
total_time = time.time() - total_start_time
|
|
||||||
|
|
||||||
if self.rank == 0:
|
|
||||||
print("\n" + "="*60)
|
|
||||||
print("Training completed!")
|
|
||||||
print("="*60)
|
|
||||||
print(f"Total training time: {total_time/60:.2f} minutes")
|
|
||||||
print(f"Tokenizer model: {self.config.tokenizer_best_model_path}")
|
|
||||||
print(f"Basemodel model: {self.config.basemodel_best_model_path}")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if self.rank == 0:
|
|
||||||
print(f"Error occurred during training: {str(e)}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
finally:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description='Kronos Model Sequential Fine-tuning Training')
|
|
||||||
parser.add_argument('--config', type=str, default='config.yaml',
|
|
||||||
help='Configuration file path (default: config.yaml)')
|
|
||||||
parser.add_argument('--skip-tokenizer', action='store_true',
|
|
||||||
help='Skip tokenizer training phase')
|
|
||||||
parser.add_argument('--skip-basemodel', action='store_true',
|
|
||||||
help='Skip basemodel training phase')
|
|
||||||
parser.add_argument('--skip-existing', action='store_true',
|
|
||||||
help='Skip training for existing models')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
trainer = SequentialTrainer(args.config)
|
|
||||||
|
|
||||||
if args.skip_tokenizer:
|
|
||||||
trainer.config.train_tokenizer = False
|
|
||||||
if args.skip_basemodel:
|
|
||||||
trainer.config.train_basemodel = False
|
|
||||||
if args.skip_existing:
|
|
||||||
trainer.config.skip_existing = True
|
|
||||||
|
|
||||||
success = trainer.run_training()
|
|
||||||
|
|
||||||
if success:
|
|
||||||
print("Training completed successfully!")
|
|
||||||
if dist.is_available() and dist.is_initialized():
|
|
||||||
dist.barrier()
|
|
||||||
dist.destroy_process_group()
|
|
||||||
sys.exit(0)
|
|
||||||
else:
|
|
||||||
print("Training failed!")
|
|
||||||
if dist.is_available() and dist.is_initialized():
|
|
||||||
try:
|
|
||||||
dist.barrier()
|
|
||||||
dist.destroy_process_group()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
419
trader.py
Normal file
@ -0,0 +1,419 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
BTCUSD 15m 多次采样预测 → 区间聚合可视化
|
||||||
|
Author: 你(鹅)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import math
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import argparse
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from pybit.unified_trading import HTTP
|
||||||
|
|
||||||
|
# ==== Kronos ====
|
||||||
|
from model import Kronos, KronosTokenizer, KronosPredictor
|
||||||
|
|
||||||
|
# ========== Matplotlib 字体 ==========
|
||||||
|
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"]
|
||||||
|
plt.rcParams["axes.unicode_minus"] = False
|
||||||
|
|
||||||
|
# ========== 默认参数 ==========
|
||||||
|
NUM_SAMPLES = 20 # 多次采样次数(建议 10~50)
|
||||||
|
QUANTILES = [0.1, 0.5, 0.9] # 预测区间分位(下/中/上)
|
||||||
|
LOOKBACK = 360 # 历史窗口
|
||||||
|
PRED_LEN = 96 # 未来预测点数
|
||||||
|
INTERVAL_MIN = 15 # K 线周期(分钟)
|
||||||
|
DEVICE = os.getenv("KRONOS_DEVICE", "cuda:0")
|
||||||
|
TEMPERATURE = 1.0
|
||||||
|
TOP_P = 0.9
|
||||||
|
|
||||||
|
OUTPUT_DIR = "figures"
|
||||||
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 实用函数 ==========
|
||||||
|
def set_global_seed(seed: int):
|
||||||
|
"""尽可能固定随机性(若底层采样逻辑使用 torch/np/random)"""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_kline(
|
||||||
|
session: HTTP, symbol="BTCUSD", interval="15", limit=1000, category="inverse"
|
||||||
|
):
|
||||||
|
print("正在获取K线数据...")
|
||||||
|
resp = session.get_kline(
|
||||||
|
category=category, symbol=symbol, interval=interval, limit=limit
|
||||||
|
)
|
||||||
|
if resp.get("retCode", -1) != 0:
|
||||||
|
raise RuntimeError(f"获取数据失败: {resp.get('retMsg')}")
|
||||||
|
lst = resp["result"]["list"]
|
||||||
|
|
||||||
|
df = pd.DataFrame(
|
||||||
|
lst,
|
||||||
|
columns=["timestamps", "open", "high", "low", "close", "volume", "turnover"],
|
||||||
|
)
|
||||||
|
df["timestamps"] = pd.to_datetime(df["timestamps"].astype(float), unit="ms")
|
||||||
|
for c in ["open", "high", "low", "close", "volume", "turnover"]:
|
||||||
|
df[c] = df[c].astype(float)
|
||||||
|
df = df.sort_values("timestamps").reset_index(drop=True)
|
||||||
|
df["amount"] = df["turnover"]
|
||||||
|
print(f"获取到 {len(df)} 根K线数据")
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_io_windows(
|
||||||
|
df: pd.DataFrame, lookback=LOOKBACK, pred_len=PRED_LEN, interval_min=INTERVAL_MIN
|
||||||
|
):
|
||||||
|
end_idx = len(df)
|
||||||
|
start_idx = max(0, end_idx - lookback)
|
||||||
|
x_df = df.loc[
|
||||||
|
start_idx : end_idx - 1, ["open", "high", "low", "close", "volume", "amount"]
|
||||||
|
].reset_index(drop=True)
|
||||||
|
x_timestamp = df.loc[start_idx : end_idx - 1, "timestamps"].reset_index(drop=True)
|
||||||
|
|
||||||
|
last_ts = df.loc[end_idx - 1, "timestamps"]
|
||||||
|
future_timestamps = pd.date_range(
|
||||||
|
start=last_ts + pd.Timedelta(minutes=interval_min),
|
||||||
|
periods=pred_len,
|
||||||
|
freq=f"{interval_min}min",
|
||||||
|
)
|
||||||
|
y_timestamp = pd.Series(future_timestamps)
|
||||||
|
y_timestamp.index = range(len(y_timestamp))
|
||||||
|
|
||||||
|
print(f"数据总量: {len(df)} 根K线")
|
||||||
|
print(f"使用最新的 {lookback} 根K线(索引 {start_idx} 到 {end_idx-1})")
|
||||||
|
print(f"最后一根历史K线时间: {last_ts}")
|
||||||
|
print(f"预测未来 {pred_len} 根K线")
|
||||||
|
return x_df, x_timestamp, y_timestamp, start_idx, end_idx
|
||||||
|
|
||||||
|
|
||||||
|
def load_kronos(device=DEVICE):
|
||||||
|
print("正在加载 Kronos 模型...")
|
||||||
|
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
|
||||||
|
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
|
||||||
|
predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
|
||||||
|
print("模型加载完成!\n")
|
||||||
|
return predictor
|
||||||
|
|
||||||
|
|
||||||
|
def run_one_prediction(
|
||||||
|
predictor,
|
||||||
|
x_df,
|
||||||
|
x_timestamp,
|
||||||
|
y_timestamp,
|
||||||
|
T=TEMPERATURE,
|
||||||
|
top_p=TOP_P,
|
||||||
|
seed=None,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
if seed is not None:
|
||||||
|
set_global_seed(seed)
|
||||||
|
return predictor.predict(
|
||||||
|
df=x_df,
|
||||||
|
x_timestamp=x_timestamp,
|
||||||
|
y_timestamp=y_timestamp,
|
||||||
|
pred_len=len(y_timestamp),
|
||||||
|
T=T,
|
||||||
|
top_p=top_p,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_multi_predictions(
|
||||||
|
predictor, x_df, x_timestamp, y_timestamp, num_samples=NUM_SAMPLES, base_seed=42,
|
||||||
|
temperature=TEMPERATURE, top_p=TOP_P
|
||||||
|
):
|
||||||
|
preds = []
|
||||||
|
print(f"正在进行多次预测:{num_samples} 次(T={temperature}, top_p={top_p})...")
|
||||||
|
for i in tqdm(range(num_samples), desc="预测进度", unit="次", ncols=100):
|
||||||
|
seed = base_seed + i # 每次不同 seed
|
||||||
|
pred_df = run_one_prediction(
|
||||||
|
predictor,
|
||||||
|
x_df,
|
||||||
|
x_timestamp,
|
||||||
|
y_timestamp,
|
||||||
|
T=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
seed=seed,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
# 兼容性处理:确保只取需要的列
|
||||||
|
cols_present = [
|
||||||
|
c
|
||||||
|
for c in ["open", "high", "low", "close", "volume", "amount"]
|
||||||
|
if c in pred_df.columns
|
||||||
|
]
|
||||||
|
pred_df = pred_df[cols_present].copy()
|
||||||
|
pred_df.reset_index(drop=True, inplace=True)
|
||||||
|
preds.append(pred_df)
|
||||||
|
print("多次预测完成。\n")
|
||||||
|
return preds
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_quantiles(pred_list, quantiles=QUANTILES):
|
||||||
|
"""
|
||||||
|
将多次预测列表聚合成分位数 DataFrame:
|
||||||
|
输出列命名:<col>_q10, <col>_q50, <col>_q90(按 quantiles 中的值)
|
||||||
|
"""
|
||||||
|
# 先把每次预测拼成 3D:time x feature x samples
|
||||||
|
keys = pred_list[0].columns.tolist()
|
||||||
|
T_len = len(pred_list[0])
|
||||||
|
S = len(pred_list)
|
||||||
|
data = {k: np.zeros((T_len, S), dtype=float) for k in keys}
|
||||||
|
for j, pdf in enumerate(pred_list):
|
||||||
|
for k in keys:
|
||||||
|
data[k][:, j] = pdf[k].values
|
||||||
|
|
||||||
|
out = {}
|
||||||
|
for k in keys:
|
||||||
|
for q in quantiles:
|
||||||
|
qv = np.quantile(data[k], q, axis=1)
|
||||||
|
out[f"{k}_q{int(q*100):02d}"] = qv
|
||||||
|
agg_df = pd.DataFrame(out)
|
||||||
|
return agg_df
|
||||||
|
|
||||||
|
|
||||||
|
def plot_results(historical_df, y_timestamp, agg_df, title_prefix="BTCUSD 15分钟"):
|
||||||
|
"""
|
||||||
|
上下两个子图(共享X轴):
|
||||||
|
- 上方(高度3):历史收盘价 + 预测收盘价区间(q10~q90) + 中位线(q50)
|
||||||
|
- 下方(高度1):历史成交量柱 + 预测成交量中位柱(仅q50,不显示区间)
|
||||||
|
"""
|
||||||
|
import matplotlib.gridspec as gridspec
|
||||||
|
|
||||||
|
# 智能推断K线宽度(柱宽)
|
||||||
|
try:
|
||||||
|
if len(historical_df) >= 2:
|
||||||
|
hist_step = (historical_df["timestamps"].iloc[-1] - historical_df["timestamps"].iloc[-2])
|
||||||
|
else:
|
||||||
|
hist_step = pd.Timedelta(minutes=10)
|
||||||
|
if len(y_timestamp) >= 2:
|
||||||
|
pred_step = (y_timestamp.iloc[1] - y_timestamp.iloc[0])
|
||||||
|
else:
|
||||||
|
pred_step = pd.Timedelta(minutes=10)
|
||||||
|
bar_width_hist = hist_step * 0.8
|
||||||
|
bar_width_pred = pred_step * 0.8
|
||||||
|
except Exception:
|
||||||
|
bar_width_hist = pd.Timedelta(minutes=10)
|
||||||
|
bar_width_pred = pd.Timedelta(minutes=10)
|
||||||
|
|
||||||
|
# 取出预测分位
|
||||||
|
close_q10 = agg_df["close_q10"] if "close_q10" in agg_df else None
|
||||||
|
close_q50 = agg_df["close_q50"] if "close_q50" in agg_df else None
|
||||||
|
close_q90 = agg_df["close_q90"] if "close_q90" in agg_df else None
|
||||||
|
|
||||||
|
vol_q50 = agg_df["volume_q50"] if "volume_q50" in agg_df else None
|
||||||
|
|
||||||
|
# 图形与网格
|
||||||
|
fig = plt.figure(figsize=(18, 10))
|
||||||
|
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.08) # 高度1:3
|
||||||
|
|
||||||
|
# ===== 上:收盘价 =====
|
||||||
|
ax_price = fig.add_subplot(gs[0])
|
||||||
|
|
||||||
|
ax_price.plot(
|
||||||
|
historical_df["timestamps"],
|
||||||
|
historical_df["close"],
|
||||||
|
label="历史收盘价",
|
||||||
|
linewidth=1.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 预测收盘价区间与中位线
|
||||||
|
if close_q10 is not None and close_q90 is not None:
|
||||||
|
ax_price.fill_between(
|
||||||
|
y_timestamp.values, close_q10, close_q90, alpha=0.25, label="预测收盘区间(q10~q90)"
|
||||||
|
)
|
||||||
|
if close_q50 is not None:
|
||||||
|
ax_price.plot(
|
||||||
|
y_timestamp.values, close_q50, linestyle="--", linewidth=2, label="预测收盘中位线(q50)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 预测起点与阴影
|
||||||
|
if len(y_timestamp) > 0:
|
||||||
|
ax_price.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1, label="预测起点")
|
||||||
|
ax_price.axvspan(y_timestamp.iloc[0], y_timestamp.iloc[-1], color="yellow", alpha=0.08)
|
||||||
|
|
||||||
|
# 绘制每天 16:00 UTC+8 的竖直虚线并标注收盘价
|
||||||
|
# 合并历史和预测时间范围与价格数据
|
||||||
|
all_timestamps = pd.concat([historical_df["timestamps"], y_timestamp], ignore_index=True)
|
||||||
|
hist_close = historical_df["close"]
|
||||||
|
pred_close = close_q50 if close_q50 is not None else pd.Series([np.nan] * len(y_timestamp))
|
||||||
|
all_close = pd.concat([hist_close.reset_index(drop=True), pred_close.reset_index(drop=True)], ignore_index=True)
|
||||||
|
|
||||||
|
if len(all_timestamps) > 0:
|
||||||
|
start_time = all_timestamps.min()
|
||||||
|
end_time = all_timestamps.max()
|
||||||
|
|
||||||
|
# 生成所有16:00时间点(UTC+8)
|
||||||
|
current_date = start_time.normalize() # 当天零点
|
||||||
|
while current_date <= end_time:
|
||||||
|
target_time = current_date + pd.Timedelta(hours=16) # 16:00
|
||||||
|
if start_time <= target_time <= end_time:
|
||||||
|
# 画虚线
|
||||||
|
ax_price.axvline(x=target_time, color='blue', linestyle='--', linewidth=0.8, alpha=0.5)
|
||||||
|
|
||||||
|
# 找到最接近16:00的时间点的收盘价
|
||||||
|
time_diffs = (all_timestamps - target_time).abs()
|
||||||
|
closest_idx = time_diffs.idxmin()
|
||||||
|
closest_time = all_timestamps.iloc[closest_idx]
|
||||||
|
closest_price = all_close.iloc[closest_idx]
|
||||||
|
|
||||||
|
# 如果时间差不超过1小时,则标注价格
|
||||||
|
if time_diffs.iloc[closest_idx] <= pd.Timedelta(hours=1) and not np.isnan(closest_price):
|
||||||
|
ax_price.plot(closest_time, closest_price, 'o', color='blue', markersize=6, alpha=0.7)
|
||||||
|
ax_price.text(
|
||||||
|
closest_time, closest_price,
|
||||||
|
f' ${closest_price:.1f}',
|
||||||
|
fontsize=9,
|
||||||
|
color='blue',
|
||||||
|
verticalalignment='bottom',
|
||||||
|
horizontalalignment='left',
|
||||||
|
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='blue', alpha=0.7)
|
||||||
|
)
|
||||||
|
current_date += pd.Timedelta(days=1)
|
||||||
|
|
||||||
|
ax_price.set_ylabel("价格 (USD)", fontsize=11)
|
||||||
|
ax_price.set_title(f"{title_prefix} - 收盘价 & 成交量(历史 + 预测)", fontsize=15, fontweight="bold")
|
||||||
|
ax_price.grid(True, alpha=0.3)
|
||||||
|
ax_price.legend(loc="best", fontsize=9)
|
||||||
|
|
||||||
|
# ===== 下:成交量(仅预测中位量能柱)=====
|
||||||
|
ax_vol = fig.add_subplot(gs[1], sharex=ax_price)
|
||||||
|
|
||||||
|
# 历史量能柱
|
||||||
|
ax_vol.bar(
|
||||||
|
historical_df["timestamps"],
|
||||||
|
historical_df["volume"],
|
||||||
|
width=bar_width_hist,
|
||||||
|
alpha=0.35,
|
||||||
|
label="历史成交量",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 预测中位量能柱(不画区间)
|
||||||
|
if vol_q50 is not None:
|
||||||
|
ax_vol.bar(
|
||||||
|
y_timestamp.values,
|
||||||
|
vol_q50,
|
||||||
|
width=bar_width_pred,
|
||||||
|
alpha=0.6,
|
||||||
|
label="预测成交量中位(q50)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 预测起点线
|
||||||
|
if len(y_timestamp) > 0:
|
||||||
|
ax_vol.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1)
|
||||||
|
|
||||||
|
ax_vol.set_ylabel("成交量", fontsize=11)
|
||||||
|
ax_vol.set_xlabel("时间", fontsize=11)
|
||||||
|
ax_vol.grid(True, alpha=0.25)
|
||||||
|
ax_vol.legend(loc="best", fontsize=9)
|
||||||
|
|
||||||
|
# 避免X轴标签重叠
|
||||||
|
plt.setp(ax_price.get_xticklabels(), visible=False)
|
||||||
|
plt.tight_layout()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Kronos 多次预测区间可视化")
|
||||||
|
parser.add_argument("--symbol", default="BTCUSD")
|
||||||
|
parser.add_argument("--category", default="inverse")
|
||||||
|
parser.add_argument("--interval", default="15")
|
||||||
|
parser.add_argument("--limit", type=int, default=1000)
|
||||||
|
parser.add_argument("--lookback", type=int, default=LOOKBACK)
|
||||||
|
parser.add_argument("--pred_len", type=int, default=PRED_LEN)
|
||||||
|
parser.add_argument("--samples", type=int, default=NUM_SAMPLES)
|
||||||
|
parser.add_argument("--temperature", type=float, default=TEMPERATURE)
|
||||||
|
parser.add_argument("--top_p", type=float, default=TOP_P)
|
||||||
|
parser.add_argument("--quantiles", default="0.1,0.5,0.9")
|
||||||
|
parser.add_argument("--device", default=DEVICE)
|
||||||
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 载入env & HTTP
|
||||||
|
load_dotenv()
|
||||||
|
api_key = os.getenv("BYBIT_API_KEY")
|
||||||
|
api_secret = os.getenv("BYBIT_API_SECRET")
|
||||||
|
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
|
||||||
|
# Kronos
|
||||||
|
predictor = load_kronos(device=args.device)
|
||||||
|
|
||||||
|
# 拉数据
|
||||||
|
df = fetch_kline(
|
||||||
|
session,
|
||||||
|
symbol=args.symbol,
|
||||||
|
interval=args.interval,
|
||||||
|
limit=args.limit,
|
||||||
|
category=args.category,
|
||||||
|
)
|
||||||
|
x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows(
|
||||||
|
df,
|
||||||
|
lookback=args.lookback,
|
||||||
|
pred_len=args.pred_len,
|
||||||
|
interval_min=int(args.interval),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 历史用于画图的窗口(最近200根,可按需调整)
|
||||||
|
plot_start = max(0, end_idx - 200)
|
||||||
|
historical_df = df.loc[plot_start : end_idx - 1].copy()
|
||||||
|
|
||||||
|
# 采样参数(直接使用局部变量,不需要修改全局变量)
|
||||||
|
temperature = args.temperature
|
||||||
|
top_p = args.top_p
|
||||||
|
qs = [float(x) for x in args.quantiles.split(",") if x.strip()]
|
||||||
|
|
||||||
|
# 多次预测(传入局部变量)
|
||||||
|
preds = run_multi_predictions(
|
||||||
|
predictor, x_df, x_ts, y_ts, num_samples=args.samples, base_seed=args.seed,
|
||||||
|
temperature=temperature, top_p=top_p
|
||||||
|
)
|
||||||
|
|
||||||
|
# 聚合分位
|
||||||
|
agg_df = aggregate_quantiles(preds, quantiles=qs)
|
||||||
|
agg_df["timestamps"] = y_ts.values
|
||||||
|
|
||||||
|
# 输出 CSV(聚合)
|
||||||
|
out_csv = os.path.join(
|
||||||
|
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.csv"
|
||||||
|
)
|
||||||
|
agg_df.to_csv(out_csv, index=False)
|
||||||
|
print(f"预测分位数据已保存到: {out_csv}")
|
||||||
|
|
||||||
|
# 作图
|
||||||
|
fig = plot_results(
|
||||||
|
historical_df, y_ts, agg_df, title_prefix=f"{args.symbol} {args.interval}分钟"
|
||||||
|
)
|
||||||
|
out_png = os.path.join(
|
||||||
|
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.png"
|
||||||
|
)
|
||||||
|
fig.savefig(out_png, dpi=150, bbox_inches="tight")
|
||||||
|
print(f"预测区间图表已保存到: {out_png}")
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
main()
|
||||||
|
except Exception as e:
|
||||||
|
print("运行出错:", repr(e))
|
||||||
|
sys.exit(1)
|
||||||
135
webui/README.md
@ -1,135 +0,0 @@
|
|||||||
# Kronos Web UI
|
|
||||||
|
|
||||||
Web user interface for Kronos financial prediction model, providing intuitive graphical operation interface.
|
|
||||||
|
|
||||||
## ✨ Features
|
|
||||||
|
|
||||||
- **Multi-format data support**: Supports CSV, Feather and other financial data formats
|
|
||||||
- **Smart time window**: Fixed 400+120 data point time window slider selection
|
|
||||||
- **Real model prediction**: Integrated real Kronos model, supports multiple model sizes
|
|
||||||
- **Prediction quality control**: Adjustable temperature, nucleus sampling, sample count and other parameters
|
|
||||||
- **Multi-device support**: Supports CPU, CUDA, MPS and other computing devices
|
|
||||||
- **Comparison analysis**: Detailed comparison between prediction results and actual data
|
|
||||||
- **K-line chart display**: Professional financial K-line chart display
|
|
||||||
|
|
||||||
## 🚀 Quick Start
|
|
||||||
|
|
||||||
### Method 1: Start with Python script
|
|
||||||
```bash
|
|
||||||
cd webui
|
|
||||||
python run.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### Method 2: Start with Shell script
|
|
||||||
```bash
|
|
||||||
cd webui
|
|
||||||
chmod +x start.sh
|
|
||||||
./start.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
### Method 3: Start Flask application directly
|
|
||||||
```bash
|
|
||||||
cd webui
|
|
||||||
python app.py
|
|
||||||
```
|
|
||||||
|
|
||||||
After successful startup, visit http://localhost:7070
|
|
||||||
|
|
||||||
## 📋 Usage Steps
|
|
||||||
|
|
||||||
1. **Load data**: Select financial data file from data directory
|
|
||||||
2. **Load model**: Select Kronos model and computing device
|
|
||||||
3. **Set parameters**: Adjust prediction quality parameters
|
|
||||||
4. **Select time window**: Use slider to select 400+120 data point time range
|
|
||||||
5. **Start prediction**: Click prediction button to generate results
|
|
||||||
6. **View results**: View prediction results in charts and tables
|
|
||||||
|
|
||||||
## 🔧 Prediction Quality Parameters
|
|
||||||
|
|
||||||
### Temperature (T)
|
|
||||||
- **Range**: 0.1 - 2.0
|
|
||||||
- **Effect**: Controls prediction randomness
|
|
||||||
- **Recommendation**: 1.2-1.5 for better prediction quality
|
|
||||||
|
|
||||||
### Nucleus Sampling (top_p)
|
|
||||||
- **Range**: 0.1 - 1.0
|
|
||||||
- **Effect**: Controls prediction diversity
|
|
||||||
- **Recommendation**: 0.95-1.0 to consider more possibilities
|
|
||||||
|
|
||||||
### Sample Count
|
|
||||||
- **Range**: 1 - 5
|
|
||||||
- **Effect**: Generate multiple prediction samples
|
|
||||||
- **Recommendation**: 2-3 samples to improve quality
|
|
||||||
|
|
||||||
## 📊 Supported Data Formats
|
|
||||||
|
|
||||||
### Required Columns
|
|
||||||
- `open`: Opening price
|
|
||||||
- `high`: Highest price
|
|
||||||
- `low`: Lowest price
|
|
||||||
- `close`: Closing price
|
|
||||||
|
|
||||||
### Optional Columns
|
|
||||||
- `volume`: Trading volume
|
|
||||||
- `amount`: Trading amount (not used for prediction)
|
|
||||||
- `timestamps`/`timestamp`/`date`: Timestamp
|
|
||||||
|
|
||||||
## 🤖 Model Support
|
|
||||||
|
|
||||||
- **Kronos-mini**: 4.1M parameters, lightweight fast prediction
|
|
||||||
- **Kronos-small**: 24.7M parameters, balanced performance and speed
|
|
||||||
- **Kronos-base**: 102.3M parameters, high quality prediction
|
|
||||||
|
|
||||||
## 🖥️ GPU Acceleration Support
|
|
||||||
|
|
||||||
- **CPU**: General computing, best compatibility
|
|
||||||
- **CUDA**: NVIDIA GPU acceleration, best performance
|
|
||||||
- **MPS**: Apple Silicon GPU acceleration, recommended for Mac users
|
|
||||||
|
|
||||||
## ⚠️ Notes
|
|
||||||
|
|
||||||
- `amount` column is not used for prediction, only for display
|
|
||||||
- Time window is fixed at 400+120=520 data points
|
|
||||||
- Ensure data file contains sufficient historical data
|
|
||||||
- First model loading may require download, please be patient
|
|
||||||
|
|
||||||
## 🔍 Comparison Analysis
|
|
||||||
|
|
||||||
The system automatically provides comparison analysis between prediction results and actual data, including:
|
|
||||||
- Price difference statistics
|
|
||||||
- Error analysis
|
|
||||||
- Prediction quality assessment
|
|
||||||
|
|
||||||
## 🛠️ Technical Architecture
|
|
||||||
|
|
||||||
- **Backend**: Flask + Python
|
|
||||||
- **Frontend**: HTML + CSS + JavaScript
|
|
||||||
- **Charts**: Plotly.js
|
|
||||||
- **Data processing**: Pandas + NumPy
|
|
||||||
- **Model**: Hugging Face Transformers
|
|
||||||
|
|
||||||
## 📝 Troubleshooting
|
|
||||||
|
|
||||||
### Common Issues
|
|
||||||
1. **Port occupied**: Modify port number in app.py
|
|
||||||
2. **Missing dependencies**: Run `pip install -r requirements.txt`
|
|
||||||
3. **Model loading failed**: Check network connection and model ID
|
|
||||||
4. **Data format error**: Ensure data column names and format are correct
|
|
||||||
|
|
||||||
### Log Viewing
|
|
||||||
Detailed runtime information will be displayed in the console at startup, including model status and error messages.
|
|
||||||
|
|
||||||
## 📄 License
|
|
||||||
|
|
||||||
This project follows the license terms of the original Kronos project.
|
|
||||||
|
|
||||||
## 🤝 Contributing
|
|
||||||
|
|
||||||
Welcome to submit Issues and Pull Requests to improve this Web UI!
|
|
||||||
|
|
||||||
## 📞 Support
|
|
||||||
|
|
||||||
If you have questions, please check:
|
|
||||||
1. Project documentation
|
|
||||||
2. GitHub Issues
|
|
||||||
3. Console error messages
|
|
||||||
708
webui/app.py
@ -1,708 +0,0 @@
|
|||||||
import os
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import json
|
|
||||||
import plotly.graph_objects as go
|
|
||||||
import plotly.utils
|
|
||||||
from flask import Flask, render_template, request, jsonify
|
|
||||||
from flask_cors import CORS
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
import datetime
|
|
||||||
warnings.filterwarnings('ignore')
|
|
||||||
|
|
||||||
# Add project root directory to path
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
try:
|
|
||||||
from model import Kronos, KronosTokenizer, KronosPredictor
|
|
||||||
MODEL_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
MODEL_AVAILABLE = False
|
|
||||||
print("Warning: Kronos model cannot be imported, will use simulated data for demonstration")
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
CORS(app)
|
|
||||||
|
|
||||||
# Global variables to store models
|
|
||||||
tokenizer = None
|
|
||||||
model = None
|
|
||||||
predictor = None
|
|
||||||
|
|
||||||
# Available model configurations
|
|
||||||
AVAILABLE_MODELS = {
|
|
||||||
'kronos-mini': {
|
|
||||||
'name': 'Kronos-mini',
|
|
||||||
'model_id': 'NeoQuasar/Kronos-mini',
|
|
||||||
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-2k',
|
|
||||||
'context_length': 2048,
|
|
||||||
'params': '4.1M',
|
|
||||||
'description': 'Lightweight model, suitable for fast prediction'
|
|
||||||
},
|
|
||||||
'kronos-small': {
|
|
||||||
'name': 'Kronos-small',
|
|
||||||
'model_id': 'NeoQuasar/Kronos-small',
|
|
||||||
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base',
|
|
||||||
'context_length': 512,
|
|
||||||
'params': '24.7M',
|
|
||||||
'description': 'Small model, balanced performance and speed'
|
|
||||||
},
|
|
||||||
'kronos-base': {
|
|
||||||
'name': 'Kronos-base',
|
|
||||||
'model_id': 'NeoQuasar/Kronos-base',
|
|
||||||
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base',
|
|
||||||
'context_length': 512,
|
|
||||||
'params': '102.3M',
|
|
||||||
'description': 'Base model, provides better prediction quality'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_data_files():
|
|
||||||
"""Scan data directory and return available data files"""
|
|
||||||
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data')
|
|
||||||
data_files = []
|
|
||||||
|
|
||||||
if os.path.exists(data_dir):
|
|
||||||
for file in os.listdir(data_dir):
|
|
||||||
if file.endswith(('.csv', '.feather')):
|
|
||||||
file_path = os.path.join(data_dir, file)
|
|
||||||
file_size = os.path.getsize(file_path)
|
|
||||||
data_files.append({
|
|
||||||
'name': file,
|
|
||||||
'path': file_path,
|
|
||||||
'size': f"{file_size / 1024:.1f} KB" if file_size < 1024*1024 else f"{file_size / (1024*1024):.1f} MB"
|
|
||||||
})
|
|
||||||
|
|
||||||
return data_files
|
|
||||||
|
|
||||||
def load_data_file(file_path):
|
|
||||||
"""Load data file"""
|
|
||||||
try:
|
|
||||||
if file_path.endswith('.csv'):
|
|
||||||
df = pd.read_csv(file_path)
|
|
||||||
elif file_path.endswith('.feather'):
|
|
||||||
df = pd.read_feather(file_path)
|
|
||||||
else:
|
|
||||||
return None, "Unsupported file format"
|
|
||||||
|
|
||||||
# Check required columns
|
|
||||||
required_cols = ['open', 'high', 'low', 'close']
|
|
||||||
if not all(col in df.columns for col in required_cols):
|
|
||||||
return None, f"Missing required columns: {required_cols}"
|
|
||||||
|
|
||||||
# Process timestamp column
|
|
||||||
if 'timestamps' in df.columns:
|
|
||||||
df['timestamps'] = pd.to_datetime(df['timestamps'])
|
|
||||||
elif 'timestamp' in df.columns:
|
|
||||||
df['timestamps'] = pd.to_datetime(df['timestamp'])
|
|
||||||
elif 'date' in df.columns:
|
|
||||||
# If column name is 'date', rename it to 'timestamps'
|
|
||||||
df['timestamps'] = pd.to_datetime(df['date'])
|
|
||||||
else:
|
|
||||||
# If no timestamp column exists, create one
|
|
||||||
df['timestamps'] = pd.date_range(start='2024-01-01', periods=len(df), freq='1H')
|
|
||||||
|
|
||||||
# Ensure numeric columns are numeric type
|
|
||||||
for col in ['open', 'high', 'low', 'close']:
|
|
||||||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
|
||||||
|
|
||||||
# Process volume column (optional)
|
|
||||||
if 'volume' in df.columns:
|
|
||||||
df['volume'] = pd.to_numeric(df['volume'], errors='coerce')
|
|
||||||
|
|
||||||
# Process amount column (optional, but not used for prediction)
|
|
||||||
if 'amount' in df.columns:
|
|
||||||
df['amount'] = pd.to_numeric(df['amount'], errors='coerce')
|
|
||||||
|
|
||||||
# Remove rows containing NaN values
|
|
||||||
df = df.dropna()
|
|
||||||
|
|
||||||
return df, None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return None, f"Failed to load file: {str(e)}"
|
|
||||||
|
|
||||||
def save_prediction_results(file_path, prediction_type, prediction_results, actual_data, input_data, prediction_params):
|
|
||||||
"""Save prediction results to file"""
|
|
||||||
try:
|
|
||||||
# Create prediction results directory
|
|
||||||
results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prediction_results')
|
|
||||||
os.makedirs(results_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Generate filename
|
|
||||||
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
||||||
filename = f'prediction_{timestamp}.json'
|
|
||||||
filepath = os.path.join(results_dir, filename)
|
|
||||||
|
|
||||||
# Prepare data for saving
|
|
||||||
save_data = {
|
|
||||||
'timestamp': datetime.datetime.now().isoformat(),
|
|
||||||
'file_path': file_path,
|
|
||||||
'prediction_type': prediction_type,
|
|
||||||
'prediction_params': prediction_params,
|
|
||||||
'input_data_summary': {
|
|
||||||
'rows': len(input_data),
|
|
||||||
'columns': list(input_data.columns),
|
|
||||||
'price_range': {
|
|
||||||
'open': {'min': float(input_data['open'].min()), 'max': float(input_data['open'].max())},
|
|
||||||
'high': {'min': float(input_data['high'].min()), 'max': float(input_data['high'].max())},
|
|
||||||
'low': {'min': float(input_data['low'].min()), 'max': float(input_data['low'].max())},
|
|
||||||
'close': {'min': float(input_data['close'].min()), 'max': float(input_data['close'].max())}
|
|
||||||
},
|
|
||||||
'last_values': {
|
|
||||||
'open': float(input_data['open'].iloc[-1]),
|
|
||||||
'high': float(input_data['high'].iloc[-1]),
|
|
||||||
'low': float(input_data['low'].iloc[-1]),
|
|
||||||
'close': float(input_data['close'].iloc[-1])
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'prediction_results': prediction_results,
|
|
||||||
'actual_data': actual_data,
|
|
||||||
'analysis': {}
|
|
||||||
}
|
|
||||||
|
|
||||||
# If actual data exists, perform comparison analysis
|
|
||||||
if actual_data and len(actual_data) > 0:
|
|
||||||
# Calculate continuity analysis
|
|
||||||
if len(prediction_results) > 0 and len(actual_data) > 0:
|
|
||||||
last_pred = prediction_results[0] # First prediction point
|
|
||||||
first_actual = actual_data[0] # First actual point
|
|
||||||
|
|
||||||
save_data['analysis']['continuity'] = {
|
|
||||||
'last_prediction': {
|
|
||||||
'open': last_pred['open'],
|
|
||||||
'high': last_pred['high'],
|
|
||||||
'low': last_pred['low'],
|
|
||||||
'close': last_pred['close']
|
|
||||||
},
|
|
||||||
'first_actual': {
|
|
||||||
'open': first_actual['open'],
|
|
||||||
'high': first_actual['high'],
|
|
||||||
'low': first_actual['low'],
|
|
||||||
'close': first_actual['close']
|
|
||||||
},
|
|
||||||
'gaps': {
|
|
||||||
'open_gap': abs(last_pred['open'] - first_actual['open']),
|
|
||||||
'high_gap': abs(last_pred['high'] - first_actual['high']),
|
|
||||||
'low_gap': abs(last_pred['low'] - first_actual['low']),
|
|
||||||
'close_gap': abs(last_pred['close'] - first_actual['close'])
|
|
||||||
},
|
|
||||||
'gap_percentages': {
|
|
||||||
'open_gap_pct': (abs(last_pred['open'] - first_actual['open']) / first_actual['open']) * 100,
|
|
||||||
'high_gap_pct': (abs(last_pred['high'] - first_actual['high']) / first_actual['high']) * 100,
|
|
||||||
'low_gap_pct': (abs(last_pred['low'] - first_actual['low']) / first_actual['low']) * 100,
|
|
||||||
'close_gap_pct': (abs(last_pred['close'] - first_actual['close']) / first_actual['close']) * 100
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Save to file
|
|
||||||
with open(filepath, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(save_data, f, indent=2, ensure_ascii=False)
|
|
||||||
|
|
||||||
print(f"Prediction results saved to: {filepath}")
|
|
||||||
return filepath
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to save prediction results: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0):
|
|
||||||
"""Create prediction chart"""
|
|
||||||
# Use specified historical data start position, not always from the beginning of df
|
|
||||||
if historical_start_idx + lookback + pred_len <= len(df):
|
|
||||||
# Display lookback historical points + pred_len prediction points starting from specified position
|
|
||||||
historical_df = df.iloc[historical_start_idx:historical_start_idx+lookback]
|
|
||||||
prediction_range = range(historical_start_idx+lookback, historical_start_idx+lookback+pred_len)
|
|
||||||
else:
|
|
||||||
# If data is insufficient, adjust to maximum available range
|
|
||||||
available_lookback = min(lookback, len(df) - historical_start_idx)
|
|
||||||
available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback))
|
|
||||||
historical_df = df.iloc[historical_start_idx:historical_start_idx+available_lookback]
|
|
||||||
prediction_range = range(historical_start_idx+available_lookback, historical_start_idx+available_lookback+available_pred_len)
|
|
||||||
|
|
||||||
# Create chart
|
|
||||||
fig = go.Figure()
|
|
||||||
|
|
||||||
# Add historical data (candlestick chart)
|
|
||||||
fig.add_trace(go.Candlestick(
|
|
||||||
x=historical_df['timestamps'] if 'timestamps' in historical_df.columns else historical_df.index,
|
|
||||||
open=historical_df['open'],
|
|
||||||
high=historical_df['high'],
|
|
||||||
low=historical_df['low'],
|
|
||||||
close=historical_df['close'],
|
|
||||||
name='Historical Data (400 data points)',
|
|
||||||
increasing_line_color='#26A69A',
|
|
||||||
decreasing_line_color='#EF5350'
|
|
||||||
))
|
|
||||||
|
|
||||||
# Add prediction data (candlestick chart)
|
|
||||||
if pred_df is not None and len(pred_df) > 0:
|
|
||||||
# Calculate prediction data timestamps - ensure continuity with historical data
|
|
||||||
if 'timestamps' in df.columns and len(historical_df) > 0:
|
|
||||||
# Start from the last timestamp of historical data, create prediction timestamps with the same time interval
|
|
||||||
last_timestamp = historical_df['timestamps'].iloc[-1]
|
|
||||||
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
|
|
||||||
|
|
||||||
pred_timestamps = pd.date_range(
|
|
||||||
start=last_timestamp + time_diff,
|
|
||||||
periods=len(pred_df),
|
|
||||||
freq=time_diff
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# If no timestamps, use index
|
|
||||||
pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df))
|
|
||||||
|
|
||||||
fig.add_trace(go.Candlestick(
|
|
||||||
x=pred_timestamps,
|
|
||||||
open=pred_df['open'],
|
|
||||||
high=pred_df['high'],
|
|
||||||
low=pred_df['low'],
|
|
||||||
close=pred_df['close'],
|
|
||||||
name='Prediction Data (120 data points)',
|
|
||||||
increasing_line_color='#66BB6A',
|
|
||||||
decreasing_line_color='#FF7043'
|
|
||||||
))
|
|
||||||
|
|
||||||
# Add actual data for comparison (if exists)
|
|
||||||
if actual_df is not None and len(actual_df) > 0:
|
|
||||||
# Actual data should be in the same time period as prediction data
|
|
||||||
if 'timestamps' in df.columns:
|
|
||||||
# Actual data should use the same timestamps as prediction data to ensure time alignment
|
|
||||||
if 'pred_timestamps' in locals():
|
|
||||||
actual_timestamps = pred_timestamps
|
|
||||||
else:
|
|
||||||
# If no prediction timestamps, calculate from the last timestamp of historical data
|
|
||||||
if len(historical_df) > 0:
|
|
||||||
last_timestamp = historical_df['timestamps'].iloc[-1]
|
|
||||||
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
|
|
||||||
actual_timestamps = pd.date_range(
|
|
||||||
start=last_timestamp + time_diff,
|
|
||||||
periods=len(actual_df),
|
|
||||||
freq=time_diff
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
|
|
||||||
else:
|
|
||||||
actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
|
|
||||||
|
|
||||||
fig.add_trace(go.Candlestick(
|
|
||||||
x=actual_timestamps,
|
|
||||||
open=actual_df['open'],
|
|
||||||
high=actual_df['high'],
|
|
||||||
low=actual_df['low'],
|
|
||||||
close=actual_df['close'],
|
|
||||||
name='Actual Data (120 data points)',
|
|
||||||
increasing_line_color='#FF9800',
|
|
||||||
decreasing_line_color='#F44336'
|
|
||||||
))
|
|
||||||
|
|
||||||
# Update layout
|
|
||||||
fig.update_layout(
|
|
||||||
title='Kronos Financial Prediction Results - 400 Historical Points + 120 Prediction Points vs 120 Actual Points',
|
|
||||||
xaxis_title='Time',
|
|
||||||
yaxis_title='Price',
|
|
||||||
template='plotly_white',
|
|
||||||
height=600,
|
|
||||||
showlegend=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure x-axis time continuity
|
|
||||||
if 'timestamps' in historical_df.columns:
|
|
||||||
# Get all timestamps and sort them
|
|
||||||
all_timestamps = []
|
|
||||||
if len(historical_df) > 0:
|
|
||||||
all_timestamps.extend(historical_df['timestamps'])
|
|
||||||
if 'pred_timestamps' in locals():
|
|
||||||
all_timestamps.extend(pred_timestamps)
|
|
||||||
if 'actual_timestamps' in locals():
|
|
||||||
all_timestamps.extend(actual_timestamps)
|
|
||||||
|
|
||||||
if all_timestamps:
|
|
||||||
all_timestamps = sorted(all_timestamps)
|
|
||||||
fig.update_xaxes(
|
|
||||||
range=[all_timestamps[0], all_timestamps[-1]],
|
|
||||||
rangeslider_visible=False,
|
|
||||||
type='date'
|
|
||||||
)
|
|
||||||
|
|
||||||
return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
|
|
||||||
|
|
||||||
@app.route('/')
|
|
||||||
def index():
|
|
||||||
"""Home page"""
|
|
||||||
return render_template('index.html')
|
|
||||||
|
|
||||||
@app.route('/api/data-files')
|
|
||||||
def get_data_files():
|
|
||||||
"""Get available data file list"""
|
|
||||||
data_files = load_data_files()
|
|
||||||
return jsonify(data_files)
|
|
||||||
|
|
||||||
@app.route('/api/load-data', methods=['POST'])
|
|
||||||
def load_data():
|
|
||||||
"""Load data file"""
|
|
||||||
try:
|
|
||||||
data = request.get_json()
|
|
||||||
file_path = data.get('file_path')
|
|
||||||
|
|
||||||
if not file_path:
|
|
||||||
return jsonify({'error': 'File path cannot be empty'}), 400
|
|
||||||
|
|
||||||
df, error = load_data_file(file_path)
|
|
||||||
if error:
|
|
||||||
return jsonify({'error': error}), 400
|
|
||||||
|
|
||||||
# Detect data time frequency
|
|
||||||
def detect_timeframe(df):
|
|
||||||
if len(df) < 2:
|
|
||||||
return "Unknown"
|
|
||||||
|
|
||||||
time_diffs = []
|
|
||||||
for i in range(1, min(10, len(df))): # Check first 10 time differences
|
|
||||||
diff = df['timestamps'].iloc[i] - df['timestamps'].iloc[i-1]
|
|
||||||
time_diffs.append(diff)
|
|
||||||
|
|
||||||
if not time_diffs:
|
|
||||||
return "Unknown"
|
|
||||||
|
|
||||||
# Calculate average time difference
|
|
||||||
avg_diff = sum(time_diffs, pd.Timedelta(0)) / len(time_diffs)
|
|
||||||
|
|
||||||
# Convert to readable format
|
|
||||||
if avg_diff < pd.Timedelta(minutes=1):
|
|
||||||
return f"{avg_diff.total_seconds():.0f} seconds"
|
|
||||||
elif avg_diff < pd.Timedelta(hours=1):
|
|
||||||
return f"{avg_diff.total_seconds() / 60:.0f} minutes"
|
|
||||||
elif avg_diff < pd.Timedelta(days=1):
|
|
||||||
return f"{avg_diff.total_seconds() / 3600:.0f} hours"
|
|
||||||
else:
|
|
||||||
return f"{avg_diff.days} days"
|
|
||||||
|
|
||||||
# Return data information
|
|
||||||
data_info = {
|
|
||||||
'rows': len(df),
|
|
||||||
'columns': list(df.columns),
|
|
||||||
'start_date': df['timestamps'].min().isoformat() if 'timestamps' in df.columns else 'N/A',
|
|
||||||
'end_date': df['timestamps'].max().isoformat() if 'timestamps' in df.columns else 'N/A',
|
|
||||||
'price_range': {
|
|
||||||
'min': float(df[['open', 'high', 'low', 'close']].min().min()),
|
|
||||||
'max': float(df[['open', 'high', 'low', 'close']].max().max())
|
|
||||||
},
|
|
||||||
'prediction_columns': ['open', 'high', 'low', 'close'] + (['volume'] if 'volume' in df.columns else []),
|
|
||||||
'timeframe': detect_timeframe(df)
|
|
||||||
}
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'data_info': data_info,
|
|
||||||
'message': f'Successfully loaded data, total {len(df)} rows'
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({'error': f'Failed to load data: {str(e)}'}), 500
|
|
||||||
|
|
||||||
@app.route('/api/predict', methods=['POST'])
|
|
||||||
def predict():
|
|
||||||
"""Perform prediction"""
|
|
||||||
try:
|
|
||||||
data = request.get_json()
|
|
||||||
file_path = data.get('file_path')
|
|
||||||
lookback = int(data.get('lookback', 400))
|
|
||||||
pred_len = int(data.get('pred_len', 120))
|
|
||||||
|
|
||||||
# Get prediction quality parameters
|
|
||||||
temperature = float(data.get('temperature', 1.0))
|
|
||||||
top_p = float(data.get('top_p', 0.9))
|
|
||||||
sample_count = int(data.get('sample_count', 1))
|
|
||||||
|
|
||||||
if not file_path:
|
|
||||||
return jsonify({'error': 'File path cannot be empty'}), 400
|
|
||||||
|
|
||||||
# Load data
|
|
||||||
df, error = load_data_file(file_path)
|
|
||||||
if error:
|
|
||||||
return jsonify({'error': error}), 400
|
|
||||||
|
|
||||||
if len(df) < lookback:
|
|
||||||
return jsonify({'error': f'Insufficient data length, need at least {lookback} rows'}), 400
|
|
||||||
|
|
||||||
# Perform prediction
|
|
||||||
if MODEL_AVAILABLE and predictor is not None:
|
|
||||||
try:
|
|
||||||
# Use real Kronos model
|
|
||||||
# Only use necessary columns: OHLCV, excluding amount
|
|
||||||
required_cols = ['open', 'high', 'low', 'close']
|
|
||||||
if 'volume' in df.columns:
|
|
||||||
required_cols.append('volume')
|
|
||||||
|
|
||||||
# Process time period selection
|
|
||||||
start_date = data.get('start_date')
|
|
||||||
|
|
||||||
if start_date:
|
|
||||||
# Custom time period - fix logic: use data within selected window
|
|
||||||
start_dt = pd.to_datetime(start_date)
|
|
||||||
|
|
||||||
# Find data after start time
|
|
||||||
mask = df['timestamps'] >= start_dt
|
|
||||||
time_range_df = df[mask]
|
|
||||||
|
|
||||||
# Ensure sufficient data: lookback + pred_len
|
|
||||||
if len(time_range_df) < lookback + pred_len:
|
|
||||||
return jsonify({'error': f'Insufficient data from start time {start_dt.strftime("%Y-%m-%d %H:%M")}, need at least {lookback + pred_len} data points, currently only {len(time_range_df)} available'}), 400
|
|
||||||
|
|
||||||
# Use first lookback data points within selected window for prediction
|
|
||||||
x_df = time_range_df.iloc[:lookback][required_cols]
|
|
||||||
x_timestamp = time_range_df.iloc[:lookback]['timestamps']
|
|
||||||
|
|
||||||
# Use last pred_len data points within selected window as actual values
|
|
||||||
y_timestamp = time_range_df.iloc[lookback:lookback+pred_len]['timestamps']
|
|
||||||
|
|
||||||
# Calculate actual time period length
|
|
||||||
start_timestamp = time_range_df['timestamps'].iloc[0]
|
|
||||||
end_timestamp = time_range_df['timestamps'].iloc[lookback+pred_len-1]
|
|
||||||
time_span = end_timestamp - start_timestamp
|
|
||||||
|
|
||||||
prediction_type = f"Kronos model prediction (within selected window: first {lookback} data points for prediction, last {pred_len} data points for comparison, time span: {time_span})"
|
|
||||||
else:
|
|
||||||
# Use latest data
|
|
||||||
x_df = df.iloc[:lookback][required_cols]
|
|
||||||
x_timestamp = df.iloc[:lookback]['timestamps']
|
|
||||||
y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps']
|
|
||||||
prediction_type = "Kronos model prediction (latest data)"
|
|
||||||
|
|
||||||
# Ensure timestamps are Series format, not DatetimeIndex, to avoid .dt attribute error in Kronos model
|
|
||||||
if isinstance(x_timestamp, pd.DatetimeIndex):
|
|
||||||
x_timestamp = pd.Series(x_timestamp, name='timestamps')
|
|
||||||
if isinstance(y_timestamp, pd.DatetimeIndex):
|
|
||||||
y_timestamp = pd.Series(y_timestamp, name='timestamps')
|
|
||||||
|
|
||||||
pred_df = predictor.predict(
|
|
||||||
df=x_df,
|
|
||||||
x_timestamp=x_timestamp,
|
|
||||||
y_timestamp=y_timestamp,
|
|
||||||
pred_len=pred_len,
|
|
||||||
T=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
sample_count=sample_count
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({'error': f'Kronos model prediction failed: {str(e)}'}), 500
|
|
||||||
else:
|
|
||||||
return jsonify({'error': 'Kronos model not loaded, please load model first'}), 400
|
|
||||||
|
|
||||||
# Prepare actual data for comparison (if exists)
|
|
||||||
actual_data = []
|
|
||||||
actual_df = None
|
|
||||||
|
|
||||||
if start_date: # Custom time period
|
|
||||||
# Fix logic: use data within selected window
|
|
||||||
# Prediction uses first 400 data points within selected window
|
|
||||||
# Actual data should be last 120 data points within selected window
|
|
||||||
start_dt = pd.to_datetime(start_date)
|
|
||||||
|
|
||||||
# Find data starting from start_date
|
|
||||||
mask = df['timestamps'] >= start_dt
|
|
||||||
time_range_df = df[mask]
|
|
||||||
|
|
||||||
if len(time_range_df) >= lookback + pred_len:
|
|
||||||
# Get last 120 data points within selected window as actual values
|
|
||||||
actual_df = time_range_df.iloc[lookback:lookback+pred_len]
|
|
||||||
|
|
||||||
for i, (_, row) in enumerate(actual_df.iterrows()):
|
|
||||||
actual_data.append({
|
|
||||||
'timestamp': row['timestamps'].isoformat(),
|
|
||||||
'open': float(row['open']),
|
|
||||||
'high': float(row['high']),
|
|
||||||
'low': float(row['low']),
|
|
||||||
'close': float(row['close']),
|
|
||||||
'volume': float(row['volume']) if 'volume' in row else 0,
|
|
||||||
'amount': float(row['amount']) if 'amount' in row else 0
|
|
||||||
})
|
|
||||||
else: # Latest data
|
|
||||||
# Prediction uses first 400 data points
|
|
||||||
# Actual data should be 120 data points after first 400 data points
|
|
||||||
if len(df) >= lookback + pred_len:
|
|
||||||
actual_df = df.iloc[lookback:lookback+pred_len]
|
|
||||||
for i, (_, row) in enumerate(actual_df.iterrows()):
|
|
||||||
actual_data.append({
|
|
||||||
'timestamp': row['timestamps'].isoformat(),
|
|
||||||
'open': float(row['open']),
|
|
||||||
'high': float(row['high']),
|
|
||||||
'low': float(row['low']),
|
|
||||||
'close': float(row['close']),
|
|
||||||
'volume': float(row['volume']) if 'volume' in row else 0,
|
|
||||||
'amount': float(row['amount']) if 'amount' in row else 0
|
|
||||||
})
|
|
||||||
|
|
||||||
# Create chart - pass historical data start position
|
|
||||||
if start_date:
|
|
||||||
# Custom time period: find starting position of historical data in original df
|
|
||||||
start_dt = pd.to_datetime(start_date)
|
|
||||||
mask = df['timestamps'] >= start_dt
|
|
||||||
historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0
|
|
||||||
else:
|
|
||||||
# Latest data: start from beginning
|
|
||||||
historical_start_idx = 0
|
|
||||||
|
|
||||||
chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx)
|
|
||||||
|
|
||||||
# Prepare prediction result data - fix timestamp calculation logic
|
|
||||||
if 'timestamps' in df.columns:
|
|
||||||
if start_date:
|
|
||||||
# Custom time period: use selected window data to calculate timestamps
|
|
||||||
start_dt = pd.to_datetime(start_date)
|
|
||||||
mask = df['timestamps'] >= start_dt
|
|
||||||
time_range_df = df[mask]
|
|
||||||
|
|
||||||
if len(time_range_df) >= lookback:
|
|
||||||
# Calculate prediction timestamps starting from last time point of selected window
|
|
||||||
last_timestamp = time_range_df['timestamps'].iloc[lookback-1]
|
|
||||||
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
|
|
||||||
future_timestamps = pd.date_range(
|
|
||||||
start=last_timestamp + time_diff,
|
|
||||||
periods=pred_len,
|
|
||||||
freq=time_diff
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
future_timestamps = []
|
|
||||||
else:
|
|
||||||
# Latest data: calculate from last time point of entire data file
|
|
||||||
last_timestamp = df['timestamps'].iloc[-1]
|
|
||||||
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
|
|
||||||
future_timestamps = pd.date_range(
|
|
||||||
start=last_timestamp + time_diff,
|
|
||||||
periods=pred_len,
|
|
||||||
freq=time_diff
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
future_timestamps = range(len(df), len(df) + pred_len)
|
|
||||||
|
|
||||||
prediction_results = []
|
|
||||||
for i, (_, row) in enumerate(pred_df.iterrows()):
|
|
||||||
prediction_results.append({
|
|
||||||
'timestamp': future_timestamps[i].isoformat() if i < len(future_timestamps) else f"T{i}",
|
|
||||||
'open': float(row['open']),
|
|
||||||
'high': float(row['high']),
|
|
||||||
'low': float(row['low']),
|
|
||||||
'close': float(row['close']),
|
|
||||||
'volume': float(row['volume']) if 'volume' in row else 0,
|
|
||||||
'amount': float(row['amount']) if 'amount' in row else 0
|
|
||||||
})
|
|
||||||
|
|
||||||
# Save prediction results to file
|
|
||||||
try:
|
|
||||||
save_prediction_results(
|
|
||||||
file_path=file_path,
|
|
||||||
prediction_type=prediction_type,
|
|
||||||
prediction_results=prediction_results,
|
|
||||||
actual_data=actual_data,
|
|
||||||
input_data=x_df,
|
|
||||||
prediction_params={
|
|
||||||
'lookback': lookback,
|
|
||||||
'pred_len': pred_len,
|
|
||||||
'temperature': temperature,
|
|
||||||
'top_p': top_p,
|
|
||||||
'sample_count': sample_count,
|
|
||||||
'start_date': start_date if start_date else 'latest'
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to save prediction results: {e}")
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'prediction_type': prediction_type,
|
|
||||||
'chart': chart_json,
|
|
||||||
'prediction_results': prediction_results,
|
|
||||||
'actual_data': actual_data,
|
|
||||||
'has_comparison': len(actual_data) > 0,
|
|
||||||
'message': f'Prediction completed, generated {pred_len} prediction points' + (f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '')
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({'error': f'Prediction failed: {str(e)}'}), 500
|
|
||||||
|
|
||||||
@app.route('/api/load-model', methods=['POST'])
|
|
||||||
def load_model():
|
|
||||||
"""Load Kronos model"""
|
|
||||||
global tokenizer, model, predictor
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not MODEL_AVAILABLE:
|
|
||||||
return jsonify({'error': 'Kronos model library not available'}), 400
|
|
||||||
|
|
||||||
data = request.get_json()
|
|
||||||
model_key = data.get('model_key', 'kronos-small')
|
|
||||||
device = data.get('device', 'cpu')
|
|
||||||
|
|
||||||
if model_key not in AVAILABLE_MODELS:
|
|
||||||
return jsonify({'error': f'Unsupported model: {model_key}'}), 400
|
|
||||||
|
|
||||||
model_config = AVAILABLE_MODELS[model_key]
|
|
||||||
|
|
||||||
# Load tokenizer and model
|
|
||||||
tokenizer = KronosTokenizer.from_pretrained(model_config['tokenizer_id'])
|
|
||||||
model = Kronos.from_pretrained(model_config['model_id'])
|
|
||||||
|
|
||||||
# Create predictor
|
|
||||||
predictor = KronosPredictor(model, tokenizer, device=device, max_context=model_config['context_length'])
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'message': f'Model loaded successfully: {model_config["name"]} ({model_config["params"]}) on {device}',
|
|
||||||
'model_info': {
|
|
||||||
'name': model_config['name'],
|
|
||||||
'params': model_config['params'],
|
|
||||||
'context_length': model_config['context_length'],
|
|
||||||
'description': model_config['description']
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({'error': f'Model loading failed: {str(e)}'}), 500
|
|
||||||
|
|
||||||
@app.route('/api/available-models')
|
|
||||||
def get_available_models():
|
|
||||||
"""Get available model list"""
|
|
||||||
return jsonify({
|
|
||||||
'models': AVAILABLE_MODELS,
|
|
||||||
'model_available': MODEL_AVAILABLE
|
|
||||||
})
|
|
||||||
|
|
||||||
@app.route('/api/model-status')
|
|
||||||
def get_model_status():
|
|
||||||
"""Get model status"""
|
|
||||||
if MODEL_AVAILABLE:
|
|
||||||
if predictor is not None:
|
|
||||||
return jsonify({
|
|
||||||
'available': True,
|
|
||||||
'loaded': True,
|
|
||||||
'message': 'Kronos model loaded and available',
|
|
||||||
'current_model': {
|
|
||||||
'name': predictor.model.__class__.__name__,
|
|
||||||
'device': str(next(predictor.model.parameters()).device)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
return jsonify({
|
|
||||||
'available': True,
|
|
||||||
'loaded': False,
|
|
||||||
'message': 'Kronos model available but not loaded'
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
return jsonify({
|
|
||||||
'available': False,
|
|
||||||
'loaded': False,
|
|
||||||
'message': 'Kronos model library not available, please install related dependencies'
|
|
||||||
})
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
print("Starting Kronos Web UI...")
|
|
||||||
print(f"Model availability: {MODEL_AVAILABLE}")
|
|
||||||
if MODEL_AVAILABLE:
|
|
||||||
print("Tip: You can load Kronos model through /api/load-model endpoint")
|
|
||||||
else:
|
|
||||||
print("Tip: Will use simulated data for demonstration")
|
|
||||||
|
|
||||||
app.run(debug=True, host='0.0.0.0', port=7070)
|
|
||||||
@ -1,7 +0,0 @@
|
|||||||
flask==2.3.3
|
|
||||||
flask-cors==4.0.0
|
|
||||||
pandas==2.2.2
|
|
||||||
numpy==1.24.3
|
|
||||||
plotly==5.17.0
|
|
||||||
torch>=2.1.0
|
|
||||||
huggingface_hub==0.33.1
|
|
||||||
89
webui/run.py
@ -1,89 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Kronos Web UI startup script
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import subprocess
|
|
||||||
import webbrowser
|
|
||||||
import time
|
|
||||||
|
|
||||||
def check_dependencies():
|
|
||||||
"""Check if dependencies are installed"""
|
|
||||||
try:
|
|
||||||
import flask
|
|
||||||
import flask_cors
|
|
||||||
import pandas
|
|
||||||
import numpy
|
|
||||||
import plotly
|
|
||||||
print("✅ All dependencies installed")
|
|
||||||
return True
|
|
||||||
except ImportError as e:
|
|
||||||
print(f"❌ Missing dependency: {e}")
|
|
||||||
print("Please run: pip install -r requirements.txt")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def install_dependencies():
|
|
||||||
"""Install dependencies"""
|
|
||||||
print("Installing dependencies...")
|
|
||||||
try:
|
|
||||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"])
|
|
||||||
print("✅ Dependencies installation completed")
|
|
||||||
return True
|
|
||||||
except subprocess.CalledProcessError:
|
|
||||||
print("❌ Dependencies installation failed")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main function"""
|
|
||||||
print("🚀 Starting Kronos Web UI...")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
# Check dependencies
|
|
||||||
if not check_dependencies():
|
|
||||||
print("\nAuto-install dependencies? (y/n): ", end="")
|
|
||||||
if input().lower() == 'y':
|
|
||||||
if not install_dependencies():
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
print("Please manually install dependencies and retry")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check model availability
|
|
||||||
try:
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
from model import Kronos, KronosTokenizer, KronosPredictor
|
|
||||||
print("✅ Kronos model library available")
|
|
||||||
model_available = True
|
|
||||||
except ImportError:
|
|
||||||
print("⚠️ Kronos model library not available, will use simulated prediction")
|
|
||||||
model_available = False
|
|
||||||
|
|
||||||
# Start Flask application
|
|
||||||
print("\n🌐 Starting Web server...")
|
|
||||||
|
|
||||||
# Set environment variables
|
|
||||||
os.environ['FLASK_APP'] = 'app.py'
|
|
||||||
os.environ['FLASK_ENV'] = 'development'
|
|
||||||
|
|
||||||
# Start server
|
|
||||||
try:
|
|
||||||
from app import app
|
|
||||||
print("✅ Web server started successfully!")
|
|
||||||
print(f"🌐 Access URL: http://localhost:7070")
|
|
||||||
print("💡 Tip: Press Ctrl+C to stop server")
|
|
||||||
|
|
||||||
# Auto-open browser
|
|
||||||
time.sleep(2)
|
|
||||||
webbrowser.open('http://localhost:7070')
|
|
||||||
|
|
||||||
# Start Flask application
|
|
||||||
app.run(debug=True, host='0.0.0.0', port=7070)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Startup failed: {e}")
|
|
||||||
print("Please check if port 7070 is occupied")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@ -1,40 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Kronos Web UI startup script
|
|
||||||
|
|
||||||
echo "🚀 Starting Kronos Web UI..."
|
|
||||||
echo "================================"
|
|
||||||
|
|
||||||
# Check if Python is installed
|
|
||||||
if ! command -v python3 &> /dev/null; then
|
|
||||||
echo "❌ Python3 not installed, please install Python3 first"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Check if in correct directory
|
|
||||||
if [ ! -f "app.py" ]; then
|
|
||||||
echo "❌ Please run this script in the webui directory"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Check dependencies
|
|
||||||
echo "📦 Checking dependencies..."
|
|
||||||
if ! python3 -c "import flask, flask_cors, pandas, numpy, plotly" &> /dev/null; then
|
|
||||||
echo "⚠️ Missing dependencies, installing..."
|
|
||||||
pip3 install -r requirements.txt
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "❌ Dependencies installation failed"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
echo "✅ Dependencies installation completed"
|
|
||||||
else
|
|
||||||
echo "✅ All dependencies installed"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Start application
|
|
||||||
echo "🌐 Starting Web server..."
|
|
||||||
echo "Access URL: http://localhost:7070"
|
|
||||||
echo "Press Ctrl+C to stop server"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
python3 app.py
|
|
||||||