fix: add torch.cuda.empty_cache() during autoregressive inference
Without releasing cached GPU memory, usage will keep growing during autoregressive prediction, leading to significant memory increase or OOM. Calling torch.cuda.empty_cache() prevents this accumulation.
This commit is contained in:
parent
939986adb1
commit
e027051b38
@ -432,6 +432,8 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context
|
||||
x_token[0] = torch.cat([x_token[0], sample_pre], dim=1)
|
||||
x_token[1] = torch.cat([x_token[1], sample_post], dim=1)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
input_tokens = [t[:, -max_context:].contiguous() for t in x_token]
|
||||
z = tokenizer.decode(input_tokens, half=True)
|
||||
z = z.reshape(batch_size, sample_count, z.size(1), z.size(2))
|
||||
@ -621,3 +623,4 @@ class KronosPredictor:
|
||||
pred_dfs.append(pred_df)
|
||||
|
||||
return pred_dfs
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user