Skip to content

[Feature]: Support Prefix Cache for Qwen-Next Model (Mamba Cache Mode) #1016

@JC-ut0

Description

@JC-ut0

🚀 The motivation and feature

Summary

This issue tracks the implementation of prefix cache support for Qwen-Next model, which uses a hybrid architecture combining standard Transformer attention with linear attention (Gated Delta Net). The implementation introduces a new MambaCacheMode configuration to handle state caching for linear attention layers.

Background

Qwen-Next is a hybrid architecture model that alternates between:

  • Standard Attention layers (every 4th layer)
  • Linear Attention layers (Gated Delta Net, other layers)

Unlike traditional LLM models, linear attention layers have recurrent state dependencies, meaning the current state depends on all historical inputs. This makes traditional KV Cache prefix caching inapplicable.

Architecture Diagram

┌─────────────────────────────────────────────────────────────────────────┐
│                         Qwen-Next Layer Structure                        │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│   Layer 0:    ┌─────────────────────────────────────────────────────┐   │
│               │  Linear Attention (GDN)                              │   │
│               │  ┌───────────────┐  ┌────────────────────────────┐  │   │
│               │  │  Conv Cache   │  │  SSM Cache (recurrent)     │  │   │
│               │  └───────────────┘  └────────────────────────────┘  │   │
│               └─────────────────────────────────────────────────────┘   │
│                                                                          │
│   Layer 1:    ┌─────────────────────────────────────────────────────┐   │
│               │  Linear Attention (GDN)                              │   │
│               └─────────────────────────────────────────────────────┘   │
│                                                                          │
│   Layer 2:    ┌─────────────────────────────────────────────────────┐   │
│               │  Linear Attention (GDN)                              │   │
│               └─────────────────────────────────────────────────────┘   │
│                                                                          │
│   Layer 3:    ┌─────────────────────────────────────────────────────┐   │
│               │  Standard Attention                                  │   │
│               │  ┌───────────────┐  ┌────────────────────────────┐  │   │
│               │  │   Key Cache   │  │   Value Cache              │  │   │
│               │  └───────────────┘  └────────────────────────────┘  │   │
│               └─────────────────────────────────────────────────────┘   │
│                                                                          │
│   ... (pattern repeats every 4 layers)                                  │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Key Challenges

  1. State Continuity: GDN's recurrent state formula h_t = f(h_{t-1}, input_t) requires state continuity
  2. Cannot Directly Reuse: Traditional KV Cache cannot be directly applied to linear attention
  3. Alignment Requirement: States must be cached at specific block boundaries

Solution

Introduce MambaCacheMode with three modes:

Mode Description Use Case
none No caching for mamba states Prefix cache disabled
all Cache states at all block boundaries Models supporting state snapshots
align Cache states only at aligned positions Qwen-Next (recommended)

Mamba Cache Mode Comparison

┌─────────────────────────────────────────────────────────────────────────┐
│                        Mamba Cache Mode Comparison                       │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  Token Position:  0   100  200  300  400  500  600  700  800  900      │
│                   │    │    │    │    │    │    │    │    │    │       │
│                   ▼    ▼    ▼    ▼    ▼    ▼    ▼    ▼    ▼    ▼       │
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │ Mode: "none"                                                     │   │
│  │                                                                  │   │
│  │  [No state caching - always recompute from scratch]             │   │
│  │                                                                  │   │
│  │  XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX   │   │
│  │  (all tokens need recomputation)                                 │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │ Mode: "all" (Not supported by Qwen-Next)                        │   │
│  │                                                                  │   │
│  │  [Cache at every block boundary]                                │   │
│  │                                                                  │   │
│  │  ┌───┐   ┌───┐   ┌───┐   ┌───┐   ┌───┐   ┌───┐   ┌───┐        │   │
│  │  │ S │   │ S │   │ S │   │ S │   │ S │   │ S │   │ S │        │   │
│  │  └───┘   └───┘   └───┘   └───┘   └───┘   └───┘   └───┘        │   │
│  │  pos:0   128    256    384    512    640    768    896        │   │
│  │                                                                  │   │
│  │  S = Saved state                                                 │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │ Mode: "align" (Recommended for Qwen-Next)                       │   │
│  │                                                                  │   │
│  │  [Cache only at scheduler step boundaries when aligned]         │   │
│  │                                                                  │   │
│  │  Prefill Step 1 (560 tokens):                                    │   │
│  │  XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX┌───┐         │   │
│  │                                                   │ S │         │   │
│  │  pos:0                                           560│         │   │
│  │                                                   └───┘         │   │
│  │                                                                  │   │
│  │  Prefill Step 2 (560 tokens):                                    │   │
│  │  [reuse cached state]───────────────────────────┐┌───┐         │   │
│  │                                                 ││ S │         │   │
│  │  pos:560                                       1120│         │   │
│  │                                                 └───┘         │   │
│  │                                                                  │   │
│  │  S = Saved state at block boundary                              │   │
│  │  X = Tokens computed in this step                               │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Implementation Details

