diff --git a/angelslim/compressor/speculative/train/data/data_utils.py b/angelslim/compressor/speculative/train/data/data_utils.py index 76d1c31e..a452d5b3 100644 --- a/angelslim/compressor/speculative/train/data/data_utils.py +++ b/angelslim/compressor/speculative/train/data/data_utils.py @@ -133,10 +133,21 @@ def process_token_dict_to_mappings( used_tokens = [key for key, freq in top_N] used_tokens.sort() - d2t = [used_tokens[i] - i for i in range(len(used_tokens))] - t2d = [i in used_tokens for i in range(target_vocab_size)] - d2t = torch.tensor(d2t) - t2d = torch.tensor(t2d) + used_set = set(used_tokens) + d2t = torch.tensor( + [used_tokens[i] - i for i in range(len(used_tokens))], + dtype=torch.int64, # must match register_buffer dtype in Eagle3LlamaForCausalLM + ) + t2d = torch.tensor( + [i in used_set for i in range(target_vocab_size)], + dtype=torch.bool, # must match register_buffer dtype in Eagle3LlamaForCausalLM + ) + + assert d2t.shape == (draft_vocab_size,), f"d2t shape {d2t.shape} != ({draft_vocab_size},)" + assert t2d.shape == (target_vocab_size,), f"t2d shape {t2d.shape} != ({target_vocab_size},)" + assert ( + t2d.sum().item() == draft_vocab_size + ), f"t2d has {t2d.sum().item()} True entries, expected {draft_vocab_size}" return d2t, t2d