7
7
)
8
8
9
9
import torch
10
- import torch .nn .functional as F
11
10
import triton_kernels_benchmark as benchmark_suit
12
11
from triton_kernels_benchmark import xetla_kernel
13
12
@@ -31,17 +30,40 @@ def causal_mask(_, __, q_idx, kv_idx):
31
30
# For details: https://github.com/pytorch/pytorch/issues/144778
32
31
@benchmark_suit .perf_report (
33
32
benchmark_suit .Benchmark (
34
- x_names = ['Z' , 'H' , 'N_CTX' , 'D_HEAD' , 'CAUSAL' , 'MODE' ],
35
- x_vals = [[z , h , 16384 // z , dhead , causal , mode ]
36
- for z in [1 , 2 , 4 , 8 , 16 , 32 ]
37
- for (h , dhead ) in [(16 , 128 ), (32 , 64 )]
38
- for causal in [True ]
39
- for mode in [os .getenv ('FA_KERNEL_MODE' , 'fwd' )]] #
40
- + [[4 , 48 , 1024 , 64 , True , mode ] for mode in [os .getenv ('FA_KERNEL_MODE' , 'fwd' )]] #
41
- + [[z , h , 1024 , dhead , True , mode ]
42
- for z in [1 , 2 , 4 , 8 , 16 , 32 , 64 ]
43
- for (h , dhead ) in [(8 , 128 ), (32 , 96 ), (4 , 128 )]
44
- for mode in [os .getenv ('FA_KERNEL_MODE' , 'fwd' )]],
33
+ x_names = ['B' , 'H_q' , 'H_kv' , 'N_CTX_q' , 'N_CTX_kv' , 'D_HEAD' , 'MODE' ],
34
+ x_vals =
35
+ # Multi-head attention. H_q equals H_kv
36
+ [
37
+ # [z, h, h, n_ctx_q, n_ctx_kv, dhead, mode]
38
+ # for z in [1, 2, 4, 8, 16, 32]
39
+ # for h in [1, 2, 4, 8, 16, 32]
40
+ # for n_ctx_q in [128, 256, 512, 1024]
41
+ # for n_ctx_kv in [128, 256, 512, 1024]
42
+ # for dhead in [64, 128, 256]
43
+ # for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]
44
+ ] +
45
+ # Multi-query attention. H_kv equals 1.
46
+ [] +
47
+ # Grouped-query attention. H_q / H_kv > 1
48
+ [[z , h_q , h_kv , n_ctx_q , n_ctx_kv , dhead , mode ]
49
+ for z in [1 , 2 , 4 , 8 , 16 , 32 ]
50
+ for h_q , h_kv in [(2 , 1 ), (4 , 1 ), (4 , 2 ), (8 , 4 ), (16 , 8 )]
51
+ for n_ctx_q in [128 , 256 , 512 , 1024 ]
52
+ for n_ctx_kv in [128 , 256 , 512 , 1024 ]
53
+ for dhead in [64 , 128 , 256 ]
54
+ for mode in [os .getenv ('FA_KERNEL_MODE' , 'fwd' )]] +
55
+ # FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
56
+ [],
57
+ # [[z, h, 16384 // z, dhead, causal, mode]
58
+ # for z in [1, 2, 4, 8, 16, 32]
59
+ # for (h, dhead) in [(16, 128), (32, 64)]
60
+ # for causal in [True]
61
+ # for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
62
+ # + [[4, 48, 1024, 64, True, mode] for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
63
+ # + [[z, h, 1024, dhead, True, mode]
64
+ # for z in [1, 2, 4, 8, 16, 32, 64]
65
+ # for (h, dhead) in [(8, 128), (32, 96), (4, 128)]
66
+ # for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]],
45
67
line_arg = 'provider' ,
46
68
line_vals = ['triton' , 'xetla' ],
47
69
line_names = ['Triton' , 'XeTLA' ],
@@ -50,89 +72,83 @@ def causal_mask(_, __, q_idx, kv_idx):
50
72
plot_name = 'flexAttnCausal-performance' ,
51
73
args = {},
52
74
))
53
- def benchmark (Z , H , N_CTX , D_HEAD , CAUSAL , MODE , provider ):
75
+ def benchmark (Z , H_q , H_kv , N_CTX_q , N_CTX_kv , D_HEAD , MODE , provider ):
54
76
assert MODE in ['fwd' , 'bwd' ]
55
- assert CAUSAL
56
77
dtype = torch .float16
57
- q = torch .randn ((Z , H , N_CTX , D_HEAD ), device = 'xpu' , dtype = dtype , requires_grad = True )
58
- k = torch .randn ((Z , H , N_CTX , D_HEAD ), device = 'xpu' , dtype = dtype , requires_grad = True )
59
- v = torch .randn ((Z , H , N_CTX , D_HEAD ), device = 'xpu' , dtype = dtype , requires_grad = True )
78
+ q = torch .randn ((Z , H_q , N_CTX_q , D_HEAD ), device = 'xpu' , dtype = dtype , requires_grad = True )
79
+ k = torch .randn ((Z , H_kv , N_CTX_kv , D_HEAD ), device = 'xpu' , dtype = dtype , requires_grad = True )
80
+ v = torch .randn ((Z , H_kv , N_CTX_kv , D_HEAD ), device = 'xpu' , dtype = dtype , requires_grad = True )
60
81
sm_scale = 0.125
61
82
if MODE == 'bwd' :
62
83
sm_scale = 1.3
63
84
64
85
quantiles = [0.5 , 0.0 , 1.0 ]
65
86
if provider == 'triton' :
66
87
kernel_options = {'num_stages' : 2 , 'num_warps' : 16 if D_HEAD == 128 else 8 , 'BLOCKS_ARE_CONTIGUOUS' : True }
67
- block_mask = create_block_mask_cached (causal_mask , 1 , 1 , N_CTX , N_CTX , device = q . device )
68
- triton_fn = lambda : flex_attention (q , k , v , block_mask = block_mask , scale = sm_scale , kernel_options = kernel_options
69
- )
88
+ block_mask = create_block_mask_cached (causal_mask , 1 , 1 , N_CTX_q , N_CTX_kv , device = 'xpu' )
89
+ triton_fn = lambda : flex_attention (q , k , v , block_mask = block_mask , scale = sm_scale , enable_gqa = ( not H_q == H_kv ),
90
+ kernel_options = kernel_options )
70
91
if MODE == 'bwd' :
71
92
triton_o = triton_fn ()
72
93
triton_do = torch .randn_like (triton_o )
73
94
triton_fn = lambda : triton_o .backward (triton_do , retain_graph = True )
74
- torch_fn = lambda : F .scaled_dot_product_attention (q .cpu (), k .cpu (), v .cpu (), is_causal = True , scale = sm_scale ).to (
75
- torch .float32 )
76
- if MODE == 'bwd' :
77
- torch_o = torch_fn ()
78
- torch_do = torch .randn_like (torch_o )
79
- torch_fn = lambda : torch_o .backward (torch_do , retain_graph = True )
80
- if MODE == 'fwd' :
81
- atol = 1e-1 if N_CTX == 16384 else 1e-2
82
- benchmark_suit .assert_close (triton_fn , torch_fn , atol = atol , rtol = 1e-3 , err_msg = 'triton to torch' )
83
- else :
84
- benchmark_suit .assert_close (lambda : triton_o , lambda : torch_o , atol = 1e-2 , rtol = 0 , err_msg = 'triton to torch' )
85
95
_ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles )
86
96
87
97
elif provider == 'xetla' :
88
- xetla_fn = None
89
- if MODE == 'fwd' :
90
- module_name = 'flash_attn_causal_True' .lower ()
91
- func = getattr (xetla_kernel , module_name )
92
- out = torch .empty_like (q , device = 'xpu' , dtype = dtype )
93
- size_score = Z * H * N_CTX * N_CTX
94
- size_attn_mask = Z * N_CTX * N_CTX
95
- dropout_mask = torch .empty ((size_score , ), device = 'xpu' , dtype = torch .uint8 )
96
- bias = torch .empty ((size_attn_mask , ), device = 'xpu' , dtype = dtype )
97
- size_ml = Z * H * N_CTX
98
- m = torch .empty ((size_ml , ), device = 'xpu' , dtype = torch .float )
99
- l = torch .empty ((size_ml , ), device = 'xpu' , dtype = torch .float )
100
- xetla_fn = lambda : func (q , k , v , out , dropout_mask , bias , m , l , Z , H , D_HEAD , N_CTX , N_CTX , sm_scale )
101
- if MODE == 'bwd' :
102
- module_name = 'flash_attn_bwd_causal_True' .lower ()
103
- func = getattr (xetla_kernel , module_name )
104
- grad_out = torch .empty_like (q , device = 'xpu' , dtype = dtype , requires_grad = True )
105
- bias = torch .empty_like (q , device = 'xpu' , dtype = dtype , requires_grad = True )
106
- dropout = torch .empty_like (q , device = 'xpu' , dtype = torch .uint8 )
107
- out = torch .empty_like (q , device = 'xpu' , dtype = dtype , requires_grad = True )
108
- log_sumexp = torch .zeros (q .size (), device = 'xpu' , dtype = dtype , requires_grad = True )
109
- workspace = torch .zeros (q .size (), device = 'xpu' , dtype = dtype , requires_grad = True )
110
- grad_q_tmp = torch .zeros (q .size (), device = 'xpu' , dtype = dtype , requires_grad = True )
111
- alpha = sm_scale
112
- dropout_prob = 0
113
- grad_query = torch .empty_like (q , device = 'xpu' , dtype = dtype , requires_grad = True )
114
- grad_key = torch .empty_like (k , device = 'xpu' , dtype = dtype , requires_grad = True )
115
- grad_value = torch .empty_like (v , device = 'xpu' , dtype = dtype , requires_grad = True )
116
- grad_bias = torch .empty_like (bias , device = 'xpu' , dtype = dtype , requires_grad = True )
117
- bias_strideB = - 1
118
- bias_strideN = - 1
119
- bias_strideF = - 1
120
- attn_mask_padding = 0
121
-
122
- xetla_fn = lambda : func (grad_out , q , k , v , bias , dropout , out , log_sumexp , workspace , grad_q_tmp , alpha ,
123
- dropout_prob , grad_query , grad_key , grad_value , grad_bias , Z , H , D_HEAD , N_CTX ,
124
- N_CTX , bias_strideB , bias_strideN , bias_strideF , attn_mask_padding )
125
- _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (xetla_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles )
98
+ if (H_q == H_kv ) and (N_CTX_q == N_CTX_kv ):
99
+ xetla_fn = None
100
+ H = H_q
101
+ N_CTX = N_CTX_q
102
+ if MODE == 'fwd' :
103
+ module_name = 'flash_attn_causal_True' .lower ()
104
+ func = getattr (xetla_kernel , module_name )
105
+ out = torch .empty_like (q , device = 'xpu' , dtype = dtype )
106
+ size_score = Z * H * N_CTX * N_CTX
107
+ size_attn_mask = Z * N_CTX * N_CTX
108
+ dropout_mask = torch .empty ((size_score , ), device = 'xpu' , dtype = torch .uint8 )
109
+ bias = torch .empty ((size_attn_mask , ), device = 'xpu' , dtype = dtype )
110
+ size_ml = Z * H * N_CTX
111
+ m = torch .empty ((size_ml , ), device = 'xpu' , dtype = torch .float )
112
+ l = torch .empty ((size_ml , ), device = 'xpu' , dtype = torch .float )
113
+ xetla_fn = lambda : func (q , k , v , out , dropout_mask , bias , m , l , Z , H , D_HEAD , N_CTX , N_CTX , sm_scale )
114
+ if MODE == 'bwd' :
115
+ module_name = 'flash_attn_bwd_causal_True' .lower ()
116
+ func = getattr (xetla_kernel , module_name )
117
+ grad_out = torch .empty_like (q , device = 'xpu' , dtype = dtype , requires_grad = True )
118
+ bias = torch .empty_like (q , device = 'xpu' , dtype = dtype , requires_grad = True )
119
+ dropout = torch .empty_like (q , device = 'xpu' , dtype = torch .uint8 )
120
+ out = torch .empty_like (q , device = 'xpu' , dtype = dtype , requires_grad = True )
121
+ log_sumexp = torch .zeros (q .size (), device = 'xpu' , dtype = dtype , requires_grad = True )
122
+ workspace = torch .zeros (q .size (), device = 'xpu' , dtype = dtype , requires_grad = True )
123
+ grad_q_tmp = torch .zeros (q .size (), device = 'xpu' , dtype = dtype , requires_grad = True )
124
+ alpha = sm_scale
125
+ dropout_prob = 0
126
+ grad_query = torch .empty_like (q , device = 'xpu' , dtype = dtype , requires_grad = True )
127
+ grad_key = torch .empty_like (k , device = 'xpu' , dtype = dtype , requires_grad = True )
128
+ grad_value = torch .empty_like (v , device = 'xpu' , dtype = dtype , requires_grad = True )
129
+ grad_bias = torch .empty_like (bias , device = 'xpu' , dtype = dtype , requires_grad = True )
130
+ bias_strideB = - 1
131
+ bias_strideN = - 1
132
+ bias_strideF = - 1
133
+ attn_mask_padding = 0
134
+
135
+ xetla_fn = lambda : func (grad_out , q , k , v , bias , dropout , out , log_sumexp , workspace , grad_q_tmp , alpha ,
136
+ dropout_prob , grad_query , grad_key , grad_value , grad_bias , Z , H , D_HEAD , N_CTX ,
137
+ N_CTX , bias_strideB , bias_strideN , bias_strideF , attn_mask_padding )
138
+ _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (xetla_fn , n_warmup = 10 , n_repeat = 10 ,
139
+ quantiles = quantiles )
140
+ else :
141
+ _ , min_ms , max_ms , mean , cv = float ('nan' ), float ('nan' ), float ('nan' ), float ('nan' ), float ('nan' )
126
142
127
143
else :
128
144
raise NotImplementedError (f'Unsupported provider { provider } ' )
129
145
130
- tflops = lambda mean : 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12 ) / (mean * 1e-3 )
131
- gbps = lambda mean : Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD ) * 2 * 2 * (1e-9 ) / (mean * 1e-3 )
146
+ tflops = lambda mean : 2 * 2 * Z * H_q * N_CTX_q * N_CTX_kv * D_HEAD * (1e-12 ) / (mean * 1e-3 )
147
+ gbps = lambda mean : Z * H_q * (N_CTX_q * D_HEAD + N_CTX_kv * D_HEAD ) * 2 * 2 * (1e-9 ) / (mean * 1e-3 )
132
148
133
149
if MODE == 'bwd' :
134
- tflops = lambda mean : 2.5 * 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12 ) / (mean * 1e-3 )
135
- gbps = lambda mean : 2.5 * Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD ) * 2 * 2 * (1e-9 ) / (mean * 1e-3 )
150
+ tflops = lambda mean : 2.5 * 2 * 2 * Z * H_q * N_CTX_q * N_CTX_kv * D_HEAD * (1e-12 ) / (mean * 1e-3 )
151
+ gbps = lambda mean : 2.5 * Z * H_q * (N_CTX_q * D_HEAD + N_CTX_kv * D_HEAD ) * 2 * 2 * (1e-9 ) / (mean * 1e-3 )
136
152
137
153
return (gbps (mean ), gbps (max_ms ), gbps (min_ms )), (tflops (mean ), tflops (max_ms ), tflops (min_ms )), cv
138
154
0 commit comments