diff --git a/README.md b/README.md index 9d7704a..2e399b7 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,40 @@ print(pred_df.head()) The `predict` method returns a pandas DataFrame containing the forecasted values for `open`, `high`, `low`, `close`, `volume`, and `amount`, indexed by the `y_timestamp` you provided. +For efficient processing of multiple time series, Kronos provides a `predict_batch` method that enables parallel prediction on multiple datasets simultaneously. This is particularly useful when you need to forecast multiple assets or time periods at once. + +```python +# Prepare multiple datasets for batch prediction +df_list = [df1, df2, df3] # List of DataFrames +x_timestamp_list = [x_ts1, x_ts2, x_ts3] # List of historical timestamps +y_timestamp_list = [y_ts1, y_ts2, y_ts3] # List of future timestamps + +# Generate batch predictions +pred_df_list = predictor.predict_batch( + df_list=df_list, + x_timestamp_list=x_timestamp_list, + y_timestamp_list=y_timestamp_list, + pred_len=pred_len, + T=1.0, + top_p=0.9, + sample_count=1, + verbose=True +) + +# pred_df_list contains prediction results in the same order as input +for i, pred_df in enumerate(pred_df_list): + print(f"Predictions for series {i}:") + print(pred_df.head()) +``` + +**Important Requirements for Batch Prediction:** +- All series must have the same historical length (lookback window) +- All series must have the same prediction length (`pred_len`) +- Each DataFrame must contain the required columns: `['open', 'high', 'low', 'close']` +- `volume` and `amount` columns are optional and will be filled with zeros if missing + +The `predict_batch` method leverages GPU parallelism for efficient processing and automatically handles normalization and denormalization for each series independently. + #### 5. Example and Visualization For a complete, runnable script that includes data loading, prediction, and plotting, please see [`examples/prediction_example.py`](examples/prediction_example.py). diff --git a/examples/prediction_batch_example.py b/examples/prediction_batch_example.py new file mode 100644 index 0000000..29a7433 --- /dev/null +++ b/examples/prediction_batch_example.py @@ -0,0 +1,72 @@ +import pandas as pd +import matplotlib.pyplot as plt +import sys +sys.path.append("../") +from model import Kronos, KronosTokenizer, KronosPredictor + + +def plot_prediction(kline_df, pred_df): + pred_df.index = kline_df.index[-pred_df.shape[0]:] + sr_close = kline_df['close'] + sr_pred_close = pred_df['close'] + sr_close.name = 'Ground Truth' + sr_pred_close.name = "Prediction" + + sr_volume = kline_df['volume'] + sr_pred_volume = pred_df['volume'] + sr_volume.name = 'Ground Truth' + sr_pred_volume.name = "Prediction" + + close_df = pd.concat([sr_close, sr_pred_close], axis=1) + volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1) + + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True) + + ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) + ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5) + ax1.set_ylabel('Close Price', fontsize=14) + ax1.legend(loc='lower left', fontsize=12) + ax1.grid(True) + + ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) + ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5) + ax2.set_ylabel('Volume', fontsize=14) + ax2.legend(loc='upper left', fontsize=12) + ax2.grid(True) + + plt.tight_layout() + plt.show() + + +# 1. Load Model and Tokenizer +tokenizer = KronosTokenizer.from_pretrained('/home/csc/huggingface/Kronos-Tokenizer-base/') +model = Kronos.from_pretrained("/home/csc/huggingface/Kronos-base/") + +# 2. Instantiate Predictor +predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512) + +# 3. Prepare Data +df = pd.read_csv("./data/XSHG_5min_600977.csv") +df['timestamps'] = pd.to_datetime(df['timestamps']) + +lookback = 400 +pred_len = 120 + +dfs = [] +xtsp = [] +ytsp = [] +for i in range(5): + idf = df.loc[(i*400):(i*400+lookback-1), ['open', 'high', 'low', 'close', 'volume', 'amount']] + i_x_timestamp = df.loc[(i*400):(i*400+lookback-1), 'timestamps'] + i_y_timestamp = df.loc[(i*400+lookback):(i*400+lookback+pred_len-1), 'timestamps'] + + dfs.append(idf) + xtsp.append(i_x_timestamp) + ytsp.append(i_y_timestamp) + +pred_df = predictor.predict_batch( + df_list=dfs, + x_timestamp_list=xtsp, + y_timestamp_list=ytsp, + pred_len=pred_len, +) diff --git a/model/kronos.py b/model/kronos.py index fcee199..b22ee8c 100644 --- a/model/kronos.py +++ b/model/kronos.py @@ -519,3 +519,105 @@ class KronosPredictor: pred_df = pd.DataFrame(preds, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp) return pred_df + + + def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True): + """ + Perform parallel (batch) prediction on multiple time series. All series must have the same historical length and prediction length (pred_len). + + Args: + df_list (List[pd.DataFrame]): List of input DataFrames, each containing price columns and optional volume/amount columns. + x_timestamp_list (List[pd.DatetimeIndex or Series]): List of timestamps corresponding to historical data, length should match the number of rows in each DataFrame. + y_timestamp_list (List[pd.DatetimeIndex or Series]): List of future prediction timestamps, length should equal pred_len. + pred_len (int): Number of prediction steps. + T (float): Sampling temperature. + top_k (int): Top-k filtering threshold. + top_p (float): Top-p (nucleus sampling) threshold. + sample_count (int): Number of parallel samples per series, automatically averaged internally. + verbose (bool): Whether to display autoregressive progress. + + Returns: + List[pd.DataFrame]: List of prediction results in the same order as input, each DataFrame contains + `open, high, low, close, volume, amount` columns, indexed by corresponding `y_timestamp`. + """ + # Basic validation + if not isinstance(df_list, (list, tuple)) or not isinstance(x_timestamp_list, (list, tuple)) or not isinstance(y_timestamp_list, (list, tuple)): + raise ValueError("df_list, x_timestamp_list, y_timestamp_list must be list or tuple types.") + if not (len(df_list) == len(x_timestamp_list) == len(y_timestamp_list)): + raise ValueError("df_list, x_timestamp_list, y_timestamp_list must have consistent lengths.") + + num_series = len(df_list) + + x_list = [] + x_stamp_list = [] + y_stamp_list = [] + means = [] + stds = [] + seq_lens = [] + y_lens = [] + + for i in range(num_series): + df = df_list[i] + if not isinstance(df, pd.DataFrame): + raise ValueError(f"Input at index {i} is not a pandas DataFrame.") + if not all(col in df.columns for col in self.price_cols): + raise ValueError(f"DataFrame at index {i} is missing price columns {self.price_cols}.") + + df = df.copy() + if self.vol_col not in df.columns: + df[self.vol_col] = 0.0 + df[self.amt_vol] = 0.0 + if self.amt_vol not in df.columns and self.vol_col in df.columns: + df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1) + + if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any(): + raise ValueError(f"DataFrame at index {i} contains NaN values in price or volume columns.") + + x_timestamp = x_timestamp_list[i] + y_timestamp = y_timestamp_list[i] + + x_time_df = calc_time_stamps(x_timestamp) + y_time_df = calc_time_stamps(y_timestamp) + + x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32) + x_stamp = x_time_df.values.astype(np.float32) + y_stamp = y_time_df.values.astype(np.float32) + + if x.shape[0] != x_stamp.shape[0]: + raise ValueError(f"Inconsistent lengths at index {i}: x has {x.shape[0]} vs x_stamp has {x_stamp.shape[0]}.") + if y_stamp.shape[0] != pred_len: + raise ValueError(f"y_timestamp length at index {i} should equal pred_len={pred_len}, got {y_stamp.shape[0]}.") + + x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) + x_norm = (x - x_mean) / (x_std + 1e-5) + x_norm = np.clip(x_norm, -self.clip, self.clip) + + x_list.append(x_norm) + x_stamp_list.append(x_stamp) + y_stamp_list.append(y_stamp) + means.append(x_mean) + stds.append(x_std) + + seq_lens.append(x_norm.shape[0]) + y_lens.append(y_stamp.shape[0]) + + # Require all series to have consistent historical and prediction lengths for batch processing + if len(set(seq_lens)) != 1: + raise ValueError(f"Parallel prediction requires all series to have consistent historical lengths, got: {seq_lens}") + if len(set(y_lens)) != 1: + raise ValueError(f"Parallel prediction requires all series to have consistent prediction lengths, got: {y_lens}") + + x_batch = np.stack(x_list, axis=0).astype(np.float32) # (B, seq_len, feat) + x_stamp_batch = np.stack(x_stamp_list, axis=0).astype(np.float32) # (B, seq_len, time_feat) + y_stamp_batch = np.stack(y_stamp_list, axis=0).astype(np.float32) # (B, pred_len, time_feat) + + preds = self.generate(x_batch, x_stamp_batch, y_stamp_batch, pred_len, T, top_k, top_p, sample_count, verbose) + # preds: (B, pred_len, feat) + + pred_dfs = [] + for i in range(num_series): + preds_i = preds[i] * (stds[i] + 1e-5) + means[i] + pred_df = pd.DataFrame(preds_i, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp_list[i]) + pred_dfs.append(pred_df) + + return pred_dfs