Skip to content

Commit d6922a2

Browse files
sleep mode level 1
Signed-off-by: Kacper Pietkun <[email protected]>
1 parent ab65f9b commit d6922a2

File tree

2 files changed

+90
-1
lines changed

2 files changed

+90
-1
lines changed

vllm_gaudi/platform.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
138138
# V1 support on HPU is experimental
139139
return True
140140

141+
@classmethod
142+
def is_sleep_mode_available(cls) -> bool:
143+
return True
144+
141145
@classmethod
142146
def set_torch_compile(cls) -> None:
143147
# NOTE: PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)

vllm_gaudi/v1/worker/hpu_worker.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
import torch.distributed
1212
import torch.nn as nn
13+
import habana_frameworks.torch as htorch
1314
from vllm.tasks import SupportedTask
1415
from vllm_gaudi.extension.profiler import HabanaMemoryProfiler, format_bytes
1516

@@ -77,6 +78,9 @@ def __init__(
7778
self.gc_track_recompiles = bool(
7879
"PT_HPU_METRICS_GC_DETAILS" in os.environ
7980
and bool_helper(os.getenv("PT_HPU_METRICS_GC_DETAILS")))
81+
82+
self.model_sleeping = False
83+
self.kv_cache_sleeping = False
8084

8185
def init_profiler(self):
8286
"""Initialize the profiler."""
@@ -273,6 +277,88 @@ def profile(self, is_start: bool = True):
273277
else:
274278
self.profiler.stop()
275279

280+
def sleep(self, level: int = 1) -> None:
281+
"""Put the worker into sleep mode to reduce memory usage. Unlike GPU workers that use custom
282+
memory allocators, HPU workers use a simpler approach of moving model to CPU and clearing KV cache.
283+
Args:
284+
level (int): Sleep level (kept for interface compatibility, always performs level 1 operations)
285+
"""
286+
287+
assert not htorch.utils.internal.is_lazy(
288+
) or self.model_config.enforce_eager, "Sleep mode is supported only for torch.compile mode"
289+
290+
# Handle model - if model was loaded move it to CPU
291+
if self.model_sleeping:
292+
logger.warning("Model is already in a sleep mode, skipping moving it to CPU")
293+
elif hasattr(self.model_runner, "model") and self.model_runner.model is not None:
294+
with HabanaMemoryProfiler() as m:
295+
self.model_runner.model.to("cpu")
296+
torch.hpu.synchronize()
297+
msg = f"Moving model to CPU for sleep mode took {m.get_summary_string()}"
298+
logger.info(msg)
299+
self.model_sleeping = True
300+
else:
301+
logger.warning("Model was not loaded yet, skipping moving it to CPU")
302+
303+
# Handle KV cache - discard it
304+
if self.kv_cache_sleeping:
305+
logger.warning("KV cache is already in a sleep mode, skipping discarding it")
306+
else:
307+
with HabanaMemoryProfiler() as m:
308+
for ve in range(self.parallel_config.pipeline_parallel_size):
309+
del self.cache_engine[ve].gpu_cache
310+
del self.cache_engine[ve].cpu_cache
311+
self.cache_engine.clear()
312+
self.hpu_cache.clear()
313+
self.hpu_cache = None
314+
for layer_name in self.compilation_config.static_forward_context:
315+
self.compilation_config.static_forward_context[layer_name].kv_cache.clear()
316+
self.compilation_config.static_forward_context[layer_name].kv_cache = [
317+
torch.tensor([]) for _ in range(self.parallel_config.pipeline_parallel_size)
318+
]
319+
torch.hpu.synchronize()
320+
msg = f"Discarding KV cache for sleep mode took {m.get_summary_string()}"
321+
logger.info(msg)
322+
self.kv_cache_sleeping = True
323+
324+
def wake_up(self, tags: list[str] | None = None) -> None:
325+
"""Wake up the worker from sleep mode. Moves the model back to HPU and optionally reinitializes KV cache.
326+
327+
Args:
328+
tags: Optional list of tags (kept for interface compatibility)
329+
"""
330+
assert not htorch.utils.internal.is_lazy(
331+
) or self.model_config.enforce_eager, "Sleep mode is supported only for torch.compile mode"
332+
333+
if tags is None:
334+
tags = ["weights", "kv_cache"]
335+
336+
# Handle model - if model was loaded, move it back to HPU
337+
if "weights" in tags:
338+
if not self.model_sleeping:
339+
logger.warning("Model is not in a sleep mode, skipping moving it to HPU")
340+
elif hasattr(self.model_runner, "model") and self.model_runner.model is not None:
341+
with HabanaMemoryProfiler() as m:
342+
self.model_runner.model.to(self.device)
343+
torch.hpu.synchronize()
344+
msg = f"Waking up model, moving it back to HPU took {m.get_summary_string()}"
345+
logger.info(msg)
346+
self.model_sleeping = False
347+
else:
348+
logger.warning("Model was not loaded yet, skipping moving it to HPU")
349+
350+
# Handle KV cache - reinitialize it
351+
if "kv_cache" in tags:
352+
if not self.kv_cache_sleeping:
353+
logger.warning("KV cache is not in a sleep mode, skipping reinitializing it")
354+
else:
355+
with HabanaMemoryProfiler() as m:
356+
self._init_cache_engine()
357+
torch.hpu.synchronize()
358+
msg = f"Waking up KV cache, reinitializing it took {m.get_summary_string()}"
359+
logger.info(msg)
360+
self.kv_cache_sleeping = False
361+
276362

277363
def init_worker_distributed_environment(
278364
parallel_config: ParallelConfig,
@@ -297,7 +383,6 @@ def init_worker_distributed_environment(
297383

298384
@contextmanager
299385
def track_graph_compile(name: str):
300-
import habana_frameworks.torch as htorch
301386
from habana_frameworks.torch.hpu.metrics import metric_localcontext
302387
with metric_localcontext("graph_compilation") as gc:
303388
yield

0 commit comments

Comments
 (0)