1010import torch
1111import torch .distributed
1212import torch .nn as nn
13+ import habana_frameworks .torch as htorch
1314from vllm .tasks import SupportedTask
1415from vllm_gaudi .extension .profiler import HabanaMemoryProfiler , format_bytes
1516
@@ -77,6 +78,9 @@ def __init__(
7778 self .gc_track_recompiles = bool (
7879 "PT_HPU_METRICS_GC_DETAILS" in os .environ
7980 and bool_helper (os .getenv ("PT_HPU_METRICS_GC_DETAILS" )))
81+
82+ self .model_sleeping = False
83+ self .kv_cache_sleeping = False
8084
8185 def init_profiler (self ):
8286 """Initialize the profiler."""
@@ -273,6 +277,88 @@ def profile(self, is_start: bool = True):
273277 else :
274278 self .profiler .stop ()
275279
280+ def sleep (self , level : int = 1 ) -> None :
281+ """Put the worker into sleep mode to reduce memory usage. Unlike GPU workers that use custom
282+ memory allocators, HPU workers use a simpler approach of moving model to CPU and clearing KV cache.
283+ Args:
284+ level (int): Sleep level (kept for interface compatibility, always performs level 1 operations)
285+ """
286+
287+ assert not htorch .utils .internal .is_lazy (
288+ ) or self .model_config .enforce_eager , "Sleep mode is supported only for torch.compile mode"
289+
290+ # Handle model - if model was loaded move it to CPU
291+ if self .model_sleeping :
292+ logger .warning ("Model is already in a sleep mode, skipping moving it to CPU" )
293+ elif hasattr (self .model_runner , "model" ) and self .model_runner .model is not None :
294+ with HabanaMemoryProfiler () as m :
295+ self .model_runner .model .to ("cpu" )
296+ torch .hpu .synchronize ()
297+ msg = f"Moving model to CPU for sleep mode took { m .get_summary_string ()} "
298+ logger .info (msg )
299+ self .model_sleeping = True
300+ else :
301+ logger .warning ("Model was not loaded yet, skipping moving it to CPU" )
302+
303+ # Handle KV cache - discard it
304+ if self .kv_cache_sleeping :
305+ logger .warning ("KV cache is already in a sleep mode, skipping discarding it" )
306+ else :
307+ with HabanaMemoryProfiler () as m :
308+ for ve in range (self .parallel_config .pipeline_parallel_size ):
309+ del self .cache_engine [ve ].gpu_cache
310+ del self .cache_engine [ve ].cpu_cache
311+ self .cache_engine .clear ()
312+ self .hpu_cache .clear ()
313+ self .hpu_cache = None
314+ for layer_name in self .compilation_config .static_forward_context :
315+ self .compilation_config .static_forward_context [layer_name ].kv_cache .clear ()
316+ self .compilation_config .static_forward_context [layer_name ].kv_cache = [
317+ torch .tensor ([]) for _ in range (self .parallel_config .pipeline_parallel_size )
318+ ]
319+ torch .hpu .synchronize ()
320+ msg = f"Discarding KV cache for sleep mode took { m .get_summary_string ()} "
321+ logger .info (msg )
322+ self .kv_cache_sleeping = True
323+
324+ def wake_up (self , tags : list [str ] | None = None ) -> None :
325+ """Wake up the worker from sleep mode. Moves the model back to HPU and optionally reinitializes KV cache.
326+
327+ Args:
328+ tags: Optional list of tags (kept for interface compatibility)
329+ """
330+ assert not htorch .utils .internal .is_lazy (
331+ ) or self .model_config .enforce_eager , "Sleep mode is supported only for torch.compile mode"
332+
333+ if tags is None :
334+ tags = ["weights" , "kv_cache" ]
335+
336+ # Handle model - if model was loaded, move it back to HPU
337+ if "weights" in tags :
338+ if not self .model_sleeping :
339+ logger .warning ("Model is not in a sleep mode, skipping moving it to HPU" )
340+ elif hasattr (self .model_runner , "model" ) and self .model_runner .model is not None :
341+ with HabanaMemoryProfiler () as m :
342+ self .model_runner .model .to (self .device )
343+ torch .hpu .synchronize ()
344+ msg = f"Waking up model, moving it back to HPU took { m .get_summary_string ()} "
345+ logger .info (msg )
346+ self .model_sleeping = False
347+ else :
348+ logger .warning ("Model was not loaded yet, skipping moving it to HPU" )
349+
350+ # Handle KV cache - reinitialize it
351+ if "kv_cache" in tags :
352+ if not self .kv_cache_sleeping :
353+ logger .warning ("KV cache is not in a sleep mode, skipping reinitializing it" )
354+ else :
355+ with HabanaMemoryProfiler () as m :
356+ self ._init_cache_engine ()
357+ torch .hpu .synchronize ()
358+ msg = f"Waking up KV cache, reinitializing it took { m .get_summary_string ()} "
359+ logger .info (msg )
360+ self .kv_cache_sleeping = False
361+
276362
277363def init_worker_distributed_environment (
278364 parallel_config : ParallelConfig ,
@@ -297,7 +383,6 @@ def init_worker_distributed_environment(
297383
298384@contextmanager
299385def track_graph_compile (name : str ):
300- import habana_frameworks .torch as htorch
301386 from habana_frameworks .torch .hpu .metrics import metric_localcontext
302387 with metric_localcontext ("graph_compilation" ) as gc :
303388 yield
0 commit comments