From 9f361aa02e513d7040eca5737ffaff75252abec3 Mon Sep 17 00:00:00 2001 From: Gareth Jones Date: Sun, 23 Feb 2025 18:23:07 -0800 Subject: [PATCH 1/5] Stage accumulator fragment to shared memory using tiled copy --- csrc/flash_fwd_mla_kernel.h | 738 ++++++++++++++++++------------------ 1 file changed, 369 insertions(+), 369 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 55f6811..a06cc03 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -13,10 +13,14 @@ using namespace cute; #include "static_switch.h" #include "flash_mla.h" - -template +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Helper: Decide K-Layout at SMEM level given type and dimension. +/// Swizzling is determined primarily by alignment constraints. +/// Return GMMA Layout at compile time. +//////////////////////////////////////////////////////////////////////////////////////////////////// +template constexpr auto getSmemLayoutK() { - constexpr int headSizeBytes = sizeof(PrecType) * DIM; + constexpr int headSizeBytes = sizeof(PrecType) * DIM; constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2; if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { @@ -28,466 +32,462 @@ constexpr auto getSmemLayoutK() { } } -template -struct Flash_fwd_kernel_traits_mla { - using Element = elem_type; +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Kernel Trait: FWD MLA for Flash Attention +/// - Templated on HeadDim (kHeadDim_), block tiling, warp usage, etc. +/// - Provides all necessary sub-layouts for Q/K/V, softmax partials, etc. +//////////////////////////////////////////////////////////////////////////////////////////////////// +template < + int kHeadDim_, + int kBlockM_, + int kBlockN_, + int kNumWarps_, + typename ElemType = cutlass::bfloat16_t, + int kHeadDimV_ = 0 +> +struct FlashFwdKernelTraitsMLA { + using Element = ElemType; using ElementAccum = float; - using index_t = int64_t; - - static constexpr int kNWarps = kNWarps_; - static constexpr int kNThreads = kNWarps * 32; - static constexpr int kNWarpsS = 4; - static constexpr int kNThreadsS = kNWarpsS * 32; - - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; + using IndexT = int64_t; + + // Warp organization + static constexpr int kNumWarps = kNumWarps_; + static constexpr int kNumThreads = kNumWarps * 32; + static constexpr int kNumWarpsSoftmax = 4; + static constexpr int kNumThreadsSoftmax = kNumWarpsSoftmax * 32; + + // Tiling in M, N, K + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); - static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; + + // Possibly distinct V-dimension + static constexpr int kHeadDimV = (kHeadDimV_ != 0) ? kHeadDimV_ : kHeadDim; static_assert(kHeadDimV % 32 == 0); static_assert(kHeadDimV <= kHeadDim); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; - using TiledMma = decltype(make_tiled_mma( - cute::GMMA::ss_op_selector, Int, Int>, - GMMA::Major::K, GMMA::Major::K>(), - Layout, _1, _1>>{})); + // SMEM swizzling for partial K/V + static constexpr int kBlockKSmem = (kHeadDim % 64 == 0) ? 64 : 32; + static constexpr int kSwizzle = (kBlockKSmem == 32) ? 2 : 3; - static constexpr int AtomLayoutNO = kNThreads / kNThreadsS; + // GMMA Tiled Mma + // Q*K -> S + using TiledMma = decltype(make_tiled_mma( + cute::GMMA::ss_op_selector< + Element, Element, ElementAccum, + Shape, Int, Int>, + GMMA::Major::K, GMMA::Major::K + >(), + Layout, _1, _1>>{} + )); + + // S*V -> O + // For the O “outer product,” we define the shape in [M, HeadDimV, N]. + static constexpr int AtomLayoutNO = kNumThreads / kNumThreadsSoftmax; using TiledMmaO = decltype(make_tiled_mma( - cute::GMMA::rs_op_selector, Int, Int>, - GMMA::Major::K, GMMA::Major::MN>(), - Layout, Int, _1>>{})); - - using SmemLayoutQ = decltype(tile_to_shape( + cute::GMMA::rs_op_selector< + Element, Element, ElementAccum, + Shape, Int, Int>, + GMMA::Major::K, GMMA::Major::MN + >(), + Layout, Int, _1>>{} + )); + + //////////////////////////////////////////////////////////////////////////////////////////////////// + /// SMEM Layout definitions: Q/K/V, P, row-scale, etc. + //////////////////////////////////////////////////////////////////////////////////////////////////// + using SmemLayoutQ = decltype( + tile_to_shape( getSmemLayoutK(), - Shape, Int>{})); + Shape, Int>{} + ) + ); - using SmemLayoutK = decltype(tile_to_shape( + using SmemLayoutK = decltype( + tile_to_shape( getSmemLayoutK(), - Shape, Int>{})); + Shape, Int>{} + ) + ); - using SmemLayoutV = decltype(tile_to_shape( + using SmemLayoutV = decltype( + tile_to_shape( getSmemLayoutK(), - Shape, Int>{})); - using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + Shape, Int>{} + ) + ); + using SmemLayoutVtransposed = decltype( + composition( + SmemLayoutV{}, + make_layout( + Shape, Int>{}, + GenRowMajor{} + ) + ) + ); - using SmemLayoutP = Layout, Int, _1, Int>>; - using SmemLayoutRow = Layout>, Stride<_1, _2>>; + // For partial S data (softmax region) + using SmemLayoutP = Layout, Int, _1, Int>>; + using SmemLayoutRow = Layout>, Stride<_1, _2>>; - using SmemLayoutAtomO = decltype(composition( + // Layout for the O tile in smem + using SmemLayoutAtomO = decltype( + composition( Swizzle{}, - Layout, Int>, Stride, _1>>{})); - using SmemLayoutO = decltype(tile_to_shape( + Layout, Int>, Stride, _1>>{} + ) + ); + using SmemLayoutO = decltype( + tile_to_shape( SmemLayoutAtomO{}, - Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; - using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; + Shape, Int>{} + ) + ); - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + //////////////////////////////////////////////////////////////////////////////////////////////////// + /// Copy Atoms for SMEM read/write + //////////////////////////////////////////////////////////////////////////////////////////////////// + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; + + //////////////////////////////////////////////////////////////////////////////////////////////////// + /// GMEM Tiled Copies for Q/K/V + //////////////////////////////////////////////////////////////////////////////////////////////////// + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must align with vector load size"); static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL; - static constexpr int kNThreadsLoad = kNThreads - kNThreadsS; - static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemCopyStruct = SM80_CP_ASYNC_CACHEGLOBAL; + static constexpr int kNumThreadsLoad = kNumThreads - kNumThreadsSoftmax; + static_assert(kNumThreadsLoad % kGmemThreadsPerRow == 0, "Thread counts must match row partitions"); using GmemLayoutAtom = Layout< - Shape, Int>, - Stride, _1>>; - using GmemTiledCopy = decltype(make_tiled_copy( - Copy_Atom{}, + Shape, Int>, + Stride, _1> + >; + using GmemTiledCopy = decltype( + make_tiled_copy( + Copy_Atom{}, GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read + Layout>{} // 8 vals per read + ) + ); + // For storing O to GMEM using GmemLayoutAtomO = Layout< - Shape, Int>, - Stride, _1>>; - using GmemTiledCopyO = decltype(make_tiled_copy( + Shape, Int>, + Stride, _1> + >; + using GmemTiledCopyO = decltype( + make_tiled_copy( Copy_Atom, Element>{}, GmemLayoutAtomO{}, - Layout>{})); // Val layout, 8 vals per store + Layout>{} // 8 vals per store + ) + ); - static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); + // For accumulation path (split) + static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum; using GmemLayoutAtomOaccum = Layout< - Shape, Int>, - Stride, _1>>; - using GmemTiledCopyOaccum = decltype(make_tiled_copy( + Shape, Int>, + Stride, _1> + >; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy( Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, - Layout>{})); // Val layout, 4 vals per store + Layout>{} // 4 vals per store + ) + ); }; +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Shared Storage Container for MLA +/// - Re-used union across Q/K/P/O or row sums, etc. +//////////////////////////////////////////////////////////////////////////////////////////////////// namespace flash { using namespace cute; -template +template struct SharedStorageMLA { union { struct { - cute::array_aligned> smem_q; - cute::array_aligned * 2> smem_k; // Double buffer - cute::array_aligned> smem_p; - cute::array_aligned> smem_scale; + cute::array_aligned> smem_q; + cute::array_aligned * 2> smem_k; // double buffer + cute::array_aligned> smem_p; + cute::array_aligned> smem_scale; }; struct { - cute::array_aligned> smem_max; - cute::array_aligned> smem_sum; - cute::array_aligned> smem_o; + cute::array_aligned> smem_max; + cute::array_aligned> smem_sum; + cute::array_aligned> smem_o; }; }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, - SharedStorage &shared_storage, AccO tOrO, Softmax softmax) { - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kHeadDimV = Kernel_traits::kHeadDimV; - constexpr int kNThreadsS = Kernel_traits::kNThreadsS; - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; +/// store() Epilogue for partial or non-partial results +/// - Manages writing O/accumulation to global memory + writing out LSE for row block. +//////////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename KernelTraits, + bool Split, + typename SharedStorage, + typename AccO, + typename Softmax +> +__forceinline__ __device__ +void store( + const Flash_fwd_mla_params ¶ms, + const int batch_id, + const int head_id, + const int m_block, + const int n_split_idx, + SharedStorage &shared_storage, + AccO tOrO, + Softmax softmax +) { + constexpr int kBlockM = KernelTraits::kBlockM; + constexpr int kHeadDimV = KernelTraits::kHeadDimV; + constexpr int kNumThreadsS = KernelTraits::kNumThreadsSoftmax; + using Element = typename KernelTraits::Element; + using ElementAccum = typename KernelTraits::ElementAccum; + using IndexT = typename KernelTraits::IndexT; const int tidx = threadIdx.x; - typename Kernel_traits::TiledMmaO tiled_mma_o; + typename KernelTraits::TiledMmaO tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); - // Epilogue + // Softmax LSE for final normalization + auto lse = softmax.template normalize_softmax_lse(tOrO, params.scale_softmax); - const int split_offset = __ldg(params.num_splits_ptr + bidb); + // Decide if writing ephemeral partial results (float accumulation) or final (Element). + using ElementO = std::conditional_t; - Tensor lse = softmax.template normalize_softmax_lse(tOrO, params.scale_softmax); + // Prepare SMEM for O + Tensor sOaccum = make_tensor( + make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), + typename KernelTraits::SmemLayoutO{} + ); + auto smem_tiled_copy_Oaccum = make_tiled_copy_C( + std::conditional_t{}, + tiled_mma_o + ); - using ElementO = std::conditional_t; - Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) - // Partition sO to match the accumulator partitioning - using SmemTiledCopyO = std::conditional_t< - !Split, - typename Kernel_traits::SmemCopyAtomO, - typename Kernel_traits::SmemCopyAtomOaccum - >; - auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o); auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor rO = flash::convert_type(tOrO); - Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor rO = flash::convert_type(tOrO); + Tensor taccOrO = smem_thr_copy_Oaccum.retile_S(rO); + Tensor taccOsO = smem_thr_copy_Oaccum.partition_D(sOaccum); __syncthreads(); + cute::copy(smem_tiled_copy_Oaccum, taccOrO, taccOsO); + + // Compute GMEM offsets + const IndexT row_offset_o = batch_id * params.o_batch_stride + + m_block * kBlockM * params.o_row_stride + + head_id * params.o_head_stride; + const IndexT row_offset_oaccum = (((__ldg(params.num_splits_ptr + batch_id) + n_split_idx) + * params.h + head_id) + * params.seqlen_q + (m_block * kBlockM)) * params.d_v; + const IndexT row_offset_lse = (batch_id * params.h + head_id) * params.seqlen_q + m_block * kBlockM; + const IndexT row_offset_lseaccum = (((__ldg(params.num_splits_ptr + batch_id) + n_split_idx) + * params.h + head_id) + * params.seqlen_q + (m_block * kBlockM)); + + // Prepare GMEM for final or partial O + Tensor gOaccum = make_tensor( + make_gmem_ptr( + reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + + (Split ? row_offset_oaccum : row_offset_o) + ), + Shape, Int>{}, + make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}) + ); - cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); - - const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), - Shape, Int>{}, - make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), - Shape>{}, Stride<_1>{}); + // Prepare GMEM LSE + Tensor gLSEaccum = make_tensor( + make_gmem_ptr( + reinterpret_cast( + Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr + ) + (Split ? row_offset_lseaccum : row_offset_lse) + ), + Shape>{}, + Stride<_1>{} + ); - using GmemTiledCopyO = std::conditional_t; - GmemTiledCopyO gmem_tiled_copy_Oaccum; + // Tiled copy from SMEM -> GMEM for O + using GmemTiledCopyOAccum = std::conditional_t< + !Split, + typename KernelTraits::GmemTiledCopyO, + typename KernelTraits::GmemTiledCopyOaccum + >; + GmemTiledCopyOAccum gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); __syncthreads(); - if (tidx >= kNThreadsS) { return; } + // If out of range of the "softmax" portion, do not store + if (tidx >= kNumThreadsS) { return; } + // Load from SMEM Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); - Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1) - Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); - CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + // Write out the LSE + auto caccO = make_identity_tensor(Shape, Int>{}); + auto taccOcO = thr_mma_o.partition_C(caccO); + auto taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); + if (get<1>(taccOcO_row(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO_row(mi)); - if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + if (row < params.seqlen_q - m_block * kBlockM) { + gLSEaccum(row) = lse(mi); + } } } - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); - // Clear_OOB_K must be false since we don't want to write zeros to gmem + // Identity layout for sO + auto cO = make_identity_tensor( + make_shape(size<0>(sOaccum), size<1>(sOaccum)) + ); + auto tOcO = gmem_thr_copy_Oaccum.partition_D(cO); + auto tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + + // Copy final O back to GMEM flash::copy( - gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, + params.seqlen_q - m_block * kBlockM ); } -template -__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params ¶ms, - const int bidb, const int bidh, const int m_block, - const int n_split_idx, const int seqlen_k, - const int n_block_min, const int n_block_max, const bool NoSplit, - SharedStorage &shared_storage) { - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kHeadDimV = Kernel_traits::kHeadDimV; - constexpr int kNThreads = Kernel_traits::kNThreads; - constexpr int kNThreadsS = Kernel_traits::kNThreadsS; - static_assert(kNThreads == 256 and kNThreadsS == 128); - using Element = typename Kernel_traits::Element; - using index_t = typename Kernel_traits::index_t; +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// compute_attn_1rowblock_splitkv_mla() +/// - Core logic for Q*K -> S -> Softmax -> S*V -> O +/// - Includes partial accumulation for splits and optional causal masking. +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +__forceinline__ __device__ +void compute_attn_1rowblock_splitkv_mla( + const Flash_fwd_mla_params ¶ms, + const int batch_id, + const int head_id, + const int m_block, + const int n_split_idx, + const int seqlen_k, + const int n_block_min, + const int n_block_max, + const bool no_split, + SharedStorage &shared_storage +) { + constexpr int kBlockM = KernelTraits::kBlockM; + constexpr int kBlockN = KernelTraits::kBlockN; + constexpr int kHeadDim = KernelTraits::kHeadDim; + constexpr int kHeadDimV = KernelTraits::kHeadDimV; + constexpr int kNumThreads = KernelTraits::kNumThreads; + constexpr int kNumThreadsS = KernelTraits::kNumThreadsSoftmax; + using Element = typename KernelTraits::Element; + using IndexT = typename KernelTraits::IndexT; + + static_assert(kNumThreads == 256 && kNumThreadsS == 128, "Expected 256 main threads, 128 softmax threads."); const int tidx = threadIdx.x; - int n_block = n_block_max - 1; - - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); - Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); - - Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); - Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); - Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{}); - Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS); - Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{}); - Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS); - Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{}); - Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS); - - typename Kernel_traits::TiledMmaO tiled_mma_o; - auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); - Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N) - Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) - clear(tOrO); - - flash::Softmax<2 * size<1>(tOrO)> softmax; - - int warp_group_idx = cutlass::canonical_warp_group_idx(); - if (warp_group_idx == 0) { - typename Kernel_traits::TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(tidx); - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) - - if (n_block % 2 == 1) { - // Double buffer for sK - constexpr int sK_offset = size(sK); - tSrK.data() = tSrK.data() + sK_offset / 8; - tOrVt.data() = tOrVt.data() + sK_offset / 8; - } - - // We need masking on S for the very last block when K and V has length not multiple of kBlockN. - // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. - // We will have at least 1 "masking" iteration. - // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to - // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; -#pragma unroll 1 - for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) { - __syncthreads(); + int n_block = n_block_max - 1; - Tensor tSrS = partition_fragment_C(tiled_mma, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) - flash::gemm(tiled_mma, tSrQ, tSrK, tSrS); - - const bool is_masking_step = masking_step > 0; - const bool is_first_masking_step = masking_step == n_masking_steps; - - if (is_masking_step) { - Tensor cS = make_identity_tensor(Shape, Int>{}); - Tensor tScS = thr_mma.partition_C(cS); -#pragma unroll - for (int i = 0; i < size(tSrS); ++i) { - if constexpr (!Is_causal) { // Just masking based on col - if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY; - } else { - // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups - // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups - int row = int(get<0>(tScS(i))); - int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; - if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY; - } - } - } - - // We have key_padding_mask so we'll need to Check_inf - Tensor scale_o = is_first_masking_step - ? softmax.template softmax(tSrS, params.scale_softmax_log2) - : is_masking_step ? - softmax.template softmax(tSrS, params.scale_softmax_log2) - : softmax.template softmax(tSrS, params.scale_softmax_log2); + // Smem pointers for Q, K, V, partial S, etc. + Tensor sQ = make_tensor( + make_smem_ptr(shared_storage.smem_q.data()), + typename KernelTraits::SmemLayoutQ{} + ); + Tensor sK = make_tensor( + make_smem_ptr(shared_storage.smem_k.data()), + typename KernelTraits::SmemLayoutK{} + ); + Tensor sV = make_tensor( + make_smem_ptr(shared_storage.smem_k.data()), + typename KernelTraits::SmemLayoutV{} + ); + Tensor sVt = make_tensor( + make_smem_ptr(shared_storage.smem_k.data()), + typename KernelTraits::SmemLayoutVtransposed{} + ); - Tensor rP = flash::convert_type(tSrS); - cute::copy(rP, tPsP); - cute::copy(scale_o, tScale_osScale_o); + // Softmax partial + Tensor sP = make_tensor( + make_smem_ptr(shared_storage.smem_p.data()), + typename KernelTraits::SmemLayoutP{} + ); + Tensor tPsP = sP(_, tidx % kNumThreadsS, _, _); - cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SReady)); + // Row-based scale, sum, etc. + Tensor sScale = make_tensor( + make_smem_ptr(shared_storage.smem_scale.data()), + typename KernelTraits::SmemLayoutRow{} + ); + Tensor tScale = sScale(_, tidx % kNumThreadsS); + Tensor sRowMax = make_tensor( + make_smem_ptr(shared_storage.smem_max.data()), + typename KernelTraits::SmemLayoutRow{} + ); + Tensor tRowMax = sRowMax(_, tidx % kNumThreadsS); + Tensor sRowSum = make_tensor( + make_smem_ptr(shared_storage.smem_sum.data()), + typename KernelTraits::SmemLayoutRow{} + ); + Tensor tRowSum = sRowSum(_, tidx % kNumThreadsS); - flash::rescale_o(tOrO, scale_o); + // Mma for O + typename KernelTraits::TiledMmaO tiled_mma_o; + auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); + Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); + Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape, Int>{}); + clear(tOrO); - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); + // Combined softmax utility + flash::Softmax<2 * size<1>(tOrO)> softmax; - // Double buffer for sK - const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tSrK.data() = tSrK.data() + sK_offset / 8; - tOrVt.data() = tOrVt.data() + sK_offset / 8; - } + // Warp group logic: warpGroupIdx=0 does Q*K->S, warpGroupIdx=1 does async loads for next iteration + int warpGroupIdx = cutlass::canonical_warp_group_idx(); + if (warpGroupIdx == 0) { + // Main matmul Q*K -> S + typename KernelTraits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); - cute::copy(softmax.row_max, tRow_maxsRow_max); - cute::copy(softmax.row_sum, tRow_sumsRow_sum); - cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); - } else { - const int *block_table = params.block_table + bidb * params.block_table_batch_stride; - int cur_block_table = __ldg(&block_table[n_block]); - - const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; - Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, - make_stride(params.q_row_stride, _1{})); - typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; - auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); - Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, - params.seqlen_q - m_block * kBlockM); - - const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; - Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - Shape, Int>{}, - make_stride(params.k_row_stride, _1{})); - typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K; - auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS); - Tensor tKgK = gmem_thr_copy_K.partition_S(gK); - Tensor tKsK = gmem_thr_copy_K.partition_D(sK); - Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) - Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); + Tensor tSrK = thr_mma.partition_fragment_B(sK); + // If n_block is odd => shift for double-buffer if (n_block % 2 == 1) { - // Double buffer for sK - constexpr int sK_offset = size(sK); - tKsK.data() = tKsK.data() + sK_offset; - tOrVt.data() = tOrVt.data() + sK_offset / 8; - } - - // We need to clear the sK smem tiles because K is V. - const index_t offset_k = cur_block_table * params.k_batch_stride; - tKgK.data() = tKgK.data() + offset_k; - flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK, - seqlen_k - n_block * kBlockN); - tKgK.data() = tKgK.data() + -offset_k; - cute::cp_async_fence(); - - if (n_block - 1 >= n_block_min) { - cur_block_table = __ldg(&block_table[n_block - 1]); - } - -#pragma unroll 1 - for (; n_block >= n_block_min; --n_block) { - flash::cp_async_wait<0>(); - __syncthreads(); - - if (n_block - 1 >= n_block_min) { - // Double buffer for sK - const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tKsK.data() = tKsK.data() + sK_offset; - - const index_t offset_k = cur_block_table * params.k_batch_stride; - tKgK.data() = tKgK.data() + offset_k; - flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK); - tKgK.data() = tKgK.data() + -offset_k; - cute::cp_async_fence(); - } - - cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady)); - - if (n_block - 2 >= n_block_min) { - cur_block_table = __ldg(&block_table[n_block - 2]); - } - - typename Kernel_traits::TiledMma tiled_mma; - auto tSrS_layout = partition_fragment_C(tiled_mma, Shape, Int>{}).layout(); - Tensor rP = make_tensor(tSrS_layout); - Tensor scale_o = make_tensor(Shape<_2>{}); - cute::copy(tScale_osScale_o, scale_o); - cute::copy(tPsP, rP); - - flash::rescale_o(tOrO, scale_o); - - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); - - // Double buffer for sK - const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tOrVt.data() = tOrVt.data() + sK_offset / 8; + constexpr int sKOffset = size(sK); + tSrK.data() += (sKOffset / 8); + tOrVt.data() += (sKOffset / 8); } - cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); - cute::copy(tRow_maxsRow_max, softmax.row_max); - cute::copy(tRow_sumsRow_sum, softmax.row_sum); - } - - if (NoSplit) - store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); - else - store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); -} - -template -__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1) -flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) { - constexpr int kBlockN = Kernel_traits::kBlockN; - const int m_block = blockIdx.x; - const int bidh = blockIdx.y; - const int partition_idx = blockIdx.z; - - extern __shared__ char shared_memory[]; - auto &shared_storage = *reinterpret_cast(shared_memory); - - int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; - int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); - int begin_idx = tile_scheduler_metadata.x; - int begin_seqlen = tile_scheduler_metadata.y; - int end_idx = tile_scheduler_metadata.z; - int end_seqlen = tile_scheduler_metadata.w; - if (begin_idx >= params.b) return; - int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + // We have a loop from n_block_max-1 down to n_block_min + // Need to do “masking step(s)” for partial or causal scenarios. + constexpr int nMaskingSteps = !IsCausal + ? 1 + : cute::ceil_div(kBlockM, kBlockN) + 1; #pragma unroll 1 - for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { - const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; - const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id); - const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; - const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); - const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); - if (batch_id > begin_idx) { - __syncthreads(); // Barrier between two tiles. - } - flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void __launch_bounds__(256, 1, 1) -flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) { - constexpr int kNThreads = 128; - - const int tidx = threadIdx.x; - const int bidx = blockIdx.x; + for (int masking const int hs = params.h * params.seqlen_q; const int batch_idx = bidx / hs; const int hs_idx = bidx % hs; From 5fb94d668f46ed383e17ab5640a3a2b0b854da22 Mon Sep 17 00:00:00 2001 From: Gareth Jones Date: Sun, 23 Feb 2025 18:38:15 -0800 Subject: [PATCH 2/5] Stage accumulator fragment to shared memory using tiled copy --- csrc/flash_fwd_mla_kernel.h | 739 ++++++++++++++++++------------------ 1 file changed, 370 insertions(+), 369 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index a06cc03..e3e46fd 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -13,14 +13,10 @@ using namespace cute; #include "static_switch.h" #include "flash_mla.h" -//////////////////////////////////////////////////////////////////////////////////////////////////// -/// Helper: Decide K-Layout at SMEM level given type and dimension. -/// Swizzling is determined primarily by alignment constraints. -/// Return GMMA Layout at compile time. -//////////////////////////////////////////////////////////////////////////////////////////////////// -template + +template constexpr auto getSmemLayoutK() { - constexpr int headSizeBytes = sizeof(PrecType) * DIM; + constexpr int headSizeBytes = sizeof(PrecType) * DIM; constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2; if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { @@ -32,462 +28,467 @@ constexpr auto getSmemLayoutK() { } } -//////////////////////////////////////////////////////////////////////////////////////////////////// -/// Kernel Trait: FWD MLA for Flash Attention -/// - Templated on HeadDim (kHeadDim_), block tiling, warp usage, etc. -/// - Provides all necessary sub-layouts for Q/K/V, softmax partials, etc. -//////////////////////////////////////////////////////////////////////////////////////////////////// -template < - int kHeadDim_, - int kBlockM_, - int kBlockN_, - int kNumWarps_, - typename ElemType = cutlass::bfloat16_t, - int kHeadDimV_ = 0 -> -struct FlashFwdKernelTraitsMLA { - using Element = ElemType; +template +struct Flash_fwd_kernel_traits_mla { + using Element = elem_type; using ElementAccum = float; - using IndexT = int64_t; - - // Warp organization - static constexpr int kNumWarps = kNumWarps_; - static constexpr int kNumThreads = kNumWarps * 32; - static constexpr int kNumWarpsSoftmax = 4; - static constexpr int kNumThreadsSoftmax = kNumWarpsSoftmax * 32; - - // Tiling in M, N, K - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; - static_assert(kHeadDim % 32 == 0); + using index_t = int64_t; + + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + static constexpr int kNWarpsS = 4; + static constexpr int kNThreadsS = kNWarpsS * 32; - // Possibly distinct V-dimension - static constexpr int kHeadDimV = (kHeadDimV_ != 0) ? kHeadDimV_ : kHeadDim; + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; static_assert(kHeadDimV % 32 == 0); static_assert(kHeadDimV <= kHeadDim); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; - // SMEM swizzling for partial K/V - static constexpr int kBlockKSmem = (kHeadDim % 64 == 0) ? 64 : 32; - static constexpr int kSwizzle = (kBlockKSmem == 32) ? 2 : 3; - - // GMMA Tiled Mma - // Q*K -> S using TiledMma = decltype(make_tiled_mma( - cute::GMMA::ss_op_selector< - Element, Element, ElementAccum, - Shape, Int, Int>, - GMMA::Major::K, GMMA::Major::K - >(), - Layout, _1, _1>>{} - )); - - // S*V -> O - // For the O “outer product,” we define the shape in [M, HeadDimV, N]. - static constexpr int AtomLayoutNO = kNumThreads / kNumThreadsSoftmax; + cute::GMMA::ss_op_selector, Int, Int>, + GMMA::Major::K, GMMA::Major::K>(), + Layout, _1, _1>>{})); + + static constexpr int AtomLayoutNO = kNThreads / kNThreadsS; using TiledMmaO = decltype(make_tiled_mma( - cute::GMMA::rs_op_selector< - Element, Element, ElementAccum, - Shape, Int, Int>, - GMMA::Major::K, GMMA::Major::MN - >(), - Layout, Int, _1>>{} - )); - - //////////////////////////////////////////////////////////////////////////////////////////////////// - /// SMEM Layout definitions: Q/K/V, P, row-scale, etc. - //////////////////////////////////////////////////////////////////////////////////////////////////// - using SmemLayoutQ = decltype( - tile_to_shape( + cute::GMMA::rs_op_selector, Int, Int>, + GMMA::Major::K, GMMA::Major::MN>(), + Layout, Int, _1>>{})); + + using SmemLayoutQ = decltype(tile_to_shape( getSmemLayoutK(), - Shape, Int>{} - ) - ); + Shape, Int>{})); - using SmemLayoutK = decltype( - tile_to_shape( + using SmemLayoutK = decltype(tile_to_shape( getSmemLayoutK(), - Shape, Int>{} - ) - ); + Shape, Int>{})); - using SmemLayoutV = decltype( - tile_to_shape( + using SmemLayoutV = decltype(tile_to_shape( getSmemLayoutK(), - Shape, Int>{} - ) - ); - using SmemLayoutVtransposed = decltype( - composition( - SmemLayoutV{}, - make_layout( - Shape, Int>{}, - GenRowMajor{} - ) - ) - ); + Shape, Int>{})); + using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - // For partial S data (softmax region) - using SmemLayoutP = Layout, Int, _1, Int>>; - using SmemLayoutRow = Layout>, Stride<_1, _2>>; + using SmemLayoutP = Layout, Int, _1, Int>>; + using SmemLayoutRow = Layout>, Stride<_1, _2>>; - // Layout for the O tile in smem - using SmemLayoutAtomO = decltype( - composition( + using SmemLayoutAtomO = decltype(composition( Swizzle{}, - Layout, Int>, Stride, _1>>{} - ) - ); - using SmemLayoutO = decltype( - tile_to_shape( + Layout, Int>, Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, - Shape, Int>{} - ) - ); + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; - //////////////////////////////////////////////////////////////////////////////////////////////////// - /// Copy Atoms for SMEM read/write - //////////////////////////////////////////////////////////////////////////////////////////////////// - using SmemCopyAtomO = Copy_Atom; - using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; - - //////////////////////////////////////////////////////////////////////////////////////////////////// - /// GMEM Tiled Copies for Q/K/V - //////////////////////////////////////////////////////////////////////////////////////////////////// - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must align with vector load size"); + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - using GmemCopyStruct = SM80_CP_ASYNC_CACHEGLOBAL; - static constexpr int kNumThreadsLoad = kNumThreads - kNumThreadsSoftmax; - static_assert(kNumThreadsLoad % kGmemThreadsPerRow == 0, "Thread counts must match row partitions"); + using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL; + static constexpr int kNThreadsLoad = kNThreads - kNThreadsS; + static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout< - Shape, Int>, - Stride, _1> - >; - using GmemTiledCopy = decltype( - make_tiled_copy( - Copy_Atom{}, + Shape, Int>, + Stride, _1>>; + using GmemTiledCopy = decltype(make_tiled_copy( + Copy_Atom{}, GmemLayoutAtom{}, - Layout>{} // 8 vals per read - ) - ); + Layout>{})); // Val layout, 8 vals per read - // For storing O to GMEM using GmemLayoutAtomO = Layout< - Shape, Int>, - Stride, _1> - >; - using GmemTiledCopyO = decltype( - make_tiled_copy( + Shape, Int>, + Stride, _1>>; + using GmemTiledCopyO = decltype(make_tiled_copy( Copy_Atom, Element>{}, GmemLayoutAtomO{}, - Layout>{} // 8 vals per store - ) - ); + Layout>{})); // Val layout, 8 vals per store - // For accumulation path (split) - static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); + static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum; using GmemLayoutAtomOaccum = Layout< - Shape, Int>, - Stride, _1> - >; - using GmemTiledCopyOaccum = decltype( - make_tiled_copy( + Shape, Int>, + Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy( Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, - Layout>{} // 4 vals per store - ) - ); + Layout>{})); // Val layout, 4 vals per store }; -//////////////////////////////////////////////////////////////////////////////////////////////////// -/// Shared Storage Container for MLA -/// - Re-used union across Q/K/P/O or row sums, etc. -//////////////////////////////////////////////////////////////////////////////////////////////////// namespace flash { using namespace cute; -template +template struct SharedStorageMLA { union { struct { - cute::array_aligned> smem_q; - cute::array_aligned * 2> smem_k; // double buffer - cute::array_aligned> smem_p; - cute::array_aligned> smem_scale; + cute::array_aligned> smem_q; + cute::array_aligned * 2> smem_k; // Double buffer + cute::array_aligned> smem_p; + cute::array_aligned> smem_scale; }; struct { - cute::array_aligned> smem_max; - cute::array_aligned> smem_sum; - cute::array_aligned> smem_o; + cute::array_aligned> smem_max; + cute::array_aligned> smem_sum; + cute::array_aligned> smem_o; }; }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -/// store() Epilogue for partial or non-partial results -/// - Manages writing O/accumulation to global memory + writing out LSE for row block. -//////////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename KernelTraits, - bool Split, - typename SharedStorage, - typename AccO, - typename Softmax -> -__forceinline__ __device__ -void store( - const Flash_fwd_mla_params ¶ms, - const int batch_id, - const int head_id, - const int m_block, - const int n_split_idx, - SharedStorage &shared_storage, - AccO tOrO, - Softmax softmax -) { - constexpr int kBlockM = KernelTraits::kBlockM; - constexpr int kHeadDimV = KernelTraits::kHeadDimV; - constexpr int kNumThreadsS = KernelTraits::kNumThreadsSoftmax; - using Element = typename KernelTraits::Element; - using ElementAccum = typename KernelTraits::ElementAccum; - using IndexT = typename KernelTraits::IndexT; + +template +__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, + SharedStorage &shared_storage, AccO tOrO, Softmax softmax) { + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDimV = Kernel_traits::kHeadDimV; + constexpr int kNThreadsS = Kernel_traits::kNThreadsS; + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; const int tidx = threadIdx.x; - typename KernelTraits::TiledMmaO tiled_mma_o; + typename Kernel_traits::TiledMmaO tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); - // Softmax LSE for final normalization - auto lse = softmax.template normalize_softmax_lse(tOrO, params.scale_softmax); + // Epilogue - // Decide if writing ephemeral partial results (float accumulation) or final (Element). - using ElementO = std::conditional_t; + const int split_offset = __ldg(params.num_splits_ptr + bidb); - // Prepare SMEM for O - Tensor sOaccum = make_tensor( - make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), - typename KernelTraits::SmemLayoutO{} - ); - auto smem_tiled_copy_Oaccum = make_tiled_copy_C( - std::conditional_t{}, - tiled_mma_o - ); + Tensor lse = softmax.template normalize_softmax_lse(tOrO, params.scale_softmax); + using ElementO = std::conditional_t; + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum + >; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o); auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor rO = flash::convert_type(tOrO); - Tensor taccOrO = smem_thr_copy_Oaccum.retile_S(rO); - Tensor taccOsO = smem_thr_copy_Oaccum.partition_D(sOaccum); + Tensor rO = flash::convert_type(tOrO); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) __syncthreads(); - cute::copy(smem_tiled_copy_Oaccum, taccOrO, taccOsO); - - // Compute GMEM offsets - const IndexT row_offset_o = batch_id * params.o_batch_stride - + m_block * kBlockM * params.o_row_stride - + head_id * params.o_head_stride; - const IndexT row_offset_oaccum = (((__ldg(params.num_splits_ptr + batch_id) + n_split_idx) - * params.h + head_id) - * params.seqlen_q + (m_block * kBlockM)) * params.d_v; - const IndexT row_offset_lse = (batch_id * params.h + head_id) * params.seqlen_q + m_block * kBlockM; - const IndexT row_offset_lseaccum = (((__ldg(params.num_splits_ptr + batch_id) + n_split_idx) - * params.h + head_id) - * params.seqlen_q + (m_block * kBlockM)); - - // Prepare GMEM for final or partial O - Tensor gOaccum = make_tensor( - make_gmem_ptr( - reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) - + (Split ? row_offset_oaccum : row_offset_o) - ), - Shape, Int>{}, - make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}) - ); - // Prepare GMEM LSE - Tensor gLSEaccum = make_tensor( - make_gmem_ptr( - reinterpret_cast( - Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr - ) + (Split ? row_offset_lseaccum : row_offset_lse) - ), - Shape>{}, - Stride<_1>{} - ); + // Stage accumulator fragment to shared memory using tiled copy + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); - // Tiled copy from SMEM -> GMEM for O - using GmemTiledCopyOAccum = std::conditional_t< - !Split, - typename KernelTraits::GmemTiledCopyO, - typename KernelTraits::GmemTiledCopyOaccum - >; - GmemTiledCopyOAccum gmem_tiled_copy_Oaccum; - auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), + Shape>{}, Stride<_1>{}); - Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); + using GmemTiledCopyO = std::conditional_t; + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); __syncthreads(); - // If out of range of the "softmax" portion, do not store - if (tidx >= kNumThreadsS) { return; } + if (tidx >= kNThreadsS) { return; } - // Load from SMEM Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); - // Write out the LSE - auto caccO = make_identity_tensor(Shape, Int>{}); - auto taccOcO = thr_mma_o.partition_C(caccO); - auto taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); - CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); - + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1) + Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M if (get<1>(taccOcO_row(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO_row(mi)); - if (row < params.seqlen_q - m_block * kBlockM) { - gLSEaccum(row) = lse(mi); - } + if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } - // Identity layout for sO - auto cO = make_identity_tensor( - make_shape(size<0>(sOaccum), size<1>(sOaccum)) - ); - auto tOcO = gmem_thr_copy_Oaccum.partition_D(cO); - auto tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); - - // Copy final O back to GMEM + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, - params.seqlen_q - m_block * kBlockM + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM ); } -//////////////////////////////////////////////////////////////////////////////////////////////////// -/// compute_attn_1rowblock_splitkv_mla() -/// - Core logic for Q*K -> S -> Softmax -> S*V -> O -/// - Includes partial accumulation for splits and optional causal masking. -//////////////////////////////////////////////////////////////////////////////////////////////////// -template -__forceinline__ __device__ -void compute_attn_1rowblock_splitkv_mla( - const Flash_fwd_mla_params ¶ms, - const int batch_id, - const int head_id, - const int m_block, - const int n_split_idx, - const int seqlen_k, - const int n_block_min, - const int n_block_max, - const bool no_split, - SharedStorage &shared_storage -) { - constexpr int kBlockM = KernelTraits::kBlockM; - constexpr int kBlockN = KernelTraits::kBlockN; - constexpr int kHeadDim = KernelTraits::kHeadDim; - constexpr int kHeadDimV = KernelTraits::kHeadDimV; - constexpr int kNumThreads = KernelTraits::kNumThreads; - constexpr int kNumThreadsS = KernelTraits::kNumThreadsSoftmax; - using Element = typename KernelTraits::Element; - using IndexT = typename KernelTraits::IndexT; - - static_assert(kNumThreads == 256 && kNumThreadsS == 128, "Expected 256 main threads, 128 softmax threads."); +template +__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params ¶ms, + const int bidb, const int bidh, const int m_block, + const int n_split_idx, const int seqlen_k, + const int n_block_min, const int n_block_max, const bool NoSplit, + SharedStorage &shared_storage) { + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kHeadDimV = Kernel_traits::kHeadDimV; + constexpr int kNThreads = Kernel_traits::kNThreads; + constexpr int kNThreadsS = Kernel_traits::kNThreadsS; + static_assert(kNThreads == 256 and kNThreadsS == 128); + using Element = typename Kernel_traits::Element; + using index_t = typename Kernel_traits::index_t; const int tidx = threadIdx.x; - int n_block = n_block_max - 1; - - // Smem pointers for Q, K, V, partial S, etc. - Tensor sQ = make_tensor( - make_smem_ptr(shared_storage.smem_q.data()), - typename KernelTraits::SmemLayoutQ{} - ); - Tensor sK = make_tensor( - make_smem_ptr(shared_storage.smem_k.data()), - typename KernelTraits::SmemLayoutK{} - ); - Tensor sV = make_tensor( - make_smem_ptr(shared_storage.smem_k.data()), - typename KernelTraits::SmemLayoutV{} - ); - Tensor sVt = make_tensor( - make_smem_ptr(shared_storage.smem_k.data()), - typename KernelTraits::SmemLayoutVtransposed{} - ); - - // Softmax partial - Tensor sP = make_tensor( - make_smem_ptr(shared_storage.smem_p.data()), - typename KernelTraits::SmemLayoutP{} - ); - Tensor tPsP = sP(_, tidx % kNumThreadsS, _, _); - - // Row-based scale, sum, etc. - Tensor sScale = make_tensor( - make_smem_ptr(shared_storage.smem_scale.data()), - typename KernelTraits::SmemLayoutRow{} - ); - Tensor tScale = sScale(_, tidx % kNumThreadsS); - Tensor sRowMax = make_tensor( - make_smem_ptr(shared_storage.smem_max.data()), - typename KernelTraits::SmemLayoutRow{} - ); - Tensor tRowMax = sRowMax(_, tidx % kNumThreadsS); - Tensor sRowSum = make_tensor( - make_smem_ptr(shared_storage.smem_sum.data()), - typename KernelTraits::SmemLayoutRow{} - ); - Tensor tRowSum = sRowSum(_, tidx % kNumThreadsS); - - // Mma for O - typename KernelTraits::TiledMmaO tiled_mma_o; + int n_block = n_block_max - 1; + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); + Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); + + Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); + Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); + Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS); + Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS); + Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS); + + typename Kernel_traits::TiledMmaO tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); - Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); - Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape, Int>{}); + Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N) + Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) clear(tOrO); - // Combined softmax utility flash::Softmax<2 * size<1>(tOrO)> softmax; - // Warp group logic: warpGroupIdx=0 does Q*K->S, warpGroupIdx=1 does async loads for next iteration - int warpGroupIdx = cutlass::canonical_warp_group_idx(); - if (warpGroupIdx == 0) { - // Main matmul Q*K -> S - typename KernelTraits::TiledMma tiled_mma; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + if (warp_group_idx == 0) { + typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + + if (n_block % 2 == 1) { + // Double buffer for sK + constexpr int sK_offset = size(sK); + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; +#pragma unroll 1 + for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) { + __syncthreads(); + + Tensor tSrS = partition_fragment_C(tiled_mma, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) + flash::gemm(tiled_mma, tSrQ, tSrK, tSrS); + + const bool is_masking_step = masking_step > 0; + const bool is_first_masking_step = masking_step == n_masking_steps; + + if (is_masking_step) { + Tensor cS = make_identity_tensor(Shape, Int>{}); + Tensor tScS = thr_mma.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if constexpr (!Is_causal) { // Just masking based on col + if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY; + } else { + // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups + // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups + int row = int(get<0>(tScS(i))); + int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; + if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY; + } + } + } + + // We have key_padding_mask so we'll need to Check_inf + Tensor scale_o = is_first_masking_step + ? softmax.template softmax(tSrS, params.scale_softmax_log2) + : is_masking_step ? + softmax.template softmax(tSrS, params.scale_softmax_log2) + : softmax.template softmax(tSrS, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(tSrS); + cute::copy(rP, tPsP); + cute::copy(scale_o, tScale_osScale_o); + + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SReady)); - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); - Tensor tSrK = thr_mma.partition_fragment_B(sK); + flash::rescale_o(tOrO, scale_o); + + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); + + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + + cute::copy(softmax.row_max, tRow_maxsRow_max); + cute::copy(softmax.row_sum, tRow_sumsRow_sum); + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); + } else { + const int *block_table = params.block_table + bidb * params.block_table_batch_stride; + int cur_block_table = __ldg(&block_table[n_block]); + + const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); + Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, + params.seqlen_q - m_block * kBlockM); + + const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K; + auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS); + Tensor tKgK = gmem_thr_copy_K.partition_S(gK); + Tensor tKsK = gmem_thr_copy_K.partition_D(sK); + Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); - // If n_block is odd => shift for double-buffer if (n_block % 2 == 1) { - constexpr int sKOffset = size(sK); - tSrK.data() += (sKOffset / 8); - tOrVt.data() += (sKOffset / 8); + // Double buffer for sK + constexpr int sK_offset = size(sK); + tKsK.data() = tKsK.data() + sK_offset; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + + // We need to clear the sK smem tiles because K is V. + const index_t offset_k = cur_block_table * params.k_batch_stride; + tKgK.data() = tKgK.data() + offset_k; + flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK, + seqlen_k - n_block * kBlockN); + tKgK.data() = tKgK.data() + -offset_k; + cute::cp_async_fence(); + + if (n_block - 1 >= n_block_min) { + cur_block_table = __ldg(&block_table[n_block - 1]); + } + +#pragma unroll 1 + for (; n_block >= n_block_min; --n_block) { + flash::cp_async_wait<0>(); + __syncthreads(); + + if (n_block - 1 >= n_block_min) { + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tKsK.data() = tKsK.data() + sK_offset; + + const index_t offset_k = cur_block_table * params.k_batch_stride; + tKgK.data() = tKgK.data() + offset_k; + flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK); + tKgK.data() = tKgK.data() + -offset_k; + cute::cp_async_fence(); + } + + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady)); + + if (n_block - 2 >= n_block_min) { + cur_block_table = __ldg(&block_table[n_block - 2]); + } + + typename Kernel_traits::TiledMma tiled_mma; + auto tSrS_layout = partition_fragment_C(tiled_mma, Shape, Int>{}).layout(); + Tensor rP = make_tensor(tSrS_layout); + Tensor scale_o = make_tensor(Shape<_2>{}); + cute::copy(tScale_osScale_o, scale_o); + cute::copy(tPsP, rP); + + flash::rescale_o(tOrO, scale_o); + + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); + + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tOrVt.data() = tOrVt.data() + sK_offset / 8; } - // We have a loop from n_block_max-1 down to n_block_min - // Need to do “masking step(s)” for partial or causal scenarios. - constexpr int nMaskingSteps = !IsCausal - ? 1 - : cute::ceil_div(kBlockM, kBlockN) + 1; + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); + cute::copy(tRow_maxsRow_max, softmax.row_max); + cute::copy(tRow_sumsRow_sum, softmax.row_sum); + } + + if (NoSplit) + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); + else + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); +} + +template +__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1) +flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) { + constexpr int kBlockN = Kernel_traits::kBlockN; + const int m_block = blockIdx.x; + const int bidh = blockIdx.y; + const int partition_idx = blockIdx.z; + + extern __shared__ char shared_memory[]; + auto &shared_storage = *reinterpret_cast(shared_memory); + + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; + int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); + int begin_idx = tile_scheduler_metadata.x; + int begin_seqlen = tile_scheduler_metadata.y; + int end_idx = tile_scheduler_metadata.z; + int end_seqlen = tile_scheduler_metadata.w; + if (begin_idx >= params.b) return; + int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); #pragma unroll 1 - for (int masking + for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { + const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; + const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id); + const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; + const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); + const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); + if (batch_id > begin_idx) { + __syncthreads(); // Barrier between two tiles. + } + flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void __launch_bounds__(256, 1, 1) +flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) { + constexpr int kNThreads = 128; + + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; const int hs = params.h * params.seqlen_q; const int batch_idx = bidx / hs; const int hs_idx = bidx % hs; From ccb208bcac49cf8fcc4ccefa2930271898abb69d Mon Sep 17 00:00:00 2001 From: Gareth Jones Date: Sun, 23 Feb 2025 18:44:25 -0800 Subject: [PATCH 3/5] Cache output stride parameters in registers to reduce global loads --- csrc/flash_fwd_mla_kernel.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index e3e46fd..9265a1a 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -154,6 +154,11 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const const int tidx = threadIdx.x; + // Cache frequently used parameters into registers for optimization + const index_t o_batch_stride = __ldg(¶ms.o_batch_stride); + const index_t o_row_stride = __ldg(¶ms.o_row_stride); + const index_t o_head_stride = __ldg(¶ms.o_head_stride); + typename Kernel_traits::TiledMmaO tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); From 46bafd9e033137fc5b24470c0b82f47fe62bc3f6 Mon Sep 17 00:00:00 2001 From: Gareth Jones Date: Sun, 23 Feb 2025 18:45:40 -0800 Subject: [PATCH 4/5] Cache output stride parameters in registers to reduce global loads --- csrc/flash_fwd_mla_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 9265a1a..b78247d 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -28,7 +28,7 @@ constexpr auto getSmemLayoutK() { } } -template +template struct Flash_fwd_kernel_traits_mla { using Element = elem_type; using ElementAccum = float; From 33e110bb66ddda2185f2efaa71f3037b615120a1 Mon Sep 17 00:00:00 2001 From: Gareth Jones Date: Sun, 23 Feb 2025 20:08:19 -0800 Subject: [PATCH 5/5] implement the index --- csrc/flash_fwd_mla_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index b78247d..63f05f8 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -187,7 +187,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const // Stage accumulator fragment to shared memory using tiled copy cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); - const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_o = bidb * o_batch_stride + m_block * kBlockM * o_row_stride + bidh * o_head_stride; const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;