Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,6 @@ def __init__(
self._kv_cache_manager_cls = get_kv_cache_manager_cls(
model_engine.model.model_config)

def _get_free_gpu_memory_fraction(self) -> float:
fraction = self._kv_cache_config.free_gpu_memory_fraction
if fraction is None:
fraction = 0.9
return fraction

def _get_kv_size_per_token(self):
model_config = self._model_engine.model.model_config
mapping = self._mapping
Expand Down Expand Up @@ -299,7 +293,7 @@ def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None:
# TODO: support CP by generating dummy requests for it.
assert 'cp_type' not in mapping.cp_config

fraction = self._get_free_gpu_memory_fraction()
fraction = self._kv_cache_config.free_gpu_memory_fraction

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
Expand Down
12 changes: 11 additions & 1 deletion tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,7 @@ class KvCacheConfig(StrictBaseModel, PybindMirror):
description=
"Number of sink tokens (tokens to always keep in attention window).")
free_gpu_memory_fraction: Optional[float] = Field(
default=None,
default=0.9,
description=
"The fraction of GPU memory fraction that should be allocated for the KV cache. Default is 90%. If both `max_tokens` and `free_gpu_memory_fraction` are specified, memory corresponding to the minimum will be used."
)
Expand Down Expand Up @@ -1341,6 +1341,16 @@ def _to_pybind(self):
attention_dp_events_gather_period_ms,
max_gpu_total_bytes=self.max_gpu_total_bytes)

@field_validator('free_gpu_memory_fraction')
@classmethod
def validate_free_gpu_memory_fraction(cls, v: float):
"""Validates that the fraction is between 0.0 and 1.0."""
if not 0 <= v <= 1:
raise ValueError(
"kv_cache_config.free_gpu_memory_fraction must be a float between 0 and 1"
)
return v

@field_validator('max_gpu_total_bytes')
@classmethod
def validate_max_gpu_total_bytes(cls, v: int):
Expand Down
Loading