Skip to content

Commit 03ae588

Browse files
[FlashAttention] Remove XeTLA for fwd mode
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 8b73e5a commit 03ae588

File tree

3 files changed

+14
-57
lines changed

3 files changed

+14
-57
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,6 @@ jobs:
276276
277277
source ../../scripts/capture-hw-details.sh
278278
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-triton-report.csv --benchmark flash-attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
279-
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-xetla-report.csv --benchmark flash-attn --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
280279
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-cutlass-report.csv --benchmark flash-attn --compiler cutlass --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col CUTLASS-TFlops --hbm_col "CUTLASS-GB/s" --tag $TAG
281280
282281
- name: Run Triton FA bwd kernel benchmark
@@ -302,7 +301,6 @@ jobs:
302301
303302
source ../../scripts/capture-hw-details.sh
304303
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-triton-report.csv --benchmark flash-attn-tensor-desc --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
305-
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-xetla-report.csv --benchmark flash-attn-tensor-desc --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
306304
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-cutlass-report.csv --benchmark flash-attn-tensor-desc --compiler cutlass --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col CUTLASS-TFlops --hbm_col "CUTLASS-GB/s" --tag $TAG
307305
308306
- name: Run Triton FlexAttention Causal Mask fwd kernel benchmark

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 13 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -546,8 +546,6 @@ def get_benchmark(
546546
providers_filter: Optional[list[str]] = None,
547547
fa_kernel_mode='fwd',
548548
attn_fwd=_attn_fwd_with_block_pointers,
549-
xetla_assert_result=False,
550-
xetla_warn_mismatch=False,
551549
):
552550
"""
553551
Returns a Mark object containing a Benchmark object constructed at runtime and parameterized by the provided option values.
@@ -647,33 +645,6 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
647645
)
648646

649647
elif provider == 'xetla':
650-
xetla_fn = None
651-
if MODE == 'fwd':
652-
module_name = f'flash_attn_causal_{CAUSAL}'.lower()
653-
func = getattr(xetla_kernel, module_name)
654-
out = torch.empty_like(q, device='xpu', dtype=dtype)
655-
size_score = Z * H * N_CTX * N_CTX
656-
size_attn_mask = Z * N_CTX * N_CTX
657-
dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8)
658-
bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype)
659-
size_ml = Z * H * N_CTX
660-
m = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
661-
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
662-
663-
def xetla_fwd_fn():
664-
func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
665-
return out
666-
667-
xetla_fn = xetla_fwd_fn
668-
669-
def check_xetla_fwd_result():
670-
if xetla_assert_result:
671-
benchmark_suite.assert_close(xetla_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='xetla to torch')
672-
elif xetla_warn_mismatch:
673-
check_close(xetla_fn, torch_fn, atol, 1e-3)
674-
675-
check_xetla_fwd_result()
676-
677648
if MODE == 'bwd':
678649
module_name = f'flash_attn_bwd_causal_{CAUSAL}'.lower()
679650
func = getattr(xetla_kernel, module_name)
@@ -701,18 +672,20 @@ def xetla_bwd_fn():
701672
bias_strideN, bias_strideF, attn_mask_padding)
702673
return out
703674

704-
xetla_fn = xetla_bwd_fn
675+
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
676+
xetla_bwd_fn,
677+
n_warmup=10,
678+
n_repeat=10,
679+
quantiles=quantiles,
680+
)
705681

706-
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
707-
xetla_fn,
708-
n_warmup=10,
709-
n_repeat=10,
710-
quantiles=quantiles,
711-
)
682+
else:
683+
min_ms = float('nan')
684+
max_ms = float('nan')
685+
mean = float('nan')
686+
cv = float('nan')
712687

713688
elif provider == 'cutlass':
714-
cutlass_fn = None
715-
716689
if MODE == 'fwd':
717690
name = 'attention'
718691
func = getattr(cutlass_kernel, name)
@@ -723,17 +696,15 @@ def cutlass_fwd_fn():
723696
return out
724697

725698
benchmark_suite.assert_close(cutlass_fwd_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='cutlass to torch')
726-
cutlass_fn = cutlass_fwd_fn
727699

728700
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
729-
cutlass_fn,
701+
cutlass_fwd_fn,
730702
n_warmup=10,
731703
n_repeat=10,
732704
quantiles=quantiles,
733705
)
734706

735707
else:
736-
cutlass_fn = None
737708
min_ms = float('nan')
738709
max_ms = float('nan')
739710
mean = float('nan')
@@ -755,9 +726,5 @@ def cutlass_fwd_fn():
755726

756727

757728
if __name__ == '__main__':
758-
_benchmark = get_benchmark(
759-
fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'),
760-
xetla_assert_result=(os.getenv('XETLA_ASSERT_RESULT', '0') == '1'),
761-
xetla_warn_mismatch=(os.getenv('XETLA_WARN_MISMATCH', '0') == '1'),
762-
)
729+
_benchmark = get_benchmark(fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'), )
763730
_benchmark.run(show_plots=False, print_data=True)

benchmarks/triton_kernels_benchmark/flash_attention_tensor_desc_benchmark.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,22 +141,14 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, #
141141
def get_benchmark(
142142
providers_filter: Optional[list[str]] = None,
143143
fa_kernel_mode='fwd',
144-
xetla_assert_result=False,
145-
xetla_warn_mismatch=False,
146144
):
147145
return flash_attention_benchmark.get_benchmark(
148146
providers_filter=providers_filter,
149147
fa_kernel_mode=fa_kernel_mode,
150148
attn_fwd=_attn_fwd_with_tensor_desc,
151-
xetla_assert_result=xetla_assert_result,
152-
xetla_warn_mismatch=xetla_warn_mismatch,
153149
)
154150

155151

156152
if __name__ == '__main__':
157-
_benchmark = get_benchmark(
158-
fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'),
159-
xetla_assert_result=(os.getenv('XETLA_ASSERT_RESULT', '0') == '1'),
160-
xetla_warn_mismatch=(os.getenv('XETLA_WARN_MISMATCH', '0') == '1'),
161-
)
153+
_benchmark = get_benchmark(fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'), )
162154
_benchmark.run(show_plots=False, print_data=True)

0 commit comments

Comments
 (0)