diff --git a/model/kronos.py b/model/kronos.py index fcee199..b22ee8c 100644 --- a/model/kronos.py +++ b/model/kronos.py @@ -519,3 +519,105 @@ class KronosPredictor: pred_df = pd.DataFrame(preds, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp) return pred_df + + + def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True): + """ + Perform parallel (batch) prediction on multiple time series. All series must have the same historical length and prediction length (pred_len). + + Args: + df_list (List[pd.DataFrame]): List of input DataFrames, each containing price columns and optional volume/amount columns. + x_timestamp_list (List[pd.DatetimeIndex or Series]): List of timestamps corresponding to historical data, length should match the number of rows in each DataFrame. + y_timestamp_list (List[pd.DatetimeIndex or Series]): List of future prediction timestamps, length should equal pred_len. + pred_len (int): Number of prediction steps. + T (float): Sampling temperature. + top_k (int): Top-k filtering threshold. + top_p (float): Top-p (nucleus sampling) threshold. + sample_count (int): Number of parallel samples per series, automatically averaged internally. + verbose (bool): Whether to display autoregressive progress. + + Returns: + List[pd.DataFrame]: List of prediction results in the same order as input, each DataFrame contains + `open, high, low, close, volume, amount` columns, indexed by corresponding `y_timestamp`. + """ + # Basic validation + if not isinstance(df_list, (list, tuple)) or not isinstance(x_timestamp_list, (list, tuple)) or not isinstance(y_timestamp_list, (list, tuple)): + raise ValueError("df_list, x_timestamp_list, y_timestamp_list must be list or tuple types.") + if not (len(df_list) == len(x_timestamp_list) == len(y_timestamp_list)): + raise ValueError("df_list, x_timestamp_list, y_timestamp_list must have consistent lengths.") + + num_series = len(df_list) + + x_list = [] + x_stamp_list = [] + y_stamp_list = [] + means = [] + stds = [] + seq_lens = [] + y_lens = [] + + for i in range(num_series): + df = df_list[i] + if not isinstance(df, pd.DataFrame): + raise ValueError(f"Input at index {i} is not a pandas DataFrame.") + if not all(col in df.columns for col in self.price_cols): + raise ValueError(f"DataFrame at index {i} is missing price columns {self.price_cols}.") + + df = df.copy() + if self.vol_col not in df.columns: + df[self.vol_col] = 0.0 + df[self.amt_vol] = 0.0 + if self.amt_vol not in df.columns and self.vol_col in df.columns: + df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1) + + if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any(): + raise ValueError(f"DataFrame at index {i} contains NaN values in price or volume columns.") + + x_timestamp = x_timestamp_list[i] + y_timestamp = y_timestamp_list[i] + + x_time_df = calc_time_stamps(x_timestamp) + y_time_df = calc_time_stamps(y_timestamp) + + x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32) + x_stamp = x_time_df.values.astype(np.float32) + y_stamp = y_time_df.values.astype(np.float32) + + if x.shape[0] != x_stamp.shape[0]: + raise ValueError(f"Inconsistent lengths at index {i}: x has {x.shape[0]} vs x_stamp has {x_stamp.shape[0]}.") + if y_stamp.shape[0] != pred_len: + raise ValueError(f"y_timestamp length at index {i} should equal pred_len={pred_len}, got {y_stamp.shape[0]}.") + + x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) + x_norm = (x - x_mean) / (x_std + 1e-5) + x_norm = np.clip(x_norm, -self.clip, self.clip) + + x_list.append(x_norm) + x_stamp_list.append(x_stamp) + y_stamp_list.append(y_stamp) + means.append(x_mean) + stds.append(x_std) + + seq_lens.append(x_norm.shape[0]) + y_lens.append(y_stamp.shape[0]) + + # Require all series to have consistent historical and prediction lengths for batch processing + if len(set(seq_lens)) != 1: + raise ValueError(f"Parallel prediction requires all series to have consistent historical lengths, got: {seq_lens}") + if len(set(y_lens)) != 1: + raise ValueError(f"Parallel prediction requires all series to have consistent prediction lengths, got: {y_lens}") + + x_batch = np.stack(x_list, axis=0).astype(np.float32) # (B, seq_len, feat) + x_stamp_batch = np.stack(x_stamp_list, axis=0).astype(np.float32) # (B, seq_len, time_feat) + y_stamp_batch = np.stack(y_stamp_list, axis=0).astype(np.float32) # (B, pred_len, time_feat) + + preds = self.generate(x_batch, x_stamp_batch, y_stamp_batch, pred_len, T, top_k, top_p, sample_count, verbose) + # preds: (B, pred_len, feat) + + pred_dfs = [] + for i in range(num_series): + preds_i = preds[i] * (stds[i] + 1e-5) + means[i] + pred_df = pd.DataFrame(preds_i, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp_list[i]) + pred_dfs.append(pred_df) + + return pred_dfs