Skip to content

fix(CP, FA): the conditional logic in the FA version contains a vulnerability when processing the output of Flash Attn forward pass#2825

Open
zhujian19891203 wants to merge 1 commit intoNVIDIA:mainfrom
021ai:CP_FA_fix
Open

fix(CP, FA): the conditional logic in the FA version contains a vulnerability when processing the output of Flash Attn forward pass#2825
zhujian19891203 wants to merge 1 commit intoNVIDIA:mainfrom
021ai:CP_FA_fix

Conversation

@zhujian19891203
Copy link
Copy Markdown
Contributor

Description

The conditional logic in the FA version contains a vulnerability when processing the output of Flash Attn forward pass, this maybe caused accidentally.

For example, when I only install FA3, and FA2 is totally not installed, something is wrong.
image

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Just context_parallel.py file

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…he conditional logic in the FA version contains a vulnerability, fix it
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 2, 2026

Greptile Summary

This PR fixes a conditional logic bug in three symmetric locations within context_parallel.py where the Flash Attention forward-pass output indices were selected based solely on fa_utils.v2_7_0_plus, ignoring use_flash_attn_3. When only FA3 is installed (FA2 absent), the original code fell into the old-FA2 branch and incorrectly accessed fa_outputs[4]/fa_outputs[5]/fa_outputs[7], whereas FA3's output is indexed at [0]/[1]. The fix adds not use_flash_attn_3 to the guard so FA3 always routes through the correct else branch regardless of the FA2 version flags.

Confidence Score: 5/5

Safe to merge — the fix is minimal, correct, and consistently applied across all three affected call sites.

All three changes follow the same correct pattern: the use_flash_attn_3 guard is added to prevent FA3 from being misrouted into the old FA2 index scheme. The else branch already handled FA3 correctly, and the rng_state/rng_states assignments are consistent with the pre-existing None path for FA3. No new logic paths are introduced; only the dead-code/wrong-index branch for FA3 is eliminated. No P0/P1 findings remain.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Three symmetric fixes to the FA forward-pass output index selection: adds not use_flash_attn_3 guard so FA3-only installations (no FA2) correctly use output indices [0,1] instead of the FA2-old-format indices [4,5,7].

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[flash_attn_fwd returns fa_outputs] --> B{use_flash_attn_3?}
    B -- Yes --> C[else branch
out = fa_outputs 0
lse = fa_outputs 1
rng = None]
    B -- No --> D{fa_utils.v2_7_0_plus?}
    D -- No
Old FA2 format --> E[if branch
out = fa_outputs 4
lse = fa_outputs 5
rng = fa_outputs 7]
    D -- Yes
New FA2 format --> F[else branch
out = fa_outputs 0
lse = fa_outputs 1
rng = fa_outputs 3]
Loading

Reviews (1): Last reviewed commit: "fix(CP, FA): when processing the output ..." | Re-trigger Greptile

@zhujian19891203 zhujian19891203 changed the title Fix bug: the conditional logic in the FA version contains a vulnerability when processing the output of Flash Attn forward pass Fix(CP, FA): the conditional logic in the FA version contains a vulnerability when processing the output of Flash Attn forward pass Apr 3, 2026
@zhujian19891203 zhujian19891203 changed the title Fix(CP, FA): the conditional logic in the FA version contains a vulnerability when processing the output of Flash Attn forward pass fix(CP, FA): the conditional logic in the FA version contains a vulnerability when processing the output of Flash Attn forward pass Apr 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant