Skip to content

[TRTLLM-5838][fix] fix max batch size and max tokens in kv cache estimations for Nemotron-H #5371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
1752239
WIP: consider num_attention_layers for kv cache estimation and add ma…
tomeras91 Jun 12, 2025
7829ec9
Merge branch 'main' into fix-trtllm-bench-for-nemotron-h
tomeras91 Jun 18, 2025
4403183
organize code and logging for max batch size calculation for trtllm-b…
tomeras91 Jun 19, 2025
6ff4602
consider only attention layers when estimating number of tokens in Kv…
tomeras91 Jun 19, 2025
e6615a8
propagate kv_cache_gpu_mem_fraction to calc_engine_setting for trtllm…
tomeras91 Jun 19, 2025
42d65f3
release mamba cache memory when shutting down MambaCacheManager (and …
tomeras91 Jun 19, 2025
17d22e5
small refactor - MambaCacheManager method names to match BaseResource…
tomeras91 Jun 19, 2025
7dfeab8
refactor - is_nemotron_hybrid works on dicts as well
tomeras91 Jun 19, 2025
ee85bac
remove log
tomeras91 Jun 19, 2025
d0d0b7e
Add comment explaining squaring of kv_cache_gpu_mem_fraction + save r…
tomeras91 Jun 19, 2025
63bea92
remove debug print
tomeras91 Jun 19, 2025
c8c71df
fix - use config.get() only if config is a dict
tomeras91 Jun 19, 2025
3e6a30e
Merge branch 'main' into fix-trtllm-bench-for-nemotron-h
tomeras91 Jun 24, 2025
83e0673
optimistic tune max batch size only if not mamba attention hybrid model
tomeras91 Jun 25, 2025
4b2ba21
Merge branch 'main' into fix-trtllm-bench-for-nemotron-h
tomeras91 Jun 25, 2025
e6e65fc
Merge branch 'main' into fix-trtllm-bench-for-nemotron-h
tomeras91 Jun 26, 2025
8cf5ee7
Merge branch 'fix-trtllm-bench-for-nemotron-h' of github.com:tomeras9…
tomeras91 Jun 26, 2025
aa5d87c
fix: Mamba cache size estimation for FP8 - always use NO_QUANT for ma…
tomeras91 Jun 26, 2025
ac481b2
Merge branch 'main' into fix-trtllm-bench-for-nemotron-h
tomeras91 Jun 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,14 @@ def _get_cache_size_per_token(model_config: ModelConfig,
config.hidden_size // config.num_attention_heads,
) * num_key_value_heads // tp_size

if is_nemotron_hybrid(config):
num_attention_layers = config.hybrid_override_pattern.count("*")
else:
num_attention_layers = config.num_hidden_layers
# provide at least 1 layer to prevent division by zero cache size
num_hidden_layers = max(
len(mapping.pp_layers(config.num_hidden_layers)), 1)
mem_per_token *= num_hidden_layers * head_dim
num_attention_layers = max(len(mapping.pp_layers(num_attention_layers)),
1)
mem_per_token *= num_attention_layers * head_dim
Comment on lines +90 to +92
Copy link
Preview

Copilot AI Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The variable 'num_attention_layers' is being reassigned to represent the number of mapped pipeline layers instead of its original meaning. Consider using a new variable name (e.g., 'mapped_attention_layers') to preserve clarity and avoid confusion.

Suggested change
num_attention_layers = max(len(mapping.pp_layers(num_attention_layers)),
1)
mem_per_token *= num_attention_layers * head_dim
mapped_attention_layers = max(len(mapping.pp_layers(num_attention_layers)),
1)
mem_per_token *= mapped_attention_layers * head_dim

Copilot uses AI. Check for mistakes.

# K and V
mem_per_token *= kv_factor
return mem_per_token
Expand Down
5 changes: 2 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/config_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
def is_nemotron_hybrid(config):
if hasattr(config, "hybrid_override_pattern"):
return True
return False
return getattr(config, "hybrid_override_pattern", None) is not None \
or (isinstance(config, dict) and config.get("hybrid_override_pattern", None) is not None)


def is_mla(config):
Expand Down
37 changes: 23 additions & 14 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def __init__(
device=device,
dtype=torch.int32)

def prepare_mamba_cache_blocks(self, request_ids: List[int]):
def _prepare_mamba_cache_blocks(self, request_ids: List[int]):
state_indices = []
for r in request_ids:
# cache hit
Expand All @@ -611,23 +611,21 @@ def prepare_mamba_cache_blocks(self, request_ids: List[int]):
dtype=torch.int32,
device=self.ssm_states.device)

def free_mamba_cache_blocks(self, request_id: int):
if request_id in self.mamba_cache_index:
block = self.mamba_cache_index.pop(request_id)
self.mamba_cache_free_blocks.append(block)

def prepare_mamba_resources(self, scheduled_batch: ScheduledRequests):
def prepare_resources(self, scheduled_batch: ScheduledRequests):
context_ids = [
i.py_request_id for i in scheduled_batch.context_requests
]
generation_ids = [
i.py_request_id for i in scheduled_batch.generation_requests
]
request_ids = context_ids + generation_ids
self.prepare_mamba_cache_blocks(request_ids)
self._prepare_mamba_cache_blocks(request_ids)

def free_mamba_resources(self, request: LlmRequest):
self.free_mamba_cache_blocks(request.py_request_id)
def free_resources(self, request: LlmRequest):
request_id = request.py_request_id
if request_id in self.mamba_cache_index:
block = self.mamba_cache_index.pop(request_id)
self.mamba_cache_free_blocks.append(block)

def get_state_indices(self) -> torch.Tensor:
return self.state_indices
Expand All @@ -640,6 +638,13 @@ def get_ssm_states(self, layer_idx: int) -> torch.Tensor:
layer_offset = self.mamba_layer_offsets[layer_idx]
return self.ssm_states[layer_offset]

def shutdown(self):
# release tensor memory, keeping python references as tensors
self.conv_states = torch.tensor([])
self.ssm_states = torch.tensor([])
self.state_indices = torch.tensor([])
torch.cuda.empty_cache()


class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):

Expand Down Expand Up @@ -710,12 +715,16 @@ def __init__(
)

def prepare_resources(self, scheduled_batch: ScheduledRequests):
self.prepare_mamba_resources(scheduled_batch)
super().prepare_resources(scheduled_batch)
MambaCacheManager.prepare_resources(self, scheduled_batch)
KVCacheManager.prepare_resources(self, scheduled_batch)

def free_resources(self, request: LlmRequest):
self.free_mamba_resources(request)
super().free_resources(request)
MambaCacheManager.free_resources(self, request)
KVCacheManager.free_resources(self, request)

def shutdown(self):
MambaCacheManager.shutdown(self)
KVCacheManager.shutdown(self)


class SlotManager:
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/bench/benchmark/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
params.get("pp"),
dataset_metadata.avg_isl,
dataset_metadata.avg_osl,
params.get("kv_cache_free_gpu_mem_fraction"),
)

logger.info(
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/bench/build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_benchmark_engine_settings(
pp_size: int,
target_input_len: int,
target_output_len: int,
kv_cache_gpu_mem_fraction: float = 0.95,
) -> Tuple[int, int]:
""" Retrieve benchmark settings for a specific model + configuration.

Expand Down Expand Up @@ -58,6 +59,7 @@ def get_benchmark_engine_settings(
pp_size,
target_input_len,
target_output_len,
kv_cache_gpu_mem_fraction,
)
else:
max_batch_size = DEFAULT_MAX_BATCH_SIZE
Expand Down
35 changes: 34 additions & 1 deletion tensorrt_llm/bench/build/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import json
import struct

from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid


def parse_safetensors_file_metadata(model_path, filename):

Expand Down Expand Up @@ -111,6 +113,29 @@ def _parse(filename: str) -> None:
return huggingface_hub.get_safetensors_metadata(model_name_or_path)


class MambaConfig(BaseModel):
d_model: int = Field(
validation_alias=AliasChoices("d_model", "hidden_size", "n_embd"))
d_state: int = Field(
validation_alias=AliasChoices("d_state", "ssm_state_size"))
d_conv: int = Field(validation_alias=AliasChoices("d_conv", "conv_kernel"))
expand: int
n_groups: int
head_dim: int = Field(
validation_alias=AliasChoices("head_dim", "mamba_head_dim"))
d_inner: int = Field(default=None)
n_heads: int = Field(default=None)

@model_validator(mode="after")
def set_values_if_none(self):
""" Set the values if cannot get values from HF config.json. """
if not self.d_inner:
self.d_inner = self.d_model * self.expand
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: fix this for Nemotron-H-4B

if not self.n_heads:
self.n_heads = self.d_inner // self.head_dim
return self


class ModelConfig(BaseModel):
""" Model specific configurations. The parameters are needed in engine
setting calculation.
Expand Down Expand Up @@ -161,6 +186,8 @@ class ModelConfig(BaseModel):
None] = Field(default="float16",
validation_alias=AliasChoices(
"dtype", "torch_dtype"))
hybrid_override_pattern: Optional[str] = Field(default=None)
mamba_config: Optional[MambaConfig] = Field(default=None)

