add custom data finetune
This commit is contained in:
parent
083294bd84
commit
84f74ae341
101
finetune_csv/README.md
Normal file
101
finetune_csv/README.md
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# Kronos Fine-tuning Training with Custom Dataset
|
||||||
|
|
||||||
|
Supports fine-tuning training with custom CSV data using configuration files
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Configuration Setup
|
||||||
|
|
||||||
|
First edit the `config.yaml` file to set the correct paths and parameters:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Data configuration
|
||||||
|
data:
|
||||||
|
data_path: "/path/to/your/data.csv"
|
||||||
|
lookback_window: 512
|
||||||
|
predict_window: 48
|
||||||
|
# ... other parameters
|
||||||
|
|
||||||
|
# Model path configuration
|
||||||
|
model_paths:
|
||||||
|
pretrained_tokenizer: "/path/to/pretrained/tokenizer"
|
||||||
|
pretrained_predictor: "/path/to/pretrained/predictor"
|
||||||
|
base_save_path: "/path/to/save/models"
|
||||||
|
# ... other paths
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Run Training
|
||||||
|
|
||||||
|
Using train_sequential
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Complete training
|
||||||
|
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml
|
||||||
|
|
||||||
|
# Skip existing models
|
||||||
|
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing
|
||||||
|
|
||||||
|
# Only train tokenizer
|
||||||
|
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel
|
||||||
|
|
||||||
|
# Only train basemodel
|
||||||
|
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer
|
||||||
|
```
|
||||||
|
|
||||||
|
Run each stage separately
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Only train tokenizer
|
||||||
|
python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml
|
||||||
|
|
||||||
|
# Only train basemodel (requires fine-tuned tokenizer first)
|
||||||
|
python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
DDP Training
|
||||||
|
```bash
|
||||||
|
# Choose communication protocol yourself, nccl can be replaced with gloo
|
||||||
|
DIST_BACKEND=nccl \
|
||||||
|
torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Description
|
||||||
|
|
||||||
|
### Main Configuration Items
|
||||||
|
|
||||||
|
- **data**: Data-related configuration
|
||||||
|
- `data_path`: CSV data file path
|
||||||
|
- `lookback_window`: Lookback window size
|
||||||
|
- `predict_window`: Prediction window size
|
||||||
|
- `train_ratio/val_ratio/test_ratio`: Dataset split ratios
|
||||||
|
|
||||||
|
- **training**: Training-related configuration
|
||||||
|
- `epochs`: Number of training epochs
|
||||||
|
- `batch_size`: Batch size
|
||||||
|
- `tokenizer_learning_rate`: Tokenizer learning rate
|
||||||
|
- `predictor_learning_rate`: Predictor learning rate
|
||||||
|
|
||||||
|
- **model_paths**: Model path configuration
|
||||||
|
- `pretrained_tokenizer`: Pre-trained tokenizer path
|
||||||
|
- `pretrained_predictor`: Pre-trained predictor path
|
||||||
|
- `base_save_path`: Model save root directory
|
||||||
|
- `finetuned_tokenizer`: Fine-tuned tokenizer path (for basemodel training)
|
||||||
|
|
||||||
|
- **experiment**: Experiment control
|
||||||
|
- `train_tokenizer`: Whether to train tokenizer
|
||||||
|
- `train_basemodel`: Whether to train basemodel
|
||||||
|
- `skip_existing`: Whether to skip existing models
|
||||||
|
|
||||||
|
## Training Process
|
||||||
|
|
||||||
|
1. **Tokenizer Fine-tuning Stage**
|
||||||
|
- Load pre-trained tokenizer
|
||||||
|
- Fine-tune on custom data
|
||||||
|
- Save fine-tuned tokenizer to `{base_save_path}/tokenizer/best_model/`
|
||||||
|
|
||||||
|
2. **Basemodel Fine-tuning Stage**
|
||||||
|
- Load fine-tuned tokenizer and pre-trained predictor
|
||||||
|
- Fine-tune on custom data
|
||||||
|
- Save fine-tuned basemodel to `{base_save_path}/basemodel/best_model/`
|
||||||
|
|
||||||
|
**Data Format**: Ensure CSV file contains the following columns: `timestamps`, `open`, `high`, `low`, `close`, `volume`, `amount`
|
||||||
105
finetune_csv/README_CN.md
Normal file
105
finetune_csv/README_CN.md
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
# 自定义数据集的Kronos微调训练
|
||||||
|
|
||||||
|
支持使用配置文件进行自定义csv数据的微调训练
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 1. 配置设置
|
||||||
|
|
||||||
|
首先编辑 `config.yaml` 文件,设置正确的路径和参数:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# 数据配置
|
||||||
|
data:
|
||||||
|
data_path: "/path/to/your/data.csv"
|
||||||
|
lookback_window: 512
|
||||||
|
predict_window: 48
|
||||||
|
# ... 其他参数
|
||||||
|
|
||||||
|
# 模型路径配置
|
||||||
|
model_paths:
|
||||||
|
pretrained_tokenizer: "/path/to/pretrained/tokenizer"
|
||||||
|
pretrained_predictor: "/path/to/pretrained/predictor"
|
||||||
|
base_save_path: "/path/to/save/models"
|
||||||
|
# ... 其他路径
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 运行训练
|
||||||
|
|
||||||
|
|
||||||
|
使用train_sequential
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 完整训练
|
||||||
|
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml
|
||||||
|
|
||||||
|
# 跳过已存在的模型
|
||||||
|
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing
|
||||||
|
|
||||||
|
# 只训练tokenizer
|
||||||
|
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel
|
||||||
|
|
||||||
|
# 只训练basemodel
|
||||||
|
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer
|
||||||
|
```
|
||||||
|
|
||||||
|
单独运行各个阶段
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 只训练tokenizer
|
||||||
|
python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml
|
||||||
|
|
||||||
|
# 只训练basemodel(需要先有微调后的tokenizer)
|
||||||
|
python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
DDP训练
|
||||||
|
```bash
|
||||||
|
# 通信协议自行选择,nccl可替换gloo
|
||||||
|
DIST_BACKEND=nccl \
|
||||||
|
torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## 配置说明
|
||||||
|
|
||||||
|
### 主要配置项
|
||||||
|
|
||||||
|
- **data**: 数据相关配置
|
||||||
|
- `data_path`: CSV数据文件路径
|
||||||
|
- `lookback_window`: 回望窗口大小
|
||||||
|
- `predict_window`: 预测窗口大小
|
||||||
|
- `train_ratio/val_ratio/test_ratio`: 数据集分割比例
|
||||||
|
|
||||||
|
- **training**: 训练相关配置
|
||||||
|
- `epochs`: 训练轮数
|
||||||
|
- `batch_size`: 批次大小
|
||||||
|
- `tokenizer_learning_rate`: Tokenizer学习率
|
||||||
|
- `predictor_learning_rate`: Predictor学习率
|
||||||
|
|
||||||
|
- **model_paths**: 模型路径配置
|
||||||
|
- `pretrained_tokenizer`: 预训练tokenizer路径
|
||||||
|
- `pretrained_predictor`: 预训练predictor路径
|
||||||
|
- `base_save_path`: 模型保存根目录
|
||||||
|
- `finetuned_tokenizer`: 微调后tokenizer路径(用于basemodel训练)
|
||||||
|
|
||||||
|
- **experiment**: 实验控制
|
||||||
|
- `train_tokenizer`: 是否训练tokenizer
|
||||||
|
- `train_basemodel`: 是否训练basemodel
|
||||||
|
- `skip_existing`: 是否跳过已存在的模型
|
||||||
|
|
||||||
|
## 训练流程
|
||||||
|
|
||||||
|
1. **Tokenizer微调阶段**
|
||||||
|
- 加载预训练tokenizer
|
||||||
|
- 在自定义数据上微调
|
||||||
|
- 保存微调后的tokenizer到 `{base_save_path}/tokenizer/best_model/`
|
||||||
|
|
||||||
|
2. **Basemodel微调阶段**
|
||||||
|
- 加载微调后的tokenizer和预训练predictor
|
||||||
|
- 在自定义数据上微调
|
||||||
|
- 保存微调后的basemodel到 `{base_save_path}/basemodel/best_model/`
|
||||||
|
|
||||||
|
|
||||||
|
**数据格式**: 确保CSV文件包含以下列:`timestamps`, `open`, `high`, `low`, `close`, `volume`, `amount`
|
||||||
|
|
||||||
|
|
||||||
267
finetune_csv/config_loader.py
Normal file
267
finetune_csv/config_loader.py
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigLoader:
|
||||||
|
|
||||||
|
def __init__(self, config_path: str):
|
||||||
|
|
||||||
|
self.config_path = config_path
|
||||||
|
self.config = self._load_config()
|
||||||
|
|
||||||
|
def _load_config(self) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
if not os.path.exists(self.config_path):
|
||||||
|
raise FileNotFoundError(f"config file not found: {self.config_path}")
|
||||||
|
|
||||||
|
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
config = self._resolve_dynamic_paths(config)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def _resolve_dynamic_paths(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
exp_name = config.get('model_paths', {}).get('exp_name', '')
|
||||||
|
if not exp_name:
|
||||||
|
return config
|
||||||
|
|
||||||
|
base_path = config.get('model_paths', {}).get('base_path', '')
|
||||||
|
path_templates = {
|
||||||
|
'base_save_path': f"{base_path}/{exp_name}",
|
||||||
|
'finetuned_tokenizer': f"{base_path}/{exp_name}/tokenizer/best_model"
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'model_paths' in config:
|
||||||
|
for key, template in path_templates.items():
|
||||||
|
if key in config['model_paths']:
|
||||||
|
# only use template when the original value is empty string
|
||||||
|
current_value = config['model_paths'][key]
|
||||||
|
if current_value == "" or current_value is None:
|
||||||
|
config['model_paths'][key] = template
|
||||||
|
else:
|
||||||
|
# if the original value is not empty, use template to replace the {exp_name} placeholder
|
||||||
|
if isinstance(current_value, str) and '{exp_name}' in current_value:
|
||||||
|
config['model_paths'][key] = current_value.format(exp_name=exp_name)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def get(self, key: str, default=None):
|
||||||
|
|
||||||
|
keys = key.split('.')
|
||||||
|
value = self.config
|
||||||
|
|
||||||
|
try:
|
||||||
|
for k in keys:
|
||||||
|
value = value[k]
|
||||||
|
return value
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
return default
|
||||||
|
|
||||||
|
def get_data_config(self) -> Dict[str, Any]:
|
||||||
|
return self.config.get('data', {})
|
||||||
|
|
||||||
|
def get_training_config(self) -> Dict[str, Any]:
|
||||||
|
return self.config.get('training', {})
|
||||||
|
|
||||||
|
def get_model_paths(self) -> Dict[str, str]:
|
||||||
|
return self.config.get('model_paths', {})
|
||||||
|
|
||||||
|
def get_experiment_config(self) -> Dict[str, Any]:
|
||||||
|
return self.config.get('experiment', {})
|
||||||
|
|
||||||
|
def get_device_config(self) -> Dict[str, Any]:
|
||||||
|
return self.config.get('device', {})
|
||||||
|
|
||||||
|
def get_distributed_config(self) -> Dict[str, Any]:
|
||||||
|
return self.config.get('distributed', {})
|
||||||
|
|
||||||
|
def update_config(self, updates: Dict[str, Any]):
|
||||||
|
|
||||||
|
def update_nested_dict(d, u):
|
||||||
|
for k, v in u.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
d[k] = update_nested_dict(d.get(k, {}), v)
|
||||||
|
else:
|
||||||
|
d[k] = v
|
||||||
|
return d
|
||||||
|
|
||||||
|
self.config = update_nested_dict(self.config, updates)
|
||||||
|
|
||||||
|
def save_config(self, save_path: str = None):
|
||||||
|
|
||||||
|
if save_path is None:
|
||||||
|
save_path = self.config_path
|
||||||
|
|
||||||
|
with open(save_path, 'w', encoding='utf-8') as f:
|
||||||
|
yaml.dump(self.config, f, default_flow_style=False, allow_unicode=True, indent=2)
|
||||||
|
|
||||||
|
def print_config(self):
|
||||||
|
print("=" * 50)
|
||||||
|
print("Current configuration:")
|
||||||
|
print("=" * 50)
|
||||||
|
yaml.dump(self.config, default_flow_style=False, allow_unicode=True, indent=2)
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomFinetuneConfig:
|
||||||
|
|
||||||
|
def __init__(self, config_path: str = None):
|
||||||
|
|
||||||
|
if config_path is None:
|
||||||
|
config_path = os.path.join(os.path.dirname(__file__), 'config.yaml')
|
||||||
|
|
||||||
|
self.loader = ConfigLoader(config_path)
|
||||||
|
self._load_all_configs()
|
||||||
|
|
||||||
|
def _load_all_configs(self):
|
||||||
|
|
||||||
|
data_config = self.loader.get_data_config()
|
||||||
|
self.data_path = data_config.get('data_path')
|
||||||
|
self.lookback_window = data_config.get('lookback_window', 512)
|
||||||
|
self.predict_window = data_config.get('predict_window', 48)
|
||||||
|
self.max_context = data_config.get('max_context', 512)
|
||||||
|
self.clip = data_config.get('clip', 5.0)
|
||||||
|
self.train_ratio = data_config.get('train_ratio', 0.9)
|
||||||
|
self.val_ratio = data_config.get('val_ratio', 0.1)
|
||||||
|
self.test_ratio = data_config.get('test_ratio', 0.0)
|
||||||
|
|
||||||
|
# training configuration
|
||||||
|
training_config = self.loader.get_training_config()
|
||||||
|
# support training epochs of tokenizer and basemodel separately
|
||||||
|
self.tokenizer_epochs = training_config.get('tokenizer_epochs', 30)
|
||||||
|
self.basemodel_epochs = training_config.get('basemodel_epochs', 30)
|
||||||
|
|
||||||
|
if 'epochs' in training_config and 'tokenizer_epochs' not in training_config:
|
||||||
|
self.tokenizer_epochs = training_config.get('epochs', 30)
|
||||||
|
if 'epochs' in training_config and 'basemodel_epochs' not in training_config:
|
||||||
|
self.basemodel_epochs = training_config.get('epochs', 30)
|
||||||
|
|
||||||
|
self.batch_size = training_config.get('batch_size', 160)
|
||||||
|
self.log_interval = training_config.get('log_interval', 50)
|
||||||
|
self.num_workers = training_config.get('num_workers', 6)
|
||||||
|
self.seed = training_config.get('seed', 100)
|
||||||
|
self.tokenizer_learning_rate = training_config.get('tokenizer_learning_rate', 2e-4)
|
||||||
|
self.predictor_learning_rate = training_config.get('predictor_learning_rate', 4e-5)
|
||||||
|
self.adam_beta1 = training_config.get('adam_beta1', 0.9)
|
||||||
|
self.adam_beta2 = training_config.get('adam_beta2', 0.95)
|
||||||
|
self.adam_weight_decay = training_config.get('adam_weight_decay', 0.1)
|
||||||
|
self.accumulation_steps = training_config.get('accumulation_steps', 1)
|
||||||
|
|
||||||
|
model_paths = self.loader.get_model_paths()
|
||||||
|
self.exp_name = model_paths.get('exp_name', 'default_experiment')
|
||||||
|
self.pretrained_tokenizer_path = model_paths.get('pretrained_tokenizer')
|
||||||
|
self.pretrained_predictor_path = model_paths.get('pretrained_predictor')
|
||||||
|
self.base_save_path = model_paths.get('base_save_path')
|
||||||
|
self.tokenizer_save_name = model_paths.get('tokenizer_save_name', 'tokenizer')
|
||||||
|
self.basemodel_save_name = model_paths.get('basemodel_save_name', 'basemodel')
|
||||||
|
self.finetuned_tokenizer_path = model_paths.get('finetuned_tokenizer')
|
||||||
|
|
||||||
|
experiment_config = self.loader.get_experiment_config()
|
||||||
|
self.experiment_name = experiment_config.get('name', 'kronos_custom_finetune')
|
||||||
|
self.experiment_description = experiment_config.get('description', '')
|
||||||
|
self.use_comet = experiment_config.get('use_comet', False)
|
||||||
|
self.train_tokenizer = experiment_config.get('train_tokenizer', True)
|
||||||
|
self.train_basemodel = experiment_config.get('train_basemodel', True)
|
||||||
|
self.skip_existing = experiment_config.get('skip_existing', False)
|
||||||
|
|
||||||
|
unified_pretrained = experiment_config.get('pre_trained', None)
|
||||||
|
self.pre_trained_tokenizer = experiment_config.get('pre_trained_tokenizer', unified_pretrained if unified_pretrained is not None else True)
|
||||||
|
self.pre_trained_predictor = experiment_config.get('pre_trained_predictor', unified_pretrained if unified_pretrained is not None else True)
|
||||||
|
|
||||||
|
device_config = self.loader.get_device_config()
|
||||||
|
self.use_cuda = device_config.get('use_cuda', True)
|
||||||
|
self.device_id = device_config.get('device_id', 0)
|
||||||
|
|
||||||
|
distributed_config = self.loader.get_distributed_config()
|
||||||
|
self.use_ddp = distributed_config.get('use_ddp', False)
|
||||||
|
self.ddp_backend = distributed_config.get('backend', 'nccl')
|
||||||
|
|
||||||
|
self._compute_full_paths()
|
||||||
|
|
||||||
|
def _compute_full_paths(self):
|
||||||
|
|
||||||
|
self.tokenizer_save_path = os.path.join(self.base_save_path, self.tokenizer_save_name)
|
||||||
|
self.tokenizer_best_model_path = os.path.join(self.tokenizer_save_path, 'best_model')
|
||||||
|
|
||||||
|
self.basemodel_save_path = os.path.join(self.base_save_path, self.basemodel_save_name)
|
||||||
|
self.basemodel_best_model_path = os.path.join(self.basemodel_save_path, 'best_model')
|
||||||
|
|
||||||
|
def get_tokenizer_config(self):
|
||||||
|
|
||||||
|
return {
|
||||||
|
'data_path': self.data_path,
|
||||||
|
'lookback_window': self.lookback_window,
|
||||||
|
'predict_window': self.predict_window,
|
||||||
|
'max_context': self.max_context,
|
||||||
|
'clip': self.clip,
|
||||||
|
'train_ratio': self.train_ratio,
|
||||||
|
'val_ratio': self.val_ratio,
|
||||||
|
'test_ratio': self.test_ratio,
|
||||||
|
'epochs': self.tokenizer_epochs,
|
||||||
|
'batch_size': self.batch_size,
|
||||||
|
'log_interval': self.log_interval,
|
||||||
|
'num_workers': self.num_workers,
|
||||||
|
'seed': self.seed,
|
||||||
|
'learning_rate': self.tokenizer_learning_rate,
|
||||||
|
'adam_beta1': self.adam_beta1,
|
||||||
|
'adam_beta2': self.adam_beta2,
|
||||||
|
'adam_weight_decay': self.adam_weight_decay,
|
||||||
|
'accumulation_steps': self.accumulation_steps,
|
||||||
|
'pretrained_model_path': self.pretrained_tokenizer_path,
|
||||||
|
'save_path': self.tokenizer_save_path,
|
||||||
|
'use_comet': self.use_comet
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_basemodel_config(self):
|
||||||
|
|
||||||
|
return {
|
||||||
|
'data_path': self.data_path,
|
||||||
|
'lookback_window': self.lookback_window,
|
||||||
|
'predict_window': self.predict_window,
|
||||||
|
'max_context': self.max_context,
|
||||||
|
'clip': self.clip,
|
||||||
|
'train_ratio': self.train_ratio,
|
||||||
|
'val_ratio': self.val_ratio,
|
||||||
|
'test_ratio': self.test_ratio,
|
||||||
|
'epochs': self.basemodel_epochs,
|
||||||
|
'batch_size': self.batch_size,
|
||||||
|
'log_interval': self.log_interval,
|
||||||
|
'num_workers': self.num_workers,
|
||||||
|
'seed': self.seed,
|
||||||
|
'predictor_learning_rate': self.predictor_learning_rate,
|
||||||
|
'tokenizer_learning_rate': self.tokenizer_learning_rate,
|
||||||
|
'adam_beta1': self.adam_beta1,
|
||||||
|
'adam_beta2': self.adam_beta2,
|
||||||
|
'adam_weight_decay': self.adam_weight_decay,
|
||||||
|
'pretrained_tokenizer_path': self.finetuned_tokenizer_path,
|
||||||
|
'pretrained_predictor_path': self.pretrained_predictor_path,
|
||||||
|
'save_path': self.basemodel_save_path,
|
||||||
|
'use_comet': self.use_comet
|
||||||
|
}
|
||||||
|
|
||||||
|
def print_config_summary(self):
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("Kronos finetuning configuration summary")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Experiment name: {self.exp_name}")
|
||||||
|
print(f"Data path: {self.data_path}")
|
||||||
|
print(f"Lookback window: {self.lookback_window}")
|
||||||
|
print(f"Predict window: {self.predict_window}")
|
||||||
|
print(f"Tokenizer training epochs: {self.tokenizer_epochs}")
|
||||||
|
print(f"Basemodel training epochs: {self.basemodel_epochs}")
|
||||||
|
print(f"Batch size: {self.batch_size}")
|
||||||
|
print(f"Tokenizer learning rate: {self.tokenizer_learning_rate}")
|
||||||
|
print(f"Predictor learning rate: {self.predictor_learning_rate}")
|
||||||
|
print(f"Train tokenizer: {self.train_tokenizer}")
|
||||||
|
print(f"Train basemodel: {self.train_basemodel}")
|
||||||
|
print(f"Skip existing: {self.skip_existing}")
|
||||||
|
print(f"Use pre-trained tokenizer: {self.pre_trained_tokenizer}")
|
||||||
|
print(f"Use pre-trained predictor: {self.pre_trained_predictor}")
|
||||||
|
print(f"Base save path: {self.base_save_path}")
|
||||||
|
print(f"Tokenizer save path: {self.tokenizer_save_path}")
|
||||||
|
print(f"Basemodel save path: {self.basemodel_save_path}")
|
||||||
|
print("=" * 60)
|
||||||
72
finetune_csv/configs/config_ali09988_candle-5min.yaml
Normal file
72
finetune_csv/configs/config_ali09988_candle-5min.yaml
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
#This is a template config for custom finetuning kronos on csv data
|
||||||
|
#这是一份模板config,用于kronos的csv自定义数据微调
|
||||||
|
|
||||||
|
data:
|
||||||
|
data_path: "/xxxx/Kronos/finetune_csv2/data/HK_ali_09988_kline_5min_all.csv"
|
||||||
|
lookback_window: 512
|
||||||
|
predict_window: 48
|
||||||
|
max_context: 512
|
||||||
|
clip: 5.0
|
||||||
|
# dataset split ratio
|
||||||
|
train_ratio: 0.9
|
||||||
|
val_ratio: 0.1
|
||||||
|
test_ratio: 0.0
|
||||||
|
|
||||||
|
training:
|
||||||
|
# control the training epochs of tokenizer and basemodel
|
||||||
|
tokenizer_epochs: 30
|
||||||
|
basemodel_epochs: 20
|
||||||
|
batch_size: 32
|
||||||
|
log_interval: 50
|
||||||
|
num_workers: 6
|
||||||
|
seed: 42
|
||||||
|
|
||||||
|
tokenizer_learning_rate: 0.0002
|
||||||
|
predictor_learning_rate: 0.000001
|
||||||
|
|
||||||
|
adam_beta1: 0.9
|
||||||
|
adam_beta2: 0.95
|
||||||
|
adam_weight_decay: 0.1
|
||||||
|
|
||||||
|
# gradient accumulation steps for tokenizer training
|
||||||
|
accumulation_steps: 1
|
||||||
|
|
||||||
|
# model path configuration
|
||||||
|
model_paths:
|
||||||
|
# pretrained model path
|
||||||
|
pretrained_tokenizer: "/mnt/DigitalHuman2D/boyuzhang/quant/Kronos/pretrained/Kronos-Tokenizer-base"
|
||||||
|
pretrained_predictor: "/mnt/DigitalHuman2D/boyuzhang/quant/Kronos/pretrained/Kronos-base"
|
||||||
|
|
||||||
|
# experiment name - other paths will be generated based on this
|
||||||
|
exp_name: "HK_ali_09988_kline_5min_all"
|
||||||
|
base_path: "/mnt/DigitalHuman2D/boyuzhang/quant/Kronos/finetune_csv/finetuned/"
|
||||||
|
|
||||||
|
# the following paths will be generated based on exp_name, no need to modify manually
|
||||||
|
# way 1: leave empty string, the system will generate the full path
|
||||||
|
base_save_path: "" # /xxxx/Kronos/finetune_csv/finetuned/{exp_name}
|
||||||
|
finetuned_tokenizer: "" # /xxxx/quant/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model
|
||||||
|
|
||||||
|
# way 2: use template string, {exp_name} will be replaced with the actual experiment name
|
||||||
|
# base_save_path: "/xxxx/Kronos/finetune_csv/finetuned/{exp_name}"
|
||||||
|
# finetuned_tokenizer: "/xxxx/quant/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model"
|
||||||
|
|
||||||
|
tokenizer_save_name: "tokenizer"
|
||||||
|
basemodel_save_name: "basemodel"
|
||||||
|
|
||||||
|
experiment:
|
||||||
|
name: "kronos_custom_finetune"
|
||||||
|
description: "Custom finetune for HK stock data"
|
||||||
|
use_comet: false
|
||||||
|
|
||||||
|
# control the training phase
|
||||||
|
train_tokenizer: true
|
||||||
|
train_basemodel: true
|
||||||
|
|
||||||
|
# if true, skip the existing model training
|
||||||
|
skip_existing: false
|
||||||
|
|
||||||
|
# device configuration
|
||||||
|
device:
|
||||||
|
use_cuda: true
|
||||||
|
device_id: 0
|
||||||
|
|
||||||
468
finetune_csv/finetune_base_model.py
Normal file
468
finetune_csv/finetune_base_model.py
Normal file
@ -0,0 +1,468 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from time import gmtime, strftime
|
||||||
|
import logging
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
|
import datetime
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
sys.path.append('../')
|
||||||
|
from model import Kronos, KronosTokenizer, KronosPredictor
|
||||||
|
from config_loader import CustomFinetuneConfig
|
||||||
|
|
||||||
|
|
||||||
|
class CustomKlineDataset(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, data_path, data_type='train', lookback_window=90, predict_window=10,
|
||||||
|
clip=5.0, seed=100, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
|
||||||
|
self.data_path = data_path
|
||||||
|
self.data_type = data_type
|
||||||
|
self.lookback_window = lookback_window
|
||||||
|
self.predict_window = predict_window
|
||||||
|
self.window = lookback_window + predict_window + 1
|
||||||
|
self.clip = clip
|
||||||
|
self.seed = seed
|
||||||
|
self.train_ratio = train_ratio
|
||||||
|
self.val_ratio = val_ratio
|
||||||
|
self.test_ratio = test_ratio
|
||||||
|
|
||||||
|
self.feature_list = ['open', 'high', 'low', 'close', 'volume', 'amount']
|
||||||
|
self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month']
|
||||||
|
|
||||||
|
self.py_rng = random.Random(seed)
|
||||||
|
|
||||||
|
self._load_and_preprocess_data()
|
||||||
|
self._split_data_by_time()
|
||||||
|
|
||||||
|
self.n_samples = len(self.data) - self.window + 1
|
||||||
|
|
||||||
|
print(f"[{data_type.upper()}] Data length: {len(self.data)}, Available samples: {self.n_samples}")
|
||||||
|
|
||||||
|
def _load_and_preprocess_data(self):
|
||||||
|
df = pd.read_csv(self.data_path)
|
||||||
|
|
||||||
|
df['timestamps'] = pd.to_datetime(df['timestamps'])
|
||||||
|
df = df.sort_values('timestamps').reset_index(drop=True)
|
||||||
|
|
||||||
|
self.timestamps = df['timestamps'].copy()
|
||||||
|
|
||||||
|
df['minute'] = df['timestamps'].dt.minute
|
||||||
|
df['hour'] = df['timestamps'].dt.hour
|
||||||
|
df['weekday'] = df['timestamps'].dt.weekday
|
||||||
|
df['day'] = df['timestamps'].dt.day
|
||||||
|
df['month'] = df['timestamps'].dt.month
|
||||||
|
|
||||||
|
self.data = df[self.feature_list + self.time_feature_list].copy()
|
||||||
|
|
||||||
|
if self.data.isnull().any().any():
|
||||||
|
print("Warning: Missing values found in data, performing forward fill")
|
||||||
|
self.data = self.data.fillna(method='ffill')
|
||||||
|
|
||||||
|
print(f"Original data time range: {self.timestamps.min()} to {self.timestamps.max()}")
|
||||||
|
print(f"Original data total length: {len(df)} records")
|
||||||
|
|
||||||
|
def _split_data_by_time(self):
|
||||||
|
total_length = len(self.data)
|
||||||
|
|
||||||
|
train_end = int(total_length * self.train_ratio)
|
||||||
|
val_end = int(total_length * (self.train_ratio + self.val_ratio))
|
||||||
|
|
||||||
|
if self.data_type == 'train':
|
||||||
|
self.data = self.data.iloc[:train_end].copy()
|
||||||
|
self.timestamps = self.timestamps.iloc[:train_end].copy()
|
||||||
|
print(f"[{self.data_type.upper()}] Training set: first {train_end} time points ({self.train_ratio})")
|
||||||
|
print(f"[{self.data_type.upper()}] Training set time range: {self.timestamps.min()} to {self.timestamps.max()}")
|
||||||
|
elif self.data_type == 'val':
|
||||||
|
self.data = self.data.iloc[train_end:val_end].copy()
|
||||||
|
self.timestamps = self.timestamps.iloc[train_end:val_end].copy()
|
||||||
|
print(f"[{self.data_type.upper()}] Validation set: time points {train_end+1} to {val_end} ({self.val_ratio})")
|
||||||
|
print(f"[{self.data_type.upper()}] Validation set time range: {self.timestamps.min()} to {self.timestamps.max()}")
|
||||||
|
elif self.data_type == 'test':
|
||||||
|
self.data = self.data.iloc[val_end:].copy()
|
||||||
|
self.timestamps = self.timestamps.iloc[val_end:].copy()
|
||||||
|
print(f"[{self.data_type.upper()}] Test set: after time point {val_end+1}")
|
||||||
|
print(f"[{self.data_type.upper()}] Test set time range: {self.timestamps.min()} to {self.timestamps.max()}")
|
||||||
|
|
||||||
|
print(f"[{self.data_type.upper()}] Data length after split: {len(self.data)} records")
|
||||||
|
|
||||||
|
def set_epoch_seed(self, epoch):
|
||||||
|
epoch_seed = self.seed + epoch
|
||||||
|
self.py_rng.seed(epoch_seed)
|
||||||
|
self.current_epoch = epoch
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.n_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
max_start = len(self.data) - self.window
|
||||||
|
if max_start <= 0:
|
||||||
|
raise ValueError("Data length insufficient to create samples")
|
||||||
|
|
||||||
|
if self.data_type == 'train':
|
||||||
|
epoch = getattr(self, 'current_epoch', 0)
|
||||||
|
start_idx = (idx * 9973 + (epoch + 1) * 104729) % (max_start + 1)
|
||||||
|
else:
|
||||||
|
start_idx = idx % (max_start + 1)
|
||||||
|
|
||||||
|
end_idx = start_idx + self.window
|
||||||
|
|
||||||
|
window_data = self.data.iloc[start_idx:end_idx]
|
||||||
|
|
||||||
|
x = window_data[self.feature_list].values.astype(np.float32)
|
||||||
|
x_stamp = window_data[self.time_feature_list].values.astype(np.float32)
|
||||||
|
|
||||||
|
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
|
||||||
|
x = (x - x_mean) / (x_std + 1e-5)
|
||||||
|
x = np.clip(x, -self.clip, self.clip)
|
||||||
|
|
||||||
|
x_tensor = torch.from_numpy(x)
|
||||||
|
x_stamp_tensor = torch.from_numpy(x_stamp)
|
||||||
|
|
||||||
|
return x_tensor, x_stamp_tensor
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger:
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
logger = logging.getLogger(f"basemodel_training_rank_{rank}")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
if logger.handlers:
|
||||||
|
return logger
|
||||||
|
|
||||||
|
log_file = os.path.join(log_dir, f"basemodel_training_rank_{rank}.log")
|
||||||
|
file_handler = RotatingFileHandler(
|
||||||
|
log_file,
|
||||||
|
maxBytes=10*1024*1024,
|
||||||
|
backupCount=5,
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
file_handler.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
console_handler = None
|
||||||
|
if rank == 0:
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
if console_handler is not None:
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
if console_handler is not None:
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
logger.info(f"=== Basemodel Training Started ===")
|
||||||
|
logger.info(f"Experiment Name: {exp_name}")
|
||||||
|
logger.info(f"Log Directory: {log_dir}")
|
||||||
|
logger.info(f"Rank: {rank}")
|
||||||
|
logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataloaders(config):
|
||||||
|
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
|
||||||
|
print("Creating data loaders...")
|
||||||
|
|
||||||
|
train_dataset = CustomKlineDataset(
|
||||||
|
data_path=config.data_path,
|
||||||
|
data_type='train',
|
||||||
|
lookback_window=config.lookback_window,
|
||||||
|
predict_window=config.predict_window,
|
||||||
|
clip=config.clip,
|
||||||
|
seed=config.seed,
|
||||||
|
train_ratio=config.train_ratio,
|
||||||
|
val_ratio=config.val_ratio,
|
||||||
|
test_ratio=config.test_ratio
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataset = CustomKlineDataset(
|
||||||
|
data_path=config.data_path,
|
||||||
|
data_type='val',
|
||||||
|
lookback_window=config.lookback_window,
|
||||||
|
predict_window=config.predict_window,
|
||||||
|
clip=config.clip,
|
||||||
|
seed=config.seed + 1,
|
||||||
|
train_ratio=config.train_ratio,
|
||||||
|
val_ratio=config.val_ratio,
|
||||||
|
test_ratio=config.test_ratio
|
||||||
|
)
|
||||||
|
|
||||||
|
use_ddp = dist.is_available() and dist.is_initialized()
|
||||||
|
train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None
|
||||||
|
val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=(train_sampler is None),
|
||||||
|
num_workers=config.num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True,
|
||||||
|
sampler=train_sampler
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=config.num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=False,
|
||||||
|
sampler=val_sampler
|
||||||
|
)
|
||||||
|
|
||||||
|
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
|
||||||
|
print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
|
||||||
|
|
||||||
|
return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(model, tokenizer, device, config, save_dir, logger):
|
||||||
|
logger.info("Starting training...")
|
||||||
|
use_ddp = dist.is_available() and dist.is_initialized()
|
||||||
|
rank = dist.get_rank() if use_ddp else 0
|
||||||
|
world_size = dist.get_world_size() if use_ddp else 1
|
||||||
|
|
||||||
|
train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config)
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
model.parameters(),
|
||||||
|
lr=config.predictor_learning_rate,
|
||||||
|
betas=(config.adam_beta1, config.adam_beta2),
|
||||||
|
weight_decay=config.adam_weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
||||||
|
optimizer,
|
||||||
|
max_lr=config.predictor_learning_rate,
|
||||||
|
steps_per_epoch=len(train_loader),
|
||||||
|
epochs=config.basemodel_epochs,
|
||||||
|
pct_start=0.03,
|
||||||
|
div_factor=10
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_ddp:
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
|
||||||
|
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
batch_idx_global = 0
|
||||||
|
|
||||||
|
for epoch in range(config.basemodel_epochs):
|
||||||
|
epoch_start_time = time.time()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
train_dataset.set_epoch_seed(epoch * 10000)
|
||||||
|
val_dataset.set_epoch_seed(0)
|
||||||
|
if train_sampler is not None:
|
||||||
|
train_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
epoch_train_loss = 0.0
|
||||||
|
train_batches = 0
|
||||||
|
|
||||||
|
for batch_idx, (batch_x, batch_x_stamp) in enumerate(train_loader):
|
||||||
|
batch_x = batch_x.to(device, non_blocking=True)
|
||||||
|
batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
|
||||||
|
|
||||||
|
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
|
||||||
|
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
|
||||||
|
|
||||||
|
logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
|
||||||
|
loss, s1_loss, s2_loss = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=3.0)
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
epoch_train_loss += loss.item()
|
||||||
|
train_batches += 1
|
||||||
|
|
||||||
|
if (batch_idx_global + 1) % config.log_interval == 0:
|
||||||
|
lr = optimizer.param_groups[0]['lr']
|
||||||
|
log_msg = (f"[Epoch {epoch+1}/{config.basemodel_epochs}, Step {batch_idx+1}/{len(train_loader)}] "
|
||||||
|
f"LR: {lr:.6f}, Loss: {loss.item():.4f}")
|
||||||
|
logger.info(log_msg)
|
||||||
|
if rank == 0:
|
||||||
|
print(log_msg)
|
||||||
|
|
||||||
|
batch_idx_global += 1
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
val_loss = 0.0
|
||||||
|
val_batches = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_x, batch_x_stamp in val_loader:
|
||||||
|
batch_x = batch_x.to(device, non_blocking=True)
|
||||||
|
batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
|
||||||
|
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
|
||||||
|
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
|
||||||
|
|
||||||
|
logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
|
||||||
|
loss, _, _ = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
|
||||||
|
|
||||||
|
val_loss += loss.item()
|
||||||
|
val_batches += 1
|
||||||
|
|
||||||
|
if use_ddp:
|
||||||
|
tensor_sum = torch.tensor([epoch_train_loss, train_batches, val_loss, val_batches], dtype=torch.float64, device=device)
|
||||||
|
dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM)
|
||||||
|
epoch_train_loss_all = tensor_sum[0].item()
|
||||||
|
train_batches_all = int(tensor_sum[1].item())
|
||||||
|
val_loss_all = tensor_sum[2].item()
|
||||||
|
val_batches_all = int(tensor_sum[3].item())
|
||||||
|
avg_train_loss = (epoch_train_loss_all / train_batches_all) if train_batches_all > 0 else 0.0
|
||||||
|
avg_val_loss = (val_loss_all / val_batches_all) if val_batches_all > 0 else 0.0
|
||||||
|
else:
|
||||||
|
avg_train_loss = epoch_train_loss / train_batches if train_batches > 0 else 0
|
||||||
|
avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
|
||||||
|
|
||||||
|
epoch_time = time.time() - epoch_start_time
|
||||||
|
epoch_summary = (f"\n--- Epoch {epoch+1}/{config.basemodel_epochs} Summary ---\n"
|
||||||
|
f"Training Loss: {avg_train_loss:.4f}\n"
|
||||||
|
f"Validation Loss: {avg_val_loss:.4f}\n"
|
||||||
|
f"Epoch Time: {epoch_time:.2f} seconds\n")
|
||||||
|
logger.info(epoch_summary)
|
||||||
|
if rank == 0:
|
||||||
|
print(epoch_summary)
|
||||||
|
|
||||||
|
if avg_val_loss < best_val_loss:
|
||||||
|
best_val_loss = avg_val_loss
|
||||||
|
if rank == 0:
|
||||||
|
model_save_path = os.path.join(save_dir, "best_model")
|
||||||
|
os.makedirs(model_save_path, exist_ok=True)
|
||||||
|
(model.module if use_ddp else model).save_pretrained(model_save_path)
|
||||||
|
save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})"
|
||||||
|
logger.info(save_msg)
|
||||||
|
print(save_msg)
|
||||||
|
|
||||||
|
return best_val_loss
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Kronos Basemodel Fine-tuning Training')
|
||||||
|
parser.add_argument('--config', type=str, default='config.yaml',
|
||||||
|
help='Configuration file path (default: config.yaml)')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
config = CustomFinetuneConfig(args.config)
|
||||||
|
|
||||||
|
os.makedirs(config.basemodel_save_path, exist_ok=True)
|
||||||
|
|
||||||
|
log_dir = os.path.join(config.base_save_path, "logs")
|
||||||
|
logger = setup_logging(config.exp_name, log_dir, 0)
|
||||||
|
|
||||||
|
torch.manual_seed(config.seed)
|
||||||
|
np.random.seed(config.seed)
|
||||||
|
random.seed(config.seed)
|
||||||
|
|
||||||
|
logger.info("Loading pretrained model or random initialization...")
|
||||||
|
print("Loading pretrained model or random initialization...")
|
||||||
|
if getattr(config, 'pre_trained_tokenizer', True):
|
||||||
|
tokenizer = KronosTokenizer.from_pretrained(config.finetuned_tokenizer_path)
|
||||||
|
else:
|
||||||
|
import json, os
|
||||||
|
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for training")
|
||||||
|
cfg_path_tok = os.path.join(config.pretrained_tokenizer_path if hasattr(config, 'pretrained_tokenizer_path') else config.finetuned_tokenizer_path, 'config.json')
|
||||||
|
with open(cfg_path_tok, 'r') as f:
|
||||||
|
arch_t = json.load(f)
|
||||||
|
tokenizer = KronosTokenizer(
|
||||||
|
d_in=arch_t.get('d_in', 6),
|
||||||
|
d_model=arch_t.get('d_model', 256),
|
||||||
|
n_heads=arch_t.get('n_heads', 4),
|
||||||
|
ff_dim=arch_t.get('ff_dim', 512),
|
||||||
|
n_enc_layers=arch_t.get('n_enc_layers', 4),
|
||||||
|
n_dec_layers=arch_t.get('n_dec_layers', 4),
|
||||||
|
ffn_dropout_p=arch_t.get('ffn_dropout_p', 0.0),
|
||||||
|
attn_dropout_p=arch_t.get('attn_dropout_p', 0.0),
|
||||||
|
resid_dropout_p=arch_t.get('resid_dropout_p', 0.0),
|
||||||
|
s1_bits=arch_t.get('s1_bits', 10),
|
||||||
|
s2_bits=arch_t.get('s2_bits', 10),
|
||||||
|
beta=arch_t.get('beta', 0.05),
|
||||||
|
gamma0=arch_t.get('gamma0', 1.0),
|
||||||
|
gamma=arch_t.get('gamma', 1.1),
|
||||||
|
zeta=arch_t.get('zeta', 0.05),
|
||||||
|
group_size=arch_t.get('group_size', 4)
|
||||||
|
)
|
||||||
|
|
||||||
|
if getattr(config, 'pre_trained_predictor', True):
|
||||||
|
model = Kronos.from_pretrained(config.pretrained_predictor_path)
|
||||||
|
else:
|
||||||
|
import json, os
|
||||||
|
print("pre_trained_predictor=False, randomly initializing Predictor architecture for training")
|
||||||
|
cfg_path = os.path.join(config.pretrained_predictor_path, 'config.json')
|
||||||
|
with open(cfg_path, 'r') as f:
|
||||||
|
arch = json.load(f)
|
||||||
|
model = Kronos(
|
||||||
|
s1_bits=arch.get('s1_bits', 10),
|
||||||
|
s2_bits=arch.get('s2_bits', 10),
|
||||||
|
n_layers=arch.get('n_layers', 12),
|
||||||
|
d_model=arch.get('d_model', 832),
|
||||||
|
n_heads=arch.get('n_heads', 16),
|
||||||
|
ff_dim=arch.get('ff_dim', 2048),
|
||||||
|
ffn_dropout_p=arch.get('ffn_dropout_p', 0.2),
|
||||||
|
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
||||||
|
resid_dropout_p=arch.get('resid_dropout_p', 0.2),
|
||||||
|
token_dropout_p=arch.get('token_dropout_p', 0.0),
|
||||||
|
learn_te=arch.get('learn_te', True)
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = tokenizer.to(device)
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
model_size = sum(p.numel() for p in model.parameters())
|
||||||
|
logger.info(f"Model parameters: {model_size:,}")
|
||||||
|
print(f"Model parameters: {model_size:,}")
|
||||||
|
|
||||||
|
logger.info("=== Training Configuration ===")
|
||||||
|
logger.info(f"Data path: {config.data_path}")
|
||||||
|
logger.info(f"Lookback window: {config.lookback_window}")
|
||||||
|
logger.info(f"Predict window: {config.predict_window}")
|
||||||
|
logger.info(f"Batch size: {config.batch_size}")
|
||||||
|
logger.info(f"Learning rate: {config.predictor_learning_rate}")
|
||||||
|
logger.info(f"Training epochs: {config.basemodel_epochs}")
|
||||||
|
logger.info(f"Device: {device}")
|
||||||
|
logger.info(f"Tokenizer path: {config.finetuned_tokenizer_path}")
|
||||||
|
logger.info(f"Pretrained model path: {config.pretrained_predictor_path}")
|
||||||
|
|
||||||
|
logger.info("Starting fine-tuning training...")
|
||||||
|
print("Starting fine-tuning training...")
|
||||||
|
best_val_loss = train_model(model, tokenizer, device, config, config.basemodel_save_path, logger)
|
||||||
|
|
||||||
|
final_msg = f"Training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.basemodel_save_path}"
|
||||||
|
logger.info(final_msg)
|
||||||
|
print(final_msg)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
359
finetune_csv/finetune_tokenizer.py
Normal file
359
finetune_csv/finetune_tokenizer.py
Normal file
@ -0,0 +1,359 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from time import gmtime, strftime
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
sys.path.append("../")
|
||||||
|
from model import KronosTokenizer
|
||||||
|
from finetune_base_model import CustomKlineDataset
|
||||||
|
from config_loader import CustomFinetuneConfig
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed: int, rank: int = 0):
|
||||||
|
actual_seed = seed
|
||||||
|
random.seed(actual_seed)
|
||||||
|
np.random.seed(actual_seed)
|
||||||
|
torch.manual_seed(actual_seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(actual_seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_size(model: torch.nn.Module) -> str:
|
||||||
|
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
if total_params >= 1e9:
|
||||||
|
return f"{total_params / 1e9:.1f}B"
|
||||||
|
elif total_params >= 1e6:
|
||||||
|
return f"{total_params / 1e6:.1f}M"
|
||||||
|
else:
|
||||||
|
return f"{total_params / 1e3:.1f}K"
|
||||||
|
|
||||||
|
|
||||||
|
def format_time(seconds: float) -> str:
|
||||||
|
return str(datetime.timedelta(seconds=int(seconds)))
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger:
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
logger = logging.getLogger(f"tokenizer_training_rank_{rank}")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
if logger.handlers:
|
||||||
|
return logger
|
||||||
|
|
||||||
|
log_file = os.path.join(log_dir, f"tokenizer_training_rank_{rank}.log")
|
||||||
|
file_handler = RotatingFileHandler(
|
||||||
|
log_file,
|
||||||
|
maxBytes=10*1024*1024,
|
||||||
|
backupCount=5,
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
file_handler.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
console_handler = None
|
||||||
|
if rank == 0:
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
if console_handler is not None:
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
if console_handler is not None:
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
logger.info(f"=== Tokenizer Training Started ===")
|
||||||
|
logger.info(f"Experiment Name: {exp_name}")
|
||||||
|
logger.info(f"Log Directory: {log_dir}")
|
||||||
|
logger.info(f"Rank: {rank}")
|
||||||
|
logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataloaders(config):
|
||||||
|
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
|
||||||
|
print("Creating tokenizer training data loaders...")
|
||||||
|
|
||||||
|
train_dataset = CustomKlineDataset(
|
||||||
|
data_path=config.data_path,
|
||||||
|
data_type="train",
|
||||||
|
lookback_window=config.lookback_window,
|
||||||
|
predict_window=config.predict_window,
|
||||||
|
clip=config.clip,
|
||||||
|
seed=config.seed,
|
||||||
|
train_ratio=config.train_ratio,
|
||||||
|
val_ratio=config.val_ratio,
|
||||||
|
test_ratio=config.test_ratio
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataset = CustomKlineDataset(
|
||||||
|
data_path=config.data_path,
|
||||||
|
data_type="val",
|
||||||
|
lookback_window=config.lookback_window,
|
||||||
|
predict_window=config.predict_window,
|
||||||
|
clip=config.clip,
|
||||||
|
seed=config.seed + 1,
|
||||||
|
train_ratio=config.train_ratio,
|
||||||
|
val_ratio=config.val_ratio,
|
||||||
|
test_ratio=config.test_ratio
|
||||||
|
)
|
||||||
|
|
||||||
|
use_ddp = dist.is_available() and dist.is_initialized()
|
||||||
|
train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None
|
||||||
|
val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=(train_sampler is None),
|
||||||
|
num_workers=config.num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True,
|
||||||
|
sampler=train_sampler
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=config.num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=False,
|
||||||
|
sampler=val_sampler
|
||||||
|
)
|
||||||
|
|
||||||
|
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
|
||||||
|
print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
|
||||||
|
|
||||||
|
return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler
|
||||||
|
|
||||||
|
|
||||||
|
def train_tokenizer(model, device, config, save_dir, logger):
|
||||||
|
logger.info("Starting tokenizer training...")
|
||||||
|
use_ddp = dist.is_available() and dist.is_initialized()
|
||||||
|
rank = dist.get_rank() if use_ddp else 0
|
||||||
|
world_size = dist.get_world_size() if use_ddp else 1
|
||||||
|
|
||||||
|
train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config)
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
model.parameters(),
|
||||||
|
lr=config.tokenizer_learning_rate,
|
||||||
|
weight_decay=config.adam_weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
||||||
|
optimizer,
|
||||||
|
max_lr=config.tokenizer_learning_rate,
|
||||||
|
steps_per_epoch=len(train_loader),
|
||||||
|
epochs=config.tokenizer_epochs,
|
||||||
|
pct_start=0.03,
|
||||||
|
div_factor=10
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_ddp:
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
|
||||||
|
|
||||||
|
best_val_loss = float("inf")
|
||||||
|
batch_idx_global = 0
|
||||||
|
|
||||||
|
accumulation_steps = getattr(config, 'accumulation_steps', 1)
|
||||||
|
|
||||||
|
for epoch in range(config.tokenizer_epochs):
|
||||||
|
epoch_start_time = time.time()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
train_dataset.set_epoch_seed(epoch * 10000)
|
||||||
|
val_dataset.set_epoch_seed(0)
|
||||||
|
if train_sampler is not None:
|
||||||
|
train_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
for batch_idx, (ori_batch_x, _) in enumerate(train_loader):
|
||||||
|
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
|
||||||
|
|
||||||
|
current_batch_total_loss = 0.0
|
||||||
|
for j in range(accumulation_steps):
|
||||||
|
start_idx = j * (ori_batch_x.shape[0] // accumulation_steps)
|
||||||
|
end_idx = (j + 1) * (ori_batch_x.shape[0] // accumulation_steps)
|
||||||
|
batch_x = ori_batch_x[start_idx:end_idx]
|
||||||
|
|
||||||
|
zs, bsq_loss, _, _ = (model.module if use_ddp else model)(batch_x)
|
||||||
|
z_pre, z = zs
|
||||||
|
|
||||||
|
recon_loss_pre = F.mse_loss(z_pre, batch_x)
|
||||||
|
recon_loss_all = F.mse_loss(z, batch_x)
|
||||||
|
recon_loss = recon_loss_pre + recon_loss_all
|
||||||
|
loss = (recon_loss + bsq_loss) / 2
|
||||||
|
|
||||||
|
loss_scaled = loss / accumulation_steps
|
||||||
|
current_batch_total_loss += loss.item()
|
||||||
|
loss_scaled.backward()
|
||||||
|
|
||||||
|
torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=2.0)
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if (batch_idx_global + 1) % config.log_interval == 0:
|
||||||
|
avg_loss = current_batch_total_loss / accumulation_steps
|
||||||
|
lr = optimizer.param_groups[0]["lr"]
|
||||||
|
log_msg = (f"[Epoch {epoch+1}/{config.tokenizer_epochs}, Step {batch_idx+1}/{len(train_loader)}] "
|
||||||
|
f"LR: {lr:.6f}, Loss: {avg_loss:.4f}")
|
||||||
|
logger.info(log_msg)
|
||||||
|
if rank == 0:
|
||||||
|
print(log_msg)
|
||||||
|
|
||||||
|
detail_msg = (f" - VQ Loss: {bsq_loss.item():.4f}\n"
|
||||||
|
f" - Recon Loss Pre: {recon_loss_pre.item():.4f}\n"
|
||||||
|
f" - Recon Loss All: {recon_loss_all.item():.4f}")
|
||||||
|
logger.info(detail_msg)
|
||||||
|
if rank == 0:
|
||||||
|
print(detail_msg)
|
||||||
|
|
||||||
|
batch_idx_global += 1
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
tot_val_loss_sum_rank = 0.0
|
||||||
|
val_sample_count_rank = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for ori_batch_x, _ in val_loader:
|
||||||
|
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
|
||||||
|
zs, _, _, _ = (model.module if use_ddp else model)(ori_batch_x)
|
||||||
|
_, z = zs
|
||||||
|
val_loss_item = F.mse_loss(z, ori_batch_x)
|
||||||
|
|
||||||
|
tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0)
|
||||||
|
val_sample_count_rank += ori_batch_x.size(0)
|
||||||
|
|
||||||
|
if use_ddp:
|
||||||
|
tensor_sum = torch.tensor([tot_val_loss_sum_rank, val_sample_count_rank], dtype=torch.float64, device=device)
|
||||||
|
dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM)
|
||||||
|
tot_val_loss_all = tensor_sum[0].item()
|
||||||
|
val_count_all = int(tensor_sum[1].item())
|
||||||
|
avg_val_loss = (tot_val_loss_all / val_count_all) if val_count_all > 0 else 0.0
|
||||||
|
else:
|
||||||
|
avg_val_loss = tot_val_loss_sum_rank / val_sample_count_rank if val_sample_count_rank > 0 else 0
|
||||||
|
|
||||||
|
epoch_time = time.time() - epoch_start_time
|
||||||
|
epoch_summary = (f"\n--- Epoch {epoch+1}/{config.tokenizer_epochs} Summary ---\n"
|
||||||
|
f"Validation Loss: {avg_val_loss:.4f}\n"
|
||||||
|
f"Epoch Time: {format_time(epoch_time)}\n"
|
||||||
|
f"Total Training Time: {format_time(time.time() - epoch_start_time)}\n")
|
||||||
|
logger.info(epoch_summary)
|
||||||
|
if rank == 0:
|
||||||
|
print(epoch_summary)
|
||||||
|
|
||||||
|
if avg_val_loss < best_val_loss:
|
||||||
|
best_val_loss = avg_val_loss
|
||||||
|
if rank == 0:
|
||||||
|
model_save_path = os.path.join(save_dir, "best_model")
|
||||||
|
os.makedirs(model_save_path, exist_ok=True)
|
||||||
|
(model.module if use_ddp else model).save_pretrained(model_save_path)
|
||||||
|
save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})"
|
||||||
|
logger.info(save_msg)
|
||||||
|
print(save_msg)
|
||||||
|
|
||||||
|
return best_val_loss
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Kronos Tokenizer Fine-tuning Training')
|
||||||
|
parser.add_argument('--config', type=str, default='config.yaml',
|
||||||
|
help='Configuration file path (default: config.yaml)')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config = CustomFinetuneConfig(args.config)
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
config = CustomFinetuneConfig(args.config)
|
||||||
|
|
||||||
|
os.makedirs(config.tokenizer_save_path, exist_ok=True)
|
||||||
|
|
||||||
|
log_dir = os.path.join(config.base_save_path, "logs")
|
||||||
|
logger = setup_logging(config.exp_name, log_dir, 0)
|
||||||
|
|
||||||
|
set_seed(config.seed)
|
||||||
|
|
||||||
|
# 加载预训练tokenizer
|
||||||
|
if getattr(config, 'pre_trained_tokenizer', True):
|
||||||
|
logger.info("Loading pretrained tokenizer...")
|
||||||
|
print("Loading pretrained tokenizer...")
|
||||||
|
tokenizer = KronosTokenizer.from_pretrained(config.pretrained_tokenizer_path)
|
||||||
|
else:
|
||||||
|
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture")
|
||||||
|
import json, os
|
||||||
|
cfg_path = os.path.join(config.pretrained_tokenizer_path, 'config.json')
|
||||||
|
with open(cfg_path, 'r') as f:
|
||||||
|
arch = json.load(f)
|
||||||
|
tokenizer = KronosTokenizer(
|
||||||
|
d_in=arch.get('d_in', 6),
|
||||||
|
d_model=arch.get('d_model', 256),
|
||||||
|
n_heads=arch.get('n_heads', 4),
|
||||||
|
ff_dim=arch.get('ff_dim', 512),
|
||||||
|
n_enc_layers=arch.get('n_enc_layers', 4),
|
||||||
|
n_dec_layers=arch.get('n_dec_layers', 4),
|
||||||
|
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
|
||||||
|
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
||||||
|
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
|
||||||
|
s1_bits=arch.get('s1_bits', 10),
|
||||||
|
s2_bits=arch.get('s2_bits', 10),
|
||||||
|
beta=arch.get('beta', 0.05),
|
||||||
|
gamma0=arch.get('gamma0', 1.0),
|
||||||
|
gamma=arch.get('gamma', 1.1),
|
||||||
|
zeta=arch.get('zeta', 0.05),
|
||||||
|
group_size=arch.get('group_size', 4)
|
||||||
|
)
|
||||||
|
tokenizer = tokenizer.to(device)
|
||||||
|
|
||||||
|
model_size = get_model_size(tokenizer)
|
||||||
|
logger.info(f"Tokenizer parameters: {model_size}")
|
||||||
|
print(f"Tokenizer parameters: {model_size}")
|
||||||
|
|
||||||
|
logger.info("=== Training Configuration ===")
|
||||||
|
logger.info(f"Data path: {config.data_path}")
|
||||||
|
logger.info(f"Lookback window: {config.lookback_window}")
|
||||||
|
logger.info(f"Predict window: {config.predict_window}")
|
||||||
|
logger.info(f"Batch size: {config.batch_size}")
|
||||||
|
logger.info(f"Learning rate: {config.tokenizer_learning_rate}")
|
||||||
|
logger.info(f"Training epochs: {config.tokenizer_epochs}")
|
||||||
|
logger.info(f"Device: {device}")
|
||||||
|
logger.info(f"Distributed training: False")
|
||||||
|
|
||||||
|
logger.info("Starting tokenizer fine-tuning training...")
|
||||||
|
print("Starting tokenizer fine-tuning training...")
|
||||||
|
best_val_loss = train_tokenizer(tokenizer, device, config, config.tokenizer_save_path, logger)
|
||||||
|
|
||||||
|
final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.tokenizer_save_path}"
|
||||||
|
logger.info(final_msg)
|
||||||
|
print(final_msg)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
361
finetune_csv/train_sequential.py
Normal file
361
finetune_csv/train_sequential.py
Normal file
@ -0,0 +1,361 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
sys.path.append('../')
|
||||||
|
from model import Kronos, KronosTokenizer, KronosPredictor
|
||||||
|
|
||||||
|
from config_loader import CustomFinetuneConfig
|
||||||
|
from finetune_tokenizer import train_tokenizer, set_seed, setup_logging as setup_tokenizer_logging
|
||||||
|
from finetune_base_model import train_model, create_dataloaders, setup_logging as setup_basemodel_logging
|
||||||
|
|
||||||
|
|
||||||
|
class SequentialTrainer:
|
||||||
|
|
||||||
|
def __init__(self, config_path: str = None):
|
||||||
|
self.config = CustomFinetuneConfig(config_path)
|
||||||
|
self.rank = int(os.environ.get("RANK", "0"))
|
||||||
|
self.world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||||
|
self.local_rank = int(os.environ.get("LOCAL_RANK", str(self.config.device_id if hasattr(self.config, 'device_id') else 0)))
|
||||||
|
self.device = self._setup_device()
|
||||||
|
|
||||||
|
self.config.print_config_summary()
|
||||||
|
|
||||||
|
def _setup_device(self):
|
||||||
|
if self.config.use_cuda and torch.cuda.is_available():
|
||||||
|
torch.cuda.set_device(self.local_rank)
|
||||||
|
device = torch.device(f"cuda:{self.local_rank}")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
if self.rank == 0:
|
||||||
|
print(f"Using device: {device} (rank={self.rank}, world_size={self.world_size}, local_rank={self.local_rank})")
|
||||||
|
return device
|
||||||
|
|
||||||
|
def _setup_distributed(self):
|
||||||
|
if self.world_size > 1 and torch.cuda.is_available():
|
||||||
|
backend = os.environ.get("DIST_BACKEND", "nccl").lower()
|
||||||
|
if not dist.is_initialized():
|
||||||
|
dist.init_process_group(backend=backend)
|
||||||
|
if self.rank == 0:
|
||||||
|
print(f"Distributed training initialized: backend={backend}, world_size={self.world_size}")
|
||||||
|
else:
|
||||||
|
if self.rank == 0:
|
||||||
|
print("Distributed training not enabled, using single GPU/CPU training")
|
||||||
|
|
||||||
|
def _check_existing_models(self):
|
||||||
|
tokenizer_exists = os.path.exists(self.config.tokenizer_best_model_path)
|
||||||
|
basemodel_exists = os.path.exists(self.config.basemodel_best_model_path)
|
||||||
|
|
||||||
|
print(f"Tokenizer model exists: {tokenizer_exists}")
|
||||||
|
print(f"Basemodel model exists: {basemodel_exists}")
|
||||||
|
|
||||||
|
return tokenizer_exists, basemodel_exists
|
||||||
|
|
||||||
|
def _create_directories(self):
|
||||||
|
os.makedirs(self.config.tokenizer_save_path, exist_ok=True)
|
||||||
|
os.makedirs(self.config.basemodel_save_path, exist_ok=True)
|
||||||
|
print(f"Created directory: {self.config.tokenizer_save_path}")
|
||||||
|
print(f"Created directory: {self.config.basemodel_save_path}")
|
||||||
|
|
||||||
|
def train_tokenizer_phase(self):
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Starting Tokenizer Fine-tuning Phase")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
tokenizer_exists, _ = self._check_existing_models()
|
||||||
|
if tokenizer_exists and self.config.skip_existing:
|
||||||
|
print("Tokenizer model already exists, skipping training")
|
||||||
|
return True
|
||||||
|
|
||||||
|
log_dir = os.path.join(self.config.base_save_path, "logs")
|
||||||
|
logger = setup_tokenizer_logging(self.config.exp_name, log_dir, self.rank)
|
||||||
|
|
||||||
|
set_seed(self.config.seed)
|
||||||
|
|
||||||
|
if getattr(self.config, 'pre_trained_tokenizer', True):
|
||||||
|
logger.info("Loading pretrained tokenizer...")
|
||||||
|
if self.rank == 0:
|
||||||
|
print("Loading pretrained tokenizer...")
|
||||||
|
tokenizer = KronosTokenizer.from_pretrained(self.config.pretrained_tokenizer_path)
|
||||||
|
else:
|
||||||
|
if self.rank == 0:
|
||||||
|
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture")
|
||||||
|
import json
|
||||||
|
cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json')
|
||||||
|
with open(cfg_path, 'r') as f:
|
||||||
|
arch = json.load(f)
|
||||||
|
tokenizer = KronosTokenizer(
|
||||||
|
d_in=arch.get('d_in', 6),
|
||||||
|
d_model=arch.get('d_model', 256),
|
||||||
|
n_heads=arch.get('n_heads', 4),
|
||||||
|
ff_dim=arch.get('ff_dim', 512),
|
||||||
|
n_enc_layers=arch.get('n_enc_layers', 4),
|
||||||
|
n_dec_layers=arch.get('n_dec_layers', 4),
|
||||||
|
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
|
||||||
|
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
||||||
|
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
|
||||||
|
s1_bits=arch.get('s1_bits', 10),
|
||||||
|
s2_bits=arch.get('s2_bits', 10),
|
||||||
|
beta=arch.get('beta', 0.05),
|
||||||
|
gamma0=arch.get('gamma0', 1.0),
|
||||||
|
gamma=arch.get('gamma', 1.1),
|
||||||
|
zeta=arch.get('zeta', 0.05),
|
||||||
|
group_size=arch.get('group_size', 4)
|
||||||
|
)
|
||||||
|
tokenizer = tokenizer.to(self.device)
|
||||||
|
|
||||||
|
model_size = sum(p.numel() for p in tokenizer.parameters())
|
||||||
|
logger.info(f"Tokenizer parameters: {model_size:,}")
|
||||||
|
if self.rank == 0:
|
||||||
|
print(f"Tokenizer parameters: {model_size:,}")
|
||||||
|
|
||||||
|
logger.info("=== Training Configuration ===")
|
||||||
|
logger.info(f"Data path: {self.config.data_path}")
|
||||||
|
logger.info(f"Lookback window: {self.config.lookback_window}")
|
||||||
|
logger.info(f"Predict window: {self.config.predict_window}")
|
||||||
|
logger.info(f"Batch size: {self.config.batch_size}")
|
||||||
|
logger.info(f"Learning rate: {self.config.tokenizer_learning_rate}")
|
||||||
|
logger.info(f"Training epochs: {self.config.tokenizer_epochs}")
|
||||||
|
logger.info(f"Device: {self.device}")
|
||||||
|
logger.info(f"Distributed training: False")
|
||||||
|
|
||||||
|
logger.info("Starting tokenizer fine-tuning training...")
|
||||||
|
if self.rank == 0:
|
||||||
|
print("Starting tokenizer fine-tuning training...")
|
||||||
|
start_time = time.time()
|
||||||
|
best_val_loss = train_tokenizer(
|
||||||
|
tokenizer,
|
||||||
|
self.device,
|
||||||
|
self.config,
|
||||||
|
self.config.tokenizer_save_path,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
training_time = time.time() - start_time
|
||||||
|
|
||||||
|
final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.tokenizer_save_path}"
|
||||||
|
logger.info(final_msg)
|
||||||
|
if self.rank == 0:
|
||||||
|
print(f"\n{final_msg}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def train_basemodel_phase(self):
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Starting Basemodel Fine-tuning Phase")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
if getattr(self.config, 'pre_trained_tokenizer', True):
|
||||||
|
if not os.path.exists(self.config.finetuned_tokenizer_path):
|
||||||
|
raise FileNotFoundError(f"Fine-tuned tokenizer does not exist: {self.config.finetuned_tokenizer_path}")
|
||||||
|
|
||||||
|
_, basemodel_exists = self._check_existing_models()
|
||||||
|
if basemodel_exists and self.config.skip_existing:
|
||||||
|
print("Basemodel model already exists, skipping training")
|
||||||
|
return True
|
||||||
|
|
||||||
|
log_dir = os.path.join(self.config.base_save_path, "logs")
|
||||||
|
logger = setup_basemodel_logging(self.config.exp_name, log_dir, self.rank)
|
||||||
|
|
||||||
|
set_seed(self.config.seed)
|
||||||
|
|
||||||
|
if getattr(self.config, 'pre_trained_tokenizer', True):
|
||||||
|
logger.info("Loading fine-tuned tokenizer...")
|
||||||
|
if self.rank == 0:
|
||||||
|
print("Loading fine-tuned tokenizer...")
|
||||||
|
tokenizer = KronosTokenizer.from_pretrained(self.config.finetuned_tokenizer_path)
|
||||||
|
else:
|
||||||
|
if self.rank == 0:
|
||||||
|
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for Predictor training")
|
||||||
|
import json
|
||||||
|
cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json')
|
||||||
|
with open(cfg_path, 'r') as f:
|
||||||
|
arch = json.load(f)
|
||||||
|
tokenizer = KronosTokenizer(
|
||||||
|
d_in=arch.get('d_in', 6),
|
||||||
|
d_model=arch.get('d_model', 256),
|
||||||
|
n_heads=arch.get('n_heads', 4),
|
||||||
|
ff_dim=arch.get('ff_dim', 512),
|
||||||
|
n_enc_layers=arch.get('n_enc_layers', 4),
|
||||||
|
n_dec_layers=arch.get('n_dec_layers', 4),
|
||||||
|
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
|
||||||
|
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
||||||
|
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
|
||||||
|
s1_bits=arch.get('s1_bits', 10),
|
||||||
|
s2_bits=arch.get('s2_bits', 10),
|
||||||
|
beta=arch.get('beta', 0.05),
|
||||||
|
gamma0=arch.get('gamma0', 1.0),
|
||||||
|
gamma=arch.get('gamma', 1.1),
|
||||||
|
zeta=arch.get('zeta', 0.05),
|
||||||
|
group_size=arch.get('group_size', 4)
|
||||||
|
)
|
||||||
|
tokenizer = tokenizer.to(self.device)
|
||||||
|
|
||||||
|
if getattr(self.config, 'pre_trained_predictor', True):
|
||||||
|
logger.info("Loading pretrained predictor...")
|
||||||
|
if self.rank == 0:
|
||||||
|
print("Loading pretrained predictor...")
|
||||||
|
model = Kronos.from_pretrained(self.config.pretrained_predictor_path)
|
||||||
|
else:
|
||||||
|
if self.rank == 0:
|
||||||
|
print("pre_trained_predictor=False, randomly initializing Predictor architecture")
|
||||||
|
import json
|
||||||
|
cfg_path = os.path.join(self.config.pretrained_predictor_path, 'config.json')
|
||||||
|
with open(cfg_path, 'r') as f:
|
||||||
|
arch = json.load(f)
|
||||||
|
print("model_config: ", arch)
|
||||||
|
model = Kronos(
|
||||||
|
s1_bits=arch.get('s1_bits', 10),
|
||||||
|
s2_bits=arch.get('s2_bits', 10),
|
||||||
|
n_layers=arch.get('n_layers', 12),
|
||||||
|
d_model=arch.get('d_model', 832),
|
||||||
|
n_heads=arch.get('n_heads', 16),
|
||||||
|
ff_dim=arch.get('ff_dim', 2048),
|
||||||
|
ffn_dropout_p=arch.get('ffn_dropout_p', 0.2),
|
||||||
|
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
||||||
|
resid_dropout_p=arch.get('resid_dropout_p', 0.2),
|
||||||
|
token_dropout_p=arch.get('token_dropout_p', 0.0),
|
||||||
|
learn_te=arch.get('learn_te', True)
|
||||||
|
)
|
||||||
|
model = model.to(self.device)
|
||||||
|
|
||||||
|
model_size = sum(p.numel() for p in model.parameters())
|
||||||
|
logger.info(f"Model parameters: {model_size:,}")
|
||||||
|
if self.rank == 0:
|
||||||
|
print(f"Model parameters: {model_size:,}")
|
||||||
|
|
||||||
|
logger.info("=== Training Configuration ===")
|
||||||
|
logger.info(f"Data path: {self.config.data_path}")
|
||||||
|
logger.info(f"Lookback window: {self.config.lookback_window}")
|
||||||
|
logger.info(f"Predict window: {self.config.predict_window}")
|
||||||
|
logger.info(f"Batch size: {self.config.batch_size}")
|
||||||
|
logger.info(f"Learning rate: {self.config.predictor_learning_rate}")
|
||||||
|
logger.info(f"Training epochs: {self.config.basemodel_epochs}")
|
||||||
|
logger.info(f"Device: {self.device}")
|
||||||
|
logger.info(f"Tokenizer path: {self.config.finetuned_tokenizer_path}")
|
||||||
|
logger.info(f"Pretrained model path: {self.config.pretrained_predictor_path}")
|
||||||
|
|
||||||
|
logger.info("Starting fine-tuning training...")
|
||||||
|
if self.rank == 0:
|
||||||
|
print("Starting fine-tuning training...")
|
||||||
|
start_time = time.time()
|
||||||
|
best_val_loss = train_model(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
self.device,
|
||||||
|
self.config,
|
||||||
|
self.config.basemodel_save_path,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
training_time = time.time() - start_time
|
||||||
|
|
||||||
|
final_msg = f"Basemodel training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.basemodel_save_path}"
|
||||||
|
logger.info(final_msg)
|
||||||
|
if self.rank == 0:
|
||||||
|
print(f"\n{final_msg}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def run_training(self):
|
||||||
|
if self.rank == 0:
|
||||||
|
print("Starting Kronos model sequential fine-tuning training")
|
||||||
|
print(f"Experiment name: {self.config.experiment_name}")
|
||||||
|
print(f"Experiment description: {self.config.experiment_description}")
|
||||||
|
|
||||||
|
self._setup_distributed()
|
||||||
|
|
||||||
|
self._create_directories()
|
||||||
|
|
||||||
|
tokenizer_exists, basemodel_exists = self._check_existing_models()
|
||||||
|
|
||||||
|
total_start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.config.train_tokenizer:
|
||||||
|
success = self.train_tokenizer_phase()
|
||||||
|
if not success:
|
||||||
|
print("Tokenizer training failed, terminating training")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
print("Skipping Tokenizer training phase")
|
||||||
|
|
||||||
|
if self.config.train_basemodel:
|
||||||
|
success = self.train_basemodel_phase()
|
||||||
|
if not success:
|
||||||
|
print("Basemodel training failed, terminating training")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
print("Skipping Basemodel training phase")
|
||||||
|
|
||||||
|
total_time = time.time() - total_start_time
|
||||||
|
|
||||||
|
if self.rank == 0:
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Training completed!")
|
||||||
|
print("="*60)
|
||||||
|
print(f"Total training time: {total_time/60:.2f} minutes")
|
||||||
|
print(f"Tokenizer model: {self.config.tokenizer_best_model_path}")
|
||||||
|
print(f"Basemodel model: {self.config.basemodel_best_model_path}")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.rank == 0:
|
||||||
|
print(f"Error occurred during training: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Kronos Model Sequential Fine-tuning Training')
|
||||||
|
parser.add_argument('--config', type=str, default='config.yaml',
|
||||||
|
help='Configuration file path (default: config.yaml)')
|
||||||
|
parser.add_argument('--skip-tokenizer', action='store_true',
|
||||||
|
help='Skip tokenizer training phase')
|
||||||
|
parser.add_argument('--skip-basemodel', action='store_true',
|
||||||
|
help='Skip basemodel training phase')
|
||||||
|
parser.add_argument('--skip-existing', action='store_true',
|
||||||
|
help='Skip training for existing models')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
trainer = SequentialTrainer(args.config)
|
||||||
|
|
||||||
|
if args.skip_tokenizer:
|
||||||
|
trainer.config.train_tokenizer = False
|
||||||
|
if args.skip_basemodel:
|
||||||
|
trainer.config.train_basemodel = False
|
||||||
|
if args.skip_existing:
|
||||||
|
trainer.config.skip_existing = True
|
||||||
|
|
||||||
|
success = trainer.run_training()
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print("Training completed successfully!")
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
dist.barrier()
|
||||||
|
dist.destroy_process_group()
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
print("Training failed!")
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
try:
|
||||||
|
dist.barrier()
|
||||||
|
dist.destroy_process_group()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user