-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[KVCache] Enable sliding window for ragged prefill (SelfAttention)
#18630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -101,6 +101,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| const bool support_sliding_window_; | ||||||||||||||||||||
| /*! \brief A boolean flag indicating if the KV cache has per layer sliding window. */ | ||||||||||||||||||||
| const bool support_layer_sliding_window_; | ||||||||||||||||||||
| /*! \brief The sliding window size for sliding window attention. */ | ||||||||||||||||||||
| int32_t sliding_window_size_; | ||||||||||||||||||||
| /*! \brief The attention kinds for each layer. */ | ||||||||||||||||||||
| const std::vector<AttnKind> attn_kinds_; | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -314,6 +316,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| : support_sliding_window), | ||||||||||||||||||||
| support_layer_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(), | ||||||||||||||||||||
| AttnKind::kMHASliding) != attn_kinds.end()), | ||||||||||||||||||||
| sliding_window_size_(-1), | ||||||||||||||||||||
| attn_kinds_(std::move(attn_kinds)), | ||||||||||||||||||||
| rope_mode_(support_sliding_window && rope_mode != RoPEMode::kNone ? RoPEMode::kInline | ||||||||||||||||||||
| : rope_mode), | ||||||||||||||||||||
|
|
@@ -766,6 +769,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| // introduce more sink. Therefore, we update the given attn sink size. | ||||||||||||||||||||
| it->second.last_block_attn_sink_size = std::max(attn_sink_size - prefix_length, 0); | ||||||||||||||||||||
| it->second.sliding_window_size = sliding_window_size; | ||||||||||||||||||||
| if (sliding_window_size_ == -1) | ||||||||||||||||||||
| sliding_window_size_ = sliding_window_size; | ||||||||||||||||||||
|
Comment on lines
+772
to
+773
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic to set
Suggested change
|
||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| void PopN(int64_t seq_id, int32_t n) final { | ||||||||||||||||||||
|
|
@@ -1009,14 +1014,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| // Do the same for page_indices_sliding_window | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // For sliding window, the first page and last page will both be partially used | ||||||||||||||||||||
| page_indptr_sliding_window_h.push_back( | ||||||||||||||||||||
| page_indptr_sliding_window_h.back() + | ||||||||||||||||||||
| std::min(static_cast<int32_t>(block.page_ids.size()), | ||||||||||||||||||||
| static_cast<int32_t>(1024 / page_size_ + | ||||||||||||||||||||
| (block.seq_length % page_size_ ? 1 : 0)))); | ||||||||||||||||||||
| if (support_layer_sliding_window_ && sliding_window_size_ > 0) { | ||||||||||||||||||||
| // For sliding window, the first page and last page will both be partially used | ||||||||||||||||||||
| page_indptr_sliding_window_h.push_back( | ||||||||||||||||||||
| page_indptr_sliding_window_h.back() + | ||||||||||||||||||||
| std::min(static_cast<int32_t>(block.page_ids.size()), | ||||||||||||||||||||
| static_cast<int32_t>(sliding_window_size_ / page_size_ + | ||||||||||||||||||||
| (block.seq_length % page_size_ ? 1 : 0)))); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); | ||||||||||||||||||||
| i < static_cast<int32_t>(page_indices_h.size()); i++) { | ||||||||||||||||||||
| i < static_cast<int32_t>(page_indices_h.size()); i++) { | ||||||||||||||||||||
| page_indices_sliding_window_h.push_back(page_indices_h[i]); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| // set up the page indices properly by choosing the last (sliding_window_size / | ||||||||||||||||||||
|
|
@@ -1027,8 +1034,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| : (block.seq_length - block.sink_length + block.sliding_window_offset - 1) % | ||||||||||||||||||||
| page_size_ + | ||||||||||||||||||||
| 1); | ||||||||||||||||||||
| if (support_layer_sliding_window_) { | ||||||||||||||||||||
| if (block.seq_length < 1024) { | ||||||||||||||||||||
| if (support_layer_sliding_window_ && sliding_window_size_ > 0) { | ||||||||||||||||||||
| if (block.seq_length < sliding_window_size_) { | ||||||||||||||||||||
| sliding_window_offset_h.push_back(0); | ||||||||||||||||||||
| } else { | ||||||||||||||||||||
| sliding_window_offset_h.push_back(block.seq_length % page_size_); | ||||||||||||||||||||
|
|
@@ -1040,9 +1047,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| k_rope_pos_offset_h.push_back(block.start_pos); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // If sliding window, we need to calculate the positional offset | ||||||||||||||||||||
| if (support_layer_sliding_window_) { | ||||||||||||||||||||
| if (support_layer_sliding_window_ && sliding_window_size_ > 0) { | ||||||||||||||||||||
| k_rope_pos_offset_sliding_window_h.push_back( | ||||||||||||||||||||
| std::max(0, block.start_pos + block.seq_length - 1024)); | ||||||||||||||||||||
| std::max(0, block.start_pos + block.seq_length - sliding_window_size_)); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } else { | ||||||||||||||||||||
| // Blocks at maximum depth | ||||||||||||||||||||
|
|
@@ -1064,11 +1071,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| last_block_id = id; | ||||||||||||||||||||
| } | ||||||||||||||||||||
| page_indptr_h.push_back(page_indptr_h.back() + num_pages); | ||||||||||||||||||||
| page_indptr_sliding_window_h.push_back( | ||||||||||||||||||||
| page_indptr_sliding_window_h.back() + | ||||||||||||||||||||
| std::min(static_cast<int32_t>(block.page_ids.size()), | ||||||||||||||||||||
| static_cast<int32_t>(1024 / page_size_ + | ||||||||||||||||||||
| (block.seq_length % page_size_ ? 1 : 0)))); | ||||||||||||||||||||
| if (support_layer_sliding_window_ && sliding_window_size_ > 0) { | ||||||||||||||||||||
| page_indptr_sliding_window_h.push_back( | ||||||||||||||||||||
| page_indptr_sliding_window_h.back() + | ||||||||||||||||||||
| std::min(static_cast<int32_t>(block.page_ids.size()), | ||||||||||||||||||||
| static_cast<int32_t>(sliding_window_size_ / page_size_ + | ||||||||||||||||||||
| (block.seq_length % page_size_ ? 1 : 0)))); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); | ||||||||||||||||||||
| i < static_cast<int32_t>(page_indices_h.size()); i++) { | ||||||||||||||||||||
| page_indices_sliding_window_h.push_back(page_indices_h[i]); | ||||||||||||||||||||
|
|
@@ -1080,8 +1089,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| last_block.sliding_window_offset - 1) % | ||||||||||||||||||||
| page_size_ + | ||||||||||||||||||||
| 1); | ||||||||||||||||||||
| if (support_layer_sliding_window_) { | ||||||||||||||||||||
| if (last_block.seq_length < 1024) { | ||||||||||||||||||||
| if (support_layer_sliding_window_ && sliding_window_size_ > 0) { | ||||||||||||||||||||
| if (last_block.seq_length < sliding_window_size_) { | ||||||||||||||||||||
| sliding_window_offset_h.push_back(0); | ||||||||||||||||||||
| } else { | ||||||||||||||||||||
| sliding_window_offset_h.push_back(last_block.seq_length % page_size_); | ||||||||||||||||||||
|
|
@@ -1091,9 +1100,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| } | ||||||||||||||||||||
| sink_size_h.push_back(last_block.sink_length); | ||||||||||||||||||||
| k_rope_pos_offset_h.push_back(block.start_pos); | ||||||||||||||||||||
| if (support_layer_sliding_window_) { | ||||||||||||||||||||
| if (support_layer_sliding_window_ && sliding_window_size_ > 0) { | ||||||||||||||||||||
| k_rope_pos_offset_sliding_window_h.push_back( | ||||||||||||||||||||
| std::max(0, block.start_pos + block.seq_length - 1024)); | ||||||||||||||||||||
| std::max(0, block.start_pos + block.seq_length - sliding_window_size_)); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
@@ -1408,8 +1417,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| // The auxiliary data structure on device must have been synchronized. | ||||||||||||||||||||
| ICHECK(!dirty_aux_data_device_); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if (attn_kind == AttnKind::kMHA) { | ||||||||||||||||||||
| MHASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale); | ||||||||||||||||||||
| if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) { | ||||||||||||||||||||
| MHASelfAttnInternal(layer_id, q_data, k_data, v_data, o_data, lse_data, sm_scale); | ||||||||||||||||||||
grf53 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||
| } else { | ||||||||||||||||||||
| MLASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale); | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
@@ -2089,7 +2098,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| if (!append_before_attn_) { | ||||||||||||||||||||
| // The first part of attention, which only involves the q and the newly appended k/v. | ||||||||||||||||||||
| is_first_kernel = false; | ||||||||||||||||||||
| MHASelfAttnInternal(q_data, k_data, v_data, output, merged_attn_lse_view_, sm_scale); | ||||||||||||||||||||
| MHASelfAttnInternal(local_layer_id, q_data, k_data, v_data, output, merged_attn_lse_view_, sm_scale); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| bool self_attn_computed = !is_first_kernel; | ||||||||||||||||||||
| bool cross_attn_computed = MHACrossAttnInternal( | ||||||||||||||||||||
|
|
@@ -2098,15 +2107,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| << "Both self-attention and cross-attention are not computed."; | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| void MHASelfAttnInternal(Tensor q_data, Tensor k_data, Tensor v_data, Tensor o_data, | ||||||||||||||||||||
| void MHASelfAttnInternal(int64_t local_layer_id, Tensor q_data, Tensor k_data, Tensor v_data, Tensor o_data, | ||||||||||||||||||||
| Tensor lse_data, double sm_scale) { | ||||||||||||||||||||
| if (is_chain_on_depths_[0]) { | ||||||||||||||||||||
| // If the batch does not form a tree, use raggedness prefill kernel. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| int sliding_window_size = | ||||||||||||||||||||
| (!support_sliding_window_ && | ||||||||||||||||||||
| attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding) | ||||||||||||||||||||
| ? -1 | ||||||||||||||||||||
| : sliding_window_size_; | ||||||||||||||||||||
| ICHECK_NOTNULL(f_attention_prefill_ragged_); | ||||||||||||||||||||
| f_attention_prefill_ragged_->MHA( | ||||||||||||||||||||
| q_data, k_data, v_data, cur_append_length_indptr_view_, cur_append_length_indptr_view_, | ||||||||||||||||||||
| q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, /*causal=*/true, rope_mode_, | ||||||||||||||||||||
| rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_); | ||||||||||||||||||||
| q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, sliding_window_size, /*causal=*/true, | ||||||||||||||||||||
| rope_mode_, rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_); | ||||||||||||||||||||
| } else { | ||||||||||||||||||||
| // The batch requires tree attention. | ||||||||||||||||||||
| ICHECK(f_attention_prefill_with_tree_mask_ != nullptr) | ||||||||||||||||||||
|
|
@@ -2127,8 +2142,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| ICHECK_NOTNULL(f_attention_prefill_ragged_); | ||||||||||||||||||||
| f_attention_prefill_ragged_->MHA( | ||||||||||||||||||||
| q_data, k_data, v_data, cur_append_length_indptr_view_, cur_append_length_indptr_view_, | ||||||||||||||||||||
| q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, /*causal=*/true, RoPEMode::kNone, | ||||||||||||||||||||
| rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_); | ||||||||||||||||||||
| q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, -1, /*causal=*/true, | ||||||||||||||||||||
| RoPEMode::kNone, rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_); | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| /*! \brief Compute cross-attention for MHA. Return if there is effective computation. */ | ||||||||||||||||||||
|
|
@@ -2165,45 +2180,39 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| Tensor page_indices; | ||||||||||||||||||||
| Tensor length_info; | ||||||||||||||||||||
| Tensor k_rope_pos; | ||||||||||||||||||||
| double rotary_theta; | ||||||||||||||||||||
| double rotary_scale; | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if (attn_kinds_[local_layer_id + layer_id_begin_offset_] == AttnKind::kMHASliding) { | ||||||||||||||||||||
| page_indptr = page_indptr_sliding_window_on_depths_view_[d]; | ||||||||||||||||||||
| page_indices = page_indices_sliding_window_on_depths_view_[d]; | ||||||||||||||||||||
| length_info = layer_sliding_window_length_info_on_depths_view_[d]; | ||||||||||||||||||||
| k_rope_pos = k_rope_pos_offset_sliding_window_view_[d]; | ||||||||||||||||||||
| rotary_theta = 10000; | ||||||||||||||||||||
| rotary_scale = 1; | ||||||||||||||||||||
| } else { | ||||||||||||||||||||
| page_indptr = page_indptr_on_depths_view_[d]; | ||||||||||||||||||||
| page_indices = page_indices_on_depths_view_[d]; | ||||||||||||||||||||
| length_info = length_info_on_depths_view_[d]; | ||||||||||||||||||||
| k_rope_pos = k_rope_pos_offset_view_[d]; | ||||||||||||||||||||
| rotary_theta = rotary_theta_; | ||||||||||||||||||||
| rotary_scale = rotary_scale_; | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if (append_before_attn_ && !is_chain_on_depths_[d]) { | ||||||||||||||||||||
| ICHECK_NOTNULL(f_attention_prefill_with_tree_mask_paged_kv_); | ||||||||||||||||||||
| f_attention_prefill_with_tree_mask_paged_kv_->MHA( | ||||||||||||||||||||
| q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, page_indices, | ||||||||||||||||||||
| length_info, k_rope_pos, q_rope_position_map_view_, tree_attn_mn_indptr_view_[d], | ||||||||||||||||||||
| tree_attn_mask_view_[d], rope_mode_, rotary_scale, rotary_theta, sm_scale, attn_output, | ||||||||||||||||||||
| tree_attn_mask_view_[d], rope_mode_, rotary_scale_, rotary_theta_, sm_scale, attn_output, | ||||||||||||||||||||
| attn_lse, compute_stream_); | ||||||||||||||||||||
| } else if (use_decode_kernel_[d]) { | ||||||||||||||||||||
| // Use decode kernel for depth d | ||||||||||||||||||||
| ICHECK_NOTNULL(f_decode); | ||||||||||||||||||||
| f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr, page_indices, length_info, | ||||||||||||||||||||
| k_rope_pos, q_rope_position_map_view_, rope_mode_, rotary_scale, rotary_theta, | ||||||||||||||||||||
| k_rope_pos, q_rope_position_map_view_, rope_mode_, rotary_scale_, rotary_theta_, | ||||||||||||||||||||
| sm_scale, attn_output, attn_lse, compute_stream_); | ||||||||||||||||||||
| } else { | ||||||||||||||||||||
| // Use prefill kernel for depth d | ||||||||||||||||||||
| ICHECK_NOTNULL(f_prefill); | ||||||||||||||||||||
| f_prefill->MHA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, | ||||||||||||||||||||
| page_indices, length_info, q_rope_position_map_view_, k_rope_pos, | ||||||||||||||||||||
| /*causal=*/false, | ||||||||||||||||||||
| /*rotary_mode=*/rope_mode_, rotary_scale, rotary_theta, sm_scale, | ||||||||||||||||||||
| /*rotary_mode=*/rope_mode_, rotary_scale_, rotary_theta_, sm_scale, | ||||||||||||||||||||
| attn_output, attn_lse, compute_stream_); | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.