Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rebase glm mask to main #2

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 61 additions & 9 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ void set_params_fprop(Flash_fwd_params &params,
void *softmax_lse_d,
float p_dropout,
float softmax_scale,
bool is_causal) {
bool is_causal,
void *glm_mask_d) {

// Reset the parameters
memset(&params, 0, sizeof(params));
Expand Down Expand Up @@ -103,6 +104,8 @@ void set_params_fprop(Flash_fwd_params &params,

params.is_causal = is_causal;
params.is_seqlens_k_cumulative = true;
params.is_glm_causal = ! (glm_mask_d == nullptr);
params.glm_mask = static_cast<int *>(glm_mask_d);
}

void set_params_dgrad(Flash_bwd_params &params,
Expand Down Expand Up @@ -134,7 +137,8 @@ void set_params_dgrad(Flash_bwd_params &params,
void *dsoftmax_sum_d,
float p_dropout,
float softmax_scale,
bool is_causal) {
bool is_causal,
void *glm_mask_d) {

set_params_fprop(params,
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
Expand All @@ -145,7 +149,8 @@ void set_params_dgrad(Flash_bwd_params &params,
softmax_lse_d,
p_dropout,
softmax_scale,
is_causal);
is_causal,
glm_mask_d);

// Set the pointers and strides.
params.do_ptr = dout.data_ptr();
Expand Down Expand Up @@ -239,7 +244,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
const float softmax_scale,
const bool is_causal,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
c10::optional<at::Generator> gen_,
const c10::optional<at::Tensor> &glm_mask // batch_size
) {

auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
Expand Down Expand Up @@ -325,6 +332,15 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
}

// glm_mask support
if (glm_mask.has_value()) {
TORCH_CHECK(is_causal, "is_causal must be true");
TORCH_CHECK(glm_mask.value().dtype() == torch::kInt32, "glm_mask must have dtype int32");
TORCH_CHECK(glm_mask.value().is_cuda(), "Input tensor must be on CUDA device");
TORCH_CHECK(glm_mask.value().is_contiguous(), "glm_mask must be contiguous");
CHECK_SHAPE(glm_mask.value(), batch_size);
}

Flash_fwd_params params;
set_params_fprop(params,
batch_size,
Expand All @@ -339,7 +355,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
softmax_lse.data_ptr(),
p_dropout,
softmax_scale,
is_causal);
is_causal,
glm_mask.has_value() ? glm_mask->data_ptr() : nullptr);

// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = is_sm90 || is_sm8x
Expand Down Expand Up @@ -403,7 +420,9 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const bool zero_tensors,
const bool is_causal,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
c10::optional<at::Generator> gen_,
const c10::optional<at::Tensor> &glm_mask // batch_size
) {

auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
Expand Down Expand Up @@ -503,6 +522,15 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
if (return_softmax) {p.zero_();}
}

// glm_mask support
if (glm_mask.has_value()) {
TORCH_CHECK(is_causal, "is_causal must be true");
TORCH_CHECK(glm_mask.value().dtype() == torch::kInt32, "glm_mask must have dtype int32");
TORCH_CHECK(glm_mask.value().is_cuda(), "Input tensor must be on CUDA device");
TORCH_CHECK(glm_mask.value().is_contiguous(), "glm_mask must be contiguous");
CHECK_SHAPE(glm_mask.value(), batch_size);
}

Flash_fwd_params params;
set_params_fprop(params,
batch_size,
Expand All @@ -517,7 +545,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
softmax_lse.data_ptr(),
p_dropout,
softmax_scale,
is_causal);
is_causal,
glm_mask.has_value() ? glm_mask->data_ptr() : nullptr);

// number of times random will be generated per thread, to offset philox counter in thc random
// state
Expand Down Expand Up @@ -584,6 +613,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const float softmax_scale,
const bool is_causal,
c10::optional<at::Generator> gen_,
const c10::optional<at::Tensor> &glm_mask, // batch_size
c10::optional<at::Tensor> &rng_state) {
auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
Expand Down Expand Up @@ -714,6 +744,15 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
dv_expanded = dv;
}

// glm_mask support
if (glm_mask.has_value()) {
TORCH_CHECK(is_causal, "is_causal must be true");
TORCH_CHECK(glm_mask.value().dtype() == torch::kInt32, "glm_mask must have dtype int32");
TORCH_CHECK(glm_mask.value().is_cuda(), "Input tensor must be on CUDA device");
TORCH_CHECK(glm_mask.value().is_contiguous(), "glm_mask must be contiguous");
CHECK_SHAPE(glm_mask.value(), batch_size);
}

Flash_bwd_params params;

set_params_dgrad(params,
Expand All @@ -735,7 +774,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
softmax_d.data_ptr(),
p_dropout,
softmax_scale,
is_causal);
is_causal,
glm_mask.has_value() ? glm_mask->data_ptr() : nullptr);

auto launch = &run_mha_bwd;
// launch(params, stream, /*configure=*/true);
Expand Down Expand Up @@ -792,6 +832,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const bool zero_tensors,
const bool is_causal,
c10::optional<at::Generator> gen_,
const c10::optional<at::Tensor> &glm_mask, // batch_size
c10::optional<at::Tensor> &rng_state
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
Expand Down Expand Up @@ -934,6 +975,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
softmax_d.zero_();
}


// glm_mask support
if (glm_mask.has_value()) {
TORCH_CHECK(is_causal, "is_causal must be true");
TORCH_CHECK(glm_mask.value().dtype() == torch::kInt32, "glm_mask must have dtype int32");
TORCH_CHECK(glm_mask.value().is_cuda(), "Input tensor must be on CUDA device");
TORCH_CHECK(glm_mask.value().is_contiguous(), "glm_mask must be contiguous");
CHECK_SHAPE(glm_mask.value(), batch_size);
}

Flash_bwd_params params;

set_params_dgrad(params,
Expand All @@ -953,7 +1004,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
softmax_d.data_ptr(),
p_dropout,
softmax_scale,
is_causal);
is_causal,
glm_mask.has_value() ? glm_mask->data_ptr() : nullptr);

auto launch = &run_mha_bwd;
// launch(params, stream, /*configure=*/true);
Expand Down
9 changes: 6 additions & 3 deletions csrc/flash_attn/src/block_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ struct BlockInfo {
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
// origin logic, actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
, break_point(params.is_glm_causal ? params.glm_mask[bidb] : 0)
{
}

Expand All @@ -35,10 +37,11 @@ struct BlockInfo {

const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int seqlen_k_cache;
const int actual_seqlen_k;
const uint32_t actual_seqlen_q;
const uint32_t actual_seqlen_k;
const int break_point;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
3 changes: 3 additions & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ struct Flash_fwd_params : public Qkv_params {
bool is_seqlens_k_cumulative;

int num_splits; // For split-KV version
// glm mask
bool is_glm_causal;
int * __restrict__ glm_mask;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
56 changes: 46 additions & 10 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -654,14 +654,50 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.d_rounded;

int m_block = m_block_max - 1;
int m_block_min = !Is_causal ? 0 : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM);
// We're guaranteed that m_block_min <= m_block:
// We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
// n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
// So m_block_min <= (actual_seqlen_q - 1) / kBlockM.
// Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.
// So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.
// We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.
int m_block_min = 0;
if (Is_causal) {
m_block_min = std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM);
if (params.is_glm_causal) {
m_block_min = binfo.break_point > n_block * kBlockN ? 0 : m_block_min;
}
}

// We might need to exit early and write 0 to dK and dV.
// Otherwise we get wrong result for the case where we don't enter the for loop.
// And we might read OOB elements from gQ and gdO.
// TODO: what if we're not parallelizing, do we need to compute dot_do_o?
if (Is_causal && m_block < m_block_min) {
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dv_row_stride, _1{}));
auto gmem_thr_copy_dKV = typename Kernel_traits::GmemTiledCopydKV{}.get_thread_slice(tidx);
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
clear(tdKrdK);
clear(tdVrdV);
Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
#pragma unroll
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
return;
}

if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ
tQsQ.data() = tQsQ.data() + size(sQ);
Expand Down Expand Up @@ -792,7 +828,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q,
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS * 16);
AtomLayoutMS * 16, binfo.break_point);
}
}
// if (cute::thread(32, 0)) { print(scores); }
Expand Down Expand Up @@ -1338,7 +1374,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
AtomLayoutMS * 16);
AtomLayoutMS * 16, binfo.break_point);
}
// Compute the exponential value.
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
Expand Down
8 changes: 5 additions & 3 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi

int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if (Is_causal) {
n_block_max = std::min(n_block_max,
cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN));
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN));
if (params.is_glm_causal) {
n_block_max = std::max(n_block_max, cute::ceil_div(binfo.break_point, kBlockN));
}
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// }
Expand Down Expand Up @@ -426,7 +428,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16);
kNWarps * 16, binfo.break_point);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
}
Expand Down
10 changes: 5 additions & 5 deletions csrc/flash_attn/src/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
}

template <typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, const int warp_row_stride) {
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_,
const uint32_t max_seqlen_k, const uint32_t row_idx_offset_,
const uint32_t warp_row_stride, const uint32_t break_point) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
Expand All @@ -162,8 +162,8 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const i
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
if (col_idx >= col_idx_limit) {
const uint32_t col_idx = col_idx_base + j;
if (col_idx >= col_idx_limit && col_idx >= break_point) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
Expand Down
Loading