Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Add caching for attention backend selection results #1381

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
WIP: add caching for backend selection results
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
cyanguwa committed Dec 19, 2024
commit c8141cb176662815b5f3f9fd47ff947eb40f8b3e
125 changes: 68 additions & 57 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
@@ -135,7 +135,6 @@ def _get_attention_backends(
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True

alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
@@ -156,53 +155,40 @@ def _get_attention_backends(
):
core_attention_bias_requires_grad = True

fused_attn_backends = []
available_backends = None
fused_attention_backend = None

def test():
attention_params = AttentionParams(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
)
_, _, fused_attention_backend, _, available_backends = get_attention_backend(
attention_params
)
return available_backends, fused_attention_backend

backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
with logging_context():
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends
attention_params = AttentionParams(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
)
_attention_backends["update_selection"] = attention_params not in _attention_backends["attention_params"]
_attention_backends["update_env_vars_only"] = False
available_backends, _ = get_attention_backend(attention_params)
return available_backends


model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
#"base_1_1": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
#"base_2_1": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
@@ -222,7 +208,7 @@ def test():
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("workspace_opt", [True])#, False])
@pytest.mark.parametrize("qkv_layout", [None])
@pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False])
@@ -250,14 +236,15 @@ def test_dot_product_attention(
if swa:
window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, window_size)
available_backends, fused_attn_backends = _get_attention_backends(
available_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
flash_attn_supported, fused_attn_supported, unfused_attn_supported, fused_attn_backends = available_backends
print('----',flash_attn_supported, fused_attn_supported, unfused_attn_supported)
# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
if pad_between_seqs and not (
@@ -679,12 +666,16 @@ def _run_dot_product_attention(
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
_attention_backends["backend_selection_requires_update"] = True
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
#_attention_backends["backend_selection_requires_update"] = True
_attention_backends["update_env_vars_only"] = True

# Create seqlens
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
@@ -1045,12 +1036,12 @@ def test_transformer_layer(
workspace_opt = True

# Test backend availability
available_backends, fused_attn_backends = _get_attention_backends(
available_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
flash_attn_supported, fused_attn_supported, unfused_attn_supported, fused_attn_backends = available_backends

# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
@@ -1168,11 +1159,15 @@ def _run_transformer_layer(
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
#_attention_backends["backend_selection_requires_update"] = True
_attention_backends["update_env_vars_only"] = True

# Create input tensor
inp = torch.randn(
@@ -1362,15 +1357,19 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
if _flash_attn_3_is_installed and not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
os.environ["NVTE_UNFUSED_ATTN"] = "0"
#_attention_backends["backend_selection_requires_update"] = True
_attention_backends["update_selection"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
)

os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
os.environ["NVTE_UNFUSED_ATTN"] = "0"
#_attention_backends["backend_selection_requires_update"] = True
_attention_backends["update_selection"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
@@ -1541,15 +1540,19 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
if _flash_attn_3_is_installed and not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
os.environ["NVTE_UNFUSED_ATTN"] = "0"
#_attention_backends["backend_selection_requires_update"] = True
_attention_backends["update_selection"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training
)

os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
os.environ["NVTE_UNFUSED_ATTN"] = "0"
#_attention_backends["backend_selection_requires_update"] = True
_attention_backends["update_selection"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training
@@ -1753,7 +1756,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
config = model_configs_fp8[model]

fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedDotProductAttention")

atol = 5e-1
rtol = 5e-1
@@ -1784,11 +1787,15 @@ def _run_custom_mha_fp8(dtype, config, backend):
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
#_attention_backends["backend_selection_requires_update"] = True
_attention_backends["update_selection"] = True

inp = 0.0001 * torch.randint(
-100,
@@ -1838,11 +1845,15 @@ def _run_ref_mha_f16(dtype, config, backend):

os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
#_attention_backends["backend_selection_requires_update"] = True
_attention_backends["update_selection"] = True

inp = torch.load("qkv.pt").to(device="cuda")
inp.requires_grad = True
225 changes: 148 additions & 77 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
@@ -206,13 +206,20 @@ def _get_supported_versions(version_min, version_max):
_flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0")
_use_flash_attn_3 = True

# maximum number of configs allowed in _attention_backends cache
_max_configs = 10
# attention config cache
_attention_backends = {
"attention_params": None,
"use_flash_attention": None,
"use_fused_attention": None,
"fused_attention_backend": None,
"use_unfused_attention": None,
"backend_selection_requires_update": False,
# list of configs [attention_params]
"attention_params": [],
# available backends for each cached config
"available_backends": [],
# last used backend for each cached config
"selected_backend": [],
# update selected_backend if only environment variables, e.g. NVTE_FUSED_ATTN, have changed
"update_env_vars_only": False,
# update both available_backends and selected_backend
"update_selection": True,
}


@@ -366,7 +373,7 @@ def get_attention_backend(
fp8 = attention_params.fp8
fp8_meta = attention_params.fp8_meta

# Run config
# Print config
logger = logging.getLogger("DotProductAttention")
logger.setLevel(_log_level)
if not logger.hasHandlers():
@@ -393,22 +400,78 @@ def get_attention_backend(
run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
logger.debug("Running with config=%s", run_config)

# check if attention_params already exists in cache
global _attention_backends
#print('entering .....',
# attention_params in _attention_backends["attention_params"],
# _attention_backends["update_env_vars_only"],
# _attention_backends["update_selection"])
#print(_attention_backends["attention_params"])
#print(attention_params)
if (
attention_params in _attention_backends["attention_params"]
and not _attention_backends["update_env_vars_only"]
and not _attention_backends["update_selection"]
):
config_id = _attention_backends["attention_params"].index(attention_params)
available_backends = _attention_backends["available_backends"][config_id]
selected_backend = _attention_backends["selected_backend"][config_id]
print('retrieveing ',config_id) #, available_backends, selected_backend)
return available_backends, selected_backend

# if environment variables such as NVTE_FUSED_ATTN change in the middle of the run
def update_env_vars(logger, available_backends):
# Filter: Environment variables
use_flash_attention, use_fused_attention, use_unfused_attention, fused_attention_backends = available_backends
_use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
if use_flash_attention and not _use_flash_attention and _flash_attn_is_installed:
logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
if use_fused_attention and not _use_fused_attention:
logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
if use_unfused_attention and not _use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0")
use_flash_attention = use_flash_attention and _use_flash_attention
use_fused_attention = use_fused_attention and _use_fused_attention
use_unfused_attention = use_unfused_attention and _use_unfused_attention
fused_attention_backend = None
if use_fused_attention:
if len(fused_attention_backends) == 1:
fused_attention_backend = fused_attention_backends[0]
if len(fused_attention_backends) == 2:
sub_backend = int(os.getenv("NVTE_FUSED_ATTN_BACKEND", "1"))
fused_attention_backend = fused_attention_backends[sub_backend]
selected_backend = [use_flash_attention, use_fused_attention, use_unfused_attention, fused_attention_backend]
return selected_backend
if (
attention_params in _attention_backends["attention_params"]
and _attention_backends["update_env_vars_only"]
and not _attention_backends["update_selection"]
):
config_id = _attention_backends["attention_params"].index(attention_params)
available_backends = _attention_backends["available_backends"][config_id]
selected_backend = update_env_vars(logger, available_backends)
_attention_backends["selected_backend"][config_id] = selected_backend
print('retrieveing update env var',config_id) #, available_backends, selected_backend)
return available_backends, selected_backend

# keep unique attention_params
if attention_params in _attention_backends["attention_params"] and _attention_backends["update_selection"]:
config_id = _attention_backends["attention_params"].index(attention_params)
_attention_backends["attention_params"].pop(config_id)
_attention_backends["available_backends"].pop(config_id)
_attention_backends["selected_backend"].pop(config_id)

# The following sections check if `FlashAttention` supports the provided attention params,
# regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is
# necessary for performance/functionality, a warning will be issued to prompt users to
# install an appropriate FA version.
global _flash_attn_version_required, _flash_attn_max_version, _use_flash_attn_3

# Filter: Environment variables
use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1"))
use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1"))
use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
if not use_flash_attention and _flash_attn_is_installed:
logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
if not use_fused_attention:
logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
if not use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0")
use_flash_attention = True
use_fused_attention = True
use_unfused_attention = True

# Filter: ONNX mode
if is_in_onnx_export_mode():
@@ -799,29 +862,45 @@ def get_attention_backend(
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"

# Filter: cuDNN support
fused_attn_avail_backends = []
fused_attention_backend = None
if use_fused_attention:
q_type = TE_DType[qkv_dtype]
kv_type = q_type
if fp8 and fp8_meta["recipe"].fp8_dpa:
q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
kv_type = q_type
fused_attention_backend = tex.get_fused_attn_backend(
q_type,
kv_type,
QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
attention_dropout,
num_heads,
num_gqa_groups,
max_seqlen_q,
max_seqlen_kv,
head_dim_qk,
head_dim_v,
window_size[0],
window_size[1],
)

def get_fused_attn_backend():
fused_attention_backend = tex.get_fused_attn_backend(
q_type,
kv_type,
QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
attention_dropout,
num_heads,
num_gqa_groups,
max_seqlen_q,
max_seqlen_kv,
head_dim_qk,
head_dim_v,
window_size[0],
window_size[1],
)
return fused_attention_backend

# all available cuDNN sub-backends
for k,v in FusedAttnBackend.items():
if k != "No_Backend":
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(int(v))
backend = get_fused_attn_backend()
if backend == FusedAttnBackend[k]:
fused_attn_avail_backends.append(backend)
del os.environ["NVTE_FUSED_ATTN_BACKEND"]
# selected cuDNN sub-backend
fused_attention_backend = get_fused_attn_backend()

if fused_attention_backend == FusedAttnBackend["No_Backend"]:
logger.debug("Disabling FusedAttention as no backend supports the provided input")
use_fused_attention = False
@@ -833,7 +912,7 @@ def get_attention_backend(
and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
):
logger.debug(
"Disabling FusedAttention as only sub-backend %s does not support "
"Disabling FusedAttention as sub-backend %s does not support "
"slidng window attention",
int(fused_attention_backend),
)
@@ -891,7 +970,11 @@ def get_attention_backend(
use_fused_attention = False

# All available backends
available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention, fused_attn_avail_backends]

# Filter: Environment variables
#selected_backend = [use_flash_attention, use_fused_attention, use_unfused_attention, fused_attention_backend]
use_flash_attention, use_fused_attention, use_unfused_attention, fused_attention_backend = update_env_vars(logger, available_backends)

# `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`.
# When `FusedAttention` does not support the provided attention params, and `FlashAttention`
@@ -909,16 +992,16 @@ def get_attention_backend(
use_flash_attention = False
available_backends[0] = False


fused_attn_str = ""
if len(fused_attn_avail_backends) > 0:
fused_attn_str = " (sub-backend " + " and ".join([str(int(x)) for x in fused_attn_avail_backends]) + ")"
logger.debug(
"Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
" UnfusedDotProductAttention=%s}",
bool(available_backends[0]),
bool(available_backends[1]),
(
f" (sub-backend {int(fused_attention_backend)})"
if fused_attention_backend is not None
else ""
),
fused_attn_str,
bool(available_backends[2]),
)

@@ -949,31 +1032,33 @@ def get_attention_backend(
# Selected backend
if use_flash_attention:
use_fused_attention = False
fused_attention_backend = []
use_unfused_attention = False
elif use_fused_attention:
use_unfused_attention = False
selected_backend = "NoBackend"
backend = "NoBackend"
if use_flash_attention:
selected_backend = "FlashAttention"
backend = "FlashAttention"
elif use_fused_attention:
selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
elif use_unfused_attention:
selected_backend = "UnfusedDotProductAttention"
logger.debug("Selected backend = %s", selected_backend)

global _attention_backends
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
backend = "UnfusedDotProductAttention"
logger.debug("Selected backend = %s", backend)
selected_backend = [use_flash_attention, use_fused_attention, use_unfused_attention, fused_attention_backend]

_attention_backends["attention_params"].append(attention_params)
_attention_backends["available_backends"].append(available_backends)
_attention_backends["selected_backend"].append(selected_backend)
if len(_attention_backends["attention_params"]) > _max_configs:
_attention_backends["attention_params"].pop(0)
_attention_backends["available_backends"].pop(0)
_attention_backends["selected_backend"].pop(0)
_attention_backends["update_selection"] = False
print('addingggg ', len(_attention_backends["attention_params"])) #,available_backends, selected_backend)

return (
use_flash_attention,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
available_backends,
selected_backend,
)


@@ -8193,22 +8278,12 @@ def forward(
fp8=self.fp8,
fp8_meta=self.fp8_meta,
)
global _attention_backends, _use_flash_attn_3
if (
_attention_backends["attention_params"] is None
or attention_params != _attention_backends["attention_params"]
):
_attention_backends["attention_params"] = attention_params
_attention_backends["backend_selection_requires_update"] = True
if _attention_backends["backend_selection_requires_update"]:
_use_flash_attn_3 = _flash_attn_3_is_installed
(
use_flash_attention,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
_,
) = get_attention_backend(attention_params)
is_new_config = (attention_params not in _attention_backends["attention_params"]
or _attention_backends["update_env_vars_only"]
or _attention_backends["update_selection"])
_, selected_backend = get_attention_backend(attention_params)
use_flash_attention, use_fused_attention, use_unfused_attention, fused_attention_backend = selected_backend
if is_new_config:
if use_flash_attention:
self.logger.info(
"Running with FlashAttention backend (version %s)",
@@ -8221,11 +8296,7 @@ def forward(
)
elif use_unfused_attention:
self.logger.info("Running with UnfusedDotProductAttention backend")
else:
use_flash_attention = _attention_backends["use_flash_attention"]
use_fused_attention = _attention_backends["use_fused_attention"]
fused_attention_backend = _attention_backends["fused_attention_backend"]
use_unfused_attention = _attention_backends["use_unfused_attention"]


if use_flash_attention:
if core_attention_bias_type == "alibi":