Skip to content

Commit dcc6d62

Browse files
authored
Fix over provisioning of Mamba state memory when max_requests is set (#4114)
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
1 parent 5b512b4 commit dcc6d62

File tree

6 files changed

+232
-8
lines changed

6 files changed

+232
-8
lines changed

megatron/core/inference/config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

3-
from dataclasses import dataclass
3+
from dataclasses import InitVar, dataclass
44
from enum import Enum
55
from typing import List, Optional, Tuple
66

@@ -309,7 +309,12 @@ class InferenceConfig:
309309
performance variability for MoEs.
310310
"""
311311

312-
def __post_init__(self):
312+
verbose: InitVar[bool] = False
313+
"""Whether to log detailed context configuration at initialization.
314+
This is an InitVar and is not stored as a field on the config."""
315+
316+
def __post_init__(self, verbose: bool):
317+
self._verbose = verbose
313318
if not (0.0 <= self.prefix_caching_routing_alpha <= 1.0):
314319
raise ValueError(
315320
f"prefix_caching_routing_alpha must be in [0, 1], "

megatron/core/inference/contexts/dynamic_context.py

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ def deserialize(cls, obj: dict) -> ContextOverflowError:
205205

206206
def get_mem_size_str(n_bytes: int) -> str:
207207
"""Convert number of bytes to human-readable string."""
208+
if n_bytes == 0:
209+
return "0 bytes"
208210
for exp, suffix in ((4, "TB"), (3, "GB"), (2, "MB"), (3, "KB"), (0, "bytes")):
209211
nquery = int(1024**exp)
210212
if round(n_bytes / nquery) >= 1:
@@ -449,6 +451,26 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC
449451
buffer_size_bytes = int(buffer_size_bytes * (1.0 - mamba_memory_ratio))
450452
paused_buffer_size_bytes = int(paused_buffer_size_bytes * (1.0 - mamba_memory_ratio))
451453

454+
block_count = buffer_size_bytes // self.block_size_bytes
455+
block_count = max(2, block_count) # need >= 1 active block + 1 dummy block
456+
paused_block_count = paused_buffer_size_bytes // self.block_size_bytes
457+
elif self.is_hybrid_model and inference_config.max_requests is not None:
458+
# Auto-derive mamba/KV split from max_requests. Allocate exactly enough
459+
# mamba memory for max_requests, and give the rest to KV cache blocks.
460+
total_memory = buffer_size_bytes + paused_buffer_size_bytes
461+
mamba_memory_needed = inference_config.max_requests * mamba_states_memory_per_request
462+
assert mamba_memory_needed < total_memory, (
463+
f"Not enough memory for {inference_config.max_requests} mamba requests. "
464+
f"Need {mamba_memory_needed / 1024**3:.2f} GB for mamba states, "
465+
f"but total buffer is {total_memory / 1024**3:.2f} GB."
466+
)
467+
mamba_max_requests = inference_config.max_requests
468+
469+
# Subtract mamba memory proportionally from active and paused buffers.
470+
mamba_memory_ratio = mamba_memory_needed / total_memory
471+
buffer_size_bytes = int(buffer_size_bytes * (1.0 - mamba_memory_ratio))
472+
paused_buffer_size_bytes = int(paused_buffer_size_bytes * (1.0 - mamba_memory_ratio))
473+
452474
block_count = buffer_size_bytes // self.block_size_bytes
453475
block_count = max(2, block_count) # need >= 1 active block + 1 dummy block
454476
paused_block_count = paused_buffer_size_bytes // self.block_size_bytes
@@ -594,13 +616,76 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC
594616
self.initialize_all_tensors()
595617

596618
# Print info.
597-
logging.info(
598-
"DynamicInferenceContext: allocated context with active buffer size %s (%d blocks)."
599-
% (
600-
get_mem_size_str(self.kv_block_allocator.active_count * self.block_size_bytes),
601-
self.kv_block_allocator.active_count,
619+
active_blocks = self.kv_block_allocator.active_count
620+
total_blocks = self.kv_block_allocator.total_count
621+
paused_blocks = self.kv_block_allocator.paused_count
622+
active_kv_bytes = active_blocks * self.block_size_bytes
623+
total_kv_bytes = total_blocks * self.block_size_bytes
624+
paused_kv_bytes = paused_blocks * self.block_size_bytes
625+
626+
log_lines = [
627+
"DynamicInferenceContext: configuration summary",
628+
f" max_requests: {self.max_requests}",
629+
f" max_tokens: {self.max_tokens}",
630+
f" max_sequence_length: {self.max_sequence_length}",
631+
f" block_size_tokens: {self.block_size_tokens}",
632+
f" max_kv_blocks_per_req: {self.max_kv_block_count}",
633+
f" KV cache:",
634+
f" block_size_bytes: {get_mem_size_str(self.block_size_bytes)}",
635+
f" active_blocks: {active_blocks} ({get_mem_size_str(active_kv_bytes)})",
636+
f" paused_blocks: {paused_blocks} ({get_mem_size_str(paused_kv_bytes)})",
637+
f" total_blocks: {total_blocks} ({get_mem_size_str(total_kv_bytes)})",
638+
]
639+
640+
if self.is_hybrid_model:
641+
mamba_conv_bytes = (
642+
math.prod(self.mamba_conv_states_shape)
643+
* self.mamba_conv_states_dtype.itemsize
644+
* self.num_mamba_layers
602645
)
603-
)
646+
mamba_ssm_bytes = (
647+
math.prod(self.mamba_ssm_states_shape)
648+
* self.mamba_ssm_states_dtype.itemsize
649+
* self.num_mamba_layers
650+
)
651+
mamba_bytes_per_req = mamba_conv_bytes + mamba_ssm_bytes
652+
mamba_total_bytes = mamba_bytes_per_req * self.max_requests
653+
log_lines += [
654+
f" Mamba states:",
655+
f" num_mamba_layers: {self.num_mamba_layers}",
656+
f" conv_state_shape: {self.mamba_conv_states_shape}",
657+
f" ssm_state_shape: {self.mamba_ssm_states_shape}",
658+
f" per_request: {get_mem_size_str(mamba_bytes_per_req)}",
659+
f" total ({self.max_requests} requests): {get_mem_size_str(mamba_total_bytes)}",
660+
]
661+
662+
if self.num_speculative_tokens > 0:
663+
spec_multiplier = self.num_speculative_tokens + 1
664+
spec_bytes_per_req = mamba_bytes_per_req * spec_multiplier
665+
spec_total_bytes = spec_bytes_per_req * self.max_requests
666+
log_lines += [
667+
f" Mamba speculative buffers (num_speculative_tokens={self.num_speculative_tokens}):",
668+
f" per_request: {get_mem_size_str(spec_bytes_per_req)}",
669+
f" total ({self.max_requests} requests): {get_mem_size_str(spec_total_bytes)}",
670+
]
671+
672+
prefix_caching_mamba_gb = inference_config.prefix_caching_mamba_gb
673+
if (
674+
inference_config.enable_prefix_caching
675+
and prefix_caching_mamba_gb is not None
676+
and prefix_caching_mamba_gb > 0
677+
):
678+
prefix_cache_bytes = int(prefix_caching_mamba_gb * 1024**3)
679+
prefix_cache_slots = prefix_cache_bytes // mamba_bytes_per_req
680+
log_lines += [
681+
f" Mamba prefix cache:",
682+
f" budget: {get_mem_size_str(prefix_cache_bytes)}",
683+
f" slots: {prefix_cache_slots}",
684+
f" per_slot: {get_mem_size_str(mamba_bytes_per_req)}",
685+
]
686+
687+
if inference_config._verbose and torch.distributed.get_rank() == 0:
688+
logging.info("\n".join(log_lines))
604689

605690
def _allocate_memory_buffer(self):
606691
"""Allocate the KV cache memory buffer."""

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import concurrent.futures
55
import logging
6+
import math
67
import multiprocessing
78
import socket
89
import struct
@@ -22,6 +23,7 @@
2223

2324
from megatron.core.inference.config import KVCacheManagementMode
2425
from megatron.core.inference.contexts.dynamic_context import (
26+
BlockOverflowError,
2527
DynamicInferenceContext,
2628
MaxSequenceLengthOverflowError,
2729
TokenOverflowError,
@@ -941,6 +943,16 @@ def _add_request(
941943
request.status = Status.FAILED
942944
request.add_event_error_nontransient(TokenOverflowError(request_id))
943945

946+
# Check that the KV cache has enough blocks for this request's max sequence length.
947+
max_request_tokens = (
948+
len(request.prompt_tokens) + request.sampling_params.num_tokens_to_generate
949+
)
950+
request_block_count = math.ceil(max_request_tokens / self.context.block_size_tokens)
951+
total_blocks = self.context.kv_block_allocator.total_count - 1 # -1 for dummy block
952+
if request_block_count > total_blocks:
953+
request.status = Status.FAILED
954+
request.add_event_error_nontransient(BlockOverflowError(request_id))
955+
944956
# Tokenize stop words if provided
945957
if request.sampling_params.stop_words:
946958
stop_word_ids = [

megatron/inference/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args):
328328
)
329329

330330
return InferenceConfig(
331+
verbose=True,
331332
block_size_tokens=args.inference_dynamic_batching_block_size,
332333
buffer_size_gb=args.inference_dynamic_batching_buffer_size_gb,
333334
paused_buffer_size_gb=args.inference_dynamic_batching_paused_buffer_size_gb,

tests/unit_tests/inference/contexts/test_dynamic_context.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,6 +1399,100 @@ def test_mamba_memory_ratio_allocation(self, ratio):
13991399
assert context.max_requests == expected_max_requests
14001400
assert context.is_hybrid_model is True
14011401

1402+
@pytest.mark.internal
1403+
@rounder_override(1)
1404+
@pytest.mark.parametrize("max_requests", [1, 4, 64])
1405+
def test_hybrid_max_requests_auto_derives_mamba_split(self, max_requests):
1406+
"""
1407+
When max_requests is set on a hybrid model without mamba_memory_ratio,
1408+
mamba memory should be allocated for exactly max_requests slots, with
1409+
the remaining memory going to KV cache blocks.
1410+
"""
1411+
1412+
buffer_gb = 0.05
1413+
paused_gb = 0.01
1414+
block_size = 256
1415+
num_attention_heads = 8
1416+
kv_channels = 64
1417+
params_dtype = torch.float32
1418+
1419+
layer_type_list = [Symbols.MAMBA, Symbols.ATTENTION]
1420+
mamba_conv_states_shape = (544, 4)
1421+
mamba_ssm_states_shape = (8, 64, 16)
1422+
mamba_config = MambaInferenceStateConfig(
1423+
layer_type_list,
1424+
mamba_conv_states_shape,
1425+
mamba_ssm_states_shape,
1426+
params_dtype,
1427+
params_dtype,
1428+
)
1429+
1430+
context = DynamicInferenceContext(
1431+
model_config=TransformerConfig(
1432+
params_dtype=params_dtype,
1433+
num_layers=2,
1434+
kv_channels=kv_channels,
1435+
num_attention_heads=num_attention_heads,
1436+
),
1437+
inference_config=InferenceConfig(
1438+
max_sequence_length=512,
1439+
buffer_size_gb=buffer_gb,
1440+
paused_buffer_size_gb=paused_gb,
1441+
block_size_tokens=block_size,
1442+
max_tokens=2048,
1443+
mamba_inference_state_config=mamba_config,
1444+
max_requests=max_requests,
1445+
unified_memory_level=0,
1446+
),
1447+
)
1448+
1449+
dtype_size = torch.tensor([], dtype=params_dtype).element_size()
1450+
1451+
mamba_mem_per_req = math.prod(mamba_conv_states_shape) + math.prod(mamba_ssm_states_shape)
1452+
mamba_mem_per_req *= dtype_size
1453+
1454+
kv_buffer_bytes = int(buffer_gb * 1024**3)
1455+
kv_paused_bytes = int(paused_gb * 1024**3)
1456+
total_mem_bytes = kv_buffer_bytes + kv_paused_bytes
1457+
1458+
# Auto-derived ratio from max_requests.
1459+
mamba_memory_needed = max_requests * mamba_mem_per_req
1460+
ratio = mamba_memory_needed / total_mem_bytes
1461+
1462+
kv_buffer_bytes = int(kv_buffer_bytes * (1.0 - ratio))
1463+
kv_paused_bytes = int(kv_paused_bytes * (1.0 - ratio))
1464+
1465+
kv_block_size_bytes = dtype_size * 2 * 1 * block_size * num_attention_heads * kv_channels
1466+
expected_active_blocks = kv_buffer_bytes // kv_block_size_bytes
1467+
1468+
assert context.kv_block_allocator.total_count == expected_active_blocks
1469+
assert context.max_requests == max_requests
1470+
1471+
# With max_requests=1, more memory goes to KV blocks than with max_requests=64.
1472+
# Verify we get more blocks with fewer requests.
1473+
if max_requests == 1:
1474+
context_many = DynamicInferenceContext(
1475+
model_config=TransformerConfig(
1476+
params_dtype=params_dtype,
1477+
num_layers=2,
1478+
kv_channels=kv_channels,
1479+
num_attention_heads=num_attention_heads,
1480+
),
1481+
inference_config=InferenceConfig(
1482+
max_sequence_length=512,
1483+
buffer_size_gb=buffer_gb,
1484+
paused_buffer_size_gb=paused_gb,
1485+
block_size_tokens=block_size,
1486+
max_tokens=2048,
1487+
mamba_inference_state_config=mamba_config,
1488+
max_requests=64,
1489+
unified_memory_level=0,
1490+
),
1491+
)
1492+
assert (
1493+
context.kv_block_allocator.total_count > context_many.kv_block_allocator.total_count
1494+
)
1495+
14021496
@pytest.mark.internal
14031497
@rounder_override(64)
14041498
def test_max_requests_less_than_tp_size(self):

tests/unit_tests/inference/engines/test_dynamic_engine.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,33 @@ def test_block_overflow(self, model_provider: str) -> None:
701701
env.engine._add_request(env.requests[0])
702702
assert list(env.engine.waiting_request_ids) == [0]
703703

704+
@pytest.mark.internal
705+
@pytest.mark.skipif(
706+
not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching"
707+
)
708+
def test_block_overflow_insufficient_kv_cache(self) -> None:
709+
"""Test that a request fails when KV cache blocks cannot fit the request's sequence."""
710+
# Use a large max_sequence_length with a small buffer so that the total
711+
# block count is smaller than what a single max-length request needs.
712+
# With num_tokens_total=8192 and prompt_length=4, the request needs
713+
# ceil(8192 / 256) = 32 blocks, but the small buffer only has ~8 blocks.
714+
test_config = DynamicEngineTestConfig(
715+
num_requests=1,
716+
min_prompt_length=4,
717+
max_prompt_length=4,
718+
num_tokens_to_generate=None,
719+
num_tokens_total=8192,
720+
max_sequence_length=8192,
721+
context_buffer_size_gb=0.001,
722+
context_block_size_tokens=256,
723+
context_max_tokens=16384,
724+
)
725+
env = self._build_test_env(test_config)
726+
request = env.requests[0]
727+
env.engine._add_request(request)
728+
assert request.status == Status.FAILED
729+
assert list(env.engine.waiting_request_ids) == []
730+
704731
@pytest.mark.internal
705732
@pytest.mark.skipif(
706733
not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching"

0 commit comments

Comments
 (0)