Skip to content

Conversation

@nv-yunzheq
Copy link
Contributor

@nv-yunzheq nv-yunzheq commented Oct 31, 2025

πŸ“Œ Description

πŸ” Related Issues

πŸš€ Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

βœ… Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

πŸ§ͺ Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • SM90 scatter-based epilogue and broader SM100/SM120 MOE/GEMM coverage; new public enum for GEMM stages and explicit runner instantiations.
  • Improvements

    • New runtime controls and parameters exposed: dynamic CGA, swap-AB, swizzled-input SF, unpadded hidden-size, and per-GEMM-stage tactic counts; expanded tile/cluster shape options, finalize-epilogue fusion and fusion/swap-aware dispatch; increased runtime debug logging and profiling.
  • Bug Fixes

    • License/namespace/header cleanups, suppressed compiler warnings, tightened assertions.
  • Tests

    • MXFP8Γ—MXFP4 test now permits SM120 devices.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 31, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Threads swizzled_input_sf, unpadded_hidden_size, router_scales, permuted_row_to_unpermuted_row, swap_ab and finalize-fusion flags through MOE/CUTLASS flows; adds SM90 scatter epilogue visitor; extends tile/cluster enums and SM100/SM120 candidate generation; renames many kernel namespaces to cutlass_kernels_oss; adds explicit template instantiations and launcher/signature updates.

Changes

