Skip to content

Commit cc449b5

Browse files
committed
Draft add more flex attention cases to benchmark.
1 parent 2be3060 commit cc449b5

File tree

1 file changed

+89
-73
lines changed

1 file changed

+89
-73
lines changed

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 89 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
)
88

99
import torch
10-
import torch.nn.functional as F
1110
import triton_kernels_benchmark as benchmark_suit
1211
from triton_kernels_benchmark import xetla_kernel
1312

@@ -31,17 +30,40 @@ def causal_mask(_, __, q_idx, kv_idx):
3130
# For details: https://github.com/pytorch/pytorch/issues/144778
3231
@benchmark_suit.perf_report(
3332
benchmark_suit.Benchmark(
34-
x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'CAUSAL', 'MODE'],
35-
x_vals=[[z, h, 16384 // z, dhead, causal, mode]
36-
for z in [1, 2, 4, 8, 16, 32]
37-
for (h, dhead) in [(16, 128), (32, 64)]
38-
for causal in [True]
39-
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
40-
+ [[4, 48, 1024, 64, True, mode] for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
41-
+ [[z, h, 1024, dhead, True, mode]
42-
for z in [1, 2, 4, 8, 16, 32, 64]
43-
for (h, dhead) in [(8, 128), (32, 96), (4, 128)]
44-
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]],
33+
x_names=['B', 'H_q', 'H_kv', 'N_CTX_q', 'N_CTX_kv', 'D_HEAD', 'MODE'],
34+
x_vals=
35+
# Multi-head attention. H_q equals H_kv
36+
[
37+
# [z, h, h, n_ctx_q, n_ctx_kv, dhead, mode]
38+
# for z in [1, 2, 4, 8, 16, 32]
39+
# for h in [1, 2, 4, 8, 16, 32]
40+
# for n_ctx_q in [128, 256, 512, 1024]
41+
# for n_ctx_kv in [128, 256, 512, 1024]
42+
# for dhead in [64, 128, 256]
43+
# for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]
44+
] +
45+
# Multi-query attention. H_kv equals 1.
46+
[] +
47+
# Grouped-query attention. H_q / H_kv > 1
48+
[[z, h_q, h_kv, n_ctx_q, n_ctx_kv, dhead, mode]
49+
for z in [1, 2, 4, 8, 16, 32]
50+
for h_q, h_kv in [(2, 1), (4, 1), (4, 2), (8, 4), (16, 8)]
51+
for n_ctx_q in [128, 256, 512, 1024]
52+
for n_ctx_kv in [128, 256, 512, 1024]
53+
for dhead in [64, 128, 256]
54+
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] +
55+
# FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
56+
[],
57+
# [[z, h, 16384 // z, dhead, causal, mode]
58+
# for z in [1, 2, 4, 8, 16, 32]
59+
# for (h, dhead) in [(16, 128), (32, 64)]
60+
# for causal in [True]
61+
# for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
62+
# + [[4, 48, 1024, 64, True, mode] for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
63+
# + [[z, h, 1024, dhead, True, mode]
64+
# for z in [1, 2, 4, 8, 16, 32, 64]
65+
# for (h, dhead) in [(8, 128), (32, 96), (4, 128)]
66+
# for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]],
4567
line_arg='provider',
4668
line_vals=['triton', 'xetla'],
4769
line_names=['Triton', 'XeTLA'],
@@ -50,89 +72,83 @@ def causal_mask(_, __, q_idx, kv_idx):
5072
plot_name='flexAttnCausal-performance',
5173
args={},
5274
))
53-
def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
75+
def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD, MODE, provider):
5476
assert MODE in ['fwd', 'bwd']
55-
assert CAUSAL
5677
dtype = torch.float16
57-
q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
58-
k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
59-
v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
78+
q = torch.randn((Z, H_q, N_CTX_q, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
79+
k = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
80+
v = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
6081
sm_scale = 0.125
6182
if MODE == 'bwd':
6283
sm_scale = 1.3
6384

6485
quantiles = [0.5, 0.0, 1.0]
6586
if provider == 'triton':
6687
kernel_options = {'num_stages': 2, 'num_warps': 16 if D_HEAD == 128 else 8, 'BLOCKS_ARE_CONTIGUOUS': True}
67-
block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX, N_CTX, device=q.device)
68-
triton_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, kernel_options=kernel_options
69-
)
88+
block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device='xpu')
89+
triton_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=(not H_q == H_kv),
90+
kernel_options=kernel_options)
7091
if MODE == 'bwd':
7192
triton_o = triton_fn()
7293
triton_do = torch.randn_like(triton_o)
7394
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
74-
torch_fn = lambda: F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), is_causal=True, scale=sm_scale).to(
75-
torch.float32)
76-
if MODE == 'bwd':
77-
torch_o = torch_fn()
78-
torch_do = torch.randn_like(torch_o)
79-
torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True)
80-
if MODE == 'fwd':
81-
atol = 1e-1 if N_CTX == 16384 else 1e-2
82-
benchmark_suit.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch')
83-
else:
84-
benchmark_suit.assert_close(lambda: triton_o, lambda: torch_o, atol=1e-2, rtol=0, err_msg='triton to torch')
8595
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
8696

8797
elif provider == 'xetla':
88-
xetla_fn = None
89-
if MODE == 'fwd':
90-
module_name = 'flash_attn_causal_True'.lower()
91-
func = getattr(xetla_kernel, module_name)
92-
out = torch.empty_like(q, device='xpu', dtype=dtype)
93-
size_score = Z * H * N_CTX * N_CTX
94-
size_attn_mask = Z * N_CTX * N_CTX
95-
dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8)
96-
bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype)
97-
size_ml = Z * H * N_CTX
98-
m = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
99-
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
100-
xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
101-
if MODE == 'bwd':
102-
module_name = 'flash_attn_bwd_causal_True'.lower()
103-
func = getattr(xetla_kernel, module_name)
104-
grad_out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
105-
bias = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
106-
dropout = torch.empty_like(q, device='xpu', dtype=torch.uint8)
107-
out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
108-
log_sumexp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
109-
workspace = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
110-
grad_q_tmp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
111-
alpha = sm_scale
112-
dropout_prob = 0
113-
grad_query = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
114-
grad_key = torch.empty_like(k, device='xpu', dtype=dtype, requires_grad=True)
115-
grad_value = torch.empty_like(v, device='xpu', dtype=dtype, requires_grad=True)
116-
grad_bias = torch.empty_like(bias, device='xpu', dtype=dtype, requires_grad=True)
117-
bias_strideB = -1
118-
bias_strideN = -1
119-
bias_strideF = -1
120-
attn_mask_padding = 0
121-
122-
xetla_fn = lambda: func(grad_out, q, k, v, bias, dropout, out, log_sumexp, workspace, grad_q_tmp, alpha,
123-
dropout_prob, grad_query, grad_key, grad_value, grad_bias, Z, H, D_HEAD, N_CTX,
124-
N_CTX, bias_strideB, bias_strideN, bias_strideF, attn_mask_padding)
125-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
98+
if (H_q == H_kv) and (N_CTX_q == N_CTX_kv):
99+
xetla_fn = None
100+
H = H_q
101+
N_CTX = N_CTX_q
102+
if MODE == 'fwd':
103+
module_name = 'flash_attn_causal_True'.lower()
104+
func = getattr(xetla_kernel, module_name)
105+
out = torch.empty_like(q, device='xpu', dtype=dtype)
106+
size_score = Z * H * N_CTX * N_CTX
107+
size_attn_mask = Z * N_CTX * N_CTX
108+
dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8)
109+
bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype)
110+
size_ml = Z * H * N_CTX
111+
m = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
112+
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
113+
xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
114+
if MODE == 'bwd':
115+
module_name = 'flash_attn_bwd_causal_True'.lower()
116+
func = getattr(xetla_kernel, module_name)
117+
grad_out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
118+
bias = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
119+
dropout = torch.empty_like(q, device='xpu', dtype=torch.uint8)
120+
out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
121+
log_sumexp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
122+
workspace = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
123+
grad_q_tmp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
124+
alpha = sm_scale
125+
dropout_prob = 0
126+
grad_query = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
127+
grad_key = torch.empty_like(k, device='xpu', dtype=dtype, requires_grad=True)
128+
grad_value = torch.empty_like(v, device='xpu', dtype=dtype, requires_grad=True)
129+
grad_bias = torch.empty_like(bias, device='xpu', dtype=dtype, requires_grad=True)
130+
bias_strideB = -1
131+
bias_strideN = -1
132+
bias_strideF = -1
133+
attn_mask_padding = 0
134+
135+
xetla_fn = lambda: func(grad_out, q, k, v, bias, dropout, out, log_sumexp, workspace, grad_q_tmp, alpha,
136+
dropout_prob, grad_query, grad_key, grad_value, grad_bias, Z, H, D_HEAD, N_CTX,
137+
N_CTX, bias_strideB, bias_strideN, bias_strideF, attn_mask_padding)
138+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
139+
quantiles=quantiles)
140+
else:
141+
_, min_ms, max_ms, mean, cv = float('nan'), float('nan'), float('nan'), float('nan'), float('nan')
126142

127143
else:
128144
raise NotImplementedError(f'Unsupported provider {provider}')
129145

130-
tflops = lambda mean: 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3)
131-
gbps = lambda mean: Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
146+
tflops = lambda mean: 2 * 2 * Z * H_q * N_CTX_q * N_CTX_kv * D_HEAD * (1e-12) / (mean * 1e-3)
147+
gbps = lambda mean: Z * H_q * (N_CTX_q * D_HEAD + N_CTX_kv * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
132148

133149
if MODE == 'bwd':
134-
tflops = lambda mean: 2.5 * 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3)
135-
gbps = lambda mean: 2.5 * Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
150+
tflops = lambda mean: 2.5 * 2 * 2 * Z * H_q * N_CTX_q * N_CTX_kv * D_HEAD * (1e-12) / (mean * 1e-3)
151+
gbps = lambda mean: 2.5 * Z * H_q * (N_CTX_q * D_HEAD + N_CTX_kv * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
136152

137153
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
138154

0 commit comments

Comments
 (0)