GEMM + Swiglu fused Grouped MLP for MXFP8#2769
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Greptile SummaryThis PR introduces a fused GEMM + SwiGLU kernel for MXFP8 Grouped MLP on SM100 (Blackwell) GPUs, using CuTe DSL kernels from cuDNN front-end. It adds Many correctness concerns (non-contiguous Confidence Score: 2/5Not safe to merge — multiple prior-round P1 correctness issues remain present in the current head with no fix. Several confirmed runtime-crash or silent-wrong-result bugs flagged in earlier review rounds remain unaddressed: (1) fc1_dgrad_kernel_out[d_tensor].view(in_shape) will raise RuntimeError because the quant kernel output requires .permute(2,0,1) before .view(); (2) grouped_fc1_dy data fields are 2-D and scales are unpermuted, producing wrong wgrad; (3) mark_grouped_tensor hard-asserts columnwise_data is not None, crashing on frozen-weight + grad-input; (4) overwrite_main_grad=True + single_grouped_weight=True silently writes wgrad to a scratch buffer, never updating main_grad. transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py (lines 504-516, 611) and transformer_engine/pytorch/utils.py (mark_grouped_tensor) need the most attention before merge. Important Files Changed
Sequence DiagramsequenceDiagram
participant Input
participant FC1_GLU as FC1 GLU Kernel
participant FC2_QUANT as FC2 QUANT Kernel
participant Output
participant FC2_DGLU as FC2 dGLU Kernel
participant FC1_DGRAD as FC1 dGrad Kernel
Note over Input,Output: FORWARD PASS
Input->>FC1_GLU: a_tensor (MXFP8), b_tensor (weight), prob_tensor (scales), bias_tensor
FC1_GLU-->>FC2_QUANT: d_tensor (MXFP8 act), sfd_row_tensor
FC1_GLU-->>FC1_GLU: c_tensor (swiglu_in bf16), d_col_tensor, sfd_col_tensor
FC2_QUANT-->>Output: d_tensor permute(2,0,1).view()
Note over Input,Output: BACKWARD PASS
Output->>FC2_DGLU: a_tensor (dy MXFP8), c_tensor (swiglu_in), b_tensor (FC2 weight col)
FC2_DGLU-->>FC1_DGRAD: d_row_tensor (fc1_dy), sfd_row_tensor
FC2_DGLU-->>FC2_DGLU: dprob_tensor (grad_scales), dbias_tensor
FC1_DGRAD-->>Input: d_tensor .view() needs permute(2,0,1) first
FC2_DGLU-->>FC2_DGLU: wgrad via general_grouped_gemm
FC1_DGRAD-->>FC1_DGRAD: wgrad via general_grouped_gemm
Reviews (35): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
| fc1_x_data = grouped_fc1_x.rowwise_data.view(in_shape[0], in_shape[1]) | ||
| fc1_x_data = fc1_x_data.view(dtype=torch.float8_e4m3fn) | ||
| fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) | ||
| fc1_x_scales = grouped_fc1_x.scale_inv | ||
| fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu) | ||
| fc1_x_scales = fc1_x_scales.view( | ||
| 1, | ||
| in_shape[0] // 128, | ||
| in_shape[1] // 128, | ||
| 32, | ||
| 4, | ||
| 4, | ||
| ) | ||
| fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) |
There was a problem hiding this comment.
No validation that total token count is divisible by 128
The scale tensor view at lines 272–279 uses integer division in_shape[0] // 128 to reshape the MXFP8 scale buffer. If in_shape[0] (i.e., sum(split_sizes)) is not divisible by 128, the view shape product will not match the actual buffer size and either produce incorrect behavior (wrong permute dimensions) or a runtime error with a confusing message.
The constructor checks that in_features % 256 == 0 and out_features % 256 == 0, but nothing validates that the token dimension sum(split_sizes) is divisible by 128 (required by the MXFP8 block-scaling layout). A user passing split sizes like [64, 65] would hit this silently.
The same assumption appears in the backward pass at backward_grouped_mlp.py lines 243–250.
Consider adding a guard before the view:
if in_shape[0] % 128 != 0:
raise ValueError(
f"Total token count must be divisible by 128 for MXFP8 fused kernel, "
f"but got sum(split_sizes)={in_shape[0]}."
)
ptrendx
left a comment
There was a problem hiding this comment.
Not yet done with the full review, but cursory glance shows some leftover debugging code and some other random things that should be cleaned up.
| Tensor setup_ws("setup_ws", std::vector<size_t>{setup_ws_bytes}, DType::kByte); | ||
| Tensor cublas_ws("cublas_ws", std::vector<size_t>{cublas_ws_bytes}, DType::kByte); | ||
|
|
||
| nvte_grouped_gemm_with_discrete_out(grouped_A.get_handle(), |
There was a problem hiding this comment.
Not a fan of this name, but it was added in another PR, so not a problem here.
transformer_engine/common/include/transformer_engine/transformer_engine.h
Outdated
Show resolved
Hide resolved
…as_gq' into fused_mxfp8_grouped_mlp_no_rebase
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
cf2d253 to
88f2b61
Compare
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
| self._apply_delay_wgrad_param_hooks() | ||
|
|
||
| def _apply_delay_wgrad_param_hooks(self) -> None: | ||
| """Set ``skip_backward_post_hook`` on weights when delaying wgrad (bias uses main backward).""" | ||
| if not self.wgrad_store.delay_wgrad_compute(): | ||
| return | ||
| if self.single_grouped_weight: | ||
| self.weight.skip_backward_post_hook = True | ||
| else: | ||
| for group_idx in range(self.num_groups): | ||
| getattr(self, f"weight{group_idx}").skip_backward_post_hook = True |
There was a problem hiding this comment.
AttributeError when device="meta" + single_grouped_weight=True + delay_wgrad_compute=True
When device.type == "meta", line 168 skips reset_parameters(), so make_grouped_weights() is never called and self.weight is never registered. _apply_delay_wgrad_param_hooks() is then called unconditionally at line 173 and immediately accesses self.weight (line 180) → AttributeError at construction time.
This breaks any framework (e.g. Megatron-LM) that initialises models on device="meta" with delayed wgrad and single grouped weights. A simple guard fixes it:
| self._apply_delay_wgrad_param_hooks() | |
| def _apply_delay_wgrad_param_hooks(self) -> None: | |
| """Set ``skip_backward_post_hook`` on weights when delaying wgrad (bias uses main backward).""" | |
| if not self.wgrad_store.delay_wgrad_compute(): | |
| return | |
| if self.single_grouped_weight: | |
| self.weight.skip_backward_post_hook = True | |
| else: | |
| for group_idx in range(self.num_groups): | |
| getattr(self, f"weight{group_idx}").skip_backward_post_hook = True | |
| if device.type != "meta": | |
| self._apply_delay_wgrad_param_hooks() |
pre_first_fuser_forward will call reset_parameters() (and re-invoke _apply_delay_wgrad_param_hooks) once real device params are materialised.
|
/te-ci L1 pytorch |
Remove unnecessary blank line in docstring.
|
/te-ci L1 pytorch |
1 similar comment
|
/te-ci L1 pytorch |
| fc1_dy_tensor_offsets = fc1_ctx.base_split_offsets * fc1_weight_shape[0] | ||
| grouped_fc1_dy = GroupedTensor( | ||
| shape=(out_shape[0], fc1_weight_shape[0]), | ||
| dtype=dtype, | ||
| num_tensors=num_groups, | ||
| quantizer=fc1_ctx.grad_output_quantizer, | ||
| data=fc1_dy_row_data, | ||
| columnwise_data=fc1_dy_col_data, | ||
| scale_inv=fc1_dy_row_scale, | ||
| columnwise_scale_inv=fc1_dy_col_scale, | ||
| first_dims=split_sizes, | ||
| tensor_offsets=fc1_dy_tensor_offsets, | ||
| with_gemm_swizzled_scales=True, | ||
| ) |
There was a problem hiding this comment.
grouped_fc1_dy data tensors not flattened — inconsistency with forward may produce incorrect wgrad
data=fc1_dy_row_data and columnwise_data=fc1_dy_col_data are passed as 2-D tensors (after .view(out_shape[0], fc1_weight_shape[0])), while GroupedTensorStorage expects 1-D flattened buffers (documented at grouped_tensor_storage.py line 44: "ALL data fields are stored as 1D flattened arrays"). The forward's equivalent construction for grouped_fc2_x explicitly flattens every field:
# forward_grouped_mlp.py – what the forward does
data=fc2_in_row_data.reshape(-1),
columnwise_data=fc2_in_col_data.reshape(-1),
scale_inv=fc2_in_row_scale.reshape(-1),
columnwise_scale_inv=fc2_in_col_scale.reshape(-1),The backward omits .reshape(-1) for both data tensors and omits the required permute(5, 2, 4, 0, 1, 3) + reshape(-1) for the scale tensors. When general_grouped_gemm_for_grouped_tensor indexes into the per-group data via element-level offsets it may read from the wrong memory locations, producing silently wrong FC1 weight gradients.
| accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) | ||
| if accumulate_into_main_grad: | ||
| grouped_wgrad = GroupedTensor.make_grouped_tensor_from_rowwise_data( | ||
| num_tensors=num_groups, | ||
| tensor_shape=weight_shape, | ||
| rowwise_data=main_grad, | ||
| dtype=main_grad.dtype, | ||
| ) | ||
|
|
||
| if grouped_wgrad is None: | ||
| grouped_wgrad = GroupedTensor.make_grouped_tensor_with_shapes( | ||
| num_tensors=num_groups, | ||
| shapes=[weight_shape] * num_groups, | ||
| quantizer=None, | ||
| device=device, | ||
| dtype=dtype, | ||
| ) |
There was a problem hiding this comment.
overwrite_main_grad=True + single_grouped_weight=True silently drops wgrad into a scratch buffer
When weight_param.overwrite_main_grad is True, accumulate_into_main_grad is set to False (line 97). Because the if accumulate_into_main_grad: branch is skipped, grouped_wgrad remains None and the fallback at line 107 allocates a new scratch buffer entirely unrelated to main_grad. The GEMM writes the weight gradient into this temporary buffer, which is then discarded — main_grad is never updated.
Compare with the single_grouped_weight=False path (lines 116–128): w_list[idx] = wp.main_grad is set unconditionally before the accumulate_into_main_grad determination, so the GEMM always targets main_grad regardless of overwrite_main_grad.
The fix is to populate grouped_wgrad from main_grad before computing accumulate_into_main_grad:
# single_grouped_weight path, before the accumulate flag computation:
grouped_wgrad = GroupedTensor.make_grouped_tensor_from_rowwise_data(
num_tensors=num_groups,
tensor_shape=weight_shape,
rowwise_data=main_grad,
dtype=main_grad.dtype,
)
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)Comments have been addressed and CI is green now.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
* GEMM + Swiglu fused Grouped MLP for MXFP8 Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * cleanup/lint Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Properly cache the alpha tensor Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * nD dummy grad Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 0 tokens in entire rank Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tmp downgrade cublas version check Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * delayed wgrad tests pass for basic gl Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * merge everything Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Rebase into fused_mxfp8_grouped_mlp; unit tests for delayed wgrad working Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix tests being skipped for fusible ops Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Integrate mxfp8 dbias kernel in group_quantize Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Add bias/dbias fused support with cute GEMMs Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Check bias/dbias support Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Pack biases more efficiently Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * GroupedTensor for biases to avoid concat Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * format Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Support 1D grouped tensor shape for bias and fix checkpointing Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fixes and tests Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Refactor grouped tensor marking for paged stashing Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Remove setting logical_shape in mark_grouped_tensor Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Cleanup logical_shape Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * pass the tests for now Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * address some review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * address review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more cleanups Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * refactor wgrad logic Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Rename argument from single_grouped_parameter to single_grouped_weight Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Check wgrad store context is not empty for 0 token case. Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Test only checks for fusion if fused kernel is available Signed-off-by: Tim Moon <tmoon@nvidia.com> * fix the tolerance to be of bf16 for the cute gemm Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * Update transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> * address further review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address more review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * address more review comments + test for zero grouped tensor work case Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * cublaslt remove zero work gemm avoidance Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the wgrad test Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * split dbias functionality from gq api Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Format and lint Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * port fixes and add better doc for page stashing war Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Guard fusion via env Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Change to trigger CI Remove unnecessary blank line in docstring. * To retrigger CI * Space to trigger the pipeline * fix zero work cublas gemm Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Varun Thumbe <vthumbe@nvidia.com> Co-authored-by: Vasudevan Rengasamy <vrengasamy@nvidia.com> Co-authored-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: