diff --git a/finetune_csv/README.md b/finetune_csv/README.md new file mode 100644 index 0000000..759096a --- /dev/null +++ b/finetune_csv/README.md @@ -0,0 +1,101 @@ +# Kronos Fine-tuning Training with Custom Dataset + +Supports fine-tuning training with custom CSV data using configuration files + +## Quick Start + +### 1. Configuration Setup + +First edit the `config.yaml` file to set the correct paths and parameters: + +```yaml +# Data configuration +data: + data_path: "/path/to/your/data.csv" + lookback_window: 512 + predict_window: 48 + # ... other parameters + +# Model path configuration +model_paths: + pretrained_tokenizer: "/path/to/pretrained/tokenizer" + pretrained_predictor: "/path/to/pretrained/predictor" + base_save_path: "/path/to/save/models" + # ... other paths +``` + +### 2. Run Training + +Using train_sequential + +```bash +# Complete training +python train_sequential.py --config configs/config_ali09988_candle-5min.yaml + +# Skip existing models +python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing + +# Only train tokenizer +python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel + +# Only train basemodel +python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer +``` + +Run each stage separately + +```bash +# Only train tokenizer +python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml + +# Only train basemodel (requires fine-tuned tokenizer first) +python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml +``` + +DDP Training +```bash +# Choose communication protocol yourself, nccl can be replaced with gloo +DIST_BACKEND=nccl \ +torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml +``` + +## Configuration Description + +### Main Configuration Items + +- **data**: Data-related configuration + - `data_path`: CSV data file path + - `lookback_window`: Lookback window size + - `predict_window`: Prediction window size + - `train_ratio/val_ratio/test_ratio`: Dataset split ratios + +- **training**: Training-related configuration + - `epochs`: Number of training epochs + - `batch_size`: Batch size + - `tokenizer_learning_rate`: Tokenizer learning rate + - `predictor_learning_rate`: Predictor learning rate + +- **model_paths**: Model path configuration + - `pretrained_tokenizer`: Pre-trained tokenizer path + - `pretrained_predictor`: Pre-trained predictor path + - `base_save_path`: Model save root directory + - `finetuned_tokenizer`: Fine-tuned tokenizer path (for basemodel training) + +- **experiment**: Experiment control + - `train_tokenizer`: Whether to train tokenizer + - `train_basemodel`: Whether to train basemodel + - `skip_existing`: Whether to skip existing models + +## Training Process + +1. **Tokenizer Fine-tuning Stage** + - Load pre-trained tokenizer + - Fine-tune on custom data + - Save fine-tuned tokenizer to `{base_save_path}/tokenizer/best_model/` + +2. **Basemodel Fine-tuning Stage** + - Load fine-tuned tokenizer and pre-trained predictor + - Fine-tune on custom data + - Save fine-tuned basemodel to `{base_save_path}/basemodel/best_model/` + +**Data Format**: Ensure CSV file contains the following columns: `timestamps`, `open`, `high`, `low`, `close`, `volume`, `amount` diff --git a/finetune_csv/README_CN.md b/finetune_csv/README_CN.md new file mode 100644 index 0000000..05269ee --- /dev/null +++ b/finetune_csv/README_CN.md @@ -0,0 +1,105 @@ +# 自定义数据集的Kronos微调训练 + +支持使用配置文件进行自定义csv数据的微调训练 + +## 快速开始 + +### 1. 配置设置 + +首先编辑 `config.yaml` 文件,设置正确的路径和参数: + +```yaml +# 数据配置 +data: + data_path: "/path/to/your/data.csv" + lookback_window: 512 + predict_window: 48 + # ... 其他参数 + +# 模型路径配置 +model_paths: + pretrained_tokenizer: "/path/to/pretrained/tokenizer" + pretrained_predictor: "/path/to/pretrained/predictor" + base_save_path: "/path/to/save/models" + # ... 其他路径 +``` + +### 2. 运行训练 + + +使用train_sequential + +```bash +# 完整训练 +python train_sequential.py --config configs/config_ali09988_candle-5min.yaml + +# 跳过已存在的模型 +python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing + +# 只训练tokenizer +python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel + +# 只训练basemodel +python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer +``` + +单独运行各个阶段 + +```bash +# 只训练tokenizer +python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml + +# 只训练basemodel(需要先有微调后的tokenizer) +python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml +``` + +DDP训练 +```bash +# 通信协议自行选择,nccl可替换gloo +DIST_BACKEND=nccl \ +torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml +``` + +## 配置说明 + +### 主要配置项 + +- **data**: 数据相关配置 + - `data_path`: CSV数据文件路径 + - `lookback_window`: 回望窗口大小 + - `predict_window`: 预测窗口大小 + - `train_ratio/val_ratio/test_ratio`: 数据集分割比例 + +- **training**: 训练相关配置 + - `epochs`: 训练轮数 + - `batch_size`: 批次大小 + - `tokenizer_learning_rate`: Tokenizer学习率 + - `predictor_learning_rate`: Predictor学习率 + +- **model_paths**: 模型路径配置 + - `pretrained_tokenizer`: 预训练tokenizer路径 + - `pretrained_predictor`: 预训练predictor路径 + - `base_save_path`: 模型保存根目录 + - `finetuned_tokenizer`: 微调后tokenizer路径(用于basemodel训练) + +- **experiment**: 实验控制 + - `train_tokenizer`: 是否训练tokenizer + - `train_basemodel`: 是否训练basemodel + - `skip_existing`: 是否跳过已存在的模型 + +## 训练流程 + +1. **Tokenizer微调阶段** + - 加载预训练tokenizer + - 在自定义数据上微调 + - 保存微调后的tokenizer到 `{base_save_path}/tokenizer/best_model/` + +2. **Basemodel微调阶段** + - 加载微调后的tokenizer和预训练predictor + - 在自定义数据上微调 + - 保存微调后的basemodel到 `{base_save_path}/basemodel/best_model/` + + + **数据格式**: 确保CSV文件包含以下列:`timestamps`, `open`, `high`, `low`, `close`, `volume`, `amount` + + diff --git a/finetune_csv/config_loader.py b/finetune_csv/config_loader.py new file mode 100644 index 0000000..6bddcae --- /dev/null +++ b/finetune_csv/config_loader.py @@ -0,0 +1,267 @@ +import os +import yaml +from typing import Dict, Any + + +class ConfigLoader: + + def __init__(self, config_path: str): + + self.config_path = config_path + self.config = self._load_config() + + def _load_config(self) -> Dict[str, Any]: + + if not os.path.exists(self.config_path): + raise FileNotFoundError(f"config file not found: {self.config_path}") + + with open(self.config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + config = self._resolve_dynamic_paths(config) + + return config + + def _resolve_dynamic_paths(self, config: Dict[str, Any]) -> Dict[str, Any]: + + exp_name = config.get('model_paths', {}).get('exp_name', '') + if not exp_name: + return config + + base_path = config.get('model_paths', {}).get('base_path', '') + path_templates = { + 'base_save_path': f"{base_path}/{exp_name}", + 'finetuned_tokenizer': f"{base_path}/{exp_name}/tokenizer/best_model" + } + + if 'model_paths' in config: + for key, template in path_templates.items(): + if key in config['model_paths']: + # only use template when the original value is empty string + current_value = config['model_paths'][key] + if current_value == "" or current_value is None: + config['model_paths'][key] = template + else: + # if the original value is not empty, use template to replace the {exp_name} placeholder + if isinstance(current_value, str) and '{exp_name}' in current_value: + config['model_paths'][key] = current_value.format(exp_name=exp_name) + + return config + + def get(self, key: str, default=None): + + keys = key.split('.') + value = self.config + + try: + for k in keys: + value = value[k] + return value + except (KeyError, TypeError): + return default + + def get_data_config(self) -> Dict[str, Any]: + return self.config.get('data', {}) + + def get_training_config(self) -> Dict[str, Any]: + return self.config.get('training', {}) + + def get_model_paths(self) -> Dict[str, str]: + return self.config.get('model_paths', {}) + + def get_experiment_config(self) -> Dict[str, Any]: + return self.config.get('experiment', {}) + + def get_device_config(self) -> Dict[str, Any]: + return self.config.get('device', {}) + + def get_distributed_config(self) -> Dict[str, Any]: + return self.config.get('distributed', {}) + + def update_config(self, updates: Dict[str, Any]): + + def update_nested_dict(d, u): + for k, v in u.items(): + if isinstance(v, dict): + d[k] = update_nested_dict(d.get(k, {}), v) + else: + d[k] = v + return d + + self.config = update_nested_dict(self.config, updates) + + def save_config(self, save_path: str = None): + + if save_path is None: + save_path = self.config_path + + with open(save_path, 'w', encoding='utf-8') as f: + yaml.dump(self.config, f, default_flow_style=False, allow_unicode=True, indent=2) + + def print_config(self): + print("=" * 50) + print("Current configuration:") + print("=" * 50) + yaml.dump(self.config, default_flow_style=False, allow_unicode=True, indent=2) + print("=" * 50) + + +class CustomFinetuneConfig: + + def __init__(self, config_path: str = None): + + if config_path is None: + config_path = os.path.join(os.path.dirname(__file__), 'config.yaml') + + self.loader = ConfigLoader(config_path) + self._load_all_configs() + + def _load_all_configs(self): + + data_config = self.loader.get_data_config() + self.data_path = data_config.get('data_path') + self.lookback_window = data_config.get('lookback_window', 512) + self.predict_window = data_config.get('predict_window', 48) + self.max_context = data_config.get('max_context', 512) + self.clip = data_config.get('clip', 5.0) + self.train_ratio = data_config.get('train_ratio', 0.9) + self.val_ratio = data_config.get('val_ratio', 0.1) + self.test_ratio = data_config.get('test_ratio', 0.0) + + # training configuration + training_config = self.loader.get_training_config() + # support training epochs of tokenizer and basemodel separately + self.tokenizer_epochs = training_config.get('tokenizer_epochs', 30) + self.basemodel_epochs = training_config.get('basemodel_epochs', 30) + + if 'epochs' in training_config and 'tokenizer_epochs' not in training_config: + self.tokenizer_epochs = training_config.get('epochs', 30) + if 'epochs' in training_config and 'basemodel_epochs' not in training_config: + self.basemodel_epochs = training_config.get('epochs', 30) + + self.batch_size = training_config.get('batch_size', 160) + self.log_interval = training_config.get('log_interval', 50) + self.num_workers = training_config.get('num_workers', 6) + self.seed = training_config.get('seed', 100) + self.tokenizer_learning_rate = training_config.get('tokenizer_learning_rate', 2e-4) + self.predictor_learning_rate = training_config.get('predictor_learning_rate', 4e-5) + self.adam_beta1 = training_config.get('adam_beta1', 0.9) + self.adam_beta2 = training_config.get('adam_beta2', 0.95) + self.adam_weight_decay = training_config.get('adam_weight_decay', 0.1) + self.accumulation_steps = training_config.get('accumulation_steps', 1) + + model_paths = self.loader.get_model_paths() + self.exp_name = model_paths.get('exp_name', 'default_experiment') + self.pretrained_tokenizer_path = model_paths.get('pretrained_tokenizer') + self.pretrained_predictor_path = model_paths.get('pretrained_predictor') + self.base_save_path = model_paths.get('base_save_path') + self.tokenizer_save_name = model_paths.get('tokenizer_save_name', 'tokenizer') + self.basemodel_save_name = model_paths.get('basemodel_save_name', 'basemodel') + self.finetuned_tokenizer_path = model_paths.get('finetuned_tokenizer') + + experiment_config = self.loader.get_experiment_config() + self.experiment_name = experiment_config.get('name', 'kronos_custom_finetune') + self.experiment_description = experiment_config.get('description', '') + self.use_comet = experiment_config.get('use_comet', False) + self.train_tokenizer = experiment_config.get('train_tokenizer', True) + self.train_basemodel = experiment_config.get('train_basemodel', True) + self.skip_existing = experiment_config.get('skip_existing', False) + + unified_pretrained = experiment_config.get('pre_trained', None) + self.pre_trained_tokenizer = experiment_config.get('pre_trained_tokenizer', unified_pretrained if unified_pretrained is not None else True) + self.pre_trained_predictor = experiment_config.get('pre_trained_predictor', unified_pretrained if unified_pretrained is not None else True) + + device_config = self.loader.get_device_config() + self.use_cuda = device_config.get('use_cuda', True) + self.device_id = device_config.get('device_id', 0) + + distributed_config = self.loader.get_distributed_config() + self.use_ddp = distributed_config.get('use_ddp', False) + self.ddp_backend = distributed_config.get('backend', 'nccl') + + self._compute_full_paths() + + def _compute_full_paths(self): + + self.tokenizer_save_path = os.path.join(self.base_save_path, self.tokenizer_save_name) + self.tokenizer_best_model_path = os.path.join(self.tokenizer_save_path, 'best_model') + + self.basemodel_save_path = os.path.join(self.base_save_path, self.basemodel_save_name) + self.basemodel_best_model_path = os.path.join(self.basemodel_save_path, 'best_model') + + def get_tokenizer_config(self): + + return { + 'data_path': self.data_path, + 'lookback_window': self.lookback_window, + 'predict_window': self.predict_window, + 'max_context': self.max_context, + 'clip': self.clip, + 'train_ratio': self.train_ratio, + 'val_ratio': self.val_ratio, + 'test_ratio': self.test_ratio, + 'epochs': self.tokenizer_epochs, + 'batch_size': self.batch_size, + 'log_interval': self.log_interval, + 'num_workers': self.num_workers, + 'seed': self.seed, + 'learning_rate': self.tokenizer_learning_rate, + 'adam_beta1': self.adam_beta1, + 'adam_beta2': self.adam_beta2, + 'adam_weight_decay': self.adam_weight_decay, + 'accumulation_steps': self.accumulation_steps, + 'pretrained_model_path': self.pretrained_tokenizer_path, + 'save_path': self.tokenizer_save_path, + 'use_comet': self.use_comet + } + + def get_basemodel_config(self): + + return { + 'data_path': self.data_path, + 'lookback_window': self.lookback_window, + 'predict_window': self.predict_window, + 'max_context': self.max_context, + 'clip': self.clip, + 'train_ratio': self.train_ratio, + 'val_ratio': self.val_ratio, + 'test_ratio': self.test_ratio, + 'epochs': self.basemodel_epochs, + 'batch_size': self.batch_size, + 'log_interval': self.log_interval, + 'num_workers': self.num_workers, + 'seed': self.seed, + 'predictor_learning_rate': self.predictor_learning_rate, + 'tokenizer_learning_rate': self.tokenizer_learning_rate, + 'adam_beta1': self.adam_beta1, + 'adam_beta2': self.adam_beta2, + 'adam_weight_decay': self.adam_weight_decay, + 'pretrained_tokenizer_path': self.finetuned_tokenizer_path, + 'pretrained_predictor_path': self.pretrained_predictor_path, + 'save_path': self.basemodel_save_path, + 'use_comet': self.use_comet + } + + def print_config_summary(self): + + print("=" * 60) + print("Kronos finetuning configuration summary") + print("=" * 60) + print(f"Experiment name: {self.exp_name}") + print(f"Data path: {self.data_path}") + print(f"Lookback window: {self.lookback_window}") + print(f"Predict window: {self.predict_window}") + print(f"Tokenizer training epochs: {self.tokenizer_epochs}") + print(f"Basemodel training epochs: {self.basemodel_epochs}") + print(f"Batch size: {self.batch_size}") + print(f"Tokenizer learning rate: {self.tokenizer_learning_rate}") + print(f"Predictor learning rate: {self.predictor_learning_rate}") + print(f"Train tokenizer: {self.train_tokenizer}") + print(f"Train basemodel: {self.train_basemodel}") + print(f"Skip existing: {self.skip_existing}") + print(f"Use pre-trained tokenizer: {self.pre_trained_tokenizer}") + print(f"Use pre-trained predictor: {self.pre_trained_predictor}") + print(f"Base save path: {self.base_save_path}") + print(f"Tokenizer save path: {self.tokenizer_save_path}") + print(f"Basemodel save path: {self.basemodel_save_path}") + print("=" * 60) diff --git a/finetune_csv/configs/config_ali09988_candle-5min.yaml b/finetune_csv/configs/config_ali09988_candle-5min.yaml new file mode 100644 index 0000000..94f4487 --- /dev/null +++ b/finetune_csv/configs/config_ali09988_candle-5min.yaml @@ -0,0 +1,72 @@ +#This is a template config for custom finetuning kronos on csv data +#这是一份模板config,用于kronos的csv自定义数据微调 + +data: + data_path: "/xxxx/Kronos/finetune_csv2/data/HK_ali_09988_kline_5min_all.csv" + lookback_window: 512 + predict_window: 48 + max_context: 512 + clip: 5.0 + # dataset split ratio + train_ratio: 0.9 + val_ratio: 0.1 + test_ratio: 0.0 + +training: + # control the training epochs of tokenizer and basemodel + tokenizer_epochs: 30 + basemodel_epochs: 20 + batch_size: 32 + log_interval: 50 + num_workers: 6 + seed: 42 + + tokenizer_learning_rate: 0.0002 + predictor_learning_rate: 0.000001 + + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_weight_decay: 0.1 + + # gradient accumulation steps for tokenizer training + accumulation_steps: 1 + +# model path configuration +model_paths: + # pretrained model path + pretrained_tokenizer: "/mnt/DigitalHuman2D/boyuzhang/quant/Kronos/pretrained/Kronos-Tokenizer-base" + pretrained_predictor: "/mnt/DigitalHuman2D/boyuzhang/quant/Kronos/pretrained/Kronos-base" + + # experiment name - other paths will be generated based on this + exp_name: "HK_ali_09988_kline_5min_all" + base_path: "/mnt/DigitalHuman2D/boyuzhang/quant/Kronos/finetune_csv/finetuned/" + + # the following paths will be generated based on exp_name, no need to modify manually + # way 1: leave empty string, the system will generate the full path + base_save_path: "" # /xxxx/Kronos/finetune_csv/finetuned/{exp_name} + finetuned_tokenizer: "" # /xxxx/quant/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model + + # way 2: use template string, {exp_name} will be replaced with the actual experiment name + # base_save_path: "/xxxx/Kronos/finetune_csv/finetuned/{exp_name}" + # finetuned_tokenizer: "/xxxx/quant/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model" + + tokenizer_save_name: "tokenizer" + basemodel_save_name: "basemodel" + +experiment: + name: "kronos_custom_finetune" + description: "Custom finetune for HK stock data" + use_comet: false + + # control the training phase + train_tokenizer: true + train_basemodel: true + + # if true, skip the existing model training + skip_existing: false + +# device configuration +device: + use_cuda: true + device_id: 0 + diff --git a/finetune_csv/finetune_base_model.py b/finetune_csv/finetune_base_model.py new file mode 100644 index 0000000..d21c22d --- /dev/null +++ b/finetune_csv/finetune_base_model.py @@ -0,0 +1,468 @@ +import os +import sys +import json +import time +import pickle +import random +import pandas as pd +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from time import gmtime, strftime +import logging +from logging.handlers import RotatingFileHandler +import datetime +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +sys.path.append('../') +from model import Kronos, KronosTokenizer, KronosPredictor +from config_loader import CustomFinetuneConfig + + +class CustomKlineDataset(Dataset): + + def __init__(self, data_path, data_type='train', lookback_window=90, predict_window=10, + clip=5.0, seed=100, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15): + self.data_path = data_path + self.data_type = data_type + self.lookback_window = lookback_window + self.predict_window = predict_window + self.window = lookback_window + predict_window + 1 + self.clip = clip + self.seed = seed + self.train_ratio = train_ratio + self.val_ratio = val_ratio + self.test_ratio = test_ratio + + self.feature_list = ['open', 'high', 'low', 'close', 'volume', 'amount'] + self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month'] + + self.py_rng = random.Random(seed) + + self._load_and_preprocess_data() + self._split_data_by_time() + + self.n_samples = len(self.data) - self.window + 1 + + print(f"[{data_type.upper()}] Data length: {len(self.data)}, Available samples: {self.n_samples}") + + def _load_and_preprocess_data(self): + df = pd.read_csv(self.data_path) + + df['timestamps'] = pd.to_datetime(df['timestamps']) + df = df.sort_values('timestamps').reset_index(drop=True) + + self.timestamps = df['timestamps'].copy() + + df['minute'] = df['timestamps'].dt.minute + df['hour'] = df['timestamps'].dt.hour + df['weekday'] = df['timestamps'].dt.weekday + df['day'] = df['timestamps'].dt.day + df['month'] = df['timestamps'].dt.month + + self.data = df[self.feature_list + self.time_feature_list].copy() + + if self.data.isnull().any().any(): + print("Warning: Missing values found in data, performing forward fill") + self.data = self.data.fillna(method='ffill') + + print(f"Original data time range: {self.timestamps.min()} to {self.timestamps.max()}") + print(f"Original data total length: {len(df)} records") + + def _split_data_by_time(self): + total_length = len(self.data) + + train_end = int(total_length * self.train_ratio) + val_end = int(total_length * (self.train_ratio + self.val_ratio)) + + if self.data_type == 'train': + self.data = self.data.iloc[:train_end].copy() + self.timestamps = self.timestamps.iloc[:train_end].copy() + print(f"[{self.data_type.upper()}] Training set: first {train_end} time points ({self.train_ratio})") + print(f"[{self.data_type.upper()}] Training set time range: {self.timestamps.min()} to {self.timestamps.max()}") + elif self.data_type == 'val': + self.data = self.data.iloc[train_end:val_end].copy() + self.timestamps = self.timestamps.iloc[train_end:val_end].copy() + print(f"[{self.data_type.upper()}] Validation set: time points {train_end+1} to {val_end} ({self.val_ratio})") + print(f"[{self.data_type.upper()}] Validation set time range: {self.timestamps.min()} to {self.timestamps.max()}") + elif self.data_type == 'test': + self.data = self.data.iloc[val_end:].copy() + self.timestamps = self.timestamps.iloc[val_end:].copy() + print(f"[{self.data_type.upper()}] Test set: after time point {val_end+1}") + print(f"[{self.data_type.upper()}] Test set time range: {self.timestamps.min()} to {self.timestamps.max()}") + + print(f"[{self.data_type.upper()}] Data length after split: {len(self.data)} records") + + def set_epoch_seed(self, epoch): + epoch_seed = self.seed + epoch + self.py_rng.seed(epoch_seed) + self.current_epoch = epoch + + def __len__(self): + return self.n_samples + + def __getitem__(self, idx): + max_start = len(self.data) - self.window + if max_start <= 0: + raise ValueError("Data length insufficient to create samples") + + if self.data_type == 'train': + epoch = getattr(self, 'current_epoch', 0) + start_idx = (idx * 9973 + (epoch + 1) * 104729) % (max_start + 1) + else: + start_idx = idx % (max_start + 1) + + end_idx = start_idx + self.window + + window_data = self.data.iloc[start_idx:end_idx] + + x = window_data[self.feature_list].values.astype(np.float32) + x_stamp = window_data[self.time_feature_list].values.astype(np.float32) + + x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) + x = (x - x_mean) / (x_std + 1e-5) + x = np.clip(x, -self.clip, self.clip) + + x_tensor = torch.from_numpy(x) + x_stamp_tensor = torch.from_numpy(x_stamp) + + return x_tensor, x_stamp_tensor + + + + +def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger: + os.makedirs(log_dir, exist_ok=True) + + logger = logging.getLogger(f"basemodel_training_rank_{rank}") + logger.setLevel(logging.INFO) + + if logger.handlers: + return logger + + log_file = os.path.join(log_dir, f"basemodel_training_rank_{rank}.log") + file_handler = RotatingFileHandler( + log_file, + maxBytes=10*1024*1024, + backupCount=5, + encoding='utf-8' + ) + file_handler.setLevel(logging.INFO) + + console_handler = None + if rank == 0: + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + file_handler.setFormatter(formatter) + if console_handler is not None: + console_handler.setFormatter(formatter) + + logger.addHandler(file_handler) + if console_handler is not None: + logger.addHandler(console_handler) + + logger.info(f"=== Basemodel Training Started ===") + logger.info(f"Experiment Name: {exp_name}") + logger.info(f"Log Directory: {log_dir}") + logger.info(f"Rank: {rank}") + logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + + return logger + + +def create_dataloaders(config): + if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0: + print("Creating data loaders...") + + train_dataset = CustomKlineDataset( + data_path=config.data_path, + data_type='train', + lookback_window=config.lookback_window, + predict_window=config.predict_window, + clip=config.clip, + seed=config.seed, + train_ratio=config.train_ratio, + val_ratio=config.val_ratio, + test_ratio=config.test_ratio + ) + + val_dataset = CustomKlineDataset( + data_path=config.data_path, + data_type='val', + lookback_window=config.lookback_window, + predict_window=config.predict_window, + clip=config.clip, + seed=config.seed + 1, + train_ratio=config.train_ratio, + val_ratio=config.val_ratio, + test_ratio=config.test_ratio + ) + + use_ddp = dist.is_available() and dist.is_initialized() + train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None + val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None + + train_loader = DataLoader( + train_dataset, + batch_size=config.batch_size, + shuffle=(train_sampler is None), + num_workers=config.num_workers, + pin_memory=True, + drop_last=True, + sampler=train_sampler + ) + + val_loader = DataLoader( + val_dataset, + batch_size=config.batch_size, + shuffle=False, + num_workers=config.num_workers, + pin_memory=True, + drop_last=False, + sampler=val_sampler + ) + + if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0: + print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}") + + return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler + + +def train_model(model, tokenizer, device, config, save_dir, logger): + logger.info("Starting training...") + use_ddp = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if use_ddp else 0 + world_size = dist.get_world_size() if use_ddp else 1 + + train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config) + optimizer = torch.optim.AdamW( + model.parameters(), + lr=config.predictor_learning_rate, + betas=(config.adam_beta1, config.adam_beta2), + weight_decay=config.adam_weight_decay + ) + + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=config.predictor_learning_rate, + steps_per_epoch=len(train_loader), + epochs=config.basemodel_epochs, + pct_start=0.03, + div_factor=10 + ) + + if use_ddp: + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False) + + best_val_loss = float('inf') + batch_idx_global = 0 + + for epoch in range(config.basemodel_epochs): + epoch_start_time = time.time() + model.train() + + train_dataset.set_epoch_seed(epoch * 10000) + val_dataset.set_epoch_seed(0) + if train_sampler is not None: + train_sampler.set_epoch(epoch) + + epoch_train_loss = 0.0 + train_batches = 0 + + for batch_idx, (batch_x, batch_x_stamp) in enumerate(train_loader): + batch_x = batch_x.to(device, non_blocking=True) + batch_x_stamp = batch_x_stamp.to(device, non_blocking=True) + + with torch.no_grad(): + token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True) + + token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]] + token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]] + + logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :]) + loss, s1_loss, s2_loss = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1]) + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=3.0) + optimizer.step() + scheduler.step() + + epoch_train_loss += loss.item() + train_batches += 1 + + if (batch_idx_global + 1) % config.log_interval == 0: + lr = optimizer.param_groups[0]['lr'] + log_msg = (f"[Epoch {epoch+1}/{config.basemodel_epochs}, Step {batch_idx+1}/{len(train_loader)}] " + f"LR: {lr:.6f}, Loss: {loss.item():.4f}") + logger.info(log_msg) + if rank == 0: + print(log_msg) + + batch_idx_global += 1 + + model.eval() + val_loss = 0.0 + val_batches = 0 + + with torch.no_grad(): + for batch_x, batch_x_stamp in val_loader: + batch_x = batch_x.to(device, non_blocking=True) + batch_x_stamp = batch_x_stamp.to(device, non_blocking=True) + + token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True) + token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]] + token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]] + + logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :]) + loss, _, _ = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1]) + + val_loss += loss.item() + val_batches += 1 + + if use_ddp: + tensor_sum = torch.tensor([epoch_train_loss, train_batches, val_loss, val_batches], dtype=torch.float64, device=device) + dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM) + epoch_train_loss_all = tensor_sum[0].item() + train_batches_all = int(tensor_sum[1].item()) + val_loss_all = tensor_sum[2].item() + val_batches_all = int(tensor_sum[3].item()) + avg_train_loss = (epoch_train_loss_all / train_batches_all) if train_batches_all > 0 else 0.0 + avg_val_loss = (val_loss_all / val_batches_all) if val_batches_all > 0 else 0.0 + else: + avg_train_loss = epoch_train_loss / train_batches if train_batches > 0 else 0 + avg_val_loss = val_loss / val_batches if val_batches > 0 else 0 + + epoch_time = time.time() - epoch_start_time + epoch_summary = (f"\n--- Epoch {epoch+1}/{config.basemodel_epochs} Summary ---\n" + f"Training Loss: {avg_train_loss:.4f}\n" + f"Validation Loss: {avg_val_loss:.4f}\n" + f"Epoch Time: {epoch_time:.2f} seconds\n") + logger.info(epoch_summary) + if rank == 0: + print(epoch_summary) + + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + if rank == 0: + model_save_path = os.path.join(save_dir, "best_model") + os.makedirs(model_save_path, exist_ok=True) + (model.module if use_ddp else model).save_pretrained(model_save_path) + save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})" + logger.info(save_msg) + print(save_msg) + + return best_val_loss + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description='Kronos Basemodel Fine-tuning Training') + parser.add_argument('--config', type=str, default='config.yaml', + help='Configuration file path (default: config.yaml)') + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + config = CustomFinetuneConfig(args.config) + + os.makedirs(config.basemodel_save_path, exist_ok=True) + + log_dir = os.path.join(config.base_save_path, "logs") + logger = setup_logging(config.exp_name, log_dir, 0) + + torch.manual_seed(config.seed) + np.random.seed(config.seed) + random.seed(config.seed) + + logger.info("Loading pretrained model or random initialization...") + print("Loading pretrained model or random initialization...") + if getattr(config, 'pre_trained_tokenizer', True): + tokenizer = KronosTokenizer.from_pretrained(config.finetuned_tokenizer_path) + else: + import json, os + print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for training") + cfg_path_tok = os.path.join(config.pretrained_tokenizer_path if hasattr(config, 'pretrained_tokenizer_path') else config.finetuned_tokenizer_path, 'config.json') + with open(cfg_path_tok, 'r') as f: + arch_t = json.load(f) + tokenizer = KronosTokenizer( + d_in=arch_t.get('d_in', 6), + d_model=arch_t.get('d_model', 256), + n_heads=arch_t.get('n_heads', 4), + ff_dim=arch_t.get('ff_dim', 512), + n_enc_layers=arch_t.get('n_enc_layers', 4), + n_dec_layers=arch_t.get('n_dec_layers', 4), + ffn_dropout_p=arch_t.get('ffn_dropout_p', 0.0), + attn_dropout_p=arch_t.get('attn_dropout_p', 0.0), + resid_dropout_p=arch_t.get('resid_dropout_p', 0.0), + s1_bits=arch_t.get('s1_bits', 10), + s2_bits=arch_t.get('s2_bits', 10), + beta=arch_t.get('beta', 0.05), + gamma0=arch_t.get('gamma0', 1.0), + gamma=arch_t.get('gamma', 1.1), + zeta=arch_t.get('zeta', 0.05), + group_size=arch_t.get('group_size', 4) + ) + + if getattr(config, 'pre_trained_predictor', True): + model = Kronos.from_pretrained(config.pretrained_predictor_path) + else: + import json, os + print("pre_trained_predictor=False, randomly initializing Predictor architecture for training") + cfg_path = os.path.join(config.pretrained_predictor_path, 'config.json') + with open(cfg_path, 'r') as f: + arch = json.load(f) + model = Kronos( + s1_bits=arch.get('s1_bits', 10), + s2_bits=arch.get('s2_bits', 10), + n_layers=arch.get('n_layers', 12), + d_model=arch.get('d_model', 832), + n_heads=arch.get('n_heads', 16), + ff_dim=arch.get('ff_dim', 2048), + ffn_dropout_p=arch.get('ffn_dropout_p', 0.2), + attn_dropout_p=arch.get('attn_dropout_p', 0.0), + resid_dropout_p=arch.get('resid_dropout_p', 0.2), + token_dropout_p=arch.get('token_dropout_p', 0.0), + learn_te=arch.get('learn_te', True) + ) + + tokenizer = tokenizer.to(device) + model = model.to(device) + + model_size = sum(p.numel() for p in model.parameters()) + logger.info(f"Model parameters: {model_size:,}") + print(f"Model parameters: {model_size:,}") + + logger.info("=== Training Configuration ===") + logger.info(f"Data path: {config.data_path}") + logger.info(f"Lookback window: {config.lookback_window}") + logger.info(f"Predict window: {config.predict_window}") + logger.info(f"Batch size: {config.batch_size}") + logger.info(f"Learning rate: {config.predictor_learning_rate}") + logger.info(f"Training epochs: {config.basemodel_epochs}") + logger.info(f"Device: {device}") + logger.info(f"Tokenizer path: {config.finetuned_tokenizer_path}") + logger.info(f"Pretrained model path: {config.pretrained_predictor_path}") + + logger.info("Starting fine-tuning training...") + print("Starting fine-tuning training...") + best_val_loss = train_model(model, tokenizer, device, config, config.basemodel_save_path, logger) + + final_msg = f"Training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.basemodel_save_path}" + logger.info(final_msg) + print(final_msg) + + +if __name__ == "__main__": + main() diff --git a/finetune_csv/finetune_tokenizer.py b/finetune_csv/finetune_tokenizer.py new file mode 100644 index 0000000..3f8c0e0 --- /dev/null +++ b/finetune_csv/finetune_tokenizer.py @@ -0,0 +1,359 @@ +import os +import sys +import json +import time +import random +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from time import gmtime, strftime +import datetime +import logging +from logging.handlers import RotatingFileHandler +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +sys.path.append("../") +from model import KronosTokenizer +from finetune_base_model import CustomKlineDataset +from config_loader import CustomFinetuneConfig + + +def set_seed(seed: int, rank: int = 0): + actual_seed = seed + random.seed(actual_seed) + np.random.seed(actual_seed) + torch.manual_seed(actual_seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(actual_seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def get_model_size(model: torch.nn.Module) -> str: + total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + if total_params >= 1e9: + return f"{total_params / 1e9:.1f}B" + elif total_params >= 1e6: + return f"{total_params / 1e6:.1f}M" + else: + return f"{total_params / 1e3:.1f}K" + + +def format_time(seconds: float) -> str: + return str(datetime.timedelta(seconds=int(seconds))) + + +def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger: + os.makedirs(log_dir, exist_ok=True) + + logger = logging.getLogger(f"tokenizer_training_rank_{rank}") + logger.setLevel(logging.INFO) + + if logger.handlers: + return logger + + log_file = os.path.join(log_dir, f"tokenizer_training_rank_{rank}.log") + file_handler = RotatingFileHandler( + log_file, + maxBytes=10*1024*1024, + backupCount=5, + encoding='utf-8' + ) + file_handler.setLevel(logging.INFO) + + console_handler = None + if rank == 0: + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + file_handler.setFormatter(formatter) + if console_handler is not None: + console_handler.setFormatter(formatter) + + logger.addHandler(file_handler) + if console_handler is not None: + logger.addHandler(console_handler) + + logger.info(f"=== Tokenizer Training Started ===") + logger.info(f"Experiment Name: {exp_name}") + logger.info(f"Log Directory: {log_dir}") + logger.info(f"Rank: {rank}") + logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + + return logger + + +def create_dataloaders(config): + if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0: + print("Creating tokenizer training data loaders...") + + train_dataset = CustomKlineDataset( + data_path=config.data_path, + data_type="train", + lookback_window=config.lookback_window, + predict_window=config.predict_window, + clip=config.clip, + seed=config.seed, + train_ratio=config.train_ratio, + val_ratio=config.val_ratio, + test_ratio=config.test_ratio + ) + + val_dataset = CustomKlineDataset( + data_path=config.data_path, + data_type="val", + lookback_window=config.lookback_window, + predict_window=config.predict_window, + clip=config.clip, + seed=config.seed + 1, + train_ratio=config.train_ratio, + val_ratio=config.val_ratio, + test_ratio=config.test_ratio + ) + + use_ddp = dist.is_available() and dist.is_initialized() + train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None + val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None + + train_loader = DataLoader( + train_dataset, + batch_size=config.batch_size, + shuffle=(train_sampler is None), + num_workers=config.num_workers, + pin_memory=True, + drop_last=True, + sampler=train_sampler + ) + + val_loader = DataLoader( + val_dataset, + batch_size=config.batch_size, + shuffle=False, + num_workers=config.num_workers, + pin_memory=True, + drop_last=False, + sampler=val_sampler + ) + + if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0: + print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}") + + return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler + + +def train_tokenizer(model, device, config, save_dir, logger): + logger.info("Starting tokenizer training...") + use_ddp = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if use_ddp else 0 + world_size = dist.get_world_size() if use_ddp else 1 + + train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config) + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=config.tokenizer_learning_rate, + weight_decay=config.adam_weight_decay + ) + + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=config.tokenizer_learning_rate, + steps_per_epoch=len(train_loader), + epochs=config.tokenizer_epochs, + pct_start=0.03, + div_factor=10 + ) + + if use_ddp: + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False) + + best_val_loss = float("inf") + batch_idx_global = 0 + + accumulation_steps = getattr(config, 'accumulation_steps', 1) + + for epoch in range(config.tokenizer_epochs): + epoch_start_time = time.time() + model.train() + + train_dataset.set_epoch_seed(epoch * 10000) + val_dataset.set_epoch_seed(0) + if train_sampler is not None: + train_sampler.set_epoch(epoch) + + for batch_idx, (ori_batch_x, _) in enumerate(train_loader): + ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True) + + current_batch_total_loss = 0.0 + for j in range(accumulation_steps): + start_idx = j * (ori_batch_x.shape[0] // accumulation_steps) + end_idx = (j + 1) * (ori_batch_x.shape[0] // accumulation_steps) + batch_x = ori_batch_x[start_idx:end_idx] + + zs, bsq_loss, _, _ = (model.module if use_ddp else model)(batch_x) + z_pre, z = zs + + recon_loss_pre = F.mse_loss(z_pre, batch_x) + recon_loss_all = F.mse_loss(z, batch_x) + recon_loss = recon_loss_pre + recon_loss_all + loss = (recon_loss + bsq_loss) / 2 + + loss_scaled = loss / accumulation_steps + current_batch_total_loss += loss.item() + loss_scaled.backward() + + torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=2.0) + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + if (batch_idx_global + 1) % config.log_interval == 0: + avg_loss = current_batch_total_loss / accumulation_steps + lr = optimizer.param_groups[0]["lr"] + log_msg = (f"[Epoch {epoch+1}/{config.tokenizer_epochs}, Step {batch_idx+1}/{len(train_loader)}] " + f"LR: {lr:.6f}, Loss: {avg_loss:.4f}") + logger.info(log_msg) + if rank == 0: + print(log_msg) + + detail_msg = (f" - VQ Loss: {bsq_loss.item():.4f}\n" + f" - Recon Loss Pre: {recon_loss_pre.item():.4f}\n" + f" - Recon Loss All: {recon_loss_all.item():.4f}") + logger.info(detail_msg) + if rank == 0: + print(detail_msg) + + batch_idx_global += 1 + + model.eval() + tot_val_loss_sum_rank = 0.0 + val_sample_count_rank = 0 + + with torch.no_grad(): + for ori_batch_x, _ in val_loader: + ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True) + zs, _, _, _ = (model.module if use_ddp else model)(ori_batch_x) + _, z = zs + val_loss_item = F.mse_loss(z, ori_batch_x) + + tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0) + val_sample_count_rank += ori_batch_x.size(0) + + if use_ddp: + tensor_sum = torch.tensor([tot_val_loss_sum_rank, val_sample_count_rank], dtype=torch.float64, device=device) + dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM) + tot_val_loss_all = tensor_sum[0].item() + val_count_all = int(tensor_sum[1].item()) + avg_val_loss = (tot_val_loss_all / val_count_all) if val_count_all > 0 else 0.0 + else: + avg_val_loss = tot_val_loss_sum_rank / val_sample_count_rank if val_sample_count_rank > 0 else 0 + + epoch_time = time.time() - epoch_start_time + epoch_summary = (f"\n--- Epoch {epoch+1}/{config.tokenizer_epochs} Summary ---\n" + f"Validation Loss: {avg_val_loss:.4f}\n" + f"Epoch Time: {format_time(epoch_time)}\n" + f"Total Training Time: {format_time(time.time() - epoch_start_time)}\n") + logger.info(epoch_summary) + if rank == 0: + print(epoch_summary) + + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + if rank == 0: + model_save_path = os.path.join(save_dir, "best_model") + os.makedirs(model_save_path, exist_ok=True) + (model.module if use_ddp else model).save_pretrained(model_save_path) + save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})" + logger.info(save_msg) + print(save_msg) + + return best_val_loss + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description='Kronos Tokenizer Fine-tuning Training') + parser.add_argument('--config', type=str, default='config.yaml', + help='Configuration file path (default: config.yaml)') + args = parser.parse_args() + + config = CustomFinetuneConfig(args.config) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + config = CustomFinetuneConfig(args.config) + + os.makedirs(config.tokenizer_save_path, exist_ok=True) + + log_dir = os.path.join(config.base_save_path, "logs") + logger = setup_logging(config.exp_name, log_dir, 0) + + set_seed(config.seed) + + # 加载预训练tokenizer + if getattr(config, 'pre_trained_tokenizer', True): + logger.info("Loading pretrained tokenizer...") + print("Loading pretrained tokenizer...") + tokenizer = KronosTokenizer.from_pretrained(config.pretrained_tokenizer_path) + else: + print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture") + import json, os + cfg_path = os.path.join(config.pretrained_tokenizer_path, 'config.json') + with open(cfg_path, 'r') as f: + arch = json.load(f) + tokenizer = KronosTokenizer( + d_in=arch.get('d_in', 6), + d_model=arch.get('d_model', 256), + n_heads=arch.get('n_heads', 4), + ff_dim=arch.get('ff_dim', 512), + n_enc_layers=arch.get('n_enc_layers', 4), + n_dec_layers=arch.get('n_dec_layers', 4), + ffn_dropout_p=arch.get('ffn_dropout_p', 0.0), + attn_dropout_p=arch.get('attn_dropout_p', 0.0), + resid_dropout_p=arch.get('resid_dropout_p', 0.0), + s1_bits=arch.get('s1_bits', 10), + s2_bits=arch.get('s2_bits', 10), + beta=arch.get('beta', 0.05), + gamma0=arch.get('gamma0', 1.0), + gamma=arch.get('gamma', 1.1), + zeta=arch.get('zeta', 0.05), + group_size=arch.get('group_size', 4) + ) + tokenizer = tokenizer.to(device) + + model_size = get_model_size(tokenizer) + logger.info(f"Tokenizer parameters: {model_size}") + print(f"Tokenizer parameters: {model_size}") + + logger.info("=== Training Configuration ===") + logger.info(f"Data path: {config.data_path}") + logger.info(f"Lookback window: {config.lookback_window}") + logger.info(f"Predict window: {config.predict_window}") + logger.info(f"Batch size: {config.batch_size}") + logger.info(f"Learning rate: {config.tokenizer_learning_rate}") + logger.info(f"Training epochs: {config.tokenizer_epochs}") + logger.info(f"Device: {device}") + logger.info(f"Distributed training: False") + + logger.info("Starting tokenizer fine-tuning training...") + print("Starting tokenizer fine-tuning training...") + best_val_loss = train_tokenizer(tokenizer, device, config, config.tokenizer_save_path, logger) + + final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.tokenizer_save_path}" + logger.info(final_msg) + print(final_msg) + + +if __name__ == "__main__": + main() + diff --git a/finetune_csv/train_sequential.py b/finetune_csv/train_sequential.py new file mode 100644 index 0000000..533f66a --- /dev/null +++ b/finetune_csv/train_sequential.py @@ -0,0 +1,361 @@ +import os +import sys +import time +import argparse +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import torch.distributed as dist + +sys.path.append('../') +from model import Kronos, KronosTokenizer, KronosPredictor + +from config_loader import CustomFinetuneConfig +from finetune_tokenizer import train_tokenizer, set_seed, setup_logging as setup_tokenizer_logging +from finetune_base_model import train_model, create_dataloaders, setup_logging as setup_basemodel_logging + + +class SequentialTrainer: + + def __init__(self, config_path: str = None): + self.config = CustomFinetuneConfig(config_path) + self.rank = int(os.environ.get("RANK", "0")) + self.world_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", str(self.config.device_id if hasattr(self.config, 'device_id') else 0))) + self.device = self._setup_device() + + self.config.print_config_summary() + + def _setup_device(self): + if self.config.use_cuda and torch.cuda.is_available(): + torch.cuda.set_device(self.local_rank) + device = torch.device(f"cuda:{self.local_rank}") + else: + device = torch.device("cpu") + + if self.rank == 0: + print(f"Using device: {device} (rank={self.rank}, world_size={self.world_size}, local_rank={self.local_rank})") + return device + + def _setup_distributed(self): + if self.world_size > 1 and torch.cuda.is_available(): + backend = os.environ.get("DIST_BACKEND", "nccl").lower() + if not dist.is_initialized(): + dist.init_process_group(backend=backend) + if self.rank == 0: + print(f"Distributed training initialized: backend={backend}, world_size={self.world_size}") + else: + if self.rank == 0: + print("Distributed training not enabled, using single GPU/CPU training") + + def _check_existing_models(self): + tokenizer_exists = os.path.exists(self.config.tokenizer_best_model_path) + basemodel_exists = os.path.exists(self.config.basemodel_best_model_path) + + print(f"Tokenizer model exists: {tokenizer_exists}") + print(f"Basemodel model exists: {basemodel_exists}") + + return tokenizer_exists, basemodel_exists + + def _create_directories(self): + os.makedirs(self.config.tokenizer_save_path, exist_ok=True) + os.makedirs(self.config.basemodel_save_path, exist_ok=True) + print(f"Created directory: {self.config.tokenizer_save_path}") + print(f"Created directory: {self.config.basemodel_save_path}") + + def train_tokenizer_phase(self): + print("\n" + "="*60) + print("Starting Tokenizer Fine-tuning Phase") + print("="*60) + + tokenizer_exists, _ = self._check_existing_models() + if tokenizer_exists and self.config.skip_existing: + print("Tokenizer model already exists, skipping training") + return True + + log_dir = os.path.join(self.config.base_save_path, "logs") + logger = setup_tokenizer_logging(self.config.exp_name, log_dir, self.rank) + + set_seed(self.config.seed) + + if getattr(self.config, 'pre_trained_tokenizer', True): + logger.info("Loading pretrained tokenizer...") + if self.rank == 0: + print("Loading pretrained tokenizer...") + tokenizer = KronosTokenizer.from_pretrained(self.config.pretrained_tokenizer_path) + else: + if self.rank == 0: + print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture") + import json + cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json') + with open(cfg_path, 'r') as f: + arch = json.load(f) + tokenizer = KronosTokenizer( + d_in=arch.get('d_in', 6), + d_model=arch.get('d_model', 256), + n_heads=arch.get('n_heads', 4), + ff_dim=arch.get('ff_dim', 512), + n_enc_layers=arch.get('n_enc_layers', 4), + n_dec_layers=arch.get('n_dec_layers', 4), + ffn_dropout_p=arch.get('ffn_dropout_p', 0.0), + attn_dropout_p=arch.get('attn_dropout_p', 0.0), + resid_dropout_p=arch.get('resid_dropout_p', 0.0), + s1_bits=arch.get('s1_bits', 10), + s2_bits=arch.get('s2_bits', 10), + beta=arch.get('beta', 0.05), + gamma0=arch.get('gamma0', 1.0), + gamma=arch.get('gamma', 1.1), + zeta=arch.get('zeta', 0.05), + group_size=arch.get('group_size', 4) + ) + tokenizer = tokenizer.to(self.device) + + model_size = sum(p.numel() for p in tokenizer.parameters()) + logger.info(f"Tokenizer parameters: {model_size:,}") + if self.rank == 0: + print(f"Tokenizer parameters: {model_size:,}") + + logger.info("=== Training Configuration ===") + logger.info(f"Data path: {self.config.data_path}") + logger.info(f"Lookback window: {self.config.lookback_window}") + logger.info(f"Predict window: {self.config.predict_window}") + logger.info(f"Batch size: {self.config.batch_size}") + logger.info(f"Learning rate: {self.config.tokenizer_learning_rate}") + logger.info(f"Training epochs: {self.config.tokenizer_epochs}") + logger.info(f"Device: {self.device}") + logger.info(f"Distributed training: False") + + logger.info("Starting tokenizer fine-tuning training...") + if self.rank == 0: + print("Starting tokenizer fine-tuning training...") + start_time = time.time() + best_val_loss = train_tokenizer( + tokenizer, + self.device, + self.config, + self.config.tokenizer_save_path, + logger, + ) + training_time = time.time() - start_time + + final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.tokenizer_save_path}" + logger.info(final_msg) + if self.rank == 0: + print(f"\n{final_msg}") + + return True + + def train_basemodel_phase(self): + print("\n" + "="*60) + print("Starting Basemodel Fine-tuning Phase") + print("="*60) + + if getattr(self.config, 'pre_trained_tokenizer', True): + if not os.path.exists(self.config.finetuned_tokenizer_path): + raise FileNotFoundError(f"Fine-tuned tokenizer does not exist: {self.config.finetuned_tokenizer_path}") + + _, basemodel_exists = self._check_existing_models() + if basemodel_exists and self.config.skip_existing: + print("Basemodel model already exists, skipping training") + return True + + log_dir = os.path.join(self.config.base_save_path, "logs") + logger = setup_basemodel_logging(self.config.exp_name, log_dir, self.rank) + + set_seed(self.config.seed) + + if getattr(self.config, 'pre_trained_tokenizer', True): + logger.info("Loading fine-tuned tokenizer...") + if self.rank == 0: + print("Loading fine-tuned tokenizer...") + tokenizer = KronosTokenizer.from_pretrained(self.config.finetuned_tokenizer_path) + else: + if self.rank == 0: + print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for Predictor training") + import json + cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json') + with open(cfg_path, 'r') as f: + arch = json.load(f) + tokenizer = KronosTokenizer( + d_in=arch.get('d_in', 6), + d_model=arch.get('d_model', 256), + n_heads=arch.get('n_heads', 4), + ff_dim=arch.get('ff_dim', 512), + n_enc_layers=arch.get('n_enc_layers', 4), + n_dec_layers=arch.get('n_dec_layers', 4), + ffn_dropout_p=arch.get('ffn_dropout_p', 0.0), + attn_dropout_p=arch.get('attn_dropout_p', 0.0), + resid_dropout_p=arch.get('resid_dropout_p', 0.0), + s1_bits=arch.get('s1_bits', 10), + s2_bits=arch.get('s2_bits', 10), + beta=arch.get('beta', 0.05), + gamma0=arch.get('gamma0', 1.0), + gamma=arch.get('gamma', 1.1), + zeta=arch.get('zeta', 0.05), + group_size=arch.get('group_size', 4) + ) + tokenizer = tokenizer.to(self.device) + + if getattr(self.config, 'pre_trained_predictor', True): + logger.info("Loading pretrained predictor...") + if self.rank == 0: + print("Loading pretrained predictor...") + model = Kronos.from_pretrained(self.config.pretrained_predictor_path) + else: + if self.rank == 0: + print("pre_trained_predictor=False, randomly initializing Predictor architecture") + import json + cfg_path = os.path.join(self.config.pretrained_predictor_path, 'config.json') + with open(cfg_path, 'r') as f: + arch = json.load(f) + print("model_config: ", arch) + model = Kronos( + s1_bits=arch.get('s1_bits', 10), + s2_bits=arch.get('s2_bits', 10), + n_layers=arch.get('n_layers', 12), + d_model=arch.get('d_model', 832), + n_heads=arch.get('n_heads', 16), + ff_dim=arch.get('ff_dim', 2048), + ffn_dropout_p=arch.get('ffn_dropout_p', 0.2), + attn_dropout_p=arch.get('attn_dropout_p', 0.0), + resid_dropout_p=arch.get('resid_dropout_p', 0.2), + token_dropout_p=arch.get('token_dropout_p', 0.0), + learn_te=arch.get('learn_te', True) + ) + model = model.to(self.device) + + model_size = sum(p.numel() for p in model.parameters()) + logger.info(f"Model parameters: {model_size:,}") + if self.rank == 0: + print(f"Model parameters: {model_size:,}") + + logger.info("=== Training Configuration ===") + logger.info(f"Data path: {self.config.data_path}") + logger.info(f"Lookback window: {self.config.lookback_window}") + logger.info(f"Predict window: {self.config.predict_window}") + logger.info(f"Batch size: {self.config.batch_size}") + logger.info(f"Learning rate: {self.config.predictor_learning_rate}") + logger.info(f"Training epochs: {self.config.basemodel_epochs}") + logger.info(f"Device: {self.device}") + logger.info(f"Tokenizer path: {self.config.finetuned_tokenizer_path}") + logger.info(f"Pretrained model path: {self.config.pretrained_predictor_path}") + + logger.info("Starting fine-tuning training...") + if self.rank == 0: + print("Starting fine-tuning training...") + start_time = time.time() + best_val_loss = train_model( + model, + tokenizer, + self.device, + self.config, + self.config.basemodel_save_path, + logger, + ) + training_time = time.time() - start_time + + final_msg = f"Basemodel training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.basemodel_save_path}" + logger.info(final_msg) + if self.rank == 0: + print(f"\n{final_msg}") + + return True + + def run_training(self): + if self.rank == 0: + print("Starting Kronos model sequential fine-tuning training") + print(f"Experiment name: {self.config.experiment_name}") + print(f"Experiment description: {self.config.experiment_description}") + + self._setup_distributed() + + self._create_directories() + + tokenizer_exists, basemodel_exists = self._check_existing_models() + + total_start_time = time.time() + + try: + if self.config.train_tokenizer: + success = self.train_tokenizer_phase() + if not success: + print("Tokenizer training failed, terminating training") + return False + else: + print("Skipping Tokenizer training phase") + + if self.config.train_basemodel: + success = self.train_basemodel_phase() + if not success: + print("Basemodel training failed, terminating training") + return False + else: + print("Skipping Basemodel training phase") + + total_time = time.time() - total_start_time + + if self.rank == 0: + print("\n" + "="*60) + print("Training completed!") + print("="*60) + print(f"Total training time: {total_time/60:.2f} minutes") + print(f"Tokenizer model: {self.config.tokenizer_best_model_path}") + print(f"Basemodel model: {self.config.basemodel_best_model_path}") + print("="*60) + + return True + + except Exception as e: + if self.rank == 0: + print(f"Error occurred during training: {str(e)}") + import traceback + traceback.print_exc() + return False + + finally: + pass + + +def main(): + parser = argparse.ArgumentParser(description='Kronos Model Sequential Fine-tuning Training') + parser.add_argument('--config', type=str, default='config.yaml', + help='Configuration file path (default: config.yaml)') + parser.add_argument('--skip-tokenizer', action='store_true', + help='Skip tokenizer training phase') + parser.add_argument('--skip-basemodel', action='store_true', + help='Skip basemodel training phase') + parser.add_argument('--skip-existing', action='store_true', + help='Skip training for existing models') + + args = parser.parse_args() + + trainer = SequentialTrainer(args.config) + + if args.skip_tokenizer: + trainer.config.train_tokenizer = False + if args.skip_basemodel: + trainer.config.train_basemodel = False + if args.skip_existing: + trainer.config.skip_existing = True + + success = trainer.run_training() + + if success: + print("Training completed successfully!") + if dist.is_available() and dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + sys.exit(0) + else: + print("Training failed!") + if dist.is_available() and dist.is_initialized(): + try: + dist.barrier() + dist.destroy_process_group() + except Exception: + pass + sys.exit(1) + + +if __name__ == "__main__": + main()