fix: define missing split_token in HierarchicalEmbedding
This commit is contained in:
parent
a5f5aba12d
commit
a7e294cc56
@ -440,9 +440,24 @@ class HierarchicalEmbedding(nn.Module):
|
||||
nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5)
|
||||
nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5)
|
||||
|
||||
def split_token(self, token_ids: torch.Tensor, s2_bits: int):
|
||||
"""Inputs:
|
||||
token_ids (torch.Tensor): Composite token IDs of shape [batch_size, seq_len] or [N], each in range [0, 2^(s1_bits + s2_bits) - 1].
|
||||
s2_bits (int): Number of low bits used for the fine token (s2).
|
||||
"""
|
||||
assert isinstance(s2_bits, int) and s2_bits > 0, "s2_bits must be a positive integer"
|
||||
|
||||
t = token_ids.long()
|
||||
mask = (1 << s2_bits) - 1
|
||||
s2_ids = t & mask # extract low bits
|
||||
s1_ids = t >> s2_bits # extract high bits
|
||||
return s1_ids, s2_ids
|
||||
|
||||
def forward(self, token_ids):
|
||||
"""Inputs:
|
||||
token_ids: [batch_size, seq_len] token ID
|
||||
token_ids:
|
||||
- tuple or list: (s1_ids, s2_ids), each of shape [batch_size, seq_len], or
|
||||
- torch.Tensor: composite token IDs of shape [batch_size, seq_len], which will be split into (s1_ids, s2_ids) internally.
|
||||
Output: [batch_size, seq_len, d_model]
|
||||
"""
|
||||
if isinstance(token_ids, tuple) or isinstance(token_ids, list):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user