Skip to content

[FlashAttention] Remove XeTLA for fwd mode #4524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/triton-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ jobs:

source ../../scripts/capture-hw-details.sh
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-triton-report.csv --benchmark flash-attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-xetla-report.csv --benchmark flash-attn --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-cutlass-report.csv --benchmark flash-attn --compiler cutlass --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col CUTLASS-TFlops --hbm_col "CUTLASS-GB/s" --tag $TAG

- name: Run Triton FA bwd kernel benchmark
Expand All @@ -302,7 +301,6 @@ jobs:

source ../../scripts/capture-hw-details.sh
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-triton-report.csv --benchmark flash-attn-tensor-desc --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-xetla-report.csv --benchmark flash-attn-tensor-desc --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-cutlass-report.csv --benchmark flash-attn-tensor-desc --compiler cutlass --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col CUTLASS-TFlops --hbm_col "CUTLASS-GB/s" --tag $TAG

- name: Run Triton FlexAttention Causal Mask fwd kernel benchmark
Expand Down
73 changes: 13 additions & 60 deletions benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import triton_kernels_benchmark as benchmark_suite
from triton_kernels_benchmark import xetla_kernel
from triton_kernels_benchmark import cutlass_kernel
import numpy as np


# pylint: disable=unused-argument
Expand Down Expand Up @@ -529,25 +528,10 @@ def backward(ctx, do):
attention = _attention.apply


def check_close(f_val, f_ref, atol, rtol):
x = f_val()
y = f_ref()
x = x.cpu().detach().numpy()
y = y.cpu().detach().numpy()
close = np.isclose(x, y, atol=atol, rtol=rtol)
num_close = np.count_nonzero(close)
num_not_close = close.size - num_close
num_perc = num_not_close / close.size * 100
if num_not_close != 0:
print(f'Warning: {num_not_close}, out of {close.size} elements do not match ({num_perc:.2f}%) in XeTLA impl')


def get_benchmark(
providers_filter: Optional[list[str]] = None,
fa_kernel_mode='fwd',
attn_fwd=_attn_fwd_with_block_pointers,
xetla_assert_result=False,
xetla_warn_mismatch=False,
):
"""
Returns a Mark object containing a Benchmark object constructed at runtime and parameterized by the provided option values.
Expand Down Expand Up @@ -647,33 +631,6 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
)

elif provider == 'xetla':
xetla_fn = None
if MODE == 'fwd':
module_name = f'flash_attn_causal_{CAUSAL}'.lower()
func = getattr(xetla_kernel, module_name)
out = torch.empty_like(q, device='xpu', dtype=dtype)
size_score = Z * H * N_CTX * N_CTX
size_attn_mask = Z * N_CTX * N_CTX
dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8)
bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype)
size_ml = Z * H * N_CTX
m = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)

def xetla_fwd_fn():
func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
return out

xetla_fn = xetla_fwd_fn

def check_xetla_fwd_result():
if xetla_assert_result:
benchmark_suite.assert_close(xetla_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='xetla to torch')
elif xetla_warn_mismatch:
check_close(xetla_fn, torch_fn, atol, 1e-3)

check_xetla_fwd_result()

if MODE == 'bwd':
module_name = f'flash_attn_bwd_causal_{CAUSAL}'.lower()
func = getattr(xetla_kernel, module_name)
Expand Down Expand Up @@ -701,18 +658,20 @@ def xetla_bwd_fn():
bias_strideN, bias_strideF, attn_mask_padding)
return out

xetla_fn = xetla_bwd_fn
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
xetla_bwd_fn,
n_warmup=10,
n_repeat=10,
quantiles=quantiles,
)

_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
xetla_fn,
n_warmup=10,
n_repeat=10,
quantiles=quantiles,
)
else:
min_ms = float('nan')
max_ms = float('nan')
mean = float('nan')
cv = float('nan')

elif provider == 'cutlass':
cutlass_fn = None

if MODE == 'fwd':
name = 'attention'
func = getattr(cutlass_kernel, name)
Expand All @@ -723,17 +682,15 @@ def cutlass_fwd_fn():
return out

benchmark_suite.assert_close(cutlass_fwd_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='cutlass to torch')
cutlass_fn = cutlass_fwd_fn

_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
cutlass_fn,
cutlass_fwd_fn,
n_warmup=10,
n_repeat=10,
quantiles=quantiles,
)

else:
cutlass_fn = None
min_ms = float('nan')
max_ms = float('nan')
mean = float('nan')
Expand All @@ -755,9 +712,5 @@ def cutlass_fwd_fn():


if __name__ == '__main__':
_benchmark = get_benchmark(
fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'),
xetla_assert_result=(os.getenv('XETLA_ASSERT_RESULT', '0') == '1'),
xetla_warn_mismatch=(os.getenv('XETLA_WARN_MISMATCH', '0') == '1'),
)
_benchmark = get_benchmark(fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'), )
_benchmark.run(show_plots=False, print_data=True)
Original file line number Diff line number Diff line change
Expand Up @@ -141,22 +141,14 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, #
def get_benchmark(
providers_filter: Optional[list[str]] = None,
fa_kernel_mode='fwd',
xetla_assert_result=False,
xetla_warn_mismatch=False,
):
return flash_attention_benchmark.get_benchmark(
providers_filter=providers_filter,
fa_kernel_mode=fa_kernel_mode,
attn_fwd=_attn_fwd_with_tensor_desc,
xetla_assert_result=xetla_assert_result,
xetla_warn_mismatch=xetla_warn_mismatch,
)


if __name__ == '__main__':
_benchmark = get_benchmark(
fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'),
xetla_assert_result=(os.getenv('XETLA_ASSERT_RESULT', '0') == '1'),
xetla_warn_mismatch=(os.getenv('XETLA_WARN_MISMATCH', '0') == '1'),
)
_benchmark = get_benchmark(fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'), )
_benchmark.run(show_plots=False, print_data=True)