diff --git a/README.md b/README.md
index 33aa214..85b4c91 100644
--- a/README.md
+++ b/README.md
@@ -35,6 +35,14 @@
> trained on data from over **45 global exchanges**.
+
+
+## 📰 News
+* 🚩 **[2025.08.17]** We have released the scripts for fine-tuning! Check them out to adapt Kronos to your own tasks.
+* 🚩 **[2025.08.02]** Our paper is now available on [arXiv](https://arxiv.org/abs/2508.02739)!
+
+
+
## 📜 Introduction
**Kronos** is a family of decoder-only foundation models, pre-trained specifically for the "language" of financial markets—K-line sequences. Unlike general-purpose TSFMs, Kronos is designed to handle the unique, high-noise characteristics of financial data. It leverages a novel two-stage framework:
@@ -158,6 +166,106 @@ Running this script will generate a plot comparing the ground truth data against
Additionally, we also provide a script that makes predictions without Volume and Amount data, which can be found in [`examples/prediction_wo_vol_example.py`](examples/prediction_wo_vol_example.py).
+
+好的,收到了你的反馈!这两个建议都非常好,加入示例图能让结果更直观,而泛化“Backtesting Complexity”的描述能让建议更具普适性。
+
+我已经根据你的反馈更新了内容。以下是修改后的版本,你可以直接替换掉之前的内容。
+
+---
+
+## 🔧 Finetuning on Your Own Data (A-Share Market Example)
+
+We provide a complete pipeline for finetuning Kronos on your own datasets. As an example, we demonstrate how to use [Qlib](https://github.com/microsoft/qlib) to prepare data from the Chinese A-share market and conduct a simple backtest.
+
+> **Disclaimer:** This pipeline is intended as a demonstration to illustrate the finetuning process. It is a simplified example and not a production-ready quantitative trading system. A robust quantitative strategy requires more sophisticated techniques, such as portfolio optimization and risk factor neutralization, to achieve stable alpha.
+
+The finetuning process is divided into four main steps:
+
+1. **Configuration**: Set up paths and hyperparameters.
+2. **Data Preparation**: Process and split your data using Qlib.
+3. **Model Finetuning**: Finetune the Tokenizer and the Predictor models.
+4. **Backtesting**: Evaluate the finetuned model's performance.
+
+### Prerequisites
+
+1. First, ensure you have all dependencies from `requirements.txt` installed.
+2. This pipeline relies on `qlib`. Please install it:
+ ```shell
+ pip install pyqlib
+ ```
+3. You will need to prepare your Qlib data. Follow the [official Qlib guide](https://github.com/microsoft/qlib) to download and set up your data locally. The example scripts assume you are using daily frequency data.
+
+### Step 1: Configure Your Experiment
+
+All settings for data, training, and model paths are centralized in `finetune/config.py`. Before running any scripts, please **modify the following paths** according to your environment:
+
+* `qlib_data_path`: Path to your local Qlib data directory.
+* `dataset_path`: Directory where the processed train/validation/test pickle files will be saved.
+* `save_path`: Base directory for saving model checkpoints.
+* `backtest_result_path`: Directory for saving backtesting results.
+* `pretrained_tokenizer_path` and `pretrained_predictor_path`: Paths to the pre-trained models you want to start from (can be local paths or Hugging Face model names).
+
+You can also adjust other parameters like `instrument`, `train_time_range`, `epochs`, and `batch_size` to fit your specific task. If you don't use [Comet.ml](https://www.comet.com/), set `use_comet = False`.
+
+### Step 2: Prepare the Dataset
+
+Run the data preprocessing script. This script will load raw market data from your Qlib directory, process it, split it into training, validation, and test sets, and save them as pickle files.
+
+```shell
+python finetune/qlib_data_preprocess.py
+```
+
+After running, you will find `train_data.pkl`, `val_data.pkl`, and `test_data.pkl` in the directory specified by `dataset_path` in your config.
+
+### Step 3: Run the Finetuning
+
+The finetuning process consists of two stages: finetuning the tokenizer and then the predictor. Both training scripts are designed for multi-GPU training using `torchrun`.
+
+#### 3.1 Finetune the Tokenizer
+
+This step adjusts the tokenizer to the data distribution of your specific domain.
+
+```shell
+# Replace NUM_GPUS with the number of GPUs you want to use (e.g., 2)
+torchrun --standalone --nproc_per_node=NUM_GPUS finetune/train_tokenizer.py
+```
+
+The best tokenizer checkpoint will be saved to the path configured in `config.py` (derived from `save_path` and `tokenizer_save_folder_name`).
+
+#### 3.2 Finetune the Predictor
+
+This step finetunes the main Kronos model for the forecasting task.
+
+```shell
+# Replace NUM_GPUS with the number of GPUs you want to use (e.g., 2)
+torchrun --standalone --nproc_per_node=NUM_GPUS finetune/train_predictor.py
+```
+
+The best predictor checkpoint will be saved to the path configured in `config.py`.
+
+### Step 4: Evaluate with Backtesting
+
+Finally, run the backtesting script to evaluate your finetuned model. This script loads the models, performs inference on the test set, generates prediction signals (e.g., forecasted price change), and runs a simple top-K strategy backtest.
+
+```shell
+# Specify the GPU for inference
+python finetune/qlib_test.py --device cuda:0
+```
+
+The script will output a detailed performance analysis in your console and generate a plot showing the cumulative return curves of your strategy against the benchmark, similar to the one below:
+
+
+
+
+
+### 💡 From Demo to Production: Important Considerations
+
+* **Raw Signals vs. Pure Alpha**: The signals generated by the model in this demo are raw predictions. In a real-world quantitative workflow, these signals would typically be fed into a portfolio optimization model. This model would apply constraints to neutralize exposure to common risk factors (e.g., market beta, style factors like size and value), thereby isolating the **"pure alpha"** and improving the strategy's robustness.
+* **Data Handling**: The provided `QlibDataset` is an example. For different data sources or formats, you will need to adapt the data loading and preprocessing logic.
+* **Strategy and Backtesting Complexity**: The simple top-K strategy used here is a basic starting point. Production-level strategies often incorporate more complex logic for portfolio construction, dynamic position sizing, and risk management (e.g., stop-loss/take-profit rules). Furthermore, a high-fidelity backtest should meticulously model transaction costs, slippage, and market impact to provide a more accurate estimate of real-world performance.
+
+> **📝 AI-Generated Comments**: Please note that many of the code comments within the `finetune/` directory were generated by an AI assistant (Gemini 2.5 Pro) for explanatory purposes. While they aim to be helpful, they may contain inaccuracies. We recommend treating the code itself as the definitive source of logic.
+
## 📖 Citation
If you use Kronos in your research, we would appreciate a citation to our [paper](https://arxiv.org/abs/2508.02739):
diff --git a/figures/backtest_result_example.png b/figures/backtest_result_example.png
new file mode 100644
index 0000000..055a2e4
Binary files /dev/null and b/figures/backtest_result_example.png differ
diff --git a/finetune/config.py b/finetune/config.py
new file mode 100644
index 0000000..04cc3ee
--- /dev/null
+++ b/finetune/config.py
@@ -0,0 +1,131 @@
+import os
+
+class Config:
+ """
+ Configuration class for the entire project.
+ """
+
+ def __init__(self):
+ # =================================================================
+ # Data & Feature Parameters
+ # =================================================================
+ # TODO: Update this path to your Qlib data directory.
+ self.qlib_data_path = "~/.qlib/qlib_data/cn_data"
+ self.instrument = 'csi300'
+
+ # Overall time range for data loading from Qlib.
+ self.dataset_begin_time = "2011-01-01"
+ self.dataset_end_time = '2025-06-05'
+
+ # Sliding window parameters for creating samples.
+ self.lookback_window = 90 # Number of past time steps for input.
+ self.predict_window = 10 # Number of future time steps for prediction.
+ self.max_context = 512 # Maximum context length for the model.
+
+ # Features to be used from the raw data.
+ self.feature_list = ['open', 'high', 'low', 'close', 'vol', 'amt']
+ # Time-based features to be generated.
+ self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month']
+
+ # =================================================================
+ # Dataset Splitting & Paths
+ # =================================================================
+ # Note: The validation/test set starts earlier than the training/validation set ends
+ # to account for the `lookback_window`.
+ self.train_time_range = ["2011-01-01", "2022-12-31"]
+ self.val_time_range = ["2022-09-01", "2024-06-30"]
+ self.test_time_range = ["2024-04-01", "2025-06-05"]
+ self.backtest_time_range = ["2024-07-01", "2025-06-05"]
+
+ # TODO: Directory to save the processed, pickled datasets.
+ self.dataset_path = "./data/processed_datasets"
+
+ # =================================================================
+ # Training Hyperparameters
+ # =================================================================
+ self.clip = 5.0 # Clipping value for normalized data to prevent outliers.
+
+ self.epochs = 30
+ self.log_interval = 100 # Log training status every N batches.
+ self.batch_size = 50 # Batch size per GPU.
+
+ # Number of samples to draw for one "epoch" of training/validation.
+ # This is useful for large datasets where a true epoch is too long.
+ self.n_train_iter = 2000 * self.batch_size
+ self.n_val_iter = 400 * self.batch_size
+
+ # Learning rates for different model components.
+ self.tokenizer_learning_rate = 2e-4
+ self.predictor_learning_rate = 4e-5
+
+ # Gradient accumulation to simulate a larger batch size.
+ self.accumulation_steps = 1
+
+ # AdamW optimizer parameters.
+ self.adam_beta1 = 0.9
+ self.adam_beta2 = 0.95
+ self.adam_weight_decay = 0.1
+
+ # Miscellaneous
+ self.seed = 100 # Global random seed for reproducibility.
+
+ # =================================================================
+ # Experiment Logging & Saving
+ # =================================================================
+ self.use_comet = True # Set to False if you don't want to use Comet ML
+ self.comet_config = {
+ # It is highly recommended to load secrets from environment variables
+ # for security purposes. Example: os.getenv("COMET_API_KEY")
+ "api_key": "YOUR_COMET_API_KEY",
+ "project_name": "Kronos-Finetune-Demo",
+ "workspace": "your_comet_workspace" # TODO: Change to your Comet ML workspace name
+ }
+ self.comet_tag = 'finetune_demo'
+ self.comet_name = 'finetune_demo'
+
+ # Base directory for saving model checkpoints and results.
+ # Using a general 'outputs' directory is a common practice.
+ self.save_path = "./outputs/models"
+ self.tokenizer_save_folder_name = 'finetune_tokenizer_demo'
+ self.predictor_save_folder_name = 'finetune_predictor_demo'
+ self.backtest_save_folder_name = 'finetune_backtest_demo'
+
+ # Path for backtesting results.
+ self.backtest_result_path = "./outputs/backtest_results"
+
+ # =================================================================
+ # Model & Checkpoint Paths
+ # =================================================================
+ # TODO: Update these paths to your pretrained model locations.
+ # These can be local paths or Hugging Face Hub model identifiers.
+ self.pretrained_tokenizer_path = "path/to/your/Kronos-Tokenizer-base"
+ self.pretrained_predictor_path = "path/to/your/Kronos-small"
+
+ # Paths to the fine-tuned models, derived from the save_path.
+ # These will be generated automatically during training.
+ self.finetuned_tokenizer_path = f"{self.save_path}/{self.tokenizer_save_folder_name}/checkpoints/best_model"
+ self.finetuned_predictor_path = f"{self.save_path}/{self.predictor_save_folder_name}/checkpoints/best_model"
+
+ # =================================================================
+ # Backtesting Parameters
+ # =================================================================
+ self.backtest_n_symbol_hold = 50 # Number of symbols to hold in the portfolio.
+ self.backtest_n_symbol_drop = 5 # Number of symbols to drop from the pool.
+ self.backtest_hold_thresh = 5 # Minimum holding period for a stock.
+ self.inference_T = 0.6
+ self.inference_top_p = 0.9
+ self.inference_top_k = 0
+ self.inference_sample_count = 5
+ self.backtest_batch_size = 1000
+ self.backtest_benchmark = self._set_benchmark(self.instrument)
+
+ def _set_benchmark(self, instrument):
+ dt_benchmark = {
+ 'csi800': "SH000906",
+ 'csi1000': "SH000852",
+ 'csi300': "SH000300",
+ }
+ if instrument in dt_benchmark:
+ return dt_benchmark[instrument]
+ else:
+ raise ValueError(f"Benchmark not defined for instrument: {instrument}")
diff --git a/finetune/dataset.py b/finetune/dataset.py
new file mode 100644
index 0000000..f955ec1
--- /dev/null
+++ b/finetune/dataset.py
@@ -0,0 +1,145 @@
+import pickle
+import random
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+from config import Config
+
+
+class QlibDataset(Dataset):
+ """
+ A PyTorch Dataset for handling Qlib financial time series data.
+
+ This dataset pre-computes all possible start indices for sliding windows
+ and then randomly samples from them during training/validation.
+
+ Args:
+ data_type (str): The type of dataset to load, either 'train' or 'val'.
+
+ Raises:
+ ValueError: If `data_type` is not 'train' or 'val'.
+ """
+
+ def __init__(self, data_type: str = 'train'):
+ self.config = Config()
+ if data_type not in ['train', 'val']:
+ raise ValueError("data_type must be 'train' or 'val'")
+ self.data_type = data_type
+
+ # Use a dedicated random number generator for sampling to avoid
+ # interfering with other random processes (e.g., in model initialization).
+ self.py_rng = random.Random(self.config.seed)
+
+ # Set paths and number of samples based on the data type.
+ if data_type == 'train':
+ self.data_path = f"{self.config.dataset_path}/train_data.pkl"
+ self.n_samples = self.config.n_train_iter
+ else:
+ self.data_path = f"{self.config.dataset_path}/val_data.pkl"
+ self.n_samples = self.config.n_val_iter
+
+ with open(self.data_path, 'rb') as f:
+ self.data = pickle.load(f)
+
+ self.window = self.config.lookback_window + self.config.predict_window + 1
+
+ self.symbols = list(self.data.keys())
+ self.feature_list = self.config.feature_list
+ self.time_feature_list = self.config.time_feature_list
+
+ # Pre-compute all possible (symbol, start_index) pairs.
+ self.indices = []
+ print(f"[{data_type.upper()}] Pre-computing sample indices...")
+ for symbol in self.symbols:
+ df = self.data[symbol].reset_index()
+ series_len = len(df)
+ num_samples = series_len - self.window + 1
+
+ if num_samples > 0:
+ # Generate time features and store them directly in the dataframe.
+ df['minute'] = df['datetime'].dt.minute
+ df['hour'] = df['datetime'].dt.hour
+ df['weekday'] = df['datetime'].dt.weekday
+ df['day'] = df['datetime'].dt.day
+ df['month'] = df['datetime'].dt.month
+ # Keep only necessary columns to save memory.
+ self.data[symbol] = df[self.feature_list + self.time_feature_list]
+
+ # Add all valid starting indices for this symbol to the global list.
+ for i in range(num_samples):
+ self.indices.append((symbol, i))
+
+ # The effective dataset size is the minimum of the configured iterations
+ # and the total number of available samples.
+ self.n_samples = min(self.n_samples, len(self.indices))
+ print(f"[{data_type.upper()}] Found {len(self.indices)} possible samples. Using {self.n_samples} per epoch.")
+
+ def set_epoch_seed(self, epoch: int):
+ """
+ Sets a new seed for the random sampler for each epoch. This is crucial
+ for reproducibility in distributed training.
+
+ Args:
+ epoch (int): The current epoch number.
+ """
+ epoch_seed = self.config.seed + epoch
+ self.py_rng.seed(epoch_seed)
+
+ def __len__(self) -> int:
+ """Returns the number of samples per epoch."""
+ return self.n_samples
+
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Retrieves a random sample from the dataset.
+
+ Note: The `idx` argument is ignored. Instead, a random index is drawn
+ from the pre-computed `self.indices` list using `self.py_rng`. This
+ ensures random sampling over the entire dataset for each call.
+
+ Args:
+ idx (int): Ignored.
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+ - x_tensor (torch.Tensor): The normalized feature tensor.
+ - x_stamp_tensor (torch.Tensor): The time feature tensor.
+ """
+ # Select a random sample from the entire pool of indices.
+ random_idx = self.py_rng.randint(0, len(self.indices) - 1)
+ symbol, start_idx = self.indices[random_idx]
+
+ # Extract the sliding window from the dataframe.
+ df = self.data[symbol]
+ end_idx = start_idx + self.window
+ win_df = df.iloc[start_idx:end_idx]
+
+ # Separate main features and time features.
+ x = win_df[self.feature_list].values.astype(np.float32)
+ x_stamp = win_df[self.time_feature_list].values.astype(np.float32)
+
+ # Perform instance-level normalization.
+ x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
+ x = (x - x_mean) / (x_std + 1e-5)
+ x = np.clip(x, -self.config.clip, self.config.clip)
+
+ # Convert to PyTorch tensors.
+ x_tensor = torch.from_numpy(x)
+ x_stamp_tensor = torch.from_numpy(x_stamp)
+
+ return x_tensor, x_stamp_tensor
+
+
+if __name__ == '__main__':
+ # Example usage and verification.
+ print("Creating training dataset instance...")
+ train_dataset = QlibDataset(data_type='train')
+
+ print(f"Dataset length: {len(train_dataset)}")
+
+ if len(train_dataset) > 0:
+ try_x, try_x_stamp = train_dataset[100] # Index 100 is ignored.
+ print(f"Sample feature shape: {try_x.shape}")
+ print(f"Sample time feature shape: {try_x_stamp.shape}")
+ else:
+ print("Dataset is empty.")
diff --git a/finetune/qlib_data_preprocess.py b/finetune/qlib_data_preprocess.py
new file mode 100644
index 0000000..e9b0288
--- /dev/null
+++ b/finetune/qlib_data_preprocess.py
@@ -0,0 +1,120 @@
+import os
+import pickle
+import numpy as np
+import pandas as pd
+import qlib
+from qlib.config import REG_CN
+from qlib.data import D
+from qlib.data.dataset.loader import QlibDataLoader
+from tqdm import trange
+
+from config import Config
+
+
+class QlibDataPreprocessor:
+ """
+ A class to handle the loading, processing, and splitting of Qlib financial data.
+ """
+
+ def __init__(self):
+ """Initializes the preprocessor with configuration and data fields."""
+ self.config = Config()
+ self.data_fields = ['open', 'close', 'high', 'low', 'volume', 'vwap']
+ self.data = {} # A dictionary to store processed data for each symbol.
+
+ def initialize_qlib(self):
+ """Initializes the Qlib environment."""
+ print("Initializing Qlib...")
+ qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN)
+
+ def load_qlib_data(self):
+ """
+ Loads raw data from Qlib, processes it symbol by symbol, and stores
+ it in the `self.data` attribute.
+ """
+ print("Loading and processing data from Qlib...")
+ data_fields_qlib = ['$' + f for f in self.data_fields]
+ cal: np.ndarray = D.calendar()
+
+ # Determine the actual start and end times to load, including buffer for lookback and predict windows.
+ start_index = cal.searchsorted(pd.Timestamp(self.config.dataset_begin_time))
+ end_index = cal.searchsorted(pd.Timestamp(self.config.dataset_end_time))
+ real_start_time = cal[start_index - self.config.lookback_window]
+
+ if cal[end_index] != pd.Timestamp(self.config.dataset_end_time):
+ end_index -= 1
+ real_end_time = cal[end_index + self.config.predict_window]
+
+ # Load data using Qlib's data loader.
+ data_df = QlibDataLoader(config=data_fields_qlib).load(
+ self.config.instrument, real_start_time, real_end_time
+ )
+ data_df = data_df.stack().unstack(level=1) # Reshape for easier access.
+
+ symbol_list = list(data_df.columns)
+ for i in trange(len(symbol_list), desc="Processing Symbols"):
+ symbol = symbol_list[i]
+ symbol_df = data_df[symbol]
+
+ # Pivot the table to have features as columns and datetime as index.
+ symbol_df = symbol_df.reset_index().rename(columns={'level_1': 'field'})
+ symbol_df = pd.pivot(symbol_df, index='datetime', columns='field', values=symbol)
+ symbol_df = symbol_df.rename(columns={f'${field}': field for field in self.data_fields})
+
+ # Calculate amount and select final features.
+ symbol_df['vol'] = symbol_df['volume']
+ symbol_df['amt'] = (symbol_df['open'] + symbol_df['high'] + symbol_df['low'] + symbol_df['close']) / 4 * symbol_df['vol']
+ symbol_df = symbol_df[self.config.feature_list]
+
+ # Filter out symbols with insufficient data.
+ symbol_df = symbol_df.dropna()
+ if len(symbol_df) < self.config.lookback_window + self.config.predict_window + 1:
+ continue
+
+ self.data[symbol] = symbol_df
+
+ def prepare_dataset(self):
+ """
+ Splits the loaded data into train, validation, and test sets and saves them to disk.
+ """
+ print("Splitting data into train, validation, and test sets...")
+ train_data, val_data, test_data = {}, {}, {}
+
+ symbol_list = list(self.data.keys())
+ for i in trange(len(symbol_list), desc="Preparing Datasets"):
+ symbol = symbol_list[i]
+ symbol_df = self.data[symbol]
+
+ # Define time ranges from config.
+ train_start, train_end = self.config.train_time_range
+ val_start, val_end = self.config.val_time_range
+ test_start, test_end = self.config.test_time_range
+
+ # Create boolean masks for each dataset split.
+ train_mask = (symbol_df.index >= train_start) & (symbol_df.index <= train_end)
+ val_mask = (symbol_df.index >= val_start) & (symbol_df.index <= val_end)
+ test_mask = (symbol_df.index >= test_start) & (symbol_df.index <= test_end)
+
+ # Apply masks to create the final datasets.
+ train_data[symbol] = symbol_df[train_mask]
+ val_data[symbol] = symbol_df[val_mask]
+ test_data[symbol] = symbol_df[test_mask]
+
+ # Save the datasets using pickle.
+ os.makedirs(self.config.dataset_path, exist_ok=True)
+ with open(f"{self.config.dataset_path}/train_data.pkl", 'wb') as f:
+ pickle.dump(train_data, f)
+ with open(f"{self.config.dataset_path}/val_data.pkl", 'wb') as f:
+ pickle.dump(val_data, f)
+ with open(f"{self.config.dataset_path}/test_data.pkl", 'wb') as f:
+ pickle.dump(test_data, f)
+
+ print("Datasets prepared and saved successfully.")
+
+
+if __name__ == '__main__':
+ # This block allows the script to be run directly to perform data preprocessing.
+ preprocessor = QlibDataPreprocessor()
+ preprocessor.initialize_qlib()
+ preprocessor.load_qlib_data()
+ preprocessor.prepare_dataset()
diff --git a/finetune/qlib_test.py b/finetune/qlib_test.py
new file mode 100644
index 0000000..29dddb3
--- /dev/null
+++ b/finetune/qlib_test.py
@@ -0,0 +1,358 @@
+import os
+import sys
+import argparse
+import pickle
+from collections import defaultdict
+
+import numpy as np
+import pandas as pd
+import torch
+from torch.utils.data import Dataset, DataLoader
+from tqdm import trange, tqdm
+from matplotlib import pyplot as plt
+
+import qlib
+from qlib.config import REG_CN
+from qlib.backtest import backtest, executor, CommonInfrastructure
+from qlib.contrib.evaluate import risk_analysis
+from qlib.contrib.strategy import TopkDropoutStrategy
+from qlib.utils import flatten_dict
+from qlib.utils.time import Freq
+
+# Ensure project root is in the Python path
+sys.path.append("../")
+from config import Config
+from model.kronos import Kronos, KronosTokenizer, auto_regressive_inference
+
+
+# =================================================================================
+# 1. Data Loading and Processing for Inference
+# =================================================================================
+
+class QlibTestDataset(Dataset):
+ """
+ PyTorch Dataset for handling Qlib test data, specifically for inference.
+
+ This dataset iterates through all possible sliding windows sequentially. It also
+ yields metadata like symbol and timestamp, which are crucial for mapping
+ predictions back to the original time series.
+ """
+
+ def __init__(self, data: dict, config: Config):
+ self.data = data
+ self.config = config
+ self.window_size = config.lookback_window + config.predict_window
+ self.symbols = list(self.data.keys())
+ self.feature_list = config.feature_list
+ self.time_feature_list = config.time_feature_list
+ self.indices = []
+
+ print("Preprocessing and building indices for test dataset...")
+ for symbol in self.symbols:
+ df = self.data[symbol].reset_index()
+ # Generate time features on-the-fly
+ df['minute'] = df['datetime'].dt.minute
+ df['hour'] = df['datetime'].dt.hour
+ df['weekday'] = df['datetime'].dt.weekday
+ df['day'] = df['datetime'].dt.day
+ df['month'] = df['datetime'].dt.month
+ self.data[symbol] = df # Store preprocessed dataframe
+
+ num_samples = len(df) - self.window_size + 1
+ if num_samples > 0:
+ for i in range(num_samples):
+ timestamp = df.iloc[i + self.config.lookback_window - 1]['datetime']
+ self.indices.append((symbol, i, timestamp))
+
+ def __len__(self) -> int:
+ return len(self.indices)
+
+ def __getitem__(self, idx: int):
+ symbol, start_idx, timestamp = self.indices[idx]
+ df = self.data[symbol]
+
+ context_end = start_idx + self.config.lookback_window
+ predict_end = context_end + self.config.predict_window
+
+ context_df = df.iloc[start_idx:context_end]
+ predict_df = df.iloc[context_end:predict_end]
+
+ x = context_df[self.feature_list].values.astype(np.float32)
+ x_stamp = context_df[self.time_feature_list].values.astype(np.float32)
+ y_stamp = predict_df[self.time_feature_list].values.astype(np.float32)
+
+ # Instance-level normalization, consistent with training
+ x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
+ x = (x - x_mean) / (x_std + 1e-5)
+ x = np.clip(x, -self.config.clip, self.config.clip)
+
+ return torch.from_numpy(x), torch.from_numpy(x_stamp), torch.from_numpy(y_stamp), symbol, timestamp
+
+
+# =================================================================================
+# 2. Backtesting Logic
+# =================================================================================
+
+class QlibBacktest:
+ """
+ A wrapper class for conducting backtesting experiments using Qlib.
+ """
+
+ def __init__(self, config: Config):
+ self.config = config
+ self.initialize_qlib()
+
+ def initialize_qlib(self):
+ """Initializes the Qlib environment."""
+ print("Initializing Qlib for backtesting...")
+ qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN)
+
+ def run_single_backtest(self, signal_series: pd.Series) -> pd.DataFrame:
+ """
+ Runs a single backtest for a given prediction signal.
+
+ Args:
+ signal_series (pd.Series): A pandas Series with a MultiIndex
+ (instrument, datetime) and prediction scores.
+ Returns:
+ pd.DataFrame: A DataFrame containing the performance report.
+ """
+ strategy = TopkDropoutStrategy(
+ topk=self.config.backtest_n_symbol_hold,
+ n_drop=self.config.backtest_n_symbol_drop,
+ hold_thresh=self.config.backtest_hold_thresh,
+ signal=signal_series,
+ )
+ executor_config = {
+ "time_per_step": "day",
+ "generate_portfolio_metrics": True,
+ "delay_execution": True,
+ }
+ backtest_config = {
+ "start_time": self.config.backtest_time_range[0],
+ "end_time": self.config.backtest_time_range[1],
+ "account": 100_000_000,
+ "benchmark": self.config.backtest_benchmark,
+ "exchange_kwargs": {
+ "freq": "day", "limit_threshold": 0.095, "deal_price": "open",
+ "open_cost": 0.001, "close_cost": 0.0015, "min_cost": 5,
+ },
+ "executor": executor.SimulatorExecutor(**executor_config),
+ }
+
+ portfolio_metric_dict, _ = backtest(strategy=strategy, **backtest_config)
+ analysis_freq = "{0}{1}".format(*Freq.parse("day"))
+ report, _ = portfolio_metric_dict.get(analysis_freq)
+
+ # --- Analysis and Reporting ---
+ analysis = {
+ "excess_return_without_cost": risk_analysis(report["return"] - report["bench"], freq=analysis_freq),
+ "excess_return_with_cost": risk_analysis(report["return"] - report["bench"] - report["cost"], freq=analysis_freq),
+ }
+ print("\n--- Backtest Analysis ---")
+ print("Benchmark Return:", risk_analysis(report["bench"], freq=analysis_freq), sep='\n')
+ print("\nExcess Return (w/o cost):", analysis["excess_return_without_cost"], sep='\n')
+ print("\nExcess Return (w/ cost):", analysis["excess_return_with_cost"], sep='\n')
+
+ report_df = pd.DataFrame({
+ "cum_bench": report["bench"].cumsum(),
+ "cum_return_w_cost": (report["return"] - report["cost"]).cumsum(),
+ "cum_ex_return_w_cost": (report["return"] - report["bench"] - report["cost"]).cumsum(),
+ })
+ return report_df
+
+ def run_and_plot_results(self, signals: dict[str, pd.DataFrame]):
+ """
+ Runs backtests for multiple signals and plots the cumulative return curves.
+
+ Args:
+ signals (dict[str, pd.DataFrame]): A dictionary where keys are signal names
+ and values are prediction DataFrames.
+ """
+ return_df, ex_return_df, bench_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
+
+ for signal_name, pred_df in signals.items():
+ print(f"\nBacktesting signal: {signal_name}...")
+ pred_series = pred_df.stack()
+ pred_series.index.names = ['datetime', 'instrument']
+ pred_series = pred_series.swaplevel().sort_index()
+ report_df = self.run_single_backtest(pred_series)
+
+ return_df[signal_name] = report_df['cum_return_w_cost']
+ ex_return_df[signal_name] = report_df['cum_ex_return_w_cost']
+ if 'return' not in bench_df:
+ bench_df['return'] = report_df['cum_bench']
+
+ # Plotting results
+ fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
+ return_df.plot(ax=axes[0], title='Cumulative Return with Cost', grid=True)
+ axes[0].plot(bench_df['return'], label=self.config.instrument.upper(), color='black', linestyle='--')
+ axes[0].legend()
+ axes[0].set_ylabel("Cumulative Return")
+
+ ex_return_df.plot(ax=axes[1], title='Cumulative Excess Return with Cost', grid=True)
+ axes[1].legend()
+ axes[1].set_xlabel("Date")
+ axes[1].set_ylabel("Cumulative Excess Return")
+
+ plt.tight_layout()
+ plt.savefig("../figures/backtest_result_example.png", dpi=200)
+ plt.show()
+
+
+# =================================================================================
+# 3. Inference Logic
+# =================================================================================
+
+def load_models(config: dict) -> tuple[KronosTokenizer, Kronos]:
+ """Loads the fine-tuned tokenizer and predictor model."""
+ device = torch.device(config['device'])
+ print(f"Loading models onto device: {device}...")
+ tokenizer = KronosTokenizer.from_pretrained(config['tokenizer_path']).to(device).eval()
+ model = Kronos.from_pretrained(config['model_path']).to(device).eval()
+ return tokenizer, model
+
+
+def collate_fn_for_inference(batch):
+ """
+ Custom collate function to handle batches containing Tensors, strings, and Timestamps.
+
+ Args:
+ batch (list): A list of samples, where each sample is the tuple returned by
+ QlibTestDataset.__getitem__.
+
+ Returns:
+ A single tuple containing the batched data.
+ """
+ # Unzip the list of samples into separate lists for each data type
+ x, x_stamp, y_stamp, symbols, timestamps = zip(*batch)
+
+ # Stack the tensors to create a batch
+ x_batch = torch.stack(x, dim=0)
+ x_stamp_batch = torch.stack(x_stamp, dim=0)
+ y_stamp_batch = torch.stack(y_stamp, dim=0)
+
+ # Return the strings and timestamps as lists
+ return x_batch, x_stamp_batch, y_stamp_batch, list(symbols), list(timestamps)
+
+
+def generate_predictions(config: dict, test_data: dict) -> dict[str, pd.DataFrame]:
+ """
+ Runs inference on the test dataset to generate prediction signals.
+
+ Args:
+ config (dict): A dictionary containing inference parameters.
+ test_data (dict): The raw test data loaded from a pickle file.
+
+ Returns:
+ A dictionary where keys are signal types (e.g., 'mean', 'last') and
+ values are DataFrames of predictions (datetime index, symbol columns).
+ """
+ tokenizer, model = load_models(config)
+ device = torch.device(config['device'])
+
+ # Use the Dataset and DataLoader for efficient batching and processing
+ dataset = QlibTestDataset(data=test_data, config=Config())
+ loader = DataLoader(
+ dataset,
+ batch_size=config['batch_size'] // config['sample_count'],
+ shuffle=False,
+ num_workers=os.cpu_count() // 2,
+ collate_fn=collate_fn_for_inference
+ )
+
+ results = defaultdict(list)
+ with torch.no_grad():
+ for x, x_stamp, y_stamp, symbols, timestamps in tqdm(loader, desc="Inference"):
+ preds = auto_regressive_inference(
+ tokenizer, model, x.to(device), x_stamp.to(device), y_stamp.to(device),
+ max_context=config['max_context'], pred_len=config['pred_len'], clip=config['clip'],
+ T=config['T'], top_k=config['top_k'], top_p=config['top_p'], sample_count=config['sample_count']
+ )
+
+ # The 'close' price is at index 3 in `feature_list`
+ last_day_close = x[:, -1, 3].numpy()
+ signals = {
+ 'last': preds[:, -1, 3] - last_day_close,
+ 'mean': np.mean(preds[:, :, 3], axis=1) - last_day_close,
+ 'max': np.max(preds[:, :, 3], axis=1) - last_day_close,
+ 'min': np.min(preds[:, :, 3], axis=1) - last_day_close,
+ }
+
+ for i in range(len(symbols)):
+ for sig_type, sig_values in signals.items():
+ results[sig_type].append((timestamps[i], symbols[i], sig_values[i]))
+
+ print("Post-processing predictions into DataFrames...")
+ prediction_dfs = {}
+ for sig_type, records in results.items():
+ df = pd.DataFrame(records, columns=['datetime', 'instrument', 'score'])
+ pivot_df = df.pivot_table(index='datetime', columns='instrument', values='score')
+ prediction_dfs[sig_type] = pivot_df.sort_index()
+
+ return prediction_dfs
+
+
+# =================================================================================
+# 4. Main Execution
+# =================================================================================
+
+def main():
+ """Main function to set up config, run inference, and execute backtesting."""
+ parser = argparse.ArgumentParser(description="Run Kronos Inference and Backtesting")
+ parser.add_argument("--device", type=str, default="cuda:1", help="Device for inference (e.g., 'cuda:0', 'cpu')")
+ args = parser.parse_args()
+
+ # --- 1. Configuration Setup ---
+ base_config = Config()
+
+ # Create a dedicated dictionary for this run's configuration
+ run_config = {
+ 'device': args.device,
+ 'data_path': base_config.dataset_path,
+ 'result_save_path': base_config.backtest_result_path,
+ 'result_name': base_config.backtest_save_folder_name,
+ 'tokenizer_path': base_config.finetuned_tokenizer_path,
+ 'model_path': base_config.finetuned_predictor_path,
+ 'max_context': base_config.max_context,
+ 'pred_len': base_config.predict_window,
+ 'clip': base_config.clip,
+ 'T': base_config.inference_T,
+ 'top_k': base_config.inference_top_k,
+ 'top_p': base_config.inference_top_p,
+ 'sample_count': base_config.inference_sample_count,
+ 'batch_size': base_config.backtest_batch_size,
+ }
+
+ print("--- Running with Configuration ---")
+ for key, val in run_config.items():
+ print(f"{key:>20}: {val}")
+ print("-" * 35)
+
+ # --- 2. Load Data ---
+ test_data_path = os.path.join(run_config['data_path'], "test_data.pkl")
+ print(f"Loading test data from {test_data_path}...")
+ with open(test_data_path, 'rb') as f:
+ test_data = pickle.load(f)
+ print(test_data)
+ # --- 3. Generate Predictions ---
+ model_preds = generate_predictions(run_config, test_data)
+
+ # --- 4. Save Predictions ---
+ save_dir = os.path.join(run_config['result_save_path'], run_config['result_name'])
+ os.makedirs(save_dir, exist_ok=True)
+ predictions_file = os.path.join(save_dir, "predictions.pkl")
+ print(f"Saving prediction signals to {predictions_file}...")
+ with open(predictions_file, 'wb') as f:
+ pickle.dump(model_preds, f)
+
+ # --- 5. Run Backtesting ---
+ with open(predictions_file, 'rb') as f:
+ model_preds = pickle.load(f)
+
+ backtester = QlibBacktest(base_config)
+ backtester.run_and_plot_results(model_preds)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/finetune/train_predictor.py b/finetune/train_predictor.py
new file mode 100644
index 0000000..1e42587
--- /dev/null
+++ b/finetune/train_predictor.py
@@ -0,0 +1,244 @@
+import os
+import sys
+import json
+import time
+from time import gmtime, strftime
+import torch.distributed as dist
+import torch
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+import comet_ml
+
+# Ensure project root is in path
+sys.path.append('../')
+from config import Config
+from dataset import QlibDataset
+from model.kronos import KronosTokenizer, Kronos
+# Import shared utilities
+from utils.training_utils import (
+ setup_ddp,
+ cleanup_ddp,
+ set_seed,
+ get_model_size,
+ format_time
+)
+
+
+def create_dataloaders(config: dict, rank: int, world_size: int):
+ """
+ Creates and returns distributed dataloaders for training and validation.
+
+ Args:
+ config (dict): A dictionary of configuration parameters.
+ rank (int): The global rank of the current process.
+ world_size (int): The total number of processes.
+
+ Returns:
+ tuple: (train_loader, val_loader, train_dataset, valid_dataset).
+ """
+ print(f"[Rank {rank}] Creating distributed dataloaders...")
+ train_dataset = QlibDataset('train')
+ valid_dataset = QlibDataset('val')
+ print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}")
+
+ train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
+ val_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, rank=rank, shuffle=False)
+
+ train_loader = DataLoader(
+ train_dataset, batch_size=config['batch_size'], sampler=train_sampler,
+ num_workers=config.get('num_workers', 2), pin_memory=True, drop_last=True
+ )
+ val_loader = DataLoader(
+ valid_dataset, batch_size=config['batch_size'], sampler=val_sampler,
+ num_workers=config.get('num_workers', 2), pin_memory=True, drop_last=False
+ )
+ return train_loader, val_loader, train_dataset, valid_dataset
+
+
+def train_model(model, tokenizer, device, config, save_dir, logger, rank, world_size):
+ """
+ The main training and validation loop for the predictor.
+ """
+ start_time = time.time()
+ if rank == 0:
+ effective_bs = config['batch_size'] * world_size
+ print(f"Effective BATCHSIZE per GPU: {config['batch_size']}, Total: {effective_bs}")
+
+ train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size)
+
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=config['predictor_learning_rate'],
+ betas=(config['adam_beta1'], config['adam_beta2']),
+ weight_decay=config['adam_weight_decay']
+ )
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
+ optimizer, max_lr=config['predictor_learning_rate'],
+ steps_per_epoch=len(train_loader), epochs=config['epochs'],
+ pct_start=0.03, div_factor=10
+ )
+
+ best_val_loss = float('inf')
+ dt_result = {}
+ batch_idx_global = 0
+
+ for epoch_idx in range(config['epochs']):
+ epoch_start_time = time.time()
+ model.train()
+ train_loader.sampler.set_epoch(epoch_idx)
+
+ train_dataset.set_epoch_seed(epoch_idx * 10000 + rank)
+ valid_dataset.set_epoch_seed(0)
+
+ for i, (batch_x, batch_x_stamp) in enumerate(train_loader):
+ batch_x = batch_x.squeeze(0).to(device, non_blocking=True)
+ batch_x_stamp = batch_x_stamp.squeeze(0).to(device, non_blocking=True)
+
+ # Tokenize input data on-the-fly
+ with torch.no_grad():
+ token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
+
+ # Prepare inputs and targets for the language model
+ token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
+ token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
+
+ # Forward pass and loss calculation
+ logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
+ loss, s1_loss, s2_loss = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
+
+ # Backward pass and optimization
+ optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0)
+ optimizer.step()
+ scheduler.step()
+
+ # Logging (Master Process Only)
+ if rank == 0 and (batch_idx_global + 1) % config['log_interval'] == 0:
+ lr = optimizer.param_groups[0]['lr']
+ print(
+ f"[Rank {rank}, Epoch {epoch_idx + 1}/{config['epochs']}, Step {i + 1}/{len(train_loader)}] "
+ f"LR {lr:.6f}, Loss: {loss.item():.4f}"
+ )
+ if rank == 0 and logger:
+ lr = optimizer.param_groups[0]['lr']
+ logger.log_metric('train_predictor_loss_batch', loss.item(), step=batch_idx_global)
+ logger.log_metric('train_S1_loss_each_batch', s1_loss.item(), step=batch_idx_global)
+ logger.log_metric('train_S2_loss_each_batch', s2_loss.item(), step=batch_idx_global)
+ logger.log_metric('predictor_learning_rate', lr, step=batch_idx_global)
+
+ batch_idx_global += 1
+
+ # --- Validation Loop ---
+ model.eval()
+ tot_val_loss_sum_rank = 0.0
+ val_batches_processed_rank = 0
+ with torch.no_grad():
+ for batch_x, batch_x_stamp in val_loader:
+ batch_x = batch_x.squeeze(0).to(device, non_blocking=True)
+ batch_x_stamp = batch_x_stamp.squeeze(0).to(device, non_blocking=True)
+
+ token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
+ token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
+ token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
+
+ logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
+ val_loss, _, _ = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
+
+ tot_val_loss_sum_rank += val_loss.item()
+ val_batches_processed_rank += 1
+
+ # Reduce validation metrics
+ val_loss_sum_tensor = torch.tensor(tot_val_loss_sum_rank, device=device)
+ val_batches_tensor = torch.tensor(val_batches_processed_rank, device=device)
+ dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM)
+ dist.all_reduce(val_batches_tensor, op=dist.ReduceOp.SUM)
+
+ avg_val_loss = val_loss_sum_tensor.item() / val_batches_tensor.item() if val_batches_tensor.item() > 0 else 0
+
+ # --- End of Epoch Summary & Checkpointing (Master Process Only) ---
+ if rank == 0:
+ print(f"\n--- Epoch {epoch_idx + 1}/{config['epochs']} Summary ---")
+ print(f"Validation Loss: {avg_val_loss:.4f}")
+ print(f"Time This Epoch: {format_time(time.time() - epoch_start_time)}")
+ print(f"Total Time Elapsed: {format_time(time.time() - start_time)}\n")
+ if logger:
+ logger.log_metric('val_predictor_loss_epoch', avg_val_loss, epoch=epoch_idx)
+
+ if avg_val_loss < best_val_loss:
+ best_val_loss = avg_val_loss
+ save_path = f"{save_dir}/checkpoints/best_model"
+ model.module.save_pretrained(save_path)
+ print(f"Best model saved to {save_path} (Val Loss: {best_val_loss:.4f})")
+
+ dist.barrier()
+
+ dt_result['best_val_loss'] = best_val_loss
+ return dt_result
+
+
+def main(config: dict):
+ """Main function to orchestrate the DDP training process."""
+ rank, world_size, local_rank = setup_ddp()
+ device = torch.device(f"cuda:{local_rank}")
+ set_seed(config['seed'], rank)
+
+ save_dir = os.path.join(config['save_path'], config['predictor_save_folder_name'])
+
+ # Logger and summary setup (master process only)
+ comet_logger, master_summary = None, {}
+ if rank == 0:
+ os.makedirs(os.path.join(save_dir, 'checkpoints'), exist_ok=True)
+ master_summary = {
+ 'start_time': strftime("%Y-%m-%dT%H-%M-%S", gmtime()),
+ 'save_directory': save_dir,
+ 'world_size': world_size,
+ }
+ if config['use_comet']:
+ comet_logger = comet_ml.Experiment(
+ api_key=config['comet_config']['api_key'],
+ project_name=config['comet_config']['project_name'],
+ workspace=config['comet_config']['workspace'],
+ )
+ comet_logger.add_tag(config['comet_tag'])
+ comet_logger.set_name(config['comet_name'])
+ comet_logger.log_parameters(config)
+ print("Comet Logger Initialized.")
+
+ dist.barrier()
+
+ # Model Initialization
+ tokenizer = KronosTokenizer.from_pretrained(config['finetuned_tokenizer_path'])
+ tokenizer.eval().to(device)
+
+ model = Kronos.from_pretrained(config['pretrained_predictor_path'])
+ model.to(device)
+ model = DDP(model, device_ids=[local_rank], find_unused_parameters=False)
+
+ if rank == 0:
+ print(f"Predictor Model Size: {get_model_size(model.module)}")
+
+ # Start Training
+ dt_result = train_model(
+ model, tokenizer, device, config, save_dir, comet_logger, rank, world_size
+ )
+
+ if rank == 0:
+ master_summary['final_result'] = dt_result
+ with open(os.path.join(save_dir, 'summary.json'), 'w') as f:
+ json.dump(master_summary, f, indent=4)
+ print('Training finished. Summary file saved.')
+ if comet_logger: comet_logger.end()
+
+ cleanup_ddp()
+
+
+if __name__ == '__main__':
+ # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_predictor.py
+ if "WORLD_SIZE" not in os.environ:
+ raise RuntimeError("This script must be launched with `torchrun`.")
+
+ config_instance = Config()
+ main(config_instance.__dict__)
diff --git a/finetune/train_tokenizer.py b/finetune/train_tokenizer.py
new file mode 100644
index 0000000..2fe28cf
--- /dev/null
+++ b/finetune/train_tokenizer.py
@@ -0,0 +1,281 @@
+import os
+import sys
+import json
+import time
+from time import gmtime, strftime
+import argparse
+import datetime
+import torch.distributed as dist
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+import comet_ml
+
+# Ensure project root is in path
+sys.path.append("../")
+from config import Config
+from dataset import QlibDataset
+from model.kronos import KronosTokenizer
+# Import shared utilities
+from utils.training_utils import (
+ setup_ddp,
+ cleanup_ddp,
+ set_seed,
+ get_model_size,
+ format_time,
+)
+
+
+def create_dataloaders(config: dict, rank: int, world_size: int):
+ """
+ Creates and returns distributed dataloaders for training and validation.
+
+ Args:
+ config (dict): A dictionary of configuration parameters.
+ rank (int): The global rank of the current process.
+ world_size (int): The total number of processes.
+
+ Returns:
+ tuple: A tuple containing (train_loader, val_loader, train_dataset, valid_dataset).
+ """
+ print(f"[Rank {rank}] Creating distributed dataloaders...")
+ train_dataset = QlibDataset('train')
+ valid_dataset = QlibDataset('val')
+ print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}")
+
+ train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
+ val_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, rank=rank, shuffle=False)
+
+ train_loader = DataLoader(
+ train_dataset,
+ batch_size=config['batch_size'],
+ sampler=train_sampler,
+ shuffle=False, # Shuffle is handled by the sampler
+ num_workers=config.get('num_workers', 2),
+ pin_memory=True,
+ drop_last=True
+ )
+ val_loader = DataLoader(
+ valid_dataset,
+ batch_size=config['batch_size'],
+ sampler=val_sampler,
+ shuffle=False,
+ num_workers=config.get('num_workers', 2),
+ pin_memory=True,
+ drop_last=False
+ )
+ print(f"[Rank {rank}] Dataloaders created. Train steps/epoch: {len(train_loader)}, Val steps: {len(val_loader)}")
+ return train_loader, val_loader, train_dataset, valid_dataset
+
+
+def train_model(model, device, config, save_dir, logger, rank, world_size):
+ """
+ The main training and validation loop for the tokenizer.
+
+ Args:
+ model (DDP): The DDP-wrapped model to train.
+ device (torch.device): The device for the current process.
+ config (dict): Configuration dictionary.
+ save_dir (str): Directory to save checkpoints.
+ logger (comet_ml.Experiment): Comet logger instance.
+ rank (int): Global rank of the process.
+ world_size (int): Total number of processes.
+
+ Returns:
+ tuple: A tuple containing the trained model and a dictionary of results.
+ """
+ start_time = time.time()
+ if rank == 0:
+ effective_bs = config['batch_size'] * world_size * config['accumulation_steps']
+ print(f"[Rank {rank}] BATCHSIZE (per GPU): {config['batch_size']}")
+ print(f"[Rank {rank}] Effective total batch size: {effective_bs}")
+
+ train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size)
+
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=config['tokenizer_learning_rate'],
+ weight_decay=config['adam_weight_decay']
+ )
+
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
+ optimizer=optimizer,
+ max_lr=config['tokenizer_learning_rate'],
+ steps_per_epoch=len(train_loader),
+ epochs=config['epochs'],
+ pct_start=0.03,
+ div_factor=10
+ )
+
+ best_val_loss = float('inf')
+ dt_result = {}
+ batch_idx_global_train = 0
+
+ for epoch_idx in range(config['epochs']):
+ epoch_start_time = time.time()
+ model.train()
+ train_loader.sampler.set_epoch(epoch_idx)
+
+ # Set dataset seeds for reproducible sampling
+ train_dataset.set_epoch_seed(epoch_idx * 10000 + rank)
+ valid_dataset.set_epoch_seed(0) # Keep validation sampling consistent
+
+ for i, (ori_batch_x, _) in enumerate(train_loader):
+ ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
+
+ # --- Gradient Accumulation Loop ---
+ current_batch_total_loss = 0.0
+ for j in range(config['accumulation_steps']):
+ start_idx = j * (ori_batch_x.shape[0] // config['accumulation_steps'])
+ end_idx = (j + 1) * (ori_batch_x.shape[0] // config['accumulation_steps'])
+ batch_x = ori_batch_x[start_idx:end_idx]
+
+ # Forward pass
+ zs, bsq_loss, _, _ = model(batch_x)
+ z_pre, z = zs
+
+ # Loss calculation
+ recon_loss_pre = F.mse_loss(z_pre, batch_x)
+ recon_loss_all = F.mse_loss(z, batch_x)
+ recon_loss = recon_loss_pre + recon_loss_all
+ loss = (recon_loss + bsq_loss) / 2 # Assuming w_1=w_2=1
+
+ loss_scaled = loss / config['accumulation_steps']
+ current_batch_total_loss += loss.item()
+ loss_scaled.backward()
+
+ # --- Optimizer Step after Accumulation ---
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
+ optimizer.step()
+ scheduler.step()
+ optimizer.zero_grad()
+
+ # --- Logging (Master Process Only) ---
+ if rank == 0 and (batch_idx_global_train + 1) % config['log_interval'] == 0:
+ avg_loss = current_batch_total_loss / config['accumulation_steps']
+ print(
+ f"[Rank {rank}, Epoch {epoch_idx + 1}/{config['epochs']}, Step {i + 1}/{len(train_loader)}] "
+ f"LR {optimizer.param_groups[0]['lr']:.6f}, Loss: {avg_loss:.4f}"
+ )
+ if rank == 0 and logger:
+ avg_loss = current_batch_total_loss / config['accumulation_steps']
+ logger.log_metric('train_tokenizer_loss_batch', avg_loss, step=batch_idx_global_train)
+ logger.log_metric(f'train_vqvae_vq_loss_each_batch', bsq_loss.item(), step=batch_idx_global_train)
+ logger.log_metric(f'train_recon_loss_pre_each_batch', recon_loss_pre.item(), step=batch_idx_global_train)
+ logger.log_metric(f'train_recon_loss_each_batch', recon_loss_all.item(), step=batch_idx_global_train)
+ logger.log_metric('tokenizer_learning_rate', optimizer.param_groups[0]["lr"], step=batch_idx_global_train)
+
+ batch_idx_global_train += 1
+
+ # --- Validation Loop ---
+ model.eval()
+ tot_val_loss_sum_rank = 0.0
+ val_sample_count_rank = 0
+ with torch.no_grad():
+ for ori_batch_x, _ in val_loader:
+ ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
+ zs, _, _, _ = model(ori_batch_x)
+ _, z = zs
+ val_loss_item = F.mse_loss(z, ori_batch_x)
+
+ tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0)
+ val_sample_count_rank += ori_batch_x.size(0)
+
+ # Reduce validation losses from all processes
+ val_loss_sum_tensor = torch.tensor(tot_val_loss_sum_rank, device=device)
+ val_count_tensor = torch.tensor(val_sample_count_rank, device=device)
+ dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM)
+ dist.all_reduce(val_count_tensor, op=dist.ReduceOp.SUM)
+
+ avg_val_loss = val_loss_sum_tensor.item() / val_count_tensor.item() if val_count_tensor.item() > 0 else 0
+
+ # --- End of Epoch Summary & Checkpointing (Master Process Only) ---
+ if rank == 0:
+ print(f"\n--- Epoch {epoch_idx + 1}/{config['epochs']} Summary ---")
+ print(f"Validation Loss: {avg_val_loss:.4f}")
+ print(f"Time This Epoch: {format_time(time.time() - epoch_start_time)}")
+ print(f"Total Time Elapsed: {format_time(time.time() - start_time)}\n")
+ if logger:
+ logger.log_metric('val_tokenizer_loss_epoch', avg_val_loss, epoch=epoch_idx)
+
+ if avg_val_loss < best_val_loss:
+ best_val_loss = avg_val_loss
+ save_path = f"{save_dir}/checkpoints/best_model"
+ model.module.save_pretrained(save_path)
+ print(f"Best model saved to {save_path} (Val Loss: {best_val_loss:.4f})")
+ if logger:
+ logger.log_model("best_model", save_path)
+
+ dist.barrier() # Ensure all processes finish the epoch before starting the next one.
+
+ dt_result['best_val_loss'] = best_val_loss
+ return model, dt_result
+
+
+def main(config: dict):
+ """
+ Main function to orchestrate the DDP training process.
+ """
+ rank, world_size, local_rank = setup_ddp()
+ device = torch.device(f"cuda:{local_rank}")
+ set_seed(config['seed'], rank)
+
+ save_dir = os.path.join(config['save_path'], config['tokenizer_save_folder_name'])
+
+ # Logger and summary setup (master process only)
+ comet_logger, master_summary = None, {}
+ if rank == 0:
+ os.makedirs(os.path.join(save_dir, 'checkpoints'), exist_ok=True)
+ master_summary = {
+ 'start_time': strftime("%Y-%m-%dT%H-%M-%S", gmtime()),
+ 'save_directory': save_dir,
+ 'world_size': world_size,
+ }
+ if config['use_comet']:
+ comet_logger = comet_ml.Experiment(
+ api_key=config['comet_config']['api_key'],
+ project_name=config['comet_config']['project_name'],
+ workspace=config['comet_config']['workspace'],
+ )
+ comet_logger.add_tag(config['comet_tag'])
+ comet_logger.set_name(config['comet_name'])
+ comet_logger.log_parameters(config)
+ print("Comet Logger Initialized.")
+
+ dist.barrier() # Ensure save directory is created before proceeding
+
+ # Model Initialization
+ model = KronosTokenizer.from_pretrained(config['pretrained_tokenizer_path'])
+ model.to(device)
+ model = DDP(model, device_ids=[local_rank], find_unused_parameters=False)
+
+ if rank == 0:
+ print(f"Model Size: {get_model_size(model.module)}")
+
+ # Start Training
+ _, dt_result = train_model(
+ model, device, config, save_dir, comet_logger, rank, world_size
+ )
+
+ # Finalize and save summary (master process only)
+ if rank == 0:
+ master_summary['final_result'] = dt_result
+ with open(os.path.join(save_dir, 'summary.json'), 'w') as f:
+ json.dump(master_summary, f, indent=4)
+ print('Training finished. Summary file saved.')
+ if comet_logger:
+ comet_logger.end()
+
+ cleanup_ddp()
+
+
+if __name__ == '__main__':
+ # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_tokenizer.py
+ if "WORLD_SIZE" not in os.environ:
+ raise RuntimeError("This script must be launched with `torchrun`.")
+
+ config_instance = Config()
+ main(config_instance.__dict__)
diff --git a/finetune/utils/__init__.py b/finetune/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/finetune/utils/training_utils.py b/finetune/utils/training_utils.py
new file mode 100644
index 0000000..8756322
--- /dev/null
+++ b/finetune/utils/training_utils.py
@@ -0,0 +1,118 @@
+import os
+import random
+import datetime
+import numpy as np
+import torch
+import torch.distributed as dist
+
+
+def setup_ddp():
+ """
+ Initializes the distributed data parallel environment.
+
+ This function relies on environment variables set by `torchrun` or a similar
+ launcher. It initializes the process group and sets the CUDA device for the
+ current process.
+
+ Returns:
+ tuple: A tuple containing (rank, world_size, local_rank).
+ """
+ if not dist.is_available():
+ raise RuntimeError("torch.distributed is not available.")
+
+ dist.init_process_group(backend="nccl")
+ rank = int(os.environ["RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ torch.cuda.set_device(local_rank)
+ print(
+ f"[DDP Setup] Global Rank: {rank}/{world_size}, "
+ f"Local Rank (GPU): {local_rank} on device {torch.cuda.current_device()}"
+ )
+ return rank, world_size, local_rank
+
+
+def cleanup_ddp():
+ """Cleans up the distributed process group."""
+ if dist.is_initialized():
+ dist.destroy_process_group()
+
+
+def set_seed(seed: int, rank: int = 0):
+ """
+ Sets the random seed for reproducibility across all relevant libraries.
+
+ Args:
+ seed (int): The base seed value.
+ rank (int): The process rank, used to ensure different processes have
+ different seeds, which can be important for data loading.
+ """
+ actual_seed = seed + rank
+ random.seed(actual_seed)
+ np.random.seed(actual_seed)
+ torch.manual_seed(actual_seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(actual_seed)
+ # The two lines below can impact performance, so they are often
+ # reserved for final experiments where reproducibility is critical.
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def get_model_size(model: torch.nn.Module) -> str:
+ """
+ Calculates the number of trainable parameters in a PyTorch model and returns
+ it as a human-readable string.
+
+ Args:
+ model (torch.nn.Module): The PyTorch model.
+
+ Returns:
+ str: A string representing the model size (e.g., "175.0B", "7.1M", "50.5K").
+ """
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ if total_params >= 1e9:
+ return f"{total_params / 1e9:.1f}B" # Billions
+ elif total_params >= 1e6:
+ return f"{total_params / 1e6:.1f}M" # Millions
+ else:
+ return f"{total_params / 1e3:.1f}K" # Thousands
+
+
+def reduce_tensor(tensor: torch.Tensor, world_size: int, op=dist.ReduceOp.SUM) -> torch.Tensor:
+ """
+ Reduces a tensor's value across all processes in a distributed setup.
+
+ Args:
+ tensor (torch.Tensor): The tensor to be reduced.
+ world_size (int): The total number of processes.
+ op (dist.ReduceOp, optional): The reduction operation (SUM, AVG, etc.).
+ Defaults to dist.ReduceOp.SUM.
+
+ Returns:
+ torch.Tensor: The reduced tensor, which will be identical on all processes.
+ """
+ rt = tensor.clone()
+ dist.all_reduce(rt, op=op)
+ # Note: `dist.ReduceOp.AVG` is available in newer torch versions.
+ # For compatibility, manual division is sometimes used after a SUM.
+ if op == dist.ReduceOp.AVG:
+ rt /= world_size
+ return rt
+
+
+def format_time(seconds: float) -> str:
+ """
+ Formats a duration in seconds into a human-readable H:M:S string.
+
+ Args:
+ seconds (float): The total seconds.
+
+ Returns:
+ str: The formatted time string (e.g., "0:15:32").
+ """
+ return str(datetime.timedelta(seconds=int(seconds)))
+
+
+