Skip to content

Commit 43c9d88

Browse files
sleep mode level 1
Signed-off-by: Kacper Pietkun <[email protected]>
1 parent e18a075 commit 43c9d88

File tree

3 files changed

+158
-1
lines changed

3 files changed

+158
-1
lines changed

tests/full_tests/sleep_mode.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import os
5+
from vllm import LLM, EngineArgs
6+
from vllm.utils.argparse_utils import FlexibleArgumentParser
7+
8+
9+
def create_parser():
10+
parser = FlexibleArgumentParser()
11+
# Add engine args
12+
EngineArgs.add_cli_args(parser)
13+
parser.set_defaults(model="Qwen/Qwen3-8B", enforce_eager=False)
14+
return parser
15+
16+
17+
def print_outputs(outputs):
18+
print("-" * 50)
19+
for output in outputs:
20+
prompt = output.prompt
21+
generated_text = output.outputs[0].text
22+
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
23+
print("-" * 50)
24+
25+
26+
def main(args):
27+
"""
28+
Test script to actually instantiate HPUWorker and test sleep/wakeup functionality.
29+
This test creates a real HPUWorker instance and calls the methods.
30+
"""
31+
llm = LLM(**args)
32+
33+
prompts = [
34+
"Hello, my name is",
35+
"The president of the United States is",
36+
"The capital of France is",
37+
"The future of AI is",
38+
]
39+
40+
outputs = llm.generate(prompts)
41+
print_outputs(outputs)
42+
43+
for i in range(3):
44+
assert llm.llm_engine.is_sleeping() == False
45+
llm.sleep()
46+
assert llm.llm_engine.is_sleeping() == True
47+
llm.wake_up(["weights"])
48+
assert llm.llm_engine.is_sleeping() == True
49+
llm.wake_up(["kv_cache"])
50+
assert llm.llm_engine.is_sleeping() == False
51+
outputs = llm.generate(prompts)
52+
print_outputs(outputs)
53+
54+
55+
if __name__ == "__main__":
56+
parser = create_parser()
57+
args: dict = vars(parser.parse_args())
58+
try:
59+
main(args)
60+
except Exception:
61+
import traceback
62+
print("An error occurred during generation:")
63+
traceback.print_exc()
64+
os._exit(1)

vllm_gaudi/platform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ def get_nixl_memory_type(cls) -> str:
168168
return "VRAM"
169169
else:
170170
return "DRAM"
171+
def is_sleep_mode_available(cls) -> bool:
172+
return True
171173

172174
@classmethod
173175
def set_torch_compile(cls) -> None:

vllm_gaudi/v1/worker/hpu_worker.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
import torch.distributed
1212
import torch.nn as nn
13+
import habana_frameworks.torch as htorch
1314
from vllm.tasks import SupportedTask
1415
from vllm_gaudi.extension.debug import init_debug_logger
1516
from vllm_gaudi.extension.profiler import (HabanaMemoryProfiler, format_bytes, setup_profiler)
@@ -93,6 +94,10 @@ def __init__(
9394
self.step_profiler = setup_step_profiler(self.profile_steps)
9495
self.step_debug = init_debug_logger('steps')
9596

97+
self.model_sleeping = False
98+
self.kv_cache_sleeping = False
99+
self.kv_cache_config = None
100+
96101
def init_profiler(self):
97102
"""Initialize the profiler."""
98103
if envs.VLLM_TORCH_PROFILER_DIR:
@@ -233,6 +238,7 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
233238
"""Allocate GPU KV cache with the specified kv_cache_config."""
234239

235240
with HabanaMemoryProfiler() as m:
241+
self.kv_cache_config = kv_cache_config
236242
self.model_runner.initialize_kv_cache(kv_cache_config)
237243
torch.hpu.synchronize()
238244
msg = (f"Usable num_blocks: {kv_cache_config.num_blocks}, "
@@ -316,6 +322,92 @@ def get_kv_connector_handshake_metadata(self) -> dict | None:
316322
tp_rank = get_tp_group().rank_in_group
317323
return {tp_rank: metadata}
318324

325+
def sleep(self, level: int = 1) -> None:
326+
"""Put the worker into sleep mode to reduce memory usage. Unlike GPU workers that use custom
327+
memory allocators, HPU workers use a simpler approach of moving model to CPU and clearing KV cache.
328+
Args:
329+
level (int): Sleep level (kept for interface compatibility, always performs level 1 operations)
330+
"""
331+
332+
assert level == 1, f"Currently, HPU supports only sleep mode level 1 (and not: level {level})"
333+
assert not htorch.utils.internal.is_lazy(
334+
) or self.model_config.enforce_eager, "Sleep mode is supported only for torch.compile mode"
335+
336+
# Handle model - if model was loaded move it to CPU
337+
if self.model_sleeping:
338+
logger.warning("Model is already in a sleep mode, skipping moving it to CPU")
339+
elif not hasattr(self.model_runner, "model") or self.model_runner.model is None:
340+
logger.warning("Model was not loaded yet, skipping moving it to CPU")
341+
else:
342+
with HabanaMemoryProfiler() as m:
343+
self.model_runner.model.to("cpu")
344+
gc.collect()
345+
torch.hpu.synchronize()
346+
msg = f"Moving model to CPU for sleep mode took {m.get_summary_string()}"
347+
logger.info(msg)
348+
self.model_sleeping = True
349+
350+
# Handle KV cache - discard it
351+
if self.kv_cache_sleeping:
352+
logger.warning("KV cache has already been discarded by calling sleep method and it has not been reinitialized by calling wake up method yet, skipping discarding it again")
353+
elif self.kv_cache_config is None:
354+
logger.warning("KV cache has not been initialized yet, skipping discarding it")
355+
else:
356+
with HabanaMemoryProfiler() as m:
357+
self.model_runner.kv_caches = []
358+
359+
forward_context = self.vllm_config.compilation_config.static_forward_context
360+
for layer_name in forward_context:
361+
forward_context[layer_name].kv_cache = None
362+
363+
gc.collect()
364+
torch.hpu.synchronize()
365+
msg = f"Discarding KV cache for sleep mode took {m.get_summary_string()}"
366+
logger.info(msg)
367+
self.kv_cache_sleeping = True
368+
369+
def wake_up(self, tags: list[str] | None = None) -> None:
370+
"""Wake up the worker from sleep mode.
371+
It can move the model back to HPU and/or reinitialize KV cache.
372+
373+
Args:
374+
tags: Optional list of tags (kept for interface compatibility)
375+
"""
376+
assert not htorch.utils.internal.is_lazy(
377+
) or self.model_config.enforce_eager, "Sleep mode is supported only for torch.compile mode"
378+
379+
if tags is None:
380+
tags = ["weights", "kv_cache"]
381+
382+
# Handle model - if model was loaded, move it back to HPU
383+
if "weights" in tags:
384+
if not self.model_sleeping:
385+
logger.warning("Model is not in a sleep mode, skipping moving it to HPU")
386+
elif not hasattr(self.model_runner, "model") or self.model_runner.model is None:
387+
logger.warning("Model was not loaded yet, skipping moving it to HPU")
388+
else:
389+
with HabanaMemoryProfiler() as m:
390+
self.model_runner.model.to(self.vllm_config.device_config.device)
391+
gc.collect()
392+
torch.hpu.synchronize()
393+
msg = f"Waking up model, moving it back to HPU took {m.get_summary_string()}"
394+
logger.info(msg)
395+
self.model_sleeping = False
396+
397+
# Handle KV cache - reinitialize it
398+
if "kv_cache" in tags:
399+
if not self.kv_cache_sleeping:
400+
logger.warning("KV cache is not in a sleep mode, skipping reinitializing it")
401+
elif self.kv_cache_config is None:
402+
logger.warning("KV cache config is empty, skipping reinitializing KV cache")
403+
else:
404+
with HabanaMemoryProfiler() as m:
405+
self.model_runner.initialize_kv_cache(self.kv_cache_config)
406+
gc.collect()
407+
torch.hpu.synchronize()
408+
msg = f"Waking up KV cache, reinitializing it took {m.get_summary_string()}"
409+
logger.info(msg)
410+
self.kv_cache_sleeping = False
319411

320412
def init_worker_distributed_environment(
321413
vllm_config: VllmConfig,
@@ -338,7 +430,6 @@ def init_worker_distributed_environment(
338430

339431
@contextmanager
340432
def track_graph_compile(name: str):
341-
import habana_frameworks.torch as htorch
342433
from habana_frameworks.torch.hpu.metrics import metric_localcontext
343434
with metric_localcontext("graph_compilation") as gc:
344435
yield

0 commit comments

Comments
 (0)