3
3
import os
4
4
from torch .nn .attention .flex_attention import (
5
5
create_block_mask ,
6
+ create_mask ,
6
7
flex_attention ,
7
8
)
8
9
9
10
import torch
10
11
import torch .nn .functional as F
12
+
11
13
import triton_kernels_benchmark as benchmark_suit
12
- from triton_kernels_benchmark import xetla_kernel
13
14
14
15
torch ._dynamo .config .recompile_limit = 100 # pylint: disable=protected-access
15
16
16
17
# Compile the flex_attention function
17
- flex_attention = torch .compile (flex_attention , dynamic = False )
18
+ compiled_flex_attention = torch .compile (flex_attention , dynamic = False )
18
19
19
20
20
21
@lru_cache
@@ -27,112 +28,128 @@ def causal_mask(_, __, q_idx, kv_idx):
27
28
return q_idx >= kv_idx
28
29
29
30
31
+ throughput_test = os .getenv ('THROUGHPUT_TEST' , '0' ) == '1'
32
+ batch_sizes = [16 , 32 , 64 ] if throughput_test else [1 ]
33
+
34
+
30
35
# Kernel profiling for Backward mode is not working as expected:
31
36
# For details: https://github.com/pytorch/pytorch/issues/144778
32
37
@benchmark_suit .perf_report (
33
38
benchmark_suit .Benchmark (
34
- x_names = ['Z' , 'H' , 'N_CTX' , 'D_HEAD' , 'CAUSAL' , 'MODE' ],
35
- x_vals = [[z , h , 16384 // z , dhead , causal , mode ]
36
- for z in [1 , 2 , 4 , 8 , 16 , 32 ]
37
- for (h , dhead ) in [(16 , 128 ), (32 , 64 )]
38
- for causal in [True ]
39
- for mode in [os .getenv ('FA_KERNEL_MODE' , 'fwd' )]] #
40
- + [[4 , 48 , 1024 , 64 , True , mode ] for mode in [os .getenv ('FA_KERNEL_MODE' , 'fwd' )]] #
41
- + [[z , h , 1024 , dhead , True , mode ]
42
- for z in [1 , 2 , 4 , 8 , 16 , 32 , 64 ]
43
- for (h , dhead ) in [(8 , 128 ), (32 , 96 ), (4 , 128 )]
44
- for mode in [os .getenv ('FA_KERNEL_MODE' , 'fwd' )]],
39
+ x_names = ['Z' , 'H_q' , 'H_kv' , 'N_CTX_q' , 'N_CTX_kv' , 'D_HEAD_qk' , 'D_HEAD_v' , 'MODE' ],
40
+ x_vals =
41
+ # Multi-head attention. H_q equals H_kv
42
+ # Prefill shapes of Phi3-mini-3.8B
43
+ [[z , 32 , 32 , 1024 , 1024 , 96 , 96 , 'fwd' ] for z in batch_sizes ] +
44
+ # Prefill shapes of Deepseek-v3
45
+ [[z , 128 , 128 , 1024 , 1024 , 192 , 128 , 'fwd' ] for z in batch_sizes ] +
46
+ # Append shapes of Phi3-mini-3.8B
47
+ [[z , 32 , 32 , 512 , 1024 + 128 + 512 , 96 , 96 , 'fwd' ] for z in batch_sizes ] +
48
+
49
+ # Multi-query attention. H_kv equals 1.
50
+ # Append shapes of Deepseek-v3 (Nope)
51
+ [
52
+ # RuntimeError: No valid triton configs. OutOfResources: out of resource: shared memory, Required: 133120, Hardware limit: 131072.
53
+ # [z, 128, 1, 512, 1024 + 128 + 512, 64, 512, 'fwd'] for z in batch_sizes
54
+ ] +
55
+ # Append shapes of Deepseek-v3 (Rope)
56
+ [] +
57
+
58
+ # Grouped-query attention. H_q / H_kv > 1
59
+ # Prefill shapes of Llama-3.1-8B
60
+ [[z , 32 , 8 , 1024 , 1024 , 128 , 128 , 'fwd' ] for z in batch_sizes ] +
61
+ # Prefill shapes of Qwen2-7B
62
+ [[z , 28 , 4 , 1024 , 1024 , 128 , 128 , 'fwd' ] for z in batch_sizes ] +
63
+ # Append shapes of Llama-3.1-8B
64
+ [[z , 32 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , 'fwd' ] for z in batch_sizes ] +
65
+ # Append shapes of Qwen2-7B
66
+ [[z , 28 , 4 , 512 , 1024 + 128 + 512 , 128 , 128 , 'fwd' ] for z in batch_sizes ] +
67
+
68
+ # FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
69
+ # Decode shapes of Llama-3.1-8B
70
+ [[z , 32 , 8 , 1 , 1024 + 64 , 128 , 128 , 'fwd' ] for z in batch_sizes ] +
71
+ # Decode shapes of Phi3-mini-3.8B
72
+ [
73
+ # acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
74
+ # ValueError: Shape element 2 must be a power of 2
75
+ # [z, 32, 32, 1, 1024 + 64, 96, 96, 'fwd'] for z in batch_sizes
76
+ ] +
77
+ # Decode shapes of Qwen2-7B
78
+ [
79
+ # torch._inductor.exc.InductorError: LoweringException: ValueError: Number of shared query heads sharing the same KV head must be power of 2.
80
+ # [z, 28, 4, 1, 1024 + 64, 128, 128, 'fwd'] for z in batch_sizes
81
+ ] +
82
+ # Decode shapes of Deepseek-v3 (Nope)
83
+ [
84
+ # RuntimeError: No valid triton configs. OutOfResources: out of resource: shared memory, Required: 264192, Hardware limit: 131072.
85
+ # [z, 128, 1, 1, 1024, 64, 512, 'fwd'] for z in batch_sizes
86
+ ] +
87
+ # Decode shapes of Deepseek-v3 (Rope)
88
+ [],
45
89
line_arg = 'provider' ,
46
- line_vals = ['triton' , 'xetla' ],
47
- line_names = ['Triton' , 'XeTLA' ],
90
+ line_vals = ['triton' ],
91
+ line_names = ['Triton' ],
48
92
styles = [('green' , '-' ), ('green' , '--' ), ('blue' , '-' ), ('blue' , '--' )],
49
93
ylabel = ['GB/s' , 'TFlops' ],
50
94
plot_name = 'flexAttnCausal-performance' ,
51
95
args = {},
52
96
))
53
- def benchmark (Z , H , N_CTX , D_HEAD , CAUSAL , MODE , provider ):
54
- assert MODE in ['fwd' , 'bwd' ]
55
- assert CAUSAL
97
+ def benchmark (Z , H_q , H_kv , N_CTX_q , N_CTX_kv , D_HEAD_qk , D_HEAD_v , MODE , provider ):
98
+ assert MODE in ['fwd' ]
56
99
dtype = torch .float16
57
- q = torch .randn ((Z , H , N_CTX , D_HEAD ), device = 'xpu' , dtype = dtype , requires_grad = True )
58
- k = torch .randn ((Z , H , N_CTX , D_HEAD ), device = 'xpu' , dtype = dtype , requires_grad = True )
59
- v = torch .randn ((Z , H , N_CTX , D_HEAD ), device = 'xpu' , dtype = dtype , requires_grad = True )
100
+ q = torch .randn ((Z , H_q , N_CTX_q , D_HEAD_qk ), device = 'xpu' , dtype = dtype , requires_grad = MODE == 'bwd' )
101
+ k = torch .randn ((Z , H_kv , N_CTX_kv , D_HEAD_qk ), device = 'xpu' , dtype = dtype , requires_grad = MODE == 'bwd' )
102
+ v = torch .randn ((Z , H_kv , N_CTX_kv , D_HEAD_v ), device = 'xpu' , dtype = dtype , requires_grad = MODE == 'bwd' )
60
103
sm_scale = 0.125
61
104
if MODE == 'bwd' :
62
105
sm_scale = 1.3
63
106
64
107
quantiles = [0.5 , 0.0 , 1.0 ]
65
108
if provider == 'triton' :
66
- kernel_options = {'num_stages' : 2 , 'num_warps' : 16 if D_HEAD == 128 else 8 , 'BLOCKS_ARE_CONTIGUOUS' : True }
67
- block_mask = create_block_mask_cached (causal_mask , 1 , 1 , N_CTX , N_CTX , device = q .device )
68
- triton_fn = lambda : flex_attention (q , k , v , block_mask = block_mask , scale = sm_scale , kernel_options = kernel_options
69
- )
109
+ kernel_options = {'num_stages' : 2 , 'num_warps' : 16 if D_HEAD_qk == 128 else 8 , 'BLOCKS_ARE_CONTIGUOUS' : True }
110
+ block_mask = create_block_mask_cached (causal_mask , 1 , 1 , N_CTX_q , N_CTX_kv , device = 'xpu' )
111
+ triton_fn = lambda : compiled_flex_attention (q , k , v , block_mask = block_mask , scale = sm_scale , enable_gqa = (
112
+ not H_q == H_kv ), kernel_options = kernel_options )
113
+ torch_fn = lambda : flex_attention (q , k , v , block_mask = block_mask , scale = sm_scale , enable_gqa = not H_q == H_kv )
70
114
if MODE == 'bwd' :
71
115
triton_o = triton_fn ()
72
116
triton_do = torch .randn_like (triton_o )
73
117
triton_fn = lambda : triton_o .backward (triton_do , retain_graph = True )
74
- torch_fn = lambda : F .scaled_dot_product_attention (q .cpu (), k .cpu (), v .cpu (), is_causal = True , scale = sm_scale ).to (
75
- torch .float32 )
76
- if MODE == 'bwd' :
77
- torch_o = torch_fn ()
78
- torch_do = torch .randn_like (torch_o )
79
- torch_fn = lambda : torch_o .backward (torch_do , retain_graph = True )
80
- if MODE == 'fwd' :
81
- atol = 1e-1 if N_CTX == 16384 else 1e-2
82
- benchmark_suit .assert_close (triton_fn , torch_fn , atol = atol , rtol = 1e-3 , err_msg = 'triton to torch' )
83
- else :
84
- benchmark_suit .assert_close (lambda : triton_o , lambda : torch_o , atol = 1e-2 , rtol = 0 , err_msg = 'triton to torch' )
118
+
119
+ atol = 1e-1
120
+ benchmark_suit .assert_close (triton_fn , torch_fn , atol = atol , rtol = 1e-3 , err_msg = 'triton to torch' )
85
121
_ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles )
86
122
87
- elif provider == 'xetla' :
88
- xetla_fn = None
89
- if MODE == 'fwd' :
90
- module_name = 'flash_attn_causal_True' .lower ()
91
- func = getattr (xetla_kernel , module_name )
92
- out = torch .empty_like (q , device = 'xpu' , dtype = dtype )
93
- size_score = Z * H * N_CTX * N_CTX
94
- size_attn_mask = Z * N_CTX * N_CTX
95
- dropout_mask = torch .empty ((size_score , ), device = 'xpu' , dtype = torch .uint8 )
96
- bias = torch .empty ((size_attn_mask , ), device = 'xpu' , dtype = dtype )
97
- size_ml = Z * H * N_CTX
98
- m = torch .empty ((size_ml , ), device = 'xpu' , dtype = torch .float )
99
- l = torch .empty ((size_ml , ), device = 'xpu' , dtype = torch .float )
100
- xetla_fn = lambda : func (q , k , v , out , dropout_mask , bias , m , l , Z , H , D_HEAD , N_CTX , N_CTX , sm_scale )
101
- if MODE == 'bwd' :
102
- module_name = 'flash_attn_bwd_causal_True' .lower ()
103
- func = getattr (xetla_kernel , module_name )
104
- grad_out = torch .empty_like (q , device = 'xpu' , dtype = dtype , requires_grad = True )
105
- bias = torch .empty_like (q , device = 'xpu' , dtype = dtype , requires_grad = True )
106
- dropout = torch .empty_like (q , device = 'xpu' , dtype = torch .uint8 )
107
- out = torch .empty_like (q , device = 'xpu' , dtype = dtype , requires_grad = True )
108
- log_sumexp = torch .zeros (q .size (), device = 'xpu' , dtype = dtype , requires_grad = True )
109
- workspace = torch .zeros (q .size (), device = 'xpu' , dtype = dtype , requires_grad = True )
110
- grad_q_tmp = torch .zeros (q .size (), device = 'xpu' , dtype = dtype , requires_grad = True )
111
- alpha = sm_scale
112
- dropout_prob = 0
113
- grad_query = torch .empty_like (q , device = 'xpu' , dtype = dtype , requires_grad = True )
114
- grad_key = torch .empty_like (k , device = 'xpu' , dtype = dtype , requires_grad = True )
115
- grad_value = torch .empty_like (v , device = 'xpu' , dtype = dtype , requires_grad = True )
116
- grad_bias = torch .empty_like (bias , device = 'xpu' , dtype = dtype , requires_grad = True )
117
- bias_strideB = - 1
118
- bias_strideN = - 1
119
- bias_strideF = - 1
120
- attn_mask_padding = 0
121
-
122
- xetla_fn = lambda : func (grad_out , q , k , v , bias , dropout , out , log_sumexp , workspace , grad_q_tmp , alpha ,
123
- dropout_prob , grad_query , grad_key , grad_value , grad_bias , Z , H , D_HEAD , N_CTX ,
124
- N_CTX , bias_strideB , bias_strideN , bias_strideF , attn_mask_padding )
125
- _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (xetla_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles )
123
+ elif provider == 'onednn' :
124
+ # OneDNN only supports MHA.
125
+ if H_q == H_kv :
126
+ mask = create_mask (causal_mask , 1 , 1 , N_CTX_q , N_CTX_kv , device = q .device )
127
+ xformers_fn = lambda : F .scaled_dot_product_attention (q , k , v , attn_mask = mask )
128
+ if MODE == 'bwd' :
129
+ xformers_o = xformers_fn ()
130
+ xformers_do = torch .randn_like (xformers_o )
131
+ xformers_fn = lambda : xformers_o .backward (xformers_do , retain_graph = True )
132
+ _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (xformers_fn , n_warmup = 10 , n_repeat = 10 ,
133
+ quantiles = quantiles )
134
+ else :
135
+ _ , min_ms , max_ms , mean , cv = float ('nan' ), float ('nan' ), float ('nan' ), float ('nan' ), float ('nan' )
126
136
127
137
else :
128
138
raise NotImplementedError (f'Unsupported provider { provider } ' )
129
139
130
- tflops = lambda mean : 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12 ) / (mean * 1e-3 )
131
- gbps = lambda mean : Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD ) * 2 * 2 * (1e-9 ) / (mean * 1e-3 )
140
+ qk_flops = H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk
141
+ pv_flops = H_q * N_CTX_q * D_HEAD_v * N_CTX_kv
142
+ tflops = lambda mean : Z * (qk_flops + pv_flops ) * (1e-12 ) / (mean * 1e-3 )
143
+
144
+ q_elems = H_q * N_CTX_q * D_HEAD_qk
145
+ k_elems = H_kv * N_CTX_kv * D_HEAD_qk
146
+ v_elems = H_kv * N_CTX_kv * D_HEAD_v
147
+ gbps = lambda mean : Z * (q_elems + k_elems + v_elems ) * 2 * (1e-9 ) / (mean * 1e-3 ) # float16 2 bytes
132
148
133
149
if MODE == 'bwd' :
134
- tflops = lambda mean : 2.5 * 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12 ) / (mean * 1e-3 )
135
- gbps = lambda mean : 2.5 * Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD ) * 2 * 2 * (1e-9 ) / (mean * 1e-3 )
150
+ tflops = lambda mean : 2.5 * 2 * 2 * Z * H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk * (1e-12 ) / (mean * 1e-3 )
151
+ gbps = lambda mean : 2.5 * Z * H_q * (N_CTX_q * D_HEAD_qk + N_CTX_kv * D_HEAD_qk ) * 2 * 2 * (1e-9 ) / (mean * 1e-3
152
+ )
136
153
137
154
return (gbps (mean ), gbps (max_ms ), gbps (min_ms )), (tflops (mean ), tflops (max_ms ), tflops (min_ms )), cv
138
155
0 commit comments