diff --git a/vllm_gaudi/platform.py b/vllm_gaudi/platform.py index d05dce25..0f12ce30 100644 --- a/vllm_gaudi/platform.py +++ b/vllm_gaudi/platform.py @@ -82,6 +82,24 @@ def set_device(cls, device: torch.device) -> None: def get_device_name(cls, device_id: int = 0) -> str: return cls.device_name + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + """Get the total memory of a device in bytes.""" + # NOTE: This is a workaround. + # The correct implementation of the method in this place should look as follows: + # total_hpu_memory = torch.hpu.mem_get_info()[1] + # A value of 0 is returned to preserve the current logic in + # vllm/vllm/engine/arg_utils.py → get_batch_defaults() → + # default_max_num_batched_tokens, in order to avoid the + # error in hpu_perf_test, while also preventing a + # NotImplementedError in test_defaults_with_usage_context. + logger.warning("This is a workaround! Please check the NOTE " + "in the get_device_total_memory definition.") + + total_hpu_memory = 0 + + return total_hpu_memory + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config