Merge pull request #138 from Luciferbobo/master

add CSV-based finetuning pipeline for Kronos models
This commit is contained in:
ShiYu 2025-10-12 17:06:30 +08:00 committed by GitHub
commit 082ab7ef62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 95678 additions and 1 deletions

1
.gitignore vendored
View File

@ -45,7 +45,6 @@ Desktop.ini
# Data files (large files) # Data files (large files)
*.feather *.feather
*.csv
*.parquet *.parquet
*.h5 *.h5
*.hdf5 *.hdf5

120
finetune_csv/README.md Normal file
View File

@ -0,0 +1,120 @@
# Kronos Fine-tuning on Custom CSV Datasets
This module provides a comprehensive pipeline for fine-tuning Kronos models on your own CSV-formatted financial data. It supports both sequential training (tokenizer followed by predictor) and individual component training, with full distributed training capabilities.
## 1. Data Preparation
### Required Data Format
Your CSV file must contain the following columns:
- `timestamps`: DateTime stamps for each data point
- `open`: Opening price
- `high`: Highest price
- `low`: Lowest price
- `close`: Closing price
- `volume`: Trading volume
- `amount`: Trading amount
(volume and amount can be 0 if not available)
### Sample Data Format
| timestamps | open | close | high | low | volume | amount |
|------------|------|-------|------|-----|--------|--------|
| 2019/11/26 9:35 | 182.45215 | 184.45215 | 184.95215 | 182.45215 | 15136000 | 0 |
| 2019/11/26 9:40 | 184.35215 | 183.85215 | 184.55215 | 183.45215 | 4433300 | 0 |
| 2019/11/26 9:45 | 183.85215 | 183.35215 | 183.95215 | 182.95215 | 3070900 | 0 |
> **Reference**: Check `data/HK_ali_09988_kline_5min_all.csv` for a complete example of the proper data format.
## 2. Config Preparation
Please edit the correct data path & pretrained model path and set your training parameters.
```yaml
# Data configuration
data:
data_path: "/path/to/your/data.csv"
lookback_window: 512 # Historical data points to use
predict_window: 48 # Future points to predict
max_context: 512 # Maximum context length
...
```
There are some other settings here, please see `configs/config_ali09988_candle-5min.yaml` for more comments.
## 3. Training
### Method 1: Sequential Training (Recommended)
The `train_sequential.py` script handles the complete training pipeline automatically:
```bash
# Complete training (tokenizer + predictor)
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml
# Skip existing models
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing
# Only train tokenizer
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel
# Only train predictor
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer
```
### Method 2: Individual Component Training
Train each component separately for more control:
```bash
# Step 1: Train tokenizer
python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml
# Step 2: Train predictor (requires fine-tuned tokenizer)
python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml
```
### DDP Training
For faster training on multiple GPUs:
```bash
# Set communication backend (nccl for NVIDIA GPUs, gloo for CPU/mixed)
DIST_BACKEND=nccl \
torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml
```
## 4. Training Results
The training process generates several outputs:
### Model Checkpoints
- **Tokenizer**: Saved to `{base_save_path}/{exp_name}/tokenizer/best_model/`
- **Predictor**: Saved to `{base_save_path}/{exp_name}/basemodel/best_model/`
### Training Logs
- **Console output**: Real-time training progress and metrics
- **Log files**: Detailed logs saved to `{base_save_path}/logs/`
- **Validation tracking**: Best models are saved based on validation loss
## 5. Prediction Vis
The following images show example training results on alibaba (HK stock) data:
![Training Result 1](examples/HK_ali_09988_kline_5min_all_historical_20250919_073929.png)
![Training Result 2](examples/HK_ali_09988_kline_5min_all_historical_20250919_073944.png)
![Training Result 3](examples/HK_ali_09988_kline_5min_all_historical_20250919_074012.png)
![Training Result 4](examples/HK_ali_09988_kline_5min_all_historical_20250919_074042.png)
![Training Result 5](examples/HK_ali_09988_kline_5min_all_historical_20250919_074251.png)

118
finetune_csv/README_CN.md Normal file
View File

@ -0,0 +1,118 @@
# Kronos微调-支持自定义CSV数据集
这是一个在自定义的CSV格式数据上微调Kronos模型的完整流程。包含顺序训练先训练tokenizer再训练predictor和单独模块训练同时支持分布式训练。
## 1. 准备数据
### 数据格式
CSV文件必须包含以下列
- `timestamps`: 每个数据点的时间戳
- `open`: 开盘价
- `high`: 最高价
- `low`: 最低价
- `close`: 收盘价
- `volume`: 交易量
- `amount`: 交易金额
(volume和amount可以全0如果没有这部分的数据)
### 示例数据格式
| timestamps | open | close | high | low | volume | amount |
|------------|------|-------|------|-----|--------|--------|
| 2019/11/26 9:35 | 182.45215 | 184.45215 | 184.95215 | 182.45215 | 15136000 | 0 |
| 2019/11/26 9:40 | 184.35215 | 183.85215 | 184.55215 | 183.45215 | 4433300 | 0 |
| 2019/11/26 9:45 | 183.85215 | 183.35215 | 183.95215 | 182.95215 | 3070900 | 0 |
> **标准数据样例**: `data/HK_ali_09988_kline_5min_all.csv`
## 2. 准备config文件
data_path及预训练模型路径需要修改训练参数可以自己调节
```yaml
# 数据配置
data:
data_path: "/path/to/your/data.csv"
lookback_window: 512 # 要使用的历史数据点
predict_window: 48 # 要预测的未来点数
max_context: 512 # 最大上下文长度
...
```
这里还有其他一些设置, `configs/config_ali09988_candle-5min.yaml` 有更详细的注释。
## 3. 训练
### 方法1: 直接顺序训练
`train_sequential.py` 脚本自动处理完整的训练流程:
```bash
# 完整训练tokenizer + predictor
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml
# 跳过已存在的模型
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing
# 只训练tokenizer
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel
# 只训练predictor
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer
```
### 方法2: 单独组件训练
可以单独训练每个组件:
```bash
# 步骤1: 训练tokenizer
python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml
# 步骤2: 训练predictor需要微调后的tokenizer
python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml
```
### DDP训练
如果有多卡可以开启ddp加速训练
```bash
# 设置通信后端NVIDIA GPU用ncclCPU/混合用gloo
DIST_BACKEND=nccl \
torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml
```
## 4. 训练结果
训练过程生成以下输出:
### 模型检查点
- **Tokenizer**: 保存到 `{base_save_path}/{exp_name}/tokenizer/best_model/`
- **Predictor**: 保存到 `{base_save_path}/{exp_name}/basemodel/best_model/`
### 训练日志
- **控制台输出**: 实时训练进度和指标
- **日志文件**: 详细日志保存到 `{base_save_path}/logs/`
- **验证跟踪**: 基于验证损失保存最佳模型
## 5. 预测可视化
以下图像显示了kronos在阿里巴巴股票数据上微调后的示例训练结果
![训练结果 1](examples/HK_ali_09988_kline_5min_all_historical_20250919_073929.png)
![训练结果 2](examples/HK_ali_09988_kline_5min_all_historical_20250919_073944.png)
![训练结果 3](examples/HK_ali_09988_kline_5min_all_historical_20250919_074012.png)
![训练结果 4](examples/HK_ali_09988_kline_5min_all_historical_20250919_074042.png)
![训练结果 5](examples/HK_ali_09988_kline_5min_all_historical_20250919_074251.png)

