Compare commits

...

10 Commits

Author SHA1 Message Date
2b5bcc55f8 init 2025-10-29 14:11:27 +08:00
ShiYu
eeb3168f71
Merge pull request #167 from SoYuCry/master
fix: define missing split_token in HierarchicalEmbedding
2025-10-26 20:30:18 +08:00
YuCry
a7e294cc56 fix: define missing split_token in HierarchicalEmbedding 2025-10-25 23:37:30 +08:00
ShiYu
a5f5aba12d
Merge pull request #152 from RahulPatel2727/master
Update README.md
2025-10-19 16:25:33 +08:00
Rahul Patel
b4e24f3e1b
Update README.md 2025-10-18 10:25:56 +05:30
ShiYu
082ab7ef62
Merge pull request #138 from Luciferbobo/master
add CSV-based finetuning pipeline for Kronos models
2025-10-12 17:06:30 +08:00
zhangboyu1
7f658d9672 update config & readme 2025-10-09 18:14:04 +08:00
zhangboyu1
166b4162fb add example data 2025-10-09 17:36:56 +08:00
zhangboyu1
a8df339586 update readme 2025-10-09 16:35:59 +08:00
zhangboyu1
38b5176cb7 update readme & figs 2025-10-09 16:24:35 +08:00
67 changed files with 439 additions and 73003 deletions

3
.gitignore vendored
View File

@ -45,7 +45,6 @@ Desktop.ini
# Data files (large files) # Data files (large files)
*.feather *.feather
*.csv
*.parquet *.parquet
*.h5 *.h5
*.hdf5 *.hdf5
@ -74,3 +73,5 @@ venv.bak/
*.temp *.temp
temp/ temp/
tmp/ tmp/
figures

View File

@ -332,3 +332,4 @@ This project is licensed under the [MIT License](./LICENSE).

File diff suppressed because it is too large Load Diff

View File

@ -1,72 +0,0 @@
import pandas as pd
import matplotlib.pyplot as plt
import sys
sys.path.append("../")
from model import Kronos, KronosTokenizer, KronosPredictor
def plot_prediction(kline_df, pred_df):
pred_df.index = kline_df.index[-pred_df.shape[0]:]
sr_close = kline_df['close']
sr_pred_close = pred_df['close']
sr_close.name = 'Ground Truth'
sr_pred_close.name = "Prediction"
sr_volume = kline_df['volume']
sr_pred_volume = pred_df['volume']
sr_volume.name = 'Ground Truth'
sr_pred_volume.name = "Prediction"
close_df = pd.concat([sr_close, sr_pred_close], axis=1)
volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
ax1.set_ylabel('Close Price', fontsize=14)
ax1.legend(loc='lower left', fontsize=12)
ax1.grid(True)
ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
ax2.set_ylabel('Volume', fontsize=14)
ax2.legend(loc='upper left', fontsize=12)
ax2.grid(True)
plt.tight_layout()
plt.show()
# 1. Load Model and Tokenizer
tokenizer = KronosTokenizer.from_pretrained('/home/csc/huggingface/Kronos-Tokenizer-base/')
model = Kronos.from_pretrained("/home/csc/huggingface/Kronos-base/")
# 2. Instantiate Predictor
predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512)
# 3. Prepare Data
df = pd.read_csv("./data/XSHG_5min_600977.csv")
df['timestamps'] = pd.to_datetime(df['timestamps'])
lookback = 400
pred_len = 120
dfs = []
xtsp = []
ytsp = []
for i in range(5):
idf = df.loc[(i*400):(i*400+lookback-1), ['open', 'high', 'low', 'close', 'volume', 'amount']]
i_x_timestamp = df.loc[(i*400):(i*400+lookback-1), 'timestamps']
i_y_timestamp = df.loc[(i*400+lookback):(i*400+lookback+pred_len-1), 'timestamps']
dfs.append(idf)
xtsp.append(i_x_timestamp)
ytsp.append(i_y_timestamp)
pred_df = predictor.predict_batch(
df_list=dfs,
x_timestamp_list=xtsp,
y_timestamp_list=ytsp,
pred_len=pred_len,
)

View File

@ -1,80 +0,0 @@
import pandas as pd
import matplotlib.pyplot as plt
import sys
sys.path.append("../")
from model import Kronos, KronosTokenizer, KronosPredictor
def plot_prediction(kline_df, pred_df):
pred_df.index = kline_df.index[-pred_df.shape[0]:]
sr_close = kline_df['close']
sr_pred_close = pred_df['close']
sr_close.name = 'Ground Truth'
sr_pred_close.name = "Prediction"
sr_volume = kline_df['volume']
sr_pred_volume = pred_df['volume']
sr_volume.name = 'Ground Truth'
sr_pred_volume.name = "Prediction"
close_df = pd.concat([sr_close, sr_pred_close], axis=1)
volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
ax1.set_ylabel('Close Price', fontsize=14)
ax1.legend(loc='lower left', fontsize=12)
ax1.grid(True)
ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
ax2.set_ylabel('Volume', fontsize=14)
ax2.legend(loc='upper left', fontsize=12)
ax2.grid(True)
plt.tight_layout()
plt.show()
# 1. Load Model and Tokenizer
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
# 2. Instantiate Predictor
predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512)
# 3. Prepare Data
df = pd.read_csv("./data/XSHG_5min_600977.csv")
df['timestamps'] = pd.to_datetime(df['timestamps'])
lookback = 400
pred_len = 120
x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close', 'volume', 'amount']]
x_timestamp = df.loc[:lookback-1, 'timestamps']
y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps']
# 4. Make Prediction
pred_df = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=1.0,
top_p=0.9,
sample_count=1,
verbose=True
)
# 5. Visualize Results
print("Forecasted Data Head:")
print(pred_df.head())
# Combine historical and forecasted data for plotting
kline_df = df.loc[:lookback+pred_len-1]
# visualize
plot_prediction(kline_df, pred_df)

View File

@ -1,68 +0,0 @@
import pandas as pd
import matplotlib.pyplot as plt
import sys
sys.path.append("../")
from model import Kronos, KronosTokenizer, KronosPredictor
def plot_prediction(kline_df, pred_df):
pred_df.index = kline_df.index[-pred_df.shape[0]:]
sr_close = kline_df['close']
sr_pred_close = pred_df['close']
sr_close.name = 'Ground Truth'
sr_pred_close.name = "Prediction"
close_df = pd.concat([sr_close, sr_pred_close], axis=1)
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
ax.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
ax.set_ylabel('Close Price', fontsize=14)
ax.legend(loc='lower left', fontsize=12)
ax.grid(True)
plt.tight_layout()
plt.show()
# 1. Load Model and Tokenizer
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
# 2. Instantiate Predictor
predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512)
# 3. Prepare Data
df = pd.read_csv("./data/XSHG_5min_600977.csv")
df['timestamps'] = pd.to_datetime(df['timestamps'])
lookback = 400
pred_len = 120
x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close']]
x_timestamp = df.loc[:lookback-1, 'timestamps']
y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps']
# 4. Make Prediction
pred_df = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=1.0,
top_p=0.9,
sample_count=1,
verbose=True
)
# 5. Visualize Results
print("Forecasted Data Head:")
print(pred_df.head())
# Combine historical and forecasted data for plotting
kline_df = df.loc[:lookback+pred_len-1]
# visualize
plot_prediction(kline_df, pred_df)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 488 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 851 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 189 KiB

View File

@ -1,131 +0,0 @@
import os
class Config:
"""
Configuration class for the entire project.
"""
def __init__(self):
# =================================================================
# Data & Feature Parameters
# =================================================================
# TODO: Update this path to your Qlib data directory.
self.qlib_data_path = "~/.qlib/qlib_data/cn_data"
self.instrument = 'csi300'
# Overall time range for data loading from Qlib.
self.dataset_begin_time = "2011-01-01"
self.dataset_end_time = '2025-06-05'
# Sliding window parameters for creating samples.
self.lookback_window = 90 # Number of past time steps for input.
self.predict_window = 10 # Number of future time steps for prediction.
self.max_context = 512 # Maximum context length for the model.
# Features to be used from the raw data.
self.feature_list = ['open', 'high', 'low', 'close', 'vol', 'amt']
# Time-based features to be generated.
self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month']
# =================================================================
# Dataset Splitting & Paths
# =================================================================
# Note: The validation/test set starts earlier than the training/validation set ends
# to account for the `lookback_window`.
self.train_time_range = ["2011-01-01", "2022-12-31"]
self.val_time_range = ["2022-09-01", "2024-06-30"]
self.test_time_range = ["2024-04-01", "2025-06-05"]
self.backtest_time_range = ["2024-07-01", "2025-06-05"]
# TODO: Directory to save the processed, pickled datasets.
self.dataset_path = "./data/processed_datasets"
# =================================================================
# Training Hyperparameters
# =================================================================
self.clip = 5.0 # Clipping value for normalized data to prevent outliers.
self.epochs = 30
self.log_interval = 100 # Log training status every N batches.
self.batch_size = 50 # Batch size per GPU.
# Number of samples to draw for one "epoch" of training/validation.
# This is useful for large datasets where a true epoch is too long.
self.n_train_iter = 2000 * self.batch_size
self.n_val_iter = 400 * self.batch_size
# Learning rates for different model components.
self.tokenizer_learning_rate = 2e-4
self.predictor_learning_rate = 4e-5
# Gradient accumulation to simulate a larger batch size.
self.accumulation_steps = 1
# AdamW optimizer parameters.
self.adam_beta1 = 0.9
self.adam_beta2 = 0.95
self.adam_weight_decay = 0.1
# Miscellaneous
self.seed = 100 # Global random seed for reproducibility.
# =================================================================
# Experiment Logging & Saving
# =================================================================
self.use_comet = True # Set to False if you don't want to use Comet ML
self.comet_config = {
# It is highly recommended to load secrets from environment variables
# for security purposes. Example: os.getenv("COMET_API_KEY")
"api_key": "YOUR_COMET_API_KEY",
"project_name": "Kronos-Finetune-Demo",
"workspace": "your_comet_workspace" # TODO: Change to your Comet ML workspace name
}
self.comet_tag = 'finetune_demo'
self.comet_name = 'finetune_demo'
# Base directory for saving model checkpoints and results.
# Using a general 'outputs' directory is a common practice.
self.save_path = "./outputs/models"
self.tokenizer_save_folder_name = 'finetune_tokenizer_demo'
self.predictor_save_folder_name = 'finetune_predictor_demo'
self.backtest_save_folder_name = 'finetune_backtest_demo'
# Path for backtesting results.
self.backtest_result_path = "./outputs/backtest_results"
# =================================================================
# Model & Checkpoint Paths
# =================================================================
# TODO: Update these paths to your pretrained model locations.
# These can be local paths or Hugging Face Hub model identifiers.
self.pretrained_tokenizer_path = "path/to/your/Kronos-Tokenizer-base"
self.pretrained_predictor_path = "path/to/your/Kronos-small"
# Paths to the fine-tuned models, derived from the save_path.
# These will be generated automatically during training.
self.finetuned_tokenizer_path = f"{self.save_path}/{self.tokenizer_save_folder_name}/checkpoints/best_model"
self.finetuned_predictor_path = f"{self.save_path}/{self.predictor_save_folder_name}/checkpoints/best_model"
# =================================================================
# Backtesting Parameters
# =================================================================
self.backtest_n_symbol_hold = 50 # Number of symbols to hold in the portfolio.
self.backtest_n_symbol_drop = 5 # Number of symbols to drop from the pool.
self.backtest_hold_thresh = 5 # Minimum holding period for a stock.
self.inference_T = 0.6
self.inference_top_p = 0.9
self.inference_top_k = 0
self.inference_sample_count = 5
self.backtest_batch_size = 1000
self.backtest_benchmark = self._set_benchmark(self.instrument)
def _set_benchmark(self, instrument):
dt_benchmark = {
'csi800': "SH000906",
'csi1000': "SH000852",
'csi300': "SH000300",
}
if instrument in dt_benchmark:
return dt_benchmark[instrument]
else:
raise ValueError(f"Benchmark not defined for instrument: {instrument}")

View File

