Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lmdeploy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def add_parser_chat():
tb_group._group_actions.append(cache_max_entry_act)
tb_group._group_actions.append(prefix_caching_act)
tb_group._group_actions.append(quant_policy)
ArgumentHelper.linear_prefix_cache_interval_blocks(tb_group)
ArgumentHelper.model_format(tb_group)
ArgumentHelper.rope_scaling_factor(tb_group)
ArgumentHelper.communicator(tb_group)
Expand Down
8 changes: 6 additions & 2 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def add_parser_api_server():
cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group)
cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)
linear_prefix_cache_interval_blocks_act = ArgumentHelper.linear_prefix_cache_interval_blocks(pt_group)
max_prefill_token_num_act = ArgumentHelper.max_prefill_token_num(pt_group)
quant_policy = ArgumentHelper.quant_policy(pt_group)
model_format = ArgumentHelper.model_format(pt_group)
Expand All @@ -140,6 +141,7 @@ def add_parser_api_server():
tb_group._group_actions.append(cache_max_entry_act)
tb_group._group_actions.append(cache_block_seq_len_act)
tb_group._group_actions.append(prefix_caching_act)
tb_group._group_actions.append(linear_prefix_cache_interval_blocks_act)
tb_group._group_actions.append(max_prefill_token_num_act)
tb_group._group_actions.append(quant_policy)
tb_group._group_actions.append(model_format)
Expand Down Expand Up @@ -208,13 +210,14 @@ def api_server(args):
"""Serve LLMs with restful api using fastapi."""
from lmdeploy.archs import autoget_backend

max_batch_size = args.max_batch_size if args.max_batch_size \
else get_max_batch_size(args.device)
backend = args.backend
if backend != 'pytorch':
# set auto backend mode
backend = autoget_backend(args.model_path)

max_batch_size = args.max_batch_size if args.max_batch_size \
else get_max_batch_size(args.device)

if backend == 'pytorch':
from lmdeploy.messages import PytorchEngineConfig
adapters = get_lora_adapters(args.adapters)
Expand Down Expand Up @@ -266,6 +269,7 @@ def api_server(args):
cache_max_entry_count=args.cache_max_entry_count,
cache_block_seq_len=args.cache_block_seq_len,
enable_prefix_caching=args.enable_prefix_caching,
linear_prefix_cache_interval_blocks=args.linear_prefix_cache_interval_blocks,
max_prefill_token_num=args.max_prefill_token_num,
num_tokens_per_iter=args.num_tokens_per_iter,
max_prefill_iters=args.max_prefill_iters,
Expand Down
11 changes: 11 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,17 @@ def enable_prefix_caching(parser):
default=False,
help='Enable cache and match prefix')

@staticmethod
def linear_prefix_cache_interval_blocks(parser):
"""Add argument linear_prefix_cache_interval_blocks to parser."""

return parser.add_argument('--linear-prefix-cache-interval-blocks',
type=int,
default=64,
help='Hybrid linear-attention prefix checkpoint interval in '
'KV cache blocks. Larger values reduce GDN checkpoint memory '
'usage but increase recompute after a prefix hit')

@staticmethod
def num_tokens_per_iter(parser):
return parser.add_argument('--num-tokens-per-iter',
Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ class TurbomindEngineConfig:
a k/v block, default to 64
enable_prefix_caching: enable cache prompts for block reuse,
default to False
linear_prefix_cache_interval_blocks: hybrid linear-attention prefix
checkpoint interval, in KV cache blocks. Larger values reduce GDN
checkpoint memory but may require more recompute after a prefix
hit. Applies only to hybrid models with prefix caching enabled.
Default to 64
quant_policy: default to 0. When k/v is quantized into 4 or 8
bit, set it to 4 or 8, respectively
rope_scaling_factor: scaling factor used for dynamic ntk,
Expand Down Expand Up @@ -278,6 +283,7 @@ class TurbomindEngineConfig:
cache_chunk_size: int = -1
cache_block_seq_len: int = 64
enable_prefix_caching: bool = False
linear_prefix_cache_interval_blocks: int = 64
quant_policy: int = 0
Comment on lines 283 to 287
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description says the default linear_prefix_cache_interval_blocks is 2 KV blocks, but the implementation sets the default to 64 here (and tests assert 64). Please either update the default value to match the PR description or adjust the PR description/user-facing docs so they align with the shipped default.

Copilot uses AI. Check for mistakes.
rope_scaling_factor: float = 0.0
use_logn_attn: bool = False
Expand All @@ -298,6 +304,8 @@ def __post_init__(self):
assert self.dtype in ['auto', 'float16', 'bfloat16']
assert self.tp >= 1, 'tp must be a positive integer'
assert self.cache_max_entry_count > 0, 'invalid cache_max_entry_count'
assert self.linear_prefix_cache_interval_blocks >= 1, \
'invalid linear_prefix_cache_interval_blocks'
assert self.quant_policy in (0, 4, 8), 'invalid quant_policy'
assert self.rope_scaling_factor >= 0, 'invalid rope_scaling_factor'
assert self.max_prefill_token_num >= 0, \
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/engine/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ void Engine::Impl::CreateSequenceManager()
param_.cache_max_block_count,
param_.cache_chunk_size,
param_.enable_prefix_caching,
param_.linear_prefix_cache_interval_blocks,
tp_rank_,
param_.attn_cp_size,
core::Context::alloc(kDEVICE),
Expand Down
159 changes: 146 additions & 13 deletions src/turbomind/models/llama/BlockTrie.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,46 @@ size_t hash(const std::vector<int>& vec)
return seed;
}

BlockTrie::BlockTrie(size_t block_len, std::shared_ptr<BlockManager> block_manager):
block_seq_len_(block_len), block_manager_(block_manager)
BlockTrie::BlockTrie(size_t block_len,
std::shared_ptr<BlockManager> block_manager,
int linear_prefix_cache_interval_blocks,
int linear_state_slot_capacity,
std::vector<ssize_t> conv_state_shape,
DataType conv_state_dtype,
std::vector<ssize_t> recurrent_state_shape,
DataType recurrent_state_dtype):
block_seq_len_(block_len),
block_manager_(block_manager),
linear_prefix_cache_interval_blocks_(linear_prefix_cache_interval_blocks),
linear_prefix_cache_interval_tokens_(linear_prefix_cache_interval_blocks * block_len),
linear_state_slot_capacity_(linear_state_slot_capacity),
conv_state_shape_(std::move(conv_state_shape)),
recurrent_state_shape_(std::move(recurrent_state_shape)),
conv_state_dtype_(conv_state_dtype),
recurrent_state_dtype_(recurrent_state_dtype)
{
root_ = std::make_shared<TrieNode>();

if (linear_state_slot_capacity_ > 0) {
TM_CHECK_GT(linear_prefix_cache_interval_blocks_, 0);
TM_CHECK(!conv_state_shape_.empty());
TM_CHECK(!recurrent_state_shape_.empty());
linear_conv_states_.resize(linear_state_slot_capacity_);
linear_recurrent_states_.resize(linear_state_slot_capacity_);
free_linear_state_slots_.reserve(linear_state_slot_capacity_);
for (int slot = linear_state_slot_capacity_ - 1; slot >= 0; --slot) {
free_linear_state_slots_.push_back(slot);
}
}
}

std::tuple<BlockIds, UniqueIds> BlockTrie::Match(const Sequence& seq)
BlockTrieMatch BlockTrie::Match(const Sequence& seq)
{
BlockIds block_ids;
UniqueIds unique_ids;
BlockTrieMatch match;

auto node = root_;
auto first = seq.prompt.begin();
auto node = root_;
auto first = seq.prompt.begin();
auto linear_prefix_state = root_;

// Warning: Do not use "<=" operator even when seq.prompt length is evenly
// divisible by block_seq_len_. The model needs at least one input token to generate output.
Expand All @@ -35,9 +62,12 @@ std::tuple<BlockIds, UniqueIds> BlockTrie::Match(const Sequence& seq)
const size_t hash_key = hash(segment);
if (const auto it = node->children.find(hash_key); it != node->children.end()) {
if (segment == it->second->tokens) {
block_ids.push_back(it->second->block_id);
unique_ids.push_back(it->second->block_unique_id);
match.block_ids.push_back(it->second->block_id);
match.unique_ids.push_back(it->second->block_unique_id);
node = it->second;
if (node->linear_state_slot >= 0) {
linear_prefix_state = node;
}
first += block_seq_len_;
}
else {
Expand All @@ -50,7 +80,14 @@ std::tuple<BlockIds, UniqueIds> BlockTrie::Match(const Sequence& seq)
}
}

return std::make_tuple(block_ids, unique_ids);
if (linear_prefix_state != root_) {
const int slot = linear_prefix_state->linear_state_slot;
match.linear_cache_len = linear_prefix_state->num_matched;
match.conv_states = LinearConvState(slot);
match.recurrent_states = LinearRecurrentState(slot);
}

return match;
}

std::tuple<BlockIds, UniqueIds> BlockTrie::Cache(const Sequence& seq, const std::vector<int>& tokens)
Expand All @@ -66,7 +103,10 @@ std::tuple<BlockIds, UniqueIds> BlockTrie::Cache(const Sequence& seq, const std:

const int n_blocks = std::min(seq.cache_len, (int)tokens.size()) / block_seq_len_;

int new_cached = 0;
int new_cached = 0;
const int checkpoint_base = linear_prefix_cache_interval_tokens_ ?
seq.pending_linear_prefix_capture_base_len / linear_prefix_cache_interval_tokens_ :
0;

for (int idx = 0; idx < n_blocks; ++idx) {
auto start = tokens.begin() + idx * block_seq_len_;
Expand All @@ -75,14 +115,16 @@ std::tuple<BlockIds, UniqueIds> BlockTrie::Cache(const Sequence& seq, const std:
const std::vector<int> segment(start, end);
const size_t hash_key = hash(segment); // TODO(lvhan): add salt to ensure the hash security

int block_id = seq.blocks[idx];
uint64_t block_unique_id = seq.block_unique_ids[idx];
int block_id = seq.blocks[idx];
uint64_t block_unique_id = seq.block_unique_ids[idx];
const int num_matched = (idx + 1) * block_seq_len_;

if (auto it = node->children.find(hash_key); it != node->children.end()) {
if (segment == it->second->tokens) { // fast-forward
node = it->second;
node->block_id = block_id;
node->block_unique_id = block_unique_id;
node->num_matched = num_matched;
}
else {
TM_LOG_WARNING("[BlockTrie][cache] Hash collision detected");
Expand All @@ -96,8 +138,24 @@ std::tuple<BlockIds, UniqueIds> BlockTrie::Cache(const Sequence& seq, const std:
node->tokens = segment;
node->block_id = block_id;
node->block_unique_id = block_unique_id;
node->num_matched = num_matched;
new_cached += block_seq_len_;
}
if (IsLinearCheckpointNode(num_matched)) {
const int checkpoint_idx = num_matched / linear_prefix_cache_interval_tokens_ - checkpoint_base - 1;
if (checkpoint_idx >= 0 && checkpoint_idx < seq.pending_linear_prefix_capture_count
&& seq.pending_linear_prefix_conv_states && seq.pending_linear_prefix_recurrent_states) {
if (node->linear_state_slot < 0) {
node->linear_state_slot = AcquireLinearStateSlot();
}
if (node->linear_state_slot >= 0) {
Copy(seq.pending_linear_prefix_conv_states.slice(checkpoint_idx).squeeze(0),
LinearConvState(node->linear_state_slot));
Copy(seq.pending_linear_prefix_recurrent_states.slice(checkpoint_idx).squeeze(0),
LinearRecurrentState(node->linear_state_slot));
}
}
}
cache_block_ids.emplace_back(block_id);
cache_block_unique_ids.emplace_back(block_unique_id);
}
Expand All @@ -117,6 +175,7 @@ void BlockTrie::DFS(std::shared_ptr<TrieNode>& node)
for (auto it = node->children.begin(); it != node->children.end();) {
if (block_manager_->unique_id(it->second->block_id) != it->second->block_unique_id) {
// child invalid
ReleaseLinearPrefixState(it->second);
it = node->children.erase(it);
}
else {
Expand All @@ -126,4 +185,78 @@ void BlockTrie::DFS(std::shared_ptr<TrieNode>& node)
}
}

void BlockTrie::ReleaseLinearPrefixState(std::shared_ptr<TrieNode>& node)
{
if (!node) {
return;
}
for (auto& [_, child] : node->children) {
ReleaseLinearPrefixState(child);
}
ReleaseLinearStateSlot(node->linear_state_slot);
node->linear_state_slot = -1;
}

bool BlockTrie::IsLinearCheckpointNode(int num_matched) const
{
return linear_prefix_cache_interval_tokens_ > 0 && num_matched > 0
&& num_matched % linear_prefix_cache_interval_tokens_ == 0;
}

int BlockTrie::AcquireLinearStateSlot()
{
if (free_linear_state_slots_.empty()) {
if (!warned_linear_state_pool_exhausted_) {
TM_LOG_WARNING("[BlockTrie] linear prefix checkpoint pool exhausted; deeper hybrid prefix checkpoints "
"will be skipped until cached entries are evicted");
warned_linear_state_pool_exhausted_ = true;
}
return -1;
}
const int slot = free_linear_state_slots_.back();
free_linear_state_slots_.pop_back();
try {
if (!linear_conv_states_[slot]) {
linear_conv_states_[slot] = {conv_state_shape_, conv_state_dtype_, kDEVICE};
}
if (!linear_recurrent_states_[slot]) {
linear_recurrent_states_[slot] = {recurrent_state_shape_, recurrent_state_dtype_, kDEVICE};
}
}
catch (const std::exception& e) {
free_linear_state_slots_.push_back(slot);
if (!warned_linear_state_pool_oom_) {
TM_LOG_WARNING("[BlockTrie] failed to allocate hybrid prefix checkpoint state: %s. "
"Further GDN prefix checkpoints will be skipped until memory is freed.",
e.what());
warned_linear_state_pool_oom_ = true;
}
return -1;
}
return slot;
}

void BlockTrie::ReleaseLinearStateSlot(int slot)
{
if (slot >= 0) {
linear_conv_states_[slot] = {};
linear_recurrent_states_[slot] = {};
Comment on lines +242 to +243
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReleaseLinearStateSlot clears the tensors (linear_conv_states_[slot] = {} etc.) before returning the slot to the free list. This defeats the purpose of pre-sizing a slot pool and can cause frequent GPU reallocations/churn when prefix checkpoints are captured repeatedly. Consider keeping the tensors allocated and just returning the slot to the free list (or only freeing under memory-pressure), so reuse is truly O(1) and avoids allocator overhead.

Suggested change
linear_conv_states_[slot] = {};
linear_recurrent_states_[slot] = {};
// Keep tensors allocated to allow true O(1) reuse and avoid allocator churn.

Copilot uses AI. Check for mistakes.
free_linear_state_slots_.push_back(slot);
}
}

Tensor BlockTrie::LinearConvState(int slot) const
{
TM_CHECK_GE(slot, 0);
TM_CHECK_LT(slot, (int)linear_conv_states_.size());
return linear_conv_states_[slot];
}

Tensor BlockTrie::LinearRecurrentState(int slot) const
{
TM_CHECK_GE(slot, 0);
TM_CHECK_LT(slot, (int)linear_recurrent_states_.size());
return linear_recurrent_states_[slot];
}

} // namespace turbomind
Loading
Loading