Skip to content

Commit f4fc826

Browse files
committed
Add support for KVCache reuse for DSv32
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent cc5a058 commit f4fc826

File tree

2 files changed

+8
-14
lines changed

2 files changed

+8
-14
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -930,15 +930,15 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
930930
start_idx=0,
931931
)
932932

933-
if len(chunk_groups) > 1:
933+
if len(chunk_groups
934+
) > 1 or metadata.enable_context_mla_with_cached_kv:
934935
metadata.indexer_prefill_chunks = [
935936
Indexer.prepare_one_prefill_chunk(
936937
metadata,
937938
chunk_specs,
938939
) for chunk_specs in chunk_groups
939940
]
940941
else:
941-
# Single chunk - use non-chunked fallback path
942942
metadata.indexer_prefill_chunks = None
943943

944944
host_cu_seqlen_ks, _ = compute_cu_seqlen_kv_bounds_with_cache(
@@ -1016,9 +1016,9 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
10161016
metadata.slot_mapping_scale[:total_tokens].copy_(
10171017
metadata.host_slot_mapping_scale[:total_tokens], non_blocking=True)
10181018

1019-
# Only when MLA chunked prefill is enabled, we need to gather the full KV for indexer's logit computation.
1019+
# When chunked prefill or KVCache reuse is enabled, we need to gather the full KV for indexer's logit computation.
10201020
# Indexer's own chunking does not need full KV gathering, instead it gathers only the current chunk with loop-based gathering.
1021-
_need_full_kv_gathering = num_contexts > 0 and has_mla_chunked_prefill
1021+
_need_full_kv_gathering = num_contexts > 0 and metadata.enable_context_mla_with_cached_kv
10221022
if _need_full_kv_gathering:
10231023
total_kv_len = metadata.host_ctx_kv_indptr[num_contexts].item()
10241024
total_kv_per_request = seq_lens[:

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2417,17 +2417,13 @@ def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
24172417
if get_sm_version() == 100 or get_sm_version() == 103:
24182418
moe_backend = "DEEPGEMM" if moe_backend == "_DEFAULT" else moe_backend
24192419
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
2420-
# TODO: Support block reuse for DeepSeek-V3.2
2421-
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
2422-
free_gpu_memory_fraction=0.6,
2420+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6,
24232421
tokens_per_block=64)
24242422
else:
24252423
if moe_backend != "_DEFAULT":
24262424
pytest.skip("Not supported MoE backend!")
24272425
moe_config = MoeConfig()
2428-
# TODO: Support block reuse for DeepSeek-V3.2
2429-
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
2430-
free_gpu_memory_fraction=0.7,
2426+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
24312427
tokens_per_block=64)
24322428

24332429
pytorch_config = dict(
@@ -2490,8 +2486,7 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
24902486
"MOE TRTLLM backend does not support SM version 120 or 121")
24912487

24922488
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
2493-
kv_cache_config = KvCacheConfig(enable_block_reuse=True,
2494-
free_gpu_memory_fraction=0.7,
2489+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
24952490
tokens_per_block=64)
24962491
cuda_graph_config = CudaGraphConfig(
24972492
enable_padding=True,
@@ -2550,8 +2545,7 @@ def test_nvfp4_multi_gpus_chunked_prefill(self, tp_size, pp_size, ep_size,
25502545
"MOE TRTLLM backend does not support SM version 120 or 121")
25512546

25522547
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
2553-
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
2554-
free_gpu_memory_fraction=0.7,
2548+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
25552549
tokens_per_block=64)
25562550
cuda_graph_config = CudaGraphConfig(
25572551
enable_padding=True,

0 commit comments

Comments
 (0)