Refactor dataset and backtesting logic in qlib_test.py

Refactor QlibTestDataset and QlibBacktest classes for improved structure and readability. Update inference logic and main execution flow.
This commit is contained in:
ShiYu 2025-09-10 21:36:27 +08:00 committed by GitHub
parent a4c14cb094
commit 764913b7d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -269,6 +269,8 @@ def generate_predictions(config: dict, test_data: dict) -> dict[str, pd.DataFram
max_context=config['max_context'], pred_len=config['pred_len'], clip=config['clip'],
T=config['T'], top_k=config['top_k'], top_p=config['top_p'], sample_count=config['sample_count']
)
# you can also try this to drop the history data
# preds = preds[:, -config['pred_len']:, :]
# The 'close' price is at index 3 in `feature_list`
last_day_close = x[:, -1, 3].numpy()
@ -356,3 +358,4 @@ def main():
if __name__ == '__main__':
main()