Bug fix
This commit is contained in:
parent
ac69e16750
commit
87157161d4
@ -370,7 +370,8 @@ class MultiHeadAttentionWithRoPE(nn.Module):
|
|||||||
q, k, v,
|
q, k, v,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
dropout_p=self.attn_dropout_p,
|
dropout_p=self.attn_dropout_p,
|
||||||
is_causal=True
|
is_causal=True,
|
||||||
|
training=self.training
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
|
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
|
||||||
@ -577,3 +578,4 @@ class TemporalEmbedding(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user