diff --git a/neps/runtime.py b/neps/runtime.py index e3958322..7536f9b5 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -6,13 +6,15 @@ import logging import os import shutil +import signal import time from collections.abc import Callable, Iterator, Mapping from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, ClassVar, Literal +from typing import TYPE_CHECKING, Any, ClassVar, Literal +from pandas.core.common import contextlib from portalocker import portalocker from neps.env import ( @@ -55,6 +57,11 @@ def _default_worker_name() -> str: return f"{os.getpid()}-{isoformat}" +SIGNALS_TO_HANDLE_IF_AVAILABLE = [ + "SIGINT", + "SIGTERM", +] + _DDP_ENV_VAR_NAME = "NEPS_DDP_TRIAL_ID" @@ -182,6 +189,8 @@ class DefaultWorker: worker_cumulative_evaluation_time_seconds: float = 0.0 """The time spent evaluating configurations by this worker.""" + _PREVIOUS_SIGNAL_HANDLERS: dict[int, signal._HANDLER] = field(default_factory=dict) + _GRACE: ClassVar = FS_SYNC_GRACE_BASE @classmethod @@ -369,6 +378,16 @@ def _check_global_stopping_criterion( return False + def _set_signal_handlers(self) -> None: + for name in SIGNALS_TO_HANDLE_IF_AVAILABLE: + if hasattr(signal.Signals, name): + sig = getattr(signal.Signals, name) + # HACK: Despite what python documentation says, the existance of a signal + # is not enough to guarantee that it can be caught. + with contextlib.suppress(ValueError): + previous_signal_handler = signal.signal(sig, self._emergency_cleanup) + self._PREVIOUS_SIGNAL_HANDLERS[sig] = previous_signal_handler + @property def _requires_global_stopping_criterion(self) -> bool: return ( @@ -491,6 +510,7 @@ def run(self) -> None: # noqa: C901, PLR0912, PLR0915 Will keep running until one of the criterion defined by the `WorkerSettings` is met. """ + self._set_signal_handlers() _set_workers_neps_state(self.state) logger.info("Launching NePS") @@ -580,15 +600,21 @@ def run(self) -> None: # noqa: C901, PLR0912, PLR0915 continue # We (this worker) has managed to set it to evaluating, now we can evaluate it - with _set_global_trial(trial_to_eval): - evaluated_trial, report = evaluate_trial( - trial=trial_to_eval, - evaluation_fn=self.evaluation_fn, - default_report_values=self.settings.default_report_values, - ) - evaluation_duration = evaluated_trial.metadata.evaluation_duration - assert evaluation_duration is not None - self.worker_cumulative_evaluation_time_seconds += evaluation_duration + try: + with _set_global_trial(trial_to_eval): + evaluated_trial, report = evaluate_trial( + trial=trial_to_eval, + evaluation_fn=self.evaluation_fn, + default_report_values=self.settings.default_report_values, + ) + except KeyboardInterrupt as e: + # This throws and we have stopped the worker at this point + self._emergency_cleanup(signum=signal.SIGINT, frame=None, rethrow=e) + return + + evaluation_duration = evaluated_trial.metadata.evaluation_duration + assert evaluation_duration is not None + self.worker_cumulative_evaluation_time_seconds += evaluation_duration self.worker_cumulative_eval_count += 1 @@ -630,6 +656,39 @@ def run(self) -> None: # noqa: C901, PLR0912, PLR0915 "Learning Curve %s: %s", evaluated_trial.id, report.learning_curve ) + def _emergency_cleanup( + self, + signum: int, + frame: Any, + rethrow: KeyboardInterrupt | None = None, + ) -> None: + """Handle signals.""" + global _CURRENTLY_RUNNING_TRIAL_IN_PROCESS # noqa: PLW0603 + logger.error( + f"Worker '{self.worker_id}' received signal {signum}. Stopping worker now!" + ) + if _CURRENTLY_RUNNING_TRIAL_IN_PROCESS is not None: + logger.error( + "Worker '%s' was interrupted while evaluating trial: %s. Setting" + " trial to pending!", + self.worker_id, + _CURRENTLY_RUNNING_TRIAL_IN_PROCESS.id, + ) + _CURRENTLY_RUNNING_TRIAL_IN_PROCESS.reset() + try: + self.state.put_updated_trial(_CURRENTLY_RUNNING_TRIAL_IN_PROCESS) + except NePSError as e: + logger.exception(e) + finally: + _CURRENTLY_RUNNING_TRIAL_IN_PROCESS = None + + previous_handler = self._PREVIOUS_SIGNAL_HANDLERS.get(signum) + if previous_handler is not None and callable(previous_handler): + previous_handler(signum, frame) + if rethrow is not None: + raise rethrow + raise KeyboardInterrupt(f"Worker was interrupted by signal {signum}.") + def _launch_ddp_runtime( *, diff --git a/pyproject.toml b/pyproject.toml index cf1961d9..de3305d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ dependencies = [ "ifbo>=0.3.10", "botorch>=0.12", "gpytorch==1.13.0", + "psutil>=7.0.0", ] [project.urls] diff --git a/tests/test_runtime/test_error_handling_strategies.py b/tests/test_runtime/test_error_handling_strategies.py index c1762a02..dd2b3de6 100644 --- a/tests/test_runtime/test_error_handling_strategies.py +++ b/tests/test_runtime/test_error_handling_strategies.py @@ -1,16 +1,20 @@ from __future__ import annotations import contextlib +import multiprocessing +import signal +import time from dataclasses import dataclass from pathlib import Path +import psutil import pytest from pytest_cases import fixture, parametrize from neps.exceptions import WorkerRaiseError from neps.optimizers import OptimizerInfo from neps.optimizers.algorithms import random_search -from neps.runtime import DefaultWorker +from neps.runtime import SIGNALS_TO_HANDLE_IF_AVAILABLE, DefaultWorker from neps.space import Float, SearchSpace from neps.state import ( DefaultReportValues, @@ -209,3 +213,89 @@ def __call__(self, *args, **kwargs) -> float: # noqa: ARG002 assert neps_state.lock_and_get_next_pending_trial() is None assert len(neps_state.lock_and_get_errors()) == 1 + + +def sleep_function(*args, **kwargs) -> float: + time.sleep(20) + return 10 + + +SIGNALS: list[signal.Signals] = [] +for name in SIGNALS_TO_HANDLE_IF_AVAILABLE: + if hasattr(signal.Signals, name): + sig: signal.Signals = getattr(signal.Signals, name) + SIGNALS.append(sig) + + +# @pytest.mark.ci_examples +@pytest.mark.parametrize("signum", SIGNALS) +def test_worker_reset_evaluating_to_pending_on_ctrl_c( + signum: signal.Signals, + neps_state: NePSState, +) -> None: + optimizer = random_search(SearchSpace({"a": Float(0, 1)})) + settings = WorkerSettings( + on_error=OnErrorPossibilities.IGNORE, # <- Highlight + default_report_values=DefaultReportValues(), + max_evaluations_total=None, + include_in_progress_evaluations_towards_maximum=False, + max_cost_total=None, + max_evaluations_for_worker=1, + max_evaluation_time_total_seconds=None, + max_wallclock_time_for_worker_seconds=None, + max_evaluation_time_for_worker_seconds=None, + max_cost_for_worker=None, + batch_size=None, + ) + + worker1 = DefaultWorker.new( + state=neps_state, + optimizer=optimizer, + evaluation_fn=sleep_function, + settings=settings, + ) + + # Use multiprocessing.Process + p = multiprocessing.Process(target=worker1.run) + p.start() + + time.sleep(5) + assert p.pid is not None + assert p.is_alive() + + # Should be evaluating at this stage + trials = neps_state.lock_and_read_trials() + assert len(trials) == 1 + assert next(iter(trials.values())).metadata.state == Trial.State.EVALUATING + + # Kill the process while it's evaluating using signals + process = psutil.Process(p.pid) + + # If sending the signal fails, skip the test, + # as most likely the signal is not supported on this platform + try: + process.send_signal(signum) + except ValueError as e: + pytest.skip(f"Signal error: {e}") + else: + # If the signal is sent successfully, we can proceed with the test + pass + + # If the system is windows and the signal is SIGTERM, skip the test + if ( + signum == signal.SIGTERM + and multiprocessing.get_start_method() == "spawn" + and multiprocessing.current_process().name == "MainProcess" + ): + pytest.skip("SIGTERM is not supported on Windows with spawn start method") + + p.join(timeout=5) # Wait for the process to terminate + + if p.is_alive(): + p.terminate() # Force terminate if it's still alive + p.join() + pytest.fail("Worker did not terminate after receiving signal!") + else: + trials2 = neps_state.lock_and_read_trials() + assert len(trials2) == 1 + assert next(iter(trials2.values())).metadata.state == Trial.State.PENDING