diff --git a/nemo_curator/core/constants.py b/nemo_curator/core/constants.py index b76a11f52..f62640003 100644 --- a/nemo_curator/core/constants.py +++ b/nemo_curator/core/constants.py @@ -24,3 +24,4 @@ # We cannot use a free port between 10000 and 19999 as it is used by Ray. DEFAULT_RAY_MIN_WORKER_PORT = 10002 DEFAULT_RAY_MAX_WORKER_PORT = 19999 +RAY_CLUSTER_START_VERIFICATION_TIMEOUT = 300 diff --git a/nemo_curator/core/utils.py b/nemo_curator/core/utils.py index d1066b18a..98e4b25a7 100644 --- a/nemo_curator/core/utils.py +++ b/nemo_curator/core/utils.py @@ -18,13 +18,16 @@ from typing import TYPE_CHECKING import ray +import tenacity from loguru import logger +from ray._private.services import canonicalize_bootstrap_address, find_gcs_addresses from nemo_curator.core.constants import ( DEFAULT_RAY_AUTOSCALER_METRIC_PORT, DEFAULT_RAY_DASHBOARD_METRIC_PORT, DEFAULT_RAY_MAX_WORKER_PORT, DEFAULT_RAY_MIN_WORKER_PORT, + RAY_CLUSTER_START_VERIFICATION_TIMEOUT, ) if TYPE_CHECKING: @@ -68,6 +71,47 @@ def _logger_custom_deserializer( return logger +@tenacity.retry( + wait=tenacity.wait_fixed(1), + stop=tenacity.stop_after_delay(RAY_CLUSTER_START_VERIFICATION_TIMEOUT), + retry=tenacity.retry_if_result(lambda x: x is False), + reraise=True, +) +def _verify_gcs_running(expected_address: str, proc: subprocess.Popen) -> bool: + """Verify that the Ray GCS is running at the expected address. + + Args: + expected_address: The expected GCS address (ip:port format) + proc: The subprocess running the Ray cluster + + Returns: + True if GCS is running at expected address, False otherwise + + Raises: + RuntimeError: If the Ray process exited with an error + """ + # Check if the process exited with an error + returncode = proc.poll() + if returncode is not None: + msg = f"Ray cluster failed to start. Process exited with code {returncode}." + logger.error(msg) + raise RuntimeError(msg) + + # Check if GCS is running at the expected address + gcs_addresses = find_gcs_addresses() + if gcs_addresses: + # Canonicalize both addresses for comparison + canonical_gcs_addresses = [] + for gcs_address in gcs_addresses: + canonical_gcs_addresses.append(canonicalize_bootstrap_address(gcs_address)) + canonical_expected_address = canonicalize_bootstrap_address(expected_address) + if canonical_expected_address in canonical_gcs_addresses: + logger.info(f"Ray cluster successfully started at {expected_address}") + return True + logger.debug(f"Found GCS at {gcs_addresses}, waiting for {expected_address}") + return False + + def init_cluster( # noqa: PLR0913 ray_port: int, ray_temp_dir: str, @@ -123,4 +167,22 @@ def init_cluster( # noqa: PLR0913 proc = subprocess.Popen(ray_command, shell=False) # noqa: S603 logger.info(f"Ray start command: {' '.join(ray_command)}") + + # Verify that Ray cluster actually started successfully using tenacity retry logic + expected_address = f"{ip_address}:{ray_port}" + try: + _verify_gcs_running(expected_address, proc) + except tenacity.RetryError: + # Check one final time if process failed + returncode = proc.poll() + if returncode is not None: + msg = f"Ray cluster failed to start. Process exited with code {returncode}." + logger.error(msg) + raise RuntimeError(msg) # noqa: B904 + + # Process is still running but GCS not detected + msg = f"Ray cluster verification timeout after {RAY_CLUSTER_START_VERIFICATION_TIMEOUT}s. GCS address not detected at {expected_address}." + logger.error(msg) + raise RuntimeError(msg) # noqa: B904 + return proc diff --git a/pyproject.toml b/pyproject.toml index b82358de6..ef0cf9e57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ dependencies = [ "pandas>=2.1.0", "pyarrow", "ray[default,data]>=2.49", + "tenacity", "torch", "transformers==4.55.2", ] diff --git a/uv.lock b/uv.lock index fbec3d739..c3de21719 100644 --- a/uv.lock +++ b/uv.lock @@ -3925,6 +3925,7 @@ dependencies = [ { name = "pyarrow", version = "19.0.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'aarch64' or platform_machine == 'x86_64'" }, { name = "pyarrow", version = "21.0.0", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'" }, { name = "ray", extra = ["data", "default"] }, + { name = "tenacity" }, { name = "torch", version = "2.7.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'darwin'" }, { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'x86_64'" }, { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, @@ -4169,6 +4170,7 @@ requires-dist = [ { name = "resiliparse", marker = "extra == 'text-cpu'" }, { name = "s5cmd", marker = "extra == 'text-cpu'" }, { name = "sentencepiece", marker = "extra == 'text-cpu'" }, + { name = "tenacity" }, { name = "torch", marker = "platform_machine != 'x86_64' or sys_platform == 'darwin'", index = "https://pypi.org/simple" }, { name = "torch", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'", index = "https://download.pytorch.org/whl/cu128" }, { name = "torch", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'video-cuda12'", specifier = "<=2.8.0", index = "https://download.pytorch.org/whl/cu128" }, @@ -7804,6 +7806,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/27/44/aa5c8b10b2cce7a053018e0d132bd58e27527a0243c4985383d5b6fd93e9/tblib-3.1.0-py3-none-any.whl", hash = "sha256:670bb4582578134b3d81a84afa1b016128b429f3d48e6cbbaecc9d15675e984e", size = 12552, upload-time = "2025-03-31T12:58:26.142Z" }, ] +[[package]] +name = "tenacity" +version = "9.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0a/d4/2b0cd0fe285e14b36db076e78c93766ff1d529d70408bd1d2a5a84f1d929/tenacity-9.1.2.tar.gz", hash = "sha256:1169d376c297e7de388d18b4481760d478b0e99a777cad3a9c86e556f4b697cb", size = 48036, upload-time = "2025-04-02T08:25:09.966Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138", size = 28248, upload-time = "2025-04-02T08:25:07.678Z" }, +] + [[package]] name = "tensorboard" version = "2.20.0"