Skip to content

Commit 8ef73e6

Browse files
authored
Spec decode: fix MTP function broken by using unified MTP method names (#633)
vLLM PR (vllm-project/vllm#25232) has consolidated speculative decode method name for MTP. We need to use "mtp" in the code instead of other names. Signed-off-by: Chen Haifeng <[email protected]>
1 parent 5c09ab7 commit 8ef73e6

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

tests/full_tests/spec_decode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def test_mtp_model(is_enable, args, prompts, sampling_params, task_key, result_q
243243
enable_expert_parallel=True,
244244
speculative_config={
245245
"num_speculative_tokens": args.num_spec_tokens,
246+
"method": "mtp",
246247
},
247248
disable_log_stats=False,
248249
trust_remote_code=True,
@@ -330,7 +331,7 @@ def test_mtp_model(is_enable, args, prompts, sampling_params, task_key, result_q
330331
task_queue['baseline_eaglemtp'] = {
331332
'proc':
332333
multiprocessing.Process(target=test_eaglemtp_model,
333-
args=(False, args, prompts, sampling_params, 'baseline_mtp', result_queue))
334+
args=(False, args, prompts, sampling_params, 'baseline_eaglemtp', result_queue))
334335
}
335336
task_queue['spec_eaglemtp'] = {
336337
'proc':

vllm_gaudi/v1/spec_decode/hpu_eagle.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def propose(
4545
attn_metadata=common_attn_metadata,
4646
)
4747

48-
if self.method in ("deepseek_mtp", "ernie_mtp"):
48+
# All MTP related method names are now unified to "mtp"
49+
if self.method == "mtp":
4950
last_hidden_states = ret_hidden_states
5051
hidden_states = last_hidden_states
5152
else:
@@ -63,7 +64,7 @@ def propose(
6364
target_positions = target_positions.view(-1)
6465
# [batch_size]
6566
positions = target_positions[last_token_indices]
66-
if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"):
67+
if self.method == "mtp":
6768
hidden_states = target_hidden_states.view(-1, target_hidden_states.shape[-1])
6869
else:
6970
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

0 commit comments

Comments
 (0)