Skip to content

Conversation

@AdityaKane2001
Copy link
Contributor

@AdityaKane2001 AdityaKane2001 commented Oct 30, 2025

TL;DR: Ported over Flash Attention CUTLASS 3.x kernels to NATTEN as-is, with wrappers to fit into NATTEN. Exposing through the flash-fmha backend.

Summary of changes:

  • C++:

    • Added actual kernel files under csrc/include/natten/cuda/flash_fmha/flash_kernel.
    • Torch C++ interface at csrc/src/flash_fmha.cu, which call into csrc/.../flash_fmha/flash_fmha_{forward/bakcward}.cuh
    • Added a utility file csrc/.../flash_kernel/param_utils.h for param conversion.
  • Python

    • Exposed C++ function through flash-fmha backend.
    • Added autograd function and configs for the same.
    • Wherever possible, some arrangement is done to later implement flash-fna backend.
    • Added autogen scripts and tests for flash-fma.

Present rough edges:

  1. Python frontend might have some code style inconsistencies.
  2. Stray template parameters for flash bwd template currently housed in flash_fmha_backward.cuh, as opposed to autogen. It seems that adding those to autogen scripts will make the scripts overly complex.
  3. Although correctness is guaranteed (because of tests), no particular refactoring of the actual flash attn kernel code was done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants