Skip to content

Commit e9aa8b2

Browse files
authored
[https://nvbugs/5556020][fix] cherry-pick fix test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3 dimension mismatch (#8644)
Signed-off-by: qgai <[email protected]>
1 parent beafc39 commit e9aa8b2

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,9 +470,9 @@ def _update_target_inputs_with_draft_tokens(
470470
continue
471471

472472
# Get the index of the draft/target tokens in the device tensor
473-
draft_idx = req_idx if self.use_static_draft_loop else request.py_batch_idx
473+
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
474474
target_idx = req_id_to_old_request[
475-
request.py_request_id].py_batch_idx
475+
request.py_request_id].py_seq_slot
476476
target_inputs.new_tokens[draft_position + 1:draft_position +
477477
draft_length + 1, target_idx,
478478
0] = draft_tensors[0:draft_length,

0 commit comments

Comments
 (0)