Skip to content

Commit 5abd839

Browse files
committed
Add more flex attention cases to benchmark.
1 parent fa4cfa0 commit 5abd839

File tree

1 file changed

+96
-79
lines changed

1 file changed

+96
-79
lines changed

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 96 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,19 @@
33
import os
44
from torch.nn.attention.flex_attention import (
55
create_block_mask,
6+
create_mask,
67
flex_attention,
78
)
89

910
import torch
1011
import torch.nn.functional as F
12+
1113
import triton_kernels_benchmark as benchmark_suit
12-
from triton_kernels_benchmark import xetla_kernel
1314

1415
torch._dynamo.config.recompile_limit = 100 # pylint: disable=protected-access
1516

1617
# Compile the flex_attention function
17-
flex_attention = torch.compile(flex_attention, dynamic=False)
18+
compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
1819

1920

2021
@lru_cache
@@ -27,112 +28,128 @@ def causal_mask(_, __, q_idx, kv_idx):
2728
return q_idx >= kv_idx
2829

2930

31+
throughput_test = os.getenv('THROUGHPUT_TEST', '0') == '1'
32+
batch_sizes = [16, 32, 64] if throughput_test else [1]
33+
34+
3035
# Kernel profiling for Backward mode is not working as expected:
3136
# For details: https://github.com/pytorch/pytorch/issues/144778
3237
@benchmark_suit.perf_report(
3338
benchmark_suit.Benchmark(
34-
x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'CAUSAL', 'MODE'],
35-
x_vals=[[z, h, 16384 // z, dhead, causal, mode]
36-
for z in [1, 2, 4, 8, 16, 32]
37-
for (h, dhead) in [(16, 128), (32, 64)]
38-
for causal in [True]
39-
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
40-
+ [[4, 48, 1024, 64, True, mode] for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
41-
+ [[z, h, 1024, dhead, True, mode]
42-
for z in [1, 2, 4, 8, 16, 32, 64]
43-
for (h, dhead) in [(8, 128), (32, 96), (4, 128)]
44-
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]],
39+
x_names=['Z', 'H_q', 'H_kv', 'N_CTX_q', 'N_CTX_kv', 'D_HEAD_qk', 'D_HEAD_v', 'MODE'],
40+
x_vals=
41+
# Multi-head attention. H_q equals H_kv
42+
# Prefill shapes of Phi3-mini-3.8B
43+
[[z, 32, 32, 1024, 1024, 96, 96, 'fwd'] for z in batch_sizes] +
44+
# Prefill shapes of Deepseek-v3
45+
[[z, 128, 128, 1024, 1024, 192, 128, 'fwd'] for z in batch_sizes] +
46+
# Append shapes of Phi3-mini-3.8B
47+
[[z, 32, 32, 512, 1024 + 128 + 512, 96, 96, 'fwd'] for z in batch_sizes] +
48+
49+
# Multi-query attention. H_kv equals 1.
50+
# Append shapes of Deepseek-v3 (Nope)
51+
[
52+
# RuntimeError: No valid triton configs. OutOfResources: out of resource: shared memory, Required: 133120, Hardware limit: 131072.
53+
# [z, 128, 1, 512, 1024 + 128 + 512, 64, 512, 'fwd'] for z in batch_sizes
54+
] +
55+
# Append shapes of Deepseek-v3 (Rope)
56+
[] +
57+
58+
# Grouped-query attention. H_q / H_kv > 1
59+
# Prefill shapes of Llama-3.1-8B
60+
[[z, 32, 8, 1024, 1024, 128, 128, 'fwd'] for z in batch_sizes] +
61+
# Prefill shapes of Qwen2-7B
62+
[[z, 28, 4, 1024, 1024, 128, 128, 'fwd'] for z in batch_sizes] +
63+
# Append shapes of Llama-3.1-8B
64+
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, 'fwd'] for z in batch_sizes] +
65+
# Append shapes of Qwen2-7B
66+
[[z, 28, 4, 512, 1024 + 128 + 512, 128, 128, 'fwd'] for z in batch_sizes] +
67+
68+
# FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
69+
# Decode shapes of Llama-3.1-8B
70+
[[z, 32, 8, 1, 1024 + 64, 128, 128, 'fwd'] for z in batch_sizes] +
71+
# Decode shapes of Phi3-mini-3.8B
72+
[
73+
# acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
74+
# ValueError: Shape element 2 must be a power of 2
75+
# [z, 32, 32, 1, 1024 + 64, 96, 96, 'fwd'] for z in batch_sizes
76+
] +
77+
# Decode shapes of Qwen2-7B
78+
[
79+
# torch._inductor.exc.InductorError: LoweringException: ValueError: Number of shared query heads sharing the same KV head must be power of 2.
80+
# [z, 28, 4, 1, 1024 + 64, 128, 128, 'fwd'] for z in batch_sizes
81+
] +
82+
# Decode shapes of Deepseek-v3 (Nope)
83+
[
84+
# RuntimeError: No valid triton configs. OutOfResources: out of resource: shared memory, Required: 264192, Hardware limit: 131072.
85+
# [z, 128, 1, 1, 1024, 64, 512, 'fwd'] for z in batch_sizes
86+
] +
87+
# Decode shapes of Deepseek-v3 (Rope)
88+
[],
4589
line_arg='provider',
46-
line_vals=['triton', 'xetla'],
47-
line_names=['Triton', 'XeTLA'],
90+
line_vals=['triton'],
91+
line_names=['Triton'],
4892
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
4993
ylabel=['GB/s', 'TFlops'],
5094
plot_name='flexAttnCausal-performance',
5195
args={},
5296
))
53-
def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
54-
assert MODE in ['fwd', 'bwd']
55-
assert CAUSAL
97+
def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provider):
98+
assert MODE in ['fwd']
5699
dtype = torch.float16
57-
q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
58-
k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
59-
v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
100+
q = torch.randn((Z, H_q, N_CTX_q, D_HEAD_qk), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd')
101+
k = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD_qk), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd')
102+
v = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD_v), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd')
60103
sm_scale = 0.125
61104
if MODE == 'bwd':
62105
sm_scale = 1.3
63106

