@@ -546,8 +546,6 @@ def get_benchmark(
546
546
providers_filter : Optional [list [str ]] = None ,
547
547
fa_kernel_mode = 'fwd' ,
548
548
attn_fwd = _attn_fwd_with_block_pointers ,
549
- xetla_assert_result = False ,
550
- xetla_warn_mismatch = False ,
551
549
):
552
550
"""
553
551
Returns a Mark object containing a Benchmark object constructed at runtime and parameterized by the provided option values.
@@ -647,33 +645,6 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
647
645
)
648
646
649
647
elif provider == 'xetla' :
650
- xetla_fn = None
651
- if MODE == 'fwd' :
652
- module_name = f'flash_attn_causal_{ CAUSAL } ' .lower ()
653
- func = getattr (xetla_kernel , module_name )
654
- out = torch .empty_like (q , device = 'xpu' , dtype = dtype )
655
- size_score = Z * H * N_CTX * N_CTX
656
- size_attn_mask = Z * N_CTX * N_CTX
657
- dropout_mask = torch .empty ((size_score , ), device = 'xpu' , dtype = torch .uint8 )
658
- bias = torch .empty ((size_attn_mask , ), device = 'xpu' , dtype = dtype )
659
- size_ml = Z * H * N_CTX
660
- m = torch .empty ((size_ml , ), device = 'xpu' , dtype = torch .float )
661
- l = torch .empty ((size_ml , ), device = 'xpu' , dtype = torch .float )
662
-
663
- def xetla_fwd_fn ():
664
- func (q , k , v , out , dropout_mask , bias , m , l , Z , H , D_HEAD , N_CTX , N_CTX , sm_scale )
665
- return out
666
-
667
- xetla_fn = xetla_fwd_fn
668
-
669
- def check_xetla_fwd_result ():
670
- if xetla_assert_result :
671
- benchmark_suite .assert_close (xetla_fn , torch_fn , atol = atol , rtol = 1e-3 , err_msg = 'xetla to torch' )
672
- elif xetla_warn_mismatch :
673
- check_close (xetla_fn , torch_fn , atol , 1e-3 )
674
-
675
- check_xetla_fwd_result ()
676
-
677
648
if MODE == 'bwd' :
678
649
module_name = f'flash_attn_bwd_causal_{ CAUSAL } ' .lower ()
679
650
func = getattr (xetla_kernel , module_name )
@@ -701,18 +672,20 @@ def xetla_bwd_fn():
701
672
bias_strideN , bias_strideF , attn_mask_padding )
702
673
return out
703
674
704
- xetla_fn = xetla_bwd_fn
675
+ _ , min_ms , max_ms , mean , cv = benchmark_suite .do_bench (
676
+ xetla_bwd_fn ,
677
+ n_warmup = 10 ,
678
+ n_repeat = 10 ,
679
+ quantiles = quantiles ,
680
+ )
705
681
706
- _ , min_ms , max_ms , mean , cv = benchmark_suite .do_bench (
707
- xetla_fn ,
708
- n_warmup = 10 ,
709
- n_repeat = 10 ,
710
- quantiles = quantiles ,
711
- )
682
+ else :
683
+ min_ms = float ('nan' )
684
+ max_ms = float ('nan' )
685
+ mean = float ('nan' )
686
+ cv = float ('nan' )
712
687
713
688
elif provider == 'cutlass' :
714
- cutlass_fn = None
715
-
716
689
if MODE == 'fwd' :
717
690
name = 'attention'
718
691
func = getattr (cutlass_kernel , name )
@@ -723,17 +696,15 @@ def cutlass_fwd_fn():
723
696
return out
724
697
725
698
benchmark_suite .assert_close (cutlass_fwd_fn , torch_fn , atol = atol , rtol = 1e-3 , err_msg = 'cutlass to torch' )
726
- cutlass_fn = cutlass_fwd_fn
727
699
728
700
_ , min_ms , max_ms , mean , cv = benchmark_suite .do_bench (
729
- cutlass_fn ,
701
+ cutlass_fwd_fn ,
730
702
n_warmup = 10 ,
731
703
n_repeat = 10 ,
732
704
quantiles = quantiles ,
733
705
)
734
706
735
707
else :
736
- cutlass_fn = None
737
708
min_ms = float ('nan' )
738
709
max_ms = float ('nan' )
739
710
mean = float ('nan' )
@@ -755,9 +726,5 @@ def cutlass_fwd_fn():
755
726
756
727
757
728
if __name__ == '__main__' :
758
- _benchmark = get_benchmark (
759
- fa_kernel_mode = os .getenv ('FA_KERNEL_MODE' , 'fwd' ),
760
- xetla_assert_result = (os .getenv ('XETLA_ASSERT_RESULT' , '0' ) == '1' ),
761
- xetla_warn_mismatch = (os .getenv ('XETLA_WARN_MISMATCH' , '0' ) == '1' ),
762
- )
729
+ _benchmark = get_benchmark (fa_kernel_mode = os .getenv ('FA_KERNEL_MODE' , 'fwd' ), )
763
730
_benchmark .run (show_plots = False , print_data = True )
0 commit comments