Skip to content

Commit 2d2bae2

Browse files
[FIX] Suppression of stacktrace on a shutdown (#187)
--------- Signed-off-by: Wallas Santos <[email protected]> Co-authored-by: Travis Johnson <[email protected]>
1 parent 05fc3d9 commit 2d2bae2

File tree

3 files changed

+38
-0
lines changed

3 files changed

+38
-0
lines changed

tests/spyre_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ def generate_spyre_vllm_output(model: str, prompts: list[str],
174174
str(val) for val in warmup_batch_size)
175175
os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = backend
176176
os.environ['VLLM_USE_V1'] = "1" if vllm_version == "V1" else "0"
177+
# Allows to run multiprocess V1 engine without dumping meaningless logs at
178+
# shutdown engine this context.
179+
os.environ['VLLM_SPYRE_OVERRIDE_SIGNALS_HANDLER'] = "1"
177180

178181
vllm_model = LLM(model=model,
179182
tokenizer=model,

vllm_spyre/envs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
VLLM_SPYRE_RM_PADDED_BLOCKS: bool = False
1111
VLLM_SPYRE_PERF_METRIC_LOGGING_ENABLED: int = 0
1212
VLLM_SPYRE_PERF_METRIC_LOGGING_DIR: str = "/tmp"
13+
VLLM_SPYRE_OVERRIDE_SIGNALS_HANDLER: bool = False
1314

1415
# --8<-- [start:env-vars-definition]
1516
environment_variables: dict[str, Callable[[], Any]] = {
@@ -68,6 +69,12 @@
6869
# logs are written to /tmp.
6970
"VLLM_SPYRE_PERF_METRIC_LOGGING_DIR":
7071
lambda: os.getenv("VLLM_SPYRE_PERF_METRIC_LOGGING_DIR", "/tmp"),
72+
73+
# If set, override the signal handler for vllm-spyre on
74+
# vLLM V1 + torch_sendnn backend to be able to gracefully
75+
# shutdown the engine.
76+
"VLLM_SPYRE_OVERRIDE_SIGNALS_HANDLER":
77+
lambda: bool(int(os.getenv("VLLM_SPYRE_OVERRIDE_SIGNALS_HANDLER", "1"))),
7178
}
7279
# --8<-- [end:env-vars-definition]
7380

vllm_spyre/v1/worker/spyre_worker.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import os
44
import platform
5+
import signal
56
import time
67
from typing import Optional, Union, cast
78

@@ -406,6 +407,8 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
406407
logger.info("Warmup finished.")
407408
logger.info("Warmup took %.3fs", warmup_total_t)
408409

410+
maybe_override_signals_handler()
411+
409412
def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
410413
special_token_ids, batch_size):
411414

@@ -524,6 +527,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
524527
logger.info(
525528
"Warmup took %.3fs (for prompt length %d and max output tokens %d)",
526529
warmup_total_t, prompt_len, num_decode_tokens)
530+
maybe_override_signals_handler()
527531

528532
def _warmup_model_forward_pass(
529533
self,
@@ -566,3 +570,27 @@ def execute_model(
566570
) -> Optional[ModelRunnerOutput]:
567571
output = self.model_runner.execute_model(scheduler_output)
568572
return output if self.is_driver_worker else None
573+
574+
575+
# Ref: https://github.com/vllm-project/vllm/blob/5fbbfe9a4c13094ad72ed3d6b4ef208a7ddc0fd7/vllm/v1/executor/multiproc_executor.py#L446 # noqa: E501
576+
# TODO: review this in the future
577+
# This setup is a workaround to suppress logs that are dumped at the shutdown
578+
# of the engine (only on V1) when vllm runs with multiprocess. The undesired
579+
# behavior happens because g3log from Spyre runtime overrides the signal
580+
# handler from vLLM when it starts a process for the engine code. Therefore,
581+
# the engine does not have a chance to gracefully shutdown.
582+
def maybe_override_signals_handler():
583+
if not (envs.VLLM_USE_V1 and envs.VLLM_ENABLE_V1_MULTIPROCESSING
584+
and envs_spyre.VLLM_SPYRE_OVERRIDE_SIGNALS_HANDLER):
585+
return
586+
587+
shutdown_requested = False
588+
589+
def signal_handler(signum, frame):
590+
nonlocal shutdown_requested
591+
if not shutdown_requested:
592+
shutdown_requested = True
593+
raise SystemExit()
594+
595+
signal.signal(signal.SIGTERM, signal_handler)
596+
signal.signal(signal.SIGINT, signal_handler)

0 commit comments

Comments
 (0)