64107
quantiles = [0.5, 0.0, 1.0]
65108
if provider == 'triton':
66-
kernel_options = {'num_stages': 2, 'num_warps': 16 if D_HEAD == 128 else 8, 'BLOCKS_ARE_CONTIGUOUS': True}
67-
block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX, N_CTX, device=q.device)
68-
triton_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, kernel_options=kernel_options
69-
)
109+
kernel_options = {'num_stages': 2, 'num_warps': 16 if D_HEAD_qk == 128 else 8, 'BLOCKS_ARE_CONTIGUOUS': True}
110+
block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device='xpu')
111+
triton_fn = lambda: compiled_flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=(
112+
not H_q == H_kv), kernel_options=kernel_options)
113+
torch_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=not H_q == H_kv)
70114
if MODE == 'bwd':
71115
triton_o = triton_fn()
72116
triton_do = torch.randn_like(triton_o)
73117
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
74-
torch_fn = lambda: F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), is_causal=True, scale=sm_scale).to(
75-
torch.float32)
76-
if MODE == 'bwd':
77-
torch_o = torch_fn()
78-
torch_do = torch.randn_like(torch_o)
79-
torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True)
80-
if MODE == 'fwd':
81-
atol = 1e-1 if N_CTX == 16384 else 1e-2
82-
benchmark_suit.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch')
83-
else:
84-
benchmark_suit.assert_close(lambda: triton_o, lambda: torch_o, atol=1e-2, rtol=0, err_msg='triton to torch')
118+
119+
atol = 1e-1
120+
benchmark_suit.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch')
85121
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
86122

87-
elif provider == 'xetla':
88-
xetla_fn = None
89-
if MODE == 'fwd':
90-
module_name = 'flash_attn_causal_True'.lower()
91-
func = getattr(xetla_kernel, module_name)
92-
out = torch.empty_like(q, device='xpu', dtype=dtype)
93-
size_score = Z * H * N_CTX * N_CTX
94-
size_attn_mask = Z * N_CTX * N_CTX
95-
dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8)
96-
bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype)
97-
size_ml = Z * H * N_CTX
98-
m = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
99-
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
100-
xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
101-
if MODE == 'bwd':
102-
module_name = 'flash_attn_bwd_causal_True'.lower()
103-
func = getattr(xetla_kernel, module_name)
104-
grad_out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
105-
bias = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
106-
dropout = torch.empty_like(q, device='xpu', dtype=torch.uint8)
107-
out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
108-
log_sumexp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
109-
workspace = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
110-
grad_q_tmp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
111-
alpha = sm_scale
112-
dropout_prob = 0
113-
grad_query = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
114-
grad_key = torch.empty_like(k, device='xpu', dtype=dtype, requires_grad=True)
115-
grad_value = torch.empty_like(v, device='xpu', dtype=dtype, requires_grad=True)
116-
grad_bias = torch.empty_like(bias, device='xpu', dtype=dtype, requires_grad=True)
117-
bias_strideB = -1
118-
bias_strideN = -1
119-
bias_strideF = -1
120-
attn_mask_padding = 0
121-
122-
xetla_fn = lambda: func(grad_out, q, k, v, bias, dropout, out, log_sumexp, workspace, grad_q_tmp, alpha,
123-
dropout_prob, grad_query, grad_key, grad_value, grad_bias, Z, H, D_HEAD, N_CTX,
124-
N_CTX, bias_strideB, bias_strideN, bias_strideF, attn_mask_padding)
125-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
123+
elif provider == 'onednn':
124+
# OneDNN only supports MHA.
125+
if H_q == H_kv:
126+
mask = create_mask(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device=q.device)
127+
xformers_fn = lambda: F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
128+
if MODE == 'bwd':
129+
xformers_o = xformers_fn()
130+
xformers_do = torch.randn_like(xformers_o)
131+
xformers_fn = lambda: xformers_o.backward(xformers_do, retain_graph=True)
132+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xformers_fn, n_warmup=10, n_repeat=10,
133+
quantiles=quantiles)
134+
else:
135+
_, min_ms, max_ms, mean, cv = float('nan'), float('nan'), float('nan'), float('nan'), float('nan')
126136

127137
else:
128138
raise NotImplementedError(f'Unsupported provider {provider}')
129139

130-
tflops = lambda mean: 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3)
131-
gbps = lambda mean: Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
140+
qk_flops = H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk
141+
pv_flops = H_q * N_CTX_q * D_HEAD_v * N_CTX_kv
142+
tflops = lambda mean: Z * (qk_flops + pv_flops) * (1e-12) / (mean * 1e-3)
143+
144+
q_elems = H_q * N_CTX_q * D_HEAD_qk
145+
k_elems = H_kv * N_CTX_kv * D_HEAD_qk
146+
v_elems = H_kv * N_CTX_kv * D_HEAD_v
147+
gbps = lambda mean: Z * (q_elems + k_elems + v_elems) * 2 * (1e-9) / (mean * 1e-3) # float16 2 bytes
132148

133149
if MODE == 'bwd':
134-
tflops = lambda mean: 2.5 * 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3)
135-
gbps = lambda mean: 2.5 * Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
150+
tflops = lambda mean: 2.5 * 2 * 2 * Z * H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk * (1e-12) / (mean * 1e-3)
151+
gbps = lambda mean: 2.5 * Z * H_q * (N_CTX_q * D_HEAD_qk + N_CTX_kv * D_HEAD_qk) * 2 * 2 * (1e-9) / (mean * 1e-3
152+
)
136153

137154
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
138155

0 commit comments

Comments
 (0)