From 187af86806e441af08f74a0df41cb277d1e13f67 Mon Sep 17 00:00:00 2001 From: Hui Zhou <90194592+kitecats@users.noreply.github.com> Date: Sat, 31 May 2025 09:33:51 +0800 Subject: [PATCH] Use 128-bit data loading --- kernels/flash-attn/cutlass/flash_attn_cute.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kernels/flash-attn/cutlass/flash_attn_cute.cu b/kernels/flash-attn/cutlass/flash_attn_cute.cu index c570910f..6fe03fa0 100644 --- a/kernels/flash-attn/cutlass/flash_attn_cute.cu +++ b/kernels/flash-attn/cutlass/flash_attn_cute.cu @@ -37,7 +37,7 @@ struct FlashAttnConfig { // Gmem2Smem config using GmemCopyAtom = - Copy_Atom, T>; + Copy_Atom, T>; static constexpr int GmemValsPerLoad = sizeof(uint128_t) / sizeof(T); static constexpr int GmemThreadsPerRow = HeadDim / GmemValsPerLoad; // each thread reads 128 bit @@ -56,7 +56,7 @@ struct FlashAttnConfig { using SmemCopyAtomTransposed = Copy_Atom; // for column major load using SmemCopyAtomO = - Copy_Atom, + Copy_Atom, T>; // NOTE: stmatrix is only available after sm90, we use a // vectorized copy instead