diff --git a/include/flashinfer/attention/generic/permuted_smem.cuh b/include/flashinfer/attention/generic/permuted_smem.cuh index e1081bd696..3bc8a1ad5a 100644 --- a/include/flashinfer/attention/generic/permuted_smem.cuh +++ b/include/flashinfer/attention/generic/permuted_smem.cuh @@ -45,40 +45,70 @@ constexpr __host__ __device__ __forceinline__ uint32_t upcast_size() { } /*! - * \brief The shared memory wrapper. + * \brief Pure arithmetic layout policy for XOR-swizzled shared memory tiles. + * + * Contains no pointer and no memory — only the coordinate arithmetic that maps + * logical (row, col) to physical cell index and vice-versa. Its methods are + * static device functions implementing pure arithmetic that the compiler can + * typically eliminate entirely at compile time. + * + * This type is passed as a template parameter to smem_t (composition), giving + * the caller an explicit bijection handle: any code that derives global-memory + * coordinates from smem_t::Layout is structurally guaranteed to use the same + * swizzle pattern as the LDS read/write path. */ -template -struct smem_t { - // The base pointer. - BasePtrTy* base; - __device__ __forceinline__ smem_t() : base(nullptr) {} - template - __device__ __forceinline__ smem_t(T* base) : base((BasePtrTy*)base) {} +template +struct SwizzleLayout { + // ── Primitive ────────────────────────────────────────────────────────────── + // XOR mask applied to column bits for a given row. + // XOR is self-inverse, so the same mask is used for both the forward + // (LDS write) and inverse (global read) directions. + template + static __device__ __forceinline__ uint32_t col_swizzle_xor(uint32_t row) { + if constexpr (swizzle_mode == SwizzleMode::k128B_16Row) { + constexpr uint32_t period = (stride >= 16u) ? 16u : 8u; + return row & (period - 1u); + } else if constexpr (swizzle_mode == SwizzleMode::k128B) { + return row & 7u; + } else if constexpr (swizzle_mode == SwizzleMode::k64B) { + return (row >> 1u) & 3u; + } else { + return 0u; // kLinear + } + } + + // ── Derived from the primitive ───────────────────────────────────────────── /*! - * \brief Compute the element offset given coordinates in a permuted shared - * memory. - * \tparam stride The stride (in terms of b128_t's) in the permuted shared - * memory. + * \brief Compute the element offset given coordinates in a permuted shared memory. + * \tparam stride The stride (in terms of BasePtrTy elements) in the permuted shared memory. * \param i The row index. * \param j The column index. */ template static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, uint32_t j) { - if constexpr (swizzle_mode == SwizzleMode::k128B) { - return i * stride + (j ^ (i % 8)); - } else if constexpr (swizzle_mode == SwizzleMode::k128B_16Row) { - // Extend the XOR period from 8 to 16 when stride allows it, eliminating - // the 8-way read-path bank conflicts that k128B has on CDNA3 MI300x. - constexpr uint32_t period = (stride >= 16u) ? 16u : 8u; - return i * stride + (j ^ (i % period)); - } else if constexpr (swizzle_mode == SwizzleMode::k64B) { - static_assert(stride == 4); - return i * stride + (j ^ ((i / 2) % 4)); - } else { - // swizzle_mode == SwizzleMode::kLinear - return i * stride + j; + if constexpr (swizzle_mode == SwizzleMode::k64B) { + static_assert(stride == 4, "k64B swizzle requires stride == 4"); } + return i * stride + (j ^ col_swizzle_xor(i)); + } + + /*! + * \brief Inverse of get_permuted_offset: recover (row, col) from a physical LDS cell index. + * + * XOR is self-inverse ((x ^ mask) ^ mask == x), so the same col_swizzle_xor + * expression serves both the forward and inverse direction in all supported + * XOR-based swizzle modes. + * + * \tparam stride The BasePtrTy-unit row stride (UPCAST_STRIDE_K / UPCAST_STRIDE_Q / etc.). + * \param cell Physical cell index (0 .. smem_size-1). + * \returns b64_t{row, col} — the logical (row, column) that maps to this cell. + */ + template + static __device__ __forceinline__ b64_t get_inverse_offset(uint32_t cell) { + const uint32_t row = cell / stride; + const uint32_t col_sw = cell % stride; + return b64_t{row, col_sw ^ col_swizzle_xor(row)}; } // advance_offset_by_column @@ -191,6 +221,57 @@ struct smem_t { return offset + step_size * row_stride; } } +}; + +/*! + * \brief Shared memory wrapper parameterized over a layout policy (composition). + * + * The LayoutPolicy template parameter (typically SwizzleLayout) owns + * all coordinate arithmetic. smem_t exposes thin __forceinline__ wrappers that + * forward to LayoutPolicy, so existing call sites — smem.get_permuted_offset(i,j), + * smem->advance_offset_by_column(...), etc. — compile without changes. + * + * The Layout type alias is the bijection handle: any caller that derives global-memory + * coordinates via smem_t::Layout is structurally tied to the same swizzle pattern + * as the LDS read/write path. It is impossible to accidentally mix modes. + */ +template +struct smem_t { + using Layout = LayoutPolicy; + + // The base pointer. + BasePtrTy* base; + __device__ __forceinline__ smem_t() : base(nullptr) {} + template + __device__ __forceinline__ smem_t(T* base) : base((BasePtrTy*)base) {} + + // ── Thin wrappers forwarding to Layout ──────────────────────────────────── + // All are static __forceinline__ — zero runtime cost, zero binary bloat. + + template + static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, uint32_t j) { + return Layout::template get_permuted_offset(i, j); + } + + template + static __device__ __forceinline__ b64_t get_inverse_offset(uint32_t cell) { + return Layout::template get_inverse_offset(cell); + } + + template + static __device__ __forceinline__ uint32_t advance_offset_by_column(uint32_t offset, + uint32_t step_idx, + uint32_t col_idx = 0) { + return Layout::template advance_offset_by_column(offset, step_idx, col_idx); + } + + template + static __device__ __forceinline__ uint32_t advance_offset_by_row(uint32_t offset, + uint32_t row_idx = 0) { + return Layout::template advance_offset_by_row(offset, row_idx); + } + + // ── LDS memory operations ───────────────────────────────────────────────── template __device__ __forceinline__ void load_fragment(uint32_t offset, T* frag) { diff --git a/include/flashinfer/attention/generic/prefill.cuh b/include/flashinfer/attention/generic/prefill.cuh index 9131be0f7f..e4d156fcc1 100644 --- a/include/flashinfer/attention/generic/prefill.cuh +++ b/include/flashinfer/attention/generic/prefill.cuh @@ -281,95 +281,92 @@ __device__ __forceinline__ void q_frag_apply_llama_rope_with_pos(T* x_first_half } } -template -__device__ __forceinline__ void produce_kv_impl( - uint32_t warp_idx, uint32_t lane_idx, - smem_t smem, uint32_t* smem_offset, - typename KTraits::DTypeKV** gptr, const uint32_t stride_n, const uint32_t kv_idx_base, - const uint32_t kv_len) { - using DTypeKV = typename KTraits::DTypeKV; - constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; // 16 - constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; - constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; - constexpr uint32_t NUM_MMA_D = produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; - constexpr uint32_t UPCAST_STRIDE = - produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; - constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; - - // NOTE: NUM_MMA_KV*4/NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps - static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); - uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / KV_THR_LAYOUT_COL; - -#pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { -#pragma unroll - for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { - smem.template load_vector_async(*smem_offset, *gptr, kv_idx < kv_len); - *smem_offset = smem.template advance_offset_by_column<16>(*smem_offset, j); - *gptr += 16 * upcast_size(); - } - kv_idx += NUM_WARPS * 4; - *smem_offset = smem.template advance_offset_by_row(*smem_offset) - - (sizeof(DTypeKV) * NUM_MMA_D * 2); - *gptr += NUM_WARPS * 4 * stride_n - - sizeof(DTypeKV) * NUM_MMA_D * 2 * upcast_size(); - } - *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; -} - /*! - * \brief Produce k/v fragments from global memory to shared memory. - * \tparam fill_mode The fill mode of the shared memory. - * \tparam NUM_MMA_D_VO The number of fragments in y dimension. - * \tparam NUM_MMA_KV The number of fragments in z dimension. - * \tparam num_warps The number of warps in the threadblock. - * \tparam T The data type of the input tensor. - * \param smem The shared memory to store kv fragments. - * \param gptr The global memory pointer. - * \param kv_idx_base The base kv index. - * \param kv_len The length of kv tensor. + * \brief Load K or V tile from global memory to shared memory using inverse-swizzle addressing. + * + * The inverse-swizzle implementation reads from global memory using a swizzled + * indexing scheme and writes to LDS using a linear offset. The LDS still stores the + * data in a swizzled order, but the permuted index offset is now computed when + * accessing global memory rather than when writing to LDS. + * This inverse-swizzle access pattern preserves global-memory coalescing, because + * every group of 16 contiguous threads still accesses locations within the same + * cache line. + * + * \param smem The K or V shared memory tile. + * \param gptr_base Pointer to k/v[chunk_start * stride_n + kv_head_idx * stride_h] (chunk base). + * \param kv_idx_base Tile start offset within chunk (0, CTA_TILE_KV, 2*CTA_TILE_KV, ...). + * \param kv_len chunk_size (for bounds check: kv_idx_base + row < kv_len). + * \param stride_n Global row stride in DTypeKV elements (= head_dim). */ template __device__ __forceinline__ void produce_kv( - smem_t smem, uint32_t* smem_offset, - typename KTraits::DTypeKV** gptr, const uint32_t stride_n, const uint32_t kv_idx_base, - const uint32_t kv_len, const dim3 tid = threadIdx) { - const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; + smem_t, typename KTraits::SmemBasePtrTy>& smem, + const typename KTraits::DTypeKV* __restrict__ gptr_base, const uint32_t kv_idx_base, + const uint32_t kv_len, const uint32_t stride_n, const dim3 tid = threadIdx) { using DTypeKV = typename KTraits::DTypeKV; - constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; // 16 - constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; - constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; - constexpr uint32_t NUM_MMA_D = produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; + using SmemCell = typename KTraits::SmemBasePtrTy; // uint2 constexpr uint32_t UPCAST_STRIDE = produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; - constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; - - // NOTE: NUM_MMA_KV*4/NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps - static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); - uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / KV_THR_LAYOUT_COL; + constexpr uint32_t SMEM_SZ = KTraits::CTA_TILE_KV * UPCAST_STRIDE; + constexpr uint32_t NUM_THREADS = KTraits::NUM_THREADS; + constexpr uint32_t UPCAST_SZ = upcast_size(); + + const uint32_t tid_linear = get_warp_idx(tid.y, tid.z) * 64u + tid.x; + const SmemCell* gptr_u2 = reinterpret_cast(gptr_base); + const uint32_t stride_n_u2 = stride_n / UPCAST_SZ; // stride in uint2 units + + // NITERS is a compile-time constant (SMEM_SZ / NUM_THREADS, typically 4–8). + // Split into two fully-unrolled phases so the compiler can issue all NITERS + // global loads before any LDS write — allowing the hardware to pipeline the + // N outstanding loads with the N subsequent LDS stores + constexpr uint32_t NITERS = SMEM_SZ / NUM_THREADS; + static_assert(SMEM_SZ % NUM_THREADS == 0, "SMEM_SZ must be divisible by NUM_THREADS"); + + // Incremental coordinate computation — eliminates per-iteration division/shift from Phase A. + // When NUM_THREADS % UPCAST_STRIDE == 0 (guaranteed: both are compile-time powers of 2, + // NUM_THREADS = NUM_WARPS*64 ≥ UPCAST_STRIDE = HEAD_DIM/4 for all supported configs): + // col_sw = (tid_linear + i*NUM_THREADS) & (UPCAST_STRIDE-1) + // = tid_linear & (UPCAST_STRIDE-1) [constant — i*NUM_THREADS contributes 0] + // row_i = (tid_linear + i*NUM_THREADS) / UPCAST_STRIDE + // = row_base + i * ROW_STEP [compile-time stride per unrolled step] + // Net cost per Phase-A iteration: 1 ADD + 1 AND + 1 XOR (vs. 1 ADD + 1 SHR + 2 AND + 1 XOR). + static_assert(NUM_THREADS % UPCAST_STRIDE == 0, + "NUM_THREADS must be divisible by UPCAST_STRIDE for incremental coord computation"); + constexpr uint32_t ROW_STEP = NUM_THREADS / UPCAST_STRIDE; + + const uint32_t col_sw = tid_linear & (UPCAST_STRIDE - 1u); // invariant across all NITERS + const uint32_t row_base = tid_linear / UPCAST_STRIDE; // v_lshrrev_b32, computed once + + // SmemTy::Layout is the bijection handle: global-memory coordinate arithmetic + // is derived from the smem parameter's type, making it structurally impossible + // to use a different swizzle mode for the global load than for the LDS write. + using SmemTy = std::remove_reference_t; + + // Zero-init the staging array so OOB slots are well-defined (avoids UB from + // copying indeterminate values in Phase B when fill_mode == kNoFill). + SmemCell loaded[NITERS] = {}; + + // Phase A: all NITERS global loads issued before any LDS write. +#pragma unroll + for (uint32_t i = 0; i < NITERS; ++i) { + const uint32_t row = row_base + i * ROW_STEP; // compile-time offset per unrolled step + const uint32_t col = col_sw ^ SmemTy::Layout::template col_swizzle_xor(row); + if ((kv_idx_base + row) < kv_len) { + loaded[i] = gptr_u2[(kv_idx_base + row) * stride_n_u2 + col]; + } + } + // Phase B: drain loads and write to LDS. #pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { -#pragma unroll - for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { - smem.template load_vector_async(*smem_offset, *gptr, kv_idx < kv_len); - *smem_offset = smem.template advance_offset_by_column<16>(*smem_offset, j); - *gptr += 16 * upcast_size(); - } - kv_idx += NUM_WARPS * 4; - *smem_offset = smem.template advance_offset_by_row(*smem_offset) - - (sizeof(DTypeKV) * NUM_MMA_D * 2); - *gptr += NUM_WARPS * 4 * stride_n - - sizeof(DTypeKV) * NUM_MMA_D * 2 * upcast_size(); + for (uint32_t i = 0; i < NITERS; ++i) { + smem.base[tid_linear + i * NUM_THREADS] = loaded[i]; } - *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; } template __device__ __forceinline__ void page_produce_kv( - smem_t smem, uint32_t* smem_offset, + smem_t, typename KTraits::SmemBasePtrTy> smem, + uint32_t* smem_offset, const paged_kv_t& paged_kv, const uint32_t kv_idx_base, const size_t* thr_local_kv_offset, const uint32_t kv_len, const dim3 tid = threadIdx) { @@ -479,55 +476,66 @@ __device__ __forceinline__ void init_states( } } +/*! + * \brief Load Q tile from global memory to shared memory using inverse-swizzle addressing. + * + * The inverse-swizzle implementation reads global memory using a swizzled + * indexing and writes the LDS using a linear offset. The LDS still stores the + * data in a swizzled order, but the permuted index offset is now calculated + * when global memory is accessed rather than when the LDS is written. + * The inverse-swizzle access pattern preserves global memory coalescing, as + * every 16 contiguous threads still have their accesses within the same cache + * line. Eliminates the nested mma_q × j × mma_do cursor loop and the row_idx tracking + * required by advance_offset_by_row<4> for k128B_16Row. + * + * \param qo_block_base Block-level packed Q offset = bx * CTA_TILE_Q. + * NOTE: This is NOT the per-warp qo_packed_idx_base; it is the + * same value for every warp in the CTA. The absolute smem_row + * (0..CTA_TILE_Q-1) is added directly to this block base. + * \param qo_upper_bound Valid Q row count (for bounds masking). + * \param q_ptr_base Pointer to q[0][kv_head_idx * group_size][0] (head-adjusted base). + * \param q_stride_n Global stride in DTypeQ elements across the sequence dimension. + * \param q_stride_h Global stride in DTypeQ elements across the head dimension. + * \param group_size GQA group size (num_qo_heads / num_kv_heads). + * \param q_smem Q shared memory tile. + */ template __device__ __forceinline__ void load_q_global_smem( - uint32_t packed_offset, const uint32_t qo_upper_bound, typename KTraits::DTypeQ* q_ptr_base, + uint32_t qo_block_base, const uint32_t qo_upper_bound, typename KTraits::DTypeQ* q_ptr_base, const uint32_t q_stride_n, const uint32_t q_stride_h, const uint_fastdiv group_size, - smem_t* q_smem, + smem_t, typename KTraits::SmemBasePtrTy>* q_smem, const dim3 tid = threadIdx) { using DTypeQ = typename KTraits::DTypeQ; - constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; - constexpr uint32_t WARP_THREAD_ROWS = KTraits::WARP_THREAD_ROWS; - constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; - constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; + using SmemCell = typename KTraits::SmemBasePtrTy; // uint2 constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; - constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + constexpr uint32_t SMEM_SZ_Q = KTraits::CTA_TILE_Q * UPCAST_STRIDE_Q; + constexpr uint32_t NUM_THREADS_Q = KTraits::NUM_WARPS_Q * 64u; + constexpr uint32_t UPCAST_SZ = upcast_size(); - constexpr uint32_t COLUMN_RESET_OFFSET = (NUM_MMA_D_QK / 4) * WARP_THREAD_COLS; - const uint32_t lane_idx = tid.x, warp_idx_x = get_warp_idx_q(tid.y); - uint32_t row = lane_idx / WARP_THREAD_COLS; - uint32_t col = lane_idx % WARP_THREAD_COLS; + // Only Q-warps participate (same guard as load_q_global_smem) + if (get_warp_idx_kv(tid.z) != 0) return; - if (get_warp_idx_kv(tid.z) == 0) { - uint32_t q_smem_offset_w = q_smem->template get_permuted_offset( - warp_idx_x * KTraits::NUM_MMA_Q * 16 + row, col); - // row_idx_w: the logical smem row immediately before each advance_offset_by_row<4> - // call. Required by k128B_16Row to select the correct xor_mask; ignored by k128B. - uint32_t row_idx_w = warp_idx_x * KTraits::NUM_MMA_Q * 16 + row; + // Thread index within the Q-warp group (0 .. NUM_THREADS_Q-1) + const uint32_t tid_q = get_warp_idx_q(tid.y) * 64u + tid.x; -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < 2 * 2; ++j) { - uint32_t q, r; - group_size.divmod(packed_offset + row + mma_q * 16 + j * 4, q, r); - const uint32_t q_idx = q; - DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h + - col * upcast_size(); -#pragma unroll - for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; ++mma_do) { - // load q fragment from gmem to smem - q_smem->template load_vector_async(q_smem_offset_w, q_ptr, - q_idx < qo_upper_bound); - q_smem_offset_w = - q_smem->template advance_offset_by_column(q_smem_offset_w, mma_do); - q_ptr += WARP_THREAD_COLS * upcast_size(); - } - q_smem_offset_w = q_smem->template advance_offset_by_row( - q_smem_offset_w, row_idx_w) - - COLUMN_RESET_OFFSET; - row_idx_w += WARP_THREAD_ROWS; - } +#pragma unroll 1 + for (uint32_t c = tid_q; c < SMEM_SZ_Q; c += NUM_THREADS_Q) { + // φ⁻¹: all swizzle logic inside smem_t::get_inverse_offset — nothing baked in here + const auto coords = q_smem->template get_inverse_offset(c); + const uint32_t smem_row = coords.x; // absolute row in Q smem (0..CTA_TILE_Q-1) + const uint32_t u2_col = coords.y; + + // Map absolute smem_row to packed token space using the BLOCK-LEVEL base. + // qo_block_base = bx * CTA_TILE_Q — same for all warps in the CTA. + uint32_t q_seq, q_head; + group_size.divmod(qo_block_base + smem_row, q_seq, q_head); + const bool valid = q_seq < qo_upper_bound; + + if (valid) { + q_smem->base[c] = *reinterpret_cast( + q_ptr_base + q_seq * q_stride_n + q_head * q_stride_h + u2_col * UPCAST_SZ); + } else { + q_smem->base[c] = SmemCell{0u, 0u}; } } } @@ -536,7 +544,7 @@ template __device__ __forceinline__ void q_smem_inplace_apply_rotary( const uint32_t q_packed_idx, const uint32_t qo_len, const uint32_t kv_len, const uint_fastdiv group_size, - smem_t* q_smem, + smem_t, typename KTraits::SmemBasePtrTy>* q_smem, uint32_t* q_smem_offset_r, float (*rope_freq)[4], const dim3 tid = threadIdx) { if (get_warp_idx_kv(tid.z) != 0) return; @@ -583,7 +591,7 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary( template __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( const uint32_t q_packed_idx_base, const typename KTraits::IdType* q_rope_offset, - smem_t* q_smem, + smem_t, typename KTraits::SmemBasePtrTy>* q_smem, const uint_fastdiv group_size, uint32_t* q_smem_offset_r, float (*rope_freq)[4], const dim3 tid = threadIdx) { if (get_warp_idx_kv(tid.z) == 0) { @@ -624,7 +632,7 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( template __device__ __forceinline__ void k_smem_inplace_apply_rotary( const uint32_t kv_idx_base, - smem_t* k_smem, + smem_t, typename KTraits::SmemBasePtrTy>* k_smem, uint32_t* k_smem_offset_r, float (*rope_freq)[4], const dim3 tid = threadIdx) { using DTypeKV = typename KTraits::DTypeKV; static_assert(sizeof(DTypeKV) == 2); @@ -716,9 +724,9 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( template __device__ __forceinline__ void compute_qk( - smem_t* q_smem, + smem_t, typename KTraits::SmemBasePtrTy>* q_smem, uint32_t* q_smem_offset_r, - smem_t* k_smem, + smem_t, typename KTraits::SmemBasePtrTy>* k_smem, uint32_t* k_smem_offset_r, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], const dim3 tid = threadIdx) { @@ -954,7 +962,7 @@ __device__ __forceinline__ void update_mdo_states( template __device__ __forceinline__ void compute_sfm_v( - smem_t* v_smem, + smem_t, typename KTraits::SmemBasePtrTy>* v_smem, uint32_t* v_smem_offset_r, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], @@ -1240,7 +1248,7 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( template __device__ __forceinline__ void write_o_reg_gmem( float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - smem_t* o_smem, + smem_t, typename KTraits::SmemBasePtrTy>* o_smem, typename KTraits::DTypeO* o_ptr_base, const uint32_t o_packed_idx_base, const uint32_t qo_upper_bound, const uint32_t o_stride_n, const uint32_t o_stride_h, const uint_fastdiv group_size, const dim3 tid = threadIdx) { @@ -1319,8 +1327,8 @@ __device__ __forceinline__ void write_o_reg_gmem( uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / WARP_THREAD_COLS, lane_idx % WARP_THREAD_COLS); - // row_idx_ow mirrors the row_idx_w tracking in load_q_global_smem. - // Required by k128B_16Row advance_offset_by_row<4>; ignored by k128B. + // row_idx_ow mirrors the row_idx_w tracking required by k128B_16Row + // advance_offset_by_row<4> uint32_t row_idx_ow = warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / WARP_THREAD_COLS; #pragma unroll @@ -1459,16 +1467,16 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( // cooperative fetch q fragment from gmem to reg const uint32_t qo_packed_idx_base = (bx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; - smem_t qo_smem(smem_storage.q_smem); + smem_t, typename KTraits::SmemBasePtrTy> qo_smem( + smem_storage.q_smem); const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO; DTypeQ* q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h; DTypeO* o_ptr_base = partition_kv ? o + chunk_idx * o_stride_n + (kv_head_idx * group_size) * o_stride_h : o + (kv_head_idx * group_size) * o_stride_h; - load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, q_stride_n, q_stride_h, + load_q_global_smem(bx * CTA_TILE_Q, qo_len, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem, tid); - uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); @@ -1481,8 +1489,10 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( block.sync(); } - smem_t k_smem(smem_storage.k_smem); - smem_t v_smem(smem_storage.v_smem); + smem_t, typename KTraits::SmemBasePtrTy> k_smem( + smem_storage.k_smem); + smem_t, typename KTraits::SmemBasePtrTy> v_smem( + smem_storage.v_smem); const uint32_t num_iterations = ceil_div(MASK_MODE == MaskMode::kCausal @@ -1504,31 +1514,20 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( : chunk_size) / CTA_TILE_KV; - DTypeKV* k_ptr = - k + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * k_stride_n + - kv_head_idx * k_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); - - DTypeKV* v_ptr = - v + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * v_stride_n + - kv_head_idx * v_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + // Base pointers for the KV chunk + const DTypeKV* k_chunk_base = k + chunk_start * k_stride_n + kv_head_idx * k_stride_h; + const DTypeKV* v_chunk_base = v + chunk_start * v_stride_n + kv_head_idx * v_stride_h; uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, (lane_idx / 16)); uint32_t v_smem_offset_r = v_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16); - uint32_t k_smem_offset_w = k_smem.template get_permuted_offset( - warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, - lane_idx % KV_THR_LAYOUT_COL), - v_smem_offset_w = v_smem.template get_permuted_offset( - warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, - lane_idx % KV_THR_LAYOUT_COL); - produce_kv(k_smem, &k_smem_offset_w, &k_ptr, - k_stride_n, 0, chunk_size, tid); + + produce_kv(k_smem, k_chunk_base, 0, chunk_size, + k_stride_n, tid); memory::commit_group(); - produce_kv(v_smem, &v_smem_offset_w, &v_ptr, - v_stride_n, 0, chunk_size, tid); + produce_kv(v_smem, v_chunk_base, 0, chunk_size, + v_stride_n, tid); memory::commit_group(); #pragma unroll 1 @@ -1558,7 +1557,7 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( update_mdo_states(variant, s_frag, o_frag, m, d); block.sync(); produce_kv( - k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); + k_smem, k_chunk_base, (iter + 1) * CTA_TILE_KV, chunk_size, k_stride_n, tid); memory::commit_group(); memory::wait_group<1>(); block.sync(); @@ -1567,7 +1566,7 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); block.sync(); produce_kv( - v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); + v_smem, v_chunk_base, (iter + 1) * CTA_TILE_KV, chunk_size, v_stride_n, tid); memory::commit_group(); } memory::wait_group<0>(); @@ -1870,7 +1869,8 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV const uint32_t qo_packed_idx_base = (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; - smem_t qo_smem(smem_storage.q_smem); + smem_t, typename KTraits::SmemBasePtrTy> qo_smem( + smem_storage.q_smem); const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO; DTypeQ* q_ptr_base = @@ -1884,7 +1884,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); - load_q_global_smem(qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, + load_q_global_smem(qo_tile_idx * CTA_TILE_Q, qo_upper_bound, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem, tid); memory::commit_group(); @@ -1929,40 +1929,27 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV : chunk_size) / CTA_TILE_KV; - smem_t k_smem(smem_storage.k_smem); - smem_t v_smem(smem_storage.v_smem); + smem_t, typename KTraits::SmemBasePtrTy> k_smem( + smem_storage.k_smem); + smem_t, typename KTraits::SmemBasePtrTy> v_smem( + smem_storage.v_smem); uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, (lane_idx / 16)); uint32_t v_smem_offset_r = v_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16); - uint32_t k_smem_offset_w = k_smem.template get_permuted_offset( - warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, - lane_idx % KV_THR_LAYOUT_COL), - v_smem_offset_w = v_smem.template get_permuted_offset( - warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, - lane_idx % KV_THR_LAYOUT_COL); + // Base pointers for the KV chunk (no per-thread offset — produce_kv uses cell index) + const DTypeKV* k_chunk_base = + k + (kv_indptr[request_idx] + chunk_start) * k_stride_n + kv_head_idx * k_stride_h; + const DTypeKV* v_chunk_base = + v + (kv_indptr[request_idx] + chunk_start) * v_stride_n + kv_head_idx * v_stride_h; - DTypeKV* k_ptr = k + - (kv_indptr[request_idx] + chunk_start + warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL) * - k_stride_n + - kv_head_idx * k_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); - DTypeKV* v_ptr = v + - (kv_indptr[request_idx] + chunk_start + warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL) * - v_stride_n + - kv_head_idx * v_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); - - produce_kv(k_smem, &k_smem_offset_w, &k_ptr, - k_stride_n, 0, chunk_size, tid); + produce_kv(k_smem, k_chunk_base, 0, chunk_size, + k_stride_n, tid); memory::commit_group(); - produce_kv(v_smem, &v_smem_offset_w, &v_ptr, - v_stride_n, 0, chunk_size, tid); - + produce_kv(v_smem, v_chunk_base, 0, chunk_size, + v_stride_n, tid); memory::commit_group(); #pragma unroll 1 @@ -2003,7 +1990,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV block.sync(); produce_kv( - k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); + k_smem, k_chunk_base, (iter + 1) * CTA_TILE_KV, chunk_size, k_stride_n, tid); memory::commit_group(); memory::wait_group<1>(); block.sync(); @@ -2013,7 +2000,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV block.sync(); produce_kv( - v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); + v_smem, v_chunk_base, (iter + 1) * CTA_TILE_KV, chunk_size, v_stride_n, tid); memory::commit_group(); } memory::wait_group<0>(); @@ -2157,7 +2144,8 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( const uint32_t qo_packed_idx_base = (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; - smem_t qo_smem(smem_storage.q_smem); + smem_t, typename KTraits::SmemBasePtrTy> qo_smem( + smem_storage.q_smem); const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO; DTypeQ* q_ptr_base = q + q_indptr[request_idx] * q_stride_n + (kv_head_idx * group_size) * q_stride_h; @@ -2168,7 +2156,7 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); - load_q_global_smem(qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, + load_q_global_smem(qo_tile_idx * CTA_TILE_Q, qo_upper_bound, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem, tid); memory::commit_group(); @@ -2191,8 +2179,10 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( block.sync(); } - smem_t k_smem(smem_storage.k_smem); - smem_t v_smem(smem_storage.v_smem); + smem_t, typename KTraits::SmemBasePtrTy> k_smem( + smem_storage.k_smem); + smem_t, typename KTraits::SmemBasePtrTy> v_smem( + smem_storage.v_smem); // The thr_local_kv_offset array stores the offsets into the paged kv cache for each // thread. The size of the array should be equal to the trip count of the initialization loop. diff --git a/include/gpu_iface/backend/hip/memory_ops_hip.h b/include/gpu_iface/backend/hip/memory_ops_hip.h index 50e757effc..157bd9ae6d 100644 --- a/include/gpu_iface/backend/hip/memory_ops_hip.h +++ b/include/gpu_iface/backend/hip/memory_ops_hip.h @@ -11,27 +11,203 @@ namespace memory { namespace detail { namespace hip { -__device__ __forceinline__ void commit_group() { - // Currently a no-op for HIP -} +// ─── Pipeline fence primitives ─────────────────────────────────────────────── + +/** + * @brief No-op on HIP: vmcnt is a continuous hardware counter, so no explicit + * group flush is required before s_waitcnt vmcnt(N). + */ +__device__ __forceinline__ void commit_group() {} +/** + * @brief Stall until at most N outstanding async VMEM operations remain. + * + * With synchronous register-to-LDS stores, vmcnt is always 0 at the point + * this is called, so the instruction would return immediately regardless. + * The s_waitcnt line is disabled until async buffer_load_dword_lds is wired + * into the KV tile load path. + * + * @tparam N Number of in-flight VMEM ops to allow through (0 = drain all). + */ template __device__ __forceinline__ void wait_group() { - // Currently a no-op for HIP + // TODO: Uncomment once async buffer_load_dword_lds is wired into produce_kv. + // asm volatile("s_waitcnt vmcnt(%0)" : : "n"(N) : "memory"); +} + +// ─── Async GMEM → LDS primitives ───────────────────────────────────────────── +// +// GFX9 / CDNA3 buffer_load_dword_lds semantics: +// LDS[ M0 + lane_k * size ] ← global[ rsrc.base + voffset[k] + soffset ] +// where M0 = readfirstlane(lds_ptr). +// +// IMPORTANT: lds_ptr must be uniform (the same value in every lane of the +// wavefront). Use to_sgpr_u32() to enforce this before passing the pointer. + +/// 4-SGPR buffer resource descriptor (V#/SRD) for MUBUF and buffer_load_lds. +/// The ext_vector_type attribute keeps the four words in consecutive SGPRs. +using srsrc_t = int __attribute__((ext_vector_type(4))); + +/// LDS address-space pointer (AMDGPU address space 3 = local/shared memory). +using lds_ptr_t = uint32_t __attribute__((address_space(3)))*; + +/// V# descriptor word 3 for raw MUBUF access on the GFX9 family (gfx942 included). +/// Matches CK_TILE_BUFFER_RESOURCE_3RD_DWORD. +static constexpr uint32_t kBufferResource3rdDword = 0x00020000u; + +/// Wavefront width for GFX9 / CDNA3. +static constexpr uint32_t kWarpSize = 64u; + +/** + * @brief LLVM intrinsic that emits a buffer_load_dword_lds instruction. + * + * @param rsrc Buffer resource descriptor (V#). + * @param lds_ptr Uniform LDS destination pointer. + * @param size Element size in bytes (1, 2, or 4 on GFX9). + * @param voffset Per-lane global byte offset. + * @param soffset Scalar byte offset added to voffset. + * @param offset Immediate byte offset. + * @param aux Auxiliary flags (GLC/SLC/etc.). + */ +__device__ extern void _fi_async_load_to_lds(srsrc_t rsrc, lds_ptr_t lds_ptr, int size, int voffset, + int soffset, int offset, + int aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); + +/** + * @brief Promote a value to the SGPR register class. + * + * readfirstlane makes the value uniform across the wavefront; the asm volatile + * constraint prevents LLVM from re-classifying it as a VGPR after the fact. + * Required when passing an LDS pointer to async_load_dword_to_lds(). + * + * @param x Value to pin to an SGPR. + * @return The same value, guaranteed to reside in an SGPR. + */ +__device__ __forceinline__ uint32_t to_sgpr_u32(uint32_t x) { + x = __builtin_amdgcn_readfirstlane(x); + asm volatile("" : "+s"(x)); + return x; } -/// @brief loads 128 bits from global to shared memory +/** + * @brief Build a V# buffer resource descriptor for a contiguous device buffer. + * + * Typically computed once per warp and reused across inner-loop iterations. + * + * @param base Base device pointer. + * @param num_bytes Byte extent of the region in bytes. Pass 0xFFFFFFFF to + * disable hardware bounds checking when per-lane offsets are + * already validated by the caller. + * @return 4-SGPR buffer resource descriptor. + */ +__device__ __forceinline__ srsrc_t make_srsrc(const void* base, uint32_t num_bytes) { + // V# layout: words 0-1 = 64-bit base address, word 2 = byte count, word 3 = format. + struct __attribute__((packed)) BufRsrc { + uint64_t ptr; + uint32_t range; + uint32_t config; + }; + BufRsrc res{reinterpret_cast(base), num_bytes, kBufferResource3rdDword}; + return __builtin_bit_cast(srsrc_t, res); +} + +/** + * @brief Issue one buffer_load_dword_lds for the whole wavefront (vmcnt += 1). + * + * Each lane loads one dword from global memory into its corresponding LDS slot: + * @code + * LDS[ lds_base_uniform + lane_k * 4 ] ← global[ rsrc.base + voffset[k] ] + * @endcode + * + * @param lds_base_uniform Uniform LDS byte offset for the tile. Must be + * pinned to an SGPR with to_sgpr_u32() before this call. + * @param rsrc Buffer descriptor from make_srsrc(). + * @param voffset Per-lane global byte offset into the buffer. + */ +__device__ __forceinline__ void async_load_dword_to_lds(uint32_t lds_base_uniform, srsrc_t rsrc, + uint32_t voffset) { + _fi_async_load_to_lds(rsrc, (lds_ptr_t)(uintptr_t)lds_base_uniform, 4, static_cast(voffset), + 0, 0, 0); +} + +/** + * @brief Async-load a 64-thread (256-byte) tile from global memory to LDS. + * + * Covers one full wavefront of dwords in a single call. For larger tiles, + * advance lds_cur and global_base in steps of kWarpSize * sizeof(uint32_t). + * + * @param lds_cur Uniform LDS byte offset (to_sgpr_u32()-pinned). + * @param rsrc Buffer descriptor from make_srsrc(). + * @param global_base Uniform byte offset of the tile start within the buffer. + */ +__device__ __forceinline__ void async_load_tile64_to_lds(uint32_t lds_cur, srsrc_t rsrc, + uint32_t global_base) { + uint32_t lane = threadIdx.x & 0x3fu; + async_load_dword_to_lds(lds_cur, rsrc, + global_base + lane * static_cast(sizeof(uint32_t))); +} + +/** + * @brief Deprecated overload with a divergent LDS pointer — do not use. + * + * The original signature accepted a per-lane lds_ptr_t. This is incorrect: + * M0 is derived from readfirstlane(lds_ptr), so all lanes write to the lane-0 + * offset regardless of their own pointer value. The overload is preserved to + * catch remaining call-sites at compile time via the [[deprecated]] attribute. + * + * @param lds_dst Per-lane LDS destination (divergent; see note above). + * @param rsrc Buffer descriptor from make_srsrc(). + * @param global_byte_offset Per-lane byte offset from rsrc base. + */ +// TODO: Remove once all call-sites are migrated to async_load_tile64_to_lds. +[[deprecated("Use async_load_tile64_to_lds with a uniform LDS base instead.")]] +__device__ __forceinline__ void async_load_64b_to_lds(lds_ptr_t lds_dst, srsrc_t rsrc, + uint32_t global_byte_offset) { + _fi_async_load_to_lds(rsrc, lds_dst, 4, static_cast(global_byte_offset), 0, 0, 0); + _fi_async_load_to_lds(rsrc, lds_dst + 1, 4, static_cast(global_byte_offset + 4), 0, 0, 0); +} + +// ─── Synchronous load functions ─────────────────────────────────────────────── + +/** + * @brief Load 128 bits from global to shared memory. + * + * @tparam PrefetchOpt Prefetch hint (unused on HIP; kept for API parity with the CUDA path). + * @tparam T Element type. + * @param smem_ptr Shared memory destination. + * @param gmem_ptr Global memory source. + */ template __device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { *reinterpret_cast(smem_ptr) = *reinterpret_cast(gmem_ptr); } +/** + * @brief Load 64 bits from global to shared memory. + * + * @tparam PrefetchOpt Prefetch hint (unused on HIP). + * @tparam T Element type. + * @param smem_ptr Shared memory destination. + * @param gmem_ptr Global memory source. + */ template __device__ __forceinline__ void load_64b(T* smem_ptr, const T* gmem_ptr) { *reinterpret_cast(smem_ptr) = *reinterpret_cast(gmem_ptr); } -// Predicated 128-bit load +/** + * @brief Predicated 128-bit load from global to shared memory. + * + * When predicate is false and FillOpt is kFillZero, the destination is zeroed; + * with kNoFill it is left unchanged. + * + * @tparam PrefetchOpt Prefetch hint (unused on HIP). + * @tparam FillOpt Fill mode applied when predicate is false. + * @tparam T Element type. + * @param smem_ptr Shared memory destination. + * @param gmem_ptr Global memory source. + * @param predicate When false, the global load is skipped. + */ template __device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, bool predicate) { if (predicate) { @@ -43,6 +219,16 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, b } } +/** + * @brief Predicated 64-bit load from global to shared memory. + * + * @tparam PrefetchOpt Prefetch hint (unused on HIP). + * @tparam FillOpt Fill mode applied when predicate is false. + * @tparam T Element type. + * @param smem_ptr Shared memory destination. + * @param gmem_ptr Global memory source. + * @param predicate When false, the global load is skipped. + */ template __device__ __forceinline__ void pred_load_64b(T* smem_ptr, const T* gmem_ptr, bool predicate) { if (predicate) { @@ -54,7 +240,15 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr, const T* gmem_ptr, bo } } -// Generic load with NumBits template parameter +/** + * @brief Load NumBits bits from global to shared memory. + * + * @tparam NumBits Transfer width in bits (128 or 256). + * @tparam PrefetchOpt Prefetch hint (unused on HIP). + * @tparam T Element type. + * @param smem_ptr Shared memory destination. + * @param gmem_ptr Global memory source. + */ template __device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) { static_assert(NumBits == 128 || NumBits == 256, "NumBits must be 128 or 256"); @@ -66,12 +260,21 @@ __device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) { } } -// Generic predicated load with NumBits template parameter +/** + * @brief Predicated load of NumBits bits from global to shared memory. + * + * @tparam NumBits Transfer width in bits (64, 128, or 256). + * @tparam PrefetchOpt Prefetch hint (unused on HIP). + * @tparam FillOpt Fill mode applied when predicate is false. + * @tparam T Element type. + * @param smem_ptr Shared memory destination. + * @param gmem_ptr Global memory source. + * @param predicate When false, the global load is skipped. + */ template __device__ __forceinline__ void pred_load(T* smem_ptr, const T* gmem_ptr, bool predicate) { static_assert(NumBits == 64 || NumBits == 128 || NumBits == 256, "NumBits must be 64, 128 or 256"); - if constexpr (NumBits == 64) { pred_load_64b(smem_ptr, gmem_ptr, predicate); } else if constexpr (NumBits == 128) { diff --git a/include/gpu_iface/fastdiv.cuh b/include/gpu_iface/fastdiv.cuh index a19d01def7..c4a207fc76 100644 --- a/include/gpu_iface/fastdiv.cuh +++ b/include/gpu_iface/fastdiv.cuh @@ -62,7 +62,7 @@ struct uint_fastdiv { #ifdef __CUDA_ARCH__ q = __umulhi(m, n); #else - q = (((unsigned long long)((long long)m * (long long)n)) >> 32); + q = (uint32_t)(((uint64_t)m * (uint64_t)n) >> 32); #endif q += a * n; q >>= s; @@ -80,7 +80,7 @@ __host__ __device__ __forceinline__ uint32_t operator/(const uint32_t n, #ifdef __CUDA_ARCH__ q = __umulhi(divisor.m, n); #else - q = (((unsigned long long)((long long)divisor.m * (long long)n)) >> 32); + q = (uint32_t)(((uint64_t)divisor.m * (uint64_t)n) >> 32); #endif q += divisor.a * n; q >>= divisor.s; diff --git a/include/gpu_iface/memory_ops.hpp b/include/gpu_iface/memory_ops.hpp index ad21ddb2e6..da4b3dc5c8 100644 --- a/include/gpu_iface/memory_ops.hpp +++ b/include/gpu_iface/memory_ops.hpp @@ -124,6 +124,38 @@ __device__ __forceinline__ void pred_load(T* smem_ptr, const T* gmem_ptr, bool p mem_detail::pred_load(smem_ptr, gmem_ptr, predicate); } +#if defined(PLATFORM_HIP_DEVICE) +// === HIP-only async GMEM → LDS primitives + +/** + * @brief Build a buffer resource descriptor (V#) for async GMEM→LDS copies. + * + * @param base Tensor base pointer (K or V head pointer) + * @param num_bytes Byte size of the region (use 0xFFFFFFFF to skip bounds check) + * @return srsrc_t 4-SGPR buffer resource descriptor + */ +__device__ __forceinline__ mem_detail::srsrc_t make_srsrc(const void* base, uint32_t num_bytes) { + return mem_detail::make_srsrc(base, num_bytes); +} + +/** + * @brief Async load 64 bits from global buffer to LDS. + * + * Issues two buffer_load_dword lds instructions (vmcnt += 2). + * The wavefront does not stall; the caller must later call + * wait_group() + __syncthreads() before reading lds_dst. + * + * @param lds_dst LDS destination (2 consecutive uint32 slots) + * @param rsrc Buffer resource from make_srsrc() + * @param global_byte_offset Per-thread byte offset from rsrc base + */ +__device__ __forceinline__ void async_load_64b_to_lds(mem_detail::lds_ptr_t lds_dst, + mem_detail::srsrc_t rsrc, + uint32_t global_byte_offset) { + mem_detail::async_load_64b_to_lds(lds_dst, rsrc, global_byte_offset); +} +#endif // PLATFORM_HIP_DEVICE + } // namespace memory } // namespace gpu_iface } // namespace flashinfer