Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(store): Cache output stride parameters in registers to reduce global loads #6

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion csrc/flash_fwd_mla_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params &params, const

const int tidx = threadIdx.x;

// Cache frequently used parameters into registers for optimization
const index_t o_batch_stride = __ldg(&params.o_batch_stride);
const index_t o_row_stride = __ldg(&params.o_row_stride);
const index_t o_head_stride = __ldg(&params.o_head_stride);

typename Kernel_traits::TiledMmaO tiled_mma_o;
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);

Expand All @@ -179,9 +184,10 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params &params, const

__syncthreads();

// Stage accumulator fragment to shared memory using tiled copy
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);

const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_o = bidb * o_batch_stride + m_block * kBlockM * o_row_stride + bidh * o_head_stride;
const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Expand Down