Skip to content

MaxLSB/flash-attn2

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FlashAttention for Sliding Window Attention in Triton

This repository provides an implementation of FlashAttention-2 Forward and Backward Pass for self-attention in Triton. FlashAttention-2 is a memory-efficient algorithm for computing attention that significantly reduces memory bandwidth requirements and improves performance on GPU hardware.

This implementation supports several configurations for both Forward and Backward Pass including:

  • Sliding Window Attention !!!
  • Global Attention
  • Causal Attention

Note:

  • Partially inspired by OpenAI's Fused Attention in Triton.
  • This implementation is intended for educational purposes and can be optimized.
  • No dropout is applied.
  • Uses FP16 precision.

Introduction

This implementation replicates the trend described in the FlashAttention-2 paper, with significant performance improvements compared to traditional attention mechanisms.

FlashAttention-2 Benchmarks

--> Up to 2-10x speedup compared to a standard PyTorch attention implementation.

This implementation differs slightly from the pseudo-code in the original paper, mainly in the backward pass, which is modified based on the Triton documentation implementation for better performance.

Benchmarking

Measure the speed (in TFLOPs/s) of FlashAttention-2 in Triton versus standard PyTorch attention for the forward + backward pass with the benchmark.py file:

python -m tests.benchmark --attn_mode 'sliding_window' --window_size 500
python -m tests.benchmark --attn_mode 'causal' --batch_size 16
python -m tests.benchmark --attn_mode 'global' --num_heads 8

Testing

Compare the results between the PyTorch and Triton implementations with the test.py file:

python -m tests.test --attn_mode 'sliding_window' --window_size 1000
python -m tests.test --attn_mode 'causal' --head_dim 128
python -m tests.test --attn_mode 'global' --batch_size 8

Each test verifies:

  • Numerical accuracy against a standard PyTorch implementation.
  • Gradient correctness.

Installation

git clone https://github.com/MaxLSB/flash-attn2.git

To Do

  • Fix the current restrictions for Sliding Window:
    • SEQ_LEN >= 4 * BLOCK_SIZE
    • 2 * BLOCK_SIZE <= WINDOW_SIZE <= SEQ_LEN
  • Improve the Autotune in the backward pass

About

FlashAttention for sliding window attention in Triton (fwd + bwd pass)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages