-
Notifications
You must be signed in to change notification settings - Fork 225
Description
Hi,
When using aiter.flash_attn_func I noticed it has large launch overhead. When investigating I noticed that the following happens in aiter/aiter/jit/core.py lines 818 - 823 in the following code segment:
if module is None:
try:
module = get_module(md_name)
except Exception:
md = custom_build_args.get("md_name", md_name)
module = get_module(md)- It always calls get_module with module_mha_fwd
- That call chain throws an exception from:
__mds[md_name] = importlib.import_module(f"{__package__}.{md_name}")with info: No module named 'aiter.jit.module_mha_fwd'
(This Happens for every call to flash_attn_func and looks like similar thing happens for the backward pass)
- It goes to the exception path in the code below and calls get_module with: mha_fwd_bf16_nbias_nmask_nlse_ndropout_nqscale and it works as intended.
In my super simple test script the launch overhead can be 2-3x faster by just commenting out the path which tries get_module with module_mha_fwd. going from around 230us => 83us. For comparison in my application the actual attn kernels take less than 100us so it's quite large overhead.
I tested with this setup:
docker image: rocm/pytorch:rocm7.1_ubuntu22.04_py3.10_pytorch_release_2.8.0
aiter commit: 7ee5bc8
aiter installation with:
git clone --recursive https://github.com/ROCm/aiter.git
cd aiter
python3 setup.py developpython module:
import torch
import aiter
if __name__ == "__main__":
torch.manual_seed(1234)
device = torch.device("cuda")
dtype = torch.bfloat16
batch_size, seqlen, num_heads, head_dim = 1, 128, 8, 64
q = torch.randn(batch_size, seqlen, num_heads, head_dim, device=device, dtype=dtype)
k = torch.randn_like(q)
v = torch.randn_like(q)
for _ in range(5):
out = aiter.flash_attn_func(q, k, v)Question is can we get rid of the behavior that it always tries get_module with module_mha_fwd and fails every time we call aiter.flash_attn_func and similarly for the backward pass.