add custom data finetune

This commit is contained in:
zhangboyu1 2025-10-09 15:48:39 +08:00
parent 083294bd84
commit 84f74ae341
7 changed files with 1733 additions and 0 deletions

101
finetune_csv/README.md Normal file
View File

@ -0,0 +1,101 @@
# Kronos Fine-tuning Training with Custom Dataset
Supports fine-tuning training with custom CSV data using configuration files
## Quick Start
### 1. Configuration Setup
First edit the `config.yaml` file to set the correct paths and parameters:
```yaml
# Data configuration
data:
data_path: "/path/to/your/data.csv"
lookback_window: 512
predict_window: 48
# ... other parameters
# Model path configuration
model_paths:
pretrained_tokenizer: "/path/to/pretrained/tokenizer"
pretrained_predictor: "/path/to/pretrained/predictor"
base_save_path: "/path/to/save/models"
# ... other paths
```
### 2. Run Training
Using train_sequential
```bash
# Complete training
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml
# Skip existing models
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing
# Only train tokenizer
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel
# Only train basemodel
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer
```
Run each stage separately
```bash
# Only train tokenizer
python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml
# Only train basemodel (requires fine-tuned tokenizer first)
python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml
```
DDP Training
```bash
# Choose communication protocol yourself, nccl can be replaced with gloo
DIST_BACKEND=nccl \
torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml
```
## Configuration Description
### Main Configuration Items
- **data**: Data-related configuration
- `data_path`: CSV data file path
- `lookback_window`: Lookback window size
- `predict_window`: Prediction window size
- `train_ratio/val_ratio/test_ratio`: Dataset split ratios
- **training**: Training-related configuration
- `epochs`: Number of training epochs
- `batch_size`: Batch size
- `tokenizer_learning_rate`: Tokenizer learning rate
- `predictor_learning_rate`: Predictor learning rate
- **model_paths**: Model path configuration
- `pretrained_tokenizer`: Pre-trained tokenizer path
- `pretrained_predictor`: Pre-trained predictor path
- `base_save_path`: Model save root directory
- `finetuned_tokenizer`: Fine-tuned tokenizer path (for basemodel training)
- **experiment**: Experiment control
- `train_tokenizer`: Whether to train tokenizer
- `train_basemodel`: Whether to train basemodel
- `skip_existing`: Whether to skip existing models
## Training Process
1. **Tokenizer Fine-tuning Stage**
- Load pre-trained tokenizer
- Fine-tune on custom data
- Save fine-tuned tokenizer to `{base_save_path}/tokenizer/best_model/`
2. **Basemodel Fine-tuning Stage**
- Load fine-tuned tokenizer and pre-trained predictor
- Fine-tune on custom data
- Save fine-tuned basemodel to `{base_save_path}/basemodel/best_model/`
**Data Format**: Ensure CSV file contains the following columns: `timestamps`, `open`, `high`, `low`, `close`, `volume`, `amount`

105
finetune_csv/README_CN.md Normal file
View File

@ -0,0 +1,105 @@
# 自定义数据集的Kronos微调训练
支持使用配置文件进行自定义csv数据的微调训练
## 快速开始
### 1. 配置设置
首先编辑 `config.yaml` 文件,设置正确的路径和参数:
```yaml
# 数据配置
data:
data_path: "/path/to/your/data.csv"
lookback_window: 512
predict_window: 48
# ... 其他参数
# 模型路径配置
model_paths:
pretrained_tokenizer: "/path/to/pretrained/tokenizer"
pretrained_predictor: "/path/to/pretrained/predictor"
base_save_path: "/path/to/save/models"
# ... 其他路径
```
### 2. 运行训练
使用train_sequential
```bash
# 完整训练
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml
# 跳过已存在的模型
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing
# 只训练tokenizer
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel
# 只训练basemodel
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer
```
单独运行各个阶段
```bash
# 只训练tokenizer
python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml
# 只训练basemodel需要先有微调后的tokenizer
python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml
```
DDP训练
```bash
# 通信协议自行选择nccl可替换gloo
DIST_BACKEND=nccl \
torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml
```
## 配置说明
### 主要配置项
- **data**: 数据相关配置
- `data_path`: CSV数据文件路径
- `lookback_window`: 回望窗口大小
- `predict_window`: 预测窗口大小
- `train_ratio/val_ratio/test_ratio`: 数据集分割比例
- **training**: 训练相关配置
- `epochs`: 训练轮数
- `batch_size`: 批次大小
- `tokenizer_learning_rate`: Tokenizer学习率
- `predictor_learning_rate`: Predictor学习率
- **model_paths**: 模型路径配置
- `pretrained_tokenizer`: 预训练tokenizer路径
- `pretrained_predictor`: 预训练predictor路径
- `base_save_path`: 模型保存根目录
- `finetuned_tokenizer`: 微调后tokenizer路径用于basemodel训练
- **experiment**: 实验控制
- `train_tokenizer`: 是否训练tokenizer
- `train_basemodel`: 是否训练basemodel
- `skip_existing`: 是否跳过已存在的模型
## 训练流程
1. **Tokenizer微调阶段**
- 加载预训练tokenizer
- 在自定义数据上微调
- 保存微调后的tokenizer到 `{base_save_path}/tokenizer/best_model/`
2. **Basemodel微调阶段**
- 加载微调后的tokenizer和预训练predictor
- 在自定义数据上微调
- 保存微调后的basemodel到 `{base_save_path}/basemodel/best_model/`
**数据格式**: 确保CSV文件包含以下列`timestamps`, `open`, `high`, `low`, `close`, `volume`, `amount`

View File

