1010import torch
1111import torch .distributed
1212import torch .nn as nn
13+ import habana_frameworks .torch as htorch
1314from vllm .tasks import SupportedTask
1415from vllm_gaudi .extension .debug import init_debug_logger
1516from vllm_gaudi .extension .profiler import (HabanaMemoryProfiler , format_bytes , setup_profiler )
@@ -93,6 +94,10 @@ def __init__(
9394 self .step_profiler = setup_step_profiler (self .profile_steps )
9495 self .step_debug = init_debug_logger ('steps' )
9596
97+ self .model_sleeping = False
98+ self .kv_cache_sleeping = False
99+ self .kv_cache_config = None
100+
96101 def init_profiler (self ):
97102 """Initialize the profiler."""
98103 if envs .VLLM_TORCH_PROFILER_DIR :
@@ -233,6 +238,7 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
233238 """Allocate GPU KV cache with the specified kv_cache_config."""
234239
235240 with HabanaMemoryProfiler () as m :
241+ self .kv_cache_config = kv_cache_config
236242 self .model_runner .initialize_kv_cache (kv_cache_config )
237243 torch .hpu .synchronize ()
238244 msg = (f"Usable num_blocks: { kv_cache_config .num_blocks } , "
@@ -316,6 +322,92 @@ def get_kv_connector_handshake_metadata(self) -> dict | None:
316322 tp_rank = get_tp_group ().rank_in_group
317323 return {tp_rank : metadata }
318324
325+ def sleep (self , level : int = 1 ) -> None :
326+ """Put the worker into sleep mode to reduce memory usage. Unlike GPU workers that use custom
327+ memory allocators, HPU workers use a simpler approach of moving model to CPU and clearing KV cache.
328+ Args:
329+ level (int): Sleep level (kept for interface compatibility, always performs level 1 operations)
330+ """
331+
332+ assert level == 1 , f"Currently, HPU supports only sleep mode level 1 (and not: level { level } )"
333+ assert not htorch .utils .internal .is_lazy (
334+ ) or self .model_config .enforce_eager , "Sleep mode is supported only for torch.compile mode"
335+
336+ # Handle model - if model was loaded move it to CPU
337+ if self .model_sleeping :
338+ logger .warning ("Model is already in a sleep mode, skipping moving it to CPU" )
339+ elif not hasattr (self .model_runner , "model" ) or self .model_runner .model is None :
340+ logger .warning ("Model was not loaded yet, skipping moving it to CPU" )
341+ else :
342+ with HabanaMemoryProfiler () as m :
343+ self .model_runner .model .to ("cpu" )
344+ gc .collect ()
345+ torch .hpu .synchronize ()
346+ msg = f"Moving model to CPU for sleep mode took { m .get_summary_string ()} "
347+ logger .info (msg )
348+ self .model_sleeping = True
349+
350+ # Handle KV cache - discard it
351+ if self .kv_cache_sleeping :
352+ logger .warning ("KV cache has already been discarded by calling sleep method and it has not been reinitialized by calling wake up method yet, skipping discarding it again" )
353+ elif self .kv_cache_config is None :
354+ logger .warning ("KV cache has not been initialized yet, skipping discarding it" )
355+ else :
356+ with HabanaMemoryProfiler () as m :
357+ self .model_runner .kv_caches = []
358+
359+ forward_context = self .vllm_config .compilation_config .static_forward_context
360+ for layer_name in forward_context :
361+ forward_context [layer_name ].kv_cache = None
362+
363+ gc .collect ()
364+ torch .hpu .synchronize ()
365+ msg = f"Discarding KV cache for sleep mode took { m .get_summary_string ()} "
366+ logger .info (msg )
367+ self .kv_cache_sleeping = True
368+
369+ def wake_up (self , tags : list [str ] | None = None ) -> None :
370+ """Wake up the worker from sleep mode.
371+ It can move the model back to HPU and/or reinitialize KV cache.
372+
373+ Args:
374+ tags: Optional list of tags (kept for interface compatibility)
375+ """
376+ assert not htorch .utils .internal .is_lazy (
377+ ) or self .model_config .enforce_eager , "Sleep mode is supported only for torch.compile mode"
378+
379+ if tags is None :
380+ tags = ["weights" , "kv_cache" ]
381+
382+ # Handle model - if model was loaded, move it back to HPU
383+ if "weights" in tags :
384+ if not self .model_sleeping :
385+ logger .warning ("Model is not in a sleep mode, skipping moving it to HPU" )
386+ elif not hasattr (self .model_runner , "model" ) or self .model_runner .model is None :
387+ logger .warning ("Model was not loaded yet, skipping moving it to HPU" )
388+ else :
389+ with HabanaMemoryProfiler () as m :
390+ self .model_runner .model .to (self .vllm_config .device_config .device )
391+ gc .collect ()
392+ torch .hpu .synchronize ()
393+ msg = f"Waking up model, moving it back to HPU took { m .get_summary_string ()} "
394+ logger .info (msg )
395+ self .model_sleeping = False
396+
397+ # Handle KV cache - reinitialize it
398+ if "kv_cache" in tags :
399+ if not self .kv_cache_sleeping :
400+ logger .warning ("KV cache is not in a sleep mode, skipping reinitializing it" )
401+ elif self .kv_cache_config is None :
402+ logger .warning ("KV cache config is empty, skipping reinitializing KV cache" )
403+ else :
404+ with HabanaMemoryProfiler () as m :
405+ self .model_runner .initialize_kv_cache (self .kv_cache_config )
406+ gc .collect ()
407+ torch .hpu .synchronize ()
408+ msg = f"Waking up KV cache, reinitializing it took { m .get_summary_string ()} "
409+ logger .info (msg )
410+ self .kv_cache_sleeping = False
319411
320412def init_worker_distributed_environment (
321413 vllm_config : VllmConfig ,
@@ -338,7 +430,6 @@ def init_worker_distributed_environment(
338430
339431@contextmanager
340432def track_graph_compile (name : str ):
341- import habana_frameworks .torch as htorch
342433 from habana_frameworks .torch .hpu .metrics import metric_localcontext
343434 with metric_localcontext ("graph_compilation" ) as gc :
344435 yield
0 commit comments