diff --git a/.github/workflows/triton-benchmarks.yml b/.github/workflows/triton-benchmarks.yml index e5358a97fa..a76646f7a3 100644 --- a/.github/workflows/triton-benchmarks.yml +++ b/.github/workflows/triton-benchmarks.yml @@ -276,7 +276,6 @@ jobs: source ../../scripts/capture-hw-details.sh python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-triton-report.csv --benchmark flash-attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG - python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-xetla-report.csv --benchmark flash-attn --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-cutlass-report.csv --benchmark flash-attn --compiler cutlass --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col CUTLASS-TFlops --hbm_col "CUTLASS-GB/s" --tag $TAG - name: Run Triton FA bwd kernel benchmark @@ -302,7 +301,6 @@ jobs: source ../../scripts/capture-hw-details.sh python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-triton-report.csv --benchmark flash-attn-tensor-desc --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG - python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-xetla-report.csv --benchmark flash-attn-tensor-desc --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-cutlass-report.csv --benchmark flash-attn-tensor-desc --compiler cutlass --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col CUTLASS-TFlops --hbm_col "CUTLASS-GB/s" --tag $TAG - name: Run Triton FlexAttention Causal Mask fwd kernel benchmark diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py index b096a154f9..52f02c6f3a 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py @@ -10,7 +10,6 @@ import triton_kernels_benchmark as benchmark_suite from triton_kernels_benchmark import xetla_kernel from triton_kernels_benchmark import cutlass_kernel -import numpy as np # pylint: disable=unused-argument @@ -529,25 +528,10 @@ def backward(ctx, do): attention = _attention.apply -def check_close(f_val, f_ref, atol, rtol): - x = f_val() - y = f_ref() - x = x.cpu().detach().numpy() - y = y.cpu().detach().numpy() - close = np.isclose(x, y, atol=atol, rtol=rtol) - num_close = np.count_nonzero(close) - num_not_close = close.size - num_close - num_perc = num_not_close / close.size * 100 - if num_not_close != 0: - print(f'Warning: {num_not_close}, out of {close.size} elements do not match ({num_perc:.2f}%) in XeTLA impl') - - def get_benchmark( providers_filter: Optional[list[str]] = None, fa_kernel_mode='fwd', attn_fwd=_attn_fwd_with_block_pointers, - xetla_assert_result=False, - xetla_warn_mismatch=False, ): """ Returns a Mark object containing a Benchmark object constructed at runtime and parameterized by the provided option values. @@ -647,33 +631,6 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider): ) elif provider == 'xetla': - xetla_fn = None - if MODE == 'fwd': - module_name = f'flash_attn_causal_{CAUSAL}'.lower() - func = getattr(xetla_kernel, module_name) - out = torch.empty_like(q, device='xpu', dtype=dtype) - size_score = Z * H * N_CTX * N_CTX - size_attn_mask = Z * N_CTX * N_CTX - dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8) - bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype) - size_ml = Z * H * N_CTX - m = torch.empty((size_ml, ), device='xpu', dtype=torch.float) - l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) - - def xetla_fwd_fn(): - func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale) - return out - - xetla_fn = xetla_fwd_fn - - def check_xetla_fwd_result(): - if xetla_assert_result: - benchmark_suite.assert_close(xetla_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='xetla to torch') - elif xetla_warn_mismatch: - check_close(xetla_fn, torch_fn, atol, 1e-3) - - check_xetla_fwd_result() - if MODE == 'bwd': module_name = f'flash_attn_bwd_causal_{CAUSAL}'.lower() func = getattr(xetla_kernel, module_name) @@ -701,18 +658,20 @@ def xetla_bwd_fn(): bias_strideN, bias_strideF, attn_mask_padding) return out - xetla_fn = xetla_bwd_fn + _, min_ms, max_ms, mean, cv = benchmark_suite.do_bench( + xetla_bwd_fn, + n_warmup=10, + n_repeat=10, + quantiles=quantiles, + ) - _, min_ms, max_ms, mean, cv = benchmark_suite.do_bench( - xetla_fn, - n_warmup=10, - n_repeat=10, - quantiles=quantiles, - ) + else: + min_ms = float('nan') + max_ms = float('nan') + mean = float('nan') + cv = float('nan') elif provider == 'cutlass': - cutlass_fn = None - if MODE == 'fwd': name = 'attention' func = getattr(cutlass_kernel, name) @@ -723,17 +682,15 @@ def cutlass_fwd_fn(): return out benchmark_suite.assert_close(cutlass_fwd_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='cutlass to torch') - cutlass_fn = cutlass_fwd_fn _, min_ms, max_ms, mean, cv = benchmark_suite.do_bench( - cutlass_fn, + cutlass_fwd_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, ) else: - cutlass_fn = None min_ms = float('nan') max_ms = float('nan') mean = float('nan') @@ -755,9 +712,5 @@ def cutlass_fwd_fn(): if __name__ == '__main__': - _benchmark = get_benchmark( - fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'), - xetla_assert_result=(os.getenv('XETLA_ASSERT_RESULT', '0') == '1'), - xetla_warn_mismatch=(os.getenv('XETLA_WARN_MISMATCH', '0') == '1'), - ) + _benchmark = get_benchmark(fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'), ) _benchmark.run(show_plots=False, print_data=True) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_tensor_desc_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_tensor_desc_benchmark.py index c3d4639418..99c409b532 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_tensor_desc_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_tensor_desc_benchmark.py @@ -141,22 +141,14 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, # def get_benchmark( providers_filter: Optional[list[str]] = None, fa_kernel_mode='fwd', - xetla_assert_result=False, - xetla_warn_mismatch=False, ): return flash_attention_benchmark.get_benchmark( providers_filter=providers_filter, fa_kernel_mode=fa_kernel_mode, attn_fwd=_attn_fwd_with_tensor_desc, - xetla_assert_result=xetla_assert_result, - xetla_warn_mismatch=xetla_warn_mismatch, ) if __name__ == '__main__': - _benchmark = get_benchmark( - fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'), - xetla_assert_result=(os.getenv('XETLA_ASSERT_RESULT', '0') == '1'), - xetla_warn_mismatch=(os.getenv('XETLA_WARN_MISMATCH', '0') == '1'), - ) + _benchmark = get_benchmark(fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'), ) _benchmark.run(show_plots=False, print_data=True)