@ -0,0 +1,267 @@
import os
import yaml
from typing import Dict, Any
class ConfigLoader:
def __init__(self, config_path: str):
self.config_path = config_path
self.config = self._load_config()
def _load_config(self) -> Dict[str, Any]:
if not os.path.exists(self.config_path):
raise FileNotFoundError(f"config file not found: {self.config_path}")
with open(self.config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
config = self._resolve_dynamic_paths(config)
return config
def _resolve_dynamic_paths(self, config: Dict[str, Any]) -> Dict[str, Any]:
exp_name = config.get('model_paths', {}).get('exp_name', '')
if not exp_name:
return config
base_path = config.get('model_paths', {}).get('base_path', '')
path_templates = {
'base_save_path': f"{base_path}/{exp_name}",
'finetuned_tokenizer': f"{base_path}/{exp_name}/tokenizer/best_model"
}
if 'model_paths' in config:
for key, template in path_templates.items():
if key in config['model_paths']:
# only use template when the original value is empty string
current_value = config['model_paths'][key]
if current_value == "" or current_value is None:
config['model_paths'][key] = template
else:
# if the original value is not empty, use template to replace the {exp_name} placeholder
if isinstance(current_value, str) and '{exp_name}' in current_value:
config['model_paths'][key] = current_value.format(exp_name=exp_name)
return config
def get(self, key: str, default=None):
keys = key.split('.')
value = self.config
try:
for k in keys:
value = value[k]
return value
except (KeyError, TypeError):
return default
def get_data_config(self) -> Dict[str, Any]:
return self.config.get('data', {})
def get_training_config(self) -> Dict[str, Any]:
return self.config.get('training', {})
def get_model_paths(self) -> Dict[str, str]:
return self.config.get('model_paths', {})
def get_experiment_config(self) -> Dict[str, Any]:
return self.config.get('experiment', {})
def get_device_config(self) -> Dict[str, Any]:
return self.config.get('device', {})
def get_distributed_config(self) -> Dict[str, Any]:
return self.config.get('distributed', {})
def update_config(self, updates: Dict[str, Any]):
def update_nested_dict(d, u):
for k, v in u.items():
if isinstance(v, dict):
d[k] = update_nested_dict(d.get(k, {}), v)
else:
d[k] = v
return d
self.config = update_nested_dict(self.config, updates)
def save_config(self, save_path: str = None):
if save_path is None:
save_path = self.config_path
with open(save_path, 'w', encoding='utf-8') as f:
yaml.dump(self.config, f, default_flow_style=False, allow_unicode=True, indent=2)
def print_config(self):
print("=" * 50)
print("Current configuration:")
print("=" * 50)
yaml.dump(self.config, default_flow_style=False, allow_unicode=True, indent=2)
print("=" * 50)
class CustomFinetuneConfig:
def __init__(self, config_path: str = None):
if config_path is None:
config_path = os.path.join(os.path.dirname(__file__), 'config.yaml')
self.loader = ConfigLoader(config_path)
self._load_all_configs()
def _load_all_configs(self):
data_config = self.loader.get_data_config()
self.data_path = data_config.get('data_path')
self.lookback_window = data_config.get('lookback_window', 512)
self.predict_window = data_config.get('predict_window', 48)
self.max_context = data_config.get('max_context', 512)
self.clip = data_config.get('clip', 5.0)
self.train_ratio = data_config.get('train_ratio', 0.9)
self.val_ratio = data_config.get('val_ratio', 0.1)
self.test_ratio = data_config.get('test_ratio', 0.0)
# training configuration
training_config = self.loader.get_training_config()
# support training epochs of tokenizer and basemodel separately
self.tokenizer_epochs = training_config.get('tokenizer_epochs', 30)
self.basemodel_epochs = training_config.get('basemodel_epochs', 30)
if 'epochs' in training_config and 'tokenizer_epochs' not in training_config:
self.tokenizer_epochs = training_config.get('epochs', 30)
if 'epochs' in training_config and 'basemodel_epochs' not in training_config:
self.basemodel_epochs = training_config.get('epochs', 30)
self.batch_size = training_config.get('batch_size', 160)
self.log_interval = training_config.get('log_interval', 50)
self.num_workers = training_config.get('num_workers', 6)
self.seed = training_config.get('seed', 100)
self.tokenizer_learning_rate = training_config.get('tokenizer_learning_rate', 2e-4)
self.predictor_learning_rate = training_config.get('predictor_learning_rate', 4e-5)
self.adam_beta1 = training_config.get('adam_beta1', 0.9)
self.adam_beta2 = training_config.get('adam_beta2', 0.95)
self.adam_weight_decay = training_config.get('adam_weight_decay', 0.1)
self.accumulation_steps = training_config.get('accumulation_steps', 1)
model_paths = self.loader.get_model_paths()
self.exp_name = model_paths.get('exp_name', 'default_experiment')
self.pretrained_tokenizer_path = model_paths.get('pretrained_tokenizer')
self.pretrained_predictor_path = model_paths.get('pretrained_predictor')
self.base_save_path = model_paths.get('base_save_path')
self.tokenizer_save_name = model_paths.get('tokenizer_save_name', 'tokenizer')
self.basemodel_save_name = model_paths.get('basemodel_save_name', 'basemodel')
self.finetuned_tokenizer_path = model_paths.get('finetuned_tokenizer')
experiment_config = self.loader.get_experiment_config()
self.experiment_name = experiment_config.get('name', 'kronos_custom_finetune')
self.experiment_description = experiment_config.get('description', '')
self.use_comet = experiment_config.get('use_comet', False)
self.train_tokenizer = experiment_config.get('train_tokenizer', True)
self.train_basemodel = experiment_config.get('train_basemodel', True)
self.skip_existing = experiment_config.get('skip_existing', False)
unified_pretrained = experiment_config.get('pre_trained', None)
self.pre_trained_tokenizer = experiment_config.get('pre_trained_tokenizer', unified_pretrained if unified_pretrained is not None else True)
self.pre_trained_predictor = experiment_config.get('pre_trained_predictor', unified_pretrained if unified_pretrained is not None else True)
device_config = self.loader.get_device_config()
self.use_cuda = device_config.get('use_cuda', True)
self.device_id = device_config.get('device_id', 0)
distributed_config = self.loader.get_distributed_config()
self.use_ddp = distributed_config.get('use_ddp', False)
self.ddp_backend = distributed_config.get('backend', 'nccl')
self._compute_full_paths()
def _compute_full_paths(self):
self.tokenizer_save_path = os.path.join(self.base_save_path, self.tokenizer_save_name)
self.tokenizer_best_model_path = os.path.join(self.tokenizer_save_path, 'best_model')
self.basemodel_save_path = os.path.join(self.base_save_path, self.basemodel_save_name)
self.basemodel_best_model_path = os.path.join(self.basemodel_save_path, 'best_model')
def get_tokenizer_config(self):
return {
'data_path': self.data_path,
'lookback_window': self.lookback_window,
'predict_window': self.predict_window,
'max_context': self.max_context,
'clip': self.clip,
'train_ratio': self.train_ratio,
'val_ratio': self.val_ratio,
'test_ratio': self.test_ratio,
'epochs': self.tokenizer_epochs,
'batch_size': self.batch_size,
'log_interval': self.log_interval,
'num_workers': self.num_workers,
'seed': self.seed,
'learning_rate': self.tokenizer_learning_rate,
'adam_beta1': self.adam_beta1,
'adam_beta2': self.adam_beta2,
'adam_weight_decay': self.adam_weight_decay,
'accumulation_steps': self.accumulation_steps,
'pretrained_model_path': self.pretrained_tokenizer_path,
'save_path': self.tokenizer_save_path,
'use_comet': self.use_comet
}
def get_basemodel_config(self):
return {
'data_path': self.data_path,
'lookback_window': self.lookback_window,
'predict_window': self.predict_window,
'max_context': self.max_context,
'clip': self.clip,
'train_ratio': self.train_ratio,
'val_ratio': self.val_ratio,
'test_ratio': self.test_ratio,
'epochs': self.basemodel_epochs,
'batch_size': self.batch_size,
'log_interval': self.log_interval,
'num_workers': self.num_workers,
'seed': self.seed,
'predictor_learning_rate': self.predictor_learning_rate,
'tokenizer_learning_rate': self.tokenizer_learning_rate,
'adam_beta1': self.adam_beta1,
'adam_beta2': self.adam_beta2,
'adam_weight_decay': self.adam_weight_decay,
'pretrained_tokenizer_path': self.finetuned_tokenizer_path,
'pretrained_predictor_path': self.pretrained_predictor_path,
'save_path': self.basemodel_save_path,
'use_comet': self.use_comet
}
def print_config_summary(self):
print("=" * 60)
print("Kronos finetuning configuration summary")
print("=" * 60)
print(f"Experiment name: {self.exp_name}")
print(f"Data path: {self.data_path}")
print(f"Lookback window: {self.lookback_window}")
print(f"Predict window: {self.predict_window}")
print(f"Tokenizer training epochs: {self.tokenizer_epochs}")
print(f"Basemodel training epochs: {self.basemodel_epochs}")
print(f"Batch size: {self.batch_size}")
print(f"Tokenizer learning rate: {self.tokenizer_learning_rate}")
print(f"Predictor learning rate: {self.predictor_learning_rate}")
print(f"Train tokenizer: {self.train_tokenizer}")
print(f"Train basemodel: {self.train_basemodel}")
print(f"Skip existing: {self.skip_existing}")
print(f"Use pre-trained tokenizer: {self.pre_trained_tokenizer}")
print(f"Use pre-trained predictor: {self.pre_trained_predictor}")
print(f"Base save path: {self.base_save_path}")
print(f"Tokenizer save path: {self.tokenizer_save_path}")
print(f"Basemodel save path: {self.basemodel_save_path}")
print("=" * 60)

View File

@ -0,0 +1,72 @@
#This is a template config for custom finetuning kronos on csv data
#这是一份模板config用于kronos的csv自定义数据微调
data:
data_path: "/xxxx/Kronos/finetune_csv2/data/HK_ali_09988_kline_5min_all.csv"
lookback_window: 512
predict_window: 48
max_context: 512
clip: 5.0
# dataset split ratio
train_ratio: 0.9
val_ratio: 0.1
test_ratio: 0.0
training:
# control the training epochs of tokenizer and basemodel
tokenizer_epochs: 30
basemodel_epochs: 20
batch_size: 32
log_interval: 50
num_workers: 6
seed: 42
tokenizer_learning_rate: 0.0002
predictor_learning_rate: 0.000001
adam_beta1: 0.9
adam_beta2: 0.95
adam_weight_decay: 0.1
# gradient accumulation steps for tokenizer training
accumulation_steps: 1
# model path configuration
model_paths:
# pretrained model path
pretrained_tokenizer: "/mnt/DigitalHuman2D/boyuzhang/quant/Kronos/pretrained/Kronos-Tokenizer-base"
pretrained_predictor: "/mnt/DigitalHuman2D/boyuzhang/quant/Kronos/pretrained/Kronos-base"
# experiment name - other paths will be generated based on this
exp_name: "HK_ali_09988_kline_5min_all"
base_path: "/mnt/DigitalHuman2D/boyuzhang/quant/Kronos/finetune_csv/finetuned/"
# the following paths will be generated based on exp_name, no need to modify manually
# way 1: leave empty string, the system will generate the full path
base_save_path: "" # /xxxx/Kronos/finetune_csv/finetuned/{exp_name}
finetuned_tokenizer: "" # /xxxx/quant/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model
# way 2: use template string, {exp_name} will be replaced with the actual experiment name
# base_save_path: "/xxxx/Kronos/finetune_csv/finetuned/{exp_name}"
# finetuned_tokenizer: "/xxxx/quant/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model"
tokenizer_save_name: "tokenizer"
basemodel_save_name: "basemodel"
experiment:
name: "kronos_custom_finetune"
description: "Custom finetune for HK stock data"
use_comet: false
# control the training phase
train_tokenizer: true
train_basemodel: true
# if true, skip the existing model training
skip_existing: false
# device configuration
device:
use_cuda: true
device_id: 0

View File

@ -0,0 +1,468 @@
import os
import sys
import json
import time
import pickle
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from time import gmtime, strftime
import logging
from logging.handlers import RotatingFileHandler
import datetime
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
sys.path.append('../')
from model import Kronos, KronosTokenizer, KronosPredictor
from config_loader import CustomFinetuneConfig
class CustomKlineDataset(Dataset):
def __init__(self, data_path, data_type='train', lookback_window=90, predict_window=10,
clip=5.0, seed=100, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
self.data_path = data_path
self.data_type = data_type
self.lookback_window = lookback_window
self.predict_window = predict_window
self.window = lookback_window + predict_window + 1
self.clip = clip
self.seed = seed
self.train_ratio = train_ratio
self.val_ratio = val_ratio
self.test_ratio = test_ratio
self.feature_list = ['open', 'high', 'low', 'close', 'volume', 'amount']
self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month']
self.py_rng = random.Random(seed)
self._load_and_preprocess_data()
self._split_data_by_time()
self.n_samples = len(self.data) - self.window + 1
print(f"[{data_type.upper()}] Data length: {len(self.data)}, Available samples: {self.n_samples}")
def _load_and_preprocess_data(self):
df = pd.read_csv(self.data_path)
df['timestamps'] = pd.to_datetime(df['timestamps'])
df = df.sort_values('timestamps').reset_index(drop=True)
self.timestamps = df['timestamps'].copy()
df['minute'] = df['timestamps'].dt.minute
df['hour'] = df['timestamps'].dt.hour
df['weekday'] = df['timestamps'].dt.weekday
df['day'] = df['timestamps'].dt.day
df['month'] = df['timestamps'].dt.month
self.data = df[self.feature_list + self.time_feature_list].copy()
if self.data.isnull().any().any():
print("Warning: Missing values found in data, performing forward fill")
self.data = self.data.fillna(method='ffill')
print(f"Original data time range: {self.timestamps.min()} to {self.timestamps.max()}")
print(f"Original data total length: {len(df)} records")
def _split_data_by_time(self):
total_length = len(self.data)
train_end = int(total_length * self.train_ratio)
val_end = int(total_length * (self.train_ratio + self.val_ratio))
if self.data_type == 'train':
self.data = self.data.iloc[:train_end].copy()
self.timestamps = self.timestamps.iloc[:train_end].copy()
print(f"[{self.data_type.upper()}] Training set: first {train_end} time points ({self.train_ratio})")
print(f"[{self.data_type.upper()}] Training set time range: {self.timestamps.min()} to {self.timestamps.max()}")
elif self.data_type == 'val':
self.data = self.data.iloc[train_end:val_end].copy()
self.timestamps = self.timestamps.iloc[train_end:val_end].copy()
print(f"[{self.data_type.upper()}] Validation set: time points {train_end+1} to {val_end} ({self.val_ratio})")
print(f"[{self.data_type.upper()}] Validation set time range: {self.timestamps.min()} to {self.timestamps.max()}")
elif self.data_type == 'test':
self.data = self.data.iloc[val_end:].copy()
self.timestamps = self.timestamps.iloc[val_end:].copy()
print(f"[{self.data_type.upper()}] Test set: after time point {val_end+1}")
print(f"[{self.data_type.upper()}] Test set time range: {self.timestamps.min()} to {self.timestamps.max()}")
print(f"[{self.data_type.upper()}] Data length after split: {len(self.data)} records")
def set_epoch_seed(self, epoch):
epoch_seed = self.seed + epoch
self.py_rng.seed(epoch_seed)
self.current_epoch = epoch
def __len__(self):
return self.n_samples
def __getitem__(self, idx):
max_start = len(self.data) - self.window
if max_start <= 0:
raise ValueError("Data length insufficient to create samples")
if self.data_type == 'train':
epoch = getattr(self, 'current_epoch', 0)
start_idx = (idx * 9973 + (epoch + 1) * 104729) % (max_start + 1)
else:
start_idx = idx % (max_start + 1)
end_idx = start_idx + self.window
window_data = self.data.iloc[start_idx:end_idx]
x = window_data[self.feature_list].values.astype(np.float32)
x_stamp = window_data[self.time_feature_list].values.astype(np.float32)
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
x = (x - x_mean) / (x_std + 1e-5)
x = np.clip(x, -self.clip, self.clip)
x_tensor = torch.from_numpy(x)
x_stamp_tensor = torch.from_numpy(x_stamp)
return x_tensor, x_stamp_tensor
def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger:
os.makedirs(log_dir, exist_ok=True)
logger = logging.getLogger(f"basemodel_training_rank_{rank}")
logger.setLevel(logging.INFO)
if logger.handlers:
return logger
log_file = os.path.join(log_dir, f"basemodel_training_rank_{rank}.log")
file_handler = RotatingFileHandler(
log_file,
maxBytes=10*1024*1024,
backupCount=5,
encoding='utf-8'
)
file_handler.setLevel(logging.INFO)
console_handler = None
if rank == 0:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter)
if console_handler is not None:
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if console_handler is not None:
logger.addHandler(console_handler)
logger.info(f"=== Basemodel Training Started ===")
logger.info(f"Experiment Name: {exp_name}")
logger.info(f"Log Directory: {log_dir}")
logger.info(f"Rank: {rank}")
logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
return logger
def create_dataloaders(config):
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
print("Creating data loaders...")
train_dataset = CustomKlineDataset(
data_path=config.data_path,
data_type='train',
lookback_window=config.lookback_window,
predict_window=config.predict_window,
clip=config.clip,
seed=config.seed,
train_ratio=config.train_ratio,
val_ratio=config.val_ratio,
test_ratio=config.test_ratio
)
val_dataset = CustomKlineDataset(
data_path=config.data_path,
data_type='val',
lookback_window=config.lookback_window,
predict_window=config.predict_window,
clip=config.clip,
seed=config.seed + 1,
train_ratio=config.train_ratio,
val_ratio=config.val_ratio,
test_ratio=config.test_ratio
)
use_ddp = dist.is_available() and dist.is_initialized()
train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None
val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=(train_sampler is None),
num_workers=config.num_workers,
pin_memory=True,
drop_last=True,
sampler=train_sampler
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
drop_last=False,
sampler=val_sampler
)
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler
def train_model(model, tokenizer, device, config, save_dir, logger):
logger.info("Starting training...")
use_ddp = dist.is_available() and dist.is_initialized()
rank = dist.get_rank() if use_ddp else 0
world_size = dist.get_world_size() if use_ddp else 1
train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.predictor_learning_rate,
betas=(config.adam_beta1, config.adam_beta2),
weight_decay=config.adam_weight_decay
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=config.predictor_learning_rate,
steps_per_epoch=len(train_loader),
epochs=config.basemodel_epochs,
pct_start=0.03,
div_factor=10
)
if use_ddp:
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
best_val_loss = float('inf')
batch_idx_global = 0
for epoch in range(config.basemodel_epochs):
epoch_start_time = time.time()
model.train()
train_dataset.set_epoch_seed(epoch * 10000)
val_dataset.set_epoch_seed(0)
if train_sampler is not None:
train_sampler.set_epoch(epoch)
epoch_train_loss = 0.0
train_batches = 0
for batch_idx, (batch_x, batch_x_stamp) in enumerate(train_loader):
batch_x = batch_x.to(device, non_blocking=True)
batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)
with torch.no_grad():
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
loss, s1_loss, s2_loss = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=3.0)
optimizer.step()
scheduler.step()
epoch_train_loss += loss.item()
train_batches += 1
if (batch_idx_global + 1) % config.log_interval == 0:
lr = optimizer.param_groups[0]['lr']
log_msg = (f"[Epoch {epoch+1}/{config.basemodel_epochs}, Step {batch_idx+1}/{len(train_loader)}] "
f"LR: {lr:.6f}, Loss: {loss.item():.4f}")
logger.info(log_msg)
if rank == 0:
print(log_msg)
batch_idx_global += 1
model.eval()
val_loss = 0.0
val_batches = 0
with torch.no_grad():
for batch_x, batch_x_stamp in val_loader:
batch_x = batch_x.to(device, non_blocking=True)
batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
loss, _, _ = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
val_loss += loss.item()
val_batches += 1
if use_ddp:
tensor_sum = torch.tensor([epoch_train_loss, train_batches, val_loss, val_batches], dtype=torch.float64, device=device)
dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM)
epoch_train_loss_all = tensor_sum[0].item()
train_batches_all = int(tensor_sum[1].item())
val_loss_all = tensor_sum[2].item()
val_batches_all = int(tensor_sum[3].item())
avg_train_loss = (epoch_train_loss_all / train_batches_all) if train_batches_all > 0 else 0.0
avg_val_loss = (val_loss_all / val_batches_all) if val_batches_all > 0 else 0.0
else:
avg_train_loss = epoch_train_loss / train_batches if train_batches > 0 else 0
avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
epoch_time = time.time() - epoch_start_time
epoch_summary = (f"\n--- Epoch {epoch+1}/{config.basemodel_epochs} Summary ---\n"
f"Training Loss: {avg_train_loss:.4f}\n"
f"Validation Loss: {avg_val_loss:.4f}\n"
f"Epoch Time: {epoch_time:.2f} seconds\n")
logger.info(epoch_summary)
if rank == 0:
print(epoch_summary)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
if rank == 0:
model_save_path = os.path.join(save_dir, "best_model")
os.makedirs(model_save_path, exist_ok=True)
(model.module if use_ddp else model).save_pretrained(model_save_path)
save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})"
logger.info(save_msg)
print(save_msg)
return best_val_loss
def main():
import argparse
parser = argparse.ArgumentParser(description='Kronos Basemodel Fine-tuning Training')
parser.add_argument('--config', type=str, default='config.yaml',
help='Configuration file path (default: config.yaml)')
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
config = CustomFinetuneConfig(args.config)
os.makedirs(config.basemodel_save_path, exist_ok=True)
log_dir = os.path.join(config.base_save_path, "logs")
logger = setup_logging(config.exp_name, log_dir, 0)
torch.manual_seed(config.seed)
np.random.seed(config.seed)
random.seed(config.seed)
logger.info("Loading pretrained model or random initialization...")
print("Loading pretrained model or random initialization...")
if getattr(config, 'pre_trained_tokenizer', True):
tokenizer = KronosTokenizer.from_pretrained(config.finetuned_tokenizer_path)
else:
import json, os
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for training")
cfg_path_tok = os.path.join(config.pretrained_tokenizer_path if hasattr(config, 'pretrained_tokenizer_path') else config.finetuned_tokenizer_path, 'config.json')
with open(cfg_path_tok, 'r') as f:
arch_t = json.load(f)
tokenizer = KronosTokenizer(
d_in=arch_t.get('d_in', 6),
d_model=arch_t.get('d_model', 256),
n_heads=arch_t.get('n_heads', 4),
ff_dim=arch_t.get('ff_dim', 512),
n_enc_layers=arch_t.get('n_enc_layers', 4),
n_dec_layers=arch_t.get('n_dec_layers', 4),
ffn_dropout_p=arch_t.get('ffn_dropout_p', 0.0),
attn_dropout_p=arch_t.get('attn_dropout_p', 0.0),
resid_dropout_p=arch_t.get('resid_dropout_p', 0.0),
s1_bits=arch_t.get('s1_bits', 10),
s2_bits=arch_t.get('s2_bits', 10),
beta=arch_t.get('beta', 0.05),
gamma0=arch_t.get('gamma0', 1.0),
gamma=arch_t.get('gamma', 1.1),
zeta=arch_t.get('zeta', 0.05),
group_size=arch_t.get('group_size', 4)
)
if getattr(config, 'pre_trained_predictor', True):
model = Kronos.from_pretrained(config.pretrained_predictor_path)
else:
import json, os
print("pre_trained_predictor=False, randomly initializing Predictor architecture for training")
cfg_path = os.path.join(config.pretrained_predictor_path, 'config.json')
with open(cfg_path, 'r') as f:
arch = json.load(f)
model = Kronos(
s1_bits=arch.get('s1_bits', 10),
s2_bits=arch.get('s2_bits', 10),
n_layers=arch.get('n_layers', 12),
d_model=arch.get('d_model', 832),
n_heads=arch.get('n_heads', 16),
ff_dim=arch.get('ff_dim', 2048),
ffn_dropout_p=arch.get('ffn_dropout_p', 0.2),
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
resid_dropout_p=arch.get('resid_dropout_p', 0.2),
token_dropout_p=arch.get('token_dropout_p', 0.0),
learn_te=arch.get('learn_te', True)
)
tokenizer = tokenizer.to(device)
model = model.to(device)
model_size = sum(p.numel() for p in model.parameters())
logger.info(f"Model parameters: {model_size:,}")
print(f"Model parameters: {model_size:,}")
logger.info("=== Training Configuration ===")
logger.info(f"Data path: {config.data_path}")
logger.info(f"Lookback window: {config.lookback_window}")
logger.info(f"Predict window: {config.predict_window}")
logger.info(f"Batch size: {config.batch_size}")
logger.info(f"Learning rate: {config.predictor_learning_rate}")
logger.info(f"Training epochs: {config.basemodel_epochs}")
logger.info(f"Device: {device}")
logger.info(f"Tokenizer path: {config.finetuned_tokenizer_path}")
logger.info(f"Pretrained model path: {config.pretrained_predictor_path}")
logger.info("Starting fine-tuning training...")
print("Starting fine-tuning training...")
best_val_loss = train_model(model, tokenizer, device, config, config.basemodel_save_path, logger)
final_msg = f"Training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.basemodel_save_path}"
logger.info(final_msg)
print(final_msg)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,359 @@
import os
import sys
import json
import time
import random
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from time import gmtime, strftime
import datetime
import logging
from logging.handlers import RotatingFileHandler
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
sys.path.append("../")
from model import KronosTokenizer
from finetune_base_model import CustomKlineDataset
from config_loader import CustomFinetuneConfig
def set_seed(seed: int, rank: int = 0):
actual_seed = seed
random.seed(actual_seed)
np.random.seed(actual_seed)
torch.manual_seed(actual_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(actual_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_model_size(model: torch.nn.Module) -> str:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
if total_params >= 1e9:
return f"{total_params / 1e9:.1f}B"
elif total_params >= 1e6:
return f"{total_params / 1e6:.1f}M"
else:
return f"{total_params / 1e3:.1f}K"
def format_time(seconds: float) -> str:
return str(datetime.timedelta(seconds=int(seconds)))
def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger:
os.makedirs(log_dir, exist_ok=True)
logger = logging.getLogger(f"tokenizer_training_rank_{rank}")
logger.setLevel(logging.INFO)
if logger.handlers:
return logger
log_file = os.path.join(log_dir, f"tokenizer_training_rank_{rank}.log")
file_handler = RotatingFileHandler(
log_file,
maxBytes=10*1024*1024,
backupCount=5,
encoding='utf-8'
)
file_handler.setLevel(logging.INFO)
console_handler = None
if rank == 0:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter)
if console_handler is not None:
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if console_handler is not None:
logger.addHandler(console_handler)
logger.info(f"=== Tokenizer Training Started ===")
logger.info(f"Experiment Name: {exp_name}")
logger.info(f"Log Directory: {log_dir}")
logger.info(f"Rank: {rank}")
logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
return logger
def create_dataloaders(config):
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
print("Creating tokenizer training data loaders...")
train_dataset = CustomKlineDataset(
data_path=config.data_path,
data_type="train",
lookback_window=config.lookback_window,
predict_window=config.predict_window,
clip=config.clip,
seed=config.seed,
train_ratio=config.train_ratio,
val_ratio=config.val_ratio,
test_ratio=config.test_ratio
)
val_dataset = CustomKlineDataset(
data_path=config.data_path,
data_type="val",
lookback_window=config.lookback_window,
predict_window=config.predict_window,
clip=config.clip,
seed=config.seed + 1,
train_ratio=config.train_ratio,
val_ratio=config.val_ratio,
test_ratio=config.test_ratio
)
use_ddp = dist.is_available() and dist.is_initialized()
train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None
val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=(train_sampler is None),
num_workers=config.num_workers,
pin_memory=True,
drop_last=True,
sampler=train_sampler
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
drop_last=False,
sampler=val_sampler
)
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler
def train_tokenizer(model, device, config, save_dir, logger):
logger.info("Starting tokenizer training...")
use_ddp = dist.is_available() and dist.is_initialized()
rank = dist.get_rank() if use_ddp else 0
world_size = dist.get_world_size() if use_ddp else 1
train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.tokenizer_learning_rate,
weight_decay=config.adam_weight_decay
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=config.tokenizer_learning_rate,
steps_per_epoch=len(train_loader),
epochs=config.tokenizer_epochs,
pct_start=0.03,
div_factor=10
)
if use_ddp:
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
best_val_loss = float("inf")
batch_idx_global = 0
accumulation_steps = getattr(config, 'accumulation_steps', 1)
for epoch in range(config.tokenizer_epochs):
epoch_start_time = time.time()
model.train()
train_dataset.set_epoch_seed(epoch * 10000)
val_dataset.set_epoch_seed(0)
if train_sampler is not None:
train_sampler.set_epoch(epoch)
for batch_idx, (ori_batch_x, _) in enumerate(train_loader):
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
current_batch_total_loss = 0.0
for j in range(accumulation_steps):
start_idx = j * (ori_batch_x.shape[0] // accumulation_steps)
end_idx = (j + 1) * (ori_batch_x.shape[0] // accumulation_steps)
batch_x = ori_batch_x[start_idx:end_idx]
zs, bsq_loss, _, _ = (model.module if use_ddp else model)(batch_x)
z_pre, z = zs
recon_loss_pre = F.mse_loss(z_pre, batch_x)
recon_loss_all = F.mse_loss(z, batch_x)
recon_loss = recon_loss_pre + recon_loss_all
loss = (recon_loss + bsq_loss) / 2
loss_scaled = loss / accumulation_steps
current_batch_total_loss += loss.item()
loss_scaled.backward()
torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=2.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if (batch_idx_global + 1) % config.log_interval == 0:
avg_loss = current_batch_total_loss / accumulation_steps
lr = optimizer.param_groups[0]["lr"]
log_msg = (f"[Epoch {epoch+1}/{config.tokenizer_epochs}, Step {batch_idx+1}/{len(train_loader)}] "
f"LR: {lr:.6f}, Loss: {avg_loss:.4f}")
logger.info(log_msg)
if rank == 0:
print(log_msg)
detail_msg = (f" - VQ Loss: {bsq_loss.item():.4f}\n"
f" - Recon Loss Pre: {recon_loss_pre.item():.4f}\n"
f" - Recon Loss All: {recon_loss_all.item():.4f}")
logger.info(detail_msg)
if rank == 0:
print(detail_msg)
batch_idx_global += 1
model.eval()
tot_val_loss_sum_rank = 0.0
val_sample_count_rank = 0
with torch.no_grad():
for ori_batch_x, _ in val_loader:
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
zs, _, _, _ = (model.module if use_ddp else model)(ori_batch_x)
_, z = zs
val_loss_item = F.mse_loss(z, ori_batch_x)
tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0)
val_sample_count_rank += ori_batch_x.size(0)
if use_ddp:
tensor_sum = torch.tensor([tot_val_loss_sum_rank, val_sample_count_rank], dtype=torch.float64, device=device)
dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM)
tot_val_loss_all = tensor_sum[0].item()
val_count_all = int(tensor_sum[1].item())
avg_val_loss = (tot_val_loss_all / val_count_all) if val_count_all > 0 else 0.0
else:
avg_val_loss = tot_val_loss_sum_rank / val_sample_count_rank if val_sample_count_rank > 0 else 0
epoch_time = time.time() - epoch_start_time
epoch_summary = (f"\n--- Epoch {epoch+1}/{config.tokenizer_epochs} Summary ---\n"
f"Validation Loss: {avg_val_loss:.4f}\n"
f"Epoch Time: {format_time(epoch_time)}\n"
f"Total Training Time: {format_time(time.time() - epoch_start_time)}\n")
logger.info(epoch_summary)
if rank == 0:
print(epoch_summary)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
if rank == 0:
model_save_path = os.path.join(save_dir, "best_model")
os.makedirs(model_save_path, exist_ok=True)
(model.module if use_ddp else model).save_pretrained(model_save_path)
save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})"
logger.info(save_msg)
print(save_msg)
return best_val_loss
def main():
import argparse
parser = argparse.ArgumentParser(description='Kronos Tokenizer Fine-tuning Training')
parser.add_argument('--config', type=str, default='config.yaml',
help='Configuration file path (default: config.yaml)')
args = parser.parse_args()
config = CustomFinetuneConfig(args.config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
config = CustomFinetuneConfig(args.config)
os.makedirs(config.tokenizer_save_path, exist_ok=True)
log_dir = os.path.join(config.base_save_path, "logs")
logger = setup_logging(config.exp_name, log_dir, 0)
set_seed(config.seed)
# 加载预训练tokenizer
if getattr(config, 'pre_trained_tokenizer', True):
logger.info("Loading pretrained tokenizer...")
print("Loading pretrained tokenizer...")
tokenizer = KronosTokenizer.from_pretrained(config.pretrained_tokenizer_path)
else:
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture")
import json, os
cfg_path = os.path.join(config.pretrained_tokenizer_path, 'config.json')
with open(cfg_path, 'r') as f:
arch = json.load(f)
tokenizer = KronosTokenizer(
d_in=arch.get('d_in', 6),
d_model=arch.get('d_model', 256),
n_heads=arch.get('n_heads', 4),
ff_dim=arch.get('ff_dim', 512),
n_enc_layers=arch.get('n_enc_layers', 4),
n_dec_layers=arch.get('n_dec_layers', 4),
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
s1_bits=arch.get('s1_bits', 10),
s2_bits=arch.get('s2_bits', 10),
beta=arch.get('beta', 0.05),
gamma0=arch.get('gamma0', 1.0),
gamma=arch.get('gamma', 1.1),
zeta=arch.get('zeta', 0.05),
group_size=arch.get('group_size', 4)
)
tokenizer = tokenizer.to(device)
model_size = get_model_size(tokenizer)
logger.info(f"Tokenizer parameters: {model_size}")
print(f"Tokenizer parameters: {model_size}")
logger.info("=== Training Configuration ===")
logger.info(f"Data path: {config.data_path}")
logger.info(f"Lookback window: {config.lookback_window}")
logger.info(f"Predict window: {config.predict_window}")
logger.info(f"Batch size: {config.batch_size}")
logger.info(f"Learning rate: {config.tokenizer_learning_rate}")
logger.info(f"Training epochs: {config.tokenizer_epochs}")
logger.info(f"Device: {device}")
logger.info(f"Distributed training: False")
logger.info("Starting tokenizer fine-tuning training...")
print("Starting tokenizer fine-tuning training...")
best_val_loss = train_tokenizer(tokenizer, device, config, config.tokenizer_save_path, logger)
final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.tokenizer_save_path}"
logger.info(final_msg)
print(final_msg)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,361 @@
import os
import sys
import time
import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.distributed as dist
sys.path.append('../')
from model import Kronos, KronosTokenizer, KronosPredictor
from config_loader import CustomFinetuneConfig
from finetune_tokenizer import train_tokenizer, set_seed, setup_logging as setup_tokenizer_logging
from finetune_base_model import train_model, create_dataloaders, setup_logging as setup_basemodel_logging
class SequentialTrainer:
def __init__(self, config_path: str = None):
self.config = CustomFinetuneConfig(config_path)
self.rank = int(os.environ.get("RANK", "0"))
self.world_size = int(os.environ.get("WORLD_SIZE", "1"))
self.local_rank = int(os.environ.get("LOCAL_RANK", str(self.config.device_id if hasattr(self.config, 'device_id') else 0)))
self.device = self._setup_device()
self.config.print_config_summary()
def _setup_device(self):
if self.config.use_cuda and torch.cuda.is_available():
torch.cuda.set_device(self.local_rank)
device = torch.device(f"cuda:{self.local_rank}")
else:
device = torch.device("cpu")
if self.rank == 0:
print(f"Using device: {device} (rank={self.rank}, world_size={self.world_size}, local_rank={self.local_rank})")
return device
def _setup_distributed(self):
if self.world_size > 1 and torch.cuda.is_available():
backend = os.environ.get("DIST_BACKEND", "nccl").lower()
if not dist.is_initialized():
dist.init_process_group(backend=backend)
if self.rank == 0:
print(f"Distributed training initialized: backend={backend}, world_size={self.world_size}")
else:
if self.rank == 0:
print("Distributed training not enabled, using single GPU/CPU training")
def _check_existing_models(self):
tokenizer_exists = os.path.exists(self.config.tokenizer_best_model_path)
basemodel_exists = os.path.exists(self.config.basemodel_best_model_path)
print(f"Tokenizer model exists: {tokenizer_exists}")
print(f"Basemodel model exists: {basemodel_exists}")
return tokenizer_exists, basemodel_exists
def _create_directories(self):
os.makedirs(self.config.tokenizer_save_path, exist_ok=True)
os.makedirs(self.config.basemodel_save_path, exist_ok=True)
print(f"Created directory: {self.config.tokenizer_save_path}")
print(f"Created directory: {self.config.basemodel_save_path}")
def train_tokenizer_phase(self):
print("\n" + "="*60)
print("Starting Tokenizer Fine-tuning Phase")
print("="*60)
tokenizer_exists, _ = self._check_existing_models()
if tokenizer_exists and self.config.skip_existing:
print("Tokenizer model already exists, skipping training")
return True
log_dir = os.path.join(self.config.base_save_path, "logs")
logger = setup_tokenizer_logging(self.config.exp_name, log_dir, self.rank)
set_seed(self.config.seed)
if getattr(self.config, 'pre_trained_tokenizer', True):
logger.info("Loading pretrained tokenizer...")
if self.rank == 0:
print("Loading pretrained tokenizer...")
tokenizer = KronosTokenizer.from_pretrained(self.config.pretrained_tokenizer_path)
else:
if self.rank == 0:
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture")
import json
cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json')
with open(cfg_path, 'r') as f:
arch = json.load(f)
tokenizer = KronosTokenizer(
d_in=arch.get('d_in', 6),
d_model=arch.get('d_model', 256),
n_heads=arch.get('n_heads', 4),
ff_dim=arch.get('ff_dim', 512),
n_enc_layers=arch.get('n_enc_layers', 4),
n_dec_layers=arch.get('n_dec_layers', 4),
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
s1_bits=arch.get('s1_bits', 10),
s2_bits=arch.get('s2_bits', 10),
beta=arch.get('beta', 0.05),
gamma0=arch.get('gamma0', 1.0),
gamma=arch.get('gamma', 1.1),
zeta=arch.get('zeta', 0.05),
group_size=arch.get('group_size', 4)
)
tokenizer = tokenizer.to(self.device)
model_size = sum(p.numel() for p in tokenizer.parameters())
logger.info(f"Tokenizer parameters: {model_size:,}")
if self.rank == 0:
print(f"Tokenizer parameters: {model_size:,}")
logger.info("=== Training Configuration ===")
logger.info(f"Data path: {self.config.data_path}")
logger.info(f"Lookback window: {self.config.lookback_window}")
logger.info(f"Predict window: {self.config.predict_window}")
logger.info(f"Batch size: {self.config.batch_size}")
logger.info(f"Learning rate: {self.config.tokenizer_learning_rate}")
logger.info(f"Training epochs: {self.config.tokenizer_epochs}")
logger.info(f"Device: {self.device}")
logger.info(f"Distributed training: False")
logger.info("Starting tokenizer fine-tuning training...")
if self.rank == 0:
print("Starting tokenizer fine-tuning training...")
start_time = time.time()
best_val_loss = train_tokenizer(
tokenizer,
self.device,
self.config,
self.config.tokenizer_save_path,
logger,
)
training_time = time.time() - start_time
final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.tokenizer_save_path}"
logger.info(final_msg)
if self.rank == 0:
print(f"\n{final_msg}")
return True
def train_basemodel_phase(self):
print("\n" + "="*60)
print("Starting Basemodel Fine-tuning Phase")
print("="*60)
if getattr(self.config, 'pre_trained_tokenizer', True):
if not os.path.exists(self.config.finetuned_tokenizer_path):
raise FileNotFoundError(f"Fine-tuned tokenizer does not exist: {self.config.finetuned_tokenizer_path}")
_, basemodel_exists = self._check_existing_models()
if basemodel_exists and self.config.skip_existing:
print("Basemodel model already exists, skipping training")
return True
log_dir = os.path.join(self.config.base_save_path, "logs")
logger = setup_basemodel_logging(self.config.exp_name, log_dir, self.rank)
set_seed(self.config.seed)
if getattr(self.config, 'pre_trained_tokenizer', True):
logger.info("Loading fine-tuned tokenizer...")
if self.rank == 0:
print("Loading fine-tuned tokenizer...")
tokenizer = KronosTokenizer.from_pretrained(self.config.finetuned_tokenizer_path)
else:
if self.rank == 0:
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for Predictor training")
import json
cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json')
with open(cfg_path, 'r') as f:
arch = json.load(f)
tokenizer = KronosTokenizer(
d_in=arch.get('d_in', 6),
d_model=arch.get('d_model', 256),
n_heads=arch.get('n_heads', 4),
ff_dim=arch.get('ff_dim', 512),
n_enc_layers=arch.get('n_enc_layers', 4),
n_dec_layers=arch.get('n_dec_layers', 4),
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
s1_bits=arch.get('s1_bits', 10),
s2_bits=arch.get('s2_bits', 10),
beta=arch.get('beta', 0.05),
gamma0=arch.get('gamma0', 1.0),
gamma=arch.get('gamma', 1.1),
zeta=arch.get('zeta', 0.05),
group_size=arch.get('group_size', 4)
)
tokenizer = tokenizer.to(self.device)
if getattr(self.config, 'pre_trained_predictor', True):
logger.info("Loading pretrained predictor...")
if self.rank == 0:
print("Loading pretrained predictor...")
model = Kronos.from_pretrained(self.config.pretrained_predictor_path)
else:
if self.rank == 0:
print("pre_trained_predictor=False, randomly initializing Predictor architecture")
import json
cfg_path = os.path.join(self.config.pretrained_predictor_path, 'config.json')
with open(cfg_path, 'r') as f:
arch = json.load(f)
print("model_config: ", arch)
model = Kronos(
s1_bits=arch.get('s1_bits', 10),
s2_bits=arch.get('s2_bits', 10),
n_layers=arch.get('n_layers', 12),
d_model=arch.get('d_model', 832),
n_heads=arch.get('n_heads', 16),
ff_dim=arch.get('ff_dim', 2048),
ffn_dropout_p=arch.get('ffn_dropout_p', 0.2),
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
resid_dropout_p=arch.get('resid_dropout_p', 0.2),
token_dropout_p=arch.get('token_dropout_p', 0.0),
learn_te=arch.get('learn_te', True)
)
model = model.to(self.device)
model_size = sum(p.numel() for p in model.parameters())
logger.info(f"Model parameters: {model_size:,}")
if self.rank == 0:
print(f"Model parameters: {model_size:,}")
logger.info("=== Training Configuration ===")
logger.info(f"Data path: {self.config.data_path}")
logger.info(f"Lookback window: {self.config.lookback_window}")
logger.info(f"Predict window: {self.config.predict_window}")
logger.info(f"Batch size: {self.config.batch_size}")
logger.info(f"Learning rate: {self.config.predictor_learning_rate}")
logger.info(f"Training epochs: {self.config.basemodel_epochs}")
logger.info(f"Device: {self.device}")
logger.info(f"Tokenizer path: {self.config.finetuned_tokenizer_path}")
logger.info(f"Pretrained model path: {self.config.pretrained_predictor_path}")
logger.info("Starting fine-tuning training...")
if self.rank == 0:
print("Starting fine-tuning training...")
start_time = time.time()
best_val_loss = train_model(
model,
tokenizer,
self.device,
self.config,
self.config.basemodel_save_path,
logger,
)
training_time = time.time() - start_time
final_msg = f"Basemodel training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.basemodel_save_path}"
logger.info(final_msg)
if self.rank == 0:
print(f"\n{final_msg}")
return True
def run_training(self):
if self.rank == 0:
print("Starting Kronos model sequential fine-tuning training")
print(f"Experiment name: {self.config.experiment_name}")
print(f"Experiment description: {self.config.experiment_description}")
self._setup_distributed()
self._create_directories()
tokenizer_exists, basemodel_exists = self._check_existing_models()
total_start_time = time.time()
try:
if self.config.train_tokenizer:
success = self.train_tokenizer_phase()
if not success:
print("Tokenizer training failed, terminating training")
return False
else:
print("Skipping Tokenizer training phase")
if self.config.train_basemodel:
success = self.train_basemodel_phase()
if not success:
print("Basemodel training failed, terminating training")
return False
else:
print("Skipping Basemodel training phase")
total_time = time.time() - total_start_time
if self.rank == 0:
print("\n" + "="*60)
print("Training completed!")
print("="*60)
print(f"Total training time: {total_time/60:.2f} minutes")
print(f"Tokenizer model: {self.config.tokenizer_best_model_path}")
print(f"Basemodel model: {self.config.basemodel_best_model_path}")
print("="*60)
return True
except Exception as e:
if self.rank == 0:
print(f"Error occurred during training: {str(e)}")
import traceback
traceback.print_exc()
return False
finally:
pass
def main():
parser = argparse.ArgumentParser(description='Kronos Model Sequential Fine-tuning Training')
parser.add_argument('--config', type=str, default='config.yaml',
help='Configuration file path (default: config.yaml)')
parser.add_argument('--skip-tokenizer', action='store_true',
help='Skip tokenizer training phase')
parser.add_argument('--skip-basemodel', action='store_true',
help='Skip basemodel training phase')
parser.add_argument('--skip-existing', action='store_true',
help='Skip training for existing models')
args = parser.parse_args()
trainer = SequentialTrainer(args.config)
if args.skip_tokenizer:
trainer.config.train_tokenizer = False
if args.skip_basemodel:
trainer.config.train_basemodel = False
if args.skip_existing:
trainer.config.skip_existing = True
success = trainer.run_training()
if success:
print("Training completed successfully!")
if dist.is_available() and dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
sys.exit(0)
else:
print("Training failed!")
if dist.is_available() and dist.is_initialized():
try:
dist.barrier()
dist.destroy_process_group()
except Exception:
pass
sys.exit(1)
if __name__ == "__main__":
main()