1. New Components

┌─────────────────────────────────────────────────────────────────────────┐
│                         Component Architecture                           │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                      Scheduler Layer                             │   │
│  │  ┌─────────────────┐  ┌─────────────────┐  ┌─────────────────┐  │   │
│  │  │ Request Scheduler│  │ Block Allocator │  │ Token Aligner   │  │   │
│  │  └────────┬────────┘  └────────┬────────┘  └────────┬────────┘  │   │
│  └───────────┼─────────────────────┼─────────────────────┼──────────┘   │
│              │                     │                     │              │
│              ▼                     ▼                     ▼              │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                      Cache Management Layer                      │   │
│  │  ┌─────────────────┐  ┌─────────────────┐  ┌─────────────────┐  │   │
│  │  │  PrefixCache    │  │  BlockManager   │  │ MambaCacheMgr   │  │   │
│  │  │  (Token-based)  │  │  (Block-based)  │  │ (State-based)   │  │   │
│  │  └────────┬────────┘  └────────┬────────┘  └────────┬────────┘  │   │
│  └───────────┼─────────────────────┼─────────────────────┼──────────┘   │
│              │                     │                     │              │
│              ▼                     ▼                     ▼              │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                       Model Layer                                │   │
│  │  ┌─────────────────────────────────────────────────────────┐    │   │
│  │  │              Qwen3NextDecoderLayer                       │    │   │
│  │  │  ┌───────────────────┐  ┌───────────────────────────┐   │    │   │
│  │  │  │ Qwen3NextAttention│  │ Qwen3NextGatedDeltaNet    │   │    │   │
│  │  │  │ (Standard Attn)   │  │ (Linear Attn)             │   │    │   │
│  │  │  │ - k,v cache       │  │ - conv_cache, ssm_cache   │   │    │   │
│  │  │  └───────────────────┘  └───────────────────────────┘   │    │   │
│  │  └─────────────────────────────────────────────────────────┘    │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

2. State Caching Flow (Align Mode)

┌─────────────────────────────────────────────────────────────────────────┐
│                    Prefill Flow with Align Mode                          │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  Step 1: Prefill (tokens 0-559)                                         │
│  ┌──────────────────────────────────────────────────────────────────┐  │
│  │                                                                   │  │
│  │  Input Tokens: [t0, t1, t2, ... t559]                            │  │
│  │       │                                                          │  │
│  │       ▼                                                          │  │
│  │  ┌─────────────────────────────────────────────────────────┐    │  │
│  │  │  GDN Forward (compute all 560 tokens)                    │    │  │
│  │  │                                                          │    │  │
│  │  │  h_0 → h_1 → h_2 → ... → h_559                          │    │  │
│  │  │                                                          │    │  │
│  │  │  Check: 560 % 560 == 0? → YES                           │    │  │
│  │  │       │                                                  │    │  │
│  │  │       ▼                                                  │    │  │
│  │  │  Save: conv_cache[Block_0], ssm_cache[Block_0]          │    │  │
│  │  └─────────────────────────────────────────────────────────┘    │  │
│  │                                                                   │  │
│  └──────────────────────────────────────────────────────────────────┘  │
│                                                                          │
│  Step 2: Prefill (tokens 560-1119) - Cache Hit!                         │
│  ┌──────────────────────────────────────────────────────────────────┐  │
│  │                                                                   │  │
│  │  Input Tokens: [t560, t561, ... t1119]                           │  │
│  │       │                                                          │  │
│  │       ▼                                                          │  │
│  │  ┌─────────────────────────────────────────────────────────┐    │  │
│  │  │  Load cached state from Block_0:                        │    │  │
│  │  │  - conv_cache[Block_0] → initial_conv_state             │    │  │
│  │  │  - ssm_cache[Block_0] → initial_ssm_state               │    │  │
│  │  └─────────────────────────────────────────────────────────┘    │  │
│  │       │                                                          │  │
│  │       ▼                                                          │  │
│  │  ┌─────────────────────────────────────────────────────────┐    │  │
│  │  │  GDN Forward (continue from cached state)               │    │  │
│  │  │                                                          │    │  │
│  │  │  h_559 → h_560 → ... → h_1119                           │    │  │
│  │  │   ↑                                                      │    │  │
│  │  │   └── loaded from cache                                  │    │  │
│  │  │                                                          │    │  │
│  │  │  Check: 1120 % 560 == 0? → YES                          │    │  │
│  │  │       │                                                  │    │  │
│  │  │       ▼                                                  │    │  │
│  │  │  Save: conv_cache[Block_1], ssm_cache[Block_1]          │    │  │
│  │  └─────────────────────────────────────────────────────────┘    │  │
│  │                                                                   │  │
│  └──────────────────────────────────────────────────────────────────┘  │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

3. Modified Components

  • Qwen3NextGatedDeltaNet: Added support for conditional state saving based on mamba_cache_mode
  • ModelInputParams: Added mamba_cache_mode and mamba_block_size fields
  • ModelArgs: Added configuration parameters
  • runtime::Options: Added runtime option for mamba_cache_mode

4. State Caching Logic

// In Qwen3NextGatedDeltaNet::forward()
bool should_save_state = false;
if (mamba_cache_mode_ == MambaCacheMode::kAlign) {
  if (mamba_block_size_ > 0 && seq_len % mamba_block_size_ == 0) {
    should_save_state = true;
  }
} else if (mamba_cache_mode_ == MambaCacheMode::kAll) {
  should_save_state = true;
}

if (should_save_state) {
  // Save conv_state and ssm_state
}

Memory Layout

┌─────────────────────────────────────────────────────────────────────────┐
│                              GPU Memory Layout                           │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                    KV Cache Memory Pool                          │   │
│  │  ┌─────────────────────────────────────────────────────────┐    │   │
│  │  │  Key Cache   [num_blocks, num_layers, block_size,       │    │   │
│  │  │               num_heads, head_dim]                      │    │   │
│  │  └─────────────────────────────────────────────────────────┘    │   │
│  │  ┌─────────────────────────────────────────────────────────┐    │   │
│  │  │  Value Cache [num_blocks, num_layers, block_size,       │    │   │
│  │  │               num_heads, head_dim]                      │    │   │
│  │  └─────────────────────────────────────────────────────────┘    │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                  Mamba State Memory Pool (NEW)                   │   │
│  │  ┌─────────────────────────────────────────────────────────┐    │   │
│  │  │  Conv State Cache                                        │    │   │
│  │  │  [num_blocks, num_gdn_layers, conv_dim, kernel_size]    │    │   │
│  │  │                                                          │    │   │
│  │  │  For Qwen-Next: kernel_size = 4                          │    │   │
│  │  └─────────────────────────────────────────────────────────┘    │   │
│  │  ┌─────────────────────────────────────────────────────────┐    │   │
│  │  │  SSM State Cache                                         │    │   │
│  │  │  [num_blocks, num_gdn_layers, num_heads, k_dim, v_dim]  │    │   │
│  │  │                                                          │    │   │
│  │  │  For Qwen-Next: k_dim = 128, v_dim = 128                 │    │   │
│  │  └─────────────────────────────────────────────────────────┘    │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Usage

// Enable prefix cache with align mode for Qwen-Next
runtime::Options options;
options.set_enable_prefix_cache(true);
options.set_mamba_cache_mode("align");
options.set_block_size(560);

Files Changed

File Change Type
core/framework/prefix_cache/mamba_cache_manager.h Added
core/framework/prefix_cache/mamba_cache_manager.cpp Added
core/framework/model/model_args.h Modified
core/framework/model/model_input_params.h Modified
core/layers/common/qwen3_next_gated_delta_net.h Modified
core/layers/common/qwen3_next_gated_delta_net.cpp Modified
core/runtime/options.h Modified
core/runtime/llm_worker_impl.cpp Modified
core/runtime/vlm_worker_impl.cpp Modified
models/llm/qwen3_next.h Modified

Testing

  • Unit tests for MambaCacheManager
  • Unit tests for state copy functions
  • Integration tests for Qwen-Next with prefix cache enabled
  • End-to-end accuracy verification

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions