-
Notifications
You must be signed in to change notification settings - Fork 1.5k
[TRTLLM-5838][fix] fix max batch size and max tokens in kv cache estimations for Nemotron-H #5371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
tomeras91
wants to merge
19
commits into
NVIDIA:main
Choose a base branch
from
tomeras91:fix-trtllm-bench-for-nemotron-h
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
1752239
WIP: consider num_attention_layers for kv cache estimation and add ma…
tomeras91 7829ec9
Merge branch 'main' into fix-trtllm-bench-for-nemotron-h
tomeras91 4403183
organize code and logging for max batch size calculation for trtllm-b…
tomeras91 6ff4602
consider only attention layers when estimating number of tokens in Kv…
tomeras91 e6615a8
propagate kv_cache_gpu_mem_fraction to calc_engine_setting for trtllm…
tomeras91 42d65f3
release mamba cache memory when shutting down MambaCacheManager (and …
tomeras91 17d22e5
small refactor - MambaCacheManager method names to match BaseResource…
tomeras91 7dfeab8
refactor - is_nemotron_hybrid works on dicts as well
tomeras91 ee85bac
remove log
tomeras91 d0d0b7e
Add comment explaining squaring of kv_cache_gpu_mem_fraction + save r…
tomeras91 63bea92
remove debug print
tomeras91 c8c71df
fix - use config.get() only if config is a dict
tomeras91 3e6a30e
Merge branch 'main' into fix-trtllm-bench-for-nemotron-h
tomeras91 83e0673
optimistic tune max batch size only if not mamba attention hybrid model
tomeras91 4b2ba21
Merge branch 'main' into fix-trtllm-bench-for-nemotron-h
tomeras91 e6e65fc
Merge branch 'main' into fix-trtllm-bench-for-nemotron-h
tomeras91 8cf5ee7
Merge branch 'fix-trtllm-bench-for-nemotron-h' of github.com:tomeras9…
tomeras91 aa5d87c
fix: Mamba cache size estimation for FP8 - always use NO_QUANT for ma…
tomeras91 ac481b2
Merge branch 'main' into fix-trtllm-bench-for-nemotron-h
tomeras91 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ | |
import json | ||
import struct | ||
|
||
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid | ||
|
||
|
||
def parse_safetensors_file_metadata(model_path, filename): | ||
|
||
|
@@ -111,6 +113,29 @@ def _parse(filename: str) -> None: | |
return huggingface_hub.get_safetensors_metadata(model_name_or_path) | ||
|
||
|
||
class MambaConfig(BaseModel): | ||
d_model: int = Field( | ||
validation_alias=AliasChoices("d_model", "hidden_size", "n_embd")) | ||
d_state: int = Field( | ||
validation_alias=AliasChoices("d_state", "ssm_state_size")) | ||
d_conv: int = Field(validation_alias=AliasChoices("d_conv", "conv_kernel")) | ||
expand: int | ||
n_groups: int | ||
head_dim: int = Field( | ||
validation_alias=AliasChoices("head_dim", "mamba_head_dim")) | ||
d_inner: int = Field(default=None) | ||
n_heads: int = Field(default=None) | ||
|
||
@model_validator(mode="after") | ||
def set_values_if_none(self): | ||
""" Set the values if cannot get values from HF config.json. """ | ||
if not self.d_inner: | ||
self.d_inner = self.d_model * self.expand | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: fix this for Nemotron-H-4B |
||
if not self.n_heads: | ||
self.n_heads = self.d_inner // self.head_dim | ||
return self | ||
|
||
|
||
class ModelConfig(BaseModel): | ||
""" Model specific configurations. The parameters are needed in engine | ||
setting calculation. | ||
|
@@ -161,6 +186,8 @@ class ModelConfig(BaseModel): | |
None] = Field(default="float16", | ||
validation_alias=AliasChoices( | ||
"dtype", "torch_dtype")) | ||
hybrid_override_pattern: Optional[str] = Field(default=None) | ||
mamba_config: Optional[MambaConfig] = Field(default=None) | ||
|
||
@model_validator(mode="after") | ||
def set_values_if_none(self): | ||
|
@@ -193,4 +220,10 @@ def from_hf(cls, model_hf_name, hf_model_path): | |
model_name_or_path, trust_remote_code=True).to_dict() | ||
param_count = cls.get_param_count(model_hf_name, hf_model_path) | ||
|
||
return cls(name=model_hf_name, param_count=param_count, **hf_config) | ||
mamba_config = MambaConfig( | ||
**hf_config) if is_nemotron_hybrid(hf_config) else None | ||
|
||
return cls(name=model_hf_name, | ||
param_count=param_count, | ||
mamba_config=mamba_config, | ||
**hf_config) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The variable 'num_attention_layers' is being reassigned to represent the number of mapped pipeline layers instead of its original meaning. Consider using a new variable name (e.g., 'mapped_attention_layers') to preserve clarity and avoid confusion.
Copilot uses AI. Check for mistakes.