Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion tests/full_tests/ci_gsm8k_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,12 @@ run_pd_disaggregate_nixl_libfabric_test() {
echo "✅ PD disaggregate through NIXL libfabric."
}


# sleep mode
run_sleep_mode_test() {
echo "Testing basic model with sleep mode / wake up functionality"
HABANA_VISIBLE_DEVICES=all VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=0 VLLM_ENABLE_V1_MULTIPROCESSING=0 python -u "${VLLM_GAUDI_PREFIX}/tests/full_tests/sleep_mode.py" --model facebook/opt-125m
echo "✅ Test with sleep mode passed."
}

# --- Script Entry Point ---

Expand Down Expand Up @@ -316,6 +321,7 @@ launch_all_tests() {
run_spec_decode_eagle3_test
run_spec_decode_eagle3_num_spec_2_test
run_llama3_70b_inc_dynamic_quant_test
run_sleep_mode_test
#run_embedding_model_test
echo "🎉 All test suites passed successfully!"
}
Expand Down
80 changes: 80 additions & 0 deletions tests/full_tests/sleep_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm_gaudi.extension.profiler import HabanaMemoryProfiler


def create_parser():
parser = FlexibleArgumentParser()
# Add engine args
EngineArgs.add_cli_args(parser)
parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=False)
return parser


def print_outputs(outputs):
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)


def assert_model_device(model_runner, targert_device):
if model_runner:
params_devices = list(set([p.device for p in model_runner.model.parameters()]))
assert len(params_devices) == 1
assert params_devices[0].type == targert_device


def main(args):
"""
Test script to actually instantiate HPUWorker and test sleep/wakeup functionality.
This test creates a real HPUWorker instance and calls the methods.
"""
llm = LLM(**args)

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

multiproc = os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING")
model_runner = None
if multiproc == "0":
model_runner = llm.llm_engine.model_executor.driver_worker.worker.model_runner

outputs = llm.generate(prompts)
print_outputs(outputs)

for i in range(3):
with HabanaMemoryProfiler() as m:
llm.sleep()
assert m.consumed_device_memory < -60 * 1024 * 1024 * 1024 # check if more than 60GB was freed
assert_model_device(model_runner, "cpu")

with HabanaMemoryProfiler() as m:
llm.wake_up(["weights", "kv_cache"])
assert m.consumed_device_memory > 60 * 1024 * 1024 * 1024 # check if more than 60GB was allocated
assert_model_device(model_runner, "hpu")

outputs = llm.generate(prompts)
print_outputs(outputs)


if __name__ == "__main__":
parser = create_parser()
args: dict = vars(parser.parse_args())
try:
main(args)
except Exception:
import traceback
print("An error occurred during generation:")
traceback.print_exc()
os._exit(1)
3 changes: 3 additions & 0 deletions vllm_gaudi/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def get_nixl_memory_type(cls) -> str:
else:
return "DRAM"

def is_sleep_mode_available(cls) -> bool:
return True

@classmethod
def set_torch_compile(cls) -> None:
# NOTE: PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
Expand Down
98 changes: 97 additions & 1 deletion vllm_gaudi/v1/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import torch
import torch.distributed
import torch.nn as nn
import habana_frameworks.torch as htorch
from vllm.tasks import SupportedTask
from vllm_gaudi.extension.debug import init_debug_logger
from vllm_gaudi.extension.defragmentation import OnlineDefragmenter
from vllm_gaudi.extension.profiler import (HabanaMemoryProfiler, format_bytes, setup_profiler)
from vllm_gaudi.extension.runtime import get_config

Expand Down Expand Up @@ -93,6 +95,10 @@ def __init__(
self.step_profiler = setup_step_profiler(self.profile_steps)
self.step_debug = init_debug_logger('steps')

self.model_sleeping = False
self.kv_cache_sleeping = False
self.kv_cache_config = None

def init_profiler(self):
"""Initialize the profiler."""
if envs.VLLM_TORCH_PROFILER_DIR:
Expand Down Expand Up @@ -236,6 +242,7 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""

with HabanaMemoryProfiler() as m:
self.kv_cache_config = kv_cache_config
self.model_runner.initialize_kv_cache(kv_cache_config)
torch.hpu.synchronize()
msg = (f"Usable num_blocks: {kv_cache_config.num_blocks}, "
Expand Down Expand Up @@ -319,6 +326,96 @@ def get_kv_connector_handshake_metadata(self) -> dict | None:
tp_rank = get_tp_group().rank_in_group
return {tp_rank: metadata}

def sleep(self, level: int = 1) -> None:
"""Put the worker into sleep mode to reduce memory usage. Unlike GPU workers that use custom
memory allocators, HPU workers use a simpler approach of moving model to CPU and clearing KV cache.
Args:
level (int): Sleep level (kept for interface compatibility, always performs level 1 operations)
"""

if level == 2:
logger.warning("Currently, HPU does not support level 2 sleep mode. Performing level 1 operations")
assert not htorch.utils.internal.is_lazy(
) or self.model_config.enforce_eager, "Sleep mode is supported only for torch.compile mode"

# Handle model - if model was loaded move it to CPU
if self.model_sleeping:
logger.warning("Model is already in a sleep mode, skipping moving it to CPU")
elif not hasattr(self.model_runner, "model") or self.model_runner.model is None:
logger.warning("Model was not loaded yet, skipping moving it to CPU")
else:
with HabanaMemoryProfiler() as m:
self.model_runner.model.to("cpu")
gc.collect()
torch.hpu.synchronize()
msg = f"Moving model to CPU for sleep mode took {m.get_summary_string()}"
logger.info(msg)
self.model_sleeping = True

# Handle KV cache - discard it
if self.kv_cache_sleeping:
logger.warning("KV cache has already been discarded by calling sleep method and it has not been "
"reinitialized by calling wake up method yet, skipping discarding it again")
elif self.kv_cache_config is None:
logger.warning("KV cache has not been initialized yet, skipping discarding it")
else:
with HabanaMemoryProfiler() as m:
self.model_runner.defragmenter.cache_utils.kv_caches = None
self.model_runner.kv_caches = []
forward_context = self.vllm_config.compilation_config.static_forward_context
for layer_name in forward_context:
forward_context[layer_name].kv_cache = None
gc.collect()
torch.hpu.synchronize()
msg = f"Discarding KV cache for sleep mode took {m.get_summary_string()}"
logger.info(msg)
self.kv_cache_sleeping = True

def wake_up(self, tags: list[str] | None = None) -> None:
"""Wake up the worker from sleep mode.
It can move the model back to HPU and/or reinitialize KV cache.

Args:
tags: Optional list of tags (kept for interface compatibility)
"""
assert not htorch.utils.internal.is_lazy(
) or self.model_config.enforce_eager, "Sleep mode is supported only for torch.compile mode"

if tags is None:
tags = ["weights", "kv_cache"]

# Handle model - if model was loaded, move it back to HPU
if "weights" in tags:
if not self.model_sleeping:
logger.warning("Model is not in a sleep mode, skipping moving it to HPU")
elif not hasattr(self.model_runner, "model") or self.model_runner.model is None:
logger.warning("Model was not loaded yet, skipping moving it to HPU")
else:
with HabanaMemoryProfiler() as m:
self.model_runner.model.to(self.vllm_config.device_config.device)
gc.collect()
torch.hpu.synchronize()
msg = f"Waking up model, moving it back to HPU took {m.get_summary_string()}"
logger.info(msg)
self.model_sleeping = False

# Handle KV cache - reinitialize it
if "kv_cache" in tags:
if not self.kv_cache_sleeping:
logger.warning("KV cache is not in a sleep mode, skipping reinitializing it")
elif self.kv_cache_config is None:
logger.warning("KV cache config is empty, skipping reinitializing KV cache")
else:
with HabanaMemoryProfiler() as m:
self.model_runner.initialize_kv_cache(self.kv_cache_config)
self.model_runner.defragmenter = OnlineDefragmenter()
self.model_runner.defragmenter.initialize(self.model_runner.kv_caches, self.model_runner.block_size)
gc.collect()
torch.hpu.synchronize()
msg = f"Waking up KV cache, reinitializing it took {m.get_summary_string()}"
logger.info(msg)
self.kv_cache_sleeping = False


def init_worker_distributed_environment(
vllm_config: VllmConfig,
Expand All @@ -341,7 +438,6 @@ def init_worker_distributed_environment(

@contextmanager
def track_graph_compile(name: str):
import habana_frameworks.torch as htorch
from habana_frameworks.torch.hpu.metrics import metric_localcontext
with metric_localcontext("graph_compilation") as gc:
yield
Expand Down