Skip to content

Streamline group Hadamard ComputeKernel loads#2810

Open
cael-ling wants to merge 13 commits intoNVIDIA:mainfrom
cael-ling:refactor/grp-hadamard-ldmatrix-transpose
Open

Streamline group Hadamard ComputeKernel loads#2810
cael-ling wants to merge 13 commits intoNVIDIA:mainfrom
cael-ling:refactor/grp-hadamard-ldmatrix-transpose

Conversation

@cael-ling
Copy link
Copy Markdown
Contributor

@cael-ling cael-ling commented Mar 29, 2026

Description

Superseded: This work has been rolled into #2820; please review that PR instead.

Reorders ComputeKernel to Transposed → Pre-RHT → Identity (per enabled kReturn* flags).
Transposed path: uses ldmatrix_x4_m8n8_shared_b16 instead of row-major load plus four in-register transposes before the same WMMA operand pattern. This reduces instruction count and warp-synchronous work on the hot path, improving performance.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@cael-ling cael-ling closed this Mar 29, 2026
@cael-ling cael-ling reopened this Mar 29, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 29, 2026

Greptile Summary

This PR refactors the ComputeKernel in all three Hadamard transform files by (1) reordering execution from Transposed → Pre-RHT → Identity, (2) replacing the row-major load + four in-register matrix_transpose_m8_n8_b16_inplace calls with a single ldmatrix_x4_m8n8_shared_b16<true> transposed load, and (3) hoisting the constant swizzle_idx computation out of the inner function into the kernel's setup code.

  • Stale-frag fix (previously flagged issue is resolved): The identity amax block (kReturnIdentityAmax) now unconditionally performs a fresh ldmatrix_x4_m8n8_shared_b16<false> load before the MMA. This guarantees correct register state regardless of which earlier paths ran.
  • Pre-RHT reuse of transposed fragments is correct: When kReturnTransposedAmax=true, the pre-RHT block skips the reload and operates on a_frag already loaded by the transposed path. Since the pre-RHT reduction only computes max-abs across all fragment registers (max.xorsign.abs.bf16x2 tree over a_frag[0..3]), and all 32 warp lanes together cover the full 16×16 matrix regardless of whether the layout is transposed, the final warp-reduced pre-RHT amax is mathematically invariant under transposition.
  • swizzle_idx hoisting is correct: The index depends only on threadIdx.x and the compile-time constant kHadamardDimension, both of which are fixed for the lifetime of the kernel. Different stage_y/stage_x iterations pass different in_sh_ptr base pointers, but the relative swizzle_idx within each chunk is the same, so hoisting it once is safe.
  • The three files (hadamard_transform.cu, group_hadamard_transform.cu, graph_safe_group_hadamard_transform.cu) receive identical changes, keeping them consistent.

Confidence Score: 5/5

This PR is safe to merge; the previously flagged stale-register bug is resolved and all execution paths are mathematically correct.

All three template-path combinations have been verified: the transposed path uses a correct direct load; the pre-RHT path correctly reuses transposed fragments for a max-abs reduction that is transpose-invariant; and the identity path always reloads fresh data, directly addressing the previously raised concern. No new logic defects, data-integrity issues, or correctness risks are introduced.

No files require special attention

Important Files Changed

