-
Notifications
You must be signed in to change notification settings - Fork 168
[Feature]: Support Prefix Cache for Qwen-Next Model (Mamba Cache Mode) #1016
Copy link
Copy link
Open
Labels
Description
🚀 The motivation and feature
Summary
This issue tracks the implementation of prefix cache support for Qwen-Next model, which uses a hybrid architecture combining standard Transformer attention with linear attention (Gated Delta Net). The implementation introduces a new MambaCacheMode configuration to handle state caching for linear attention layers.
Background
Qwen-Next is a hybrid architecture model that alternates between:
- Standard Attention layers (every 4th layer)
- Linear Attention layers (Gated Delta Net, other layers)
Unlike traditional LLM models, linear attention layers have recurrent state dependencies, meaning the current state depends on all historical inputs. This makes traditional KV Cache prefix caching inapplicable.
Architecture Diagram
┌─────────────────────────────────────────────────────────────────────────┐
│ Qwen-Next Layer Structure │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Layer 0: ┌─────────────────────────────────────────────────────┐ │
│ │ Linear Attention (GDN) │ │
│ │ ┌───────────────┐ ┌────────────────────────────┐ │ │
│ │ │ Conv Cache │ │ SSM Cache (recurrent) │ │ │
│ │ └───────────────┘ └────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ Layer 1: ┌─────────────────────────────────────────────────────┐ │
│ │ Linear Attention (GDN) │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ Layer 2: ┌─────────────────────────────────────────────────────┐ │
│ │ Linear Attention (GDN) │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ Layer 3: ┌─────────────────────────────────────────────────────┐ │
│ │ Standard Attention │ │
│ │ ┌───────────────┐ ┌────────────────────────────┐ │ │
│ │ │ Key Cache │ │ Value Cache │ │ │
│ │ └───────────────┘ └────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ ... (pattern repeats every 4 layers) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Key Challenges
- State Continuity: GDN's recurrent state formula
h_t = f(h_{t-1}, input_t)requires state continuity - Cannot Directly Reuse: Traditional KV Cache cannot be directly applied to linear attention
- Alignment Requirement: States must be cached at specific block boundaries
Solution
Introduce MambaCacheMode with three modes:
| Mode | Description | Use Case |
|---|---|---|
none |
No caching for mamba states | Prefix cache disabled |
all |
Cache states at all block boundaries | Models supporting state snapshots |
align |
Cache states only at aligned positions | Qwen-Next (recommended) |
Mamba Cache Mode Comparison
┌─────────────────────────────────────────────────────────────────────────┐
│ Mamba Cache Mode Comparison │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Token Position: 0 100 200 300 400 500 600 700 800 900 │
│ │ │ │ │ │ │ │ │ │ │ │
│ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ │
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Mode: "none" │ │
│ │ │ │
│ │ [No state caching - always recompute from scratch] │ │
│ │ │ │
│ │ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX │ │
│ │ (all tokens need recomputation) │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Mode: "all" (Not supported by Qwen-Next) │ │
│ │ │ │
│ │ [Cache at every block boundary] │ │
│ │ │ │
│ │ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ │ │
│ │ │ S │ │ S │ │ S │ │ S │ │ S │ │ S │ │ S │ │ │
│ │ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ │ │
│ │ pos:0 128 256 384 512 640 768 896 │ │
│ │ │ │
│ │ S = Saved state │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Mode: "align" (Recommended for Qwen-Next) │ │
│ │ │ │
│ │ [Cache only at scheduler step boundaries when aligned] │ │
│ │ │ │
│ │ Prefill Step 1 (560 tokens): │ │
│ │ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX┌───┐ │ │
│ │ │ S │ │ │
│ │ pos:0 560│ │ │
│ │ └───┘ │ │
│ │ │ │
│ │ Prefill Step 2 (560 tokens): │ │
│ │ [reuse cached state]───────────────────────────┐┌───┐ │ │
│ │ ││ S │ │ │
│ │ pos:560 1120│ │ │
│ │ └───┘ │ │
│ │ │ │
│ │ S = Saved state at block boundary │ │
│ │ X = Tokens computed in this step │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Implementation Details
1. New Components
┌─────────────────────────────────────────────────────────────────────────┐
│ Component Architecture │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Scheduler Layer │ │
│ │ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │ │
│ │ │ Request Scheduler│ │ Block Allocator │ │ Token Aligner │ │ │
│ │ └────────┬────────┘ └────────┬────────┘ └────────┬────────┘ │ │
│ └───────────┼─────────────────────┼─────────────────────┼──────────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Cache Management Layer │ │
│ │ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │ │
│ │ │ PrefixCache │ │ BlockManager │ │ MambaCacheMgr │ │ │
│ │ │ (Token-based) │ │ (Block-based) │ │ (State-based) │ │ │
│ │ └────────┬────────┘ └────────┬────────┘ └────────┬────────┘ │ │
│ └───────────┼─────────────────────┼─────────────────────┼──────────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Model Layer │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ Qwen3NextDecoderLayer │ │ │
│ │ │ ┌───────────────────┐ ┌───────────────────────────┐ │ │ │
│ │ │ │ Qwen3NextAttention│ │ Qwen3NextGatedDeltaNet │ │ │ │
│ │ │ │ (Standard Attn) │ │ (Linear Attn) │ │ │ │
│ │ │ │ - k,v cache │ │ - conv_cache, ssm_cache │ │ │ │
│ │ │ └───────────────────┘ └───────────────────────────┘ │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
2. State Caching Flow (Align Mode)
┌─────────────────────────────────────────────────────────────────────────┐
│ Prefill Flow with Align Mode │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Step 1: Prefill (tokens 0-559) │
│ ┌──────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ Input Tokens: [t0, t1, t2, ... t559] │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ GDN Forward (compute all 560 tokens) │ │ │
│ │ │ │ │ │
│ │ │ h_0 → h_1 → h_2 → ... → h_559 │ │ │
│ │ │ │ │ │
│ │ │ Check: 560 % 560 == 0? → YES │ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ Save: conv_cache[Block_0], ssm_cache[Block_0] │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ │ │ │
│ └──────────────────────────────────────────────────────────────────┘ │
│ │
│ Step 2: Prefill (tokens 560-1119) - Cache Hit! │
│ ┌──────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ Input Tokens: [t560, t561, ... t1119] │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ Load cached state from Block_0: │ │ │
│ │ │ - conv_cache[Block_0] → initial_conv_state │ │ │
│ │ │ - ssm_cache[Block_0] → initial_ssm_state │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ GDN Forward (continue from cached state) │ │ │
│ │ │ │ │ │
│ │ │ h_559 → h_560 → ... → h_1119 │ │ │
│ │ │ ↑ │ │ │
│ │ │ └── loaded from cache │ │ │
│ │ │ │ │ │
│ │ │ Check: 1120 % 560 == 0? → YES │ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ Save: conv_cache[Block_1], ssm_cache[Block_1] │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ │ │ │
│ └──────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
3. Modified Components
Qwen3NextGatedDeltaNet: Added support for conditional state saving based onmamba_cache_modeModelInputParams: Addedmamba_cache_modeandmamba_block_sizefieldsModelArgs: Added configuration parametersruntime::Options: Added runtime option formamba_cache_mode
4. State Caching Logic
// In Qwen3NextGatedDeltaNet::forward()
bool should_save_state = false;
if (mamba_cache_mode_ == MambaCacheMode::kAlign) {
if (mamba_block_size_ > 0 && seq_len % mamba_block_size_ == 0) {
should_save_state = true;
}
} else if (mamba_cache_mode_ == MambaCacheMode::kAll) {
should_save_state = true;
}
if (should_save_state) {
// Save conv_state and ssm_state
}Memory Layout
┌─────────────────────────────────────────────────────────────────────────┐
│ GPU Memory Layout │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ KV Cache Memory Pool │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ Key Cache [num_blocks, num_layers, block_size, │ │ │
│ │ │ num_heads, head_dim] │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ Value Cache [num_blocks, num_layers, block_size, │ │ │
│ │ │ num_heads, head_dim] │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Mamba State Memory Pool (NEW) │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ Conv State Cache │ │ │
│ │ │ [num_blocks, num_gdn_layers, conv_dim, kernel_size] │ │ │
│ │ │ │ │ │
│ │ │ For Qwen-Next: kernel_size = 4 │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ SSM State Cache │ │ │
│ │ │ [num_blocks, num_gdn_layers, num_heads, k_dim, v_dim] │ │ │
│ │ │ │ │ │
│ │ │ For Qwen-Next: k_dim = 128, v_dim = 128 │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Usage
// Enable prefix cache with align mode for Qwen-Next
runtime::Options options;
options.set_enable_prefix_cache(true);
options.set_mamba_cache_mode("align");
options.set_block_size(560);Files Changed
| File | Change Type |
|---|---|
core/framework/prefix_cache/mamba_cache_manager.h |
Added |
core/framework/prefix_cache/mamba_cache_manager.cpp |
Added |
core/framework/model/model_args.h |
Modified |
core/framework/model/model_input_params.h |
Modified |
core/layers/common/qwen3_next_gated_delta_net.h |
Modified |
core/layers/common/qwen3_next_gated_delta_net.cpp |
Modified |
core/runtime/options.h |
Modified |
core/runtime/llm_worker_impl.cpp |
Modified |
core/runtime/vlm_worker_impl.cpp |
Modified |
models/llm/qwen3_next.h |
Modified |
Testing
- Unit tests for
MambaCacheManager - Unit tests for state copy functions
- Integration tests for Qwen-Next with prefix cache enabled
- End-to-end accuracy verification
Alternatives
No response
Additional context
No response
Reactions are currently unavailable