Skip to content

Commit f5522d0

Browse files
Add get_device_total_memory method
Signed-off-by: Paweł Olejniczak <[email protected]>
1 parent 856f980 commit f5522d0

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

vllm_gaudi/platform.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ def set_device(cls, device: torch.device) -> None:
8282
def get_device_name(cls, device_id: int = 0) -> str:
8383
return cls.device_name
8484

85+
@classmethod
86+
def get_device_total_memory(cls, device_id: int = 0) -> int:
87+
"""Get the total memory of a device in bytes."""
88+
total_hpu_memory = torch.hpu.mem_get_info()[1]
89+
return total_hpu_memory
90+
8591
@classmethod
8692
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
8793
parallel_config = vllm_config.parallel_config

0 commit comments

Comments
 (0)