diff --git a/model/kronos.py b/model/kronos.py index b22ee8c..e3fb417 100644 --- a/model/kronos.py +++ b/model/kronos.py @@ -432,6 +432,8 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context x_token[0] = torch.cat([x_token[0], sample_pre], dim=1) x_token[1] = torch.cat([x_token[1], sample_post], dim=1) + torch.cuda.empty_cache() + input_tokens = [t[:, -max_context:].contiguous() for t in x_token] z = tokenizer.decode(input_tokens, half=True) z = z.reshape(batch_size, sample_count, z.size(1), z.size(2)) @@ -621,3 +623,4 @@ class KronosPredictor: pred_dfs.append(pred_df) return pred_dfs +