Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,10 +860,15 @@ def _var(dtype):
return T.alloc_buffer((1,), dtype, scope="local")


def _causal_mask(causal, row, col, kv_len, qo_len):
def _causal_mask(causal, row, col, kv_len, qo_len, sliding_window_size=-1):
lower_bound_condition = T.if_then_else(
sliding_window_size > 0,
col >= kv_len - qo_len + row - sliding_window_size + 1,
True,
)
return T.if_then_else(
causal > 0,
col < kv_len - qo_len + row + 1,
tir.all(col < kv_len - qo_len + row + 1, lower_bound_condition),
col < kv_len,
)

Expand Down Expand Up @@ -2300,6 +2305,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
var_k_rope_pos_offset: T.handle, # [b]
var_output: T.handle, # [total_len, h_q, d_v]
var_lse: T.handle, # [total_len, h_q]
sliding_window_size: T.int32,
causal: T.int32,
rotary_mode: T.int32,
rope_scale: T.float32,
Expand Down Expand Up @@ -2364,6 +2370,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
col=k_idx,
kv_len=kv_indptr[b + 1] - kv_indptr[b],
qo_len=q_indptr[b + 1] - q_indptr[b],
sliding_window_size=sliding_window_size,
):
result[0] = 0.0
for d_idx in T.serial(d_qk):
Expand Down Expand Up @@ -2442,6 +2449,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
var_k_rope_pos_offset: T.handle, # [b]
var_output: T.handle, # [total_len, h_q, d_v]
var_lse: T.handle, # [total_len, h_q]
sliding_window_size: T.int32,
causal: T.int32,
rotary_mode: T.int32,
rope_scale: T.float32,
Expand Down Expand Up @@ -2609,7 +2617,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
row=row_,
col=L_kv_start + j,
kv_len=kv_chunk_len[0],
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx],
sliding_window_size=sliding_window_size):
m_new[i] = T.max(m_new[i], S_smem[row, j])
d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])

Expand All @@ -2624,7 +2633,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
row=row_,
col=L_kv_start + j,
kv_len=kv_chunk_len[0],
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx],
sliding_window_size=sliding_window_size):
S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
else:
S_smem[row, j] = T.exp2(-5e4 - m_new[i])
Expand Down
8 changes: 4 additions & 4 deletions src/runtime/vm/attn_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ class RaggedPrefillFunc : public AttnBackendFunc {
: AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {}

virtual void MHA(Tensor q, Tensor k, Tensor v, Tensor qo_indptr, Tensor kv_indptr,
Tensor q_rope_position, Tensor k_rope_pos_offset, bool causal,
Tensor q_rope_position, Tensor k_rope_pos_offset, int sliding_window_size, bool causal,
RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale,
Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) {
LOG(FATAL) << "MHA computation is not supported by the current backend";
Expand All @@ -308,11 +308,11 @@ class TIRRaggedPrefillFunc : public RaggedPrefillFunc {
: RaggedPrefillFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {}

void MHA(Tensor q, Tensor k, Tensor v, Tensor qo_indptr, Tensor kv_indptr, Tensor q_rope_position,
Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale,
Tensor k_rope_pos_offset, int sliding_window_size, bool causal, RoPEMode rope_mode, double rotary_scale,
double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse,
TVMStreamHandle compute_stream) final {
attn_func_(q, qo_indptr, k, v, kv_indptr, q_rope_position, k_rope_pos_offset, attn_output,
attn_lse, static_cast<int64_t>(causal),
attn_lse, sliding_window_size, static_cast<int64_t>(causal),
/*rotary_mode=*/static_cast<int64_t>(rope_mode == RoPEMode::kInline), rotary_scale,
rotary_theta, sm_scale);
}
Expand All @@ -330,7 +330,7 @@ class FlashInferRaggedPrefillFunc : public RaggedPrefillFunc {
plan_func_(std::move(plan_func)) {}

void MHA(Tensor q, Tensor k, Tensor v, Tensor qo_indptr, Tensor kv_indptr, Tensor q_rope_position,
Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale,
Tensor k_rope_pos_offset, int sliding_window_size, bool causal, RoPEMode rope_mode, double rotary_scale,
double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse,
TVMStreamHandle compute_stream) final {
Device device = q->device;
Expand Down
83 changes: 46 additions & 37 deletions src/runtime/vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
const bool support_sliding_window_;
/*! \brief A boolean flag indicating if the KV cache has per layer sliding window. */
const bool support_layer_sliding_window_;
/*! \brief The sliding window size for sliding window attention. */
int32_t sliding_window_size_;
/*! \brief The attention kinds for each layer. */
const std::vector<AttnKind> attn_kinds_;

Expand Down Expand Up @@ -314,6 +316,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
: support_sliding_window),
support_layer_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(),
AttnKind::kMHASliding) != attn_kinds.end()),
sliding_window_size_(-1),
attn_kinds_(std::move(attn_kinds)),
rope_mode_(support_sliding_window && rope_mode != RoPEMode::kNone ? RoPEMode::kInline
: rope_mode),
Expand Down Expand Up @@ -766,6 +769,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
// introduce more sink. Therefore, we update the given attn sink size.
it->second.last_block_attn_sink_size = std::max(attn_sink_size - prefix_length, 0);
it->second.sliding_window_size = sliding_window_size;
if (sliding_window_size_ == -1)
sliding_window_size_ = sliding_window_size;
Comment on lines +772 to +773
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to set sliding_window_size_ is a bit fragile. If EnableSlidingWindowForSeq is called for different sequences with different sliding_window_size values, only the first one will take effect, and subsequent calls with different sizes will be silently ignored. This could lead to unexpected behavior. To make this assumption explicit and prevent silent errors, I'd suggest adding a check to ensure that all calls use a consistent sliding_window_size.

Suggested change
if (sliding_window_size_ == -1)
sliding_window_size_ = sliding_window_size;
if (sliding_window_size_ == -1) {
sliding_window_size_ = sliding_window_size;
} else {
ICHECK_EQ(sliding_window_size_, sliding_window_size)
<< "Inconsistent sliding window sizes are not supported. Previously got "
<< sliding_window_size_ << ", but now got " << sliding_window_size;
}

}

void PopN(int64_t seq_id, int32_t n) final {
Expand Down Expand Up @@ -1009,14 +1014,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
// Do the same for page_indices_sliding_window
}

// For sliding window, the first page and last page will both be partially used
page_indptr_sliding_window_h.push_back(
page_indptr_sliding_window_h.back() +
std::min(static_cast<int32_t>(block.page_ids.size()),
static_cast<int32_t>(1024 / page_size_ +
(block.seq_length % page_size_ ? 1 : 0))));
if (support_layer_sliding_window_ && sliding_window_size_ > 0) {
// For sliding window, the first page and last page will both be partially used
page_indptr_sliding_window_h.push_back(
page_indptr_sliding_window_h.back() +
std::min(static_cast<int32_t>(block.page_ids.size()),
static_cast<int32_t>(sliding_window_size_ / page_size_ +
(block.seq_length % page_size_ ? 1 : 0))));
}
for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back();
i < static_cast<int32_t>(page_indices_h.size()); i++) {
i < static_cast<int32_t>(page_indices_h.size()); i++) {
page_indices_sliding_window_h.push_back(page_indices_h[i]);
}
// set up the page indices properly by choosing the last (sliding_window_size /
Expand All @@ -1027,8 +1034,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
: (block.seq_length - block.sink_length + block.sliding_window_offset - 1) %
page_size_ +
1);
if (support_layer_sliding_window_) {
if (block.seq_length < 1024) {
if (support_layer_sliding_window_ && sliding_window_size_ > 0) {
if (block.seq_length < sliding_window_size_) {
sliding_window_offset_h.push_back(0);
} else {
sliding_window_offset_h.push_back(block.seq_length % page_size_);
Expand All @@ -1040,9 +1047,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
k_rope_pos_offset_h.push_back(block.start_pos);

// If sliding window, we need to calculate the positional offset
if (support_layer_sliding_window_) {
if (support_layer_sliding_window_ && sliding_window_size_ > 0) {
k_rope_pos_offset_sliding_window_h.push_back(
std::max(0, block.start_pos + block.seq_length - 1024));
std::max(0, block.start_pos + block.seq_length - sliding_window_size_));
}
} else {
// Blocks at maximum depth
Expand All @@ -1064,11 +1071,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
last_block_id = id;
}
page_indptr_h.push_back(page_indptr_h.back() + num_pages);
page_indptr_sliding_window_h.push_back(
page_indptr_sliding_window_h.back() +
std::min(static_cast<int32_t>(block.page_ids.size()),
static_cast<int32_t>(1024 / page_size_ +
(block.seq_length % page_size_ ? 1 : 0))));
if (support_layer_sliding_window_ && sliding_window_size_ > 0) {
page_indptr_sliding_window_h.push_back(
page_indptr_sliding_window_h.back() +
std::min(static_cast<int32_t>(block.page_ids.size()),
static_cast<int32_t>(sliding_window_size_ / page_size_ +
(block.seq_length % page_size_ ? 1 : 0))));
}
for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back();
i < static_cast<int32_t>(page_indices_h.size()); i++) {
page_indices_sliding_window_h.push_back(page_indices_h[i]);
Expand All @@ -1080,8 +1089,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
last_block.sliding_window_offset - 1) %
page_size_ +
1);
if (support_layer_sliding_window_) {
if (last_block.seq_length < 1024) {
if (support_layer_sliding_window_ && sliding_window_size_ > 0) {
if (last_block.seq_length < sliding_window_size_) {
sliding_window_offset_h.push_back(0);
} else {
sliding_window_offset_h.push_back(last_block.seq_length % page_size_);
Expand All @@ -1091,9 +1100,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
sink_size_h.push_back(last_block.sink_length);
k_rope_pos_offset_h.push_back(block.start_pos);
if (support_layer_sliding_window_) {
if (support_layer_sliding_window_ && sliding_window_size_ > 0) {
k_rope_pos_offset_sliding_window_h.push_back(
std::max(0, block.start_pos + block.seq_length - 1024));
std::max(0, block.start_pos + block.seq_length - sliding_window_size_));
}
}
}
Expand Down Expand Up @@ -1408,8 +1417,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
// The auxiliary data structure on device must have been synchronized.
ICHECK(!dirty_aux_data_device_);

if (attn_kind == AttnKind::kMHA) {
MHASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale);
if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) {
MHASelfAttnInternal(local_layer_id, q_data, k_data, v_data, o_data, lse_data, sm_scale);
} else {
MLASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale);
}
Expand Down Expand Up @@ -2089,7 +2098,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
if (!append_before_attn_) {
// The first part of attention, which only involves the q and the newly appended k/v.
is_first_kernel = false;
MHASelfAttnInternal(q_data, k_data, v_data, output, merged_attn_lse_view_, sm_scale);
MHASelfAttnInternal(local_layer_id, q_data, k_data, v_data, output, merged_attn_lse_view_, sm_scale);
}
bool self_attn_computed = !is_first_kernel;
bool cross_attn_computed = MHACrossAttnInternal(
Expand All @@ -2098,15 +2107,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
<< "Both self-attention and cross-attention are not computed.";
}

void MHASelfAttnInternal(Tensor q_data, Tensor k_data, Tensor v_data, Tensor o_data,
void MHASelfAttnInternal(int64_t local_layer_id, Tensor q_data, Tensor k_data, Tensor v_data, Tensor o_data,
Tensor lse_data, double sm_scale) {
if (is_chain_on_depths_[0]) {
// If the batch does not form a tree, use raggedness prefill kernel.

int sliding_window_size =
(!support_sliding_window_ &&
attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding)
? -1
: sliding_window_size_;
ICHECK_NOTNULL(f_attention_prefill_ragged_);
f_attention_prefill_ragged_->MHA(
q_data, k_data, v_data, cur_append_length_indptr_view_, cur_append_length_indptr_view_,
q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, /*causal=*/true, rope_mode_,
rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_);
q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, sliding_window_size, /*causal=*/true,
rope_mode_, rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_);
} else {
// The batch requires tree attention.
ICHECK(f_attention_prefill_with_tree_mask_ != nullptr)
Expand All @@ -2127,8 +2142,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
ICHECK_NOTNULL(f_attention_prefill_ragged_);
f_attention_prefill_ragged_->MHA(
q_data, k_data, v_data, cur_append_length_indptr_view_, cur_append_length_indptr_view_,
q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, /*causal=*/true, RoPEMode::kNone,
rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_);
q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, -1, /*causal=*/true,
RoPEMode::kNone, rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_);
}

/*! \brief Compute cross-attention for MHA. Return if there is effective computation. */
Expand Down Expand Up @@ -2165,45 +2180,39 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
Tensor page_indices;
Tensor length_info;
Tensor k_rope_pos;
double rotary_theta;
double rotary_scale;

if (attn_kinds_[local_layer_id + layer_id_begin_offset_] == AttnKind::kMHASliding) {
page_indptr = page_indptr_sliding_window_on_depths_view_[d];
page_indices = page_indices_sliding_window_on_depths_view_[d];
length_info = layer_sliding_window_length_info_on_depths_view_[d];
k_rope_pos = k_rope_pos_offset_sliding_window_view_[d];
rotary_theta = 10000;
rotary_scale = 1;
} else {
page_indptr = page_indptr_on_depths_view_[d];
page_indices = page_indices_on_depths_view_[d];
length_info = length_info_on_depths_view_[d];
k_rope_pos = k_rope_pos_offset_view_[d];
rotary_theta = rotary_theta_;
rotary_scale = rotary_scale_;
}

if (append_before_attn_ && !is_chain_on_depths_[d]) {
ICHECK_NOTNULL(f_attention_prefill_with_tree_mask_paged_kv_);
f_attention_prefill_with_tree_mask_paged_kv_->MHA(
q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, page_indices,
length_info, k_rope_pos, q_rope_position_map_view_, tree_attn_mn_indptr_view_[d],
tree_attn_mask_view_[d], rope_mode_, rotary_scale, rotary_theta, sm_scale, attn_output,
tree_attn_mask_view_[d], rope_mode_, rotary_scale_, rotary_theta_, sm_scale, attn_output,
attn_lse, compute_stream_);
} else if (use_decode_kernel_[d]) {
// Use decode kernel for depth d
ICHECK_NOTNULL(f_decode);
f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr, page_indices, length_info,
k_rope_pos, q_rope_position_map_view_, rope_mode_, rotary_scale, rotary_theta,
k_rope_pos, q_rope_position_map_view_, rope_mode_, rotary_scale_, rotary_theta_,
sm_scale, attn_output, attn_lse, compute_stream_);
} else {
// Use prefill kernel for depth d
ICHECK_NOTNULL(f_prefill);
f_prefill->MHA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr,
page_indices, length_info, q_rope_position_map_view_, k_rope_pos,
/*causal=*/false,
/*rotary_mode=*/rope_mode_, rotary_scale, rotary_theta, sm_scale,
/*rotary_mode=*/rope_mode_, rotary_scale_, rotary_theta_, sm_scale,
attn_output, attn_lse, compute_stream_);
}

Expand Down