Merge pull request #167 from SoYuCry/master

fix: define missing split_token in HierarchicalEmbedding
This commit is contained in:
ShiYu 2025-10-26 20:30:18 +08:00 committed by GitHub
commit eeb3168f71
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -440,9 +440,24 @@ class HierarchicalEmbedding(nn.Module):
nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5)
nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5)
def split_token(self, token_ids: torch.Tensor, s2_bits: int):
"""Inputs:
token_ids (torch.Tensor): Composite token IDs of shape [batch_size, seq_len] or [N], each in range [0, 2^(s1_bits + s2_bits) - 1].
s2_bits (int): Number of low bits used for the fine token (s2).
"""
assert isinstance(s2_bits, int) and s2_bits > 0, "s2_bits must be a positive integer"
t = token_ids.long()
mask = (1 << s2_bits) - 1
s2_ids = t & mask # extract low bits
s1_ids = t >> s2_bits # extract high bits
return s1_ids, s2_ids
def forward(self, token_ids):
"""Inputs:
token_ids: [batch_size, seq_len] token ID
token_ids:
- tuple or list: (s1_ids, s2_ids), each of shape [batch_size, seq_len], or
- torch.Tensor: composite token IDs of shape [batch_size, seq_len], which will be split into (s1_ids, s2_ids) internally.
Output: [batch_size, seq_len, d_model]
"""
if isinstance(token_ids, tuple) or isinstance(token_ids, list):