diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py index 6fd04f50918..625e588a132 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py @@ -41,6 +41,7 @@ def fused_mlp_moe_kernel( topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, + num_tokens_post_padded_ptr, # Matrix dimensions N, K, @@ -84,6 +85,10 @@ def fused_mlp_moe_kernel( pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) # Bounds check: EM might not be a multiple of BLOCK_SIZE_M # so offs_token_id can exceed EM-1. Load with mask to avoid out-of-bounds. @@ -270,6 +275,7 @@ def _grid(META): topk_weights if topk_weights is not None else C, sorted_token_ids, expert_ids, + num_tokens_post_padded, B.size(1), B.size(2), EM,