diff --git a/model/module.py b/model/module.py index 72c0619..538ebc2 100644 --- a/model/module.py +++ b/model/module.py @@ -440,9 +440,24 @@ class HierarchicalEmbedding(nn.Module): nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5) nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5) + def split_token(self, token_ids: torch.Tensor, s2_bits: int): + """Inputs: + token_ids (torch.Tensor): Composite token IDs of shape [batch_size, seq_len] or [N], each in range [0, 2^(s1_bits + s2_bits) - 1]. + s2_bits (int): Number of low bits used for the fine token (s2). + """ + assert isinstance(s2_bits, int) and s2_bits > 0, "s2_bits must be a positive integer" + + t = token_ids.long() + mask = (1 << s2_bits) - 1 + s2_ids = t & mask # extract low bits + s1_ids = t >> s2_bits # extract high bits + return s1_ids, s2_ids + def forward(self, token_ids): """Inputs: - token_ids: [batch_size, seq_len] token ID + token_ids: + - tuple or list: (s1_ids, s2_ids), each of shape [batch_size, seq_len], or + - torch.Tensor: composite token IDs of shape [batch_size, seq_len], which will be split into (s1_ids, s2_ids) internally. Output: [batch_size, seq_len, d_model] """ if isinstance(token_ids, tuple) or isinstance(token_ids, list):