@model_validator(mode="after")
def set_values_if_none(self):
Expand Down Expand Up @@ -193,4 +220,10 @@ def from_hf(cls, model_hf_name, hf_model_path):
model_name_or_path, trust_remote_code=True).to_dict()
param_count = cls.get_param_count(model_hf_name, hf_model_path)

return cls(name=model_hf_name, param_count=param_count, **hf_config)
mamba_config = MambaConfig(
**hf_config) if is_nemotron_hybrid(hf_config) else None

return cls(name=model_hf_name,
param_count=param_count,
mamba_config=mamba_config,
**hf_config)
87 changes: 64 additions & 23 deletions tensorrt_llm/bench/build/tuning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Tuple

from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
from tensorrt_llm.llmapi.llm_utils import QuantConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.quantization.mode import QuantAlgo
Expand Down Expand Up @@ -55,7 +56,15 @@ def calc_engine_setting(

# Each GPU in TP group has at least 1 kv head
adjusted_num_kv_heads = max(tp_size, model_config.num_key_value_heads)
byte_per_token = 2 * model_config.num_hidden_layers * adjusted_num_kv_heads \

is_mamba_attn_hybrid = is_nemotron_hybrid(model_config)
if is_mamba_attn_hybrid:
num_attention_layers = model_config.hybrid_override_pattern.count("*")
else:
num_attention_layers = model_config.num_hidden_layers
logger.info(f"Number of attention layers: {num_attention_layers}")

byte_per_token = 2 * num_attention_layers * adjusted_num_kv_heads \
* model_config.head_size * byte_per_kv_elem / (1024 ** 3)

# Number of GPU used for this run.
Expand All @@ -70,19 +79,48 @@ def calc_engine_setting(
f"{available_memory:.2f} GB")

# Calculate max requests in KV cache based on target ISL and OSL.
kv_cache_memory = available_memory * kv_cache_gpu_mem_fraction
kv_cache_max_tokens = kv_cache_memory / byte_per_token
kv_cache_max_requests = kv_cache_max_tokens / (target_input_len +
target_output_len)
logger.info(f"Estimated total KV cache memory: {kv_cache_memory:.2f} GB")
target_seq_len = target_input_len + target_output_len

if is_mamba_attn_hybrid:
num_mamba_layers = model_config.hybrid_override_pattern.count("M")
conv_dim = model_config.mamba_config.d_inner + 2 * model_config.mamba_config.n_groups * model_config.mamba_config.d_state
num_conv_state_elements = num_mamba_layers * conv_dim * (
model_config.mamba_config.d_conv - 1)
num_ssm_state_elements = num_mamba_layers * model_config.mamba_config.n_heads * model_config.mamba_config.head_dim * model_config.mamba_config.d_state
byte_per_state_elem = BYTES_PER_ELEM.get(QuantAlgo.NO_QUANT)
byte_per_mamba_cache = byte_per_state_elem * (
num_conv_state_elements + num_ssm_state_elements) / (1024**3)

# Each mamba cache entry is pretty large (~50MB), so we are more conservative when estimating the max batch size
kv_cache_gpu_mem_fraction *= kv_cache_gpu_mem_fraction
else:
byte_per_mamba_cache = 0

cache_memory = available_memory * kv_cache_gpu_mem_fraction
kv_cache_max_requests = cache_memory / (byte_per_token * target_seq_len +
byte_per_mamba_cache)
mamba_cache_memory = byte_per_mamba_cache * kv_cache_max_requests
kv_cache_max_tokens = (cache_memory - mamba_cache_memory) / byte_per_token

if is_mamba_attn_hybrid:
kv_cache_memory = kv_cache_max_tokens * byte_per_token
logger.info(
f"Estimated total cache memory: {cache_memory:.2f} GB. KV cache: {kv_cache_memory:.2f} GB, Mamba cache: {mamba_cache_memory:.2f} GB"
)
else:
logger.info(f"Estimated total KV cache memory: {cache_memory:.2f} GB")
logger.info(f"Estimated kv cache max tokens: {kv_cache_max_tokens:.2f}")
logger.info("Estimated max number of requests in KV cache memory: "
f"{kv_cache_max_requests:.2f}")

# Fine-tune the max batch size and num token setting for performance.
max_batch_size, max_num_tokens = finetune_setting(kv_cache_max_requests,
target_input_len,
target_output_len,
pp_size)
# For mamba-attn hybrid models, we disable optimistic tuning because the mamba cache leaves less memory for the KV cache
max_batch_size, max_num_tokens = finetune_setting(
kv_cache_max_requests,
target_input_len,
target_output_len,
pp_size,
enable_optimistic_tuning=not is_mamba_attn_hybrid)

# Functional and performance
if total_gpu_memory < engine_size:
Expand All @@ -107,7 +145,7 @@ def calc_engine_setting(
if kv_cache_max_requests < 1:
raise RuntimeError("The amount of KV cache memory is insufficient to "
"run this model. Please try with more GPUs.")
if kv_cache_memory / n_gpus < 10.0:
if cache_memory / n_gpus < 10.0:
logger.warning(
f"The KV cache memory per GPU is less than 10 GB. "
"Performance may be undesirable. Please consider using a different "
Expand All @@ -121,12 +159,11 @@ def calc_engine_setting(
return max_batch_size, max_num_tokens


def finetune_setting(
kv_cache_max_requests: float,
input_len: int,
output_len: int,
pp_size: int,
) -> Tuple[int, int]:
def finetune_setting(kv_cache_max_requests: float,
input_len: int,
output_len: int,
pp_size: int,
enable_optimistic_tuning: bool = True) -> Tuple[int, int]:
""" Calculate and fine-tune the engine build settings (max batch size and
max num tokens). Both max batch size and max num tokens are fine-tuned
to be slightly optimistic.
Expand All @@ -137,6 +174,7 @@ def finetune_setting(
input_len (int): Input sequence length to compile the engine.
output_len (int): Output sequence length to compile the engine.
pp_size (int): Number of pipeline parallel stages.
enable_optimistic_tuning (bool): Whether to enable optimistic tuning.

Returns:
Tuple[int, int]: Tuple containing fine-tuned values for engine
Expand All @@ -148,13 +186,16 @@ def finetune_setting(
raw_token = min(raw_bs * (1 + input_len / output_len), 32768)

# Fine-tune the max batch size.
# Set min BS to be 64.
if raw_bs < 256:
max_bs = max(64, 32 * math.ceil(raw_bs / 32))
elif raw_bs < 1024:
max_bs = 128 * math.ceil(raw_bs / 128)
if enable_optimistic_tuning:
# Set min BS to be 64.
if raw_bs < 256:
max_bs = max(64, 32 * math.ceil(raw_bs / 32))
elif raw_bs < 1024:
max_bs = 128 * math.ceil(raw_bs / 128)
else:
max_bs = 256 * math.ceil(raw_bs / 256)
else:
max_bs = 256 * math.ceil(raw_bs / 256)
max_bs = 2 * math.floor(raw_bs / 2)

# Fine-tune the max num tokens.
# Set min to 2048 to ensure Ctx/Gen overlap efficiency
Expand Down