Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 106 additions & 25 deletions include/flashinfer/attention/generic/permuted_smem.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,40 +45,70 @@ constexpr __host__ __device__ __forceinline__ uint32_t upcast_size() {
}

/*!
* \brief The shared memory wrapper.
* \brief Pure arithmetic layout policy for XOR-swizzled shared memory tiles.
*
* Contains no pointer and no memory — only the coordinate arithmetic that maps
* logical (row, col) to physical cell index and vice-versa. Its methods are
* static device functions implementing pure arithmetic that the compiler can
* typically eliminate entirely at compile time.
*
* This type is passed as a template parameter to smem_t (composition), giving
* the caller an explicit bijection handle: any code that derives global-memory
* coordinates from smem_t::Layout is structurally guaranteed to use the same
* swizzle pattern as the LDS read/write path.
*/
template <SwizzleMode swizzle_mode, typename BasePtrTy = b128_t>
struct smem_t {
// The base pointer.
BasePtrTy* base;
__device__ __forceinline__ smem_t() : base(nullptr) {}
template <typename T>
__device__ __forceinline__ smem_t(T* base) : base((BasePtrTy*)base) {}
template <SwizzleMode swizzle_mode>
struct SwizzleLayout {
// ── Primitive ──────────────────────────────────────────────────────────────
// XOR mask applied to column bits for a given row.
// XOR is self-inverse, so the same mask is used for both the forward
// (LDS write) and inverse (global read) directions.
template <uint32_t stride>
static __device__ __forceinline__ uint32_t col_swizzle_xor(uint32_t row) {
if constexpr (swizzle_mode == SwizzleMode::k128B_16Row) {
constexpr uint32_t period = (stride >= 16u) ? 16u : 8u;
return row & (period - 1u);
} else if constexpr (swizzle_mode == SwizzleMode::k128B) {
return row & 7u;
} else if constexpr (swizzle_mode == SwizzleMode::k64B) {
return (row >> 1u) & 3u;
} else {
return 0u; // kLinear
}
}

// ── Derived from the primitive ─────────────────────────────────────────────

/*!
* \brief Compute the element offset given coordinates in a permuted shared
* memory.
* \tparam stride The stride (in terms of b128_t's) in the permuted shared
* memory.
* \brief Compute the element offset given coordinates in a permuted shared memory.
* \tparam stride The stride (in terms of BasePtrTy elements) in the permuted shared memory.
* \param i The row index.
* \param j The column index.
*/
template <uint32_t stride>
static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, uint32_t j) {
if constexpr (swizzle_mode == SwizzleMode::k128B) {
return i * stride + (j ^ (i % 8));
} else if constexpr (swizzle_mode == SwizzleMode::k128B_16Row) {
// Extend the XOR period from 8 to 16 when stride allows it, eliminating
// the 8-way read-path bank conflicts that k128B has on CDNA3 MI300x.
constexpr uint32_t period = (stride >= 16u) ? 16u : 8u;
return i * stride + (j ^ (i % period));
} else if constexpr (swizzle_mode == SwizzleMode::k64B) {
static_assert(stride == 4);
return i * stride + (j ^ ((i / 2) % 4));
} else {
// swizzle_mode == SwizzleMode::kLinear
return i * stride + j;
if constexpr (swizzle_mode == SwizzleMode::k64B) {
static_assert(stride == 4, "k64B swizzle requires stride == 4");
}
return i * stride + (j ^ col_swizzle_xor<stride>(i));
}

/*!
* \brief Inverse of get_permuted_offset: recover (row, col) from a physical LDS cell index.
*
* XOR is self-inverse ((x ^ mask) ^ mask == x), so the same col_swizzle_xor
* expression serves both the forward and inverse direction in all supported
* XOR-based swizzle modes.
*
Comment on lines +97 to +102
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Doxygen comment here has a formatting error: the “XOR-based swizzle modes.” line is missing the leading *, so it won’t be included correctly in generated docs.

Copilot uses AI. Check for mistakes.
* \tparam stride The BasePtrTy-unit row stride (UPCAST_STRIDE_K / UPCAST_STRIDE_Q / etc.).
* \param cell Physical cell index (0 .. smem_size-1).
* \returns b64_t{row, col} — the logical (row, column) that maps to this cell.
*/
template <uint32_t stride>
static __device__ __forceinline__ b64_t get_inverse_offset(uint32_t cell) {
const uint32_t row = cell / stride;
const uint32_t col_sw = cell % stride;
return b64_t{row, col_sw ^ col_swizzle_xor<stride>(row)};
}

// advance_offset_by_column
Expand Down Expand Up @@ -191,6 +221,57 @@ struct smem_t {
return offset + step_size * row_stride;
}
}
};

/*!
* \brief Shared memory wrapper parameterized over a layout policy (composition).
*
* The LayoutPolicy template parameter (typically SwizzleLayout<swizzle_mode>) owns
* all coordinate arithmetic. smem_t exposes thin __forceinline__ wrappers that
* forward to LayoutPolicy, so existing call sites — smem.get_permuted_offset<stride>(i,j),
* smem->advance_offset_by_column<step,...>(...), etc. — compile without changes.
*
* The Layout type alias is the bijection handle: any caller that derives global-memory
* coordinates via smem_t::Layout is structurally tied to the same swizzle pattern
* as the LDS read/write path. It is impossible to accidentally mix modes.
*/
template <typename LayoutPolicy, typename BasePtrTy = b128_t>
struct smem_t {
using Layout = LayoutPolicy;

// The base pointer.
BasePtrTy* base;
__device__ __forceinline__ smem_t() : base(nullptr) {}
template <typename T>
__device__ __forceinline__ smem_t(T* base) : base((BasePtrTy*)base) {}

// ── Thin wrappers forwarding to Layout ────────────────────────────────────
// All are static __forceinline__ — zero runtime cost, zero binary bloat.

template <uint32_t stride>
static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, uint32_t j) {
return Layout::template get_permuted_offset<stride>(i, j);
}

template <uint32_t stride>
static __device__ __forceinline__ b64_t get_inverse_offset(uint32_t cell) {
return Layout::template get_inverse_offset<stride>(cell);
}

template <uint32_t step_size, uint32_t stride = 0>
static __device__ __forceinline__ uint32_t advance_offset_by_column(uint32_t offset,
uint32_t step_idx,
uint32_t col_idx = 0) {
return Layout::template advance_offset_by_column<step_size, stride>(offset, step_idx, col_idx);
}

template <uint32_t step_size, uint32_t row_stride>
static __device__ __forceinline__ uint32_t advance_offset_by_row(uint32_t offset,
uint32_t row_idx = 0) {
return Layout::template advance_offset_by_row<step_size, row_stride>(offset, row_idx);
}

// ── LDS memory operations ─────────────────────────────────────────────────

template <typename T = uint32_t>
__device__ __forceinline__ void load_fragment(uint32_t offset, T* frag) {
Expand Down
Loading
Loading