Skip to content

Add more flex attention cases to benchmark. #3928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 28, 2025

Conversation

chengjunlu
Copy link
Contributor

@chengjunlu chengjunlu commented Apr 15, 2025

Add the flex attention shapes which is used by real model to benchmark for tracking performance.

I commented out 4 cases for now for the reason:

  1. There is not enough share local memory for the Triton kernel.
  • Append shapes of Deepseek-v3 (Nope)
  • Decode shapes of Deepseek-v3 (Nope)
  1. Flex Attention doesn't support such kind of shapes: Error: LoweringException: ValueError: Number of shared query heads sharing the same KV head must be power of 2.
  • Decode shapes of Qwen2-7B
  1. Triton kernel block shapes must be power of 2.
  • Decode shapes of Phi3-mini-3.8B

We will investigate the first issue on Triton side later.

@chengjunlu
Copy link
Contributor Author

@liangan1 Please help to comments the expected configuration for the flex attention benchmark.
I will fill up the configuration with reasonable cases from the real use case.

@mfrancepillois mfrancepillois linked an issue Apr 15, 2025 that may be closed by this pull request
@etiotto
Copy link
Contributor

etiotto commented Apr 17, 2025

Note the bmk failure "TypeError: benchmark() got an unexpected keyword argument 'B'"

@vlad-penkin vlad-penkin self-requested a review April 18, 2025 09:55
@chengjunlu chengjunlu force-pushed the chengjun/add_more_flex_attention_varient branch from cc449b5 to 7dca5bc Compare April 23, 2025 01:47
@chengjunlu chengjunlu changed the title Draft add more flex attention cases to benchmark. Add more flex attention cases to benchmark. Apr 23, 2025
@chengjunlu chengjunlu force-pushed the chengjun/add_more_flex_attention_varient branch 3 times, most recently from 76d3c3b to 01b9091 Compare April 23, 2025 02:19
q_elems = H_q * N_CTX_q * D_HEAD_qk
k_elems = H_kv * N_CTX_kv * D_HEAD_qk
v_elems = H_kv * N_CTX_kv * D_HEAD_v
gbps = lambda mean: Z * (q_elems + k_elems + v_elems) * 2 * (1e-9) / (mean * 1e-3) # float16 2 bytes
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only the GEEM computation and inputs are considered for calculating the tflops and gbps.

@chengjunlu chengjunlu force-pushed the chengjun/add_more_flex_attention_varient branch 2 times, most recently from b5216d6 to 5abd839 Compare April 23, 2025 03:37
@chengjunlu
Copy link
Contributor Author

chengjunlu commented Apr 23, 2025

There is an accuracy issue caused by regression on flex decoding. Need to solve the flex decoding regression issue first.

AssertionError: 
Not equal to tolerance rtol=0.001, atol=0.1

Mismatched elements: 344 / 4096 (8.4%)
Max absolute difference: 1.766
Max relative difference: 0.654
 x: array([[[[ 0.4414  , -0.003366, -0.4326  , ...,  0.0951 ,  0.3289  ,
          -0.614   ]],
...
 y: array([[[[ 1.273  , -0.00971, -1.248  , ...,  0.2744 ,  0.9487 ,
          -1.7705 ]],

@etiotto
Copy link
Contributor

etiotto commented Apr 23, 2025

@chengjunlu the test failed due to diff in result. Did you forget to push a local change ?

@whitneywhtsang whitneywhtsang force-pushed the chengjun/add_more_flex_attention_varient branch from bf4c520 to 9d64d49 Compare April 23, 2025 22:23
@whitneywhtsang
Copy link
Contributor

whitneywhtsang commented Apr 23, 2025

There is an accuracy issue caused by regression on flex decoding. Need to solve the flex decoding regression issue first.

AssertionError: 
Not equal to tolerance rtol=0.001, atol=0.1

Mismatched elements: 344 / 4096 (8.4%)
Max absolute difference: 1.766
Max relative difference: 0.654
 x: array([[[[ 0.4414  , -0.003366, -0.4326  , ...,  0.0951 ,  0.3289  ,
          -0.614   ]],
...
 y: array([[[[ 1.273  , -0.00971, -1.248  , ...,  0.2744 ,  0.9487 ,
          -1.7705 ]],

Similar to other problematic shapes, how about we comment out the decode shape in this PR, and fix it in another PR? (93882ff)

@whitneywhtsang whitneywhtsang force-pushed the chengjun/add_more_flex_attention_varient branch from 9d64d49 to 93882ff Compare April 24, 2025 00:40
@chengjunlu
Copy link
Contributor Author

There is an accuracy issue caused by regression on flex decoding. Need to solve the flex decoding regression issue first.

AssertionError: 
Not equal to tolerance rtol=0.001, atol=0.1

Mismatched elements: 344 / 4096 (8.4%)
Max absolute difference: 1.766
Max relative difference: 0.654
 x: array([[[[ 0.4414  , -0.003366, -0.4326  , ...,  0.0951 ,  0.3289  ,
          -0.614   ]],
...
 y: array([[[[ 1.273  , -0.00971, -1.248  , ...,  0.2744 ,  0.9487 ,
          -1.7705 ]],

Similar to other problematic shapes, how about we comment out the decode shape in this PR, and fix it in another PR? (93882ff)

Sounds good to me. Let's add the benchmark first. Let's use other issue to track the decoding regression issue.

@chengjunlu chengjunlu force-pushed the chengjun/add_more_flex_attention_varient branch 2 times, most recently from 55afaa4 to f582a67 Compare April 24, 2025 05:41
@chengjunlu
Copy link
Contributor Author

There is an accuracy issue caused by regression on flex decoding. Need to solve the flex decoding regression issue first.

AssertionError: 
Not equal to tolerance rtol=0.001, atol=0.1

Mismatched elements: 344 / 4096 (8.4%)
Max absolute difference: 1.766
Max relative difference: 0.654
 x: array([[[[ 0.4414  , -0.003366, -0.4326  , ...,  0.0951 ,  0.3289  ,
          -0.614   ]],
...
 y: array([[[[ 1.273  , -0.00971, -1.248  , ...,  0.2744 ,  0.9487 ,
          -1.7705 ]],

Similar to other problematic shapes, how about we comment out the decode shape in this PR, and fix it in another PR? (93882ff)

The accuracy issue has been fixed by #3999.

I will rebase this PR after the #3999 merged.

@chengjunlu chengjunlu force-pushed the chengjun/add_more_flex_attention_varient branch from f582a67 to 04d6f4f Compare April 25, 2025 07:33
@chengjunlu chengjunlu merged commit 94b3473 into main Apr 28, 2025
10 checks passed
@chengjunlu chengjunlu deleted the chengjun/add_more_flex_attention_varient branch April 28, 2025 02:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FlexAttention] Enhance benchmark with GQA
3 participants