Cohort / File(s) Summary
Fused MOE instantiations
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu
Added explicit template instantiation for CutlassMoeFCRunner with FP8/uint4/BF16/FP8 combo and INSTANTIATE_FINALIZE_MOE_ROUTING(...) instantiations (half, float, conditional BF16).
Fused MOE kernels
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh
Threaded swizzled_input_sf, padded_cols/unpadded_cols, router_scales, permuted_row_to_unpermuted_row; updated writeSF, strides, finalize kernels, expandInputRows, fusion path selection, profiler/workspace flows.
Bindings & Runner
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu, flashinfer/fused_moe/core.py
FusedMoeRunner aggregates gemm1/gemm2 tactics, exposes tactic counts/getters, wires unpadded_hidden_size and swizzled_input_sf through profiling/init, and flashinfer tuner threads gemm_idx_for_tuning for stage-specific tactic selection.
MOE public interfaces
csrc/nv_internal/.../include/moe_kernels.h, .../include/moe_gemm_kernels.h
Added MoeGemmId; expanded getTactics/getConfigs/signatures to accept gemm id/sm/fusion flags; runMoe/gemm2/computeStrides dispatch signatures extended with swizzled_input_sf, unpadded_hidden_size, router_scales, permutation pointers; added use_fused_finalize_ and profiler workspace arrays.
TMA warp inputs & workspace
.../moe_gemm/moe_tma_warp_specialized_input.cu, .../moe_gemm/moe_tma_warp_specialized_traits.h
Workspace buffers increased (17β†’20); A/B renamed to Act/Weight pointers/strides; setFinalizeFusionParams signature changed; SM120/FP4/FP8 specialization checks reworked.
SM90 epilogue visitor
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/.../sm90_visitor_scatter.hpp
New SM90 scatter pointer-array epilogue visitor, reduction helpers, ScaledAccPerRow/PerColBias types, pointer-array scatter store fusion callbacks and fused-bias/scale/reduction variants.
Gemm config & heuristics
.../cutlass_extensions/gemm_configs.h, .../cutlass_kernels/cutlass_heuristic.{h,cpp}
Added shape_tuple_to_enum/enum_to_shape_tuple, new TileShape/ClusterShape enums, EpilogueFusionType, dynamic/fallback cluster shapes, swap_ab field; added SM100/SM120 candidate generation and DYNAMIC_CGA-aware filtering.
Cutlass OSS namespace moves
many files under csrc/nv_internal/.../fpA_intB_gemm/*, .../moe_gemm/*, flashinfer/jit/gemm/cutlass/*
Public namespace renamed to tensorrt_llm::kernels::cutlass_kernels_oss; updated opening/closing comments, re-exports, generated code, and call sites.
Dispatch & launchers
moe_gemm/moe_gemm_template_dispatch*.h, moe_gemm_tma_ws_launcher.h, moe_gemm_tma_ws_mixed_input_launcher.*
Introduced dispatchMoeGemmFinalDispatchTmaWarpSpecialized and getDispatchFunctionForSM100; dispatchs now accept CutlassGemmConfig, dynamic/fallback cluster shapes; template parameter lists extended (EpilogueSchedule, DYNAMIC_CGA, SwapAB).
Gather utils & cuda utils
.../gather_tensor.hpp, csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h
Moved IndexedGather/CustomStride into cutlass::util namespace; qualified cute:: types; added template<bool VALUE> using ConstBool = ConstExprWrapper<bool, VALUE>;.
MOE launchers & signatures
.../moe_gemm/launchers/*
Multiple launcher signatures extended (biases, bias_is_broadcast, C output, padded/unpadded cols, gemm dims, num_experts, workspace_size, occupancy), and namespace rename to OSS variants.
Template instantiations & licenses
many moe_gemm_kernels_*.cu
Standardized many headers to Apache‑2.0, simplified includes to "moe_gemm_template_dispatch.h", added many explicit MoeGemmRunner template instantiations (various FP/BF/uint combos) and adjusted namespace boundaries.
Codegen & tests
flashinfer/jit/gemm/cutlass/generate_kernels.py, tests/moe/test_trtllm_cutlass_fused_moe.py
Generator adds dynamic_cga and swap_ab flags and emits OSS namespace content including SM103/SM120 variants; test skip list expanded to include SM120 for MXFP8/MXFP4 test.

Sequence Diagram(s)

sequenceDiagram
  participant App
  participant Runner as CutlassMoeFCRunner
  participant Heuristic
  participant Profiler
  participant Dispatcher
  Note over App,Runner: runMoe(..., swizzled_input_sf, unpadded_hidden_size, router_scales, permuted_row_to_unpermuted_row, swap_ab)
  App->>Runner: runMoe(...)
  Runner->>Heuristic: getTactics(gemm_id, sm, supports_finalize_fusion)
  Heuristic-->>Runner: candidate CutlassGemmConfig (may include FINALIZE, swap_ab, dynamic cluster shapes)
  Runner->>Profiler: profile/select (uses unpadded_hidden_size, stage-specific tactic counts)
  Profiler-->>Runner: selected gemm_config
  Runner->>Dispatcher: dispatch(gemm_config, router_scales, permuted_row_to_unpermuted_row, swizzled_input_sf, swap_ab)
  Dispatcher-->>Runner: launches kernel (TMA warp specialized / finalize fused / scatter epilogue)
  Runner-->>App: results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Areas to focus during review:

  • Namespace rename consistency and re-exports across headers/implementations.
  • Correct propagation and argument ordering of new parameters (swizzled_input_sf, unpadded_hidden_size, router_scales, permuted_row_to_unpermuted_row, swap_ab).
  • FINALIZE epilogue filtering, workspace sizing, SMEM/no-SMEM compatibility.
  • SM100/SM120 candidate generation and DYNAMIC_CGA effects on filtering/expansion.
  • SM90 scatter epilogue correctness (reduction ops, pointer-array scatter, FusionCallbacks).
  • Workspace buffer index/size changes and renamed Act/Weight pointer usages.
  • New explicit instantiations and JIT symbol finalization macros.

Possibly related PRs

Suggested reviewers

  • joker-eph
  • djmmoss
  • yongwww
  • cyx-6
  • wenscarl
  • IwakuraRein
  • kahyunnam

Poem

🐰 I hopped through headers, swizzled scales at dawn,
OSS names stitched, and tile shapes newly drawn.
Buffers grew three, scatter paths hum along,
Flags threaded true β€” kernels sing their song.
A rabbit cheers: compile fast, land strong.

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ⚠️ Warning The PR description contains only the repository's template with checkboxes, missing the actual implementation details about what changes are made and why. Add a detailed description of the changes made, including the purpose of the MOE/CUTLASS kernel updates, key modifications to namespaces, API signatures, and any migration steps. Explain why the finalize fusion and dynamic cluster shape features were added.
Title check ❓ Inconclusive The title 'update trtllm cutlass moe' is too vague and generic. It lacks specific details about what aspect of the CUTLASS MOE was updated. Use a more descriptive title that highlights the primary change, such as 'Add dynamic cluster shape support and finalize fusion for TensorRT-LLM MOE kernels' or similar.
✨ Finishing touches
  • πŸ“ Generate docstrings
πŸ§ͺ Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❀️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request primarily focuses on enhancing the TensorRT-LLM (TRTLLM) CUTLASS Mixture-of-Experts (MoE) implementation, particularly for Hopper and Blackwell architectures. The main objective is to introduce a new FINALIZE epilogue fusion type for TMA warp-specialized grouped GEMM, which allows for more efficient post-processing operations. Additionally, it adds support for dynamic cluster shapes on SM100, expands mixed-precision capabilities with FP8xFP4, and streamlines the codebase by removing the deprecated min-latency mode and performing general refactoring. These changes aim to improve performance and flexibility in MoE computations.

Highlights

  • New Epilogue Fusion Type: A FINALIZE epilogue fusion type has been introduced for TMA warp-specialized grouped GEMM operations, enabling fused post-processing steps directly within the kernel for improved efficiency.
  • Dynamic Cluster Shape Support: The CutlassGemmConfig now supports dynamic cluster shapes for SM100 (Blackwell) architectures, allowing for more flexible kernel configurations at runtime based on workload characteristics.
  • FP8xFP4 Mixed Precision Support: Added support for FP8 activation with FP4 weights (WFP4AFP8) in TMA warp-specialized GEMM, including specific handling for SM103, expanding the range of supported mixed-precision computations.
  • Min Latency Mode Removal: The 'Min Latency Mode' for TMA warp-specialized grouped GEMM has been removed, simplifying the codebase and focusing on more generalized optimizations.
  • Code Refactoring and Cleanup: Various code refactorings were performed, including renaming variables (e.g., ptr_a to ptr_act, stride_a to stride_act), updating copyright years, and removing unused code to enhance maintainability and clarity.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with πŸ‘ and πŸ‘Ž on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩

@nv-yunzheq
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !104 has been updated with latest changes, and the CI pipeline #37963389 is currently running. I'll report back once the pipeline job completes.

layout_info2.default_epilogue.ptr_d[expert] = nullptr;
}
}
bias2, gemm2_output, router_scales, permuted_row_to_unpermuted_row, expert);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theres was a bug identified in TRT-LLM that the asm volatile("griddepcontrol.launch_dependents;"); on line 1302 is incorrect and needs to be moved to the end of the kernel. There should be no observable perf difference from doing this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to the end of the kernel

auto id1 = profile_ids.value()[0];
if (id1 != -1) {
TVM_FFI_ICHECK(id1 >= 0 && id1 < static_cast<int64_t>(mAllProfiles.size()))
<< "Invalid gemm1 profile id: " << id1;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this check that the tactic is not in the gemm 2 tactic range. Not all GEMM 2 tactics are valid for GEMM 1.
I think in general this combined approach is dangerous, I didn't implement MOE with any particular thought that we could get GEMM2 tactics for GEMM1, this may break or have other subtle failures such as the profiler picking a worse implementation. Any chance we could separate them in the proper API (happy for this to be a later PR though)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to make sure id1 is smaller than mGemm1TacticCount. Agree to separate them them in the future

@pavanimajety
Copy link
Contributor

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)

401-417: Critical: Non-OSS path still missing parameters and hardcodes enable_alltoall=false.

Despite being marked as addressed in commit b09efb0, the non-OSS branch still has critical issues:

  1. Line 413: Missing unpadded_hidden_size parameter (compare to OSS line 395)
  2. Line 415: Hardcodes false instead of passing enable_alltoall
  3. Line 415: Missing use_lora flag before lora_params

This causes functional regressions: all-to-all communication is disabled and LoRA configuration is lost in non-OSS builds.

-        quant_params, num_rows, hidden_size, inter_size, num_experts_total,
+        quant_params, num_rows, hidden_size, unpadded_hidden_size, inter_size, num_experts_total,
         static_cast<int>(experts_per_token),
         static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
-        static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
+        static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall,
+        use_lora, lora_params,
         mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream);
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (1)

351-355: Move griddepcontrol.launch_dependents to the end of three kernel functions to prevent data races.

Three instances of griddepcontrol.launch_dependents are called before all memory store operations complete, violating the PDL requirement to ensure stores finish in flight:

  • Lines 352-354 in fusedBuildExpertMapsSortFirstTokenKernel: followed by memory writes at lines 356-372 before function end
  • Lines 1508-1510: followed by padding writes at lines 1512+
  • Lines 2215-2217: followed by padding writes at lines 2219+

Since launch_dependents provides no memory visibility guarantee, dependent kernels may see stale data. Move each call to immediately before the enclosing function's closing brace.

♻️ Duplicate comments (5)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)

585-602: Critical: Non-OSS min-latency path hardcodes enable_alltoall=false.

Line 599 passes false instead of enable_alltoall, disabling all-to-all communication in non-OSS builds. While this path correctly includes use_lora_ml (unlike the regular runMoe), it still has the critical enable_alltoall bug.

-        static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, use_lora_ml,
+        static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, use_lora_ml,
         lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl,
         stream);
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (4)

1011-1050: Deduplicate SF layout branch in writeSF; compute layout once and call helper once.

Removes duplicate cvt_quant_get_sf_out_offset calls and branches on layout only.

   if (sf_out) {
     if (input_sf) {
-      if (swizzled_input_sf) {
-        auto const sf_in =
-            cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
-                                        NumThreadsPerSF>(
-                std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */,
-                num_cols / VecSize,
-                const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf),
-                QuantizationSFLayout::SWIZZLED_128x4);
-        *sf_out = *sf_in;
-      } else {
-        auto const sf_in =
-            cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
-                                        NumThreadsPerSF>(
-                std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */,
-                num_cols / VecSize,
-                const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf),
-                QuantizationSFLayout::LINEAR);
-        *sf_out = *sf_in;
-      }
+      auto const layout = swizzled_input_sf ? QuantizationSFLayout::SWIZZLED_128x4
+                                            : QuantizationSFLayout::LINEAR;
+      auto const sf_in =
+          cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
+                                      NumThreadsPerSF>(
+              std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */,
+              num_cols / VecSize,
+              const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf), layout);
+      *sf_out = *sf_in;
     } else {
       *sf_out = 0x00;
     }
   }

1684-1686: Finalize kernel alignment checks must use element-width-based stride, not hard-coded 4.

Tie asserts to FINALIZE_ELEM_PER_THREAD for correctness across dtypes.

-  assert(padded_cols % 4 == 0);
-  assert(unpadded_cols % 4 == 0);
-  assert(unpadded_cols <= padded_cols);
+  // Load 128-bits per thread, according to smallest IO type
   constexpr int64_t FINALIZE_ELEM_PER_THREAD =
       128 / std::min(sizeof_bits<OutputType>::value, sizeof_bits<GemmOutputType>::value);
+  assert(padded_cols % FINALIZE_ELEM_PER_THREAD == 0);
+  assert(unpadded_cols % FINALIZE_ELEM_PER_THREAD == 0);
+  assert(unpadded_cols <= padded_cols);
-  int64_t const start_offset = threadIdx.x;
-  int64_t const stride = FINALIZE_THREADS_PER_BLOCK;
-  int64_t const num_elems_in_padded_col = padded_cols / FINALIZE_ELEM_PER_THREAD;
-  int64_t const num_elems_in_orig_col = unpadded_cols / FINALIZE_ELEM_PER_THREAD;
+  int64_t const start_offset = threadIdx.x;
+  int64_t const stride = FINALIZE_THREADS_PER_BLOCK;
+  int64_t const num_elems_in_padded_col = padded_cols / FINALIZE_ELEM_PER_THREAD;
+  int64_t const num_elems_in_orig_col = unpadded_cols / FINALIZE_ELEM_PER_THREAD;

Also applies to: 1693-1700


1761-1764: Same alignment issue in no-filling finalize kernel; mirror FINALIZE_ELEM_PER_THREAD asserts.

Apply element-width-based asserts and derived counts.

-  assert(padded_cols % 4 == 0);
-  assert(unpadded_cols % 4 == 0);
-  assert(unpadded_cols <= padded_cols);
+  // Alignment checks moved below after FINALIZE_ELEM_PER_THREAD is known
   ...
   constexpr int64_t FINALIZE_ELEM_PER_THREAD =
       128 / std::min(sizeof_bits<OutputType>::value, sizeof_bits<GemmOutputType>::value);
+  assert(padded_cols % FINALIZE_ELEM_PER_THREAD == 0);
+  assert(unpadded_cols % FINALIZE_ELEM_PER_THREAD == 0);
+  assert(unpadded_cols <= padded_cols);
-  int64_t const num_elems_in_padded_col = padded_cols / FINALIZE_ELEM_PER_THREAD;
-  int64_t const num_elems_in_orig_col = unpadded_cols / FINALIZE_ELEM_PER_THREAD;
+  int64_t const num_elems_in_padded_col = padded_cols / FINALIZE_ELEM_PER_THREAD;
+  int64_t const num_elems_in_orig_col = unpadded_cols / FINALIZE_ELEM_PER_THREAD;

Also applies to: 1799-1803


3951-3952: Min-latency path currently throws; restore functional fallback.

Throwing breaks SM90 min-latency flows. Route LL to non-LL compute until LL is ready.

Option A (minimal): switch setupTmaWarpSpecializedInputs to non-LL compute in LL mode:

@@
   if (min_latency_mode) {
@@
-    return Self::computeStridesTmaWarpSpecializedLowLatency(
-        gemm1_tma_ws_input, gemm2_tma_ws_input, num_rows, fc1_out_size, hidden_size, hidden_size,
-        inter_size, num_experts_per_node, reinterpret_cast<T const*>(gemm1_input),
-        reinterpret_cast<T const*>(gemm2_input), fc1_expert_weights, fc2_expert_weights,
-        quant_params.fp8.dequant_fc1, quant_params.fp8.dequant_fc2, input_sf, fc2_fp4_act_scale_,
-        quant_params, nullptr, nullptr, reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output),
-        reinterpret_cast<UnfusedGemmOutputType*>(fc2_result_),
-        min_latency_params.num_active_experts_per_node, min_latency_params.active_expert_global_ids,
-        start_expert, enable_pdl, stream);
+    // Temporary fallback: use non-LL path to keep correctness
+    return Self::computeStridesTmaWarpSpecialized(
+        expert_first_token_offset_, gemm1_tma_ws_input, gemm2_tma_ws_input, num_rows,
+        expanded_num_rows, fc1_out_size, hidden_size, hidden_size, inter_size, num_experts_per_node,
+        reinterpret_cast<T const*>(gemm1_input), reinterpret_cast<T const*>(gemm2_input),
+        fc1_expert_weights, fc2_expert_weights, quant_params.fp8.dequant_fc1,
+        quant_params.fp8.dequant_fc2, input_sf, fc2_fp4_act_scale_, quant_params,
+        /*bias1=*/nullptr, /*bias2=*/nullptr,
+        reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output),
+        reinterpret_cast<UnfusedGemmOutputType*>(fc2_result_),
+        /*router_scales=*/permuted_token_final_scales_, /*permuted_row_to_unpermuted_row=*/permuted_row_to_unpermuted_row_,
+        enable_pdl, stream);
   }

Alternatively, if you prefer keeping call sites unchanged, implement a similar fallback inside computeStridesTmaWarpSpecializedLowLatency with a warning.

#!/bin/bash
rg -n "computeStridesTmaWarpSpecializedLowLatency" -C2
🧹 Nitpick comments (1)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)

226-232: Address djns99's architectural concern about combined GEMM1/GEMM2 tactics.

While the current implementation now correctly tracks the split between GEMM1 and GEMM2 tactics, djns99's comment highlights a fundamental issue: "I didn't implement MOE with any particular thought that we could get GEMM2 tactics for GEMM1, this may break or have other subtle failures such as the profiler picking a worse implementation."

Consider separating the GEMM1 and GEMM2 tactic APIs completely rather than concatenating them, which could help prevent cross-contamination and make the API clearer.

Based on past review comments.

πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between 22b97b0 and a30f033.

πŸ“’ Files selected for processing (3)
  • csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (55 hunks)
  • csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (10 hunks)
  • tests/moe/test_trtllm_cutlass_fused_moe.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)
include/flashinfer/trtllm/fused_moe/runner.h (3)
  • hidden_size (265-265)
  • num_experts (263-263)
  • top_k (270-270)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
πŸ”‡ Additional comments (13)
tests/moe/test_trtllm_cutlass_fused_moe.py (1)

1089-1090: LGTM! SM120 support correctly added.

The skip condition now includes compute capability 12 (SM120) and the reason string has been updated accordingly, addressing the previous review feedback. This change is consistent with the similar update in test_moe_nvfp4 (lines 366-369) and aligns with the broader PR objective of extending SM120 support.

csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (7)

384-400: LGTM: OSS path correctly threads new parameters.

The OSS branch properly passes swizzled_input_sf, unpadded_hidden_size, and use_lora to the kernel runner.


568-584: LGTM: OSS min-latency path correctly uses new parameters.

The parameters are properly threaded through, matching the regular runMoe pattern.


714-717: LGTM: Exposing GEMM tactic counts improves API clarity.

These new functions allow callers to understand the GEMM1/GEMM2 tactic split, which is essential for correct tactic selection.


785-786: LGTM: Member variables properly track tactic counts.

The zero-initialization and int64_t type are appropriate.


799-804: LGTM: Default GEMM2 profile now correctly selected from GEMM2 subrange.

The fallback logic properly uses mGemm1TacticCount to index into the GEMM2 tactics, fixing the previous critical bug where both GEMMs defaulted to the same GEMM1 tactic.


805-813: LGTM: GEMM1 index validation prevents cross-tactic contamination.

The range check id1 < mGemm1TacticCount ensures GEMM1 can only select from its own tactics, addressing part of djns99's concern.


380-383: Clarify handling of hardcoded kernel parameters as documented temporary constraints.

These three values are marked with HACK/TODO comments indicating they are known limitations:

  • use_lora = false aligns with the "TODO: support lora in the future" comment β€” acceptable for now
  • swizzled_input_sf = true assumes input scale factors are always swizzled; verify this matches all actual inputs
  • unpadded_hidden_size = hidden_size assumes no padding; this assumption must hold for correctness

The same pattern (hardcoded unpadded_hidden_size = hidden_size) repeats in runMoe variants and profiler code throughout this file. If padding is ever used, this will silently compute incorrect results. Either verify padding is never applied in practice, or expose these as parameters so callers can provide correct values.

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (5)

1093-1111: swap_ab-aware FP4/MXFPX stride setup looks correct.

Good: symmetrical SFA/SFB mapping and explicit transpose handling.


1196-1210: Fused finalize epilogue pointers are wired per-expert correctly.

Setting ptr_source_token_index, ptr_router_scales, and optional bias per expert is sound.


3598-3617: NoSmem epilogue + finalize fusion guardrails are correct.

Appropriate TLLM_CHECKs preventing unsupported combinations.

Consider adding a unit/profiler case asserting this path errors cleanly.


4018-4040: Finalize fusion gating matches runner config and excludes w4_groupwise/LoRA.

Consistent with prepareTmaWsInputs; good.

Ensure mInterface->use_fused_finalize_ is set coherently where GemmProfilerBackend is constructed.


1641-1647: Public macro instantiations: signature/arg ordering consistent.

Launchers’ new params (input_sf, swizzled_input_sf, padded/unpadded) flow through correctly.

Also applies to: 1912-1921

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #37963389: 13/17 passed

@yzh119
Copy link
Collaborator

yzh119 commented Nov 6, 2025

@nv-yunzheq there are some compilation errors in CI: https://ci.tlcpack.ai/blue/organizations/jenkins/flashinfer-ci/detail/PR-2020/22/pipeline/23, can you double check? (I'm not sure if it's because of cutlass versions).

@nv-yunzheq
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !104 has been updated with latest changes, and the CI pipeline #37982790 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #37982790: canceled

@nv-yunzheq
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !104 has been updated with latest changes, and the CI pipeline #37985869 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (1)

221-236: Fix OOB read: s_local_experts indexed with global expert id.

In min-latency map build, s_local_experts is sized to local experts but is indexed by the global expert, causing OOB when expert βˆ‰ [start_expert, end_expert). Guard and subtract start_expert.

Apply:

-    bool is_valid_expert =
-        smart_routing ? s_local_experts[expert] : (expert >= start_expert && expert < end_expert);
+    bool const expert_in_node = (expert >= start_expert && expert < end_expert);
+    bool is_valid_expert = smart_routing
+                               ? (expert_in_node && s_local_experts[expert - start_expert])
+                               : expert_in_node;

Also consider mirroring this guard wherever s_store_experts[expert - start_expert] is used to avoid underflow when expert_in_node == false.

♻️ Duplicate comments (5)
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (5)

1724-1734: Restore defensive check for invalid permutation indices (debug-only OK).

expanded_permuted_row = unpermuted_row_to_permuted_row[...] has no validity guard. If upstream builds ever leave sentinel values, this will read OOB from expanded_permuted_rows.

-      int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row];
+      int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row];
+#ifndef NDEBUG
+      if (expanded_permuted_row < 0) { continue; }
+#endif

Alternatively add an unconditional if (expanded_permuted_row < 0) continue; if negative is a valid sentinel in production.


1031-1050: De-duplicate swizzled vs linear SF input handling.

Simplify by computing layout once and a single call to cvt_quant_get_sf_out_offset.

-      if (swizzled_input_sf) {
-        auto const sf_in =
-            cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
-                                        NumThreadsPerSF>(
-                std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */,
-                num_cols / VecSize,
-                const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf),
-                QuantizationSFLayout::SWIZZLED_128x4);
-        *sf_out = *sf_in;
-      } else {
-        auto const sf_in =
-            cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
-                                        NumThreadsPerSF>(
-                std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */,
-                num_cols / VecSize,
-                const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf),
-                QuantizationSFLayout::LINEAR);
-        *sf_out = *sf_in;
-      }
+      auto const layout = swizzled_input_sf ? QuantizationSFLayout::SWIZZLED_128x4
+                                            : QuantizationSFLayout::LINEAR;
+      auto const sf_in =
+          cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
+                                      NumThreadsPerSF>(
+              std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */,
+              num_cols / VecSize,
+              const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf),
+              layout);
+      *sf_out = *sf_in;

3937-3955: Min‑latency path currently throws; add safe fallback or gate dispatch.

computeStridesTmaWarpSpecializedLowLatency unconditionally throws, breaking callers (setupTmaWarpSpecializedInputs min‑latency branch).

Options:

  • Route min‑latency to the non‑LL computeStridesTmaWarpSpecialized with a temporary expert_first_token_offset built from num_active_experts_per/active_expert_global_ids, or
  • Gate all LL dispatch sites behind TLLM_CHECK_WITH_INFO(!min_latency_mode) to avoid calling this until LL is reintroduced. Do you want a minimal fallback drafted?

1684-1700: Align checks to element width, not constant 4.

Hardcoding % 4 can break for dtypes where FINALIZE_ELEM_PER_THREAD != 4. Use the computed constant.

-  assert(padded_cols % 4 == 0);
-  assert(unpadded_cols % 4 == 0);
-  assert(unpadded_cols <= padded_cols);
+  assert(unpadded_cols <= padded_cols);
+  constexpr int64_t FINALIZE_ELEM_PER_THREAD =
+      128 / std::min(sizeof_bits<OutputType>::value, sizeof_bits<GemmOutputType>::value);
+  assert(padded_cols   % FINALIZE_ELEM_PER_THREAD == 0);
+  assert(unpadded_cols % FINALIZE_ELEM_PER_THREAD == 0);

As per earlier feedback.


1761-1764: Same alignment issue + simplify loop bound.

Mirror the FINALIZE_ELEM_PER_THREAD-based asserts and iterate to num_elems_in_orig_col to avoid per-iteration branch.

-  assert(padded_cols % 4 == 0);
-  assert(unpadded_cols % 4 == 0);
-  assert(unpadded_cols <= padded_cols);
+  assert(unpadded_cols <= padded_cols);
+  constexpr int64_t FINALIZE_ELEM_PER_THREAD =
+      128 / std::min(sizeof_bits<OutputType>::value, sizeof_bits<GemmOutputType>::value);
+  assert(padded_cols   % FINALIZE_ELEM_PER_THREAD == 0);
+  assert(unpadded_cols % FINALIZE_ELEM_PER_THREAD == 0);
@@
-    for (int elem_index = start_offset; elem_index < num_elems_in_padded_col;
-         elem_index += stride) {
-      if (elem_index >= num_elems_in_orig_col) continue;  // Skip writing beyond original columns
+    for (int elem_index = start_offset; elem_index < num_elems_in_orig_col;
+         elem_index += stride) {

As per earlier feedback.

Also applies to: 1799-1806

🧹 Nitpick comments (2)
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (2)

280-286: CUDA dynamic shared memory attr check: allow == max.

Guard rejects shared_size >= max_smem_per_block; typically equality is valid. Prefer > to avoid unnecessary fallback.

-  if (shared_size >= static_cast<size_t>(max_smem_per_block)) {
+  if (shared_size > static_cast<size_t>(max_smem_per_block)) {

Also applies to: 606-620


197-217: Smart routing: active_expert_global_ids semantics.

In the smart-routing branch, the stored id is i (local expert index), while in the else-branch it is i + start_expert (global id). If consumers expect global ids in both modes (per comment), convert to global via i + start_expert. Otherwise, please add a comment clarifying that LL mode uses local ids.

πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between 31b0df0 and 91a85ad.

πŸ“’ Files selected for processing (1)
  • csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (56 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
πŸ”‡ Additional comments (3)
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (3)

1886-1894: Incorrect reference in line citation.

The concern about padded_cols and unpadded_cols consistency applies only to lines 1886-1894 and 1901-1909 (the actual kernel launches within finalizeMoeRoutingKernelLauncher). Lines 3850-3858 reference a different function call (Self::gemm2()) and should not be included. All three finalizeMoeRoutingKernelLauncher call sites (2925-2930, 3302-3308, 3309-3315) correctly pass hidden_size and unpadded_hidden_size with consistent semantic mapping to the kernel parameters.

Likely an incorrect or invalid review comment.


270-286: Incorrect line references in review comment.

Lines 637-645 and 686-696 are kernel device code (globalExpertPrefixSumLargeKernel and globalExpertPrefixSumKernel function implementations), not cudaLaunchKernelEx call sites. The actual launches in cutlass_fused_moe_kernels.cuh at lines 281, 617, 738, 743, and 801 all consistently set attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization with .programmaticStreamSerializationAllowed = enable_pdl.

Likely an incorrect or invalid review comment.


4040-4051: Aliasing and memset size are already correctly handled.

The code at lines 2811-2812 allocates both fc1_fp4_act_scale_ and fc2_fp4_act_scale_ from the same workspace buffer key ("fp4_act_scale"), ensuring they point to identical memory. The workspace allocation (lines 2606-2612) uses std::max(fc1_fp4_act_scale_size, fc2_fp4_act_scale_size) to reserve space, and the memset at lines 4066-4069 uses the identical std::max(fc1_sf_offset, fc2_sf_offset) logic to compute the fill size. The assertion at line 4049 confirms the aliasing invariant. Since both allocation and memset employ the same max-size calculation, the memset is guaranteed to fit within the workspace allocation.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #37985869: canceled

@nv-yunzheq
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !104 has been updated with latest changes, and the CI pipeline #37989907 is currently running. I'll report back once the pipeline job completes.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failed UT on gb300 is not relevant, LGTM on my side.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #37989907: 12/17 passed

@yzh119
Copy link
Collaborator

yzh119 commented Nov 6, 2025

There are still some remaining cu126 compilation issues such as:

[2025-11-06T07:40:48.794Z] FAILED: [code=2] fused_moe_90/moe_gemm_kernels_fp8_fp8.cuda.o 
[2025-11-06T07:40:48.794Z] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output fused_moe_90/moe_gemm_kernels_fp8_fp8.cuda.o.d -DPy_LIMITED_API=0x03090000 -D_GLIBCXX_USE_CXX11_ABI=1 -I/workspace/csrc/nv_internal -I/workspace/csrc/nv_internal/include -I/workspace/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include -I/workspace/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include -I/workspace/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels -isystem /opt/conda/envs/py312/include/python3.12 -isystem /usr/local/cuda/include -isystem /usr/local/cuda/include/cccl -isystem /tmp/build-env-vsqya_iz/lib/python3.12/site-packages/tvm_ffi/include -isystem /tmp/build-env-vsqya_iz/lib/python3.12/site-packages/tvm_ffi/include -isystem /workspace/include -isystem /workspace/csrc -isystem /workspace/3rdparty/cutlass/include -isystem /workspace/3rdparty/cutlass/tools/util/include -isystem /workspace/3rdparty/spdlog/include --compiler-options=-fPIC --expt-relaxed-constexpr -DFLASHINFER_ENABLE_FP8_E8M0 -DFLASHINFER_ENABLE_FP4_E2M1 -std=c++17 --threads=1 -use_fast_math -DFLASHINFER_ENABLE_F16 -DFLASHINFER_ENABLE_BF16 -DFLASHINFER_ENABLE_FP8_E4M3 -DFLASHINFER_ENABLE_FP8_E5M2 -DNDEBUG -O3 -gencode=arch=compute_90a,code=sm_90a -DFLASHINFER_ENABLE_FP8_E8M0 -DFLASHINFER_ENABLE_FP4_E2M1 -DCOMPILE_HOPPER_TMA_GEMMS -DCOMPILE_HOPPER_TMA_GROUPED_GEMMS -DENABLE_BF16 -DENABLE_FP8 -DUSING_OSS_CUTLASS_MOE_GEMM -c /workspace/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu -o fused_moe_90/moe_gemm_kernels_fp8_fp8.cuda.o 
[2025-11-06T07:40:48.794Z] /workspace/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h(99): error: identifier "__nv_fp4_e2m1" is undefined
[2025-11-06T07:40:48.794Z]                     cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value ||
[2025-11-06T07:40:48.794Z]                                                            ^
[2025-11-06T07:40:48.795Z] 
[2025-11-06T07:40:48.795Z] /workspace/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h(740): error: identifier "__nv_fp4_e2m1" is undefined
[2025-11-06T07:40:48.795Z]       if constexpr (!std::is_same_v<WeightType, __nv_fp4_e2m1>) {
[2025-11-06T07:40:48.795Z]                                                 ^
[2025-11-06T07:40:48.795Z] 
[2025-11-06T07:40:48.795Z] /workspace/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h(748): error: identifier "__nv_fp4_e2m1" is undefined
[2025-11-06T07:40:48.795Z]       if constexpr (!std::is_same_v<WeightType, __nv_fp4_e2m1>) {
[2025-11-06T07:40:48.795Z]                                                 ^
[2025-11-06T07:40:48.795Z] 
[2025-11-06T07:40:48.795Z] 3 errors detected in the compilation of "/workspace/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu".

Likely because we didn't add guard on the usage of __nv_fp4_e2m1 properly (cu126 do not support fp4).

We will retire cu126 at some point, but not now (considering cu126 is still one of the three supported cuda version of pytorch).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)

99-102: FP4 guards insufficient for CUDA 12.6 compatibility

The guard only checks ENABLE_FP4, but __nv_fp4_e2m1 requires CUDA 12.8+. The CI failure on cu126 (reported in PR objectives) confirms this: the identifier is undefined because CUDA 12.6 doesn't provide it. Same issue exists at lines 249-253, 742-746, and 755-759.

Apply guards that also check CUDA version:

-#if defined(ENABLE_FP4)
+#if defined(ENABLE_FP4) && CUDA_VERSION >= 12080
                   cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value ||
 #endif

Repeat for all FP4 type references at lines 249-253, 742-746, and 755-759.

♻️ Duplicate comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)

672-676: Fix zero-argument call to supportsTmaWarpSpecialized

This duplicates a past review concern: isTmaWarpSpecialized calls supportsTmaWarpSpecialized() without arguments on line 675, but the signature at lines 679-688 now requires an int sm parameter. The same issue occurs at line 920 in calcMaxWorkspaceSize.

Apply this diff to forward the member's sm_:

-  return supportsTmaWarpSpecialized() && config_is_tma_warp_specialized;
+  return supportsTmaWarpSpecialized(sm_) && config_is_tma_warp_specialized;

Also fix line 920:

-  if (!supportsTmaWarpSpecialized()) {
+  if (!supportsTmaWarpSpecialized(sm_)) {

Alternatively, add a const wrapper in the class:

bool supportsTmaWarpSpecialized() const {
  return supportsTmaWarpSpecialized(sm_);
}

Based on learnings

🧹 Nitpick comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)

953-956: Consider extending FINALIZE fusion workspace calculation beyond SM90

FINALIZE fusion workspace size is currently only calculated for SM90 (line 954). If other architectures (e.g., SM100+) support finalize fusion, they should also be included in this calculation to avoid underestimating workspace requirements.

πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between e15a96c and 33aec35.

πŸ“’ Files selected for processing (1)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (14 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (6)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (3)
  • tensorrt_llm (63-112)
  • std (81-95)
  • calcMaxWorkspaceSizeTmaWarpSpecialized (490-502)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h (1)
  • tensorrt_llm (19-34)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h (1)
  • tensorrt_llm (60-274)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h (9)
  • tensorrt_llm (33-150)
  • kernels (34-149)
  • cutlass (114-116)
  • cutlass (120-122)
  • cutlass (127-129)
  • cutlass (132-134)
  • cutlass (140-142)
  • cutlass_kernels (35-148)
  • __nv_fp8_e5m2 (91-93)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (2)
  • get_candidate_configs (638-689)
  • get_candidate_configs (638-640)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h (1)
  • EpilogueScheduleType (197-433)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
πŸ”‡ Additional comments (4)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (4)

530-544: LGTM: Clean signature updates for finalize fusion support

The addition of the supports_finalize_fusion parameter to both the const member and static getConfigs methods properly threads this capability flag through the config selection pipeline.


624-629: Verify SM103 FP4 config selection strategy

The code explicitly adds SM100 configs when running on SM103 with FP4. Ensure this cross-architecture config reuse is validated and doesn't cause performance regressions or compatibility issues.


631-666: Well-structured finalize fusion and swap_ab config expansion

The logic correctly:

  1. Duplicates configs and marks them with FINALIZE fusion type when supported (lines 631-640)
  2. Removes FINALIZE configs that lack epilogue SMEM (lines 642-650)
  3. Adds swap_ab variants for all configs (lines 653-659) with a defensive check
  4. Filters to swap_ab=true only for w4_groupwise mode (lines 661-666)

978-1007: Activation type dispatch looks correct

The switch statement appropriately handles the supported activation types (Relu, Gelu, Silu, Identity, Swiglu, Geglu) and throws for invalid types. Note that Relu2 from the ActivationType enum is not handled, which appears intentional per the AI summary noting "Relu2 path removed (no longer supported)".

@nv-yunzheq
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !104 has been updated with latest changes, and the CI pipeline #38037173 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@nvmbreughe nvmbreughe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
Perhaps just add the additional tests for DSR1 and autotuner we discussed.

cute::make_shape(gemm_n, gemm_k, 1));
}
if (layout_info.stride_c) {
// TODO Enable 1xN bias matrix as C
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we don't support batch size = 1 ?

Copy link
Contributor Author

@nv-yunzheq nv-yunzheq Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's just the bias tensor could not be 1xN

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants