Merge pull request #138 from Luciferbobo/master
add CSV-based finetuning pipeline for Kronos models
This commit is contained in:
commit
082ab7ef62
1
.gitignore
vendored
1
.gitignore
vendored
@ -45,7 +45,6 @@ Desktop.ini
|
||||
|
||||
# Data files (large files)
|
||||
*.feather
|
||||
*.csv
|
||||
*.parquet
|
||||
*.h5
|
||||
*.hdf5
|
||||
|
||||
120
finetune_csv/README.md
Normal file
120
finetune_csv/README.md
Normal file
@ -0,0 +1,120 @@
|
||||
# Kronos Fine-tuning on Custom CSV Datasets
|
||||
|
||||
This module provides a comprehensive pipeline for fine-tuning Kronos models on your own CSV-formatted financial data. It supports both sequential training (tokenizer followed by predictor) and individual component training, with full distributed training capabilities.
|
||||
|
||||
|
||||
## 1. Data Preparation
|
||||
|
||||
### Required Data Format
|
||||
|
||||
Your CSV file must contain the following columns:
|
||||
- `timestamps`: DateTime stamps for each data point
|
||||
- `open`: Opening price
|
||||
- `high`: Highest price
|
||||
- `low`: Lowest price
|
||||
- `close`: Closing price
|
||||
- `volume`: Trading volume
|
||||
- `amount`: Trading amount
|
||||
|
||||
(volume and amount can be 0 if not available)
|
||||
|
||||
### Sample Data Format
|
||||
|
||||
| timestamps | open | close | high | low | volume | amount |
|
||||
|------------|------|-------|------|-----|--------|--------|
|
||||
| 2019/11/26 9:35 | 182.45215 | 184.45215 | 184.95215 | 182.45215 | 15136000 | 0 |
|
||||
| 2019/11/26 9:40 | 184.35215 | 183.85215 | 184.55215 | 183.45215 | 4433300 | 0 |
|
||||
| 2019/11/26 9:45 | 183.85215 | 183.35215 | 183.95215 | 182.95215 | 3070900 | 0 |
|
||||
|
||||
> **Reference**: Check `data/HK_ali_09988_kline_5min_all.csv` for a complete example of the proper data format.
|
||||
|
||||
|
||||
## 2. Config Preparation
|
||||
|
||||
|
||||
Please edit the correct data path & pretrained model path and set your training parameters.
|
||||
|
||||
```yaml
|
||||
# Data configuration
|
||||
data:
|
||||
data_path: "/path/to/your/data.csv"
|
||||
lookback_window: 512 # Historical data points to use
|
||||
predict_window: 48 # Future points to predict
|
||||
max_context: 512 # Maximum context length
|
||||
|
||||
...
|
||||
|
||||
```
|
||||
There are some other settings here, please see `configs/config_ali09988_candle-5min.yaml` for more comments.
|
||||
|
||||
## 3. Training
|
||||
|
||||
### Method 1: Sequential Training (Recommended)
|
||||
|
||||
The `train_sequential.py` script handles the complete training pipeline automatically:
|
||||
|
||||
```bash
|
||||
# Complete training (tokenizer + predictor)
|
||||
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml
|
||||
|
||||
# Skip existing models
|
||||
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing
|
||||
|
||||
# Only train tokenizer
|
||||
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel
|
||||
|
||||
# Only train predictor
|
||||
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer
|
||||
```
|
||||
|
||||
### Method 2: Individual Component Training
|
||||
|
||||
Train each component separately for more control:
|
||||
|
||||
```bash
|
||||
# Step 1: Train tokenizer
|
||||
python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml
|
||||
|
||||
# Step 2: Train predictor (requires fine-tuned tokenizer)
|
||||
python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml
|
||||
```
|
||||
|
||||
### DDP Training
|
||||
|
||||
For faster training on multiple GPUs:
|
||||
|
||||
```bash
|
||||
# Set communication backend (nccl for NVIDIA GPUs, gloo for CPU/mixed)
|
||||
DIST_BACKEND=nccl \
|
||||
torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml
|
||||
```
|
||||
|
||||
## 4. Training Results
|
||||
|
||||
The training process generates several outputs:
|
||||
|
||||
### Model Checkpoints
|
||||
- **Tokenizer**: Saved to `{base_save_path}/{exp_name}/tokenizer/best_model/`
|
||||
- **Predictor**: Saved to `{base_save_path}/{exp_name}/basemodel/best_model/`
|
||||
|
||||
### Training Logs
|
||||
- **Console output**: Real-time training progress and metrics
|
||||
- **Log files**: Detailed logs saved to `{base_save_path}/logs/`
|
||||
- **Validation tracking**: Best models are saved based on validation loss
|
||||
|
||||
## 5. Prediction Vis
|
||||
|
||||
The following images show example training results on alibaba (HK stock) data:
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
118
finetune_csv/README_CN.md
Normal file
118
finetune_csv/README_CN.md
Normal file
@ -0,0 +1,118 @@
|
||||
# Kronos微调-支持自定义CSV数据集
|
||||
|
||||
这是一个在自定义的CSV格式数据上微调Kronos模型的完整流程。包含顺序训练(先训练tokenizer再训练predictor)和单独模块训练,同时支持分布式训练。
|
||||
|
||||
|
||||
## 1. 准备数据
|
||||
|
||||
### 数据格式
|
||||
|
||||
CSV文件必须包含以下列:
|
||||
- `timestamps`: 每个数据点的时间戳
|
||||
- `open`: 开盘价
|
||||
- `high`: 最高价
|
||||
- `low`: 最低价
|
||||
- `close`: 收盘价
|
||||
- `volume`: 交易量
|
||||
- `amount`: 交易金额
|
||||
|
||||
(volume和amount可以全0如果没有这部分的数据)
|
||||
|
||||
### 示例数据格式
|
||||
|
||||
| timestamps | open | close | high | low | volume | amount |
|
||||
|------------|------|-------|------|-----|--------|--------|
|
||||
| 2019/11/26 9:35 | 182.45215 | 184.45215 | 184.95215 | 182.45215 | 15136000 | 0 |
|
||||
| 2019/11/26 9:40 | 184.35215 | 183.85215 | 184.55215 | 183.45215 | 4433300 | 0 |
|
||||
| 2019/11/26 9:45 | 183.85215 | 183.35215 | 183.95215 | 182.95215 | 3070900 | 0 |
|
||||
|
||||
> **标准数据样例**: `data/HK_ali_09988_kline_5min_all.csv`
|
||||
|
||||
## 2. 准备config文件
|
||||
|
||||
data_path及预训练模型路径需要修改,训练参数可以自己调节
|
||||
|
||||
```yaml
|
||||
# 数据配置
|
||||
data:
|
||||
data_path: "/path/to/your/data.csv"
|
||||
lookback_window: 512 # 要使用的历史数据点
|
||||
predict_window: 48 # 要预测的未来点数
|
||||
max_context: 512 # 最大上下文长度
|
||||
|
||||
...
|
||||
|
||||
```
|
||||
这里还有其他一些设置, `configs/config_ali09988_candle-5min.yaml` 有更详细的注释。
|
||||
|
||||
## 3. 训练
|
||||
|
||||
### 方法1: 直接顺序训练
|
||||
|
||||
`train_sequential.py` 脚本自动处理完整的训练流程:
|
||||
|
||||
```bash
|
||||
# 完整训练(tokenizer + predictor)
|
||||
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml
|
||||
|
||||
# 跳过已存在的模型
|
||||
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing
|
||||
|
||||
# 只训练tokenizer
|
||||
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel
|
||||
|
||||
# 只训练predictor
|
||||
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer
|
||||
```
|
||||
|
||||
### 方法2: 单独组件训练
|
||||
|
||||
可以单独训练每个组件:
|
||||
|
||||
```bash
|
||||
# 步骤1: 训练tokenizer
|
||||
python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml
|
||||
|
||||
# 步骤2: 训练predictor(需要微调后的tokenizer)
|
||||
python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml
|
||||
```
|
||||
|
||||
### DDP训练
|
||||
|
||||
如果有多卡,可以开启ddp加速训练:
|
||||
|
||||
```bash
|
||||
# 设置通信后端(NVIDIA GPU用nccl,CPU/混合用gloo)
|
||||
DIST_BACKEND=nccl \
|
||||
torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml
|
||||
```
|
||||
|
||||
## 4. 训练结果
|
||||
|
||||
训练过程生成以下输出:
|
||||
|
||||
### 模型检查点
|
||||
- **Tokenizer**: 保存到 `{base_save_path}/{exp_name}/tokenizer/best_model/`
|
||||
- **Predictor**: 保存到 `{base_save_path}/{exp_name}/basemodel/best_model/`
|
||||
|
||||
### 训练日志
|
||||
- **控制台输出**: 实时训练进度和指标
|
||||
- **日志文件**: 详细日志保存到 `{base_save_path}/logs/`
|
||||
- **验证跟踪**: 基于验证损失保存最佳模型
|
||||
|
||||
## 5. 预测可视化
|
||||
|
||||
以下图像显示了kronos在阿里巴巴股票数据上微调后的示例训练结果:
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
267
finetune_csv/config_loader.py
Normal file
267
finetune_csv/config_loader.py
Normal file
@ -0,0 +1,267 @@
|
||||
import os
|
||||
import yaml
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
|
||||
self.config_path = config_path
|
||||
self.config = self._load_config()
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
|
||||
if not os.path.exists(self.config_path):
|
||||
raise FileNotFoundError(f"config file not found: {self.config_path}")
|
||||
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
config = self._resolve_dynamic_paths(config)
|
||||
|
||||
return config
|
||||
|
||||
def _resolve_dynamic_paths(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
exp_name = config.get('model_paths', {}).get('exp_name', '')
|
||||
if not exp_name:
|
||||
return config
|
||||
|
||||
base_path = config.get('model_paths', {}).get('base_path', '')
|
||||
path_templates = {
|
||||
'base_save_path': f"{base_path}/{exp_name}",
|
||||
'finetuned_tokenizer': f"{base_path}/{exp_name}/tokenizer/best_model"
|
||||
}
|
||||
|
||||
if 'model_paths' in config:
|
||||
for key, template in path_templates.items():
|
||||
if key in config['model_paths']:
|
||||
# only use template when the original value is empty string
|
||||
current_value = config['model_paths'][key]
|
||||
if current_value == "" or current_value is None:
|
||||
config['model_paths'][key] = template
|
||||
else:
|
||||
# if the original value is not empty, use template to replace the {exp_name} placeholder
|
||||
if isinstance(current_value, str) and '{exp_name}' in current_value:
|
||||
config['model_paths'][key] = current_value.format(exp_name=exp_name)
|
||||
|
||||
return config
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
|
||||
keys = key.split('.')
|
||||
value = self.config
|
||||
|
||||
try:
|
||||
for k in keys:
|
||||
value = value[k]
|
||||
return value
|
||||
except (KeyError, TypeError):
|
||||
return default
|
||||
|
||||
def get_data_config(self) -> Dict[str, Any]:
|
||||
return self.config.get('data', {})
|
||||
|
||||
def get_training_config(self) -> Dict[str, Any]:
|
||||
return self.config.get('training', {})
|
||||
|
||||
def get_model_paths(self) -> Dict[str, str]:
|
||||
return self.config.get('model_paths', {})
|
||||
|
||||
def get_experiment_config(self) -> Dict[str, Any]:
|
||||
return self.config.get('experiment', {})
|
||||
|
||||
def get_device_config(self) -> Dict[str, Any]:
|
||||
return self.config.get('device', {})
|
||||
|
||||
def get_distributed_config(self) -> Dict[str, Any]:
|
||||
return self.config.get('distributed', {})
|
||||
|
||||
def update_config(self, updates: Dict[str, Any]):
|
||||
|
||||
def update_nested_dict(d, u):
|
||||
for k, v in u.items():
|
||||
if isinstance(v, dict):
|
||||
d[k] = update_nested_dict(d.get(k, {}), v)
|
||||
else:
|
||||
d[k] = v
|
||||
return d
|
||||
|
||||
self.config = update_nested_dict(self.config, updates)
|
||||
|
||||
def save_config(self, save_path: str = None):
|
||||
|
||||
if save_path is None:
|
||||
save_path = self.config_path
|
||||
|
||||
with open(save_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(self.config, f, default_flow_style=False, allow_unicode=True, indent=2)
|
||||
|
||||
def print_config(self):
|
||||
print("=" * 50)
|
||||
print("Current configuration:")
|
||||
print("=" * 50)
|
||||
yaml.dump(self.config, default_flow_style=False, allow_unicode=True, indent=2)
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
class CustomFinetuneConfig:
|
||||
|
||||
def __init__(self, config_path: str = None):
|
||||
|
||||
if config_path is None:
|
||||
config_path = os.path.join(os.path.dirname(__file__), 'config.yaml')
|
||||
|
||||
self.loader = ConfigLoader(config_path)
|
||||
self._load_all_configs()
|
||||
|
||||
def _load_all_configs(self):
|
||||
|
||||
data_config = self.loader.get_data_config()
|
||||
self.data_path = data_config.get('data_path')
|
||||
self.lookback_window = data_config.get('lookback_window', 512)
|
||||
self.predict_window = data_config.get('predict_window', 48)
|
||||
self.max_context = data_config.get('max_context', 512)
|
||||
self.clip = data_config.get('clip', 5.0)
|
||||
self.train_ratio = data_config.get('train_ratio', 0.9)
|
||||
self.val_ratio = data_config.get('val_ratio', 0.1)
|
||||
self.test_ratio = data_config.get('test_ratio', 0.0)
|
||||
|
||||
# training configuration
|
||||
training_config = self.loader.get_training_config()
|
||||
# support training epochs of tokenizer and basemodel separately
|
||||
self.tokenizer_epochs = training_config.get('tokenizer_epochs', 30)
|
||||
self.basemodel_epochs = training_config.get('basemodel_epochs', 30)
|
||||
|
||||
if 'epochs' in training_config and 'tokenizer_epochs' not in training_config:
|
||||
self.tokenizer_epochs = training_config.get('epochs', 30)
|
||||
if 'epochs' in training_config and 'basemodel_epochs' not in training_config:
|
||||
self.basemodel_epochs = training_config.get('epochs', 30)
|
||||
|
||||
self.batch_size = training_config.get('batch_size', 160)
|
||||
self.log_interval = training_config.get('log_interval', 50)
|
||||
self.num_workers = training_config.get('num_workers', 6)
|
||||
self.seed = training_config.get('seed', 100)
|
||||
self.tokenizer_learning_rate = training_config.get('tokenizer_learning_rate', 2e-4)
|
||||
self.predictor_learning_rate = training_config.get('predictor_learning_rate', 4e-5)
|
||||
self.adam_beta1 = training_config.get('adam_beta1', 0.9)
|
||||
self.adam_beta2 = training_config.get('adam_beta2', 0.95)
|
||||
self.adam_weight_decay = training_config.get('adam_weight_decay', 0.1)
|
||||
self.accumulation_steps = training_config.get('accumulation_steps', 1)
|
||||
|
||||
model_paths = self.loader.get_model_paths()
|
||||
self.exp_name = model_paths.get('exp_name', 'default_experiment')
|
||||
self.pretrained_tokenizer_path = model_paths.get('pretrained_tokenizer')
|
||||
self.pretrained_predictor_path = model_paths.get('pretrained_predictor')
|
||||
self.base_save_path = model_paths.get('base_save_path')
|
||||
self.tokenizer_save_name = model_paths.get('tokenizer_save_name', 'tokenizer')
|
||||
self.basemodel_save_name = model_paths.get('basemodel_save_name', 'basemodel')
|
||||
self.finetuned_tokenizer_path = model_paths.get('finetuned_tokenizer')
|
||||
|
||||
experiment_config = self.loader.get_experiment_config()
|
||||
self.experiment_name = experiment_config.get('name', 'kronos_custom_finetune')
|
||||
self.experiment_description = experiment_config.get('description', '')
|
||||
self.use_comet = experiment_config.get('use_comet', False)
|
||||
self.train_tokenizer = experiment_config.get('train_tokenizer', True)
|
||||
self.train_basemodel = experiment_config.get('train_basemodel', True)
|
||||
self.skip_existing = experiment_config.get('skip_existing', False)
|
||||
|
||||
unified_pretrained = experiment_config.get('pre_trained', None)
|
||||
self.pre_trained_tokenizer = experiment_config.get('pre_trained_tokenizer', unified_pretrained if unified_pretrained is not None else True)
|
||||
self.pre_trained_predictor = experiment_config.get('pre_trained_predictor', unified_pretrained if unified_pretrained is not None else True)
|
||||
|
||||
device_config = self.loader.get_device_config()
|
||||
self.use_cuda = device_config.get('use_cuda', True)
|
||||
self.device_id = device_config.get('device_id', 0)
|
||||
|
||||
distributed_config = self.loader.get_distributed_config()
|
||||
self.use_ddp = distributed_config.get('use_ddp', False)
|
||||
self.ddp_backend = distributed_config.get('backend', 'nccl')
|
||||
|
||||
self._compute_full_paths()
|
||||
|
||||
def _compute_full_paths(self):
|
||||
|
||||
self.tokenizer_save_path = os.path.join(self.base_save_path, self.tokenizer_save_name)
|
||||
self.tokenizer_best_model_path = os.path.join(self.tokenizer_save_path, 'best_model')
|
||||
|
||||
self.basemodel_save_path = os.path.join(self.base_save_path, self.basemodel_save_name)
|
||||
self.basemodel_best_model_path = os.path.join(self.basemodel_save_path, 'best_model')
|
||||
|
||||
def get_tokenizer_config(self):
|
||||
|
||||
return {
|
||||
'data_path': self.data_path,
|
||||
'lookback_window': self.lookback_window,
|
||||
'predict_window': self.predict_window,
|
||||
'max_context': self.max_context,
|
||||
'clip': self.clip,
|
||||
'train_ratio': self.train_ratio,
|
||||
'val_ratio': self.val_ratio,
|
||||
'test_ratio': self.test_ratio,
|
||||
'epochs': self.tokenizer_epochs,
|
||||
'batch_size': self.batch_size,
|
||||
'log_interval': self.log_interval,
|
||||
'num_workers': self.num_workers,
|
||||
'seed': self.seed,
|
||||
'learning_rate': self.tokenizer_learning_rate,
|
||||
'adam_beta1': self.adam_beta1,
|
||||
'adam_beta2': self.adam_beta2,
|
||||
'adam_weight_decay': self.adam_weight_decay,
|
||||
'accumulation_steps': self.accumulation_steps,
|
||||
'pretrained_model_path': self.pretrained_tokenizer_path,
|
||||
'save_path': self.tokenizer_save_path,
|
||||
'use_comet': self.use_comet
|
||||
}
|
||||
|
||||
def get_basemodel_config(self):
|
||||
|
||||
return {
|
||||
'data_path': self.data_path,
|
||||
'lookback_window': self.lookback_window,
|
||||
'predict_window': self.predict_window,
|
||||
'max_context': self.max_context,
|
||||
'clip': self.clip,
|
||||
'train_ratio': self.train_ratio,
|
||||
'val_ratio': self.val_ratio,
|
||||
'test_ratio': self.test_ratio,
|
||||
'epochs': self.basemodel_epochs,
|
||||
'batch_size': self.batch_size,
|
||||
'log_interval': self.log_interval,
|
||||
'num_workers': self.num_workers,
|
||||
'seed': self.seed,
|
||||
'predictor_learning_rate': self.predictor_learning_rate,
|
||||
'tokenizer_learning_rate': self.tokenizer_learning_rate,
|
||||
'adam_beta1': self.adam_beta1,
|
||||
'adam_beta2': self.adam_beta2,
|
||||
'adam_weight_decay': self.adam_weight_decay,
|
||||
'pretrained_tokenizer_path': self.finetuned_tokenizer_path,
|
||||
'pretrained_predictor_path': self.pretrained_predictor_path,
|
||||
'save_path': self.basemodel_save_path,
|
||||
'use_comet': self.use_comet
|
||||
}
|
||||
|
||||
def print_config_summary(self):
|
||||
|
||||
print("=" * 60)
|
||||
print("Kronos finetuning configuration summary")
|
||||
print("=" * 60)
|
||||
print(f"Experiment name: {self.exp_name}")
|
||||
print(f"Data path: {self.data_path}")
|
||||
print(f"Lookback window: {self.lookback_window}")
|
||||
print(f"Predict window: {self.predict_window}")
|
||||
print(f"Tokenizer training epochs: {self.tokenizer_epochs}")
|
||||
print(f"Basemodel training epochs: {self.basemodel_epochs}")
|
||||
print(f"Batch size: {self.batch_size}")
|
||||
print(f"Tokenizer learning rate: {self.tokenizer_learning_rate}")
|
||||
print(f"Predictor learning rate: {self.predictor_learning_rate}")
|
||||
print(f"Train tokenizer: {self.train_tokenizer}")
|
||||
print(f"Train basemodel: {self.train_basemodel}")
|
||||
print(f"Skip existing: {self.skip_existing}")
|
||||
print(f"Use pre-trained tokenizer: {self.pre_trained_tokenizer}")
|
||||
print(f"Use pre-trained predictor: {self.pre_trained_predictor}")
|
||||
print(f"Base save path: {self.base_save_path}")
|
||||
print(f"Tokenizer save path: {self.tokenizer_save_path}")
|
||||
print(f"Basemodel save path: {self.basemodel_save_path}")
|
||||
print("=" * 60)
|
||||
72
finetune_csv/configs/config_ali09988_candle-5min.yaml
Normal file
72
finetune_csv/configs/config_ali09988_candle-5min.yaml
Normal file
@ -0,0 +1,72 @@
|
||||
#This is a template config for custom finetuning kronos on csv data
|
||||
#这是一份模板config,用于kronos的csv自定义数据微调
|
||||
|
||||
data:
|
||||
data_path: "/xxxx/Kronos/finetune_csv/data/HK_ali_09988_kline_5min_all.csv"
|
||||
lookback_window: 512
|
||||
predict_window: 48
|
||||
max_context: 512
|
||||
clip: 5.0
|
||||
# dataset split ratio
|
||||
train_ratio: 0.9
|
||||
val_ratio: 0.1
|
||||
test_ratio: 0.0
|
||||
|
||||
training:
|
||||
# control the training epochs of tokenizer and basemodel
|
||||
tokenizer_epochs: 30
|
||||
basemodel_epochs: 20
|
||||
batch_size: 32
|
||||
log_interval: 50
|
||||
num_workers: 6
|
||||
seed: 42
|
||||
|
||||
tokenizer_learning_rate: 0.0002
|
||||
predictor_learning_rate: 0.000001
|
||||
|
||||
adam_beta1: 0.9
|
||||
adam_beta2: 0.95
|
||||
adam_weight_decay: 0.1
|
||||
|
||||
# gradient accumulation steps for tokenizer training
|
||||
accumulation_steps: 1
|
||||
|
||||
# model path configuration
|
||||
model_paths:
|
||||
# pretrained model path
|
||||
pretrained_tokenizer: "/xxx/Kronos/pretrained/Kronos-Tokenizer-base"
|
||||
pretrained_predictor: "/xxx/Kronos/pretrained/Kronos-base"
|
||||
|
||||
# experiment name - other paths will be generated based on this
|
||||
exp_name: "HK_ali_09988_kline_5min_all"
|
||||
base_path: "/xxx/Kronos/finetune_csv/finetuned/"
|
||||
|
||||
# the following paths will be generated based on exp_name, no need to modify manually
|
||||
# way 1: leave empty string, the system will generate the full path
|
||||
base_save_path: "" # /xxxx/Kronos/finetune_csv/finetuned/{exp_name}
|
||||
finetuned_tokenizer: "" # /xxxx/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model
|
||||
|
||||
# way 2: use template string, {exp_name} will be replaced with the actual experiment name
|
||||
# base_save_path: "/xxxx/Kronos/finetune_csv/finetuned/{exp_name}"
|
||||
# finetuned_tokenizer: "/xxxx/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model"
|
||||
|
||||
tokenizer_save_name: "tokenizer"
|
||||
basemodel_save_name: "basemodel"
|
||||
|
||||
experiment:
|
||||
name: "kronos_custom_finetune"
|
||||
description: "Custom finetune for HK stock data"
|
||||
use_comet: false
|
||||
|
||||
# control the training phase
|
||||
train_tokenizer: true
|
||||
train_basemodel: true
|
||||
|
||||
# if true, skip the existing model training
|
||||
skip_existing: false
|
||||
|
||||
# device configuration
|
||||
device:
|
||||
use_cuda: true
|
||||
device_id: 0
|
||||
|
||||
93913
finetune_csv/data/HK_ali_09988_kline_5min_all.csv
Normal file
93913
finetune_csv/data/HK_ali_09988_kline_5min_all.csv
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
After Width: | Height: | Size: 474 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 473 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 331 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 449 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 530 KiB |
468
finetune_csv/finetune_base_model.py
Normal file
468
finetune_csv/finetune_base_model.py
Normal file
@ -0,0 +1,468 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import pickle
|
||||
import random
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from time import gmtime, strftime
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
import datetime
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
sys.path.append('../')
|
||||
from model import Kronos, KronosTokenizer, KronosPredictor
|
||||
from config_loader import CustomFinetuneConfig
|
||||
|
||||
|
||||
class CustomKlineDataset(Dataset):
|
||||
|
||||
def __init__(self, data_path, data_type='train', lookback_window=90, predict_window=10,
|
||||
clip=5.0, seed=100, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
|
||||
self.data_path = data_path
|
||||
self.data_type = data_type
|
||||
self.lookback_window = lookback_window
|
||||
self.predict_window = predict_window
|
||||
self.window = lookback_window + predict_window + 1
|
||||
self.clip = clip
|
||||
self.seed = seed
|
||||
self.train_ratio = train_ratio
|
||||
self.val_ratio = val_ratio
|
||||
self.test_ratio = test_ratio
|
||||
|
||||
self.feature_list = ['open', 'high', 'low', 'close', 'volume', 'amount']
|
||||
self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month']
|
||||
|
||||
self.py_rng = random.Random(seed)
|
||||
|
||||
self._load_and_preprocess_data()
|
||||
self._split_data_by_time()
|
||||
|
||||
self.n_samples = len(self.data) - self.window + 1
|
||||
|
||||
print(f"[{data_type.upper()}] Data length: {len(self.data)}, Available samples: {self.n_samples}")
|
||||
|
||||
def _load_and_preprocess_data(self):
|
||||
df = pd.read_csv(self.data_path)
|
||||
|
||||
df['timestamps'] = pd.to_datetime(df['timestamps'])
|
||||
df = df.sort_values('timestamps').reset_index(drop=True)
|
||||
|
||||
self.timestamps = df['timestamps'].copy()
|
||||
|
||||
df['minute'] = df['timestamps'].dt.minute
|
||||
df['hour'] = df['timestamps'].dt.hour
|
||||
df['weekday'] = df['timestamps'].dt.weekday
|
||||
df['day'] = df['timestamps'].dt.day
|
||||
df['month'] = df['timestamps'].dt.month
|
||||
|
||||
self.data = df[self.feature_list + self.time_feature_list].copy()
|
||||
|
||||
if self.data.isnull().any().any():
|
||||
print("Warning: Missing values found in data, performing forward fill")
|
||||
self.data = self.data.fillna(method='ffill')
|
||||
|
||||
print(f"Original data time range: {self.timestamps.min()} to {self.timestamps.max()}")
|
||||
print(f"Original data total length: {len(df)} records")
|
||||
|
||||
def _split_data_by_time(self):
|
||||
total_length = len(self.data)
|
||||
|
||||
train_end = int(total_length * self.train_ratio)
|
||||
val_end = int(total_length * (self.train_ratio + self.val_ratio))
|
||||
|
||||
if self.data_type == 'train':
|
||||
self.data = self.data.iloc[:train_end].copy()
|
||||
self.timestamps = self.timestamps.iloc[:train_end].copy()
|
||||
print(f"[{self.data_type.upper()}] Training set: first {train_end} time points ({self.train_ratio})")
|
||||
print(f"[{self.data_type.upper()}] Training set time range: {self.timestamps.min()} to {self.timestamps.max()}")
|
||||
elif self.data_type == 'val':
|
||||
self.data = self.data.iloc[train_end:val_end].copy()
|
||||
self.timestamps = self.timestamps.iloc[train_end:val_end].copy()
|
||||
print(f"[{self.data_type.upper()}] Validation set: time points {train_end+1} to {val_end} ({self.val_ratio})")
|
||||
print(f"[{self.data_type.upper()}] Validation set time range: {self.timestamps.min()} to {self.timestamps.max()}")
|
||||
elif self.data_type == 'test':
|
||||
self.data = self.data.iloc[val_end:].copy()
|
||||
self.timestamps = self.timestamps.iloc[val_end:].copy()
|
||||
print(f"[{self.data_type.upper()}] Test set: after time point {val_end+1}")
|
||||
print(f"[{self.data_type.upper()}] Test set time range: {self.timestamps.min()} to {self.timestamps.max()}")
|
||||
|
||||
print(f"[{self.data_type.upper()}] Data length after split: {len(self.data)} records")
|
||||
|
||||
def set_epoch_seed(self, epoch):
|
||||
epoch_seed = self.seed + epoch
|
||||
self.py_rng.seed(epoch_seed)
|
||||
self.current_epoch = epoch
|
||||
|
||||
def __len__(self):
|
||||
return self.n_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
max_start = len(self.data) - self.window
|
||||
if max_start <= 0:
|
||||
raise ValueError("Data length insufficient to create samples")
|
||||
|
||||
if self.data_type == 'train':
|
||||
epoch = getattr(self, 'current_epoch', 0)
|
||||
start_idx = (idx * 9973 + (epoch + 1) * 104729) % (max_start + 1)
|
||||
else:
|
||||
start_idx = idx % (max_start + 1)
|
||||
|
||||
end_idx = start_idx + self.window
|
||||
|
||||
window_data = self.data.iloc[start_idx:end_idx]
|
||||
|
||||
x = window_data[self.feature_list].values.astype(np.float32)
|
||||
x_stamp = window_data[self.time_feature_list].values.astype(np.float32)
|
||||
|
||||
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
|
||||
x = (x - x_mean) / (x_std + 1e-5)
|
||||
x = np.clip(x, -self.clip, self.clip)
|
||||
|
||||
x_tensor = torch.from_numpy(x)
|
||||
x_stamp_tensor = torch.from_numpy(x_stamp)
|
||||
|
||||
return x_tensor, x_stamp_tensor
|
||||
|
||||
|
||||
|
||||
|
||||
def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger:
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
logger = logging.getLogger(f"basemodel_training_rank_{rank}")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
if logger.handlers:
|
||||
return logger
|
||||
|
||||
log_file = os.path.join(log_dir, f"basemodel_training_rank_{rank}.log")
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=10*1024*1024,
|
||||
backupCount=5,
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setLevel(logging.INFO)
|
||||
|
||||
console_handler = None
|
||||
if rank == 0:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
if console_handler is not None:
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
if console_handler is not None:
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
logger.info(f"=== Basemodel Training Started ===")
|
||||
logger.info(f"Experiment Name: {exp_name}")
|
||||
logger.info(f"Log Directory: {log_dir}")
|
||||
logger.info(f"Rank: {rank}")
|
||||
logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def create_dataloaders(config):
|
||||
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
|
||||
print("Creating data loaders...")
|
||||
|
||||
train_dataset = CustomKlineDataset(
|
||||
data_path=config.data_path,
|
||||
data_type='train',
|
||||
lookback_window=config.lookback_window,
|
||||
predict_window=config.predict_window,
|
||||
clip=config.clip,
|
||||
seed=config.seed,
|
||||
train_ratio=config.train_ratio,
|
||||
val_ratio=config.val_ratio,
|
||||
test_ratio=config.test_ratio
|
||||
)
|
||||
|
||||
val_dataset = CustomKlineDataset(
|
||||
data_path=config.data_path,
|
||||
data_type='val',
|
||||
lookback_window=config.lookback_window,
|
||||
predict_window=config.predict_window,
|
||||
clip=config.clip,
|
||||
seed=config.seed + 1,
|
||||
train_ratio=config.train_ratio,
|
||||
val_ratio=config.val_ratio,
|
||||
test_ratio=config.test_ratio
|
||||
)
|
||||
|
||||
use_ddp = dist.is_available() and dist.is_initialized()
|
||||
train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None
|
||||
val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=config.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=config.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
sampler=val_sampler
|
||||
)
|
||||
|
||||
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
|
||||
print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
|
||||
|
||||
return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler
|
||||
|
||||
|
||||
def train_model(model, tokenizer, device, config, save_dir, logger):
|
||||
logger.info("Starting training...")
|
||||
use_ddp = dist.is_available() and dist.is_initialized()
|
||||
rank = dist.get_rank() if use_ddp else 0
|
||||
world_size = dist.get_world_size() if use_ddp else 1
|
||||
|
||||
train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config)
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=config.predictor_learning_rate,
|
||||
betas=(config.adam_beta1, config.adam_beta2),
|
||||
weight_decay=config.adam_weight_decay
|
||||
)
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=config.predictor_learning_rate,
|
||||
steps_per_epoch=len(train_loader),
|
||||
epochs=config.basemodel_epochs,
|
||||
pct_start=0.03,
|
||||
div_factor=10
|
||||
)
|
||||
|
||||
if use_ddp:
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
|
||||
|
||||
best_val_loss = float('inf')
|
||||
batch_idx_global = 0
|
||||
|
||||
for epoch in range(config.basemodel_epochs):
|
||||
epoch_start_time = time.time()
|
||||
model.train()
|
||||
|
||||
train_dataset.set_epoch_seed(epoch * 10000)
|
||||
val_dataset.set_epoch_seed(0)
|
||||
if train_sampler is not None:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
epoch_train_loss = 0.0
|
||||
train_batches = 0
|
||||
|
||||
for batch_idx, (batch_x, batch_x_stamp) in enumerate(train_loader):
|
||||
batch_x = batch_x.to(device, non_blocking=True)
|
||||
batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)
|
||||
|
||||
with torch.no_grad():
|
||||
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
|
||||
|
||||
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
|
||||
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
|
||||
|
||||
logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
|
||||
loss, s1_loss, s2_loss = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=3.0)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
epoch_train_loss += loss.item()
|
||||
train_batches += 1
|
||||
|
||||
if (batch_idx_global + 1) % config.log_interval == 0:
|
||||
lr = optimizer.param_groups[0]['lr']
|
||||
log_msg = (f"[Epoch {epoch+1}/{config.basemodel_epochs}, Step {batch_idx+1}/{len(train_loader)}] "
|
||||
f"LR: {lr:.6f}, Loss: {loss.item():.4f}")
|
||||
logger.info(log_msg)
|
||||
if rank == 0:
|
||||
print(log_msg)
|
||||
|
||||
batch_idx_global += 1
|
||||
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
val_batches = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_x, batch_x_stamp in val_loader:
|
||||
batch_x = batch_x.to(device, non_blocking=True)
|
||||
batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)
|
||||
|
||||
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
|
||||
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
|
||||
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
|
||||
|
||||
logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
|
||||
loss, _, _ = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
|
||||
|
||||
val_loss += loss.item()
|
||||
val_batches += 1
|
||||
|
||||
if use_ddp:
|
||||
tensor_sum = torch.tensor([epoch_train_loss, train_batches, val_loss, val_batches], dtype=torch.float64, device=device)
|
||||
dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM)
|
||||
epoch_train_loss_all = tensor_sum[0].item()
|
||||
train_batches_all = int(tensor_sum[1].item())
|
||||
val_loss_all = tensor_sum[2].item()
|
||||
val_batches_all = int(tensor_sum[3].item())
|
||||
avg_train_loss = (epoch_train_loss_all / train_batches_all) if train_batches_all > 0 else 0.0
|
||||
avg_val_loss = (val_loss_all / val_batches_all) if val_batches_all > 0 else 0.0
|
||||
else:
|
||||
avg_train_loss = epoch_train_loss / train_batches if train_batches > 0 else 0
|
||||
avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
|
||||
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
epoch_summary = (f"\n--- Epoch {epoch+1}/{config.basemodel_epochs} Summary ---\n"
|
||||
f"Training Loss: {avg_train_loss:.4f}\n"
|
||||
f"Validation Loss: {avg_val_loss:.4f}\n"
|
||||
f"Epoch Time: {epoch_time:.2f} seconds\n")
|
||||
logger.info(epoch_summary)
|
||||
if rank == 0:
|
||||
print(epoch_summary)
|
||||
|
||||
if avg_val_loss < best_val_loss:
|
||||
best_val_loss = avg_val_loss
|
||||
if rank == 0:
|
||||
model_save_path = os.path.join(save_dir, "best_model")
|
||||
os.makedirs(model_save_path, exist_ok=True)
|
||||
(model.module if use_ddp else model).save_pretrained(model_save_path)
|
||||
save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})"
|
||||
logger.info(save_msg)
|
||||
print(save_msg)
|
||||
|
||||
return best_val_loss
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Kronos Basemodel Fine-tuning Training')
|
||||
parser.add_argument('--config', type=str, default='config.yaml',
|
||||
help='Configuration file path (default: config.yaml)')
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
config = CustomFinetuneConfig(args.config)
|
||||
|
||||
os.makedirs(config.basemodel_save_path, exist_ok=True)
|
||||
|
||||
log_dir = os.path.join(config.base_save_path, "logs")
|
||||
logger = setup_logging(config.exp_name, log_dir, 0)
|
||||
|
||||
torch.manual_seed(config.seed)
|
||||
np.random.seed(config.seed)
|
||||
random.seed(config.seed)
|
||||
|
||||
logger.info("Loading pretrained model or random initialization...")
|
||||
print("Loading pretrained model or random initialization...")
|
||||
if getattr(config, 'pre_trained_tokenizer', True):
|
||||
tokenizer = KronosTokenizer.from_pretrained(config.finetuned_tokenizer_path)
|
||||
else:
|
||||
import json, os
|
||||
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for training")
|
||||
cfg_path_tok = os.path.join(config.pretrained_tokenizer_path if hasattr(config, 'pretrained_tokenizer_path') else config.finetuned_tokenizer_path, 'config.json')
|
||||
with open(cfg_path_tok, 'r') as f:
|
||||
arch_t = json.load(f)
|
||||
tokenizer = KronosTokenizer(
|
||||
d_in=arch_t.get('d_in', 6),
|
||||
d_model=arch_t.get('d_model', 256),
|
||||
n_heads=arch_t.get('n_heads', 4),
|
||||
ff_dim=arch_t.get('ff_dim', 512),
|
||||
n_enc_layers=arch_t.get('n_enc_layers', 4),
|
||||
n_dec_layers=arch_t.get('n_dec_layers', 4),
|
||||
ffn_dropout_p=arch_t.get('ffn_dropout_p', 0.0),
|
||||
attn_dropout_p=arch_t.get('attn_dropout_p', 0.0),
|
||||
resid_dropout_p=arch_t.get('resid_dropout_p', 0.0),
|
||||
s1_bits=arch_t.get('s1_bits', 10),
|
||||
s2_bits=arch_t.get('s2_bits', 10),
|
||||
beta=arch_t.get('beta', 0.05),
|
||||
gamma0=arch_t.get('gamma0', 1.0),
|
||||
gamma=arch_t.get('gamma', 1.1),
|
||||
zeta=arch_t.get('zeta', 0.05),
|
||||
group_size=arch_t.get('group_size', 4)
|
||||
)
|
||||
|
||||
if getattr(config, 'pre_trained_predictor', True):
|
||||
model = Kronos.from_pretrained(config.pretrained_predictor_path)
|
||||
else:
|
||||
import json, os
|
||||
print("pre_trained_predictor=False, randomly initializing Predictor architecture for training")
|
||||
cfg_path = os.path.join(config.pretrained_predictor_path, 'config.json')
|
||||
with open(cfg_path, 'r') as f:
|
||||
arch = json.load(f)
|
||||
model = Kronos(
|
||||
s1_bits=arch.get('s1_bits', 10),
|
||||
s2_bits=arch.get('s2_bits', 10),
|
||||
n_layers=arch.get('n_layers', 12),
|
||||
d_model=arch.get('d_model', 832),
|
||||
n_heads=arch.get('n_heads', 16),
|
||||
ff_dim=arch.get('ff_dim', 2048),
|
||||
ffn_dropout_p=arch.get('ffn_dropout_p', 0.2),
|
||||
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
||||
resid_dropout_p=arch.get('resid_dropout_p', 0.2),
|
||||
token_dropout_p=arch.get('token_dropout_p', 0.0),
|
||||
learn_te=arch.get('learn_te', True)
|
||||
)
|
||||
|
||||
tokenizer = tokenizer.to(device)
|
||||
model = model.to(device)
|
||||
|
||||
model_size = sum(p.numel() for p in model.parameters())
|
||||
logger.info(f"Model parameters: {model_size:,}")
|
||||
print(f"Model parameters: {model_size:,}")
|
||||
|
||||
logger.info("=== Training Configuration ===")
|
||||
logger.info(f"Data path: {config.data_path}")
|
||||
logger.info(f"Lookback window: {config.lookback_window}")
|
||||
logger.info(f"Predict window: {config.predict_window}")
|
||||
logger.info(f"Batch size: {config.batch_size}")
|
||||
logger.info(f"Learning rate: {config.predictor_learning_rate}")
|
||||
logger.info(f"Training epochs: {config.basemodel_epochs}")
|
||||
logger.info(f"Device: {device}")
|
||||
logger.info(f"Tokenizer path: {config.finetuned_tokenizer_path}")
|
||||
logger.info(f"Pretrained model path: {config.pretrained_predictor_path}")
|
||||
|
||||
logger.info("Starting fine-tuning training...")
|
||||
print("Starting fine-tuning training...")
|
||||
best_val_loss = train_model(model, tokenizer, device, config, config.basemodel_save_path, logger)
|
||||
|
||||
final_msg = f"Training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.basemodel_save_path}"
|
||||
logger.info(final_msg)
|
||||
print(final_msg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
359
finetune_csv/finetune_tokenizer.py
Normal file
359
finetune_csv/finetune_tokenizer.py
Normal file
@ -0,0 +1,359 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from time import gmtime, strftime
|
||||
import datetime
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
sys.path.append("../")
|
||||
from model import KronosTokenizer
|
||||
from finetune_base_model import CustomKlineDataset
|
||||
from config_loader import CustomFinetuneConfig
|
||||
|
||||
|
||||
def set_seed(seed: int, rank: int = 0):
|
||||
actual_seed = seed
|
||||
random.seed(actual_seed)
|
||||
np.random.seed(actual_seed)
|
||||
torch.manual_seed(actual_seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(actual_seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def get_model_size(model: torch.nn.Module) -> str:
|
||||
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
if total_params >= 1e9:
|
||||
return f"{total_params / 1e9:.1f}B"
|
||||
elif total_params >= 1e6:
|
||||
return f"{total_params / 1e6:.1f}M"
|
||||
else:
|
||||
return f"{total_params / 1e3:.1f}K"
|
||||
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
return str(datetime.timedelta(seconds=int(seconds)))
|
||||
|
||||
|
||||
def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger:
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
logger = logging.getLogger(f"tokenizer_training_rank_{rank}")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
if logger.handlers:
|
||||
return logger
|
||||
|
||||
log_file = os.path.join(log_dir, f"tokenizer_training_rank_{rank}.log")
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=10*1024*1024,
|
||||
backupCount=5,
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setLevel(logging.INFO)
|
||||
|
||||
console_handler = None
|
||||
if rank == 0:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
if console_handler is not None:
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
if console_handler is not None:
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
logger.info(f"=== Tokenizer Training Started ===")
|
||||
logger.info(f"Experiment Name: {exp_name}")
|
||||
logger.info(f"Log Directory: {log_dir}")
|
||||
logger.info(f"Rank: {rank}")
|
||||
logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def create_dataloaders(config):
|
||||
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
|
||||
print("Creating tokenizer training data loaders...")
|
||||
|
||||
train_dataset = CustomKlineDataset(
|
||||
data_path=config.data_path,
|
||||
data_type="train",
|
||||
lookback_window=config.lookback_window,
|
||||
predict_window=config.predict_window,
|
||||
clip=config.clip,
|
||||
seed=config.seed,
|
||||
train_ratio=config.train_ratio,
|
||||
val_ratio=config.val_ratio,
|
||||
test_ratio=config.test_ratio
|
||||
)
|
||||
|
||||
val_dataset = CustomKlineDataset(
|
||||
data_path=config.data_path,
|
||||
data_type="val",
|
||||
lookback_window=config.lookback_window,
|
||||
predict_window=config.predict_window,
|
||||
clip=config.clip,
|
||||
seed=config.seed + 1,
|
||||
train_ratio=config.train_ratio,
|
||||
val_ratio=config.val_ratio,
|
||||
test_ratio=config.test_ratio
|
||||
)
|
||||
|
||||
use_ddp = dist.is_available() and dist.is_initialized()
|
||||
train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None
|
||||
val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=config.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=config.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
sampler=val_sampler
|
||||
)
|
||||
|
||||
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
|
||||
print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
|
||||
|
||||
return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler
|
||||
|
||||
|
||||
def train_tokenizer(model, device, config, save_dir, logger):
|
||||
logger.info("Starting tokenizer training...")
|
||||
use_ddp = dist.is_available() and dist.is_initialized()
|
||||
rank = dist.get_rank() if use_ddp else 0
|
||||
world_size = dist.get_world_size() if use_ddp else 1
|
||||
|
||||
train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config)
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=config.tokenizer_learning_rate,
|
||||
weight_decay=config.adam_weight_decay
|
||||
)
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=config.tokenizer_learning_rate,
|
||||
steps_per_epoch=len(train_loader),
|
||||
epochs=config.tokenizer_epochs,
|
||||
pct_start=0.03,
|
||||
div_factor=10
|
||||
)
|
||||
|
||||
if use_ddp:
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
|
||||
|
||||
best_val_loss = float("inf")
|
||||
batch_idx_global = 0
|
||||
|
||||
accumulation_steps = getattr(config, 'accumulation_steps', 1)
|
||||
|
||||
for epoch in range(config.tokenizer_epochs):
|
||||
epoch_start_time = time.time()
|
||||
model.train()
|
||||
|
||||
train_dataset.set_epoch_seed(epoch * 10000)
|
||||
val_dataset.set_epoch_seed(0)
|
||||
if train_sampler is not None:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
for batch_idx, (ori_batch_x, _) in enumerate(train_loader):
|
||||
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
|
||||
|
||||
current_batch_total_loss = 0.0
|
||||
for j in range(accumulation_steps):
|
||||
start_idx = j * (ori_batch_x.shape[0] // accumulation_steps)
|
||||
end_idx = (j + 1) * (ori_batch_x.shape[0] // accumulation_steps)
|
||||
batch_x = ori_batch_x[start_idx:end_idx]
|
||||
|
||||
zs, bsq_loss, _, _ = (model.module if use_ddp else model)(batch_x)
|
||||
z_pre, z = zs
|
||||
|
||||
recon_loss_pre = F.mse_loss(z_pre, batch_x)
|
||||
recon_loss_all = F.mse_loss(z, batch_x)
|
||||
recon_loss = recon_loss_pre + recon_loss_all
|
||||
loss = (recon_loss + bsq_loss) / 2
|
||||
|
||||
loss_scaled = loss / accumulation_steps
|
||||
current_batch_total_loss += loss.item()
|
||||
loss_scaled.backward()
|
||||
|
||||
torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=2.0)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if (batch_idx_global + 1) % config.log_interval == 0:
|
||||
avg_loss = current_batch_total_loss / accumulation_steps
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
log_msg = (f"[Epoch {epoch+1}/{config.tokenizer_epochs}, Step {batch_idx+1}/{len(train_loader)}] "
|
||||
f"LR: {lr:.6f}, Loss: {avg_loss:.4f}")
|
||||
logger.info(log_msg)
|
||||
if rank == 0:
|
||||
print(log_msg)
|
||||
|
||||
detail_msg = (f" - VQ Loss: {bsq_loss.item():.4f}\n"
|
||||
f" - Recon Loss Pre: {recon_loss_pre.item():.4f}\n"
|
||||
f" - Recon Loss All: {recon_loss_all.item():.4f}")
|
||||
logger.info(detail_msg)
|
||||
if rank == 0:
|
||||
print(detail_msg)
|
||||
|
||||
batch_idx_global += 1
|
||||
|
||||
model.eval()
|
||||
tot_val_loss_sum_rank = 0.0
|
||||
val_sample_count_rank = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for ori_batch_x, _ in val_loader:
|
||||
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
|
||||
zs, _, _, _ = (model.module if use_ddp else model)(ori_batch_x)
|
||||
_, z = zs
|
||||
val_loss_item = F.mse_loss(z, ori_batch_x)
|
||||
|
||||
tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0)
|
||||
val_sample_count_rank += ori_batch_x.size(0)
|
||||
|
||||
if use_ddp:
|
||||
tensor_sum = torch.tensor([tot_val_loss_sum_rank, val_sample_count_rank], dtype=torch.float64, device=device)
|
||||
dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM)
|
||||
tot_val_loss_all = tensor_sum[0].item()
|
||||
val_count_all = int(tensor_sum[1].item())
|
||||
avg_val_loss = (tot_val_loss_all / val_count_all) if val_count_all > 0 else 0.0
|
||||
else:
|
||||
avg_val_loss = tot_val_loss_sum_rank / val_sample_count_rank if val_sample_count_rank > 0 else 0
|
||||
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
epoch_summary = (f"\n--- Epoch {epoch+1}/{config.tokenizer_epochs} Summary ---\n"
|
||||
f"Validation Loss: {avg_val_loss:.4f}\n"
|
||||
f"Epoch Time: {format_time(epoch_time)}\n"
|
||||
f"Total Training Time: {format_time(time.time() - epoch_start_time)}\n")
|
||||
logger.info(epoch_summary)
|
||||
if rank == 0:
|
||||
print(epoch_summary)
|
||||
|
||||
if avg_val_loss < best_val_loss:
|
||||
best_val_loss = avg_val_loss
|
||||
if rank == 0:
|
||||
model_save_path = os.path.join(save_dir, "best_model")
|
||||
os.makedirs(model_save_path, exist_ok=True)
|
||||
(model.module if use_ddp else model).save_pretrained(model_save_path)
|
||||
save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})"
|
||||
logger.info(save_msg)
|
||||
print(save_msg)
|
||||
|
||||
return best_val_loss
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Kronos Tokenizer Fine-tuning Training')
|
||||
parser.add_argument('--config', type=str, default='config.yaml',
|
||||
help='Configuration file path (default: config.yaml)')
|
||||
args = parser.parse_args()
|
||||
|
||||
config = CustomFinetuneConfig(args.config)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
config = CustomFinetuneConfig(args.config)
|
||||
|
||||
os.makedirs(config.tokenizer_save_path, exist_ok=True)
|
||||
|
||||
log_dir = os.path.join(config.base_save_path, "logs")
|
||||
logger = setup_logging(config.exp_name, log_dir, 0)
|
||||
|
||||
set_seed(config.seed)
|
||||
|
||||
# 加载预训练tokenizer
|
||||
if getattr(config, 'pre_trained_tokenizer', True):
|
||||
logger.info("Loading pretrained tokenizer...")
|
||||
print("Loading pretrained tokenizer...")
|
||||
tokenizer = KronosTokenizer.from_pretrained(config.pretrained_tokenizer_path)
|
||||
else:
|
||||
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture")
|
||||
import json, os
|
||||
cfg_path = os.path.join(config.pretrained_tokenizer_path, 'config.json')
|
||||
with open(cfg_path, 'r') as f:
|
||||
arch = json.load(f)
|
||||
tokenizer = KronosTokenizer(
|
||||
d_in=arch.get('d_in', 6),
|
||||
d_model=arch.get('d_model', 256),
|
||||
n_heads=arch.get('n_heads', 4),
|
||||
ff_dim=arch.get('ff_dim', 512),
|
||||
n_enc_layers=arch.get('n_enc_layers', 4),
|
||||
n_dec_layers=arch.get('n_dec_layers', 4),
|
||||
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
|
||||
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
||||
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
|
||||
s1_bits=arch.get('s1_bits', 10),
|
||||
s2_bits=arch.get('s2_bits', 10),
|
||||
beta=arch.get('beta', 0.05),
|
||||
gamma0=arch.get('gamma0', 1.0),
|
||||
gamma=arch.get('gamma', 1.1),
|
||||
zeta=arch.get('zeta', 0.05),
|
||||
group_size=arch.get('group_size', 4)
|
||||
)
|
||||
tokenizer = tokenizer.to(device)
|
||||
|
||||
model_size = get_model_size(tokenizer)
|
||||
logger.info(f"Tokenizer parameters: {model_size}")
|
||||
print(f"Tokenizer parameters: {model_size}")
|
||||
|
||||
logger.info("=== Training Configuration ===")
|
||||
logger.info(f"Data path: {config.data_path}")
|
||||
logger.info(f"Lookback window: {config.lookback_window}")
|
||||
logger.info(f"Predict window: {config.predict_window}")
|
||||
logger.info(f"Batch size: {config.batch_size}")
|
||||
logger.info(f"Learning rate: {config.tokenizer_learning_rate}")
|
||||
logger.info(f"Training epochs: {config.tokenizer_epochs}")
|
||||
logger.info(f"Device: {device}")
|
||||
logger.info(f"Distributed training: False")
|
||||
|
||||
logger.info("Starting tokenizer fine-tuning training...")
|
||||
print("Starting tokenizer fine-tuning training...")
|
||||
best_val_loss = train_tokenizer(tokenizer, device, config, config.tokenizer_save_path, logger)
|
||||
|
||||
final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.tokenizer_save_path}"
|
||||
logger.info(final_msg)
|
||||
print(final_msg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
361
finetune_csv/train_sequential.py
Normal file
361
finetune_csv/train_sequential.py
Normal file
@ -0,0 +1,361 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.distributed as dist
|
||||
|
||||
sys.path.append('../')
|
||||
from model import Kronos, KronosTokenizer, KronosPredictor
|
||||
|
||||
from config_loader import CustomFinetuneConfig
|
||||
from finetune_tokenizer import train_tokenizer, set_seed, setup_logging as setup_tokenizer_logging
|
||||
from finetune_base_model import train_model, create_dataloaders, setup_logging as setup_basemodel_logging
|
||||
|
||||
|
||||
class SequentialTrainer:
|
||||
|
||||
def __init__(self, config_path: str = None):
|
||||
self.config = CustomFinetuneConfig(config_path)
|
||||
self.rank = int(os.environ.get("RANK", "0"))
|
||||
self.world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
self.local_rank = int(os.environ.get("LOCAL_RANK", str(self.config.device_id if hasattr(self.config, 'device_id') else 0)))
|
||||
self.device = self._setup_device()
|
||||
|
||||
self.config.print_config_summary()
|
||||
|
||||
def _setup_device(self):
|
||||
if self.config.use_cuda and torch.cuda.is_available():
|
||||
torch.cuda.set_device(self.local_rank)
|
||||
device = torch.device(f"cuda:{self.local_rank}")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
if self.rank == 0:
|
||||
print(f"Using device: {device} (rank={self.rank}, world_size={self.world_size}, local_rank={self.local_rank})")
|
||||
return device
|
||||
|
||||
def _setup_distributed(self):
|
||||
if self.world_size > 1 and torch.cuda.is_available():
|
||||
backend = os.environ.get("DIST_BACKEND", "nccl").lower()
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend=backend)
|
||||
if self.rank == 0:
|
||||
print(f"Distributed training initialized: backend={backend}, world_size={self.world_size}")
|
||||
else:
|
||||
if self.rank == 0:
|
||||
print("Distributed training not enabled, using single GPU/CPU training")
|
||||
|
||||
def _check_existing_models(self):
|
||||
tokenizer_exists = os.path.exists(self.config.tokenizer_best_model_path)
|
||||
basemodel_exists = os.path.exists(self.config.basemodel_best_model_path)
|
||||
|
||||
print(f"Tokenizer model exists: {tokenizer_exists}")
|
||||
print(f"Basemodel model exists: {basemodel_exists}")
|
||||
|
||||
return tokenizer_exists, basemodel_exists
|
||||
|
||||
def _create_directories(self):
|
||||
os.makedirs(self.config.tokenizer_save_path, exist_ok=True)
|
||||
os.makedirs(self.config.basemodel_save_path, exist_ok=True)
|
||||
print(f"Created directory: {self.config.tokenizer_save_path}")
|
||||
print(f"Created directory: {self.config.basemodel_save_path}")
|
||||
|
||||
def train_tokenizer_phase(self):
|
||||
print("\n" + "="*60)
|
||||
print("Starting Tokenizer Fine-tuning Phase")
|
||||
print("="*60)
|
||||
|
||||
tokenizer_exists, _ = self._check_existing_models()
|
||||
if tokenizer_exists and self.config.skip_existing:
|
||||
print("Tokenizer model already exists, skipping training")
|
||||
return True
|
||||
|
||||
log_dir = os.path.join(self.config.base_save_path, "logs")
|
||||
logger = setup_tokenizer_logging(self.config.exp_name, log_dir, self.rank)
|
||||
|
||||
set_seed(self.config.seed)
|
||||
|
||||
if getattr(self.config, 'pre_trained_tokenizer', True):
|
||||
logger.info("Loading pretrained tokenizer...")
|
||||
if self.rank == 0:
|
||||
print("Loading pretrained tokenizer...")
|
||||
tokenizer = KronosTokenizer.from_pretrained(self.config.pretrained_tokenizer_path)
|
||||
else:
|
||||
if self.rank == 0:
|
||||
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture")
|
||||
import json
|
||||
cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json')
|
||||
with open(cfg_path, 'r') as f:
|
||||
arch = json.load(f)
|
||||
tokenizer = KronosTokenizer(
|
||||
d_in=arch.get('d_in', 6),
|
||||
d_model=arch.get('d_model', 256),
|
||||
n_heads=arch.get('n_heads', 4),
|
||||
ff_dim=arch.get('ff_dim', 512),
|
||||
n_enc_layers=arch.get('n_enc_layers', 4),
|
||||
n_dec_layers=arch.get('n_dec_layers', 4),
|
||||
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
|
||||
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
||||
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
|
||||
s1_bits=arch.get('s1_bits', 10),
|
||||
s2_bits=arch.get('s2_bits', 10),
|
||||
beta=arch.get('beta', 0.05),
|
||||
gamma0=arch.get('gamma0', 1.0),
|
||||
gamma=arch.get('gamma', 1.1),
|
||||
zeta=arch.get('zeta', 0.05),
|
||||
group_size=arch.get('group_size', 4)
|
||||
)
|
||||
tokenizer = tokenizer.to(self.device)
|
||||
|
||||
model_size = sum(p.numel() for p in tokenizer.parameters())
|
||||
logger.info(f"Tokenizer parameters: {model_size:,}")
|
||||
if self.rank == 0:
|
||||
print(f"Tokenizer parameters: {model_size:,}")
|
||||
|
||||
logger.info("=== Training Configuration ===")
|
||||
logger.info(f"Data path: {self.config.data_path}")
|
||||
logger.info(f"Lookback window: {self.config.lookback_window}")
|
||||
logger.info(f"Predict window: {self.config.predict_window}")
|
||||
logger.info(f"Batch size: {self.config.batch_size}")
|
||||
logger.info(f"Learning rate: {self.config.tokenizer_learning_rate}")
|
||||
logger.info(f"Training epochs: {self.config.tokenizer_epochs}")
|
||||
logger.info(f"Device: {self.device}")
|
||||
logger.info(f"Distributed training: False")
|
||||
|
||||
logger.info("Starting tokenizer fine-tuning training...")
|
||||
if self.rank == 0:
|
||||
print("Starting tokenizer fine-tuning training...")
|
||||
start_time = time.time()
|
||||
best_val_loss = train_tokenizer(
|
||||
tokenizer,
|
||||
self.device,
|
||||
self.config,
|
||||
self.config.tokenizer_save_path,
|
||||
logger,
|
||||
)
|
||||
training_time = time.time() - start_time
|
||||
|
||||
final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.tokenizer_save_path}"
|
||||
logger.info(final_msg)
|
||||
if self.rank == 0:
|
||||
print(f"\n{final_msg}")
|
||||
|
||||
return True
|
||||
|
||||
def train_basemodel_phase(self):
|
||||
print("\n" + "="*60)
|
||||
print("Starting Basemodel Fine-tuning Phase")
|
||||
print("="*60)
|
||||
|
||||
if getattr(self.config, 'pre_trained_tokenizer', True):
|
||||
if not os.path.exists(self.config.finetuned_tokenizer_path):
|
||||
raise FileNotFoundError(f"Fine-tuned tokenizer does not exist: {self.config.finetuned_tokenizer_path}")
|
||||
|
||||
_, basemodel_exists = self._check_existing_models()
|
||||
if basemodel_exists and self.config.skip_existing:
|
||||
print("Basemodel model already exists, skipping training")
|
||||
return True
|
||||
|
||||
log_dir = os.path.join(self.config.base_save_path, "logs")
|
||||
logger = setup_basemodel_logging(self.config.exp_name, log_dir, self.rank)
|
||||
|
||||
set_seed(self.config.seed)
|
||||
|
||||
if getattr(self.config, 'pre_trained_tokenizer', True):
|
||||
logger.info("Loading fine-tuned tokenizer...")
|
||||
if self.rank == 0:
|
||||
print("Loading fine-tuned tokenizer...")
|
||||
tokenizer = KronosTokenizer.from_pretrained(self.config.finetuned_tokenizer_path)
|
||||
else:
|
||||
if self.rank == 0:
|
||||
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for Predictor training")
|
||||
import json
|
||||
cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json')
|
||||
with open(cfg_path, 'r') as f:
|
||||
arch = json.load(f)
|
||||
tokenizer = KronosTokenizer(
|
||||
d_in=arch.get('d_in', 6),
|
||||
d_model=arch.get('d_model', 256),
|
||||
n_heads=arch.get('n_heads', 4),
|
||||
ff_dim=arch.get('ff_dim', 512),
|
||||
n_enc_layers=arch.get('n_enc_layers', 4),
|
||||
n_dec_layers=arch.get('n_dec_layers', 4),
|
||||
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
|
||||
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
||||
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
|
||||
s1_bits=arch.get('s1_bits', 10),
|
||||
s2_bits=arch.get('s2_bits', 10),
|
||||
beta=arch.get('beta', 0.05),
|
||||
gamma0=arch.get('gamma0', 1.0),
|
||||
gamma=arch.get('gamma', 1.1),
|
||||
zeta=arch.get('zeta', 0.05),
|
||||
group_size=arch.get('group_size', 4)
|
||||
)
|
||||
tokenizer = tokenizer.to(self.device)
|
||||
|
||||
if getattr(self.config, 'pre_trained_predictor', True):
|
||||
logger.info("Loading pretrained predictor...")
|
||||
if self.rank == 0:
|
||||
print("Loading pretrained predictor...")
|
||||
model = Kronos.from_pretrained(self.config.pretrained_predictor_path)
|
||||
else:
|
||||
if self.rank == 0:
|
||||
print("pre_trained_predictor=False, randomly initializing Predictor architecture")
|
||||
import json
|
||||
cfg_path = os.path.join(self.config.pretrained_predictor_path, 'config.json')
|
||||
with open(cfg_path, 'r') as f:
|
||||
arch = json.load(f)
|
||||
print("model_config: ", arch)
|
||||
model = Kronos(
|
||||
s1_bits=arch.get('s1_bits', 10),
|
||||
s2_bits=arch.get('s2_bits', 10),
|
||||
n_layers=arch.get('n_layers', 12),
|
||||
d_model=arch.get('d_model', 832),
|
||||
n_heads=arch.get('n_heads', 16),
|
||||
ff_dim=arch.get('ff_dim', 2048),
|
||||
ffn_dropout_p=arch.get('ffn_dropout_p', 0.2),
|
||||
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
||||
resid_dropout_p=arch.get('resid_dropout_p', 0.2),
|
||||
token_dropout_p=arch.get('token_dropout_p', 0.0),
|
||||
learn_te=arch.get('learn_te', True)
|
||||
)
|
||||
model = model.to(self.device)
|
||||
|
||||
model_size = sum(p.numel() for p in model.parameters())
|
||||
logger.info(f"Model parameters: {model_size:,}")
|
||||
if self.rank == 0:
|
||||
print(f"Model parameters: {model_size:,}")
|
||||
|
||||
logger.info("=== Training Configuration ===")
|
||||
logger.info(f"Data path: {self.config.data_path}")
|
||||
logger.info(f"Lookback window: {self.config.lookback_window}")
|
||||
logger.info(f"Predict window: {self.config.predict_window}")
|
||||
logger.info(f"Batch size: {self.config.batch_size}")
|
||||
logger.info(f"Learning rate: {self.config.predictor_learning_rate}")
|
||||
logger.info(f"Training epochs: {self.config.basemodel_epochs}")
|
||||
logger.info(f"Device: {self.device}")
|
||||
logger.info(f"Tokenizer path: {self.config.finetuned_tokenizer_path}")
|
||||
logger.info(f"Pretrained model path: {self.config.pretrained_predictor_path}")
|
||||
|
||||
logger.info("Starting fine-tuning training...")
|
||||
if self.rank == 0:
|
||||
print("Starting fine-tuning training...")
|
||||
start_time = time.time()
|
||||
best_val_loss = train_model(
|
||||
model,
|
||||
tokenizer,
|
||||
self.device,
|
||||
self.config,
|
||||
self.config.basemodel_save_path,
|
||||
logger,
|
||||
)
|
||||
training_time = time.time() - start_time
|
||||
|
||||
final_msg = f"Basemodel training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.basemodel_save_path}"
|
||||
logger.info(final_msg)
|
||||
if self.rank == 0:
|
||||
print(f"\n{final_msg}")
|
||||
|
||||
return True
|
||||
|
||||
def run_training(self):
|
||||
if self.rank == 0:
|
||||
print("Starting Kronos model sequential fine-tuning training")
|
||||
print(f"Experiment name: {self.config.experiment_name}")
|
||||
print(f"Experiment description: {self.config.experiment_description}")
|
||||
|
||||
self._setup_distributed()
|
||||
|
||||
self._create_directories()
|
||||
|
||||
tokenizer_exists, basemodel_exists = self._check_existing_models()
|
||||
|
||||
total_start_time = time.time()
|
||||
|
||||
try:
|
||||
if self.config.train_tokenizer:
|
||||
success = self.train_tokenizer_phase()
|
||||
if not success:
|
||||
print("Tokenizer training failed, terminating training")
|
||||
return False
|
||||
else:
|
||||
print("Skipping Tokenizer training phase")
|
||||
|
||||
if self.config.train_basemodel:
|
||||
success = self.train_basemodel_phase()
|
||||
if not success:
|
||||
print("Basemodel training failed, terminating training")
|
||||
return False
|
||||
else:
|
||||
print("Skipping Basemodel training phase")
|
||||
|
||||
total_time = time.time() - total_start_time
|
||||
|
||||
if self.rank == 0:
|
||||
print("\n" + "="*60)
|
||||
print("Training completed!")
|
||||
print("="*60)
|
||||
print(f"Total training time: {total_time/60:.2f} minutes")
|
||||
print(f"Tokenizer model: {self.config.tokenizer_best_model_path}")
|
||||
print(f"Basemodel model: {self.config.basemodel_best_model_path}")
|
||||
print("="*60)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
if self.rank == 0:
|
||||
print(f"Error occurred during training: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Kronos Model Sequential Fine-tuning Training')
|
||||
parser.add_argument('--config', type=str, default='config.yaml',
|
||||
help='Configuration file path (default: config.yaml)')
|
||||
parser.add_argument('--skip-tokenizer', action='store_true',
|
||||
help='Skip tokenizer training phase')
|
||||
parser.add_argument('--skip-basemodel', action='store_true',
|
||||
help='Skip basemodel training phase')
|
||||
parser.add_argument('--skip-existing', action='store_true',
|
||||
help='Skip training for existing models')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
trainer = SequentialTrainer(args.config)
|
||||
|
||||
if args.skip_tokenizer:
|
||||
trainer.config.train_tokenizer = False
|
||||
if args.skip_basemodel:
|
||||
trainer.config.train_basemodel = False
|
||||
if args.skip_existing:
|
||||
trainer.config.skip_existing = True
|
||||
|
||||
success = trainer.run_training()
|
||||
|
||||
if success:
|
||||
print("Training completed successfully!")
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("Training failed!")
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
try:
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
except Exception:
|
||||
pass
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
x
Reference in New Issue
Block a user