Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions rlix/pipeline/miles_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import asyncio
import logging
import math
import os
import threading
import time
from copy import deepcopy
Expand All @@ -37,7 +38,7 @@
get_pipeline_namespace,
)
from rlix.protocol.validation import validate_pipeline_id
from rlix.utils.env import pipeline_identity_env_vars
from rlix.utils.env import parse_env_positive_float, pipeline_identity_env_vars
from rlix.utils.ray import get_actor_or_raise

logger = logging.getLogger(__name__)
Expand All @@ -60,9 +61,13 @@ def _build_pipeline_env_vars(*, pipeline_id: str, ray_namespace: str) -> Dict[st
runtime_env. Reads ``RLIX_CONTROL_PLANE`` from the environment so
actors inside an existing pipeline preserve the inherited value.
"""
return pipeline_identity_env_vars(
env_vars = pipeline_identity_env_vars(
pipeline_id=str(pipeline_id), ray_namespace=str(ray_namespace)
)
for key in ("MILES_MAX_RESIDUAL_GPU_MEM_GB",):
if (value := os.environ.get(key)) is not None:
env_vars[key] = value
return env_vars


class MilesCoordinator(Coordinator):
Expand Down Expand Up @@ -431,7 +436,31 @@ def _shrink_workers(self, engine_indices: Set[int]) -> None:
if rollout_manager is None:
raise RuntimeError("resource registration missing for shrink")
# RPC outside the lock.
ray.get(rollout_manager.shrink_engines.remote(sorted(engine_indices)))
# Per-engine PROCESS-resident GPU memory threshold (GiB) passed to
# MILES shrink_engines -> assert_post_sleep_process_vram_below_threshold.
# Default 3.0: an offloaded SGLang scheduler process measures ~1.8 GiB
# resident (mostly non-offloadable CUDA context), so 3.0 leaves margin
# over that baseline while still catching large residuals such as an
# unoffloaded KV pool. (A 0.5B weight-only offload miss adds only ~1 GiB
# and may not trip it; the gate targets large KV/full-offload failures.)
# This is NOT whole-GPU used and NOT /server_info accounting.
residual_threshold_gb = parse_env_positive_float(
"MILES_MAX_RESIDUAL_GPU_MEM_GB", 3.0
)
shrunk = ray.get(
rollout_manager.shrink_engines.remote(
sorted(engine_indices),
post_sleep_vram_threshold_gb=residual_threshold_gb,
)
)
logger.info(
"[MilesCoordinator] shrink_engines complete pipeline_id=%s "
"engine_indices=%s per_process_residual_threshold=%.1f GB "
"(per-engine resident gate ran inside shrink_engines; fail-open if unmeasurable)",
self._pipeline_id,
sorted(shrunk),
residual_threshold_gb,
)
# Commit under lock.
with self._resize_sync_lock:
self._active_engine_indices -= engine_indices
Expand Down
87 changes: 30 additions & 57 deletions rlix/pipeline/miles_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,16 +503,14 @@ def _init_phase_b_infer(self) -> None:

def _wait_for_overlap_engines_offloaded(self, allocated_train_gpus, *, timeout_s: float = 60.0) -> None:
"""After scheduler grants actor_train, poll the rollout manager
until the engines on overlap GPUs have transitioned to ``offloaded``
AND the OS-reported GPU memory is actually free. SGLang's HTTP
``/release_memory_occupation`` 200 OK + state="offloaded" do not
by themselves guarantee the CUDA driver has returned the memory
to the OS pool — the wake_up in the next-process train actor
would then OOM. Verify actual GPU mem free by parsing
``nvidia-smi --query-gpu=memory.free`` on the same node, since
miles' single-node smoke topology has driver+actors+engines all
on the head node and ``CUDA_VISIBLE_DEVICES`` is the per-actor
slice of the shared physical pool.
until the engines on overlap GPUs have transitioned to ``offloaded``.

The hard residual-allocation safety check runs during
``RolloutManager.shrink_engines`` via SGLang ``/server_info``
(weight + kvcache + graph). This method only waits for the state
transition and logs raw OS-level ``nvidia-smi memory.used`` as a
diagnostic, because process-level GPU usage includes CUDA / Ray /
runtime overhead beyond SGLang's offloadable allocations.
"""
rollout_manager = getattr(self, "_rollout_manager", None)
if rollout_manager is None:
Expand Down Expand Up @@ -573,52 +571,27 @@ def _wait_for_overlap_engines_offloaded(self, allocated_train_gpus, *, timeout_s
timeout_s, target_indices, uniq,
)

# Phase 2: probe nvidia-smi for OS-level free memory on the
# overlap GPU IDs. The train actor will need ~3.7 GB for the
# 0.5B model + a few GB for activations; aim for ≥20 GB free
# before we let _before_training proceed to wake_up.
target_free_gb = 20.0
deadline2 = time.time() + float(timeout_s)
last_min_free_gb: Optional[float] = None
nvidia_smi_unavail_count = 0
while time.time() < deadline2:
min_free_gb = self._probe_min_free_gpu_mem_gb(target_gpu_ids)
if min_free_gb is None:
# F5 (m11-review.review-report.md §2): nvidia-smi unavailable
# or unparseable. Was logged at DEBUG only — promoted to INFO
# so operators see the fallback without flipping log levels.
# If this fires repeatedly across sessions, it's a hardware
# / image regression worth investigating (driver missing,
# nvidia-smi path changed, etc.).
nvidia_smi_unavail_count += 1
logger.info(
"_wait_for_overlap_engines_offloaded: nvidia-smi probe "
"unavailable (count=%d); falling back to 3s grace sleep",
nvidia_smi_unavail_count,
)
time.sleep(3.0)
return
last_min_free_gb = min_free_gb
if min_free_gb >= target_free_gb:
logger.info(
"_wait_for_overlap_engines_offloaded: OS-level GPU mem free "
"min=%.2f GB across overlap GPUs %s (target=%.1f GB)",
min_free_gb, target_gpu_ids, target_free_gb,
)
return
time.sleep(0.5)
logger.warning(
"_wait_for_overlap_engines_offloaded: free-mem timeout after %.1fs; "
"min_free_gb=%.2f below %.1f GB target on GPUs %s — wake_up may OOM",
timeout_s,
last_min_free_gb if last_min_free_gb is not None else float("nan"),
target_free_gb,
# Phase 2: log raw nvidia-smi used memory as diagnostics only.
# The hard safety check now runs inside RolloutManager.shrink_engines
# via SGLang /server_info (weight + kvcache + graph), which is a
# narrower residual-allocation signal than process-level GPU usage.
max_used_gb = self._probe_max_used_gpu_mem_gb(target_gpu_ids)
if max_used_gb is None:
logger.info(
"_wait_for_overlap_engines_offloaded: nvidia-smi probe unavailable; "
"server-side SGLang residual check already ran during shrink"
)
return
logger.info(
"_wait_for_overlap_engines_offloaded: OS-level GPU mem used max=%.2f GB "
"across overlap GPUs %s (diagnostic; SGLang residual assert is the gate)",
max_used_gb,
target_gpu_ids,
)

@staticmethod
def _probe_min_free_gpu_mem_gb(gpu_ids: list[int]) -> Optional[float]:
"""Return the minimum free GPU memory (GB) across ``gpu_ids`` as
def _probe_max_used_gpu_mem_gb(gpu_ids: list[int]) -> Optional[float]:
"""Return the maximum used GPU memory (GB) across ``gpu_ids`` as
reported by ``nvidia-smi``. Returns ``None`` if nvidia-smi is
not available or output cannot be parsed.
"""
Expand All @@ -634,7 +607,7 @@ def _probe_min_free_gpu_mem_gb(gpu_ids: list[int]) -> Optional[float]:
[
"nvidia-smi",
f"--id={','.join(str(g) for g in gpu_ids)}",
"--query-gpu=memory.free",
"--query-gpu=memory.used",
"--format=csv,noheader,nounits",
],
stderr=subprocess.STDOUT,
Expand All @@ -643,18 +616,18 @@ def _probe_min_free_gpu_mem_gb(gpu_ids: list[int]) -> Optional[float]:
except (subprocess.SubprocessError, OSError) as exc:
logger.debug("nvidia-smi probe failed: %r", exc)
return None
free_mibs: list[float] = []
used_mibs: list[float] = []
for line in out.strip().splitlines():
line = line.strip()
if not line:
continue
try:
free_mibs.append(float(line))
used_mibs.append(float(line))
except ValueError:
continue
if not free_mibs:
if not used_mibs:
return None
return min(free_mibs) / 1024.0
return max(used_mibs) / 1024.0

def _before_training(self, step: int) -> None:
if not self._initialized:
Expand Down
18 changes: 18 additions & 0 deletions rlix/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,21 @@ def parse_env_timeout_s(env_key: str, default_s: Optional[float] = None) -> Opti
except ValueError as exc:
raise RuntimeError(f"{env_key} must be a number, got: {raw!r}") from exc
return None if value <= 0 else value


def parse_env_positive_float(env_key: str, default: float) -> float:
"""Read a positive float from an env var; fail-fast on invalid values.

Returns *default* when the env var is unset. Raises RuntimeError if the
value cannot be parsed as a number, or if the parsed value is <= 0.
"""
raw = os.environ.get(env_key)
if raw is None:
return float(default)
try:
value = float(raw)
except ValueError as exc:
raise RuntimeError(f"{env_key} must be a number, got: {raw!r}") from exc
if value <= 0.0:
raise RuntimeError(f"{env_key} must be > 0, got: {value!r}")
return value
56 changes: 56 additions & 0 deletions tests/test_env_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import importlib
import sys
import types
from pathlib import Path

import pytest

REPO_ROOT = Path(__file__).resolve().parents[1]
RLIX_ROOT = REPO_ROOT / "rlix"


def _load_env_module(monkeypatch):
for module_name in list(sys.modules):
if module_name == "rlix" or module_name.startswith("rlix."):
monkeypatch.delitem(sys.modules, module_name, raising=False)

package_roots = {
"rlix": RLIX_ROOT,
"rlix.utils": RLIX_ROOT / "utils",
}
for module_name, module_path in package_roots.items():
package_module = types.ModuleType(module_name)
package_module.__path__ = [str(module_path)] # type: ignore[attr-defined]
monkeypatch.setitem(sys.modules, module_name, package_module)

return importlib.import_module("rlix.utils.env")


def test_parse_env_positive_float_uses_default_when_unset(monkeypatch):
env = _load_env_module(monkeypatch)
monkeypatch.delenv("MILES_MAX_RESIDUAL_GPU_MEM_GB", raising=False)

assert env.parse_env_positive_float("MILES_MAX_RESIDUAL_GPU_MEM_GB", 2.0) == 2.0


def test_parse_env_positive_float_reads_override(monkeypatch):
env = _load_env_module(monkeypatch)
monkeypatch.setenv("MILES_MAX_RESIDUAL_GPU_MEM_GB", "40.5")

assert env.parse_env_positive_float("MILES_MAX_RESIDUAL_GPU_MEM_GB", 2.0) == 40.5


def test_parse_env_positive_float_rejects_non_positive(monkeypatch):
env = _load_env_module(monkeypatch)
monkeypatch.setenv("MILES_MAX_RESIDUAL_GPU_MEM_GB", "0")

with pytest.raises(RuntimeError, match="must be > 0"):
env.parse_env_positive_float("MILES_MAX_RESIDUAL_GPU_MEM_GB", 2.0)


def test_parse_env_positive_float_rejects_non_numeric(monkeypatch):
env = _load_env_module(monkeypatch)
monkeypatch.setenv("MILES_MAX_RESIDUAL_GPU_MEM_GB", "not-a-number")

with pytest.raises(RuntimeError, match="must be a number"):
env.parse_env_positive_float("MILES_MAX_RESIDUAL_GPU_MEM_GB", 2.0)
68 changes: 68 additions & 0 deletions tests/test_miles_residual_threshold_wiring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

import ast
from pathlib import Path

REPO_ROOT = Path(__file__).resolve().parents[1]


def _is_name(node: ast.AST, name: str) -> bool:
return isinstance(node, ast.Name) and node.id == name


def _is_attr(node: ast.AST, attr: str) -> bool:
return isinstance(node, ast.Attribute) and node.attr == attr


def test_miles_shrink_uses_server_side_residual_threshold() -> None:
source = (REPO_ROOT / "rlix" / "pipeline" / "miles_coordinator.py").read_text(
encoding="utf-8"
)
tree = ast.parse(source)

shrink_fn = next(
node
for node in ast.walk(tree)
if isinstance(node, ast.FunctionDef) and node.name == "_shrink_workers"
)

assert any(
isinstance(node, ast.Call)
and _is_name(node.func, "parse_env_positive_float")
and len(node.args) >= 2
and isinstance(node.args[0], ast.Constant)
and node.args[0].value == "MILES_MAX_RESIDUAL_GPU_MEM_GB"
and isinstance(node.args[1], ast.Constant)
and node.args[1].value == 3.0
for node in ast.walk(shrink_fn)
), "_shrink_workers must parse the residual threshold env var with 3GB default"

assert any(
isinstance(node, ast.Call)
and _is_attr(node.func, "remote")
and any(
kw.arg == "post_sleep_vram_threshold_gb"
and _is_name(kw.value, "residual_threshold_gb")
for kw in node.keywords
)
for node in ast.walk(shrink_fn)
), "shrink_engines must receive post_sleep_vram_threshold_gb"


def test_miles_coordinator_forwards_residual_threshold_env_var() -> None:
source = (REPO_ROOT / "rlix" / "pipeline" / "miles_coordinator.py").read_text(
encoding="utf-8"
)
tree = ast.parse(source)

build_env_fn = next(
node
for node in ast.walk(tree)
if isinstance(node, ast.FunctionDef) and node.name == "_build_pipeline_env_vars"
)

assert any(
isinstance(node, ast.Constant)
and node.value == "MILES_MAX_RESIDUAL_GPU_MEM_GB"
for node in ast.walk(build_env_fn)
), "_build_pipeline_env_vars must forward MILES_MAX_RESIDUAL_GPU_MEM_GB"