@ -1,145 +0,0 @@
import pickle
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from config import Config
class QlibDataset(Dataset):
"""
A PyTorch Dataset for handling Qlib financial time series data.
This dataset pre-computes all possible start indices for sliding windows
and then randomly samples from them during training/validation.
Args:
data_type (str): The type of dataset to load, either 'train' or 'val'.
Raises:
ValueError: If `data_type` is not 'train' or 'val'.
"""
def __init__(self, data_type: str = 'train'):
self.config = Config()
if data_type not in ['train', 'val']:
raise ValueError("data_type must be 'train' or 'val'")
self.data_type = data_type
# Use a dedicated random number generator for sampling to avoid
# interfering with other random processes (e.g., in model initialization).
self.py_rng = random.Random(self.config.seed)
# Set paths and number of samples based on the data type.
if data_type == 'train':
self.data_path = f"{self.config.dataset_path}/train_data.pkl"
self.n_samples = self.config.n_train_iter
else:
self.data_path = f"{self.config.dataset_path}/val_data.pkl"
self.n_samples = self.config.n_val_iter
with open(self.data_path, 'rb') as f:
self.data = pickle.load(f)
self.window = self.config.lookback_window + self.config.predict_window + 1
self.symbols = list(self.data.keys())
self.feature_list = self.config.feature_list
self.time_feature_list = self.config.time_feature_list
# Pre-compute all possible (symbol, start_index) pairs.
self.indices = []
print(f"[{data_type.upper()}] Pre-computing sample indices...")
for symbol in self.symbols:
df = self.data[symbol].reset_index()
series_len = len(df)
num_samples = series_len - self.window + 1
if num_samples > 0:
# Generate time features and store them directly in the dataframe.
df['minute'] = df['datetime'].dt.minute
df['hour'] = df['datetime'].dt.hour
df['weekday'] = df['datetime'].dt.weekday
df['day'] = df['datetime'].dt.day
df['month'] = df['datetime'].dt.month
# Keep only necessary columns to save memory.
self.data[symbol] = df[self.feature_list + self.time_feature_list]
# Add all valid starting indices for this symbol to the global list.
for i in range(num_samples):
self.indices.append((symbol, i))
# The effective dataset size is the minimum of the configured iterations
# and the total number of available samples.
self.n_samples = min(self.n_samples, len(self.indices))
print(f"[{data_type.upper()}] Found {len(self.indices)} possible samples. Using {self.n_samples} per epoch.")
def set_epoch_seed(self, epoch: int):
"""
Sets a new seed for the random sampler for each epoch. This is crucial
for reproducibility in distributed training.
Args:
epoch (int): The current epoch number.
"""
epoch_seed = self.config.seed + epoch
self.py_rng.seed(epoch_seed)
def __len__(self) -> int:
"""Returns the number of samples per epoch."""
return self.n_samples
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Retrieves a random sample from the dataset.
Note: The `idx` argument is ignored. Instead, a random index is drawn
from the pre-computed `self.indices` list using `self.py_rng`. This
ensures random sampling over the entire dataset for each call.
Args:
idx (int): Ignored.
Returns:
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- x_tensor (torch.Tensor): The normalized feature tensor.
- x_stamp_tensor (torch.Tensor): The time feature tensor.
"""
# Select a random sample from the entire pool of indices.
random_idx = self.py_rng.randint(0, len(self.indices) - 1)
symbol, start_idx = self.indices[random_idx]
# Extract the sliding window from the dataframe.
df = self.data[symbol]
end_idx = start_idx + self.window
win_df = df.iloc[start_idx:end_idx]
# Separate main features and time features.
x = win_df[self.feature_list].values.astype(np.float32)
x_stamp = win_df[self.time_feature_list].values.astype(np.float32)
# Perform instance-level normalization.
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
x = (x - x_mean) / (x_std + 1e-5)
x = np.clip(x, -self.config.clip, self.config.clip)
# Convert to PyTorch tensors.
x_tensor = torch.from_numpy(x)
x_stamp_tensor = torch.from_numpy(x_stamp)
return x_tensor, x_stamp_tensor
if __name__ == '__main__':
# Example usage and verification.
print("Creating training dataset instance...")
train_dataset = QlibDataset(data_type='train')
print(f"Dataset length: {len(train_dataset)}")
if len(train_dataset) > 0:
try_x, try_x_stamp = train_dataset[100] # Index 100 is ignored.
print(f"Sample feature shape: {try_x.shape}")
print(f"Sample time feature shape: {try_x_stamp.shape}")
else:
print("Dataset is empty.")

View File

@ -1,130 +0,0 @@
import os
import pickle
import numpy as np
import pandas as pd
import qlib
from qlib.config import REG_CN
from qlib.data import D
from qlib.data.dataset.loader import QlibDataLoader
from tqdm import trange
from config import Config
class QlibDataPreprocessor:
"""
A class to handle the loading, processing, and splitting of Qlib financial data.
"""
def __init__(self):
"""Initializes the preprocessor with configuration and data fields."""
self.config = Config()
self.data_fields = ['open', 'close', 'high', 'low', 'volume', 'vwap']
self.data = {} # A dictionary to store processed data for each symbol.
def initialize_qlib(self):
"""Initializes the Qlib environment."""
print("Initializing Qlib...")
qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN)
def load_qlib_data(self):
"""
Loads raw data from Qlib, processes it symbol by symbol, and stores
it in the `self.data` attribute.
"""
print("Loading and processing data from Qlib...")
data_fields_qlib = ['$' + f for f in self.data_fields]
cal: np.ndarray = D.calendar()
# Determine the actual start and end times to load, including buffer for lookback and predict windows.
start_index = cal.searchsorted(pd.Timestamp(self.config.dataset_begin_time))
end_index = cal.searchsorted(pd.Timestamp(self.config.dataset_end_time))
# Check if start_index lookbackw_window will cause negative index
adjusted_start_index = max(start_index - self.config.lookback_window, 0)
real_start_time = cal[adjusted_start_index]
# Check if end_index exceeds the range of the array
if end_index >= len(cal):
end_index = len(cal) - 1
elif cal[end_index] != pd.Timestamp(self.config.dataset_end_time):
end_index -= 1
# Check if end_index+predictw_window will exceed the range of the array
adjusted_end_index = min(end_index + self.config.predict_window, len(cal) - 1)
real_end_time = cal[adjusted_end_index]
# Load data using Qlib's data loader.
data_df = QlibDataLoader(config=data_fields_qlib).load(
self.config.instrument, real_start_time, real_end_time
)
data_df = data_df.stack().unstack(level=1) # Reshape for easier access.
symbol_list = list(data_df.columns)
for i in trange(len(symbol_list), desc="Processing Symbols"):
symbol = symbol_list[i]
symbol_df = data_df[symbol]
# Pivot the table to have features as columns and datetime as index.
symbol_df = symbol_df.reset_index().rename(columns={'level_1': 'field'})
symbol_df = pd.pivot(symbol_df, index='datetime', columns='field', values=symbol)
symbol_df = symbol_df.rename(columns={f'${field}': field for field in self.data_fields})
# Calculate amount and select final features.
symbol_df['vol'] = symbol_df['volume']
symbol_df['amt'] = (symbol_df['open'] + symbol_df['high'] + symbol_df['low'] + symbol_df['close']) / 4 * symbol_df['vol']
symbol_df = symbol_df[self.config.feature_list]
# Filter out symbols with insufficient data.
symbol_df = symbol_df.dropna()
if len(symbol_df) < self.config.lookback_window + self.config.predict_window + 1:
continue
self.data[symbol] = symbol_df
def prepare_dataset(self):
"""
Splits the loaded data into train, validation, and test sets and saves them to disk.
"""
print("Splitting data into train, validation, and test sets...")
train_data, val_data, test_data = {}, {}, {}
symbol_list = list(self.data.keys())
for i in trange(len(symbol_list), desc="Preparing Datasets"):
symbol = symbol_list[i]
symbol_df = self.data[symbol]
# Define time ranges from config.
train_start, train_end = self.config.train_time_range
val_start, val_end = self.config.val_time_range
test_start, test_end = self.config.test_time_range
# Create boolean masks for each dataset split.
train_mask = (symbol_df.index >= train_start) & (symbol_df.index <= train_end)
val_mask = (symbol_df.index >= val_start) & (symbol_df.index <= val_end)
test_mask = (symbol_df.index >= test_start) & (symbol_df.index <= test_end)
# Apply masks to create the final datasets.
train_data[symbol] = symbol_df[train_mask]
val_data[symbol] = symbol_df[val_mask]
test_data[symbol] = symbol_df[test_mask]
# Save the datasets using pickle.
os.makedirs(self.config.dataset_path, exist_ok=True)
with open(f"{self.config.dataset_path}/train_data.pkl", 'wb') as f:
pickle.dump(train_data, f)
with open(f"{self.config.dataset_path}/val_data.pkl", 'wb') as f:
pickle.dump(val_data, f)
with open(f"{self.config.dataset_path}/test_data.pkl", 'wb') as f:
pickle.dump(test_data, f)
print("Datasets prepared and saved successfully.")
if __name__ == '__main__':
# This block allows the script to be run directly to perform data preprocessing.
preprocessor = QlibDataPreprocessor()
preprocessor.initialize_qlib()
preprocessor.load_qlib_data()
preprocessor.prepare_dataset()

View File

@ -1,362 +0,0 @@
import os
import sys
import argparse
import pickle
from collections import defaultdict
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import trange, tqdm
from matplotlib import pyplot as plt
import qlib
from qlib.config import REG_CN
from qlib.backtest import backtest, executor, CommonInfrastructure
from qlib.contrib.evaluate import risk_analysis
from qlib.contrib.strategy import TopkDropoutStrategy
from qlib.utils import flatten_dict
from qlib.utils.time import Freq
# Ensure project root is in the Python path
sys.path.append("../")
from config import Config
from model.kronos import Kronos, KronosTokenizer, auto_regressive_inference
# =================================================================================
# 1. Data Loading and Processing for Inference
# =================================================================================
class QlibTestDataset(Dataset):
"""
PyTorch Dataset for handling Qlib test data, specifically for inference.
This dataset iterates through all possible sliding windows sequentially. It also
yields metadata like symbol and timestamp, which are crucial for mapping
predictions back to the original time series.
"""
def __init__(self, data: dict, config: Config):
self.data = data
self.config = config
self.window_size = config.lookback_window + config.predict_window
self.symbols = list(self.data.keys())
self.feature_list = config.feature_list
self.time_feature_list = config.time_feature_list
self.indices = []
print("Preprocessing and building indices for test dataset...")
for symbol in self.symbols:
df = self.data[symbol].reset_index()
# Generate time features on-the-fly
df['minute'] = df['datetime'].dt.minute
df['hour'] = df['datetime'].dt.hour
df['weekday'] = df['datetime'].dt.weekday
df['day'] = df['datetime'].dt.day
df['month'] = df['datetime'].dt.month
self.data[symbol] = df # Store preprocessed dataframe
num_samples = len(df) - self.window_size + 1
if num_samples > 0:
for i in range(num_samples):
timestamp = df.iloc[i + self.config.lookback_window - 1]['datetime']
self.indices.append((symbol, i, timestamp))
def __len__(self) -> int:
return len(self.indices)
def __getitem__(self, idx: int):
symbol, start_idx, timestamp = self.indices[idx]
df = self.data[symbol]
context_end = start_idx + self.config.lookback_window
predict_end = context_end + self.config.predict_window
context_df = df.iloc[start_idx:context_end]
predict_df = df.iloc[context_end:predict_end]
x = context_df[self.feature_list].values.astype(np.float32)
x_stamp = context_df[self.time_feature_list].values.astype(np.float32)
y_stamp = predict_df[self.time_feature_list].values.astype(np.float32)
# Instance-level normalization, consistent with training
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
x = (x - x_mean) / (x_std + 1e-5)
x = np.clip(x, -self.config.clip, self.config.clip)
return torch.from_numpy(x), torch.from_numpy(x_stamp), torch.from_numpy(y_stamp), symbol, timestamp
# =================================================================================
# 2. Backtesting Logic
# =================================================================================
class QlibBacktest:
"""
A wrapper class for conducting backtesting experiments using Qlib.
"""
def __init__(self, config: Config):
self.config = config
self.initialize_qlib()
def initialize_qlib(self):
"""Initializes the Qlib environment."""
print("Initializing Qlib for backtesting...")
qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN)
def run_single_backtest(self, signal_series: pd.Series) -> pd.DataFrame:
"""
Runs a single backtest for a given prediction signal.
Args:
signal_series (pd.Series): A pandas Series with a MultiIndex
(instrument, datetime) and prediction scores.
Returns:
pd.DataFrame: A DataFrame containing the performance report.
"""
strategy = TopkDropoutStrategy(
topk=self.config.backtest_n_symbol_hold,
n_drop=self.config.backtest_n_symbol_drop,
hold_thresh=self.config.backtest_hold_thresh,
signal=signal_series,
)
executor_config = {
"time_per_step": "day",
"generate_portfolio_metrics": True,
"delay_execution": True,
}
backtest_config = {
"start_time": self.config.backtest_time_range[0],
"end_time": self.config.backtest_time_range[1],
"account": 100_000_000,
"benchmark": self.config.backtest_benchmark,
"exchange_kwargs": {
"freq": "day", "limit_threshold": 0.095, "deal_price": "open",
"open_cost": 0.001, "close_cost": 0.0015, "min_cost": 5,
},
"executor": executor.SimulatorExecutor(**executor_config),
}
portfolio_metric_dict, _ = backtest(strategy=strategy, **backtest_config)
analysis_freq = "{0}{1}".format(*Freq.parse("day"))
report, _ = portfolio_metric_dict.get(analysis_freq)
# --- Analysis and Reporting ---
analysis = {
"excess_return_without_cost": risk_analysis(report["return"] - report["bench"], freq=analysis_freq),
"excess_return_with_cost": risk_analysis(report["return"] - report["bench"] - report["cost"], freq=analysis_freq),
}
print("\n--- Backtest Analysis ---")
print("Benchmark Return:", risk_analysis(report["bench"], freq=analysis_freq), sep='\n')
print("\nExcess Return (w/o cost):", analysis["excess_return_without_cost"], sep='\n')
print("\nExcess Return (w/ cost):", analysis["excess_return_with_cost"], sep='\n')
report_df = pd.DataFrame({
"cum_bench": report["bench"].cumsum(),
"cum_return_w_cost": (report["return"] - report["cost"]).cumsum(),
"cum_ex_return_w_cost": (report["return"] - report["bench"] - report["cost"]).cumsum(),
})
return report_df
def run_and_plot_results(self, signals: dict[str, pd.DataFrame]):
"""
Runs backtests for multiple signals and plots the cumulative return curves.
Args:
signals (dict[str, pd.DataFrame]): A dictionary where keys are signal names
and values are prediction DataFrames.
"""
return_df, ex_return_df, bench_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
for signal_name, pred_df in signals.items():
print(f"\nBacktesting signal: {signal_name}...")
pred_series = pred_df.stack()
pred_series.index.names = ['datetime', 'instrument']
pred_series = pred_series.swaplevel().sort_index()
report_df = self.run_single_backtest(pred_series)
return_df[signal_name] = report_df['cum_return_w_cost']
ex_return_df[signal_name] = report_df['cum_ex_return_w_cost']
if 'return' not in bench_df:
bench_df['return'] = report_df['cum_bench']
# Plotting results
fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
return_df.plot(ax=axes[0], title='Cumulative Return with Cost', grid=True)
axes[0].plot(bench_df['return'], label=self.config.instrument.upper(), color='black', linestyle='--')
axes[0].legend()
axes[0].set_ylabel("Cumulative Return")
ex_return_df.plot(ax=axes[1], title='Cumulative Excess Return with Cost', grid=True)
axes[1].legend()
axes[1].set_xlabel("Date")
axes[1].set_ylabel("Cumulative Excess Return")
plt.tight_layout()
plt.savefig("../figures/backtest_result_example.png", dpi=200)
plt.show()
# =================================================================================
# 3. Inference Logic
# =================================================================================
def load_models(config: dict) -> tuple[KronosTokenizer, Kronos]:
"""Loads the fine-tuned tokenizer and predictor model."""
device = torch.device(config['device'])
print(f"Loading models onto device: {device}...")
tokenizer = KronosTokenizer.from_pretrained(config['tokenizer_path']).to(device).eval()
model = Kronos.from_pretrained(config['model_path']).to(device).eval()
return tokenizer, model
def collate_fn_for_inference(batch):
"""
Custom collate function to handle batches containing Tensors, strings, and Timestamps.
Args:
batch (list): A list of samples, where each sample is the tuple returned by
QlibTestDataset.__getitem__.
Returns:
A single tuple containing the batched data.
"""
# Unzip the list of samples into separate lists for each data type
x, x_stamp, y_stamp, symbols, timestamps = zip(*batch)
# Stack the tensors to create a batch
x_batch = torch.stack(x, dim=0)
x_stamp_batch = torch.stack(x_stamp, dim=0)
y_stamp_batch = torch.stack(y_stamp, dim=0)
# Return the strings and timestamps as lists
return x_batch, x_stamp_batch, y_stamp_batch, list(symbols), list(timestamps)
def generate_predictions(config: dict, test_data: dict) -> dict[str, pd.DataFrame]:
"""
Runs inference on the test dataset to generate prediction signals.
Args:
config (dict): A dictionary containing inference parameters.
test_data (dict): The raw test data loaded from a pickle file.
Returns:
A dictionary where keys are signal types (e.g., 'mean', 'last') and
values are DataFrames of predictions (datetime index, symbol columns).
"""
tokenizer, model = load_models(config)
device = torch.device(config['device'])
# Use the Dataset and DataLoader for efficient batching and processing
dataset = QlibTestDataset(data=test_data, config=Config())
loader = DataLoader(
dataset,
batch_size=config['batch_size'] // config['sample_count'],
shuffle=False,
num_workers=os.cpu_count() // 2,
collate_fn=collate_fn_for_inference
)
results = defaultdict(list)
with torch.no_grad():
for x, x_stamp, y_stamp, symbols, timestamps in tqdm(loader, desc="Inference"):
preds = auto_regressive_inference(
tokenizer, model, x.to(device), x_stamp.to(device), y_stamp.to(device),
max_context=config['max_context'], pred_len=config['pred_len'], clip=config['clip'],
T=config['T'], top_k=config['top_k'], top_p=config['top_p'], sample_count=config['sample_count']
)
# You can try commenting on this line to keep the history data
preds = preds[:, -config['pred_len']:, :]
# The 'close' price is at index 3 in `feature_list`
last_day_close = x[:, -1, 3].numpy()
signals = {
'last': preds[:, -1, 3] - last_day_close,
'mean': np.mean(preds[:, :, 3], axis=1) - last_day_close,
'max': np.max(preds[:, :, 3], axis=1) - last_day_close,
'min': np.min(preds[:, :, 3], axis=1) - last_day_close,
}
for i in range(len(symbols)):
for sig_type, sig_values in signals.items():
results[sig_type].append((timestamps[i], symbols[i], sig_values[i]))
print("Post-processing predictions into DataFrames...")
prediction_dfs = {}
for sig_type, records in results.items():
df = pd.DataFrame(records, columns=['datetime', 'instrument', 'score'])
pivot_df = df.pivot_table(index='datetime', columns='instrument', values='score')
prediction_dfs[sig_type] = pivot_df.sort_index()
return prediction_dfs
# =================================================================================
# 4. Main Execution
# =================================================================================
def main():
"""Main function to set up config, run inference, and execute backtesting."""
parser = argparse.ArgumentParser(description="Run Kronos Inference and Backtesting")
parser.add_argument("--device", type=str, default="cuda:1", help="Device for inference (e.g., 'cuda:0', 'cpu')")
args = parser.parse_args()
# --- 1. Configuration Setup ---
base_config = Config()
# Create a dedicated dictionary for this run's configuration
run_config = {
'device': args.device,
'data_path': base_config.dataset_path,
'result_save_path': base_config.backtest_result_path,
'result_name': base_config.backtest_save_folder_name,
'tokenizer_path': base_config.finetuned_tokenizer_path,
'model_path': base_config.finetuned_predictor_path,
'max_context': base_config.max_context,
'pred_len': base_config.predict_window,
'clip': base_config.clip,
'T': base_config.inference_T,
'top_k': base_config.inference_top_k,
'top_p': base_config.inference_top_p,
'sample_count': base_config.inference_sample_count,
'batch_size': base_config.backtest_batch_size,
}
print("--- Running with Configuration ---")
for key, val in run_config.items():
print(f"{key:>20}: {val}")
print("-" * 35)
# --- 2. Load Data ---
test_data_path = os.path.join(run_config['data_path'], "test_data.pkl")
print(f"Loading test data from {test_data_path}...")
with open(test_data_path, 'rb') as f:
test_data = pickle.load(f)
print(test_data)
# --- 3. Generate Predictions ---
model_preds = generate_predictions(run_config, test_data)
# --- 4. Save Predictions ---
save_dir = os.path.join(run_config['result_save_path'], run_config['result_name'])
os.makedirs(save_dir, exist_ok=True)
predictions_file = os.path.join(save_dir, "predictions.pkl")
print(f"Saving prediction signals to {predictions_file}...")
with open(predictions_file, 'wb') as f:
pickle.dump(model_preds, f)
# --- 5. Run Backtesting ---
with open(predictions_file, 'rb') as f:
model_preds = pickle.load(f)
backtester = QlibBacktest(base_config)
backtester.run_and_plot_results(model_preds)
if __name__ == '__main__':
main()

View File

@ -1,244 +0,0 @@
import os
import sys
import json
import time
from time import gmtime, strftime
import torch.distributed as dist
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import comet_ml
# Ensure project root is in path
sys.path.append('../')
from config import Config
from dataset import QlibDataset
from model.kronos import KronosTokenizer, Kronos
# Import shared utilities
from utils.training_utils import (
setup_ddp,
cleanup_ddp,
set_seed,
get_model_size,
format_time
)
def create_dataloaders(config: dict, rank: int, world_size: int):
"""
Creates and returns distributed dataloaders for training and validation.
Args:
config (dict): A dictionary of configuration parameters.
rank (int): The global rank of the current process.
world_size (int): The total number of processes.
Returns:
tuple: (train_loader, val_loader, train_dataset, valid_dataset).
"""
print(f"[Rank {rank}] Creating distributed dataloaders...")
train_dataset = QlibDataset('train')
valid_dataset = QlibDataset('val')
print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}")
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
val_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, rank=rank, shuffle=False)
train_loader = DataLoader(
train_dataset, batch_size=config['batch_size'], sampler=train_sampler,
num_workers=config.get('num_workers', 2), pin_memory=True, drop_last=True
)
val_loader = DataLoader(
valid_dataset, batch_size=config['batch_size'], sampler=val_sampler,
num_workers=config.get('num_workers', 2), pin_memory=True, drop_last=False
)
return train_loader, val_loader, train_dataset, valid_dataset
def train_model(model, tokenizer, device, config, save_dir, logger, rank, world_size):
"""
The main training and validation loop for the predictor.
"""
start_time = time.time()
if rank == 0:
effective_bs = config['batch_size'] * world_size
print(f"Effective BATCHSIZE per GPU: {config['batch_size']}, Total: {effective_bs}")
train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config['predictor_learning_rate'],
betas=(config['adam_beta1'], config['adam_beta2']),
weight_decay=config['adam_weight_decay']
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=config['predictor_learning_rate'],
steps_per_epoch=len(train_loader), epochs=config['epochs'],
pct_start=0.03, div_factor=10
)
best_val_loss = float('inf')
dt_result = {}
batch_idx_global = 0
for epoch_idx in range(config['epochs']):
epoch_start_time = time.time()
model.train()
train_loader.sampler.set_epoch(epoch_idx)
train_dataset.set_epoch_seed(epoch_idx * 10000 + rank)
valid_dataset.set_epoch_seed(0)
for i, (batch_x, batch_x_stamp) in enumerate(train_loader):
batch_x = batch_x.squeeze(0).to(device, non_blocking=True)
batch_x_stamp = batch_x_stamp.squeeze(0).to(device, non_blocking=True)
# Tokenize input data on-the-fly
with torch.no_grad():
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
# Prepare inputs and targets for the language model
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
# Forward pass and loss calculation
logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
loss, s1_loss, s2_loss = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0)
optimizer.step()
scheduler.step()
# Logging (Master Process Only)
if rank == 0 and (batch_idx_global + 1) % config['log_interval'] == 0:
lr = optimizer.param_groups[0]['lr']
print(
f"[Rank {rank}, Epoch {epoch_idx + 1}/{config['epochs']}, Step {i + 1}/{len(train_loader)}] "
f"LR {lr:.6f}, Loss: {loss.item():.4f}"
)
if rank == 0 and logger:
lr = optimizer.param_groups[0]['lr']
logger.log_metric('train_predictor_loss_batch', loss.item(), step=batch_idx_global)
logger.log_metric('train_S1_loss_each_batch', s1_loss.item(), step=batch_idx_global)
logger.log_metric('train_S2_loss_each_batch', s2_loss.item(), step=batch_idx_global)
logger.log_metric('predictor_learning_rate', lr, step=batch_idx_global)
batch_idx_global += 1
# --- Validation Loop ---
model.eval()
tot_val_loss_sum_rank = 0.0
val_batches_processed_rank = 0
with torch.no_grad():
for batch_x, batch_x_stamp in val_loader:
batch_x = batch_x.squeeze(0).to(device, non_blocking=True)
batch_x_stamp = batch_x_stamp.squeeze(0).to(device, non_blocking=True)
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
val_loss, _, _ = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
tot_val_loss_sum_rank += val_loss.item()
val_batches_processed_rank += 1
# Reduce validation metrics
val_loss_sum_tensor = torch.tensor(tot_val_loss_sum_rank, device=device)
val_batches_tensor = torch.tensor(val_batches_processed_rank, device=device)
dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(val_batches_tensor, op=dist.ReduceOp.SUM)
avg_val_loss = val_loss_sum_tensor.item() / val_batches_tensor.item() if val_batches_tensor.item() > 0 else 0
# --- End of Epoch Summary & Checkpointing (Master Process Only) ---
if rank == 0:
print(f"\n--- Epoch {epoch_idx + 1}/{config['epochs']} Summary ---")
print(f"Validation Loss: {avg_val_loss:.4f}")
print(f"Time This Epoch: {format_time(time.time() - epoch_start_time)}")
print(f"Total Time Elapsed: {format_time(time.time() - start_time)}\n")
if logger:
logger.log_metric('val_predictor_loss_epoch', avg_val_loss, epoch=epoch_idx)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
save_path = f"{save_dir}/checkpoints/best_model"
model.module.save_pretrained(save_path)
print(f"Best model saved to {save_path} (Val Loss: {best_val_loss:.4f})")
dist.barrier()
dt_result['best_val_loss'] = best_val_loss
return dt_result
def main(config: dict):
"""Main function to orchestrate the DDP training process."""
rank, world_size, local_rank = setup_ddp()
device = torch.device(f"cuda:{local_rank}")
set_seed(config['seed'], rank)
save_dir = os.path.join(config['save_path'], config['predictor_save_folder_name'])
# Logger and summary setup (master process only)
comet_logger, master_summary = None, {}
if rank == 0:
os.makedirs(os.path.join(save_dir, 'checkpoints'), exist_ok=True)
master_summary = {
'start_time': strftime("%Y-%m-%dT%H-%M-%S", gmtime()),
'save_directory': save_dir,
'world_size': world_size,
}
if config['use_comet']:
comet_logger = comet_ml.Experiment(
api_key=config['comet_config']['api_key'],
project_name=config['comet_config']['project_name'],
workspace=config['comet_config']['workspace'],
)
comet_logger.add_tag(config['comet_tag'])
comet_logger.set_name(config['comet_name'])
comet_logger.log_parameters(config)
print("Comet Logger Initialized.")
dist.barrier()
# Model Initialization
tokenizer = KronosTokenizer.from_pretrained(config['finetuned_tokenizer_path'])
tokenizer.eval().to(device)
model = Kronos.from_pretrained(config['pretrained_predictor_path'])
model.to(device)
model = DDP(model, device_ids=[local_rank], find_unused_parameters=False)
if rank == 0:
print(f"Predictor Model Size: {get_model_size(model.module)}")
# Start Training
dt_result = train_model(
model, tokenizer, device, config, save_dir, comet_logger, rank, world_size
)
if rank == 0:
master_summary['final_result'] = dt_result
with open(os.path.join(save_dir, 'summary.json'), 'w') as f:
json.dump(master_summary, f, indent=4)
print('Training finished. Summary file saved.')
if comet_logger: comet_logger.end()
cleanup_ddp()
if __name__ == '__main__':
# Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_predictor.py
if "WORLD_SIZE" not in os.environ:
raise RuntimeError("This script must be launched with `torchrun`.")
config_instance = Config()
main(config_instance.__dict__)

View File

@ -1,281 +0,0 @@
import os
import sys
import json
import time
from time import gmtime, strftime
import argparse
import datetime
import torch.distributed as dist
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import comet_ml
# Ensure project root is in path
sys.path.append("../")
from config import Config
from dataset import QlibDataset
from model.kronos import KronosTokenizer
# Import shared utilities
from utils.training_utils import (
setup_ddp,
cleanup_ddp,
set_seed,
get_model_size,
format_time,
)
def create_dataloaders(config: dict, rank: int, world_size: int):
"""
Creates and returns distributed dataloaders for training and validation.
Args:
config (dict): A dictionary of configuration parameters.
rank (int): The global rank of the current process.
world_size (int): The total number of processes.
Returns:
tuple: A tuple containing (train_loader, val_loader, train_dataset, valid_dataset).
"""
print(f"[Rank {rank}] Creating distributed dataloaders...")
train_dataset = QlibDataset('train')
valid_dataset = QlibDataset('val')
print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}")
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
val_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, rank=rank, shuffle=False)
train_loader = DataLoader(
train_dataset,
batch_size=config['batch_size'],
sampler=train_sampler,
shuffle=False, # Shuffle is handled by the sampler
num_workers=config.get('num_workers', 2),
pin_memory=True,
drop_last=True
)
val_loader = DataLoader(
valid_dataset,
batch_size=config['batch_size'],
sampler=val_sampler,
shuffle=False,
num_workers=config.get('num_workers', 2),
pin_memory=True,
drop_last=False
)
print(f"[Rank {rank}] Dataloaders created. Train steps/epoch: {len(train_loader)}, Val steps: {len(val_loader)}")
return train_loader, val_loader, train_dataset, valid_dataset
def train_model(model, device, config, save_dir, logger, rank, world_size):
"""
The main training and validation loop for the tokenizer.
Args:
model (DDP): The DDP-wrapped model to train.
device (torch.device): The device for the current process.
config (dict): Configuration dictionary.
save_dir (str): Directory to save checkpoints.
logger (comet_ml.Experiment): Comet logger instance.
rank (int): Global rank of the process.
world_size (int): Total number of processes.
Returns:
tuple: A tuple containing the trained model and a dictionary of results.
"""
start_time = time.time()
if rank == 0:
effective_bs = config['batch_size'] * world_size * config['accumulation_steps']
print(f"[Rank {rank}] BATCHSIZE (per GPU): {config['batch_size']}")
print(f"[Rank {rank}] Effective total batch size: {effective_bs}")
train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config['tokenizer_learning_rate'],
weight_decay=config['adam_weight_decay']
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=config['tokenizer_learning_rate'],
steps_per_epoch=len(train_loader),
epochs=config['epochs'],
pct_start=0.03,
div_factor=10
)
best_val_loss = float('inf')
dt_result = {}
batch_idx_global_train = 0
for epoch_idx in range(config['epochs']):
epoch_start_time = time.time()
model.train()
train_loader.sampler.set_epoch(epoch_idx)
# Set dataset seeds for reproducible sampling
train_dataset.set_epoch_seed(epoch_idx * 10000 + rank)
valid_dataset.set_epoch_seed(0) # Keep validation sampling consistent
for i, (ori_batch_x, _) in enumerate(train_loader):
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
# --- Gradient Accumulation Loop ---
current_batch_total_loss = 0.0
for j in range(config['accumulation_steps']):
start_idx = j * (ori_batch_x.shape[0] // config['accumulation_steps'])
end_idx = (j + 1) * (ori_batch_x.shape[0] // config['accumulation_steps'])
batch_x = ori_batch_x[start_idx:end_idx]
# Forward pass
zs, bsq_loss, _, _ = model(batch_x)
z_pre, z = zs
# Loss calculation
recon_loss_pre = F.mse_loss(z_pre, batch_x)
recon_loss_all = F.mse_loss(z, batch_x)
recon_loss = recon_loss_pre + recon_loss_all
loss = (recon_loss + bsq_loss) / 2 # Assuming w_1=w_2=1
loss_scaled = loss / config['accumulation_steps']
current_batch_total_loss += loss.item()
loss_scaled.backward()
# --- Optimizer Step after Accumulation ---
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# --- Logging (Master Process Only) ---
if rank == 0 and (batch_idx_global_train + 1) % config['log_interval'] == 0:
avg_loss = current_batch_total_loss / config['accumulation_steps']
print(
f"[Rank {rank}, Epoch {epoch_idx + 1}/{config['epochs']}, Step {i + 1}/{len(train_loader)}] "
f"LR {optimizer.param_groups[0]['lr']:.6f}, Loss: {avg_loss:.4f}"
)
if rank == 0 and logger:
avg_loss = current_batch_total_loss / config['accumulation_steps']
logger.log_metric('train_tokenizer_loss_batch', avg_loss, step=batch_idx_global_train)
logger.log_metric(f'train_vqvae_vq_loss_each_batch', bsq_loss.item(), step=batch_idx_global_train)
logger.log_metric(f'train_recon_loss_pre_each_batch', recon_loss_pre.item(), step=batch_idx_global_train)
logger.log_metric(f'train_recon_loss_each_batch', recon_loss_all.item(), step=batch_idx_global_train)
logger.log_metric('tokenizer_learning_rate', optimizer.param_groups[0]["lr"], step=batch_idx_global_train)
batch_idx_global_train += 1
# --- Validation Loop ---
model.eval()
tot_val_loss_sum_rank = 0.0
val_sample_count_rank = 0
with torch.no_grad():
for ori_batch_x, _ in val_loader:
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
zs, _, _, _ = model(ori_batch_x)
_, z = zs
val_loss_item = F.mse_loss(z, ori_batch_x)
tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0)
val_sample_count_rank += ori_batch_x.size(0)
# Reduce validation losses from all processes
val_loss_sum_tensor = torch.tensor(tot_val_loss_sum_rank, device=device)
val_count_tensor = torch.tensor(val_sample_count_rank, device=device)
dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(val_count_tensor, op=dist.ReduceOp.SUM)
avg_val_loss = val_loss_sum_tensor.item() / val_count_tensor.item() if val_count_tensor.item() > 0 else 0
# --- End of Epoch Summary & Checkpointing (Master Process Only) ---
if rank == 0:
print(f"\n--- Epoch {epoch_idx + 1}/{config['epochs']} Summary ---")
print(f"Validation Loss: {avg_val_loss:.4f}")
print(f"Time This Epoch: {format_time(time.time() - epoch_start_time)}")
print(f"Total Time Elapsed: {format_time(time.time() - start_time)}\n")
if logger:
logger.log_metric('val_tokenizer_loss_epoch', avg_val_loss, epoch=epoch_idx)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
save_path = f"{save_dir}/checkpoints/best_model"
model.module.save_pretrained(save_path)
print(f"Best model saved to {save_path} (Val Loss: {best_val_loss:.4f})")
if logger:
logger.log_model("best_model", save_path)
dist.barrier() # Ensure all processes finish the epoch before starting the next one.
dt_result['best_val_loss'] = best_val_loss
return model, dt_result
def main(config: dict):
"""
Main function to orchestrate the DDP training process.
"""
rank, world_size, local_rank = setup_ddp()
device = torch.device(f"cuda:{local_rank}")
set_seed(config['seed'], rank)
save_dir = os.path.join(config['save_path'], config['tokenizer_save_folder_name'])
# Logger and summary setup (master process only)
comet_logger, master_summary = None, {}
if rank == 0:
os.makedirs(os.path.join(save_dir, 'checkpoints'), exist_ok=True)
master_summary = {
'start_time': strftime("%Y-%m-%dT%H-%M-%S", gmtime()),
'save_directory': save_dir,
'world_size': world_size,
}
if config['use_comet']:
comet_logger = comet_ml.Experiment(
api_key=config['comet_config']['api_key'],
project_name=config['comet_config']['project_name'],
workspace=config['comet_config']['workspace'],
)
comet_logger.add_tag(config['comet_tag'])
comet_logger.set_name(config['comet_name'])
comet_logger.log_parameters(config)
print("Comet Logger Initialized.")
dist.barrier() # Ensure save directory is created before proceeding
# Model Initialization
model = KronosTokenizer.from_pretrained(config['pretrained_tokenizer_path'])
model.to(device)
model = DDP(model, device_ids=[local_rank], find_unused_parameters=False)
if rank == 0:
print(f"Model Size: {get_model_size(model.module)}")
# Start Training
_, dt_result = train_model(
model, device, config, save_dir, comet_logger, rank, world_size
)
# Finalize and save summary (master process only)
if rank == 0:
master_summary['final_result'] = dt_result
with open(os.path.join(save_dir, 'summary.json'), 'w') as f:
json.dump(master_summary, f, indent=4)
print('Training finished. Summary file saved.')
if comet_logger:
comet_logger.end()
cleanup_ddp()
if __name__ == '__main__':
# Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_tokenizer.py
if "WORLD_SIZE" not in os.environ:
raise RuntimeError("This script must be launched with `torchrun`.")
config_instance = Config()
main(config_instance.__dict__)

View File

@ -1,118 +0,0 @@
import os
import random
import datetime
import numpy as np
import torch
import torch.distributed as dist
def setup_ddp():
"""
Initializes the distributed data parallel environment.
This function relies on environment variables set by `torchrun` or a similar
launcher. It initializes the process group and sets the CUDA device for the
current process.
Returns:
tuple: A tuple containing (rank, world_size, local_rank).
"""
if not dist.is_available():
raise RuntimeError("torch.distributed is not available.")
dist.init_process_group(backend="nccl")
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
print(
f"[DDP Setup] Global Rank: {rank}/{world_size}, "
f"Local Rank (GPU): {local_rank} on device {torch.cuda.current_device()}"
)
return rank, world_size, local_rank
def cleanup_ddp():
"""Cleans up the distributed process group."""
if dist.is_initialized():
dist.destroy_process_group()
def set_seed(seed: int, rank: int = 0):
"""
Sets the random seed for reproducibility across all relevant libraries.
Args:
seed (int): The base seed value.
rank (int): The process rank, used to ensure different processes have
different seeds, which can be important for data loading.
"""
actual_seed = seed + rank
random.seed(actual_seed)
np.random.seed(actual_seed)
torch.manual_seed(actual_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(actual_seed)
# The two lines below can impact performance, so they are often
# reserved for final experiments where reproducibility is critical.
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_model_size(model: torch.nn.Module) -> str:
"""
Calculates the number of trainable parameters in a PyTorch model and returns
it as a human-readable string.
Args:
model (torch.nn.Module): The PyTorch model.
Returns:
str: A string representing the model size (e.g., "175.0B", "7.1M", "50.5K").
"""
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
if total_params >= 1e9:
return f"{total_params / 1e9:.1f}B" # Billions
elif total_params >= 1e6:
return f"{total_params / 1e6:.1f}M" # Millions
else:
return f"{total_params / 1e3:.1f}K" # Thousands
def reduce_tensor(tensor: torch.Tensor, world_size: int, op=dist.ReduceOp.SUM) -> torch.Tensor:
"""
Reduces a tensor's value across all processes in a distributed setup.
Args:
tensor (torch.Tensor): The tensor to be reduced.
world_size (int): The total number of processes.
op (dist.ReduceOp, optional): The reduction operation (SUM, AVG, etc.).
Defaults to dist.ReduceOp.SUM.
Returns:
torch.Tensor: The reduced tensor, which will be identical on all processes.
"""
rt = tensor.clone()
dist.all_reduce(rt, op=op)
# Note: `dist.ReduceOp.AVG` is available in newer torch versions.
# For compatibility, manual division is sometimes used after a SUM.
if op == dist.ReduceOp.AVG:
rt /= world_size
return rt
def format_time(seconds: float) -> str:
"""
Formats a duration in seconds into a human-readable H:M:S string.
Args:
seconds (float): The total seconds.
Returns:
str: The formatted time string (e.g., "0:15:32").
"""
return str(datetime.timedelta(seconds=int(seconds)))

View File

@ -1,88 +0,0 @@
# Kronos Finetuning on Your Custom csv Dataset
Supports fine-tuning training with custom CSV data using configuration files
## 1. Prepare Your Data
**Data Format**: Ensure CSV file contains the following columns: `timestamps`, `open`, `high`, `low`, `close`, `volume`, `amount`
A good csv data should be like:
| timestamps | open | close | high | low | volume | amount |
|------------|------|-------|------|-----|--------|--------|
| 2019/11/26 9:35 | 182.45215 | 184.45215 | 184.95215 | 182.45215 | 15136000 | 0 |
| 2019/11/26 9:40 | 184.35215 | 183.85215 | 184.55215 | 183.45215 | 4433300 | 0 |
| ... | ... | ... | ... | ... | ... | ... |
| ... | ... | ... | ... | ... | ... | ... |
You can check "data/HK_ali_09988_kline_5min_all.csv" to find out the proper format.
## 2. Training
### Configuration Setup
First edit the `config.yaml` file to set the correct paths and parameters:
```yaml
# Data configuration
data:
data_path: "/path/to/your/data.csv"
lookback_window: 512
predict_window: 48
# ... other parameters
# Model path configuration
model_paths:
pretrained_tokenizer: "/path/to/pretrained/tokenizer"
pretrained_predictor: "/path/to/pretrained/predictor"
base_save_path: "/path/to/save/models"
# ... other paths
```
### Run Training
Using train_sequential
```bash
# Complete training
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml
# Skip existing models
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing
# Only train tokenizer
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel
# Only train basemodel
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer
```
Run each stage separately
```bash
# Only train tokenizer
python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml
# Only train basemodel (requires fine-tuned tokenizer first)
python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml
```
DDP Training
```bash
# Choose communication protocol yourself, nccl can be replaced with gloo
DIST_BACKEND=nccl \
torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml
```
## 2. Training Results
![HK_ali_09988_kline_5min_all_historical_20250919_073929](examples/HK_ali_09988_kline_5min_all_historical_20250919_073929.png)
![HK_ali_09988_kline_5min_all_historical_20250919_073944](examples/HK_ali_09988_kline_5min_all_historical_20250919_073944.png)
![HK_ali_09988_kline_5min_all_historical_20250919_074012](examples/HK_ali_09988_kline_5min_all_historical_20250919_074012.png)
![HK_ali_09988_kline_5min_all_historical_20250919_074042](examples/HK_ali_09988_kline_5min_all_historical_20250919_074042.png)
![HK_ali_09988_kline_5min_all_historical_20250919_074251](examples/HK_ali_09988_kline_5min_all_historical_20250919_074251.png)

View File

@ -1,105 +0,0 @@
# 自定义数据集的Kronos微调训练
支持使用配置文件进行自定义csv数据的微调训练
## 快速开始
### 1. 配置设置
首先编辑 `config.yaml` 文件,设置正确的路径和参数:
```yaml
# 数据配置
data:
data_path: "/path/to/your/data.csv"
lookback_window: 512
predict_window: 48
# ... 其他参数
# 模型路径配置
model_paths:
pretrained_tokenizer: "/path/to/pretrained/tokenizer"
pretrained_predictor: "/path/to/pretrained/predictor"
base_save_path: "/path/to/save/models"
# ... 其他路径
```
### 2. 运行训练
使用train_sequential
```bash
# 完整训练
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml
# 跳过已存在的模型
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing
# 只训练tokenizer
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel
# 只训练basemodel
python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer
```
单独运行各个阶段
```bash
# 只训练tokenizer
python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml
# 只训练basemodel需要先有微调后的tokenizer
python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml
```
DDP训练
```bash
# 通信协议自行选择nccl可替换gloo
DIST_BACKEND=nccl \
torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml
```
## 配置说明
### 主要配置项
- **data**: 数据相关配置
- `data_path`: CSV数据文件路径
- `lookback_window`: 回望窗口大小
- `predict_window`: 预测窗口大小
- `train_ratio/val_ratio/test_ratio`: 数据集分割比例
- **training**: 训练相关配置
- `epochs`: 训练轮数
- `batch_size`: 批次大小
- `tokenizer_learning_rate`: Tokenizer学习率
- `predictor_learning_rate`: Predictor学习率
- **model_paths**: 模型路径配置
- `pretrained_tokenizer`: 预训练tokenizer路径
- `pretrained_predictor`: 预训练predictor路径
- `base_save_path`: 模型保存根目录
- `finetuned_tokenizer`: 微调后tokenizer路径用于basemodel训练
- **experiment**: 实验控制
- `train_tokenizer`: 是否训练tokenizer
- `train_basemodel`: 是否训练basemodel
- `skip_existing`: 是否跳过已存在的模型
## 训练流程
1. **Tokenizer微调阶段**
- 加载预训练tokenizer
- 在自定义数据上微调
- 保存微调后的tokenizer到 `{base_save_path}/tokenizer/best_model/`
2. **Basemodel微调阶段**
- 加载微调后的tokenizer和预训练predictor
- 在自定义数据上微调
- 保存微调后的basemodel到 `{base_save_path}/basemodel/best_model/`
**数据格式**: 确保CSV文件包含以下列`timestamps`, `open`, `high`, `low`, `close`, `volume`, `amount`

View File

@ -1,267 +0,0 @@
import os
import yaml
from typing import Dict, Any
class ConfigLoader:
def __init__(self, config_path: str):
self.config_path = config_path
self.config = self._load_config()
def _load_config(self) -> Dict[str, Any]:
if not os.path.exists(self.config_path):
raise FileNotFoundError(f"config file not found: {self.config_path}")
with open(self.config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
config = self._resolve_dynamic_paths(config)
return config
def _resolve_dynamic_paths(self, config: Dict[str, Any]) -> Dict[str, Any]:
exp_name = config.get('model_paths', {}).get('exp_name', '')
if not exp_name:
return config
base_path = config.get('model_paths', {}).get('base_path', '')
path_templates = {
'base_save_path': f"{base_path}/{exp_name}",
'finetuned_tokenizer': f"{base_path}/{exp_name}/tokenizer/best_model"
}
if 'model_paths' in config:
for key, template in path_templates.items():
if key in config['model_paths']:
# only use template when the original value is empty string
current_value = config['model_paths'][key]
if current_value == "" or current_value is None:
config['model_paths'][key] = template
else:
# if the original value is not empty, use template to replace the {exp_name} placeholder
if isinstance(current_value, str) and '{exp_name}' in current_value:
config['model_paths'][key] = current_value.format(exp_name=exp_name)
return config
def get(self, key: str, default=None):
keys = key.split('.')
value = self.config
try:
for k in keys:
value = value[k]
return value
except (KeyError, TypeError):
return default
def get_data_config(self) -> Dict[str, Any]:
return self.config.get('data', {})
def get_training_config(self) -> Dict[str, Any]:
return self.config.get('training', {})
def get_model_paths(self) -> Dict[str, str]:
return self.config.get('model_paths', {})
def get_experiment_config(self) -> Dict[str, Any]:
return self.config.get('experiment', {})
def get_device_config(self) -> Dict[str, Any]:
return self.config.get('device', {})
def get_distributed_config(self) -> Dict[str, Any]:
return self.config.get('distributed', {})
def update_config(self, updates: Dict[str, Any]):
def update_nested_dict(d, u):
for k, v in u.items():
if isinstance(v, dict):
d[k] = update_nested_dict(d.get(k, {}), v)
else:
d[k] = v
return d
self.config = update_nested_dict(self.config, updates)
def save_config(self, save_path: str = None):
if save_path is None:
save_path = self.config_path
with open(save_path, 'w', encoding='utf-8') as f:
yaml.dump(self.config, f, default_flow_style=False, allow_unicode=True, indent=2)
def print_config(self):
print("=" * 50)
print("Current configuration:")
print("=" * 50)
yaml.dump(self.config, default_flow_style=False, allow_unicode=True, indent=2)
print("=" * 50)
class CustomFinetuneConfig:
def __init__(self, config_path: str = None):
if config_path is None:
config_path = os.path.join(os.path.dirname(__file__), 'config.yaml')
self.loader = ConfigLoader(config_path)
self._load_all_configs()
def _load_all_configs(self):
data_config = self.loader.get_data_config()
self.data_path = data_config.get('data_path')
self.lookback_window = data_config.get('lookback_window', 512)
self.predict_window = data_config.get('predict_window', 48)
self.max_context = data_config.get('max_context', 512)
self.clip = data_config.get('clip', 5.0)
self.train_ratio = data_config.get('train_ratio', 0.9)
self.val_ratio = data_config.get('val_ratio', 0.1)
self.test_ratio = data_config.get('test_ratio', 0.0)
# training configuration
training_config = self.loader.get_training_config()
# support training epochs of tokenizer and basemodel separately
self.tokenizer_epochs = training_config.get('tokenizer_epochs', 30)
self.basemodel_epochs = training_config.get('basemodel_epochs', 30)
if 'epochs' in training_config and 'tokenizer_epochs' not in training_config:
self.tokenizer_epochs = training_config.get('epochs', 30)
if 'epochs' in training_config and 'basemodel_epochs' not in training_config:
self.basemodel_epochs = training_config.get('epochs', 30)
self.batch_size = training_config.get('batch_size', 160)
self.log_interval = training_config.get('log_interval', 50)
self.num_workers = training_config.get('num_workers', 6)
self.seed = training_config.get('seed', 100)
self.tokenizer_learning_rate = training_config.get('tokenizer_learning_rate', 2e-4)
self.predictor_learning_rate = training_config.get('predictor_learning_rate', 4e-5)
self.adam_beta1 = training_config.get('adam_beta1', 0.9)
self.adam_beta2 = training_config.get('adam_beta2', 0.95)
self.adam_weight_decay = training_config.get('adam_weight_decay', 0.1)
self.accumulation_steps = training_config.get('accumulation_steps', 1)
model_paths = self.loader.get_model_paths()
self.exp_name = model_paths.get('exp_name', 'default_experiment')
self.pretrained_tokenizer_path = model_paths.get('pretrained_tokenizer')
self.pretrained_predictor_path = model_paths.get('pretrained_predictor')
self.base_save_path = model_paths.get('base_save_path')
self.tokenizer_save_name = model_paths.get('tokenizer_save_name', 'tokenizer')
self.basemodel_save_name = model_paths.get('basemodel_save_name', 'basemodel')
self.finetuned_tokenizer_path = model_paths.get('finetuned_tokenizer')
experiment_config = self.loader.get_experiment_config()
self.experiment_name = experiment_config.get('name', 'kronos_custom_finetune')
self.experiment_description = experiment_config.get('description', '')
self.use_comet = experiment_config.get('use_comet', False)
self.train_tokenizer = experiment_config.get('train_tokenizer', True)
self.train_basemodel = experiment_config.get('train_basemodel', True)
self.skip_existing = experiment_config.get('skip_existing', False)
unified_pretrained = experiment_config.get('pre_trained', None)
self.pre_trained_tokenizer = experiment_config.get('pre_trained_tokenizer', unified_pretrained if unified_pretrained is not None else True)
self.pre_trained_predictor = experiment_config.get('pre_trained_predictor', unified_pretrained if unified_pretrained is not None else True)
device_config = self.loader.get_device_config()
self.use_cuda = device_config.get('use_cuda', True)
self.device_id = device_config.get('device_id', 0)
distributed_config = self.loader.get_distributed_config()
self.use_ddp = distributed_config.get('use_ddp', False)
self.ddp_backend = distributed_config.get('backend', 'nccl')
self._compute_full_paths()
def _compute_full_paths(self):
self.tokenizer_save_path = os.path.join(self.base_save_path, self.tokenizer_save_name)
self.tokenizer_best_model_path = os.path.join(self.tokenizer_save_path, 'best_model')
self.basemodel_save_path = os.path.join(self.base_save_path, self.basemodel_save_name)
self.basemodel_best_model_path = os.path.join(self.basemodel_save_path, 'best_model')
def get_tokenizer_config(self):
return {
'data_path': self.data_path,
'lookback_window': self.lookback_window,
'predict_window': self.predict_window,
'max_context': self.max_context,
'clip': self.clip,
'train_ratio': self.train_ratio,
'val_ratio': self.val_ratio,
'test_ratio': self.test_ratio,
'epochs': self.tokenizer_epochs,
'batch_size': self.batch_size,
'log_interval': self.log_interval,
'num_workers': self.num_workers,
'seed': self.seed,
'learning_rate': self.tokenizer_learning_rate,
'adam_beta1': self.adam_beta1,
'adam_beta2': self.adam_beta2,
'adam_weight_decay': self.adam_weight_decay,
'accumulation_steps': self.accumulation_steps,
'pretrained_model_path': self.pretrained_tokenizer_path,
'save_path': self.tokenizer_save_path,
'use_comet': self.use_comet
}
def get_basemodel_config(self):
return {
'data_path': self.data_path,
'lookback_window': self.lookback_window,
'predict_window': self.predict_window,
'max_context': self.max_context,
'clip': self.clip,
'train_ratio': self.train_ratio,
'val_ratio': self.val_ratio,
'test_ratio': self.test_ratio,
'epochs': self.basemodel_epochs,
'batch_size': self.batch_size,
'log_interval': self.log_interval,
'num_workers': self.num_workers,
'seed': self.seed,
'predictor_learning_rate': self.predictor_learning_rate,
'tokenizer_learning_rate': self.tokenizer_learning_rate,
'adam_beta1': self.adam_beta1,
'adam_beta2': self.adam_beta2,
'adam_weight_decay': self.adam_weight_decay,
'pretrained_tokenizer_path': self.finetuned_tokenizer_path,
'pretrained_predictor_path': self.pretrained_predictor_path,
'save_path': self.basemodel_save_path,
'use_comet': self.use_comet
}
def print_config_summary(self):
print("=" * 60)
print("Kronos finetuning configuration summary")
print("=" * 60)
print(f"Experiment name: {self.exp_name}")
print(f"Data path: {self.data_path}")
print(f"Lookback window: {self.lookback_window}")
print(f"Predict window: {self.predict_window}")
print(f"Tokenizer training epochs: {self.tokenizer_epochs}")
print(f"Basemodel training epochs: {self.basemodel_epochs}")
print(f"Batch size: {self.batch_size}")
print(f"Tokenizer learning rate: {self.tokenizer_learning_rate}")
print(f"Predictor learning rate: {self.predictor_learning_rate}")
print(f"Train tokenizer: {self.train_tokenizer}")
print(f"Train basemodel: {self.train_basemodel}")
print(f"Skip existing: {self.skip_existing}")
print(f"Use pre-trained tokenizer: {self.pre_trained_tokenizer}")
print(f"Use pre-trained predictor: {self.pre_trained_predictor}")
print(f"Base save path: {self.base_save_path}")
print(f"Tokenizer save path: {self.tokenizer_save_path}")
print(f"Basemodel save path: {self.basemodel_save_path}")
print("=" * 60)

View File

@ -1,72 +0,0 @@
#This is a template config for custom finetuning kronos on csv data
#这是一份模板config用于kronos的csv自定义数据微调
data:
data_path: "/xxxx/Kronos/finetune_csv2/data/HK_ali_09988_kline_5min_all.csv"
lookback_window: 512
predict_window: 48
max_context: 512
clip: 5.0
# dataset split ratio
train_ratio: 0.9
val_ratio: 0.1
test_ratio: 0.0
training:
# control the training epochs of tokenizer and basemodel
tokenizer_epochs: 30
basemodel_epochs: 20
batch_size: 32
log_interval: 50
num_workers: 6
seed: 42
tokenizer_learning_rate: 0.0002
predictor_learning_rate: 0.000001
adam_beta1: 0.9
adam_beta2: 0.95
adam_weight_decay: 0.1
# gradient accumulation steps for tokenizer training
accumulation_steps: 1
# model path configuration
model_paths:
# pretrained model path
pretrained_tokenizer: "/mnt/DigitalHuman2D/boyuzhang/quant/Kronos/pretrained/Kronos-Tokenizer-base"
pretrained_predictor: "/mnt/DigitalHuman2D/boyuzhang/quant/Kronos/pretrained/Kronos-base"
# experiment name - other paths will be generated based on this
exp_name: "HK_ali_09988_kline_5min_all"
base_path: "/mnt/DigitalHuman2D/boyuzhang/quant/Kronos/finetune_csv/finetuned/"
# the following paths will be generated based on exp_name, no need to modify manually
# way 1: leave empty string, the system will generate the full path
base_save_path: "" # /xxxx/Kronos/finetune_csv/finetuned/{exp_name}
finetuned_tokenizer: "" # /xxxx/quant/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model
# way 2: use template string, {exp_name} will be replaced with the actual experiment name
# base_save_path: "/xxxx/Kronos/finetune_csv/finetuned/{exp_name}"
# finetuned_tokenizer: "/xxxx/quant/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model"
tokenizer_save_name: "tokenizer"
basemodel_save_name: "basemodel"
experiment:
name: "kronos_custom_finetune"
description: "Custom finetune for HK stock data"
use_comet: false
# control the training phase
train_tokenizer: true
train_basemodel: true
# if true, skip the existing model training
skip_existing: false
# device configuration
device:
use_cuda: true
device_id: 0

Binary file not shown.

Before

Width:  |  Height:  |  Size: 474 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 473 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 331 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 449 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 530 KiB

View File

@ -1,468 +0,0 @@
import os
import sys
import json
import time
import pickle
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from time import gmtime, strftime
import logging
from logging.handlers import RotatingFileHandler
import datetime
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
sys.path.append('../')
from model import Kronos, KronosTokenizer, KronosPredictor
from config_loader import CustomFinetuneConfig
class CustomKlineDataset(Dataset):
def __init__(self, data_path, data_type='train', lookback_window=90, predict_window=10,
clip=5.0, seed=100, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
self.data_path = data_path
self.data_type = data_type
self.lookback_window = lookback_window
self.predict_window = predict_window
self.window = lookback_window + predict_window + 1
self.clip = clip
self.seed = seed
self.train_ratio = train_ratio
self.val_ratio = val_ratio
self.test_ratio = test_ratio
self.feature_list = ['open', 'high', 'low', 'close', 'volume', 'amount']
self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month']
self.py_rng = random.Random(seed)
self._load_and_preprocess_data()
self._split_data_by_time()
self.n_samples = len(self.data) - self.window + 1
print(f"[{data_type.upper()}] Data length: {len(self.data)}, Available samples: {self.n_samples}")
def _load_and_preprocess_data(self):
df = pd.read_csv(self.data_path)
df['timestamps'] = pd.to_datetime(df['timestamps'])
df = df.sort_values('timestamps').reset_index(drop=True)
self.timestamps = df['timestamps'].copy()
df['minute'] = df['timestamps'].dt.minute
df['hour'] = df['timestamps'].dt.hour
df['weekday'] = df['timestamps'].dt.weekday
df['day'] = df['timestamps'].dt.day
df['month'] = df['timestamps'].dt.month
self.data = df[self.feature_list + self.time_feature_list].copy()
if self.data.isnull().any().any():
print("Warning: Missing values found in data, performing forward fill")
self.data = self.data.fillna(method='ffill')
print(f"Original data time range: {self.timestamps.min()} to {self.timestamps.max()}")
print(f"Original data total length: {len(df)} records")
def _split_data_by_time(self):
total_length = len(self.data)
train_end = int(total_length * self.train_ratio)
val_end = int(total_length * (self.train_ratio + self.val_ratio))
if self.data_type == 'train':
self.data = self.data.iloc[:train_end].copy()
self.timestamps = self.timestamps.iloc[:train_end].copy()
print(f"[{self.data_type.upper()}] Training set: first {train_end} time points ({self.train_ratio})")
print(f"[{self.data_type.upper()}] Training set time range: {self.timestamps.min()} to {self.timestamps.max()}")
elif self.data_type == 'val':
self.data = self.data.iloc[train_end:val_end].copy()
self.timestamps = self.timestamps.iloc[train_end:val_end].copy()
print(f"[{self.data_type.upper()}] Validation set: time points {train_end+1} to {val_end} ({self.val_ratio})")
print(f"[{self.data_type.upper()}] Validation set time range: {self.timestamps.min()} to {self.timestamps.max()}")
elif self.data_type == 'test':
self.data = self.data.iloc[val_end:].copy()
self.timestamps = self.timestamps.iloc[val_end:].copy()
print(f"[{self.data_type.upper()}] Test set: after time point {val_end+1}")
print(f"[{self.data_type.upper()}] Test set time range: {self.timestamps.min()} to {self.timestamps.max()}")
print(f"[{self.data_type.upper()}] Data length after split: {len(self.data)} records")
def set_epoch_seed(self, epoch):
epoch_seed = self.seed + epoch
self.py_rng.seed(epoch_seed)
self.current_epoch = epoch
def __len__(self):
return self.n_samples
def __getitem__(self, idx):
max_start = len(self.data) - self.window
if max_start <= 0:
raise ValueError("Data length insufficient to create samples")
if self.data_type == 'train':
epoch = getattr(self, 'current_epoch', 0)
start_idx = (idx * 9973 + (epoch + 1) * 104729) % (max_start + 1)
else:
start_idx = idx % (max_start + 1)
end_idx = start_idx + self.window
window_data = self.data.iloc[start_idx:end_idx]
x = window_data[self.feature_list].values.astype(np.float32)
x_stamp = window_data[self.time_feature_list].values.astype(np.float32)
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
x = (x - x_mean) / (x_std + 1e-5)
x = np.clip(x, -self.clip, self.clip)
x_tensor = torch.from_numpy(x)
x_stamp_tensor = torch.from_numpy(x_stamp)
return x_tensor, x_stamp_tensor
def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger:
os.makedirs(log_dir, exist_ok=True)
logger = logging.getLogger(f"basemodel_training_rank_{rank}")
logger.setLevel(logging.INFO)
if logger.handlers:
return logger
log_file = os.path.join(log_dir, f"basemodel_training_rank_{rank}.log")
file_handler = RotatingFileHandler(
log_file,
maxBytes=10*1024*1024,
backupCount=5,
encoding='utf-8'
)
file_handler.setLevel(logging.INFO)
console_handler = None
if rank == 0:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter)
if console_handler is not None:
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if console_handler is not None:
logger.addHandler(console_handler)
logger.info(f"=== Basemodel Training Started ===")
logger.info(f"Experiment Name: {exp_name}")
logger.info(f"Log Directory: {log_dir}")
logger.info(f"Rank: {rank}")
logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
return logger
def create_dataloaders(config):
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
print("Creating data loaders...")
train_dataset = CustomKlineDataset(
data_path=config.data_path,
data_type='train',
lookback_window=config.lookback_window,
predict_window=config.predict_window,
clip=config.clip,
seed=config.seed,
train_ratio=config.train_ratio,
val_ratio=config.val_ratio,
test_ratio=config.test_ratio
)
val_dataset = CustomKlineDataset(
data_path=config.data_path,
data_type='val',
lookback_window=config.lookback_window,
predict_window=config.predict_window,
clip=config.clip,
seed=config.seed + 1,
train_ratio=config.train_ratio,
val_ratio=config.val_ratio,
test_ratio=config.test_ratio
)
use_ddp = dist.is_available() and dist.is_initialized()
train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None
val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=(train_sampler is None),
num_workers=config.num_workers,
pin_memory=True,
drop_last=True,
sampler=train_sampler
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
drop_last=False,
sampler=val_sampler
)
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler
def train_model(model, tokenizer, device, config, save_dir, logger):
logger.info("Starting training...")
use_ddp = dist.is_available() and dist.is_initialized()
rank = dist.get_rank() if use_ddp else 0
world_size = dist.get_world_size() if use_ddp else 1
train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.predictor_learning_rate,
betas=(config.adam_beta1, config.adam_beta2),
weight_decay=config.adam_weight_decay
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=config.predictor_learning_rate,
steps_per_epoch=len(train_loader),
epochs=config.basemodel_epochs,
pct_start=0.03,
div_factor=10
)
if use_ddp:
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
best_val_loss = float('inf')
batch_idx_global = 0
for epoch in range(config.basemodel_epochs):
epoch_start_time = time.time()
model.train()
train_dataset.set_epoch_seed(epoch * 10000)
val_dataset.set_epoch_seed(0)
if train_sampler is not None:
train_sampler.set_epoch(epoch)
epoch_train_loss = 0.0
train_batches = 0
for batch_idx, (batch_x, batch_x_stamp) in enumerate(train_loader):
batch_x = batch_x.to(device, non_blocking=True)
batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)
with torch.no_grad():
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
loss, s1_loss, s2_loss = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=3.0)
optimizer.step()
scheduler.step()
epoch_train_loss += loss.item()
train_batches += 1
if (batch_idx_global + 1) % config.log_interval == 0:
lr = optimizer.param_groups[0]['lr']
log_msg = (f"[Epoch {epoch+1}/{config.basemodel_epochs}, Step {batch_idx+1}/{len(train_loader)}] "
f"LR: {lr:.6f}, Loss: {loss.item():.4f}")
logger.info(log_msg)
if rank == 0:
print(log_msg)
batch_idx_global += 1
model.eval()
val_loss = 0.0
val_batches = 0
with torch.no_grad():
for batch_x, batch_x_stamp in val_loader:
batch_x = batch_x.to(device, non_blocking=True)
batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
loss, _, _ = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
val_loss += loss.item()
val_batches += 1
if use_ddp:
tensor_sum = torch.tensor([epoch_train_loss, train_batches, val_loss, val_batches], dtype=torch.float64, device=device)
dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM)
epoch_train_loss_all = tensor_sum[0].item()
train_batches_all = int(tensor_sum[1].item())
val_loss_all = tensor_sum[2].item()
val_batches_all = int(tensor_sum[3].item())
avg_train_loss = (epoch_train_loss_all / train_batches_all) if train_batches_all > 0 else 0.0
avg_val_loss = (val_loss_all / val_batches_all) if val_batches_all > 0 else 0.0
else:
avg_train_loss = epoch_train_loss / train_batches if train_batches > 0 else 0
avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
epoch_time = time.time() - epoch_start_time
epoch_summary = (f"\n--- Epoch {epoch+1}/{config.basemodel_epochs} Summary ---\n"
f"Training Loss: {avg_train_loss:.4f}\n"
f"Validation Loss: {avg_val_loss:.4f}\n"
f"Epoch Time: {epoch_time:.2f} seconds\n")
logger.info(epoch_summary)
if rank == 0:
print(epoch_summary)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
if rank == 0:
model_save_path = os.path.join(save_dir, "best_model")
os.makedirs(model_save_path, exist_ok=True)
(model.module if use_ddp else model).save_pretrained(model_save_path)
save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})"
logger.info(save_msg)
print(save_msg)
return best_val_loss
def main():
import argparse
parser = argparse.ArgumentParser(description='Kronos Basemodel Fine-tuning Training')
parser.add_argument('--config', type=str, default='config.yaml',
help='Configuration file path (default: config.yaml)')
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
config = CustomFinetuneConfig(args.config)
os.makedirs(config.basemodel_save_path, exist_ok=True)
log_dir = os.path.join(config.base_save_path, "logs")
logger = setup_logging(config.exp_name, log_dir, 0)
torch.manual_seed(config.seed)
np.random.seed(config.seed)
random.seed(config.seed)
logger.info("Loading pretrained model or random initialization...")
print("Loading pretrained model or random initialization...")
if getattr(config, 'pre_trained_tokenizer', True):
tokenizer = KronosTokenizer.from_pretrained(config.finetuned_tokenizer_path)
else:
import json, os
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for training")
cfg_path_tok = os.path.join(config.pretrained_tokenizer_path if hasattr(config, 'pretrained_tokenizer_path') else config.finetuned_tokenizer_path, 'config.json')
with open(cfg_path_tok, 'r') as f:
arch_t = json.load(f)
tokenizer = KronosTokenizer(
d_in=arch_t.get('d_in', 6),
d_model=arch_t.get('d_model', 256),
n_heads=arch_t.get('n_heads', 4),
ff_dim=arch_t.get('ff_dim', 512),
n_enc_layers=arch_t.get('n_enc_layers', 4),
n_dec_layers=arch_t.get('n_dec_layers', 4),
ffn_dropout_p=arch_t.get('ffn_dropout_p', 0.0),
attn_dropout_p=arch_t.get('attn_dropout_p', 0.0),
resid_dropout_p=arch_t.get('resid_dropout_p', 0.0),
s1_bits=arch_t.get('s1_bits', 10),
s2_bits=arch_t.get('s2_bits', 10),
beta=arch_t.get('beta', 0.05),
gamma0=arch_t.get('gamma0', 1.0),
gamma=arch_t.get('gamma', 1.1),
zeta=arch_t.get('zeta', 0.05),
group_size=arch_t.get('group_size', 4)
)
if getattr(config, 'pre_trained_predictor', True):
model = Kronos.from_pretrained(config.pretrained_predictor_path)
else:
import json, os
print("pre_trained_predictor=False, randomly initializing Predictor architecture for training")
cfg_path = os.path.join(config.pretrained_predictor_path, 'config.json')
with open(cfg_path, 'r') as f:
arch = json.load(f)
model = Kronos(
s1_bits=arch.get('s1_bits', 10),
s2_bits=arch.get('s2_bits', 10),
n_layers=arch.get('n_layers', 12),
d_model=arch.get('d_model', 832),
n_heads=arch.get('n_heads', 16),
ff_dim=arch.get('ff_dim', 2048),
ffn_dropout_p=arch.get('ffn_dropout_p', 0.2),
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
resid_dropout_p=arch.get('resid_dropout_p', 0.2),
token_dropout_p=arch.get('token_dropout_p', 0.0),
learn_te=arch.get('learn_te', True)
)
tokenizer = tokenizer.to(device)
model = model.to(device)
model_size = sum(p.numel() for p in model.parameters())
logger.info(f"Model parameters: {model_size:,}")
print(f"Model parameters: {model_size:,}")
logger.info("=== Training Configuration ===")
logger.info(f"Data path: {config.data_path}")
logger.info(f"Lookback window: {config.lookback_window}")
logger.info(f"Predict window: {config.predict_window}")
logger.info(f"Batch size: {config.batch_size}")
logger.info(f"Learning rate: {config.predictor_learning_rate}")
logger.info(f"Training epochs: {config.basemodel_epochs}")
logger.info(f"Device: {device}")
logger.info(f"Tokenizer path: {config.finetuned_tokenizer_path}")
logger.info(f"Pretrained model path: {config.pretrained_predictor_path}")
logger.info("Starting fine-tuning training...")
print("Starting fine-tuning training...")
best_val_loss = train_model(model, tokenizer, device, config, config.basemodel_save_path, logger)
final_msg = f"Training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.basemodel_save_path}"
logger.info(final_msg)
print(final_msg)
if __name__ == "__main__":
main()

View File

@ -1,359 +0,0 @@
import os
import sys
import json
import time
import random
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from time import gmtime, strftime
import datetime
import logging
from logging.handlers import RotatingFileHandler
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
sys.path.append("../")
from model import KronosTokenizer
from finetune_base_model import CustomKlineDataset
from config_loader import CustomFinetuneConfig
def set_seed(seed: int, rank: int = 0):
actual_seed = seed
random.seed(actual_seed)
np.random.seed(actual_seed)
torch.manual_seed(actual_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(actual_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_model_size(model: torch.nn.Module) -> str:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
if total_params >= 1e9:
return f"{total_params / 1e9:.1f}B"
elif total_params >= 1e6:
return f"{total_params / 1e6:.1f}M"
else:
return f"{total_params / 1e3:.1f}K"
def format_time(seconds: float) -> str:
return str(datetime.timedelta(seconds=int(seconds)))
def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger:
os.makedirs(log_dir, exist_ok=True)
logger = logging.getLogger(f"tokenizer_training_rank_{rank}")
logger.setLevel(logging.INFO)
if logger.handlers:
return logger
log_file = os.path.join(log_dir, f"tokenizer_training_rank_{rank}.log")
file_handler = RotatingFileHandler(
log_file,
maxBytes=10*1024*1024,
backupCount=5,
encoding='utf-8'
)
file_handler.setLevel(logging.INFO)
console_handler = None
if rank == 0:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter)
if console_handler is not None:
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if console_handler is not None:
logger.addHandler(console_handler)
logger.info(f"=== Tokenizer Training Started ===")
logger.info(f"Experiment Name: {exp_name}")
logger.info(f"Log Directory: {log_dir}")
logger.info(f"Rank: {rank}")
logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
return logger
def create_dataloaders(config):
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
print("Creating tokenizer training data loaders...")
train_dataset = CustomKlineDataset(
data_path=config.data_path,
data_type="train",
lookback_window=config.lookback_window,
predict_window=config.predict_window,
clip=config.clip,
seed=config.seed,
train_ratio=config.train_ratio,
val_ratio=config.val_ratio,
test_ratio=config.test_ratio
)
val_dataset = CustomKlineDataset(
data_path=config.data_path,
data_type="val",
lookback_window=config.lookback_window,
predict_window=config.predict_window,
clip=config.clip,
seed=config.seed + 1,
train_ratio=config.train_ratio,
val_ratio=config.val_ratio,
test_ratio=config.test_ratio
)
use_ddp = dist.is_available() and dist.is_initialized()
train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None
val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=(train_sampler is None),
num_workers=config.num_workers,
pin_memory=True,
drop_last=True,
sampler=train_sampler
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
drop_last=False,
sampler=val_sampler
)
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler
def train_tokenizer(model, device, config, save_dir, logger):
logger.info("Starting tokenizer training...")
use_ddp = dist.is_available() and dist.is_initialized()
rank = dist.get_rank() if use_ddp else 0
world_size = dist.get_world_size() if use_ddp else 1
train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.tokenizer_learning_rate,
weight_decay=config.adam_weight_decay
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=config.tokenizer_learning_rate,
steps_per_epoch=len(train_loader),
epochs=config.tokenizer_epochs,
pct_start=0.03,
div_factor=10
)
if use_ddp:
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
best_val_loss = float("inf")
batch_idx_global = 0
accumulation_steps = getattr(config, 'accumulation_steps', 1)
for epoch in range(config.tokenizer_epochs):
epoch_start_time = time.time()
model.train()
train_dataset.set_epoch_seed(epoch * 10000)
val_dataset.set_epoch_seed(0)
if train_sampler is not None:
train_sampler.set_epoch(epoch)
for batch_idx, (ori_batch_x, _) in enumerate(train_loader):
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
current_batch_total_loss = 0.0
for j in range(accumulation_steps):
start_idx = j * (ori_batch_x.shape[0] // accumulation_steps)
end_idx = (j + 1) * (ori_batch_x.shape[0] // accumulation_steps)
batch_x = ori_batch_x[start_idx:end_idx]
zs, bsq_loss, _, _ = (model.module if use_ddp else model)(batch_x)
z_pre, z = zs
recon_loss_pre = F.mse_loss(z_pre, batch_x)
recon_loss_all = F.mse_loss(z, batch_x)
recon_loss = recon_loss_pre + recon_loss_all
loss = (recon_loss + bsq_loss) / 2
loss_scaled = loss / accumulation_steps
current_batch_total_loss += loss.item()
loss_scaled.backward()
torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=2.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if (batch_idx_global + 1) % config.log_interval == 0:
avg_loss = current_batch_total_loss / accumulation_steps
lr = optimizer.param_groups[0]["lr"]
log_msg = (f"[Epoch {epoch+1}/{config.tokenizer_epochs}, Step {batch_idx+1}/{len(train_loader)}] "
f"LR: {lr:.6f}, Loss: {avg_loss:.4f}")
logger.info(log_msg)
if rank == 0:
print(log_msg)
detail_msg = (f" - VQ Loss: {bsq_loss.item():.4f}\n"
f" - Recon Loss Pre: {recon_loss_pre.item():.4f}\n"
f" - Recon Loss All: {recon_loss_all.item():.4f}")
logger.info(detail_msg)
if rank == 0:
print(detail_msg)
batch_idx_global += 1
model.eval()
tot_val_loss_sum_rank = 0.0
val_sample_count_rank = 0
with torch.no_grad():
for ori_batch_x, _ in val_loader:
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
zs, _, _, _ = (model.module if use_ddp else model)(ori_batch_x)
_, z = zs
val_loss_item = F.mse_loss(z, ori_batch_x)
tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0)
val_sample_count_rank += ori_batch_x.size(0)
if use_ddp:
tensor_sum = torch.tensor([tot_val_loss_sum_rank, val_sample_count_rank], dtype=torch.float64, device=device)
dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM)
tot_val_loss_all = tensor_sum[0].item()
val_count_all = int(tensor_sum[1].item())
avg_val_loss = (tot_val_loss_all / val_count_all) if val_count_all > 0 else 0.0
else:
avg_val_loss = tot_val_loss_sum_rank / val_sample_count_rank if val_sample_count_rank > 0 else 0
epoch_time = time.time() - epoch_start_time
epoch_summary = (f"\n--- Epoch {epoch+1}/{config.tokenizer_epochs} Summary ---\n"
f"Validation Loss: {avg_val_loss:.4f}\n"
f"Epoch Time: {format_time(epoch_time)}\n"
f"Total Training Time: {format_time(time.time() - epoch_start_time)}\n")
logger.info(epoch_summary)
if rank == 0:
print(epoch_summary)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
if rank == 0:
model_save_path = os.path.join(save_dir, "best_model")
os.makedirs(model_save_path, exist_ok=True)
(model.module if use_ddp else model).save_pretrained(model_save_path)
save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})"
logger.info(save_msg)
print(save_msg)
return best_val_loss
def main():
import argparse
parser = argparse.ArgumentParser(description='Kronos Tokenizer Fine-tuning Training')
parser.add_argument('--config', type=str, default='config.yaml',
help='Configuration file path (default: config.yaml)')
args = parser.parse_args()
config = CustomFinetuneConfig(args.config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
config = CustomFinetuneConfig(args.config)
os.makedirs(config.tokenizer_save_path, exist_ok=True)
log_dir = os.path.join(config.base_save_path, "logs")
logger = setup_logging(config.exp_name, log_dir, 0)
set_seed(config.seed)
# 加载预训练tokenizer
if getattr(config, 'pre_trained_tokenizer', True):
logger.info("Loading pretrained tokenizer...")
print("Loading pretrained tokenizer...")
tokenizer = KronosTokenizer.from_pretrained(config.pretrained_tokenizer_path)
else:
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture")
import json, os
cfg_path = os.path.join(config.pretrained_tokenizer_path, 'config.json')
with open(cfg_path, 'r') as f:
arch = json.load(f)
tokenizer = KronosTokenizer(
d_in=arch.get('d_in', 6),
d_model=arch.get('d_model', 256),
n_heads=arch.get('n_heads', 4),
ff_dim=arch.get('ff_dim', 512),
n_enc_layers=arch.get('n_enc_layers', 4),
n_dec_layers=arch.get('n_dec_layers', 4),
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
s1_bits=arch.get('s1_bits', 10),
s2_bits=arch.get('s2_bits', 10),
beta=arch.get('beta', 0.05),
gamma0=arch.get('gamma0', 1.0),
gamma=arch.get('gamma', 1.1),
zeta=arch.get('zeta', 0.05),
group_size=arch.get('group_size', 4)
)
tokenizer = tokenizer.to(device)
model_size = get_model_size(tokenizer)
logger.info(f"Tokenizer parameters: {model_size}")
print(f"Tokenizer parameters: {model_size}")
logger.info("=== Training Configuration ===")
logger.info(f"Data path: {config.data_path}")
logger.info(f"Lookback window: {config.lookback_window}")
logger.info(f"Predict window: {config.predict_window}")
logger.info(f"Batch size: {config.batch_size}")
logger.info(f"Learning rate: {config.tokenizer_learning_rate}")
logger.info(f"Training epochs: {config.tokenizer_epochs}")
logger.info(f"Device: {device}")
logger.info(f"Distributed training: False")
logger.info("Starting tokenizer fine-tuning training...")
print("Starting tokenizer fine-tuning training...")
best_val_loss = train_tokenizer(tokenizer, device, config, config.tokenizer_save_path, logger)
final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.tokenizer_save_path}"
logger.info(final_msg)
print(final_msg)
if __name__ == "__main__":
main()

View File

@ -1,361 +0,0 @@
import os
import sys
import time
import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.distributed as dist
sys.path.append('../')
from model import Kronos, KronosTokenizer, KronosPredictor
from config_loader import CustomFinetuneConfig
from finetune_tokenizer import train_tokenizer, set_seed, setup_logging as setup_tokenizer_logging
from finetune_base_model import train_model, create_dataloaders, setup_logging as setup_basemodel_logging
class SequentialTrainer:
def __init__(self, config_path: str = None):
self.config = CustomFinetuneConfig(config_path)
self.rank = int(os.environ.get("RANK", "0"))
self.world_size = int(os.environ.get("WORLD_SIZE", "1"))
self.local_rank = int(os.environ.get("LOCAL_RANK", str(self.config.device_id if hasattr(self.config, 'device_id') else 0)))
self.device = self._setup_device()
self.config.print_config_summary()
def _setup_device(self):
if self.config.use_cuda and torch.cuda.is_available():
torch.cuda.set_device(self.local_rank)
device = torch.device(f"cuda:{self.local_rank}")
else:
device = torch.device("cpu")
if self.rank == 0:
print(f"Using device: {device} (rank={self.rank}, world_size={self.world_size}, local_rank={self.local_rank})")
return device
def _setup_distributed(self):
if self.world_size > 1 and torch.cuda.is_available():
backend = os.environ.get("DIST_BACKEND", "nccl").lower()
if not dist.is_initialized():
dist.init_process_group(backend=backend)
if self.rank == 0:
print(f"Distributed training initialized: backend={backend}, world_size={self.world_size}")
else:
if self.rank == 0:
print("Distributed training not enabled, using single GPU/CPU training")
def _check_existing_models(self):
tokenizer_exists = os.path.exists(self.config.tokenizer_best_model_path)
basemodel_exists = os.path.exists(self.config.basemodel_best_model_path)
print(f"Tokenizer model exists: {tokenizer_exists}")
print(f"Basemodel model exists: {basemodel_exists}")
return tokenizer_exists, basemodel_exists
def _create_directories(self):
os.makedirs(self.config.tokenizer_save_path, exist_ok=True)
os.makedirs(self.config.basemodel_save_path, exist_ok=True)
print(f"Created directory: {self.config.tokenizer_save_path}")
print(f"Created directory: {self.config.basemodel_save_path}")
def train_tokenizer_phase(self):
print("\n" + "="*60)
print("Starting Tokenizer Fine-tuning Phase")
print("="*60)
tokenizer_exists, _ = self._check_existing_models()
if tokenizer_exists and self.config.skip_existing:
print("Tokenizer model already exists, skipping training")
return True
log_dir = os.path.join(self.config.base_save_path, "logs")
logger = setup_tokenizer_logging(self.config.exp_name, log_dir, self.rank)
set_seed(self.config.seed)
if getattr(self.config, 'pre_trained_tokenizer', True):
logger.info("Loading pretrained tokenizer...")
if self.rank == 0:
print("Loading pretrained tokenizer...")
tokenizer = KronosTokenizer.from_pretrained(self.config.pretrained_tokenizer_path)
else:
if self.rank == 0:
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture")
import json
cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json')
with open(cfg_path, 'r') as f:
arch = json.load(f)
tokenizer = KronosTokenizer(
d_in=arch.get('d_in', 6),
d_model=arch.get('d_model', 256),
n_heads=arch.get('n_heads', 4),
ff_dim=arch.get('ff_dim', 512),
n_enc_layers=arch.get('n_enc_layers', 4),
n_dec_layers=arch.get('n_dec_layers', 4),
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
s1_bits=arch.get('s1_bits', 10),
s2_bits=arch.get('s2_bits', 10),
beta=arch.get('beta', 0.05),
gamma0=arch.get('gamma0', 1.0),
gamma=arch.get('gamma', 1.1),
zeta=arch.get('zeta', 0.05),
group_size=arch.get('group_size', 4)
)
tokenizer = tokenizer.to(self.device)
model_size = sum(p.numel() for p in tokenizer.parameters())
logger.info(f"Tokenizer parameters: {model_size:,}")
if self.rank == 0:
print(f"Tokenizer parameters: {model_size:,}")
logger.info("=== Training Configuration ===")
logger.info(f"Data path: {self.config.data_path}")
logger.info(f"Lookback window: {self.config.lookback_window}")
logger.info(f"Predict window: {self.config.predict_window}")
logger.info(f"Batch size: {self.config.batch_size}")
logger.info(f"Learning rate: {self.config.tokenizer_learning_rate}")
logger.info(f"Training epochs: {self.config.tokenizer_epochs}")
logger.info(f"Device: {self.device}")
logger.info(f"Distributed training: False")
logger.info("Starting tokenizer fine-tuning training...")
if self.rank == 0:
print("Starting tokenizer fine-tuning training...")
start_time = time.time()
best_val_loss = train_tokenizer(
tokenizer,
self.device,
self.config,
self.config.tokenizer_save_path,
logger,
)
training_time = time.time() - start_time
final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.tokenizer_save_path}"
logger.info(final_msg)
if self.rank == 0:
print(f"\n{final_msg}")
return True
def train_basemodel_phase(self):
print("\n" + "="*60)
print("Starting Basemodel Fine-tuning Phase")
print("="*60)
if getattr(self.config, 'pre_trained_tokenizer', True):
if not os.path.exists(self.config.finetuned_tokenizer_path):
raise FileNotFoundError(f"Fine-tuned tokenizer does not exist: {self.config.finetuned_tokenizer_path}")
_, basemodel_exists = self._check_existing_models()
if basemodel_exists and self.config.skip_existing:
print("Basemodel model already exists, skipping training")
return True
log_dir = os.path.join(self.config.base_save_path, "logs")
logger = setup_basemodel_logging(self.config.exp_name, log_dir, self.rank)
set_seed(self.config.seed)
if getattr(self.config, 'pre_trained_tokenizer', True):
logger.info("Loading fine-tuned tokenizer...")
if self.rank == 0:
print("Loading fine-tuned tokenizer...")
tokenizer = KronosTokenizer.from_pretrained(self.config.finetuned_tokenizer_path)
else:
if self.rank == 0:
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for Predictor training")
import json
cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json')
with open(cfg_path, 'r') as f:
arch = json.load(f)
tokenizer = KronosTokenizer(
d_in=arch.get('d_in', 6),
d_model=arch.get('d_model', 256),
n_heads=arch.get('n_heads', 4),
ff_dim=arch.get('ff_dim', 512),
n_enc_layers=arch.get('n_enc_layers', 4),
n_dec_layers=arch.get('n_dec_layers', 4),
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
s1_bits=arch.get('s1_bits', 10),
s2_bits=arch.get('s2_bits', 10),
beta=arch.get('beta', 0.05),
gamma0=arch.get('gamma0', 1.0),
gamma=arch.get('gamma', 1.1),
zeta=arch.get('zeta', 0.05),
group_size=arch.get('group_size', 4)
)
tokenizer = tokenizer.to(self.device)
if getattr(self.config, 'pre_trained_predictor', True):
logger.info("Loading pretrained predictor...")
if self.rank == 0:
print("Loading pretrained predictor...")
model = Kronos.from_pretrained(self.config.pretrained_predictor_path)
else:
if self.rank == 0:
print("pre_trained_predictor=False, randomly initializing Predictor architecture")
import json
cfg_path = os.path.join(self.config.pretrained_predictor_path, 'config.json')
with open(cfg_path, 'r') as f:
arch = json.load(f)
print("model_config: ", arch)
model = Kronos(
s1_bits=arch.get('s1_bits', 10),
s2_bits=arch.get('s2_bits', 10),
n_layers=arch.get('n_layers', 12),
d_model=arch.get('d_model', 832),
n_heads=arch.get('n_heads', 16),
ff_dim=arch.get('ff_dim', 2048),
ffn_dropout_p=arch.get('ffn_dropout_p', 0.2),
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
resid_dropout_p=arch.get('resid_dropout_p', 0.2),
token_dropout_p=arch.get('token_dropout_p', 0.0),
learn_te=arch.get('learn_te', True)
)
model = model.to(self.device)
model_size = sum(p.numel() for p in model.parameters())
logger.info(f"Model parameters: {model_size:,}")
if self.rank == 0:
print(f"Model parameters: {model_size:,}")
logger.info("=== Training Configuration ===")
logger.info(f"Data path: {self.config.data_path}")
logger.info(f"Lookback window: {self.config.lookback_window}")
logger.info(f"Predict window: {self.config.predict_window}")
logger.info(f"Batch size: {self.config.batch_size}")
logger.info(f"Learning rate: {self.config.predictor_learning_rate}")
logger.info(f"Training epochs: {self.config.basemodel_epochs}")
logger.info(f"Device: {self.device}")
logger.info(f"Tokenizer path: {self.config.finetuned_tokenizer_path}")
logger.info(f"Pretrained model path: {self.config.pretrained_predictor_path}")
logger.info("Starting fine-tuning training...")
if self.rank == 0:
print("Starting fine-tuning training...")
start_time = time.time()
best_val_loss = train_model(
model,
tokenizer,
self.device,
self.config,
self.config.basemodel_save_path,
logger,
)
training_time = time.time() - start_time
final_msg = f"Basemodel training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.basemodel_save_path}"
logger.info(final_msg)
if self.rank == 0:
print(f"\n{final_msg}")
return True
def run_training(self):
if self.rank == 0:
print("Starting Kronos model sequential fine-tuning training")
print(f"Experiment name: {self.config.experiment_name}")
print(f"Experiment description: {self.config.experiment_description}")
self._setup_distributed()
self._create_directories()
tokenizer_exists, basemodel_exists = self._check_existing_models()
total_start_time = time.time()
try:
if self.config.train_tokenizer:
success = self.train_tokenizer_phase()
if not success:
print("Tokenizer training failed, terminating training")
return False
else:
print("Skipping Tokenizer training phase")
if self.config.train_basemodel:
success = self.train_basemodel_phase()
if not success:
print("Basemodel training failed, terminating training")
return False
else:
print("Skipping Basemodel training phase")
total_time = time.time() - total_start_time
if self.rank == 0:
print("\n" + "="*60)
print("Training completed!")
print("="*60)
print(f"Total training time: {total_time/60:.2f} minutes")
print(f"Tokenizer model: {self.config.tokenizer_best_model_path}")
print(f"Basemodel model: {self.config.basemodel_best_model_path}")
print("="*60)
return True
except Exception as e:
if self.rank == 0:
print(f"Error occurred during training: {str(e)}")
import traceback
traceback.print_exc()
return False
finally:
pass
def main():
parser = argparse.ArgumentParser(description='Kronos Model Sequential Fine-tuning Training')
parser.add_argument('--config', type=str, default='config.yaml',
help='Configuration file path (default: config.yaml)')
parser.add_argument('--skip-tokenizer', action='store_true',
help='Skip tokenizer training phase')
parser.add_argument('--skip-basemodel', action='store_true',
help='Skip basemodel training phase')
parser.add_argument('--skip-existing', action='store_true',
help='Skip training for existing models')
args = parser.parse_args()
trainer = SequentialTrainer(args.config)
if args.skip_tokenizer:
trainer.config.train_tokenizer = False
if args.skip_basemodel:
trainer.config.train_basemodel = False
if args.skip_existing:
trainer.config.skip_existing = True
success = trainer.run_training()
if success:
print("Training completed successfully!")
if dist.is_available() and dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
sys.exit(0)
else:
print("Training failed!")
if dist.is_available() and dist.is_initialized():
try:
dist.barrier()
dist.destroy_process_group()
except Exception:
pass
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -440,9 +440,24 @@ class HierarchicalEmbedding(nn.Module):
nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5) nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5)
nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5) nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5)
def split_token(self, token_ids: torch.Tensor, s2_bits: int):
"""Inputs:
token_ids (torch.Tensor): Composite token IDs of shape [batch_size, seq_len] or [N], each in range [0, 2^(s1_bits + s2_bits) - 1].
s2_bits (int): Number of low bits used for the fine token (s2).
"""
assert isinstance(s2_bits, int) and s2_bits > 0, "s2_bits must be a positive integer"
t = token_ids.long()
mask = (1 << s2_bits) - 1
s2_ids = t & mask # extract low bits
s1_ids = t >> s2_bits # extract high bits
return s1_ids, s2_ids
def forward(self, token_ids): def forward(self, token_ids):
"""Inputs: """Inputs:
token_ids: [batch_size, seq_len] token ID token_ids:
- tuple or list: (s1_ids, s2_ids), each of shape [batch_size, seq_len], or
- torch.Tensor: composite token IDs of shape [batch_size, seq_len], which will be split into (s1_ids, s2_ids) internally.
Output: [batch_size, seq_len, d_model] Output: [batch_size, seq_len, d_model]
""" """
if isinstance(token_ids, tuple) or isinstance(token_ids, list): if isinstance(token_ids, tuple) or isinstance(token_ids, list):

419
trader.py Normal file
View File

@ -0,0 +1,419 @@
# -*- coding: utf-8 -*-
"""
BTCUSD 15m 多次采样预测 区间聚合可视化
Author:
"""
import os
import sys
import math
import json
import time
import random
import argparse
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from dotenv import load_dotenv
from pybit.unified_trading import HTTP
# ==== Kronos ====
from model import Kronos, KronosTokenizer, KronosPredictor
# ========== Matplotlib 字体 ==========
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei", "SimHei", "Arial"]
plt.rcParams["axes.unicode_minus"] = False
# ========== 默认参数 ==========
NUM_SAMPLES = 20 # 多次采样次数(建议 10~50
QUANTILES = [0.1, 0.5, 0.9] # 预测区间分位(下/中/上)
LOOKBACK = 360 # 历史窗口
PRED_LEN = 96 # 未来预测点数
INTERVAL_MIN = 15 # K 线周期(分钟)
DEVICE = os.getenv("KRONOS_DEVICE", "cuda:0")
TEMPERATURE = 1.0
TOP_P = 0.9
OUTPUT_DIR = "figures"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ========== 实用函数 ==========
def set_global_seed(seed: int):
"""尽可能固定随机性(若底层采样逻辑使用 torch/np/random"""
try:
import torch
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
except Exception:
pass
random.seed(seed)
np.random.seed(seed)
def fetch_kline(
session: HTTP, symbol="BTCUSD", interval="15", limit=1000, category="inverse"
):
print("正在获取K线数据...")
resp = session.get_kline(
category=category, symbol=symbol, interval=interval, limit=limit
)
if resp.get("retCode", -1) != 0:
raise RuntimeError(f"获取数据失败: {resp.get('retMsg')}")
lst = resp["result"]["list"]
df = pd.DataFrame(
lst,
columns=["timestamps", "open", "high", "low", "close", "volume", "turnover"],
)
df["timestamps"] = pd.to_datetime(df["timestamps"].astype(float), unit="ms")
for c in ["open", "high", "low", "close", "volume", "turnover"]:
df[c] = df[c].astype(float)
df = df.sort_values("timestamps").reset_index(drop=True)
df["amount"] = df["turnover"]
print(f"获取到 {len(df)} 根K线数据")
return df
def prepare_io_windows(
df: pd.DataFrame, lookback=LOOKBACK, pred_len=PRED_LEN, interval_min=INTERVAL_MIN
):
end_idx = len(df)
start_idx = max(0, end_idx - lookback)
x_df = df.loc[
start_idx : end_idx - 1, ["open", "high", "low", "close", "volume", "amount"]
].reset_index(drop=True)
x_timestamp = df.loc[start_idx : end_idx - 1, "timestamps"].reset_index(drop=True)
last_ts = df.loc[end_idx - 1, "timestamps"]
future_timestamps = pd.date_range(
start=last_ts + pd.Timedelta(minutes=interval_min),
periods=pred_len,
freq=f"{interval_min}min",
)
y_timestamp = pd.Series(future_timestamps)
y_timestamp.index = range(len(y_timestamp))
print(f"数据总量: {len(df)} 根K线")
print(f"使用最新的 {lookback} 根K线索引 {start_idx}{end_idx-1}")
print(f"最后一根历史K线时间: {last_ts}")
print(f"预测未来 {pred_len} 根K线")
return x_df, x_timestamp, y_timestamp, start_idx, end_idx
def load_kronos(device=DEVICE):
print("正在加载 Kronos 模型...")
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
print("模型加载完成!\n")
return predictor
def run_one_prediction(
predictor,
x_df,
x_timestamp,
y_timestamp,
T=TEMPERATURE,
top_p=TOP_P,
seed=None,
verbose=False,
):
if seed is not None:
set_global_seed(seed)
return predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=len(y_timestamp),
T=T,
top_p=top_p,
verbose=verbose,
)
def run_multi_predictions(
predictor, x_df, x_timestamp, y_timestamp, num_samples=NUM_SAMPLES, base_seed=42,
temperature=TEMPERATURE, top_p=TOP_P
):
preds = []
print(f"正在进行多次预测:{num_samples}T={temperature}, top_p={top_p}...")
for i in tqdm(range(num_samples), desc="预测进度", unit="", ncols=100):
seed = base_seed + i # 每次不同 seed
pred_df = run_one_prediction(
predictor,
x_df,
x_timestamp,
y_timestamp,
T=temperature,
top_p=top_p,
seed=seed,
verbose=False,
)
# 兼容性处理:确保只取需要的列
cols_present = [
c
for c in ["open", "high", "low", "close", "volume", "amount"]
if c in pred_df.columns
]
pred_df = pred_df[cols_present].copy()
pred_df.reset_index(drop=True, inplace=True)
preds.append(pred_df)
print("多次预测完成。\n")
return preds
def aggregate_quantiles(pred_list, quantiles=QUANTILES):
"""
将多次预测列表聚合成分位数 DataFrame
输出列命名<col>_q10, <col>_q50, <col>_q90 quantiles 中的值
"""
# 先把每次预测拼成 3Dtime x feature x samples
keys = pred_list[0].columns.tolist()
T_len = len(pred_list[0])
S = len(pred_list)
data = {k: np.zeros((T_len, S), dtype=float) for k in keys}
for j, pdf in enumerate(pred_list):
for k in keys:
data[k][:, j] = pdf[k].values
out = {}
for k in keys:
for q in quantiles:
qv = np.quantile(data[k], q, axis=1)
out[f"{k}_q{int(q*100):02d}"] = qv
agg_df = pd.DataFrame(out)
return agg_df
def plot_results(historical_df, y_timestamp, agg_df, title_prefix="BTCUSD 15分钟"):
"""
上下两个子图共享X轴
- 上方高度3历史收盘价 + 预测收盘价区间(q10~q90) + 中位线(q50)
- 下方高度1历史成交量柱 + 预测成交量中位柱仅q50不显示区间
"""
import matplotlib.gridspec as gridspec
# 智能推断K线宽度柱宽
try:
if len(historical_df) >= 2:
hist_step = (historical_df["timestamps"].iloc[-1] - historical_df["timestamps"].iloc[-2])
else:
hist_step = pd.Timedelta(minutes=10)
if len(y_timestamp) >= 2:
pred_step = (y_timestamp.iloc[1] - y_timestamp.iloc[0])
else:
pred_step = pd.Timedelta(minutes=10)
bar_width_hist = hist_step * 0.8
bar_width_pred = pred_step * 0.8
except Exception:
bar_width_hist = pd.Timedelta(minutes=10)
bar_width_pred = pd.Timedelta(minutes=10)
# 取出预测分位
close_q10 = agg_df["close_q10"] if "close_q10" in agg_df else None
close_q50 = agg_df["close_q50"] if "close_q50" in agg_df else None
close_q90 = agg_df["close_q90"] if "close_q90" in agg_df else None
vol_q50 = agg_df["volume_q50"] if "volume_q50" in agg_df else None
# 图形与网格
fig = plt.figure(figsize=(18, 10))
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.08) # 高度1:3
# ===== 上:收盘价 =====
ax_price = fig.add_subplot(gs[0])
ax_price.plot(
historical_df["timestamps"],
historical_df["close"],
label="历史收盘价",
linewidth=1.8,
)
# 预测收盘价区间与中位线
if close_q10 is not None and close_q90 is not None:
ax_price.fill_between(
y_timestamp.values, close_q10, close_q90, alpha=0.25, label="预测收盘区间(q10~q90)"
)
if close_q50 is not None:
ax_price.plot(
y_timestamp.values, close_q50, linestyle="--", linewidth=2, label="预测收盘中位线(q50)"
)
# 预测起点与阴影
if len(y_timestamp) > 0:
ax_price.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1, label="预测起点")
ax_price.axvspan(y_timestamp.iloc[0], y_timestamp.iloc[-1], color="yellow", alpha=0.08)
# 绘制每天 16:00 UTC+8 的竖直虚线并标注收盘价
# 合并历史和预测时间范围与价格数据
all_timestamps = pd.concat([historical_df["timestamps"], y_timestamp], ignore_index=True)
hist_close = historical_df["close"]
pred_close = close_q50 if close_q50 is not None else pd.Series([np.nan] * len(y_timestamp))
all_close = pd.concat([hist_close.reset_index(drop=True), pred_close.reset_index(drop=True)], ignore_index=True)
if len(all_timestamps) > 0:
start_time = all_timestamps.min()
end_time = all_timestamps.max()
# 生成所有16:00时间点UTC+8
current_date = start_time.normalize() # 当天零点
while current_date <= end_time:
target_time = current_date + pd.Timedelta(hours=16) # 16:00
if start_time <= target_time <= end_time:
# 画虚线
ax_price.axvline(x=target_time, color='blue', linestyle='--', linewidth=0.8, alpha=0.5)
# 找到最接近16:00的时间点的收盘价
time_diffs = (all_timestamps - target_time).abs()
closest_idx = time_diffs.idxmin()
closest_time = all_timestamps.iloc[closest_idx]
closest_price = all_close.iloc[closest_idx]
# 如果时间差不超过1小时则标注价格
if time_diffs.iloc[closest_idx] <= pd.Timedelta(hours=1) and not np.isnan(closest_price):
ax_price.plot(closest_time, closest_price, 'o', color='blue', markersize=6, alpha=0.7)
ax_price.text(
closest_time, closest_price,
f' ${closest_price:.1f}',
fontsize=9,
color='blue',
verticalalignment='bottom',
horizontalalignment='left',
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='blue', alpha=0.7)
)
current_date += pd.Timedelta(days=1)
ax_price.set_ylabel("价格 (USD)", fontsize=11)
ax_price.set_title(f"{title_prefix} - 收盘价 & 成交量(历史 + 预测)", fontsize=15, fontweight="bold")
ax_price.grid(True, alpha=0.3)
ax_price.legend(loc="best", fontsize=9)
# ===== 下:成交量(仅预测中位量能柱)=====
ax_vol = fig.add_subplot(gs[1], sharex=ax_price)
# 历史量能柱
ax_vol.bar(
historical_df["timestamps"],
historical_df["volume"],
width=bar_width_hist,
alpha=0.35,
label="历史成交量",
)
# 预测中位量能柱(不画区间)
if vol_q50 is not None:
ax_vol.bar(
y_timestamp.values,
vol_q50,
width=bar_width_pred,
alpha=0.6,
label="预测成交量中位(q50)",
)
# 预测起点线
if len(y_timestamp) > 0:
ax_vol.axvline(x=y_timestamp.iloc[0], color="gray", linestyle=":", linewidth=1)
ax_vol.set_ylabel("成交量", fontsize=11)
ax_vol.set_xlabel("时间", fontsize=11)
ax_vol.grid(True, alpha=0.25)
ax_vol.legend(loc="best", fontsize=9)
# 避免X轴标签重叠
plt.setp(ax_price.get_xticklabels(), visible=False)
plt.tight_layout()
return fig
def main():
parser = argparse.ArgumentParser(description="Kronos 多次预测区间可视化")
parser.add_argument("--symbol", default="BTCUSD")
parser.add_argument("--category", default="inverse")
parser.add_argument("--interval", default="15")
parser.add_argument("--limit", type=int, default=1000)
parser.add_argument("--lookback", type=int, default=LOOKBACK)
parser.add_argument("--pred_len", type=int, default=PRED_LEN)
parser.add_argument("--samples", type=int, default=NUM_SAMPLES)
parser.add_argument("--temperature", type=float, default=TEMPERATURE)
parser.add_argument("--top_p", type=float, default=TOP_P)
parser.add_argument("--quantiles", default="0.1,0.5,0.9")
parser.add_argument("--device", default=DEVICE)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
# 载入env & HTTP
load_dotenv()
api_key = os.getenv("BYBIT_API_KEY")
api_secret = os.getenv("BYBIT_API_SECRET")
session = HTTP(testnet=False, api_key=api_key, api_secret=api_secret)
# Kronos
predictor = load_kronos(device=args.device)
# 拉数据
df = fetch_kline(
session,
symbol=args.symbol,
interval=args.interval,
limit=args.limit,
category=args.category,
)
x_df, x_ts, y_ts, start_idx, end_idx = prepare_io_windows(
df,
lookback=args.lookback,
pred_len=args.pred_len,
interval_min=int(args.interval),
)
# 历史用于画图的窗口最近200根可按需调整
plot_start = max(0, end_idx - 200)
historical_df = df.loc[plot_start : end_idx - 1].copy()
# 采样参数(直接使用局部变量,不需要修改全局变量)
temperature = args.temperature
top_p = args.top_p
qs = [float(x) for x in args.quantiles.split(",") if x.strip()]
# 多次预测(传入局部变量)
preds = run_multi_predictions(
predictor, x_df, x_ts, y_ts, num_samples=args.samples, base_seed=args.seed,
temperature=temperature, top_p=top_p
)
# 聚合分位
agg_df = aggregate_quantiles(preds, quantiles=qs)
agg_df["timestamps"] = y_ts.values
# 输出 CSV聚合
out_csv = os.path.join(
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.csv"
)
agg_df.to_csv(out_csv, index=False)
print(f"预测分位数据已保存到: {out_csv}")
# 作图
fig = plot_results(
historical_df, y_ts, agg_df, title_prefix=f"{args.symbol} {args.interval}分钟"
)
out_png = os.path.join(
OUTPUT_DIR, f"{args.symbol.lower()}_interval_pred_{args.interval}m.png"
)
fig.savefig(out_png, dpi=150, bbox_inches="tight")
print(f"预测区间图表已保存到: {out_png}")
plt.show()
if __name__ == "__main__":
try:
main()
except Exception as e:
print("运行出错:", repr(e))
sys.exit(1)

View File

@ -1,135 +0,0 @@
# Kronos Web UI
Web user interface for Kronos financial prediction model, providing intuitive graphical operation interface.
## ✨ Features
- **Multi-format data support**: Supports CSV, Feather and other financial data formats
- **Smart time window**: Fixed 400+120 data point time window slider selection
- **Real model prediction**: Integrated real Kronos model, supports multiple model sizes
- **Prediction quality control**: Adjustable temperature, nucleus sampling, sample count and other parameters
- **Multi-device support**: Supports CPU, CUDA, MPS and other computing devices
- **Comparison analysis**: Detailed comparison between prediction results and actual data
- **K-line chart display**: Professional financial K-line chart display
## 🚀 Quick Start
### Method 1: Start with Python script
```bash
cd webui
python run.py
```
### Method 2: Start with Shell script
```bash
cd webui
chmod +x start.sh
./start.sh
```
### Method 3: Start Flask application directly
```bash
cd webui
python app.py
```
After successful startup, visit http://localhost:7070
## 📋 Usage Steps
1. **Load data**: Select financial data file from data directory
2. **Load model**: Select Kronos model and computing device
3. **Set parameters**: Adjust prediction quality parameters
4. **Select time window**: Use slider to select 400+120 data point time range
5. **Start prediction**: Click prediction button to generate results
6. **View results**: View prediction results in charts and tables
## 🔧 Prediction Quality Parameters
### Temperature (T)
- **Range**: 0.1 - 2.0
- **Effect**: Controls prediction randomness
- **Recommendation**: 1.2-1.5 for better prediction quality
### Nucleus Sampling (top_p)
- **Range**: 0.1 - 1.0
- **Effect**: Controls prediction diversity
- **Recommendation**: 0.95-1.0 to consider more possibilities
### Sample Count
- **Range**: 1 - 5
- **Effect**: Generate multiple prediction samples
- **Recommendation**: 2-3 samples to improve quality
## 📊 Supported Data Formats
### Required Columns
- `open`: Opening price
- `high`: Highest price
- `low`: Lowest price
- `close`: Closing price
### Optional Columns
- `volume`: Trading volume
- `amount`: Trading amount (not used for prediction)
- `timestamps`/`timestamp`/`date`: Timestamp
## 🤖 Model Support
- **Kronos-mini**: 4.1M parameters, lightweight fast prediction
- **Kronos-small**: 24.7M parameters, balanced performance and speed
- **Kronos-base**: 102.3M parameters, high quality prediction
## 🖥️ GPU Acceleration Support
- **CPU**: General computing, best compatibility
- **CUDA**: NVIDIA GPU acceleration, best performance
- **MPS**: Apple Silicon GPU acceleration, recommended for Mac users
## ⚠️ Notes
- `amount` column is not used for prediction, only for display
- Time window is fixed at 400+120=520 data points
- Ensure data file contains sufficient historical data
- First model loading may require download, please be patient
## 🔍 Comparison Analysis
The system automatically provides comparison analysis between prediction results and actual data, including:
- Price difference statistics
- Error analysis
- Prediction quality assessment
## 🛠️ Technical Architecture
- **Backend**: Flask + Python
- **Frontend**: HTML + CSS + JavaScript
- **Charts**: Plotly.js
- **Data processing**: Pandas + NumPy
- **Model**: Hugging Face Transformers
## 📝 Troubleshooting
### Common Issues
1. **Port occupied**: Modify port number in app.py
2. **Missing dependencies**: Run `pip install -r requirements.txt`
3. **Model loading failed**: Check network connection and model ID
4. **Data format error**: Ensure data column names and format are correct
### Log Viewing
Detailed runtime information will be displayed in the console at startup, including model status and error messages.
## 📄 License
This project follows the license terms of the original Kronos project.
## 🤝 Contributing
Welcome to submit Issues and Pull Requests to improve this Web UI!
## 📞 Support
If you have questions, please check:
1. Project documentation
2. GitHub Issues
3. Console error messages

View File

@ -1,708 +0,0 @@
import os
import pandas as pd
import numpy as np
import json
import plotly.graph_objects as go
import plotly.utils
from flask import Flask, render_template, request, jsonify
from flask_cors import CORS
import sys
import warnings
import datetime
warnings.filterwarnings('ignore')
# Add project root directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
from model import Kronos, KronosTokenizer, KronosPredictor
MODEL_AVAILABLE = True
except ImportError:
MODEL_AVAILABLE = False
print("Warning: Kronos model cannot be imported, will use simulated data for demonstration")
app = Flask(__name__)
CORS(app)
# Global variables to store models
tokenizer = None
model = None
predictor = None
# Available model configurations
AVAILABLE_MODELS = {
'kronos-mini': {
'name': 'Kronos-mini',
'model_id': 'NeoQuasar/Kronos-mini',
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-2k',
'context_length': 2048,
'params': '4.1M',
'description': 'Lightweight model, suitable for fast prediction'
},
'kronos-small': {
'name': 'Kronos-small',
'model_id': 'NeoQuasar/Kronos-small',
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base',
'context_length': 512,
'params': '24.7M',
'description': 'Small model, balanced performance and speed'
},
'kronos-base': {
'name': 'Kronos-base',
'model_id': 'NeoQuasar/Kronos-base',
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base',
'context_length': 512,
'params': '102.3M',
'description': 'Base model, provides better prediction quality'
}
}
def load_data_files():
"""Scan data directory and return available data files"""
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data')
data_files = []
if os.path.exists(data_dir):
for file in os.listdir(data_dir):
if file.endswith(('.csv', '.feather')):
file_path = os.path.join(data_dir, file)
file_size = os.path.getsize(file_path)
data_files.append({
'name': file,
'path': file_path,
'size': f"{file_size / 1024:.1f} KB" if file_size < 1024*1024 else f"{file_size / (1024*1024):.1f} MB"
})
return data_files
def load_data_file(file_path):
"""Load data file"""
try:
if file_path.endswith('.csv'):
df = pd.read_csv(file_path)
elif file_path.endswith('.feather'):
df = pd.read_feather(file_path)
else:
return None, "Unsupported file format"
# Check required columns
required_cols = ['open', 'high', 'low', 'close']
if not all(col in df.columns for col in required_cols):
return None, f"Missing required columns: {required_cols}"
# Process timestamp column
if 'timestamps' in df.columns:
df['timestamps'] = pd.to_datetime(df['timestamps'])
elif 'timestamp' in df.columns:
df['timestamps'] = pd.to_datetime(df['timestamp'])
elif 'date' in df.columns:
# If column name is 'date', rename it to 'timestamps'
df['timestamps'] = pd.to_datetime(df['date'])
else:
# If no timestamp column exists, create one
df['timestamps'] = pd.date_range(start='2024-01-01', periods=len(df), freq='1H')
# Ensure numeric columns are numeric type
for col in ['open', 'high', 'low', 'close']:
df[col] = pd.to_numeric(df[col], errors='coerce')
# Process volume column (optional)
if 'volume' in df.columns:
df['volume'] = pd.to_numeric(df['volume'], errors='coerce')
# Process amount column (optional, but not used for prediction)
if 'amount' in df.columns:
df['amount'] = pd.to_numeric(df['amount'], errors='coerce')
# Remove rows containing NaN values
df = df.dropna()
return df, None
except Exception as e:
return None, f"Failed to load file: {str(e)}"
def save_prediction_results(file_path, prediction_type, prediction_results, actual_data, input_data, prediction_params):
"""Save prediction results to file"""
try:
# Create prediction results directory
results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prediction_results')
os.makedirs(results_dir, exist_ok=True)
# Generate filename
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'prediction_{timestamp}.json'
filepath = os.path.join(results_dir, filename)
# Prepare data for saving
save_data = {
'timestamp': datetime.datetime.now().isoformat(),
'file_path': file_path,
'prediction_type': prediction_type,
'prediction_params': prediction_params,
'input_data_summary': {
'rows': len(input_data),
'columns': list(input_data.columns),
'price_range': {
'open': {'min': float(input_data['open'].min()), 'max': float(input_data['open'].max())},
'high': {'min': float(input_data['high'].min()), 'max': float(input_data['high'].max())},
'low': {'min': float(input_data['low'].min()), 'max': float(input_data['low'].max())},
'close': {'min': float(input_data['close'].min()), 'max': float(input_data['close'].max())}
},
'last_values': {
'open': float(input_data['open'].iloc[-1]),
'high': float(input_data['high'].iloc[-1]),
'low': float(input_data['low'].iloc[-1]),
'close': float(input_data['close'].iloc[-1])
}
},
'prediction_results': prediction_results,
'actual_data': actual_data,
'analysis': {}
}
# If actual data exists, perform comparison analysis
if actual_data and len(actual_data) > 0:
# Calculate continuity analysis
if len(prediction_results) > 0 and len(actual_data) > 0:
last_pred = prediction_results[0] # First prediction point
first_actual = actual_data[0] # First actual point
save_data['analysis']['continuity'] = {
'last_prediction': {
'open': last_pred['open'],
'high': last_pred['high'],
'low': last_pred['low'],
'close': last_pred['close']
},
'first_actual': {
'open': first_actual['open'],
'high': first_actual['high'],
'low': first_actual['low'],
'close': first_actual['close']
},
'gaps': {
'open_gap': abs(last_pred['open'] - first_actual['open']),
'high_gap': abs(last_pred['high'] - first_actual['high']),
'low_gap': abs(last_pred['low'] - first_actual['low']),
'close_gap': abs(last_pred['close'] - first_actual['close'])
},
'gap_percentages': {
'open_gap_pct': (abs(last_pred['open'] - first_actual['open']) / first_actual['open']) * 100,
'high_gap_pct': (abs(last_pred['high'] - first_actual['high']) / first_actual['high']) * 100,
'low_gap_pct': (abs(last_pred['low'] - first_actual['low']) / first_actual['low']) * 100,
'close_gap_pct': (abs(last_pred['close'] - first_actual['close']) / first_actual['close']) * 100
}
}
# Save to file
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(save_data, f, indent=2, ensure_ascii=False)
print(f"Prediction results saved to: {filepath}")
return filepath
except Exception as e:
print(f"Failed to save prediction results: {e}")
return None
def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0):
"""Create prediction chart"""
# Use specified historical data start position, not always from the beginning of df
if historical_start_idx + lookback + pred_len <= len(df):
# Display lookback historical points + pred_len prediction points starting from specified position
historical_df = df.iloc[historical_start_idx:historical_start_idx+lookback]
prediction_range = range(historical_start_idx+lookback, historical_start_idx+lookback+pred_len)
else:
# If data is insufficient, adjust to maximum available range
available_lookback = min(lookback, len(df) - historical_start_idx)
available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback))
historical_df = df.iloc[historical_start_idx:historical_start_idx+available_lookback]
prediction_range = range(historical_start_idx+available_lookback, historical_start_idx+available_lookback+available_pred_len)
# Create chart
fig = go.Figure()
# Add historical data (candlestick chart)
fig.add_trace(go.Candlestick(
x=historical_df['timestamps'] if 'timestamps' in historical_df.columns else historical_df.index,
open=historical_df['open'],
high=historical_df['high'],
low=historical_df['low'],
close=historical_df['close'],
name='Historical Data (400 data points)',
increasing_line_color='#26A69A',
decreasing_line_color='#EF5350'
))
# Add prediction data (candlestick chart)
if pred_df is not None and len(pred_df) > 0:
# Calculate prediction data timestamps - ensure continuity with historical data
if 'timestamps' in df.columns and len(historical_df) > 0:
# Start from the last timestamp of historical data, create prediction timestamps with the same time interval
last_timestamp = historical_df['timestamps'].iloc[-1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
pred_timestamps = pd.date_range(
start=last_timestamp + time_diff,
periods=len(pred_df),
freq=time_diff
)
else:
# If no timestamps, use index
pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df))
fig.add_trace(go.Candlestick(
x=pred_timestamps,
open=pred_df['open'],
high=pred_df['high'],
low=pred_df['low'],
close=pred_df['close'],
name='Prediction Data (120 data points)',
increasing_line_color='#66BB6A',
decreasing_line_color='#FF7043'
))
# Add actual data for comparison (if exists)
if actual_df is not None and len(actual_df) > 0:
# Actual data should be in the same time period as prediction data
if 'timestamps' in df.columns:
# Actual data should use the same timestamps as prediction data to ensure time alignment
if 'pred_timestamps' in locals():
actual_timestamps = pred_timestamps
else:
# If no prediction timestamps, calculate from the last timestamp of historical data
if len(historical_df) > 0:
last_timestamp = historical_df['timestamps'].iloc[-1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
actual_timestamps = pd.date_range(
start=last_timestamp + time_diff,
periods=len(actual_df),
freq=time_diff
)
else:
actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
else:
actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
fig.add_trace(go.Candlestick(
x=actual_timestamps,
open=actual_df['open'],
high=actual_df['high'],
low=actual_df['low'],
close=actual_df['close'],
name='Actual Data (120 data points)',
increasing_line_color='#FF9800',
decreasing_line_color='#F44336'
))
# Update layout
fig.update_layout(
title='Kronos Financial Prediction Results - 400 Historical Points + 120 Prediction Points vs 120 Actual Points',
xaxis_title='Time',
yaxis_title='Price',
template='plotly_white',
height=600,
showlegend=True
)
# Ensure x-axis time continuity
if 'timestamps' in historical_df.columns:
# Get all timestamps and sort them
all_timestamps = []
if len(historical_df) > 0:
all_timestamps.extend(historical_df['timestamps'])
if 'pred_timestamps' in locals():
all_timestamps.extend(pred_timestamps)
if 'actual_timestamps' in locals():
all_timestamps.extend(actual_timestamps)
if all_timestamps:
all_timestamps = sorted(all_timestamps)
fig.update_xaxes(
range=[all_timestamps[0], all_timestamps[-1]],
rangeslider_visible=False,
type='date'
)
return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
@app.route('/')
def index():
"""Home page"""
return render_template('index.html')
@app.route('/api/data-files')
def get_data_files():
"""Get available data file list"""
data_files = load_data_files()
return jsonify(data_files)
@app.route('/api/load-data', methods=['POST'])
def load_data():
"""Load data file"""
try:
data = request.get_json()
file_path = data.get('file_path')
if not file_path:
return jsonify({'error': 'File path cannot be empty'}), 400
df, error = load_data_file(file_path)
if error:
return jsonify({'error': error}), 400
# Detect data time frequency
def detect_timeframe(df):
if len(df) < 2:
return "Unknown"
time_diffs = []
for i in range(1, min(10, len(df))): # Check first 10 time differences
diff = df['timestamps'].iloc[i] - df['timestamps'].iloc[i-1]
time_diffs.append(diff)
if not time_diffs:
return "Unknown"
# Calculate average time difference
avg_diff = sum(time_diffs, pd.Timedelta(0)) / len(time_diffs)
# Convert to readable format
if avg_diff < pd.Timedelta(minutes=1):
return f"{avg_diff.total_seconds():.0f} seconds"
elif avg_diff < pd.Timedelta(hours=1):
return f"{avg_diff.total_seconds() / 60:.0f} minutes"
elif avg_diff < pd.Timedelta(days=1):
return f"{avg_diff.total_seconds() / 3600:.0f} hours"
else:
return f"{avg_diff.days} days"
# Return data information
data_info = {
'rows': len(df),
'columns': list(df.columns),
'start_date': df['timestamps'].min().isoformat() if 'timestamps' in df.columns else 'N/A',
'end_date': df['timestamps'].max().isoformat() if 'timestamps' in df.columns else 'N/A',
'price_range': {
'min': float(df[['open', 'high', 'low', 'close']].min().min()),
'max': float(df[['open', 'high', 'low', 'close']].max().max())
},
'prediction_columns': ['open', 'high', 'low', 'close'] + (['volume'] if 'volume' in df.columns else []),
'timeframe': detect_timeframe(df)
}
return jsonify({
'success': True,
'data_info': data_info,
'message': f'Successfully loaded data, total {len(df)} rows'
})
except Exception as e:
return jsonify({'error': f'Failed to load data: {str(e)}'}), 500
@app.route('/api/predict', methods=['POST'])
def predict():
"""Perform prediction"""
try:
data = request.get_json()
file_path = data.get('file_path')
lookback = int(data.get('lookback', 400))
pred_len = int(data.get('pred_len', 120))
# Get prediction quality parameters
temperature = float(data.get('temperature', 1.0))
top_p = float(data.get('top_p', 0.9))
sample_count = int(data.get('sample_count', 1))
if not file_path:
return jsonify({'error': 'File path cannot be empty'}), 400
# Load data
df, error = load_data_file(file_path)
if error:
return jsonify({'error': error}), 400
if len(df) < lookback:
return jsonify({'error': f'Insufficient data length, need at least {lookback} rows'}), 400
# Perform prediction
if MODEL_AVAILABLE and predictor is not None:
try:
# Use real Kronos model
# Only use necessary columns: OHLCV, excluding amount
required_cols = ['open', 'high', 'low', 'close']
if 'volume' in df.columns:
required_cols.append('volume')
# Process time period selection
start_date = data.get('start_date')
if start_date:
# Custom time period - fix logic: use data within selected window
start_dt = pd.to_datetime(start_date)
# Find data after start time
mask = df['timestamps'] >= start_dt
time_range_df = df[mask]
# Ensure sufficient data: lookback + pred_len
if len(time_range_df) < lookback + pred_len:
return jsonify({'error': f'Insufficient data from start time {start_dt.strftime("%Y-%m-%d %H:%M")}, need at least {lookback + pred_len} data points, currently only {len(time_range_df)} available'}), 400
# Use first lookback data points within selected window for prediction
x_df = time_range_df.iloc[:lookback][required_cols]
x_timestamp = time_range_df.iloc[:lookback]['timestamps']
# Use last pred_len data points within selected window as actual values
y_timestamp = time_range_df.iloc[lookback:lookback+pred_len]['timestamps']
# Calculate actual time period length
start_timestamp = time_range_df['timestamps'].iloc[0]
end_timestamp = time_range_df['timestamps'].iloc[lookback+pred_len-1]
time_span = end_timestamp - start_timestamp
prediction_type = f"Kronos model prediction (within selected window: first {lookback} data points for prediction, last {pred_len} data points for comparison, time span: {time_span})"
else:
# Use latest data
x_df = df.iloc[:lookback][required_cols]
x_timestamp = df.iloc[:lookback]['timestamps']
y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps']
prediction_type = "Kronos model prediction (latest data)"
# Ensure timestamps are Series format, not DatetimeIndex, to avoid .dt attribute error in Kronos model
if isinstance(x_timestamp, pd.DatetimeIndex):
x_timestamp = pd.Series(x_timestamp, name='timestamps')
if isinstance(y_timestamp, pd.DatetimeIndex):
y_timestamp = pd.Series(y_timestamp, name='timestamps')
pred_df = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=temperature,
top_p=top_p,
sample_count=sample_count
)
except Exception as e:
return jsonify({'error': f'Kronos model prediction failed: {str(e)}'}), 500
else:
return jsonify({'error': 'Kronos model not loaded, please load model first'}), 400
# Prepare actual data for comparison (if exists)
actual_data = []
actual_df = None
if start_date: # Custom time period
# Fix logic: use data within selected window
# Prediction uses first 400 data points within selected window
# Actual data should be last 120 data points within selected window
start_dt = pd.to_datetime(start_date)
# Find data starting from start_date
mask = df['timestamps'] >= start_dt
time_range_df = df[mask]
if len(time_range_df) >= lookback + pred_len:
# Get last 120 data points within selected window as actual values
actual_df = time_range_df.iloc[lookback:lookback+pred_len]
for i, (_, row) in enumerate(actual_df.iterrows()):
actual_data.append({
'timestamp': row['timestamps'].isoformat(),
'open': float(row['open']),
'high': float(row['high']),
'low': float(row['low']),
'close': float(row['close']),
'volume': float(row['volume']) if 'volume' in row else 0,
'amount': float(row['amount']) if 'amount' in row else 0
})
else: # Latest data
# Prediction uses first 400 data points
# Actual data should be 120 data points after first 400 data points
if len(df) >= lookback + pred_len:
actual_df = df.iloc[lookback:lookback+pred_len]
for i, (_, row) in enumerate(actual_df.iterrows()):
actual_data.append({
'timestamp': row['timestamps'].isoformat(),
'open': float(row['open']),
'high': float(row['high']),
'low': float(row['low']),
'close': float(row['close']),
'volume': float(row['volume']) if 'volume' in row else 0,
'amount': float(row['amount']) if 'amount' in row else 0
})
# Create chart - pass historical data start position
if start_date:
# Custom time period: find starting position of historical data in original df
start_dt = pd.to_datetime(start_date)
mask = df['timestamps'] >= start_dt
historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0
else:
# Latest data: start from beginning
historical_start_idx = 0
chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx)
# Prepare prediction result data - fix timestamp calculation logic
if 'timestamps' in df.columns:
if start_date:
# Custom time period: use selected window data to calculate timestamps
start_dt = pd.to_datetime(start_date)
mask = df['timestamps'] >= start_dt
time_range_df = df[mask]
if len(time_range_df) >= lookback:
# Calculate prediction timestamps starting from last time point of selected window
last_timestamp = time_range_df['timestamps'].iloc[lookback-1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
future_timestamps = pd.date_range(
start=last_timestamp + time_diff,
periods=pred_len,
freq=time_diff
)
else:
future_timestamps = []
else:
# Latest data: calculate from last time point of entire data file
last_timestamp = df['timestamps'].iloc[-1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
future_timestamps = pd.date_range(
start=last_timestamp + time_diff,
periods=pred_len,
freq=time_diff
)
else:
future_timestamps = range(len(df), len(df) + pred_len)
prediction_results = []
for i, (_, row) in enumerate(pred_df.iterrows()):
prediction_results.append({
'timestamp': future_timestamps[i].isoformat() if i < len(future_timestamps) else f"T{i}",
'open': float(row['open']),
'high': float(row['high']),
'low': float(row['low']),
'close': float(row['close']),
'volume': float(row['volume']) if 'volume' in row else 0,
'amount': float(row['amount']) if 'amount' in row else 0
})
# Save prediction results to file
try:
save_prediction_results(
file_path=file_path,
prediction_type=prediction_type,
prediction_results=prediction_results,
actual_data=actual_data,
input_data=x_df,
prediction_params={
'lookback': lookback,
'pred_len': pred_len,
'temperature': temperature,
'top_p': top_p,
'sample_count': sample_count,
'start_date': start_date if start_date else 'latest'
}
)
except Exception as e:
print(f"Failed to save prediction results: {e}")
return jsonify({
'success': True,
'prediction_type': prediction_type,
'chart': chart_json,
'prediction_results': prediction_results,
'actual_data': actual_data,
'has_comparison': len(actual_data) > 0,
'message': f'Prediction completed, generated {pred_len} prediction points' + (f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '')
})
except Exception as e:
return jsonify({'error': f'Prediction failed: {str(e)}'}), 500
@app.route('/api/load-model', methods=['POST'])
def load_model():
"""Load Kronos model"""
global tokenizer, model, predictor
try:
if not MODEL_AVAILABLE:
return jsonify({'error': 'Kronos model library not available'}), 400
data = request.get_json()
model_key = data.get('model_key', 'kronos-small')
device = data.get('device', 'cpu')
if model_key not in AVAILABLE_MODELS:
return jsonify({'error': f'Unsupported model: {model_key}'}), 400
model_config = AVAILABLE_MODELS[model_key]
# Load tokenizer and model
tokenizer = KronosTokenizer.from_pretrained(model_config['tokenizer_id'])
model = Kronos.from_pretrained(model_config['model_id'])
# Create predictor
predictor = KronosPredictor(model, tokenizer, device=device, max_context=model_config['context_length'])
return jsonify({
'success': True,
'message': f'Model loaded successfully: {model_config["name"]} ({model_config["params"]}) on {device}',
'model_info': {
'name': model_config['name'],
'params': model_config['params'],
'context_length': model_config['context_length'],
'description': model_config['description']
}
})
except Exception as e:
return jsonify({'error': f'Model loading failed: {str(e)}'}), 500
@app.route('/api/available-models')
def get_available_models():
"""Get available model list"""
return jsonify({
'models': AVAILABLE_MODELS,
'model_available': MODEL_AVAILABLE
})
@app.route('/api/model-status')
def get_model_status():
"""Get model status"""
if MODEL_AVAILABLE:
if predictor is not None:
return jsonify({
'available': True,
'loaded': True,
'message': 'Kronos model loaded and available',
'current_model': {
'name': predictor.model.__class__.__name__,
'device': str(next(predictor.model.parameters()).device)
}
})
else:
return jsonify({
'available': True,
'loaded': False,
'message': 'Kronos model available but not loaded'
})
else:
return jsonify({
'available': False,
'loaded': False,
'message': 'Kronos model library not available, please install related dependencies'
})
if __name__ == '__main__':
print("Starting Kronos Web UI...")
print(f"Model availability: {MODEL_AVAILABLE}")
if MODEL_AVAILABLE:
print("Tip: You can load Kronos model through /api/load-model endpoint")
else:
print("Tip: Will use simulated data for demonstration")
app.run(debug=True, host='0.0.0.0', port=7070)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +0,0 @@
flask==2.3.3
flask-cors==4.0.0
pandas==2.2.2
numpy==1.24.3
plotly==5.17.0
torch>=2.1.0
huggingface_hub==0.33.1

View File

@ -1,89 +0,0 @@
#!/usr/bin/env python3
"""
Kronos Web UI startup script
"""
import os
import sys
import subprocess
import webbrowser
import time
def check_dependencies():
"""Check if dependencies are installed"""
try:
import flask
import flask_cors
import pandas
import numpy
import plotly
print("✅ All dependencies installed")
return True
except ImportError as e:
print(f"❌ Missing dependency: {e}")
print("Please run: pip install -r requirements.txt")
return False
def install_dependencies():
"""Install dependencies"""
print("Installing dependencies...")
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"])
print("✅ Dependencies installation completed")
return True
except subprocess.CalledProcessError:
print("❌ Dependencies installation failed")
return False
def main():
"""Main function"""
print("🚀 Starting Kronos Web UI...")
print("=" * 50)
# Check dependencies
if not check_dependencies():
print("\nAuto-install dependencies? (y/n): ", end="")
if input().lower() == 'y':
if not install_dependencies():
return
else:
print("Please manually install dependencies and retry")
return
# Check model availability
try:
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model import Kronos, KronosTokenizer, KronosPredictor
print("✅ Kronos model library available")
model_available = True
except ImportError:
print("⚠️ Kronos model library not available, will use simulated prediction")
model_available = False
# Start Flask application
print("\n🌐 Starting Web server...")
# Set environment variables
os.environ['FLASK_APP'] = 'app.py'
os.environ['FLASK_ENV'] = 'development'
# Start server
try:
from app import app
print("✅ Web server started successfully!")
print(f"🌐 Access URL: http://localhost:7070")
print("💡 Tip: Press Ctrl+C to stop server")
# Auto-open browser
time.sleep(2)
webbrowser.open('http://localhost:7070')
# Start Flask application
app.run(debug=True, host='0.0.0.0', port=7070)
except Exception as e:
print(f"❌ Startup failed: {e}")
print("Please check if port 7070 is occupied")
if __name__ == "__main__":
main()

View File

@ -1,40 +0,0 @@
#!/bin/bash
# Kronos Web UI startup script
echo "🚀 Starting Kronos Web UI..."
echo "================================"
# Check if Python is installed
if ! command -v python3 &> /dev/null; then
echo "❌ Python3 not installed, please install Python3 first"
exit 1
fi
# Check if in correct directory
if [ ! -f "app.py" ]; then
echo "❌ Please run this script in the webui directory"
exit 1
fi
# Check dependencies
echo "📦 Checking dependencies..."
if ! python3 -c "import flask, flask_cors, pandas, numpy, plotly" &> /dev/null; then
echo "⚠️ Missing dependencies, installing..."
pip3 install -r requirements.txt
if [ $? -ne 0 ]; then
echo "❌ Dependencies installation failed"
exit 1
fi
echo "✅ Dependencies installation completed"
else
echo "✅ All dependencies installed"
fi
# Start application
echo "🌐 Starting Web server..."
echo "Access URL: http://localhost:7070"
echo "Press Ctrl+C to stop server"
echo ""
python3 app.py

File diff suppressed because it is too large Load Diff