Skip to content

Large launch overhead for aiter.flash_attn_func #2129

@SampoAMD

Description

@SampoAMD

Hi,

When using aiter.flash_attn_func I noticed it has large launch overhead. When investigating I noticed that the following happens in aiter/aiter/jit/core.py lines 818 - 823 in the following code segment:

if module is None:
    try:
        module = get_module(md_name)
    except Exception:
        md = custom_build_args.get("md_name", md_name)
        module = get_module(md)
  1. It always calls get_module with module_mha_fwd
  2. That call chain throws an exception from: 
__mds[md_name] = importlib.import_module(f"{__package__}.{md_name}")

with info: No module named 'aiter.jit.module_mha_fwd'
(This Happens for every call to flash_attn_func and looks like similar thing happens for the backward pass)

  1. It goes to the exception path in the code below and calls get_module with: mha_fwd_bf16_nbias_nmask_nlse_ndropout_nqscale and it works as intended.

In my super simple test script the launch overhead can be 2-3x faster by just commenting out the path which tries get_module with module_mha_fwd. going from around 230us => 83us. For comparison in my application the actual attn kernels take less than 100us so it's quite large overhead.

I tested with this setup:

docker image: rocm/pytorch:rocm7.1_ubuntu22.04_py3.10_pytorch_release_2.8.0
aiter commit: 7ee5bc8

aiter installation with:

git clone --recursive https://github.com/ROCm/aiter.git
cd aiter
python3 setup.py develop

python module:

import torch
import aiter    

if __name__ == "__main__":

    torch.manual_seed(1234)
    device = torch.device("cuda")
    dtype = torch.bfloat16
    batch_size, seqlen, num_heads, head_dim = 1, 128, 8, 64
    q = torch.randn(batch_size, seqlen, num_heads, head_dim, device=device, dtype=dtype)
    k = torch.randn_like(q)
    v = torch.randn_like(q)

    for _ in range(5):
        out = aiter.flash_attn_func(q, k, v)

Question is can we get rid of the behavior that it always tries get_module with module_mha_fwd and fails every time we call aiter.flash_attn_func and similarly for the backward pass.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions