This commit is contained in:
ShiYu 2025-09-16 10:32:03 +08:00 committed by GitHub
parent f82acd69bd
commit ac69e16750
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -309,7 +309,7 @@ class RotaryPositionalEmbedding(nn.Module):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, training=True) -> torch.Tensor:
L, S = query.size(-2), key.size(-2) L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device) attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
@ -332,7 +332,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
attn_weight += attn_mask_bias attn_weight += attn_mask_bias
attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True) attn_weight = torch.dropout(attn_weight, dropout_p, train=training)
return attn_weight @ value return attn_weight @ value
@ -414,7 +414,8 @@ class MultiHeadCrossAttentionWithRoPE(nn.Module):
q, k, v, q, k, v,
attn_mask=attn_mask, attn_mask=attn_mask,
dropout_p=self.attn_dropout_p, dropout_p=self.attn_dropout_p,
is_causal=is_causal_flag is_causal=is_causal_flag,
training=self.training
) )
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model)
@ -575,3 +576,4 @@ class TemporalEmbedding(nn.Module):