fix: define missing split_token in HierarchicalEmbedding
This commit is contained in:
parent
a5f5aba12d
commit
a7e294cc56
@ -440,9 +440,24 @@ class HierarchicalEmbedding(nn.Module):
|
|||||||
nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5)
|
nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5)
|
||||||
nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5)
|
nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5)
|
||||||
|
|
||||||
|
def split_token(self, token_ids: torch.Tensor, s2_bits: int):
|
||||||
|
"""Inputs:
|
||||||
|
token_ids (torch.Tensor): Composite token IDs of shape [batch_size, seq_len] or [N], each in range [0, 2^(s1_bits + s2_bits) - 1].
|
||||||
|
s2_bits (int): Number of low bits used for the fine token (s2).
|
||||||
|
"""
|
||||||
|
assert isinstance(s2_bits, int) and s2_bits > 0, "s2_bits must be a positive integer"
|
||||||
|
|
||||||
|
t = token_ids.long()
|
||||||
|
mask = (1 << s2_bits) - 1
|
||||||
|
s2_ids = t & mask # extract low bits
|
||||||
|
s1_ids = t >> s2_bits # extract high bits
|
||||||
|
return s1_ids, s2_ids
|
||||||
|
|
||||||
def forward(self, token_ids):
|
def forward(self, token_ids):
|
||||||
"""Inputs:
|
"""Inputs:
|
||||||
token_ids: [batch_size, seq_len] token ID
|
token_ids:
|
||||||
|
- tuple or list: (s1_ids, s2_ids), each of shape [batch_size, seq_len], or
|
||||||
|
- torch.Tensor: composite token IDs of shape [batch_size, seq_len], which will be split into (s1_ids, s2_ids) internally.
|
||||||
Output: [batch_size, seq_len, d_model]
|
Output: [batch_size, seq_len, d_model]
|
||||||
"""
|
"""
|
||||||
if isinstance(token_ids, tuple) or isinstance(token_ids, list):
|
if isinstance(token_ids, tuple) or isinstance(token_ids, list):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user