diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index e94d5c42957b..8e1dd1e5b443 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -860,10 +860,15 @@ def _var(dtype): return T.alloc_buffer((1,), dtype, scope="local") -def _causal_mask(causal, row, col, kv_len, qo_len): +def _causal_mask(causal, row, col, kv_len, qo_len, sliding_window_size=-1): + lower_bound_condition = T.if_then_else( + sliding_window_size > 0, + col >= kv_len - qo_len + row - sliding_window_size + 1, + True, + ) return T.if_then_else( causal > 0, - col < kv_len - qo_len + row + 1, + tir.all(col < kv_len - qo_len + row + 1, lower_bound_condition), col < kv_len, ) @@ -2300,6 +2305,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches var_k_rope_pos_offset: T.handle, # [b] var_output: T.handle, # [total_len, h_q, d_v] var_lse: T.handle, # [total_len, h_q] + sliding_window_size: T.int32, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, @@ -2364,6 +2370,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches col=k_idx, kv_len=kv_indptr[b + 1] - kv_indptr[b], qo_len=q_indptr[b + 1] - q_indptr[b], + sliding_window_size=sliding_window_size, ): result[0] = 0.0 for d_idx in T.serial(d_qk): @@ -2442,6 +2449,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches var_k_rope_pos_offset: T.handle, # [b] var_output: T.handle, # [total_len, h_q, d_v] var_lse: T.handle, # [total_len, h_q] + sliding_window_size: T.int32, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, @@ -2609,7 +2617,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches row=row_, col=L_kv_start + j, kv_len=kv_chunk_len[0], - qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx], + sliding_window_size=sliding_window_size): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) @@ -2624,7 +2633,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches row=row_, col=L_kv_start + j, kv_len=kv_chunk_len[0], - qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx], + sliding_window_size=sliding_window_size): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: S_smem[row, j] = T.exp2(-5e4 - m_new[i]) diff --git a/src/runtime/vm/attn_backend.h b/src/runtime/vm/attn_backend.h index 31f1ce9f4ad2..600ac28444c9 100644 --- a/src/runtime/vm/attn_backend.h +++ b/src/runtime/vm/attn_backend.h @@ -286,7 +286,7 @@ class RaggedPrefillFunc : public AttnBackendFunc { : AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {} virtual void MHA(Tensor q, Tensor k, Tensor v, Tensor qo_indptr, Tensor kv_indptr, - Tensor q_rope_position, Tensor k_rope_pos_offset, bool causal, + Tensor q_rope_position, Tensor k_rope_pos_offset, int sliding_window_size, bool causal, RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { LOG(FATAL) << "MHA computation is not supported by the current backend"; @@ -308,11 +308,11 @@ class TIRRaggedPrefillFunc : public RaggedPrefillFunc { : RaggedPrefillFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {} void MHA(Tensor q, Tensor k, Tensor v, Tensor qo_indptr, Tensor kv_indptr, Tensor q_rope_position, - Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, + Tensor k_rope_pos_offset, int sliding_window_size, bool causal, RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) final { attn_func_(q, qo_indptr, k, v, kv_indptr, q_rope_position, k_rope_pos_offset, attn_output, - attn_lse, static_cast(causal), + attn_lse, sliding_window_size, static_cast(causal), /*rotary_mode=*/static_cast(rope_mode == RoPEMode::kInline), rotary_scale, rotary_theta, sm_scale); } @@ -330,7 +330,7 @@ class FlashInferRaggedPrefillFunc : public RaggedPrefillFunc { plan_func_(std::move(plan_func)) {} void MHA(Tensor q, Tensor k, Tensor v, Tensor qo_indptr, Tensor kv_indptr, Tensor q_rope_position, - Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, + Tensor k_rope_pos_offset, int sliding_window_size, bool causal, RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) final { Device device = q->device; diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index 4fb3cd69d60f..f57dab2d0f87 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -101,6 +101,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { const bool support_sliding_window_; /*! \brief A boolean flag indicating if the KV cache has per layer sliding window. */ const bool support_layer_sliding_window_; + /*! \brief The sliding window size for sliding window attention. */ + int32_t sliding_window_size_; /*! \brief The attention kinds for each layer. */ const std::vector attn_kinds_; @@ -314,6 +316,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { : support_sliding_window), support_layer_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(), AttnKind::kMHASliding) != attn_kinds.end()), + sliding_window_size_(-1), attn_kinds_(std::move(attn_kinds)), rope_mode_(support_sliding_window && rope_mode != RoPEMode::kNone ? RoPEMode::kInline : rope_mode), @@ -766,6 +769,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // introduce more sink. Therefore, we update the given attn sink size. it->second.last_block_attn_sink_size = std::max(attn_sink_size - prefix_length, 0); it->second.sliding_window_size = sliding_window_size; + if (sliding_window_size_ == -1) + sliding_window_size_ = sliding_window_size; } void PopN(int64_t seq_id, int32_t n) final { @@ -1009,14 +1014,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Do the same for page_indices_sliding_window } - // For sliding window, the first page and last page will both be partially used - page_indptr_sliding_window_h.push_back( - page_indptr_sliding_window_h.back() + - std::min(static_cast(block.page_ids.size()), - static_cast(1024 / page_size_ + - (block.seq_length % page_size_ ? 1 : 0)))); + if (support_layer_sliding_window_ && sliding_window_size_ > 0) { + // For sliding window, the first page and last page will both be partially used + page_indptr_sliding_window_h.push_back( + page_indptr_sliding_window_h.back() + + std::min(static_cast(block.page_ids.size()), + static_cast(sliding_window_size_ / page_size_ + + (block.seq_length % page_size_ ? 1 : 0)))); + } for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); - i < static_cast(page_indices_h.size()); i++) { + i < static_cast(page_indices_h.size()); i++) { page_indices_sliding_window_h.push_back(page_indices_h[i]); } // set up the page indices properly by choosing the last (sliding_window_size / @@ -1027,8 +1034,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { : (block.seq_length - block.sink_length + block.sliding_window_offset - 1) % page_size_ + 1); - if (support_layer_sliding_window_) { - if (block.seq_length < 1024) { + if (support_layer_sliding_window_ && sliding_window_size_ > 0) { + if (block.seq_length < sliding_window_size_) { sliding_window_offset_h.push_back(0); } else { sliding_window_offset_h.push_back(block.seq_length % page_size_); @@ -1040,9 +1047,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { k_rope_pos_offset_h.push_back(block.start_pos); // If sliding window, we need to calculate the positional offset - if (support_layer_sliding_window_) { + if (support_layer_sliding_window_ && sliding_window_size_ > 0) { k_rope_pos_offset_sliding_window_h.push_back( - std::max(0, block.start_pos + block.seq_length - 1024)); + std::max(0, block.start_pos + block.seq_length - sliding_window_size_)); } } else { // Blocks at maximum depth @@ -1064,11 +1071,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { last_block_id = id; } page_indptr_h.push_back(page_indptr_h.back() + num_pages); - page_indptr_sliding_window_h.push_back( - page_indptr_sliding_window_h.back() + - std::min(static_cast(block.page_ids.size()), - static_cast(1024 / page_size_ + - (block.seq_length % page_size_ ? 1 : 0)))); + if (support_layer_sliding_window_ && sliding_window_size_ > 0) { + page_indptr_sliding_window_h.push_back( + page_indptr_sliding_window_h.back() + + std::min(static_cast(block.page_ids.size()), + static_cast(sliding_window_size_ / page_size_ + + (block.seq_length % page_size_ ? 1 : 0)))); + } for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); i < static_cast(page_indices_h.size()); i++) { page_indices_sliding_window_h.push_back(page_indices_h[i]); @@ -1080,8 +1089,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { last_block.sliding_window_offset - 1) % page_size_ + 1); - if (support_layer_sliding_window_) { - if (last_block.seq_length < 1024) { + if (support_layer_sliding_window_ && sliding_window_size_ > 0) { + if (last_block.seq_length < sliding_window_size_) { sliding_window_offset_h.push_back(0); } else { sliding_window_offset_h.push_back(last_block.seq_length % page_size_); @@ -1091,9 +1100,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } sink_size_h.push_back(last_block.sink_length); k_rope_pos_offset_h.push_back(block.start_pos); - if (support_layer_sliding_window_) { + if (support_layer_sliding_window_ && sliding_window_size_ > 0) { k_rope_pos_offset_sliding_window_h.push_back( - std::max(0, block.start_pos + block.seq_length - 1024)); + std::max(0, block.start_pos + block.seq_length - sliding_window_size_)); } } } @@ -1408,8 +1417,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // The auxiliary data structure on device must have been synchronized. ICHECK(!dirty_aux_data_device_); - if (attn_kind == AttnKind::kMHA) { - MHASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale); + if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) { + MHASelfAttnInternal(local_layer_id, q_data, k_data, v_data, o_data, lse_data, sm_scale); } else { MLASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale); } @@ -2089,7 +2098,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (!append_before_attn_) { // The first part of attention, which only involves the q and the newly appended k/v. is_first_kernel = false; - MHASelfAttnInternal(q_data, k_data, v_data, output, merged_attn_lse_view_, sm_scale); + MHASelfAttnInternal(local_layer_id, q_data, k_data, v_data, output, merged_attn_lse_view_, sm_scale); } bool self_attn_computed = !is_first_kernel; bool cross_attn_computed = MHACrossAttnInternal( @@ -2098,15 +2107,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { << "Both self-attention and cross-attention are not computed."; } - void MHASelfAttnInternal(Tensor q_data, Tensor k_data, Tensor v_data, Tensor o_data, + void MHASelfAttnInternal(int64_t local_layer_id, Tensor q_data, Tensor k_data, Tensor v_data, Tensor o_data, Tensor lse_data, double sm_scale) { if (is_chain_on_depths_[0]) { // If the batch does not form a tree, use raggedness prefill kernel. + + int sliding_window_size = + (!support_sliding_window_ && + attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding) + ? -1 + : sliding_window_size_; ICHECK_NOTNULL(f_attention_prefill_ragged_); f_attention_prefill_ragged_->MHA( q_data, k_data, v_data, cur_append_length_indptr_view_, cur_append_length_indptr_view_, - q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, /*causal=*/true, rope_mode_, - rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_); + q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, sliding_window_size, /*causal=*/true, + rope_mode_, rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_); } else { // The batch requires tree attention. ICHECK(f_attention_prefill_with_tree_mask_ != nullptr) @@ -2127,8 +2142,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { ICHECK_NOTNULL(f_attention_prefill_ragged_); f_attention_prefill_ragged_->MHA( q_data, k_data, v_data, cur_append_length_indptr_view_, cur_append_length_indptr_view_, - q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, /*causal=*/true, RoPEMode::kNone, - rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_); + q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, -1, /*causal=*/true, + RoPEMode::kNone, rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_); } /*! \brief Compute cross-attention for MHA. Return if there is effective computation. */ @@ -2165,23 +2180,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { Tensor page_indices; Tensor length_info; Tensor k_rope_pos; - double rotary_theta; - double rotary_scale; if (attn_kinds_[local_layer_id + layer_id_begin_offset_] == AttnKind::kMHASliding) { page_indptr = page_indptr_sliding_window_on_depths_view_[d]; page_indices = page_indices_sliding_window_on_depths_view_[d]; length_info = layer_sliding_window_length_info_on_depths_view_[d]; k_rope_pos = k_rope_pos_offset_sliding_window_view_[d]; - rotary_theta = 10000; - rotary_scale = 1; } else { page_indptr = page_indptr_on_depths_view_[d]; page_indices = page_indices_on_depths_view_[d]; length_info = length_info_on_depths_view_[d]; k_rope_pos = k_rope_pos_offset_view_[d]; - rotary_theta = rotary_theta_; - rotary_scale = rotary_scale_; } if (append_before_attn_ && !is_chain_on_depths_[d]) { @@ -2189,13 +2198,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_attention_prefill_with_tree_mask_paged_kv_->MHA( q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, page_indices, length_info, k_rope_pos, q_rope_position_map_view_, tree_attn_mn_indptr_view_[d], - tree_attn_mask_view_[d], rope_mode_, rotary_scale, rotary_theta, sm_scale, attn_output, + tree_attn_mask_view_[d], rope_mode_, rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); } else if (use_decode_kernel_[d]) { // Use decode kernel for depth d ICHECK_NOTNULL(f_decode); f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr, page_indices, length_info, - k_rope_pos, q_rope_position_map_view_, rope_mode_, rotary_scale, rotary_theta, + k_rope_pos, q_rope_position_map_view_, rope_mode_, rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); } else { // Use prefill kernel for depth d @@ -2203,7 +2212,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_prefill->MHA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, page_indices, length_info, q_rope_position_map_view_, k_rope_pos, /*causal=*/false, - /*rotary_mode=*/rope_mode_, rotary_scale, rotary_theta, sm_scale, + /*rotary_mode=*/rope_mode_, rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); }