diff --git a/kernels/flash-attn/cutlass/flash_attn_cute.cu b/kernels/flash-attn/cutlass/flash_attn_cute.cu index c570910f..6fe03fa0 100644 --- a/kernels/flash-attn/cutlass/flash_attn_cute.cu +++ b/kernels/flash-attn/cutlass/flash_attn_cute.cu @@ -37,7 +37,7 @@ struct FlashAttnConfig { // Gmem2Smem config using GmemCopyAtom = - Copy_Atom, T>; + Copy_Atom, T>; static constexpr int GmemValsPerLoad = sizeof(uint128_t) / sizeof(T); static constexpr int GmemThreadsPerRow = HeadDim / GmemValsPerLoad; // each thread reads 128 bit @@ -56,7 +56,7 @@ struct FlashAttnConfig { using SmemCopyAtomTransposed = Copy_Atom; // for column major load using SmemCopyAtomO = - Copy_Atom, + Copy_Atom, T>; // NOTE: stmatrix is only available after sm90, we use a // vectorized copy instead