Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions vllm_gaudi/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,24 @@ def set_device(cls, device: torch.device) -> None:
def get_device_name(cls, device_id: int = 0) -> str:
return cls.device_name

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
"""Get the total memory of a device in bytes."""
# NOTE: This is a workaround.
# The correct implementation of the method in this place should look as follows:
# total_hpu_memory = torch.hpu.mem_get_info()[1]
# A value of 0 is returned to preserve the current logic in
# vllm/vllm/engine/arg_utils.py → get_batch_defaults() →
# default_max_num_batched_tokens, in order to avoid the
# error in hpu_perf_test, while also preventing a
# NotImplementedError in test_defaults_with_usage_context.
logger.warning("This is a workaround! Please check the NOTE "
"in the get_device_total_memory definition.")

total_hpu_memory = 0

return total_hpu_memory

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
Expand Down