diff --git a/model/module.py b/model/module.py index 9a7f1e9..d79a5b1 100644 --- a/model/module.py +++ b/model/module.py @@ -309,7 +309,7 @@ class RotaryPositionalEmbedding(nn.Module): return torch.cat((-x2, x1), dim=-1) -def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, training=True) -> torch.Tensor: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device) @@ -332,7 +332,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. attn_weight += attn_mask_bias attn_weight = torch.softmax(attn_weight, dim=-1) - attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + attn_weight = torch.dropout(attn_weight, dropout_p, train=training) return attn_weight @ value @@ -414,7 +414,8 @@ class MultiHeadCrossAttentionWithRoPE(nn.Module): q, k, v, attn_mask=attn_mask, dropout_p=self.attn_dropout_p, - is_causal=is_causal_flag + is_causal=is_causal_flag, + training=self.training ) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model) @@ -575,3 +576,4 @@ class TemporalEmbedding(nn.Module): +