Refactor qlib_test.py

This commit is contained in:
ShiYu 2025-09-10 21:37:57 +08:00 committed by GitHub
parent 764913b7d0
commit f82acd69bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -269,8 +269,8 @@ def generate_predictions(config: dict, test_data: dict) -> dict[str, pd.DataFram
max_context=config['max_context'], pred_len=config['pred_len'], clip=config['clip'], max_context=config['max_context'], pred_len=config['pred_len'], clip=config['clip'],
T=config['T'], top_k=config['top_k'], top_p=config['top_p'], sample_count=config['sample_count'] T=config['T'], top_k=config['top_k'], top_p=config['top_p'], sample_count=config['sample_count']
) )
# you can also try this to drop the history data # You can try commenting on this line to keep the history data
# preds = preds[:, -config['pred_len']:, :] preds = preds[:, -config['pred_len']:, :]
# The 'close' price is at index 3 in `feature_list` # The 'close' price is at index 3 in `feature_list`
last_day_close = x[:, -1, 3].numpy() last_day_close = x[:, -1, 3].numpy()
@ -359,3 +359,4 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() main()