Merge pull request #59 from pengxiao-song/dev_empty_cuda_cache
fix: add torch.cuda.empty_cache() during autoregressive inference
This commit is contained in:
commit
2d1a1ae809
@ -432,6 +432,8 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context
|
||||
x_token[0] = torch.cat([x_token[0], sample_pre], dim=1)
|
||||
x_token[1] = torch.cat([x_token[1], sample_post], dim=1)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
input_tokens = [t[:, -max_context:].contiguous() for t in x_token]
|
||||
z = tokenizer.decode(input_tokens, half=True)
|
||||
z = z.reshape(batch_size, sample_count, z.size(1), z.size(2))
|
||||
@ -621,3 +623,4 @@ class KronosPredictor:
|
||||
pred_dfs.append(pred_df)
|
||||
|
||||
return pred_dfs
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user