Filename Overview
transformer_engine/common/hadamard_transform/group_hadamard_transform.cu ComputeKernel reordered to Transposed → Pre-RHT → Identity; swizzle_idx hoisted outside; transposed path now uses direct ldmatrix_x4 transposed load; identity always does a fresh reload (fixing previously flagged stale-frag bug)
transformer_engine/common/hadamard_transform/hadamard_transform.cu Identical structural refactoring as group_hadamard_transform.cu — same reordering, swizzle_idx hoisting, direct transposed load, and unconditional identity reload
transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu Same refactoring applied to the graph-safe variant; changes are symmetric with the other two files

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    Entry["ComputeKernel(b_frag_i, b_frag_t, in_sh_ptr, swizzle_idx, ...)"]

    Entry --> T{kReturnTransposedAmax?}

    T -- yes --> TL["ldmatrix_x4 transposed load → a_frag[0..3]"]
    TL --> TMMA["MMA (transposed layout)\na_frag[0,2,1,3] × b_frag_t → c_frag\nupdate local_amax_t_reg"]
    TMMA --> P

    T -- no --> P{kReturnPreRhtAmax?}

    P -- yes, Transposed ran --> PMax["max-abs reduction over a_frag[0..3]\n(transposed frags; result is transpose-invariant)\nupdate local_pre_rht_amax_reg"]
    P -- yes, Transposed did NOT run --> PR["ldmatrix_x4 row-major load → a_frag[0..3]"]
    PR --> PMax
    PMax --> I

    P -- no --> I{kReturnIdentityAmax?}

    I -- yes --> IL["ldmatrix_x4 row-major load → a_frag[0..3]\n(unconditional fresh reload)"]
    IL --> IMMA["MMA (identity layout)\na_frag[0..3] × b_frag_i → c_frag\nupdate local_amax_reg"]
    IMMA --> Done["return"]

    I -- no --> Done
Loading

Reviews (5): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +93 to +96
if (kReturnTransposedAmax || (!kReturnTransposedAmax && !kReturnPreRhtAmax)) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Stale a_frag used for identity MMA when only Pre-RHT + Identity are enabled

The condition on line 93 simplifies via Boolean algebra to kReturnTransposedAmax || !kReturnPreRhtAmax, which means the reload is skipped when kReturnTransposedAmax=false and kReturnPreRhtAmax=true.

In that code path the pre-RHT block just ran (lines 72–90) and left a_frag[0] and a_frag[2] overwritten with intermediate max-reduction results:

a_frag[0] ← max(|a_frag[0]|, |a_frag[1]|)   // line 79
a_frag[2] ← max(|a_frag[2]|, |a_frag[3]|)   // line 82
a_frag[0] ← max(|a_frag[0]|, |a_frag[2]|)   // line 85

When the identity MMA then runs without a fresh load, it consumes these scalar amax values instead of the original matrix fragment data, producing an incorrect identity amax.

Because a fresh row-major load is required in every caller configuration (transposed data from the kReturnTransposedAmax branch is unusable, and kReturnPreRhtAmax corrupts registers regardless), the guard should be dropped and the load made unconditional:

Suggested change
if (kReturnTransposedAmax || (!kReturnTransposedAmax && !kReturnPreRhtAmax)) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

valid concern, can you take a look? @cael-ling

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated kReturnIdentityAmax path: if it is true, perform one extra reload of the values to guarantee correct behavior. @zhongbozhu

cael-ling and others added 4 commits March 29, 2026 02:01
Signed-off-by: Cael Ling <caell@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
…ranspose' into refactor/grp-hadamard-ldmatrix-transpose

Made-with: Cursor
@cael-ling cael-ling force-pushed the refactor/grp-hadamard-ldmatrix-transpose branch from cc3e5f5 to d90ac01 Compare March 29, 2026 09:55
Signed-off-by: Cael Ling <caell@nvidia.com>
}

if (kReturnIdentityAmax) {
if (kReturnTransposedAmax || (!kReturnTransposedAmax && !kReturnPreRhtAmax)) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the double if looks confusing

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not mix the three cases, we should be able to use constexpr to remove the overhead of if, so having duplicated code makes it more readable.

cael-ling and others added 8 commits March 30, 2026 16:26
Signed-off-by: Cael Ling <caell@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
…com:cael-ling/TransformerEngine into refactor/grp-hadamard-ldmatrix-transpose
…:cael-ling/TransformerEngine into refactor/grp-hadamard-ldmatrix-transpose
@cael-ling
Copy link
Copy Markdown
Contributor Author

The change has been applied to variants:(group_hadamard_transform.cu/hadamard_trnsform.cu/graph_safe_group_hadamard_transform.cu)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants