-
Notifications
You must be signed in to change notification settings - Fork 62
Add more flex attention cases to benchmark. #3928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@liangan1 Please help to comments the expected configuration for the flex attention benchmark. |
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Outdated
Show resolved
Hide resolved
Note the bmk failure "TypeError: benchmark() got an unexpected keyword argument 'B'" |
cc449b5
to
7dca5bc
Compare
76d3c3b
to
01b9091
Compare
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Outdated
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Outdated
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Outdated
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Show resolved
Hide resolved
q_elems = H_q * N_CTX_q * D_HEAD_qk | ||
k_elems = H_kv * N_CTX_kv * D_HEAD_qk | ||
v_elems = H_kv * N_CTX_kv * D_HEAD_v | ||
gbps = lambda mean: Z * (q_elems + k_elems + v_elems) * 2 * (1e-9) / (mean * 1e-3) # float16 2 bytes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only the GEEM computation and inputs are considered for calculating the tflops and gbps.
b5216d6
to
5abd839
Compare
There is an accuracy issue caused by regression on flex decoding. Need to solve the flex decoding regression issue first.
|
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Outdated
Show resolved
Hide resolved
@chengjunlu the test failed due to diff in result. Did you forget to push a local change ? |
bf4c520
to
9d64d49
Compare
Similar to other problematic shapes, how about we comment out the decode shape in this PR, and fix it in another PR? (93882ff) |
9d64d49
to
93882ff
Compare
Sounds good to me. Let's add the benchmark first. Let's use other issue to track the decoding regression issue. |
55afaa4
to
f582a67
Compare
The accuracy issue has been fixed by #3999. I will rebase this PR after the #3999 merged. |
f582a67
to
04d6f4f
Compare
Signed-off-by: Lu,Chengjun <[email protected]>
Add the flex attention shapes which is used by real model to benchmark for tracking performance.
I commented out 4 cases for now for the reason:
We will investigate the first issue on Triton side later.