PyTorch CUDA extension prototype for Flash Attention forward pass.
flash_attn/
setup.py # CUDA extension build script
flash_attn.py # PyTorch autograd wrapper
csrc/
flash_attn.cpp # PyBind / extension entrypoint
flash_attn.cu # CUDA forward implementation (WIP)
requirements.txt
- NVIDIA GPU
- CUDA toolkit installed (
nvccavailable inPATH) CUDA_HOMEset (example:/usr/local/cuda)- Python with PyTorch installed
- Create and activate a virtual environment.
- Install PyTorch first.
- Build/install the extension in editable mode without build isolation:
pip install torch
pip install -e flash_attn --no-build-isolationimport torch
from flash_attn.flash_attn import flash_attention
Q = torch.randn(2, 8, 128, 64, device="cuda", dtype=torch.float32).contiguous()
K = torch.randn(2, 8, 128, 64, device="cuda", dtype=torch.float32).contiguous()
V = torch.randn(2, 8, 128, 64, device="cuda", dtype=torch.float32).contiguous()
O = flash_attention(Q, K, V)
print(O.shape)CUDA_HOME is not set:export CUDA_HOME=/usr/local/cuda
nvcc was not found:- install CUDA toolkit and add
bin/toPATH
- install CUDA toolkit and add
- PyTorch import/build errors:
- install torch first, then reinstall with:
pip install -e flash_attn --no-build-isolation
- install torch first, then reinstall with: