diff --git a/model/module.py b/model/module.py index d79a5b1..72c0619 100644 --- a/model/module.py +++ b/model/module.py @@ -370,7 +370,8 @@ class MultiHeadAttentionWithRoPE(nn.Module): q, k, v, attn_mask=attn_mask, dropout_p=self.attn_dropout_p, - is_causal=True + is_causal=True, + training=self.training ) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) @@ -577,3 +578,4 @@ class TemporalEmbedding(nn.Module): +