-
Notifications
You must be signed in to change notification settings - Fork 347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Support segment_ids/pos as FA inputs #1406
base: main
Are you sure you want to change the base?
Conversation
/te-ci jax L1 |
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
08a7582
to
e62c049
Compare
/te-ci jax L1 |
@@ -709,10 +727,7 @@ def check_dqkv(primitive, reference, pad): | |||
@pytest.mark.parametrize( | |||
"b, s_q, s_kv, h_q, h_kv, d, dtype", | |||
[ | |||
pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these removals to cut down on number of [redundant] test cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think those tests are unnecessary because we no longer use max_512 kernels for seqlen < 512. The kernels for seqlen=128 are the same as seqlen=2048, so the 2048 test cases already cover the 128 cases.
@pytest.mark.parametrize( | ||
"seq_desc_format", | ||
[ | ||
pytest.param(SeqDescFormat.Mask, id="Mask"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it help to cut down on test cases by creating a standalone unit test for the SequenceDesc
to cover and check all of the cases? Then in this unit test we can use either Seqlens
or SegmentIDs
depending on THD or BSHD?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea
else: | ||
|
||
def generate_default_pos(segment_ids): | ||
seqlen = segment_ids.shape[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't see if we applied max_segments_per_seq
anywhere when generating the seqlen
and offset
to cudnn. I found that this was a very useful to help limit the overhead of the jax code that sets up seqlen/offset rather than assuming max_seq_len
(see _get_seqlens_and_offsets
). We should do a quick benchmark to see how much overhead if any is incurred.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Description
This PR adds
segment_ids/pos
limited support and deprecatedfused_attn_thd
API.Type of change
Changes
SequenceDescriptor
class for different sequence descriptions scenario.from_seqlens
for non-THDfrom_seqlens_and_offsets
for THDfrom_segment_ids_and_pos
for THD + ring attn (haven't implemented)fused_attn
mask
parameter toSequenceDescriptor
. Passingmask
in the position argument will work for a while but generating deprecation warning.fused_attn_thd
API as the refactoredfused_attn
can also support THD format.test_fused_attn.py
as the long sequence inputs should cover.test_fused_attn.py
Checklist: