Bug fix
This commit is contained in:
parent
f82acd69bd
commit
ac69e16750
@ -309,7 +309,7 @@ class RotaryPositionalEmbedding(nn.Module):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, training=True) -> torch.Tensor:
|
||||
L, S = query.size(-2), key.size(-2)
|
||||
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
||||
attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
|
||||
@ -332,7 +332,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
||||
attn_weight += attn_mask_bias
|
||||
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
||||
attn_weight = torch.dropout(attn_weight, dropout_p, train=training)
|
||||
return attn_weight @ value
|
||||
|
||||
|
||||
@ -414,7 +414,8 @@ class MultiHeadCrossAttentionWithRoPE(nn.Module):
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_dropout_p,
|
||||
is_causal=is_causal_flag
|
||||
is_causal=is_causal_flag,
|
||||
training=self.training
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model)
|
||||
@ -575,3 +576,4 @@ class TemporalEmbedding(nn.Module):
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user