diff --git a/examples/rlix/run_miles_dual.py b/examples/rlix/run_miles_dual.py index 59174a99bd..e44a436579 100644 --- a/examples/rlix/run_miles_dual.py +++ b/examples/rlix/run_miles_dual.py @@ -289,6 +289,7 @@ def _build_pipeline( pipeline_runtime_env_vars["PYTHONPATH"] = pythonpath for _k in ( "MILES_TMS_HOOK_MODE", + "MILES_MAX_RESIDUAL_GPU_MEM_GB", "MILES_SKIP_TMS_PAUSE", "MILES_SKIP_NODE_PG_PIN", "TMS_INIT_ENABLE_CPU_BACKUP", diff --git a/examples/rlix/run_miles_rlix.py b/examples/rlix/run_miles_rlix.py index 49092fc712..78e6dfd300 100644 --- a/examples/rlix/run_miles_rlix.py +++ b/examples/rlix/run_miles_rlix.py @@ -178,6 +178,7 @@ class MilesPipelineConfig: # the parent driver's env by default). for _k in ( "MILES_TMS_HOOK_MODE", + "MILES_MAX_RESIDUAL_GPU_MEM_GB", "MILES_SKIP_TMS_PAUSE", "MILES_SKIP_NODE_PG_PIN", "TMS_INIT_ENABLE_CPU_BACKUP", diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index beea2342ac..2804e92b36 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -15,6 +15,7 @@ from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME, convert_target_modules_to_hf, is_lora_enabled from miles.ray.ray_actor import RayActor +from miles.utils.gpu_probe import query_process_tree_gpu_used_gb from miles.utils.env_report import collect_and_print_node_env_report from miles.utils.http_utils import get_host_info @@ -756,6 +757,93 @@ def assert_post_sleep_vram_below_threshold( ) return observed_max_gb + def _server_info_residual_gb(self, timeout_s: float = 5.0): + """SGLang /server_info weight+kvcache+graph, max across DPs (GiB). + + This is *accounting* (KV static-pool size). It does NOT drop after a + torch_memory_saver pause, so it is logged for diagnostics only and is + never used as a hard gate. Returns None if unavailable. + """ + try: + body = self.get_server_info() + except Exception: + return None + internal_states = body.get("internal_states") if isinstance(body, dict) else None + if not isinstance(internal_states, list) or not internal_states: + return None + observed_max_gb = 0.0 + for state in internal_states: + mem = state.get("memory_usage") if isinstance(state, dict) else None + if not isinstance(mem, dict): + continue + total_gb = sum(float(mem.get(k, 0.0) or 0.0) for k in ("weight", "kvcache", "graph")) + observed_max_gb = max(observed_max_gb, total_gb) + return observed_max_gb + + def get_process_tree_gpu_used_gb(self, timeout_s: float = 5.0): + """Real resident GPU memory (GiB) of this engine's SGLang process + tree, via nvidia-smi compute-apps (see ``miles.utils.gpu_probe``). + + Returns ``None`` (fail-open) when unmeasurable — nvidia-smi missing or + a PID-namespace mismatch inside a container. Callers MUST treat + ``None`` as "cannot measure", never as 0. + """ + root = getattr(self, "process", None) + return query_process_tree_gpu_used_gb( + getattr(root, "pid", None), timeout_s=timeout_s + ) + + def assert_post_sleep_process_vram_below_threshold( + self, threshold_gb: float, timeout_s: float = 5.0 + ): + """Hard gate (fail-open) on this engine's REAL resident GPU memory + after ``release_memory_occupation``, measured per-process via + ``nvidia-smi`` compute-apps over the engine's process tree. + + Behavior: + - non-rank-0 node: no-op, return None. + - measurable and > threshold: raise RuntimeError. + - measurable and <= threshold: return observed GiB. + - NOT measurable (nvidia-smi missing / parse fail / PID-namespace + mismatch): fail-open — log a warning and return None WITHOUT + raising, so a missing metric never kills a healthy pipeline + (engine-state polling stays the liveness gate). + + ``/server_info`` accounting is logged alongside as diagnostic. + """ + if self.node_rank != 0: + return None + _log = logging.getLogger(__name__) + account_gb = self._server_info_residual_gb(timeout_s=timeout_s) + resident_gb = self.get_process_tree_gpu_used_gb(timeout_s=timeout_s) + _log.info( + "post-sleep residual engine=%s:%s process_resident=%s GiB " + "server_info_accounting(weight+kvcache+graph)=%s GiB threshold=%.3f GiB", + self.server_host, + self.server_port, + ("%.3f" % resident_gb) if resident_gb is not None else "n/a", + ("%.3f" % account_gb) if account_gb is not None else "n/a", + float(threshold_gb), + ) + if resident_gb is None: + _log.warning( + "post-sleep process-resident probe unavailable on engine " + "%s:%s (nvidia-smi missing or PID-namespace mismatch); " + "skipping hard gate (fail-open).", + self.server_host, + self.server_port, + ) + return None + if resident_gb > float(threshold_gb): + raise RuntimeError( + f"Post-sleep process-resident GPU memory {resident_gb:.3f} GiB " + f"exceeds threshold {float(threshold_gb):.3f} GiB on engine " + f"{self.server_host}:{self.server_port} — offload did not free " + f"this engine's GPU memory (check release_memory_occupation / " + f"torch_memory_saver)." + ) + return resident_gb + def resume_memory_occupation(self, tags: list[str] = None): """ Available tags for multi-stage resume: weights, kv_cache diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 30c7d78db9..283b9f8f8e 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -1031,16 +1031,35 @@ def shrink_engines( ) # Step 4: release memory. ray.get([h.release_memory_occupation.remote(tags=None) for h in handles]) - # Step 5: optional post-sleep VRAM assert. + # Step 5: optional post-sleep VRAM hard gate. Gate on each + # engine's REAL per-process resident GPU memory (nvidia-smi + # compute-apps over the SGLang process tree), NOT /server_info + # accounting — the latter reports KV static-pool size and does + # not drop after a torch_memory_saver pause, so it falsely shows + # ~9 GiB for an offloaded 0.5B engine whose process is actually + # ~1.8 GiB resident. fail-open: an unmeasurable engine (PID- + # namespace mismatch / no nvidia-smi) returns None and is skipped + # rather than killing a healthy pipeline. if post_sleep_vram_threshold_gb is not None: - ray.get( + observed_resident_gbs = ray.get( [ - h.assert_post_sleep_vram_below_threshold.remote( + h.assert_post_sleep_process_vram_below_threshold.remote( threshold_gb=post_sleep_vram_threshold_gb ) for h in handles ] ) + measured = [v for v in observed_resident_gbs if v is not None] + logger.info( + "shrink_engines: post-sleep process-resident GPU residual " + "max=%s GiB per_engine=%s threshold=%.3f GiB " + "engine_indices=%s (real resident; server_info accounting " + "logged per-engine)", + ("%.3f" % max(measured)) if measured else "n/a", + [None if v is None else round(float(v), 3) for v in observed_resident_gbs], + float(post_sleep_vram_threshold_gb), + indices, + ) except Exception: # Reset the abort cache on failure so retry re-aborts new # in-flights that arrived during the failed cycle. diff --git a/miles/utils/gpu_probe.py b/miles/utils/gpu_probe.py new file mode 100644 index 0000000000..ba865de174 --- /dev/null +++ b/miles/utils/gpu_probe.py @@ -0,0 +1,165 @@ +"""GPU residual probing helpers — pure, dependency-free, unit-testable. + +Used by :class:`SGLangEngine` to measure a SGLang server's REAL per-process +resident GPU memory after offload, via ``nvidia-smi`` compute-apps over the +engine's process tree. Kept free of sglang/torch imports so the parsing and +fail-open logic can be unit-tested without a GPU. + +Semantics: ``MILES_MAX_RESIDUAL_GPU_MEM_GB`` is the **max per-GPU** resident +residual for an engine — for each GPU, sum the engine's process-tree usage on +that GPU, then take the max across GPUs. This supports TP>1 without summing +across cards (which would over-count and falsely trip the gate). +""" +from __future__ import annotations + +import logging +import shutil +import subprocess + +logger = logging.getLogger(__name__) + + +def build_process_tree(root_pid, proc_root: str = "/proc") -> set: + """Return ``root_pid`` plus all descendant PIDs by reading + ``//stat`` ppid links. Pure /proc walk, no psutil. + + ``self.process.pid`` is the multiprocessing spawn parent; the real + GPU-resident process is the ``sglang::scheduler`` child, so the whole + tree must be walked. + """ + import os + + try: + entries = [int(p) for p in os.listdir(proc_root) if p.isdigit()] + except OSError: + return {root_pid} + children: dict = {} + for pid in entries: + try: + with open(os.path.join(proc_root, str(pid), "stat"), "rb") as f: + data = f.read() + except OSError: + continue + # comm (2nd field) is paren-wrapped and may contain spaces/parens; + # ppid is the 2nd whitespace token after the final ')'. + try: + rparen = data.rindex(b")") + ppid = int(data[rparen + 2:].split()[1]) + except (ValueError, IndexError): + continue + children.setdefault(ppid, []).append(pid) + tree = {root_pid} + stack = [root_pid] + while stack: + cur = stack.pop() + for ch in children.get(cur, ()): + if ch not in tree: + tree.add(ch) + stack.append(ch) + return tree + + +def parse_compute_apps_per_gpu_max_gb(nvidia_csv: str, tree_pids: set): + """Parse ``nvidia-smi --query-compute-apps=gpu_bus_id,pid,used_memory``. + + For PIDs in ``tree_pids``: sum ``used_memory`` (MiB) per GPU + (keyed by ``gpu_bus_id``), then take the MAX across GPUs and return GiB. + This is the ``MILES_MAX_RESIDUAL_GPU_MEM_GB`` semantics: the engine's + worst single-GPU resident residual (TP-safe — no cross-card summing). + + Returns ``None`` (fail-open) if no tree PID appears in the listing. + """ + per_gpu: dict = {} + matched = False + for line in nvidia_csv.strip().splitlines(): + parts = [p.strip() for p in line.split(",")] + if len(parts) < 3: + continue + bus_id = parts[0] + try: + pid = int(parts[1]) + used = float(parts[2]) + except ValueError: + continue + if pid in tree_pids: + per_gpu[bus_id] = per_gpu.get(bus_id, 0.0) + used + matched = True + if not matched: + return None + return max(per_gpu.values()) / 1024.0 + + +def parse_compute_apps_used_gb(nvidia_csv: str, tree_pids: set): + """Fallback parser for the legacy 2-col ``pid,used_memory`` query (no + ``gpu_bus_id``). Sums all matched rows -> GiB. Used only when the + GPU-aware query is unavailable; it cannot distinguish per-GPU, so it + over-estimates for a multi-GPU engine. + + Returns ``None`` (fail-open) if no tree PID appears. + """ + total_mib = 0.0 + matched = False + for line in nvidia_csv.strip().splitlines(): + parts = [p.strip() for p in line.split(",")] + if len(parts) < 2: + continue + try: + pid = int(parts[0]) + used = float(parts[1]) + except ValueError: + continue + if pid in tree_pids: + total_mib += used + matched = True + if not matched: + return None + return total_mib / 1024.0 + + +def _run_nvidia_smi(args, timeout_s: float): + try: + return subprocess.check_output( + ["nvidia-smi"] + args, stderr=subprocess.STDOUT, timeout=timeout_s + ).decode("utf-8", errors="replace") + except (subprocess.SubprocessError, OSError): + return None + + +def query_process_tree_gpu_used_gb(root_pid, timeout_s: float = 5.0, + proc_root: str = "/proc"): + """Max per-GPU resident GPU memory (GiB) of ``root_pid``'s process tree. + + Prefers the GPU-aware query (``gpu_bus_id,pid,used_memory``): per-GPU + sum, max across GPUs. Falls back to the legacy 2-col query + (``pid,used_memory``, summed) with a warning if the GPU-aware query is + unsupported by this nvidia-smi. + + Returns ``None`` (fail-open) when nvidia-smi is missing, the call fails, + or no tree PID appears (PID-namespace mismatch inside a container). + Callers MUST treat ``None`` as "cannot measure", never as 0. + """ + if root_pid is None or shutil.which("nvidia-smi") is None: + return None + tree = build_process_tree(root_pid, proc_root=proc_root) + if not tree: + return None + out = _run_nvidia_smi( + ["--query-compute-apps=gpu_bus_id,pid,used_memory", + "--format=csv,noheader,nounits"], + timeout_s, + ) + if out is not None: + return parse_compute_apps_per_gpu_max_gb(out, tree) + # GPU-aware query unsupported -> legacy 2-col fallback (summed). + logger.warning( + "nvidia-smi gpu_bus_id query unavailable; falling back to " + "pid,used_memory (summed; cannot distinguish per-GPU)" + ) + out = _run_nvidia_smi( + ["--query-compute-apps=pid,used_memory", + "--format=csv,noheader,nounits"], + timeout_s, + ) + if out is None: + return None + return parse_compute_apps_used_gb(out, tree) diff --git a/tests/test_gpu_probe.py b/tests/test_gpu_probe.py new file mode 100644 index 0000000000..1725c27400 --- /dev/null +++ b/tests/test_gpu_probe.py @@ -0,0 +1,113 @@ +"""Unit tests for miles.utils.gpu_probe — the per-process GPU residual probe. + +Key properties under test: +- FAIL-OPEN: when none of the engine's process-tree PIDs appear in nvidia-smi + compute-apps (e.g. PID-namespace mismatch in a container), parsers return + None ("cannot measure"), never 0 — otherwise a hard gate would falsely pass. +- MAX-PER-GPU: MILES_MAX_RESIDUAL_GPU_MEM_GB = sum within a GPU, max across + GPUs (TP-safe; never sum across cards). + +Dependency-free: runnable via pytest OR directly (`python3 tests/test_gpu_probe.py`). +""" +import os +import sys +import tempfile + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from miles.utils.gpu_probe import ( + build_process_tree, + parse_compute_apps_per_gpu_max_gb, + parse_compute_apps_used_gb, +) + + +# --- P2: GPU-aware parser (sum within a GPU, max across GPUs) --- + +def test_per_gpu_max_takes_max_across_gpus(): + csv = "0000:98:00.0, 100, 500\n0000:98:00.0, 200, 1024\n0000:A8:00.0, 100, 2000" + # bus 98 = 500+1024 = 1524 MiB; bus A8 = 2000 MiB; max = 2000 + got = parse_compute_apps_per_gpu_max_gb(csv, {100, 200}) + assert abs(got - 2000 / 1024.0) < 1e-6 + + +def test_per_gpu_max_sums_within_same_gpu(): + # two engine procs on the SAME gpu must be summed, not max'd + csv = "0000:98:00.0,100,500\n0000:98:00.0,200,256" + got = parse_compute_apps_per_gpu_max_gb(csv, {100, 200}) + assert abs(got - 756 / 1024.0) < 1e-6 + + +def test_per_gpu_max_excludes_non_tree_pids(): + csv = "0000:98:00.0,100,500\n0000:A8:00.0,999,9000" + got = parse_compute_apps_per_gpu_max_gb(csv, {100}) + assert abs(got - 500 / 1024.0) < 1e-6 + + +def test_per_gpu_max_no_match_returns_none_not_zero(): + assert parse_compute_apps_per_gpu_max_gb("0000:A8:00.0,999,9000", {100, 200}) is None + + +def test_per_gpu_max_empty_returns_none(): + assert parse_compute_apps_per_gpu_max_gb("", {100}) is None + + +# --- legacy 2-col fallback parser --- + +def test_fallback_sums_only_tree_pids(): + got = parse_compute_apps_used_gb("100, 500\n200, 1024\n999, 8000", {100, 200}) + assert abs(got - (1524 / 1024.0)) < 1e-6 + + +def test_fallback_no_match_returns_none_not_zero(): + assert parse_compute_apps_used_gb("999, 8000\n888, 4000", {100, 200}) is None + + +def test_fallback_skips_unparsable_rows(): + csv = "100, [N/A]\n100, 512\nbad line\n, \n200,256" + got = parse_compute_apps_used_gb(csv, {100, 200}) + assert abs(got - (768 / 1024.0)) < 1e-6 + + +# --- process-tree walk --- + +def _stat(proc, pid, comm, ppid): + d = os.path.join(proc, str(pid)) + os.makedirs(d) + with open(os.path.join(d, "stat"), "w") as f: + f.write(f"{pid} ({comm}) S {ppid} 0 0 0 0\n") + + +def test_build_process_tree_walks_descendants(): + with tempfile.TemporaryDirectory() as tmp: + proc = os.path.join(tmp, "proc") + os.makedirs(proc) + _stat(proc, 100, "spawn_main", 1) + _stat(proc, 200, "sglang::sched", 100) + _stat(proc, 300, "sglang::detok", 200) + _stat(proc, 999, "unrelated", 1) + assert build_process_tree(100, proc_root=proc) == {100, 200, 300} + + +def test_build_process_tree_comm_with_spaces_and_parens(): + with tempfile.TemporaryDirectory() as tmp: + proc = os.path.join(tmp, "proc") + os.makedirs(proc) + _stat(proc, 100, "spawn", 1) + d = os.path.join(proc, "200") + os.makedirs(d) + with open(os.path.join(d, "stat"), "w") as f: + f.write("200 (weird (c o m m)) S 100 0\n") + assert build_process_tree(100, proc_root=proc) == {100, 200} + + +def test_build_process_tree_missing_proc_returns_root(): + assert build_process_tree(100, proc_root="/nonexistent_proc_xyz") == {100} + + +if __name__ == "__main__": + fns = [v for k, v in sorted(globals().items()) if k.startswith("test_")] + for fn in fns: + fn() + print("PASS", fn.__name__) + print(f"\n{len(fns)} passed") diff --git a/tests/test_residual_gpu_mem_wiring.py b/tests/test_residual_gpu_mem_wiring.py new file mode 100644 index 0000000000..2a974d77c7 --- /dev/null +++ b/tests/test_residual_gpu_mem_wiring.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[1] + + +def test_rlix_drivers_forward_residual_gpu_mem_env_var() -> None: + for relpath in ( + "examples/rlix/run_miles_rlix.py", + "examples/rlix/run_miles_dual.py", + ): + source = (REPO_ROOT / relpath).read_text(encoding="utf-8") + assert ( + '"MILES_MAX_RESIDUAL_GPU_MEM_GB"' in source + ), f"{relpath} must forward residual threshold env into runtime_env" + + +def test_shrink_gates_on_process_resident_gpu_memory() -> None: + source = (REPO_ROOT / "miles" / "ray" / "rollout.py").read_text( + encoding="utf-8" + ) + # Hard gate uses per-process resident GPU memory, not /server_info accounting. + assert "assert_post_sleep_process_vram_below_threshold" in source + assert "post-sleep process-resident GPU residual" in source