Merge pull request #59 from pengxiao-song/dev_empty_cuda_cache
fix: add torch.cuda.empty_cache() during autoregressive inference
This commit is contained in:
commit
2d1a1ae809
@ -432,6 +432,8 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context
|
|||||||
x_token[0] = torch.cat([x_token[0], sample_pre], dim=1)
|
x_token[0] = torch.cat([x_token[0], sample_pre], dim=1)
|
||||||
x_token[1] = torch.cat([x_token[1], sample_post], dim=1)
|
x_token[1] = torch.cat([x_token[1], sample_post], dim=1)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
input_tokens = [t[:, -max_context:].contiguous() for t in x_token]
|
input_tokens = [t[:, -max_context:].contiguous() for t in x_token]
|
||||||
z = tokenizer.decode(input_tokens, half=True)
|
z = tokenizer.decode(input_tokens, half=True)
|
||||||
z = z.reshape(batch_size, sample_count, z.size(1), z.size(2))
|
z = z.reshape(batch_size, sample_count, z.size(1), z.size(2))
|
||||||
@ -621,3 +623,4 @@ class KronosPredictor:
|
|||||||
pred_dfs.append(pred_df)
|
pred_dfs.append(pred_df)
|
||||||
|
|
||||||
return pred_dfs
|
return pred_dfs
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user