From 175223962d4b7834956375c487d8379e74406fe0 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Thu, 12 Jun 2025 15:47:19 +0000 Subject: [PATCH 01/13] WIP: consider num_attention_layers for kv cache estimation and add mamba cache memory estimation Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- .../_torch/pyexecutor/config_utils.py | 3 +- tensorrt_llm/bench/build/dataclasses.py | 39 ++++++++++++++++-- tensorrt_llm/bench/build/tuning.py | 40 ++++++++++++++++++- 3 files changed, 77 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/config_utils.py b/tensorrt_llm/_torch/pyexecutor/config_utils.py index c0f0482674..aedb092141 100644 --- a/tensorrt_llm/_torch/pyexecutor/config_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/config_utils.py @@ -1,5 +1,6 @@ def is_nemotron_hybrid(config): - if hasattr(config, "hybrid_override_pattern"): + if hasattr(config, "hybrid_override_pattern" + ) and config.hybrid_override_pattern is not None: return True return False diff --git a/tensorrt_llm/bench/build/dataclasses.py b/tensorrt_llm/bench/build/dataclasses.py index 2aab385799..ab5a74bbec 100755 --- a/tensorrt_llm/bench/build/dataclasses.py +++ b/tensorrt_llm/bench/build/dataclasses.py @@ -14,6 +14,8 @@ import json import struct +from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid + def parse_safetensors_file_metadata(model_path, filename): @@ -111,6 +113,29 @@ def _parse(filename: str) -> None: return huggingface_hub.get_safetensors_metadata(model_name_or_path) +class MambaConfig(BaseModel): + d_model: int = Field( + validation_alias=AliasChoices("d_model", "hidden_size", "n_embd")) + d_state: int = Field( + validation_alias=AliasChoices("d_state", "ssm_state_size")) + d_conv: int = Field(validation_alias=AliasChoices("d_conv", "conv_kernel")) + expand: int + n_groups: int + head_dim: int = Field( + validation_alias=AliasChoices("head_dim", "mamba_head_dim")) + d_inner: int = Field(default=None) + n_heads: int = Field(default=None) + + @model_validator(mode="after") + def set_values_if_none(self): + """ Set the values if cannot get values from HF config.json. """ + if not self.d_inner: + self.d_inner = self.d_model * self.expand + if not self.n_heads: + self.n_heads = self.d_inner // self.head_dim + return self + + class ModelConfig(BaseModel): """ Model specific configurations. The parameters are needed in engine setting calculation. @@ -161,6 +186,8 @@ class ModelConfig(BaseModel): None] = Field(default="float16", validation_alias=AliasChoices( "dtype", "torch_dtype")) + hybrid_override_pattern: Optional[str] = Field(default=None) + mamba_config: Optional[MambaConfig] = Field(default=None) @model_validator(mode="after") def set_values_if_none(self): @@ -189,8 +216,14 @@ def get_param_count(cls, model_hf_name, hf_model_path): @classmethod def from_hf(cls, model_hf_name, hf_model_path): model_name_or_path = hf_model_path or model_hf_name - hf_config = AutoConfig.from_pretrained( - model_name_or_path, trust_remote_code=True).to_dict() + hf_config = AutoConfig.from_pretrained(model_name_or_path, + trust_remote_code=True) param_count = cls.get_param_count(model_hf_name, hf_model_path) - return cls(name=model_hf_name, param_count=param_count, **hf_config) + mamba_config = MambaConfig( + **hf_config.to_dict()) if is_nemotron_hybrid(hf_config) else None + + return cls(name=model_hf_name, + param_count=param_count, + mamba_config=mamba_config, + **hf_config.to_dict()) diff --git a/tensorrt_llm/bench/build/tuning.py b/tensorrt_llm/bench/build/tuning.py index f67e4a6f5f..6875f9f68f 100755 --- a/tensorrt_llm/bench/build/tuning.py +++ b/tensorrt_llm/bench/build/tuning.py @@ -1,5 +1,6 @@ from typing import Tuple +from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid from tensorrt_llm.llmapi.llm_utils import QuantConfig from tensorrt_llm.logger import logger from tensorrt_llm.quantization.mode import QuantAlgo @@ -50,12 +51,20 @@ def calc_engine_setting( Tuple[int, int]: Tuple containing engine configuration information for engine build (max_num_tokens, max_batch_size). """ + print(f"{kv_cache_gpu_mem_fraction=}") byte_per_elem = BYTES_PER_ELEM.get(quant_config.quant_algo, 2) byte_per_kv_elem = BYTES_PER_ELEM.get(quant_config.kv_cache_quant_algo, 2) # Each GPU in TP group has at least 1 kv head adjusted_num_kv_heads = max(tp_size, model_config.num_key_value_heads) - byte_per_token = 2 * model_config.num_hidden_layers * adjusted_num_kv_heads \ + # TODO: change num_hidden_layers to num_attention_layers + # TODO: Add estimation for mamba cache memory + if is_nemotron_hybrid(model_config): + num_attention_layers = model_config.hybrid_override_pattern.count("*") + else: + num_attention_layers = model_config.num_hidden_layers + logger.info(f"Number of attention layers: {num_attention_layers}") + byte_per_token = 2 * num_attention_layers * adjusted_num_kv_heads \ * model_config.head_size * byte_per_kv_elem / (1024 ** 3) # Number of GPU used for this run. @@ -63,6 +72,7 @@ def calc_engine_setting( # Total engine size. engine_size = model_config.param_count * byte_per_elem / (1024**3) total_gpu_memory = get_device_memory() * n_gpus + total_gpu_memory = 24.0 * n_gpus # Available memory to allocate KV cache. available_memory = total_gpu_memory - engine_size logger.info(f"Estimated engine size: {engine_size:.2f} GB") @@ -74,7 +84,34 @@ def calc_engine_setting( kv_cache_max_tokens = kv_cache_memory / byte_per_token kv_cache_max_requests = kv_cache_max_tokens / (target_input_len + target_output_len) + + total_seq_len = target_input_len + target_output_len + + if is_nemotron_hybrid(model_config): + num_mamba_layers = model_config.hybrid_override_pattern.count("M") + conv_dim = model_config.mamba_config.d_inner + 2 * model_config.mamba_config.n_groups * model_config.mamba_config.d_state + num_conv_state_elements = num_mamba_layers * conv_dim * ( + model_config.mamba_config.d_conv - 1) + num_ssm_state_elements = num_mamba_layers * model_config.mamba_config.n_heads * model_config.mamba_config.head_dim * model_config.mamba_config.d_state + print(f"{num_conv_state_elements=}") + print(f"{num_ssm_state_elements=}") + byte_per_state_elem = BYTES_PER_ELEM.get(quant_config.quant_algo, 2) + byte_per_mamba_cache = byte_per_state_elem * ( + num_conv_state_elements + num_ssm_state_elements) / (1024**3) + + kv_cache_gpu_mem_fraction *= kv_cache_gpu_mem_fraction + else: + byte_per_mamba_cache = 0 + + kv_cache_max_requests = available_memory * kv_cache_gpu_mem_fraction / ( + byte_per_token * total_seq_len + byte_per_mamba_cache) + + print( + f"cache memory estimation: {(byte_per_mamba_cache + total_seq_len * byte_per_token)*kv_cache_max_requests:.2f} GB" + ) + logger.info(f"Estimated total KV cache memory: {kv_cache_memory:.2f} GB") + logger.info(f"Estimated kv cache max tokens: {kv_cache_max_tokens:.2f}") logger.info("Estimated max number of requests in KV cache memory: " f"{kv_cache_max_requests:.2f}") @@ -158,6 +195,7 @@ def finetune_setting( # Fine-tune the max num tokens. # Set min to 2048 to ensure Ctx/Gen overlap efficiency + logger.info(f"Estimated max num tokens (before fine-tune): {raw_token}") if raw_token < 4096: max_token = max(2048, 256 * math.ceil(raw_token / 256)) elif raw_token < 8192: From 4403183ca1c71cf9f26d162a4b62e5efc184e578 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Thu, 19 Jun 2025 09:29:03 +0000 Subject: [PATCH 02/13] organize code and logging for max batch size calculation for trtllm-bench Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tensorrt_llm/bench/build/tuning.py | 37 ++++++++++++++---------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/tensorrt_llm/bench/build/tuning.py b/tensorrt_llm/bench/build/tuning.py index 6875f9f68f..560586bd3a 100755 --- a/tensorrt_llm/bench/build/tuning.py +++ b/tensorrt_llm/bench/build/tuning.py @@ -51,19 +51,19 @@ def calc_engine_setting( Tuple[int, int]: Tuple containing engine configuration information for engine build (max_num_tokens, max_batch_size). """ - print(f"{kv_cache_gpu_mem_fraction=}") + kv_cache_gpu_mem_fraction = 0.9 byte_per_elem = BYTES_PER_ELEM.get(quant_config.quant_algo, 2) byte_per_kv_elem = BYTES_PER_ELEM.get(quant_config.kv_cache_quant_algo, 2) # Each GPU in TP group has at least 1 kv head adjusted_num_kv_heads = max(tp_size, model_config.num_key_value_heads) - # TODO: change num_hidden_layers to num_attention_layers - # TODO: Add estimation for mamba cache memory + if is_nemotron_hybrid(model_config): num_attention_layers = model_config.hybrid_override_pattern.count("*") else: num_attention_layers = model_config.num_hidden_layers logger.info(f"Number of attention layers: {num_attention_layers}") + byte_per_token = 2 * num_attention_layers * adjusted_num_kv_heads \ * model_config.head_size * byte_per_kv_elem / (1024 ** 3) @@ -72,7 +72,6 @@ def calc_engine_setting( # Total engine size. engine_size = model_config.param_count * byte_per_elem / (1024**3) total_gpu_memory = get_device_memory() * n_gpus - total_gpu_memory = 24.0 * n_gpus # Available memory to allocate KV cache. available_memory = total_gpu_memory - engine_size logger.info(f"Estimated engine size: {engine_size:.2f} GB") @@ -80,12 +79,7 @@ def calc_engine_setting( f"{available_memory:.2f} GB") # Calculate max requests in KV cache based on target ISL and OSL. - kv_cache_memory = available_memory * kv_cache_gpu_mem_fraction - kv_cache_max_tokens = kv_cache_memory / byte_per_token - kv_cache_max_requests = kv_cache_max_tokens / (target_input_len + - target_output_len) - - total_seq_len = target_input_len + target_output_len + target_seq_len = target_input_len + target_output_len if is_nemotron_hybrid(model_config): num_mamba_layers = model_config.hybrid_override_pattern.count("M") @@ -93,8 +87,6 @@ def calc_engine_setting( num_conv_state_elements = num_mamba_layers * conv_dim * ( model_config.mamba_config.d_conv - 1) num_ssm_state_elements = num_mamba_layers * model_config.mamba_config.n_heads * model_config.mamba_config.head_dim * model_config.mamba_config.d_state - print(f"{num_conv_state_elements=}") - print(f"{num_ssm_state_elements=}") byte_per_state_elem = BYTES_PER_ELEM.get(quant_config.quant_algo, 2) byte_per_mamba_cache = byte_per_state_elem * ( num_conv_state_elements + num_ssm_state_elements) / (1024**3) @@ -103,14 +95,19 @@ def calc_engine_setting( else: byte_per_mamba_cache = 0 - kv_cache_max_requests = available_memory * kv_cache_gpu_mem_fraction / ( - byte_per_token * total_seq_len + byte_per_mamba_cache) + cache_memory = available_memory * kv_cache_gpu_mem_fraction + kv_cache_max_requests = cache_memory / (byte_per_token * target_seq_len + + byte_per_mamba_cache) + mamba_cache_memory = byte_per_mamba_cache * kv_cache_max_requests + kv_cache_max_tokens = (cache_memory - mamba_cache_memory) / byte_per_token - print( - f"cache memory estimation: {(byte_per_mamba_cache + total_seq_len * byte_per_token)*kv_cache_max_requests:.2f} GB" - ) - - logger.info(f"Estimated total KV cache memory: {kv_cache_memory:.2f} GB") + if is_nemotron_hybrid(model_config): + kv_cache_memory = kv_cache_max_tokens * byte_per_token + logger.info( + f"Estimated total cache memory: {cache_memory:.2f} GB. KV cache: {kv_cache_memory:.2f} GB, Mamba cache: {mamba_cache_memory:.2f} GB" + ) + else: + logger.info(f"Estimated total KV cache memory: {cache_memory:.2f} GB") logger.info(f"Estimated kv cache max tokens: {kv_cache_max_tokens:.2f}") logger.info("Estimated max number of requests in KV cache memory: " f"{kv_cache_max_requests:.2f}") @@ -144,7 +141,7 @@ def calc_engine_setting( if kv_cache_max_requests < 1: raise RuntimeError("The amount of KV cache memory is insufficient to " "run this model. Please try with more GPUs.") - if kv_cache_memory / n_gpus < 10.0: + if cache_memory / n_gpus < 10.0: logger.warning( f"The KV cache memory per GPU is less than 10 GB. " "Performance may be undesirable. Please consider using a different " From 6ff46028a18823b3f71a6e4fd0358b44036f0d0b Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Thu, 19 Jun 2025 09:33:19 +0000 Subject: [PATCH 03/13] consider only attention layers when estimating number of tokens in KvCacheCreator Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/_util.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index a34b946d26..37e25b721c 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -82,10 +82,14 @@ def _get_cache_size_per_token(model_config: ModelConfig, config.hidden_size // config.num_attention_heads, ) * num_key_value_heads // tp_size + if is_nemotron_hybrid(config): + num_attention_layers = config.hybrid_override_pattern.count("*") + else: + num_attention_layers = config.num_hidden_layers # provide at least 1 layer to prevent division by zero cache size - num_hidden_layers = max( - len(mapping.pp_layers(config.num_hidden_layers)), 1) - mem_per_token *= num_hidden_layers * head_dim + num_attention_layers = max(len(mapping.pp_layers(num_attention_layers)), + 1) + mem_per_token *= num_attention_layers * head_dim # K and V mem_per_token *= kv_factor return mem_per_token From e6615a8b5a874c52b875d06ce80ff5962b1e7de1 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Thu, 19 Jun 2025 09:59:08 +0000 Subject: [PATCH 04/13] propagate kv_cache_gpu_mem_fraction to calc_engine_setting for trtllm-bench throughput command Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tensorrt_llm/bench/benchmark/utils/general.py | 1 + tensorrt_llm/bench/build/build.py | 2 ++ tensorrt_llm/bench/build/tuning.py | 1 - 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/bench/benchmark/utils/general.py b/tensorrt_llm/bench/benchmark/utils/general.py index ed6712a6c1..9afcf2e656 100755 --- a/tensorrt_llm/bench/benchmark/utils/general.py +++ b/tensorrt_llm/bench/benchmark/utils/general.py @@ -132,6 +132,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str, params.get("pp"), dataset_metadata.avg_isl, dataset_metadata.avg_osl, + params.get("kv_cache_free_gpu_mem_fraction"), ) logger.info( diff --git a/tensorrt_llm/bench/build/build.py b/tensorrt_llm/bench/build/build.py index e3bd6cbef5..586ecb20cf 100644 --- a/tensorrt_llm/bench/build/build.py +++ b/tensorrt_llm/bench/build/build.py @@ -31,6 +31,7 @@ def get_benchmark_engine_settings( pp_size: int, target_input_len: int, target_output_len: int, + kv_cache_gpu_mem_fraction: float = 0.95, ) -> Tuple[int, int]: """ Retrieve benchmark settings for a specific model + configuration. @@ -58,6 +59,7 @@ def get_benchmark_engine_settings( pp_size, target_input_len, target_output_len, + kv_cache_gpu_mem_fraction, ) else: max_batch_size = DEFAULT_MAX_BATCH_SIZE diff --git a/tensorrt_llm/bench/build/tuning.py b/tensorrt_llm/bench/build/tuning.py index 560586bd3a..78c573dcfa 100755 --- a/tensorrt_llm/bench/build/tuning.py +++ b/tensorrt_llm/bench/build/tuning.py @@ -51,7 +51,6 @@ def calc_engine_setting( Tuple[int, int]: Tuple containing engine configuration information for engine build (max_num_tokens, max_batch_size). """ - kv_cache_gpu_mem_fraction = 0.9 byte_per_elem = BYTES_PER_ELEM.get(quant_config.quant_algo, 2) byte_per_kv_elem = BYTES_PER_ELEM.get(quant_config.kv_cache_quant_algo, 2) From 42d65f3548ac99b71ca34052a3e8f445cdecedeb Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Thu, 19 Jun 2025 10:09:53 +0000 Subject: [PATCH 05/13] release mamba cache memory when shutting down MambaCacheManager (and MambaHybridCacheManager) Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/resource_manager.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index ec8b3e705f..3b5a61164b 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -631,6 +631,13 @@ def get_ssm_states(self, layer_idx: int) -> torch.Tensor: layer_offset = self.mamba_layer_offsets[layer_idx] return self.ssm_states[layer_offset] + def shutdown(self): + # release tensor memory, keeping python references as tensors + self.conv_states = torch.tensor([]) + self.ssm_states = torch.tensor([]) + self.state_indices = torch.tensor([]) + torch.cuda.empty_cache() + class MambaHybridCacheManager(KVCacheManager, MambaCacheManager): @@ -708,6 +715,10 @@ def free_resources(self, request: LlmRequest): self.free_mamba_resources(request) super().free_resources(request) + def shutdown(self): + MambaCacheManager.shutdown(self) + KVCacheManager.shutdown(self) + class BaseDraftTokenManager(BaseResourceManager): From 17d22e50f71f190ef5eb0ac24a830acab0733fe9 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Thu, 19 Jun 2025 10:53:48 +0000 Subject: [PATCH 06/13] small refactor - MambaCacheManager method names to match BaseResourceManager, and explicit call to MambaCacheManager and KVCacheManager functions in MambaHybridCacheManager to reduce confusion Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- .../_torch/pyexecutor/resource_manager.py | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 3b5a61164b..0c972dd2fc 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -585,7 +585,7 @@ def __init__( device=device, dtype=torch.int32) - def prepare_mamba_cache_blocks(self, request_ids: List[int]): + def _prepare_mamba_cache_blocks(self, request_ids: List[int]): state_indices = [] for r in request_ids: # cache hit @@ -602,12 +602,7 @@ def prepare_mamba_cache_blocks(self, request_ids: List[int]): dtype=torch.int32, device=self.ssm_states.device) - def free_mamba_cache_blocks(self, request_id: int): - if request_id in self.mamba_cache_index: - block = self.mamba_cache_index.pop(request_id) - self.mamba_cache_free_blocks.append(block) - - def prepare_mamba_resources(self, scheduled_batch: ScheduledRequests): + def prepare_resources(self, scheduled_batch: ScheduledRequests): context_ids = [ i.py_request_id for i in scheduled_batch.context_requests ] @@ -615,10 +610,13 @@ def prepare_mamba_resources(self, scheduled_batch: ScheduledRequests): i.py_request_id for i in scheduled_batch.generation_requests ] request_ids = context_ids + generation_ids - self.prepare_mamba_cache_blocks(request_ids) + self._prepare_mamba_cache_blocks(request_ids) - def free_mamba_resources(self, request: LlmRequest): - self.free_mamba_cache_blocks(request.py_request_id) + def free_resources(self, request: LlmRequest): + request_id = request.py_request_id + if request_id in self.mamba_cache_index: + block = self.mamba_cache_index.pop(request_id) + self.mamba_cache_free_blocks.append(block) def get_state_indices(self) -> torch.Tensor: return self.state_indices @@ -708,12 +706,12 @@ def __init__( ) def prepare_resources(self, scheduled_batch: ScheduledRequests): - self.prepare_mamba_resources(scheduled_batch) - super().prepare_resources(scheduled_batch) + MambaCacheManager.prepare_resources(self, scheduled_batch) + KVCacheManager.prepare_resources(self, scheduled_batch) def free_resources(self, request: LlmRequest): - self.free_mamba_resources(request) - super().free_resources(request) + MambaCacheManager.free_resources(self, request) + KVCacheManager.free_resources(self, request) def shutdown(self): MambaCacheManager.shutdown(self) From 7dfeab8ec1cd2f282b11aebe6b329c06f8539e56 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Thu, 19 Jun 2025 11:03:50 +0000 Subject: [PATCH 07/13] refactor - is_nemotron_hybrid works on dicts as well Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/config_utils.py | 6 ++---- tensorrt_llm/bench/build/dataclasses.py | 8 ++++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/config_utils.py b/tensorrt_llm/_torch/pyexecutor/config_utils.py index aedb092141..11a7c7a884 100644 --- a/tensorrt_llm/_torch/pyexecutor/config_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/config_utils.py @@ -1,8 +1,6 @@ def is_nemotron_hybrid(config): - if hasattr(config, "hybrid_override_pattern" - ) and config.hybrid_override_pattern is not None: - return True - return False + return getattr(config, "hybrid_override_pattern", None) is not None \ + or config.get("hybrid_override_pattern", None) is not None def is_mla(config): diff --git a/tensorrt_llm/bench/build/dataclasses.py b/tensorrt_llm/bench/build/dataclasses.py index ab5a74bbec..414ad3a757 100755 --- a/tensorrt_llm/bench/build/dataclasses.py +++ b/tensorrt_llm/bench/build/dataclasses.py @@ -216,14 +216,14 @@ def get_param_count(cls, model_hf_name, hf_model_path): @classmethod def from_hf(cls, model_hf_name, hf_model_path): model_name_or_path = hf_model_path or model_hf_name - hf_config = AutoConfig.from_pretrained(model_name_or_path, - trust_remote_code=True) + hf_config = AutoConfig.from_pretrained( + model_name_or_path, trust_remote_code=True).to_dict() param_count = cls.get_param_count(model_hf_name, hf_model_path) mamba_config = MambaConfig( - **hf_config.to_dict()) if is_nemotron_hybrid(hf_config) else None + **hf_config) if is_nemotron_hybrid(hf_config) else None return cls(name=model_hf_name, param_count=param_count, mamba_config=mamba_config, - **hf_config.to_dict()) + **hf_config) From ee85bac5abe25b7747c343042f036ad13d33f285 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Thu, 19 Jun 2025 11:07:04 +0000 Subject: [PATCH 08/13] remove log Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tensorrt_llm/bench/build/tuning.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorrt_llm/bench/build/tuning.py b/tensorrt_llm/bench/build/tuning.py index 78c573dcfa..7041268874 100755 --- a/tensorrt_llm/bench/build/tuning.py +++ b/tensorrt_llm/bench/build/tuning.py @@ -191,7 +191,6 @@ def finetune_setting( # Fine-tune the max num tokens. # Set min to 2048 to ensure Ctx/Gen overlap efficiency - logger.info(f"Estimated max num tokens (before fine-tune): {raw_token}") if raw_token < 4096: max_token = max(2048, 256 * math.ceil(raw_token / 256)) elif raw_token < 8192: From d0d0b7e4ee13bd1669bc9083982432c3e984990d Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Thu, 19 Jun 2025 12:40:20 +0000 Subject: [PATCH 09/13] Add comment explaining squaring of kv_cache_gpu_mem_fraction + save result of is_nemotron_hybrid to increase readability Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tensorrt_llm/bench/build/tuning.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/bench/build/tuning.py b/tensorrt_llm/bench/build/tuning.py index 7041268874..6b3c6b3f05 100755 --- a/tensorrt_llm/bench/build/tuning.py +++ b/tensorrt_llm/bench/build/tuning.py @@ -57,7 +57,8 @@ def calc_engine_setting( # Each GPU in TP group has at least 1 kv head adjusted_num_kv_heads = max(tp_size, model_config.num_key_value_heads) - if is_nemotron_hybrid(model_config): + is_mamba_attn_hybrid = is_nemotron_hybrid(model_config) + if is_mamba_attn_hybrid: num_attention_layers = model_config.hybrid_override_pattern.count("*") else: num_attention_layers = model_config.num_hidden_layers @@ -80,7 +81,7 @@ def calc_engine_setting( # Calculate max requests in KV cache based on target ISL and OSL. target_seq_len = target_input_len + target_output_len - if is_nemotron_hybrid(model_config): + if is_mamba_attn_hybrid: num_mamba_layers = model_config.hybrid_override_pattern.count("M") conv_dim = model_config.mamba_config.d_inner + 2 * model_config.mamba_config.n_groups * model_config.mamba_config.d_state num_conv_state_elements = num_mamba_layers * conv_dim * ( @@ -90,6 +91,9 @@ def calc_engine_setting( byte_per_mamba_cache = byte_per_state_elem * ( num_conv_state_elements + num_ssm_state_elements) / (1024**3) + print(f"byte_per_mamba_cache: {byte_per_mamba_cache}") + + # Each mamba cache entry is pretty large (~50MB), so we are more conservative when estimating the max batch size kv_cache_gpu_mem_fraction *= kv_cache_gpu_mem_fraction else: byte_per_mamba_cache = 0 @@ -100,7 +104,7 @@ def calc_engine_setting( mamba_cache_memory = byte_per_mamba_cache * kv_cache_max_requests kv_cache_max_tokens = (cache_memory - mamba_cache_memory) / byte_per_token - if is_nemotron_hybrid(model_config): + if is_mamba_attn_hybrid: kv_cache_memory = kv_cache_max_tokens * byte_per_token logger.info( f"Estimated total cache memory: {cache_memory:.2f} GB. KV cache: {kv_cache_memory:.2f} GB, Mamba cache: {mamba_cache_memory:.2f} GB" From 63bea92171577da59702d9c6f833440088520982 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Thu, 19 Jun 2025 12:44:36 +0000 Subject: [PATCH 10/13] remove debug print Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tensorrt_llm/bench/build/tuning.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorrt_llm/bench/build/tuning.py b/tensorrt_llm/bench/build/tuning.py index 6b3c6b3f05..2cbe9becc9 100755 --- a/tensorrt_llm/bench/build/tuning.py +++ b/tensorrt_llm/bench/build/tuning.py @@ -91,8 +91,6 @@ def calc_engine_setting( byte_per_mamba_cache = byte_per_state_elem * ( num_conv_state_elements + num_ssm_state_elements) / (1024**3) - print(f"byte_per_mamba_cache: {byte_per_mamba_cache}") - # Each mamba cache entry is pretty large (~50MB), so we are more conservative when estimating the max batch size kv_cache_gpu_mem_fraction *= kv_cache_gpu_mem_fraction else: From c8c71dfa1e225fc1619a5b928d9e69a5bd7cd278 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Thu, 19 Jun 2025 16:13:39 +0000 Subject: [PATCH 11/13] fix - use config.get() only if config is a dict Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/config_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/config_utils.py b/tensorrt_llm/_torch/pyexecutor/config_utils.py index 11a7c7a884..22205a7cd2 100644 --- a/tensorrt_llm/_torch/pyexecutor/config_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/config_utils.py @@ -1,6 +1,6 @@ def is_nemotron_hybrid(config): return getattr(config, "hybrid_override_pattern", None) is not None \ - or config.get("hybrid_override_pattern", None) is not None + or (isinstance(config, dict) and config.get("hybrid_override_pattern", None) is not None) def is_mla(config): From 83e0673d29ae9a41c3c3625aa6d21bb51ae7b446 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Wed, 25 Jun 2025 07:55:22 +0000 Subject: [PATCH 12/13] optimistic tune max batch size only if not mamba attention hybrid model Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tensorrt_llm/bench/build/tuning.py | 38 +++++++++++++++++------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/tensorrt_llm/bench/build/tuning.py b/tensorrt_llm/bench/build/tuning.py index 2cbe9becc9..d49315d51b 100755 --- a/tensorrt_llm/bench/build/tuning.py +++ b/tensorrt_llm/bench/build/tuning.py @@ -114,10 +114,13 @@ def calc_engine_setting( f"{kv_cache_max_requests:.2f}") # Fine-tune the max batch size and num token setting for performance. - max_batch_size, max_num_tokens = finetune_setting(kv_cache_max_requests, - target_input_len, - target_output_len, - pp_size) + # For mamba-attn hybrid models, we disable optimistic tuning because the mamba cache leaves less memory for the KV cache + max_batch_size, max_num_tokens = finetune_setting( + kv_cache_max_requests, + target_input_len, + target_output_len, + pp_size, + enable_optimistic_tuning=not is_mamba_attn_hybrid) # Functional and performance if total_gpu_memory < engine_size: @@ -156,12 +159,11 @@ def calc_engine_setting( return max_batch_size, max_num_tokens -def finetune_setting( - kv_cache_max_requests: float, - input_len: int, - output_len: int, - pp_size: int, -) -> Tuple[int, int]: +def finetune_setting(kv_cache_max_requests: float, + input_len: int, + output_len: int, + pp_size: int, + enable_optimistic_tuning: bool = True) -> Tuple[int, int]: """ Calculate and fine-tune the engine build settings (max batch size and max num tokens). Both max batch size and max num tokens are fine-tuned to be slightly optimistic. @@ -172,6 +174,7 @@ def finetune_setting( input_len (int): Input sequence length to compile the engine. output_len (int): Output sequence length to compile the engine. pp_size (int): Number of pipeline parallel stages. + enable_optimistic_tuning (bool): Whether to enable optimistic tuning. Returns: Tuple[int, int]: Tuple containing fine-tuned values for engine @@ -183,13 +186,16 @@ def finetune_setting( raw_token = min(raw_bs * (1 + input_len / output_len), 32768) # Fine-tune the max batch size. - # Set min BS to be 64. - if raw_bs < 256: - max_bs = max(64, 32 * math.ceil(raw_bs / 32)) - elif raw_bs < 1024: - max_bs = 128 * math.ceil(raw_bs / 128) + if enable_optimistic_tuning: + # Set min BS to be 64. + if raw_bs < 256: + max_bs = max(64, 32 * math.ceil(raw_bs / 32)) + elif raw_bs < 1024: + max_bs = 128 * math.ceil(raw_bs / 128) + else: + max_bs = 256 * math.ceil(raw_bs / 256) else: - max_bs = 256 * math.ceil(raw_bs / 256) + max_bs = 2 * math.floor(raw_bs / 2) # Fine-tune the max num tokens. # Set min to 2048 to ensure Ctx/Gen overlap efficiency From aa5d87c92e34be1858bcab984e6147ebb8c4f084 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Thu, 26 Jun 2025 13:50:58 +0000 Subject: [PATCH 13/13] fix: Mamba cache size estimation for FP8 - always use NO_QUANT for mamba cache bytes_per_elem since it's not quantized Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- tensorrt_llm/bench/build/tuning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/bench/build/tuning.py b/tensorrt_llm/bench/build/tuning.py index d49315d51b..d1fa809bc8 100755 --- a/tensorrt_llm/bench/build/tuning.py +++ b/tensorrt_llm/bench/build/tuning.py @@ -87,7 +87,7 @@ def calc_engine_setting( num_conv_state_elements = num_mamba_layers * conv_dim * ( model_config.mamba_config.d_conv - 1) num_ssm_state_elements = num_mamba_layers * model_config.mamba_config.n_heads * model_config.mamba_config.head_dim * model_config.mamba_config.d_state - byte_per_state_elem = BYTES_PER_ELEM.get(quant_config.quant_algo, 2) + byte_per_state_elem = BYTES_PER_ELEM.get(QuantAlgo.NO_QUANT) byte_per_mamba_cache = byte_per_state_elem * ( num_conv_state_elements + num_ssm_state_elements) / (1024**3)