-
Notifications
You must be signed in to change notification settings - Fork 680
feat: Turbomind linear gdn prefix caching #4465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
891a40c
0397f35
6c38cf9
ec5e8aa
ad6573b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,19 +14,46 @@ size_t hash(const std::vector<int>& vec) | |||||||
| return seed; | ||||||||
| } | ||||||||
|
|
||||||||
| BlockTrie::BlockTrie(size_t block_len, std::shared_ptr<BlockManager> block_manager): | ||||||||
| block_seq_len_(block_len), block_manager_(block_manager) | ||||||||
| BlockTrie::BlockTrie(size_t block_len, | ||||||||
| std::shared_ptr<BlockManager> block_manager, | ||||||||
| int linear_prefix_cache_interval_blocks, | ||||||||
| int linear_state_slot_capacity, | ||||||||
| std::vector<ssize_t> conv_state_shape, | ||||||||
| DataType conv_state_dtype, | ||||||||
| std::vector<ssize_t> recurrent_state_shape, | ||||||||
| DataType recurrent_state_dtype): | ||||||||
| block_seq_len_(block_len), | ||||||||
| block_manager_(block_manager), | ||||||||
| linear_prefix_cache_interval_blocks_(linear_prefix_cache_interval_blocks), | ||||||||
| linear_prefix_cache_interval_tokens_(linear_prefix_cache_interval_blocks * block_len), | ||||||||
| linear_state_slot_capacity_(linear_state_slot_capacity), | ||||||||
| conv_state_shape_(std::move(conv_state_shape)), | ||||||||
| recurrent_state_shape_(std::move(recurrent_state_shape)), | ||||||||
| conv_state_dtype_(conv_state_dtype), | ||||||||
| recurrent_state_dtype_(recurrent_state_dtype) | ||||||||
| { | ||||||||
| root_ = std::make_shared<TrieNode>(); | ||||||||
|
|
||||||||
| if (linear_state_slot_capacity_ > 0) { | ||||||||
| TM_CHECK_GT(linear_prefix_cache_interval_blocks_, 0); | ||||||||
| TM_CHECK(!conv_state_shape_.empty()); | ||||||||
| TM_CHECK(!recurrent_state_shape_.empty()); | ||||||||
| linear_conv_states_.resize(linear_state_slot_capacity_); | ||||||||
| linear_recurrent_states_.resize(linear_state_slot_capacity_); | ||||||||
| free_linear_state_slots_.reserve(linear_state_slot_capacity_); | ||||||||
| for (int slot = linear_state_slot_capacity_ - 1; slot >= 0; --slot) { | ||||||||
| free_linear_state_slots_.push_back(slot); | ||||||||
| } | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| std::tuple<BlockIds, UniqueIds> BlockTrie::Match(const Sequence& seq) | ||||||||
| BlockTrieMatch BlockTrie::Match(const Sequence& seq) | ||||||||
| { | ||||||||
| BlockIds block_ids; | ||||||||
| UniqueIds unique_ids; | ||||||||
| BlockTrieMatch match; | ||||||||
|
|
||||||||
| auto node = root_; | ||||||||
| auto first = seq.prompt.begin(); | ||||||||
| auto node = root_; | ||||||||
| auto first = seq.prompt.begin(); | ||||||||
| auto linear_prefix_state = root_; | ||||||||
|
|
||||||||
| // Warning: Do not use "<=" operator even when seq.prompt length is evenly | ||||||||
| // divisible by block_seq_len_. The model needs at least one input token to generate output. | ||||||||
|
|
@@ -35,9 +62,12 @@ std::tuple<BlockIds, UniqueIds> BlockTrie::Match(const Sequence& seq) | |||||||
| const size_t hash_key = hash(segment); | ||||||||
| if (const auto it = node->children.find(hash_key); it != node->children.end()) { | ||||||||
| if (segment == it->second->tokens) { | ||||||||
| block_ids.push_back(it->second->block_id); | ||||||||
| unique_ids.push_back(it->second->block_unique_id); | ||||||||
| match.block_ids.push_back(it->second->block_id); | ||||||||
| match.unique_ids.push_back(it->second->block_unique_id); | ||||||||
| node = it->second; | ||||||||
| if (node->linear_state_slot >= 0) { | ||||||||
| linear_prefix_state = node; | ||||||||
| } | ||||||||
| first += block_seq_len_; | ||||||||
| } | ||||||||
| else { | ||||||||
|
|
@@ -50,7 +80,14 @@ std::tuple<BlockIds, UniqueIds> BlockTrie::Match(const Sequence& seq) | |||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| return std::make_tuple(block_ids, unique_ids); | ||||||||
| if (linear_prefix_state != root_) { | ||||||||
| const int slot = linear_prefix_state->linear_state_slot; | ||||||||
| match.linear_cache_len = linear_prefix_state->num_matched; | ||||||||
| match.conv_states = LinearConvState(slot); | ||||||||
| match.recurrent_states = LinearRecurrentState(slot); | ||||||||
| } | ||||||||
|
|
||||||||
| return match; | ||||||||
| } | ||||||||
|
|
||||||||
| std::tuple<BlockIds, UniqueIds> BlockTrie::Cache(const Sequence& seq, const std::vector<int>& tokens) | ||||||||
|
|
@@ -66,7 +103,10 @@ std::tuple<BlockIds, UniqueIds> BlockTrie::Cache(const Sequence& seq, const std: | |||||||
|
|
||||||||
| const int n_blocks = std::min(seq.cache_len, (int)tokens.size()) / block_seq_len_; | ||||||||
|
|
||||||||
| int new_cached = 0; | ||||||||
| int new_cached = 0; | ||||||||
| const int checkpoint_base = linear_prefix_cache_interval_tokens_ ? | ||||||||
| seq.pending_linear_prefix_capture_base_len / linear_prefix_cache_interval_tokens_ : | ||||||||
| 0; | ||||||||
|
|
||||||||
| for (int idx = 0; idx < n_blocks; ++idx) { | ||||||||
| auto start = tokens.begin() + idx * block_seq_len_; | ||||||||
|
|
@@ -75,14 +115,16 @@ std::tuple<BlockIds, UniqueIds> BlockTrie::Cache(const Sequence& seq, const std: | |||||||
| const std::vector<int> segment(start, end); | ||||||||
| const size_t hash_key = hash(segment); // TODO(lvhan): add salt to ensure the hash security | ||||||||
|
|
||||||||
| int block_id = seq.blocks[idx]; | ||||||||
| uint64_t block_unique_id = seq.block_unique_ids[idx]; | ||||||||
| int block_id = seq.blocks[idx]; | ||||||||
| uint64_t block_unique_id = seq.block_unique_ids[idx]; | ||||||||
| const int num_matched = (idx + 1) * block_seq_len_; | ||||||||
|
|
||||||||
| if (auto it = node->children.find(hash_key); it != node->children.end()) { | ||||||||
| if (segment == it->second->tokens) { // fast-forward | ||||||||
| node = it->second; | ||||||||
| node->block_id = block_id; | ||||||||
| node->block_unique_id = block_unique_id; | ||||||||
| node->num_matched = num_matched; | ||||||||
| } | ||||||||
| else { | ||||||||
| TM_LOG_WARNING("[BlockTrie][cache] Hash collision detected"); | ||||||||
|
|
@@ -96,8 +138,24 @@ std::tuple<BlockIds, UniqueIds> BlockTrie::Cache(const Sequence& seq, const std: | |||||||
| node->tokens = segment; | ||||||||
| node->block_id = block_id; | ||||||||
| node->block_unique_id = block_unique_id; | ||||||||
| node->num_matched = num_matched; | ||||||||
| new_cached += block_seq_len_; | ||||||||
| } | ||||||||
| if (IsLinearCheckpointNode(num_matched)) { | ||||||||
| const int checkpoint_idx = num_matched / linear_prefix_cache_interval_tokens_ - checkpoint_base - 1; | ||||||||
| if (checkpoint_idx >= 0 && checkpoint_idx < seq.pending_linear_prefix_capture_count | ||||||||
| && seq.pending_linear_prefix_conv_states && seq.pending_linear_prefix_recurrent_states) { | ||||||||
| if (node->linear_state_slot < 0) { | ||||||||
| node->linear_state_slot = AcquireLinearStateSlot(); | ||||||||
| } | ||||||||
| if (node->linear_state_slot >= 0) { | ||||||||
| Copy(seq.pending_linear_prefix_conv_states.slice(checkpoint_idx).squeeze(0), | ||||||||
| LinearConvState(node->linear_state_slot)); | ||||||||
| Copy(seq.pending_linear_prefix_recurrent_states.slice(checkpoint_idx).squeeze(0), | ||||||||
| LinearRecurrentState(node->linear_state_slot)); | ||||||||
| } | ||||||||
| } | ||||||||
| } | ||||||||
| cache_block_ids.emplace_back(block_id); | ||||||||
| cache_block_unique_ids.emplace_back(block_unique_id); | ||||||||
| } | ||||||||
|
|
@@ -117,6 +175,7 @@ void BlockTrie::DFS(std::shared_ptr<TrieNode>& node) | |||||||
| for (auto it = node->children.begin(); it != node->children.end();) { | ||||||||
| if (block_manager_->unique_id(it->second->block_id) != it->second->block_unique_id) { | ||||||||
| // child invalid | ||||||||
| ReleaseLinearPrefixState(it->second); | ||||||||
| it = node->children.erase(it); | ||||||||
| } | ||||||||
| else { | ||||||||
|
|
@@ -126,4 +185,78 @@ void BlockTrie::DFS(std::shared_ptr<TrieNode>& node) | |||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| void BlockTrie::ReleaseLinearPrefixState(std::shared_ptr<TrieNode>& node) | ||||||||
| { | ||||||||
| if (!node) { | ||||||||
| return; | ||||||||
| } | ||||||||
| for (auto& [_, child] : node->children) { | ||||||||
| ReleaseLinearPrefixState(child); | ||||||||
| } | ||||||||
| ReleaseLinearStateSlot(node->linear_state_slot); | ||||||||
| node->linear_state_slot = -1; | ||||||||
| } | ||||||||
|
|
||||||||
| bool BlockTrie::IsLinearCheckpointNode(int num_matched) const | ||||||||
| { | ||||||||
| return linear_prefix_cache_interval_tokens_ > 0 && num_matched > 0 | ||||||||
| && num_matched % linear_prefix_cache_interval_tokens_ == 0; | ||||||||
| } | ||||||||
|
|
||||||||
| int BlockTrie::AcquireLinearStateSlot() | ||||||||
| { | ||||||||
| if (free_linear_state_slots_.empty()) { | ||||||||
| if (!warned_linear_state_pool_exhausted_) { | ||||||||
| TM_LOG_WARNING("[BlockTrie] linear prefix checkpoint pool exhausted; deeper hybrid prefix checkpoints " | ||||||||
| "will be skipped until cached entries are evicted"); | ||||||||
| warned_linear_state_pool_exhausted_ = true; | ||||||||
| } | ||||||||
| return -1; | ||||||||
| } | ||||||||
| const int slot = free_linear_state_slots_.back(); | ||||||||
| free_linear_state_slots_.pop_back(); | ||||||||
| try { | ||||||||
| if (!linear_conv_states_[slot]) { | ||||||||
| linear_conv_states_[slot] = {conv_state_shape_, conv_state_dtype_, kDEVICE}; | ||||||||
| } | ||||||||
| if (!linear_recurrent_states_[slot]) { | ||||||||
| linear_recurrent_states_[slot] = {recurrent_state_shape_, recurrent_state_dtype_, kDEVICE}; | ||||||||
| } | ||||||||
| } | ||||||||
| catch (const std::exception& e) { | ||||||||
| free_linear_state_slots_.push_back(slot); | ||||||||
| if (!warned_linear_state_pool_oom_) { | ||||||||
| TM_LOG_WARNING("[BlockTrie] failed to allocate hybrid prefix checkpoint state: %s. " | ||||||||
| "Further GDN prefix checkpoints will be skipped until memory is freed.", | ||||||||
| e.what()); | ||||||||
| warned_linear_state_pool_oom_ = true; | ||||||||
| } | ||||||||
| return -1; | ||||||||
| } | ||||||||
| return slot; | ||||||||
| } | ||||||||
|
|
||||||||
| void BlockTrie::ReleaseLinearStateSlot(int slot) | ||||||||
| { | ||||||||
| if (slot >= 0) { | ||||||||
| linear_conv_states_[slot] = {}; | ||||||||
| linear_recurrent_states_[slot] = {}; | ||||||||
|
Comment on lines
+242
to
+243
|
||||||||
| linear_conv_states_[slot] = {}; | |
| linear_recurrent_states_[slot] = {}; | |
| // Keep tensors allocated to allow true O(1) reuse and avoid allocator churn. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR description says the default
linear_prefix_cache_interval_blocksis 2 KV blocks, but the implementation sets the default to 64 here (and tests assert 64). Please either update the default value to match the PR description or adjust the PR description/user-facing docs so they align with the shipped default.