View File

@ -0,0 +1,267 @@
import os
import yaml
from typing import Dict, Any
class ConfigLoader:
def __init__(self, config_path: str):
self.config_path = config_path
self.config = self._load_config()
def _load_config(self) -> Dict[str, Any]:
if not os.path.exists(self.config_path):
raise FileNotFoundError(f"config file not found: {self.config_path}")
with open(self.config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
config = self._resolve_dynamic_paths(config)
return config
def _resolve_dynamic_paths(self, config: Dict[str, Any]) -> Dict[str, Any]:
exp_name = config.get('model_paths', {}).get('exp_name', '')
if not exp_name:
return config
base_path = config.get('model_paths', {}).get('base_path', '')
path_templates = {
'base_save_path': f"{base_path}/{exp_name}",
'finetuned_tokenizer': f"{base_path}/{exp_name}/tokenizer/best_model"
}
if 'model_paths' in config:
for key, template in path_templates.items():
if key in config['model_paths']:
# only use template when the original value is empty string
current_value = config['model_paths'][key]
if current_value == "" or current_value is None:
config['model_paths'][key] = template
else:
# if the original value is not empty, use template to replace the {exp_name} placeholder
if isinstance(current_value, str) and '{exp_name}' in current_value:
config['model_paths'][key] = current_value.format(exp_name=exp_name)
return config
def get(self, key: str, default=None):
keys = key.split('.')
value = self.config
try:
for k in keys:
value = value[k]
return value
except (KeyError, TypeError):
return default
def get_data_config(self) -> Dict[str, Any]:
return self.config.get('data', {})
def get_training_config(self) -> Dict[str, Any]:
return self.config.get('training', {})
def get_model_paths(self) -> Dict[str, str]:
return self.config.get('model_paths', {})
def get_experiment_config(self) -> Dict[str, Any]:
return self.config.get('experiment', {})
def get_device_config(self) -> Dict[str, Any]:
return self.config.get('device', {})
def get_distributed_config(self) -> Dict[str, Any]:
return self.config.get('distributed', {})
def update_config(self, updates: Dict[str, Any]):
def update_nested_dict(d, u):
for k, v in u.items():
if isinstance(v, dict):
d[k] = update_nested_dict(d.get(k, {}), v)
else:
d[k] = v
return d
self.config = update_nested_dict(self.config, updates)
def save_config(self, save_path: str = None):
if save_path is None:
save_path = self.config_path
with open(save_path, 'w', encoding='utf-8') as f:
yaml.dump(self.config, f, default_flow_style=False, allow_unicode=True, indent=2)
def print_config(self):
print("=" * 50)
print("Current configuration:")
print("=" * 50)
yaml.dump(self.config, default_flow_style=False, allow_unicode=True, indent=2)
print("=" * 50)
class CustomFinetuneConfig:
def __init__(self, config_path: str = None):
if config_path is None:
config_path = os.path.join(os.path.dirname(__file__), 'config.yaml')
self.loader = ConfigLoader(config_path)
self._load_all_configs()
def _load_all_configs(self):
data_config = self.loader.get_data_config()
self.data_path = data_config.get('data_path')
self.lookback_window = data_config.get('lookback_window', 512)
self.predict_window = data_config.get('predict_window', 48)
self.max_context = data_config.get('max_context', 512)
self.clip = data_config.get('clip', 5.0)
self.train_ratio = data_config.get('train_ratio', 0.9)
self.val_ratio = data_config.get('val_ratio', 0.1)
self.test_ratio = data_config.get('test_ratio', 0.0)
# training configuration
training_config = self.loader.get_training_config()
# support training epochs of tokenizer and basemodel separately
self.tokenizer_epochs = training_config.get('tokenizer_epochs', 30)
self.basemodel_epochs = training_config.get('basemodel_epochs', 30)
if 'epochs' in training_config and 'tokenizer_epochs' not in training_config:
self.tokenizer_epochs = training_config.get('epochs', 30)
if 'epochs' in training_config and 'basemodel_epochs' not in training_config:
self.basemodel_epochs = training_config.get('epochs', 30)
self.batch_size = training_config.get('batch_size', 160)
self.log_interval = training_config.get('log_interval', 50)
self.num_workers = training_config.get('num_workers', 6)
self.seed = training_config.get('seed', 100)
self.tokenizer_learning_rate = training_config.get('tokenizer_learning_rate', 2e-4)
self.predictor_learning_rate = training_config.get('predictor_learning_rate', 4e-5)
self.adam_beta1 = training_config.get('adam_beta1', 0.9)
self.adam_beta2 = training_config.get('adam_beta2', 0.95)
self.adam_weight_decay = training_config.get('adam_weight_decay', 0.1)
self.accumulation_steps = training_config.get('accumulation_steps', 1)
model_paths = self.loader.get_model_paths()
self.exp_name = model_paths.get('exp_name', 'default_experiment')
self.pretrained_tokenizer_path = model_paths.get('pretrained_tokenizer')
self.pretrained_predictor_path = model_paths.get('pretrained_predictor')
self.base_save_path = model_paths.get('base_save_path')
self.tokenizer_save_name = model_paths.get('tokenizer_save_name', 'tokenizer')
self.basemodel_save_name = model_paths.get('basemodel_save_name', 'basemodel')
self.finetuned_tokenizer_path = model_paths.get('finetuned_tokenizer')
experiment_config = self.loader.get_experiment_config()
self.experiment_name = experiment_config.get('name', 'kronos_custom_finetune')
self.experiment_description = experiment_config.get('description', '')
self.use_comet = experiment_config.get('use_comet', False)
self.train_tokenizer = experiment_config.get('train_tokenizer', True)
self.train_basemodel = experiment_config.get('train_basemodel', True)
self.skip_existing = experiment_config.get('skip_existing', False)
unified_pretrained = experiment_config.get('pre_trained', None)
self.pre_trained_tokenizer = experiment_config.get('pre_trained_tokenizer', unified_pretrained if unified_pretrained is not None else True)
self.pre_trained_predictor = experiment_config.get('pre_trained_predictor', unified_pretrained if unified_pretrained is not None else True)
device_config = self.loader.get_device_config()
self.use_cuda = device_config.get('use_cuda', True)
self.device_id = device_config.get('device_id', 0)
distributed_config = self.loader.get_distributed_config()
self.use_ddp = distributed_config.get('use_ddp', False)
self.ddp_backend = distributed_config.get('backend', 'nccl')
self._compute_full_paths()
def _compute_full_paths(self):
self.tokenizer_save_path = os.path.join(self.base_save_path, self.tokenizer_save_name)
self.tokenizer_best_model_path = os.path.join(self.tokenizer_save_path, 'best_model')
self.basemodel_save_path = os.path.join(self.base_save_path, self.basemodel_save_name)
self.basemodel_best_model_path = os.path.join(self.basemodel_save_path, 'best_model')
def get_tokenizer_config(self):
return {
'data_path': self.data_path,
'lookback_window': self.lookback_window,
'predict_window': self.predict_window,
'max_context': self.max_context,
'clip': self.clip,
'train_ratio': self.train_ratio,
'val_ratio': self.val_ratio,
'test_ratio': self.test_ratio,
'epochs': self.tokenizer_epochs,
'batch_size': self.batch_size,
'log_interval': self.log_interval,
'num_workers': self.num_workers,
'seed': self.seed,
'learning_rate': self.tokenizer_learning_rate,
'adam_beta1': self.adam_beta1,
'adam_beta2': self.adam_beta2,
'adam_weight_decay': self.adam_weight_decay,
'accumulation_steps': self.accumulation_steps,
'pretrained_model_path': self.pretrained_tokenizer_path,
'save_path': self.tokenizer_save_path,
'use_comet': self.use_comet
}
def get_basemodel_config(self):
return {
'data_path': self.data_path,
'lookback_window': self.lookback_window,
'predict_window': self.predict_window,
'max_context': self.max_context,
'clip': self.clip,
'train_ratio': self.train_ratio,
'val_ratio': self.val_ratio,
'test_ratio': self.test_ratio,
'epochs': self.basemodel_epochs,
'batch_size': self.batch_size,
'log_interval': self.log_interval,
'num_workers': self.num_workers,
'seed': self.seed,
'predictor_learning_rate': self.predictor_learning_rate,
'tokenizer_learning_rate': self.tokenizer_learning_rate,
'adam_beta1': self.adam_beta1,
'adam_beta2': self.adam_beta2,
'adam_weight_decay': self.adam_weight_decay,
'pretrained_tokenizer_path': self.finetuned_tokenizer_path,
'pretrained_predictor_path': self.pretrained_predictor_path,
'save_path': self.basemodel_save_path,
'use_comet': self.use_comet
}
def print_config_summary(self):
print("=" * 60)
print("Kronos finetuning configuration summary")
print("=" * 60)
print(f"Experiment name: {self.exp_name}")
print(f"Data path: {self.data_path}")
print(f"Lookback window: {self.lookback_window}")
print(f"Predict window: {self.predict_window}")
print(f"Tokenizer training epochs: {self.tokenizer_epochs}")
print(f"Basemodel training epochs: {self.basemodel_epochs}")
print(f"Batch size: {self.batch_size}")
print(f"Tokenizer learning rate: {self.tokenizer_learning_rate}")
print(f"Predictor learning rate: {self.predictor_learning_rate}")
print(f"Train tokenizer: {self.train_tokenizer}")
print(f"Train basemodel: {self.train_basemodel}")
print(f"Skip existing: {self.skip_existing}")
print(f"Use pre-trained tokenizer: {self.pre_trained_tokenizer}")
print(f"Use pre-trained predictor: {self.pre_trained_predictor}")
print(f"Base save path: {self.base_save_path}")
print(f"Tokenizer save path: {self.tokenizer_save_path}")
print(f"Basemodel save path: {self.basemodel_save_path}")
print("=" * 60)

View File

@ -0,0 +1,72 @@
#This is a template config for custom finetuning kronos on csv data
#这是一份模板config用于kronos的csv自定义数据微调
data:
data_path: "/xxxx/Kronos/finetune_csv/data/HK_ali_09988_kline_5min_all.csv"
lookback_window: 512
predict_window: 48
max_context: 512
clip: 5.0
# dataset split ratio
train_ratio: 0.9
val_ratio: 0.1
test_ratio: 0.0
training:
# control the training epochs of tokenizer and basemodel
tokenizer_epochs: 30
basemodel_epochs: 20
batch_size: 32
log_interval: 50
num_workers: 6
seed: 42
tokenizer_learning_rate: 0.0002
predictor_learning_rate: 0.000001
adam_beta1: 0.9
adam_beta2: 0.95
adam_weight_decay: 0.1
# gradient accumulation steps for tokenizer training
accumulation_steps: 1
# model path configuration
model_paths:
# pretrained model path
pretrained_tokenizer: "/xxx/Kronos/pretrained/Kronos-Tokenizer-base"
pretrained_predictor: "/xxx/Kronos/pretrained/Kronos-base"
# experiment name - other paths will be generated based on this
exp_name: "HK_ali_09988_kline_5min_all"
base_path: "/xxx/Kronos/finetune_csv/finetuned/"
# the following paths will be generated based on exp_name, no need to modify manually
# way 1: leave empty string, the system will generate the full path
base_save_path: "" # /xxxx/Kronos/finetune_csv/finetuned/{exp_name}
finetuned_tokenizer: "" # /xxxx/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model
# way 2: use template string, {exp_name} will be replaced with the actual experiment name
# base_save_path: "/xxxx/Kronos/finetune_csv/finetuned/{exp_name}"
# finetuned_tokenizer: "/xxxx/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model"
tokenizer_save_name: "tokenizer"
basemodel_save_name: "basemodel"
experiment:
name: "kronos_custom_finetune"
description: "Custom finetune for HK stock data"
use_comet: false
# control the training phase
train_tokenizer: true
train_basemodel: true
# if true, skip the existing model training
skip_existing: false
# device configuration
device:
use_cuda: true
device_id: 0

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 474 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 473 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 331 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 449 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 530 KiB

View File

@ -0,0 +1,468 @@
import os
import sys
import json
import time
import pickle
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from time import gmtime, strftime
import logging
from logging.handlers import RotatingFileHandler
import datetime
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
sys.path.append('../')
from model import Kronos, KronosTokenizer, KronosPredictor
from config_loader import CustomFinetuneConfig
class CustomKlineDataset(Dataset):
def __init__(self, data_path, data_type='train', lookback_window=90, predict_window=10,
clip=5.0, seed=100, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
self.data_path = data_path
self.data_type = data_type
self.lookback_window = lookback_window
self.predict_window = predict_window
self.window = lookback_window + predict_window + 1
self.clip = clip
self.seed = seed
self.train_ratio = train_ratio
self.val_ratio = val_ratio
self.test_ratio = test_ratio
self.feature_list = ['open', 'high', 'low', 'close', 'volume', 'amount']
self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month']
self.py_rng = random.Random(seed)
self._load_and_preprocess_data()
self._split_data_by_time()
self.n_samples = len(self.data) - self.window + 1
print(f"[{data_type.upper()}] Data length: {len(self.data)}, Available samples: {self.n_samples}")
def _load_and_preprocess_data(self):
df = pd.read_csv(self.data_path)
df['timestamps'] = pd.to_datetime(df['timestamps'])
df = df.sort_values('timestamps').reset_index(drop=True)
self.timestamps = df['timestamps'].copy()
df['minute'] = df['timestamps'].dt.minute
df['hour'] = df['timestamps'].dt.hour
df['weekday'] = df['timestamps'].dt.weekday
df['day'] = df['timestamps'].dt.day
df['month'] = df['timestamps'].dt.month
self.data = df[self.feature_list + self.time_feature_list].copy()
if self.data.isnull().any().any():
print("Warning: Missing values found in data, performing forward fill")
self.data = self.data.fillna(method='ffill')
print(f"Original data time range: {self.timestamps.min()} to {self.timestamps.max()}")
print(f"Original data total length: {len(df)} records")
def _split_data_by_time(self):
total_length = len(self.data)
train_end = int(total_length * self.train_ratio)
val_end = int(total_length * (self.train_ratio + self.val_ratio))
if self.data_type == 'train':
self.data = self.data.iloc[:train_end].copy()
self.timestamps = self.timestamps.iloc[:train_end].copy()
print(f"[{self.data_type.upper()}] Training set: first {train_end} time points ({self.train_ratio})")
print(f"[{self.data_type.upper()}] Training set time range: {self.timestamps.min()} to {self.timestamps.max()}")
elif self.data_type == 'val':
self.data = self.data.iloc[train_end:val_end].copy()
self.timestamps = self.timestamps.iloc[train_end:val_end].copy()
print(f"[{self.data_type.upper()}] Validation set: time points {train_end+1} to {val_end} ({self.val_ratio})")
print(f"[{self.data_type.upper()}] Validation set time range: {self.timestamps.min()} to {self.timestamps.max()}")
elif self.data_type == 'test':
self.data = self.data.iloc[val_end:].copy()
self.timestamps = self.timestamps.iloc[val_end:].copy()
print(f"[{self.data_type.upper()}] Test set: after time point {val_end+1}")
print(f"[{self.data_type.upper()}] Test set time range: {self.timestamps.min()} to {self.timestamps.max()}")
print(f"[{self.data_type.upper()}] Data length after split: {len(self.data)} records")
def set_epoch_seed(self, epoch):
epoch_seed = self.seed + epoch
self.py_rng.seed(epoch_seed)
self.current_epoch = epoch
def __len__(self):
return self.n_samples
def __getitem__(self, idx):
max_start = len(self.data) - self.window
if max_start <= 0:
raise ValueError("Data length insufficient to create samples")
if self.data_type == 'train':
epoch = getattr(self, 'current_epoch', 0)
start_idx = (idx * 9973 + (epoch + 1) * 104729) % (max_start + 1)
else:
start_idx = idx % (max_start + 1)
end_idx = start_idx + self.window
window_data = self.data.iloc[start_idx:end_idx]
x = window_data[self.feature_list].values.astype(np.float32)
x_stamp = window_data[self.time_feature_list].values.astype(np.float32)
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
x = (x - x_mean) / (x_std + 1e-5)
x = np.clip(x, -self.clip, self.clip)
x_tensor = torch.from_numpy(x)
x_stamp_tensor = torch.from_numpy(x_stamp)
return x_tensor, x_stamp_tensor
def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger:
os.makedirs(log_dir, exist_ok=True)
logger = logging.getLogger(f"basemodel_training_rank_{rank}")
logger.setLevel(logging.INFO)
if logger.handlers:
return logger
log_file = os.path.join(log_dir, f"basemodel_training_rank_{rank}.log")
file_handler = RotatingFileHandler(
log_file,
maxBytes=10*1024*1024,
backupCount=5,
encoding='utf-8'
)
file_handler.setLevel(logging.INFO)
console_handler = None
if rank == 0:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter)
if console_handler is not None:
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if console_handler is not None:
logger.addHandler(console_handler)
logger.info(f"=== Basemodel Training Started ===")
logger.info(f"Experiment Name: {exp_name}")
logger.info(f"Log Directory: {log_dir}")
logger.info(f"Rank: {rank}")
logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
return logger
def create_dataloaders(config):
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
print("Creating data loaders...")
train_dataset = CustomKlineDataset(
data_path=config.data_path,
data_type='train',
lookback_window=config.lookback_window,
predict_window=config.predict_window,
clip=config.clip,
seed=config.seed,
train_ratio=config.train_ratio,
val_ratio=config.val_ratio,
test_ratio=config.test_ratio
)
val_dataset = CustomKlineDataset(
data_path=config.data_path,
data_type='val',
lookback_window=config.lookback_window,
predict_window=config.predict_window,
clip=config.clip,
seed=config.seed + 1,
train_ratio=config.train_ratio,
val_ratio=config.val_ratio,
test_ratio=config.test_ratio
)
use_ddp = dist.is_available() and dist.is_initialized()
train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None
val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=(train_sampler is None),
num_workers=config.num_workers,
pin_memory=True,
drop_last=True,
sampler=train_sampler
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
drop_last=False,
sampler=val_sampler
)
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler
def train_model(model, tokenizer, device, config, save_dir, logger):
logger.info("Starting training...")
use_ddp = dist.is_available() and dist.is_initialized()
rank = dist.get_rank() if use_ddp else 0
world_size = dist.get_world_size() if use_ddp else 1
train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.predictor_learning_rate,
betas=(config.adam_beta1, config.adam_beta2),
weight_decay=config.adam_weight_decay
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=config.predictor_learning_rate,
steps_per_epoch=len(train_loader),
epochs=config.basemodel_epochs,
pct_start=0.03,
div_factor=10
)
if use_ddp:
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
best_val_loss = float('inf')
batch_idx_global = 0
for epoch in range(config.basemodel_epochs):
epoch_start_time = time.time()
model.train()
train_dataset.set_epoch_seed(epoch * 10000)
val_dataset.set_epoch_seed(0)
if train_sampler is not None:
train_sampler.set_epoch(epoch)
epoch_train_loss = 0.0
train_batches = 0
for batch_idx, (batch_x, batch_x_stamp) in enumerate(train_loader):
batch_x = batch_x.to(device, non_blocking=True)
batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)
with torch.no_grad():
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
loss, s1_loss, s2_loss = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=3.0)
optimizer.step()
scheduler.step()
epoch_train_loss += loss.item()
train_batches += 1
if (batch_idx_global + 1) % config.log_interval == 0:
lr = optimizer.param_groups[0]['lr']
log_msg = (f"[Epoch {epoch+1}/{config.basemodel_epochs}, Step {batch_idx+1}/{len(train_loader)}] "
f"LR: {lr:.6f}, Loss: {loss.item():.4f}")
logger.info(log_msg)
if rank == 0:
print(log_msg)
batch_idx_global += 1
model.eval()
val_loss = 0.0
val_batches = 0
with torch.no_grad():
for batch_x, batch_x_stamp in val_loader:
batch_x = batch_x.to(device, non_blocking=True)
batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
loss, _, _ = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
val_loss += loss.item()
val_batches += 1
if use_ddp:
tensor_sum = torch.tensor([epoch_train_loss, train_batches, val_loss, val_batches], dtype=torch.float64, device=device)
dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM)
epoch_train_loss_all = tensor_sum[0].item()
train_batches_all = int(tensor_sum[1].item())
val_loss_all = tensor_sum[2].item()
val_batches_all = int(tensor_sum[3].item())
avg_train_loss = (epoch_train_loss_all / train_batches_all) if train_batches_all > 0 else 0.0
avg_val_loss = (val_loss_all / val_batches_all) if val_batches_all > 0 else 0.0
else:
avg_train_loss = epoch_train_loss / train_batches if train_batches > 0 else 0
avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
epoch_time = time.time() - epoch_start_time
epoch_summary = (f"\n--- Epoch {epoch+1}/{config.basemodel_epochs} Summary ---\n"
f"Training Loss: {avg_train_loss:.4f}\n"
f"Validation Loss: {avg_val_loss:.4f}\n"
f"Epoch Time: {epoch_time:.2f} seconds\n")
logger.info(epoch_summary)
if rank == 0:
print(epoch_summary)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
if rank == 0:
model_save_path = os.path.join(save_dir, "best_model")
os.makedirs(model_save_path, exist_ok=True)
(model.module if use_ddp else model).save_pretrained(model_save_path)
save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})"
logger.info(save_msg)
print(save_msg)
return best_val_loss
def main():
import argparse
parser = argparse.ArgumentParser(description='Kronos Basemodel Fine-tuning Training')
parser.add_argument('--config', type=str, default='config.yaml',
help='Configuration file path (default: config.yaml)')
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
config = CustomFinetuneConfig(args.config)
os.makedirs(config.basemodel_save_path, exist_ok=True)
log_dir = os.path.join(config.base_save_path, "logs")
logger = setup_logging(config.exp_name, log_dir, 0)
torch.manual_seed(config.seed)
np.random.seed(config.seed)
random.seed(config.seed)
logger.info("Loading pretrained model or random initialization...")
print("Loading pretrained model or random initialization...")
if getattr(config, 'pre_trained_tokenizer', True):
tokenizer = KronosTokenizer.from_pretrained(config.finetuned_tokenizer_path)
else:
import json, os
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for training")
cfg_path_tok = os.path.join(config.pretrained_tokenizer_path if hasattr(config, 'pretrained_tokenizer_path') else config.finetuned_tokenizer_path, 'config.json')
with open(cfg_path_tok, 'r') as f:
arch_t = json.load(f)
tokenizer = KronosTokenizer(
d_in=arch_t.get('d_in', 6),
d_model=arch_t.get('d_model', 256),
n_heads=arch_t.get('n_heads', 4),
ff_dim=arch_t.get('ff_dim', 512),
n_enc_layers=arch_t.get('n_enc_layers', 4),
n_dec_layers=arch_t.get('n_dec_layers', 4),
ffn_dropout_p=arch_t.get('ffn_dropout_p', 0.0),
attn_dropout_p=arch_t.get('attn_dropout_p', 0.0),
resid_dropout_p=arch_t.get('resid_dropout_p', 0.0),
s1_bits=arch_t.get('s1_bits', 10),
s2_bits=arch_t.get('s2_bits', 10),
beta=arch_t.get('beta', 0.05),
gamma0=arch_t.get('gamma0', 1.0),
gamma=arch_t.get('gamma', 1.1),
zeta=arch_t.get('zeta', 0.05),
group_size=arch_t.get('group_size', 4)
)
if getattr(config, 'pre_trained_predictor', True):
model = Kronos.from_pretrained(config.pretrained_predictor_path)
else:
import json, os
print("pre_trained_predictor=False, randomly initializing Predictor architecture for training")
cfg_path = os.path.join(config.pretrained_predictor_path, 'config.json')
with open(cfg_path, 'r') as f:
arch = json.load(f)
model = Kronos(
s1_bits=arch.get('s1_bits', 10),
s2_bits=arch.get('s2_bits', 10),
n_layers=arch.get('n_layers', 12),
d_model=arch.get('d_model', 832),
n_heads=arch.get('n_heads', 16),
ff_dim=arch.get('ff_dim', 2048),
ffn_dropout_p=arch.get('ffn_dropout_p', 0.2),
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
resid_dropout_p=arch.get('resid_dropout_p', 0.2),
token_dropout_p=arch.get('token_dropout_p', 0.0),
learn_te=arch.get('learn_te', True)
)
tokenizer = tokenizer.to(device)
model = model.to(device)
model_size = sum(p.numel() for p in model.parameters())
logger.info(f"Model parameters: {model_size:,}")
print(f"Model parameters: {model_size:,}")
logger.info("=== Training Configuration ===")
logger.info(f"Data path: {config.data_path}")
logger.info(f"Lookback window: {config.lookback_window}")
logger.info(f"Predict window: {config.predict_window}")
logger.info(f"Batch size: {config.batch_size}")
logger.info(f"Learning rate: {config.predictor_learning_rate}")
logger.info(f"Training epochs: {config.basemodel_epochs}")
logger.info(f"Device: {device}")
logger.info(f"Tokenizer path: {config.finetuned_tokenizer_path}")
logger.info(f"Pretrained model path: {config.pretrained_predictor_path}")
logger.info("Starting fine-tuning training...")
print("Starting fine-tuning training...")
best_val_loss = train_model(model, tokenizer, device, config, config.basemodel_save_path, logger)
final_msg = f"Training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.basemodel_save_path}"
logger.info(final_msg)
print(final_msg)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,359 @@
import os
import sys
import json
import time
import random
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from time import gmtime, strftime
import datetime
import logging
from logging.handlers import RotatingFileHandler
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
sys.path.append("../")
from model import KronosTokenizer
from finetune_base_model import CustomKlineDataset
from config_loader import CustomFinetuneConfig
def set_seed(seed: int, rank: int = 0):
actual_seed = seed
random.seed(actual_seed)
np.random.seed(actual_seed)
torch.manual_seed(actual_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(actual_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_model_size(model: torch.nn.Module) -> str:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
if total_params >= 1e9:
return f"{total_params / 1e9:.1f}B"
elif total_params >= 1e6:
return f"{total_params / 1e6:.1f}M"
else:
return f"{total_params / 1e3:.1f}K"
def format_time(seconds: float) -> str:
return str(datetime.timedelta(seconds=int(seconds)))
def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger:
os.makedirs(log_dir, exist_ok=True)
logger = logging.getLogger(f"tokenizer_training_rank_{rank}")
logger.setLevel(logging.INFO)
if logger.handlers:
return logger
log_file = os.path.join(log_dir, f"tokenizer_training_rank_{rank}.log")
file_handler = RotatingFileHandler(
log_file,
maxBytes=10*1024*1024,
backupCount=5,
encoding='utf-8'
)
file_handler.setLevel(logging.INFO)
console_handler = None
if rank == 0:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter)
if console_handler is not None:
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if console_handler is not None:
logger.addHandler(console_handler)
logger.info(f"=== Tokenizer Training Started ===")
logger.info(f"Experiment Name: {exp_name}")
logger.info(f"Log Directory: {log_dir}")
logger.info(f"Rank: {rank}")
logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
return logger
def create_dataloaders(config):
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
print("Creating tokenizer training data loaders...")
train_dataset = CustomKlineDataset(
data_path=config.data_path,
data_type="train",
lookback_window=config.lookback_window,
predict_window=config.predict_window,
clip=config.clip,
seed=config.seed,
train_ratio=config.train_ratio,
val_ratio=config.val_ratio,
test_ratio=config.test_ratio
)
val_dataset = CustomKlineDataset(
data_path=config.data_path,
data_type="val",
lookback_window=config.lookback_window,
predict_window=config.predict_window,
clip=config.clip,
seed=config.seed + 1,
train_ratio=config.train_ratio,
val_ratio=config.val_ratio,
test_ratio=config.test_ratio
)
use_ddp = dist.is_available() and dist.is_initialized()
train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None
val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=(train_sampler is None),
num_workers=config.num_workers,
pin_memory=True,
drop_last=True,
sampler=train_sampler
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
drop_last=False,
sampler=val_sampler
)
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler
def train_tokenizer(model, device, config, save_dir, logger):
logger.info("Starting tokenizer training...")
use_ddp = dist.is_available() and dist.is_initialized()
rank = dist.get_rank() if use_ddp else 0
world_size = dist.get_world_size() if use_ddp else 1
train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.tokenizer_learning_rate,
weight_decay=config.adam_weight_decay
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=config.tokenizer_learning_rate,
steps_per_epoch=len(train_loader),
epochs=config.tokenizer_epochs,
pct_start=0.03,
div_factor=10
)
if use_ddp:
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
best_val_loss = float("inf")
batch_idx_global = 0
accumulation_steps = getattr(config, 'accumulation_steps', 1)
for epoch in range(config.tokenizer_epochs):
epoch_start_time = time.time()
model.train()
train_dataset.set_epoch_seed(epoch * 10000)
val_dataset.set_epoch_seed(0)
if train_sampler is not None:
train_sampler.set_epoch(epoch)
for batch_idx, (ori_batch_x, _) in enumerate(train_loader):
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
current_batch_total_loss = 0.0
for j in range(accumulation_steps):
start_idx = j * (ori_batch_x.shape[0] // accumulation_steps)
end_idx = (j + 1) * (ori_batch_x.shape[0] // accumulation_steps)
batch_x = ori_batch_x[start_idx:end_idx]
zs, bsq_loss, _, _ = (model.module if use_ddp else model)(batch_x)
z_pre, z = zs
recon_loss_pre = F.mse_loss(z_pre, batch_x)
recon_loss_all = F.mse_loss(z, batch_x)
recon_loss = recon_loss_pre + recon_loss_all
loss = (recon_loss + bsq_loss) / 2
loss_scaled = loss / accumulation_steps
current_batch_total_loss += loss.item()
loss_scaled.backward()
torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=2.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if (batch_idx_global + 1) % config.log_interval == 0:
avg_loss = current_batch_total_loss / accumulation_steps
lr = optimizer.param_groups[0]["lr"]
log_msg = (f"[Epoch {epoch+1}/{config.tokenizer_epochs}, Step {batch_idx+1}/{len(train_loader)}] "
f"LR: {lr:.6f}, Loss: {avg_loss:.4f}")
logger.info(log_msg)
if rank == 0:
print(log_msg)
detail_msg = (f" - VQ Loss: {bsq_loss.item():.4f}\n"
f" - Recon Loss Pre: {recon_loss_pre.item():.4f}\n"
f" - Recon Loss All: {recon_loss_all.item():.4f}")
logger.info(detail_msg)
if rank == 0:
print(detail_msg)
batch_idx_global += 1
model.eval()
tot_val_loss_sum_rank = 0.0
val_sample_count_rank = 0
with torch.no_grad():
for ori_batch_x, _ in val_loader:
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
zs, _, _, _ = (model.module if use_ddp else model)(ori_batch_x)
_, z = zs
val_loss_item = F.mse_loss(z, ori_batch_x)
tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0)
val_sample_count_rank += ori_batch_x.size(0)
if use_ddp:
tensor_sum = torch.tensor([tot_val_loss_sum_rank, val_sample_count_rank], dtype=torch.float64, device=device)
dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM)
tot_val_loss_all = tensor_sum[0].item()
val_count_all = int(tensor_sum[1].item())
avg_val_loss = (tot_val_loss_all / val_count_all) if val_count_all > 0 else 0.0
else:
avg_val_loss = tot_val_loss_sum_rank / val_sample_count_rank if val_sample_count_rank > 0 else 0
epoch_time = time.time() - epoch_start_time
epoch_summary = (f"\n--- Epoch {epoch+1}/{config.tokenizer_epochs} Summary ---\n"
f"Validation Loss: {avg_val_loss:.4f}\n"
f"Epoch Time: {format_time(epoch_time)}\n"
f"Total Training Time: {format_time(time.time() - epoch_start_time)}\n")
logger.info(epoch_summary)
if rank == 0:
print(epoch_summary)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
if rank == 0:
model_save_path = os.path.join(save_dir, "best_model")
os.makedirs(model_save_path, exist_ok=True)
(model.module if use_ddp else model).save_pretrained(model_save_path)
save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})"
logger.info(save_msg)
print(save_msg)
return best_val_loss
def main():
import argparse
parser = argparse.ArgumentParser(description='Kronos Tokenizer Fine-tuning Training')
parser.add_argument('--config', type=str, default='config.yaml',
help='Configuration file path (default: config.yaml)')
args = parser.parse_args()
config = CustomFinetuneConfig(args.config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
config = CustomFinetuneConfig(args.config)
os.makedirs(config.tokenizer_save_path, exist_ok=True)
log_dir = os.path.join(config.base_save_path, "logs")
logger = setup_logging(config.exp_name, log_dir, 0)
set_seed(config.seed)
# 加载预训练tokenizer
if getattr(config, 'pre_trained_tokenizer', True):
logger.info("Loading pretrained tokenizer...")
print("Loading pretrained tokenizer...")
tokenizer = KronosTokenizer.from_pretrained(config.pretrained_tokenizer_path)
else:
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture")
import json, os
cfg_path = os.path.join(config.pretrained_tokenizer_path, 'config.json')
with open(cfg_path, 'r') as f:
arch = json.load(f)
tokenizer = KronosTokenizer(
d_in=arch.get('d_in', 6),
d_model=arch.get('d_model', 256),
n_heads=arch.get('n_heads', 4),
ff_dim=arch.get('ff_dim', 512),
n_enc_layers=arch.get('n_enc_layers', 4),
n_dec_layers=arch.get('n_dec_layers', 4),
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
s1_bits=arch.get('s1_bits', 10),
s2_bits=arch.get('s2_bits', 10),
beta=arch.get('beta', 0.05),
gamma0=arch.get('gamma0', 1.0),
gamma=arch.get('gamma', 1.1),
zeta=arch.get('zeta', 0.05),
group_size=arch.get('group_size', 4)
)
tokenizer = tokenizer.to(device)
model_size = get_model_size(tokenizer)
logger.info(f"Tokenizer parameters: {model_size}")
print(f"Tokenizer parameters: {model_size}")
logger.info("=== Training Configuration ===")
logger.info(f"Data path: {config.data_path}")
logger.info(f"Lookback window: {config.lookback_window}")
logger.info(f"Predict window: {config.predict_window}")
logger.info(f"Batch size: {config.batch_size}")
logger.info(f"Learning rate: {config.tokenizer_learning_rate}")
logger.info(f"Training epochs: {config.tokenizer_epochs}")
logger.info(f"Device: {device}")
logger.info(f"Distributed training: False")
logger.info("Starting tokenizer fine-tuning training...")
print("Starting tokenizer fine-tuning training...")
best_val_loss = train_tokenizer(tokenizer, device, config, config.tokenizer_save_path, logger)
final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.tokenizer_save_path}"
logger.info(final_msg)
print(final_msg)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,361 @@
import os
import sys
import time
import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.distributed as dist
sys.path.append('../')
from model import Kronos, KronosTokenizer, KronosPredictor
from config_loader import CustomFinetuneConfig
from finetune_tokenizer import train_tokenizer, set_seed, setup_logging as setup_tokenizer_logging
from finetune_base_model import train_model, create_dataloaders, setup_logging as setup_basemodel_logging
class SequentialTrainer:
def __init__(self, config_path: str = None):
self.config = CustomFinetuneConfig(config_path)
self.rank = int(os.environ.get("RANK", "0"))
self.world_size = int(os.environ.get("WORLD_SIZE", "1"))
self.local_rank = int(os.environ.get("LOCAL_RANK", str(self.config.device_id if hasattr(self.config, 'device_id') else 0)))
self.device = self._setup_device()
self.config.print_config_summary()
def _setup_device(self):
if self.config.use_cuda and torch.cuda.is_available():
torch.cuda.set_device(self.local_rank)
device = torch.device(f"cuda:{self.local_rank}")
else:
device = torch.device("cpu")
if self.rank == 0:
print(f"Using device: {device} (rank={self.rank}, world_size={self.world_size}, local_rank={self.local_rank})")
return device
def _setup_distributed(self):
if self.world_size > 1 and torch.cuda.is_available():
backend = os.environ.get("DIST_BACKEND", "nccl").lower()
if not dist.is_initialized():
dist.init_process_group(backend=backend)
if self.rank == 0:
print(f"Distributed training initialized: backend={backend}, world_size={self.world_size}")
else:
if self.rank == 0:
print("Distributed training not enabled, using single GPU/CPU training")
def _check_existing_models(self):
tokenizer_exists = os.path.exists(self.config.tokenizer_best_model_path)
basemodel_exists = os.path.exists(self.config.basemodel_best_model_path)
print(f"Tokenizer model exists: {tokenizer_exists}")
print(f"Basemodel model exists: {basemodel_exists}")
return tokenizer_exists, basemodel_exists
def _create_directories(self):
os.makedirs(self.config.tokenizer_save_path, exist_ok=True)
os.makedirs(self.config.basemodel_save_path, exist_ok=True)
print(f"Created directory: {self.config.tokenizer_save_path}")
print(f"Created directory: {self.config.basemodel_save_path}")
def train_tokenizer_phase(self):
print("\n" + "="*60)
print("Starting Tokenizer Fine-tuning Phase")
print("="*60)
tokenizer_exists, _ = self._check_existing_models()
if tokenizer_exists and self.config.skip_existing:
print("Tokenizer model already exists, skipping training")
return True
log_dir = os.path.join(self.config.base_save_path, "logs")
logger = setup_tokenizer_logging(self.config.exp_name, log_dir, self.rank)
set_seed(self.config.seed)
if getattr(self.config, 'pre_trained_tokenizer', True):
logger.info("Loading pretrained tokenizer...")
if self.rank == 0:
print("Loading pretrained tokenizer...")
tokenizer = KronosTokenizer.from_pretrained(self.config.pretrained_tokenizer_path)
else:
if self.rank == 0:
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture")
import json
cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json')
with open(cfg_path, 'r') as f:
arch = json.load(f)
tokenizer = KronosTokenizer(
d_in=arch.get('d_in', 6),
d_model=arch.get('d_model', 256),
n_heads=arch.get('n_heads', 4),
ff_dim=arch.get('ff_dim', 512),
n_enc_layers=arch.get('n_enc_layers', 4),
n_dec_layers=arch.get('n_dec_layers', 4),
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
s1_bits=arch.get('s1_bits', 10),
s2_bits=arch.get('s2_bits', 10),
beta=arch.get('beta', 0.05),
gamma0=arch.get('gamma0', 1.0),
gamma=arch.get('gamma', 1.1),
zeta=arch.get('zeta', 0.05),
group_size=arch.get('group_size', 4)
)
tokenizer = tokenizer.to(self.device)
model_size = sum(p.numel() for p in tokenizer.parameters())
logger.info(f"Tokenizer parameters: {model_size:,}")
if self.rank == 0:
print(f"Tokenizer parameters: {model_size:,}")
logger.info("=== Training Configuration ===")
logger.info(f"Data path: {self.config.data_path}")
logger.info(f"Lookback window: {self.config.lookback_window}")
logger.info(f"Predict window: {self.config.predict_window}")
logger.info(f"Batch size: {self.config.batch_size}")
logger.info(f"Learning rate: {self.config.tokenizer_learning_rate}")
logger.info(f"Training epochs: {self.config.tokenizer_epochs}")
logger.info(f"Device: {self.device}")
logger.info(f"Distributed training: False")
logger.info("Starting tokenizer fine-tuning training...")
if self.rank == 0:
print("Starting tokenizer fine-tuning training...")
start_time = time.time()
best_val_loss = train_tokenizer(
tokenizer,
self.device,
self.config,
self.config.tokenizer_save_path,
logger,
)
training_time = time.time() - start_time
final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.tokenizer_save_path}"
logger.info(final_msg)
if self.rank == 0:
print(f"\n{final_msg}")
return True
def train_basemodel_phase(self):
print("\n" + "="*60)
print("Starting Basemodel Fine-tuning Phase")
print("="*60)
if getattr(self.config, 'pre_trained_tokenizer', True):
if not os.path.exists(self.config.finetuned_tokenizer_path):
raise FileNotFoundError(f"Fine-tuned tokenizer does not exist: {self.config.finetuned_tokenizer_path}")
_, basemodel_exists = self._check_existing_models()
if basemodel_exists and self.config.skip_existing:
print("Basemodel model already exists, skipping training")
return True
log_dir = os.path.join(self.config.base_save_path, "logs")
logger = setup_basemodel_logging(self.config.exp_name, log_dir, self.rank)
set_seed(self.config.seed)
if getattr(self.config, 'pre_trained_tokenizer', True):
logger.info("Loading fine-tuned tokenizer...")
if self.rank == 0:
print("Loading fine-tuned tokenizer...")
tokenizer = KronosTokenizer.from_pretrained(self.config.finetuned_tokenizer_path)
else:
if self.rank == 0:
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for Predictor training")
import json
cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json')
with open(cfg_path, 'r') as f:
arch = json.load(f)
tokenizer = KronosTokenizer(
d_in=arch.get('d_in', 6),
d_model=arch.get('d_model', 256),
n_heads=arch.get('n_heads', 4),
ff_dim=arch.get('ff_dim', 512),
n_enc_layers=arch.get('n_enc_layers', 4),
n_dec_layers=arch.get('n_dec_layers', 4),
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
s1_bits=arch.get('s1_bits', 10),
s2_bits=arch.get('s2_bits', 10),
beta=arch.get('beta', 0.05),
gamma0=arch.get('gamma0', 1.0),
gamma=arch.get('gamma', 1.1),
zeta=arch.get('zeta', 0.05),
group_size=arch.get('group_size', 4)
)
tokenizer = tokenizer.to(self.device)
if getattr(self.config, 'pre_trained_predictor', True):
logger.info("Loading pretrained predictor...")
if self.rank == 0:
print("Loading pretrained predictor...")
model = Kronos.from_pretrained(self.config.pretrained_predictor_path)
else:
if self.rank == 0:
print("pre_trained_predictor=False, randomly initializing Predictor architecture")
import json
cfg_path = os.path.join(self.config.pretrained_predictor_path, 'config.json')
with open(cfg_path, 'r') as f:
arch = json.load(f)
print("model_config: ", arch)
model = Kronos(
s1_bits=arch.get('s1_bits', 10),
s2_bits=arch.get('s2_bits', 10),
n_layers=arch.get('n_layers', 12),
d_model=arch.get('d_model', 832),
n_heads=arch.get('n_heads', 16),
ff_dim=arch.get('ff_dim', 2048),
ffn_dropout_p=arch.get('ffn_dropout_p', 0.2),
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
resid_dropout_p=arch.get('resid_dropout_p', 0.2),
token_dropout_p=arch.get('token_dropout_p', 0.0),
learn_te=arch.get('learn_te', True)
)
model = model.to(self.device)
model_size = sum(p.numel() for p in model.parameters())
logger.info(f"Model parameters: {model_size:,}")
if self.rank == 0:
print(f"Model parameters: {model_size:,}")
logger.info("=== Training Configuration ===")
logger.info(f"Data path: {self.config.data_path}")
logger.info(f"Lookback window: {self.config.lookback_window}")
logger.info(f"Predict window: {self.config.predict_window}")
logger.info(f"Batch size: {self.config.batch_size}")
logger.info(f"Learning rate: {self.config.predictor_learning_rate}")
logger.info(f"Training epochs: {self.config.basemodel_epochs}")
logger.info(f"Device: {self.device}")
logger.info(f"Tokenizer path: {self.config.finetuned_tokenizer_path}")
logger.info(f"Pretrained model path: {self.config.pretrained_predictor_path}")
logger.info("Starting fine-tuning training...")
if self.rank == 0:
print("Starting fine-tuning training...")
start_time = time.time()
best_val_loss = train_model(
model,
tokenizer,
self.device,
self.config,
self.config.basemodel_save_path,
logger,
)
training_time = time.time() - start_time
final_msg = f"Basemodel training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.basemodel_save_path}"
logger.info(final_msg)
if self.rank == 0:
print(f"\n{final_msg}")
return True
def run_training(self):
if self.rank == 0:
print("Starting Kronos model sequential fine-tuning training")
print(f"Experiment name: {self.config.experiment_name}")
print(f"Experiment description: {self.config.experiment_description}")
self._setup_distributed()
self._create_directories()
tokenizer_exists, basemodel_exists = self._check_existing_models()
total_start_time = time.time()
try:
if self.config.train_tokenizer:
success = self.train_tokenizer_phase()
if not success:
print("Tokenizer training failed, terminating training")
return False
else:
print("Skipping Tokenizer training phase")
if self.config.train_basemodel:
success = self.train_basemodel_phase()
if not success:
print("Basemodel training failed, terminating training")
return False
else:
print("Skipping Basemodel training phase")
total_time = time.time() - total_start_time
if self.rank == 0:
print("\n" + "="*60)
print("Training completed!")
print("="*60)
print(f"Total training time: {total_time/60:.2f} minutes")
print(f"Tokenizer model: {self.config.tokenizer_best_model_path}")
print(f"Basemodel model: {self.config.basemodel_best_model_path}")
print("="*60)
return True
except Exception as e:
if self.rank == 0:
print(f"Error occurred during training: {str(e)}")
import traceback
traceback.print_exc()
return False
finally:
pass
def main():
parser = argparse.ArgumentParser(description='Kronos Model Sequential Fine-tuning Training')
parser.add_argument('--config', type=str, default='config.yaml',
help='Configuration file path (default: config.yaml)')
parser.add_argument('--skip-tokenizer', action='store_true',
help='Skip tokenizer training phase')
parser.add_argument('--skip-basemodel', action='store_true',
help='Skip basemodel training phase')
parser.add_argument('--skip-existing', action='store_true',
help='Skip training for existing models')
args = parser.parse_args()
trainer = SequentialTrainer(args.config)
if args.skip_tokenizer:
trainer.config.train_tokenizer = False
if args.skip_basemodel:
trainer.config.train_basemodel = False
if args.skip_existing:
trainer.config.skip_existing = True
success = trainer.run_training()
if success:
print("Training completed successfully!")
if dist.is_available() and dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
sys.exit(0)
else:
print("Training failed!")
if dist.is_available() and dist.is_initialized():
try:
dist.barrier()
dist.destroy_process_group()
except Exception:
pass
sys.exit(1)
if __name__ == "__main__":
main()