diff --git a/finetune/qlib_test.py b/finetune/qlib_test.py index 43a2f0c..97aebe4 100644 --- a/finetune/qlib_test.py +++ b/finetune/qlib_test.py @@ -269,8 +269,8 @@ def generate_predictions(config: dict, test_data: dict) -> dict[str, pd.DataFram max_context=config['max_context'], pred_len=config['pred_len'], clip=config['clip'], T=config['T'], top_k=config['top_k'], top_p=config['top_p'], sample_count=config['sample_count'] ) - # you can also try this to drop the history data - # preds = preds[:, -config['pred_len']:, :] + # You can try commenting on this line to keep the history data + preds = preds[:, -config['pred_len']:, :] # The 'close' price is at index 3 in `feature_list` last_day_close = x[:, -1, 3].numpy() @@ -359,3 +359,4 @@ def main(): if __name__ == '__main__': main() +