Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Support segment_ids/pos as FA inputs #1406

Open
wants to merge 21 commits into
base: main
Choose a base branch
from

Conversation

zlsh80826
Copy link
Collaborator

@zlsh80826 zlsh80826 commented Jan 13, 2025

Description

This PR adds segment_ids/pos limited support and deprecated fused_attn_thd API.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

  • Add a new SequenceDescriptor class for different sequence descriptions scenario.
    • from_seqlens for non-THD
    • from_seqlens_and_offsets for THD
    • from_segment_ids_and_pos for THD + ring attn (haven't implemented)
  • Change the old fused_attn mask parameter to SequenceDescriptor. Passing mask in the position argument will work for a while but generating deprecation warning.
  • Deprecate fused_attn_thd API as the refactored fused_attn can also support THD format.
  • Remove small inputs in test_fused_attn.py as the long sequence inputs should cover.
  • Add different sequence inputs tests in test_fused_attn.py

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
@zlsh80826 zlsh80826 force-pushed the rewang/test-segment-ids branch from 08a7582 to e62c049 Compare January 14, 2025 09:57
@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

@@ -709,10 +727,7 @@ def check_dqkv(primitive, reference, pad):
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d, dtype",
[
pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these removals to cut down on number of [redundant] test cases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think those tests are unnecessary because we no longer use max_512 kernels for seqlen < 512. The kernels for seqlen=128 are the same as seqlen=2048, so the 2048 test cases already cover the 128 cases.

@pytest.mark.parametrize(
"seq_desc_format",
[
pytest.param(SeqDescFormat.Mask, id="Mask"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it help to cut down on test cases by creating a standalone unit test for the SequenceDesc to cover and check all of the cases? Then in this unit test we can use either Seqlens or SegmentIDs depending on THD or BSHD?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea

else:

def generate_default_pos(segment_ids):
seqlen = segment_ids.shape[-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't see if we applied max_segments_per_seq anywhere when generating the seqlen and offset to cudnn. I found that this was a very useful to help limit the overhead of the jax code that sets up seqlen/offset rather than assuming max_seq_len (see _get_seqlens_and_offsets). We should do a quick benchmark to see how much overhead if any is incurred.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added fea6b0e and cfe00a8 to reduce the seqlen and offset shape to (batch, max_segments_per_seq)

Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants