Bug fix
This commit is contained in:
parent
ac69e16750
commit
87157161d4
@ -370,7 +370,8 @@ class MultiHeadAttentionWithRoPE(nn.Module):
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_dropout_p,
|
||||
is_causal=True
|
||||
is_causal=True,
|
||||
training=self.training
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
|
||||
@ -577,3 +578,4 @@ class TemporalEmbedding(nn.Module):
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user