Skip to content

Commit 8282d6c

Browse files
authored
[fix] Fix llama4 min latency (#5117)
Signed-off-by: Jin Li <[email protected]>
1 parent 56abae0 commit 8282d6c

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def compute_routed_output(self, hidden_states, all_rank_num_tokens,
301301
routed_output = self.experts(
302302
hidden_states,
303303
router_logits,
304-
cutlass_min_latency_mode=cutlass_min_latency_mode,
304+
do_finalize=not cutlass_min_latency_mode,
305305
all_rank_num_tokens=all_rank_num_tokens,
306306
use_dp_padding=use_dp_padding,
307307
)

tensorrt_llm/_torch/models/modeling_llama_min_latency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def forward(
515515

516516
return super().forward(x,
517517
router_logits,
518-
cutlass_min_latency_mode=False,
518+
do_finalize=True,
519519
output_dtype=output_dtype)
520520

521521

0 commit comments

Comments
 (0)