Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/rlix/run_miles_dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def _build_pipeline(
pipeline_runtime_env_vars["PYTHONPATH"] = pythonpath
for _k in (
"MILES_TMS_HOOK_MODE",
"MILES_MAX_RESIDUAL_GPU_MEM_GB",
"MILES_SKIP_TMS_PAUSE",
"MILES_SKIP_NODE_PG_PIN",
"TMS_INIT_ENABLE_CPU_BACKUP",
Expand Down
1 change: 1 addition & 0 deletions examples/rlix/run_miles_rlix.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class MilesPipelineConfig:
# the parent driver's env by default).
for _k in (
"MILES_TMS_HOOK_MODE",
"MILES_MAX_RESIDUAL_GPU_MEM_GB",
"MILES_SKIP_TMS_PAUSE",
"MILES_SKIP_NODE_PG_PIN",
"TMS_INIT_ENABLE_CPU_BACKUP",
Expand Down
88 changes: 88 additions & 0 deletions miles/backends/sglang_utils/sglang_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME, convert_target_modules_to_hf, is_lora_enabled
from miles.ray.ray_actor import RayActor
from miles.utils.gpu_probe import query_process_tree_gpu_used_gb
from miles.utils.env_report import collect_and_print_node_env_report
from miles.utils.http_utils import get_host_info

Expand Down Expand Up @@ -756,6 +757,93 @@ def assert_post_sleep_vram_below_threshold(
)
return observed_max_gb

def _server_info_residual_gb(self, timeout_s: float = 5.0):
"""SGLang /server_info weight+kvcache+graph, max across DPs (GiB).

This is *accounting* (KV static-pool size). It does NOT drop after a
torch_memory_saver pause, so it is logged for diagnostics only and is
never used as a hard gate. Returns None if unavailable.
"""
try:
body = self.get_server_info()
except Exception:
return None
internal_states = body.get("internal_states") if isinstance(body, dict) else None
if not isinstance(internal_states, list) or not internal_states:
return None
observed_max_gb = 0.0
for state in internal_states:
mem = state.get("memory_usage") if isinstance(state, dict) else None
if not isinstance(mem, dict):
continue
total_gb = sum(float(mem.get(k, 0.0) or 0.0) for k in ("weight", "kvcache", "graph"))
observed_max_gb = max(observed_max_gb, total_gb)
return observed_max_gb

def get_process_tree_gpu_used_gb(self, timeout_s: float = 5.0):
"""Real resident GPU memory (GiB) of this engine's SGLang process
tree, via nvidia-smi compute-apps (see ``miles.utils.gpu_probe``).

Returns ``None`` (fail-open) when unmeasurable — nvidia-smi missing or
a PID-namespace mismatch inside a container. Callers MUST treat
``None`` as "cannot measure", never as 0.
"""
root = getattr(self, "process", None)
return query_process_tree_gpu_used_gb(
getattr(root, "pid", None), timeout_s=timeout_s
)

def assert_post_sleep_process_vram_below_threshold(
self, threshold_gb: float, timeout_s: float = 5.0
):
"""Hard gate (fail-open) on this engine's REAL resident GPU memory
after ``release_memory_occupation``, measured per-process via
``nvidia-smi`` compute-apps over the engine's process tree.

Behavior:
- non-rank-0 node: no-op, return None.
- measurable and > threshold: raise RuntimeError.
- measurable and <= threshold: return observed GiB.
- NOT measurable (nvidia-smi missing / parse fail / PID-namespace
mismatch): fail-open — log a warning and return None WITHOUT
raising, so a missing metric never kills a healthy pipeline
(engine-state polling stays the liveness gate).

``/server_info`` accounting is logged alongside as diagnostic.
"""
if self.node_rank != 0:
return None
_log = logging.getLogger(__name__)
account_gb = self._server_info_residual_gb(timeout_s=timeout_s)
resident_gb = self.get_process_tree_gpu_used_gb(timeout_s=timeout_s)
_log.info(
"post-sleep residual engine=%s:%s process_resident=%s GiB "
"server_info_accounting(weight+kvcache+graph)=%s GiB threshold=%.3f GiB",
self.server_host,
self.server_port,
("%.3f" % resident_gb) if resident_gb is not None else "n/a",
("%.3f" % account_gb) if account_gb is not None else "n/a",
float(threshold_gb),
)
if resident_gb is None:
_log.warning(
"post-sleep process-resident probe unavailable on engine "
"%s:%s (nvidia-smi missing or PID-namespace mismatch); "
"skipping hard gate (fail-open).",
self.server_host,
self.server_port,
)
return None
if resident_gb > float(threshold_gb):
raise RuntimeError(
f"Post-sleep process-resident GPU memory {resident_gb:.3f} GiB "
f"exceeds threshold {float(threshold_gb):.3f} GiB on engine "
f"{self.server_host}:{self.server_port} — offload did not free "
f"this engine's GPU memory (check release_memory_occupation / "
f"torch_memory_saver)."
)
return resident_gb

def resume_memory_occupation(self, tags: list[str] = None):
"""
Available tags for multi-stage resume: weights, kv_cache
Expand Down
25 changes: 22 additions & 3 deletions miles/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,16 +1031,35 @@ def shrink_engines(
)
# Step 4: release memory.
ray.get([h.release_memory_occupation.remote(tags=None) for h in handles])
# Step 5: optional post-sleep VRAM assert.
# Step 5: optional post-sleep VRAM hard gate. Gate on each
# engine's REAL per-process resident GPU memory (nvidia-smi
# compute-apps over the SGLang process tree), NOT /server_info
# accounting — the latter reports KV static-pool size and does
# not drop after a torch_memory_saver pause, so it falsely shows
# ~9 GiB for an offloaded 0.5B engine whose process is actually
# ~1.8 GiB resident. fail-open: an unmeasurable engine (PID-
# namespace mismatch / no nvidia-smi) returns None and is skipped
# rather than killing a healthy pipeline.
if post_sleep_vram_threshold_gb is not None:
ray.get(
observed_resident_gbs = ray.get(
[
h.assert_post_sleep_vram_below_threshold.remote(
h.assert_post_sleep_process_vram_below_threshold.remote(
threshold_gb=post_sleep_vram_threshold_gb
)
for h in handles
]
)
measured = [v for v in observed_resident_gbs if v is not None]
logger.info(
"shrink_engines: post-sleep process-resident GPU residual "
"max=%s GiB per_engine=%s threshold=%.3f GiB "
"engine_indices=%s (real resident; server_info accounting "
"logged per-engine)",
("%.3f" % max(measured)) if measured else "n/a",
[None if v is None else round(float(v), 3) for v in observed_resident_gbs],
float(post_sleep_vram_threshold_gb),
indices,
)
except Exception:
# Reset the abort cache on failure so retry re-aborts new
# in-flights that arrived during the failed cycle.
Expand Down
165 changes: 165 additions & 0 deletions miles/utils/gpu_probe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""GPU residual probing helpers — pure, dependency-free, unit-testable.

Used by :class:`SGLangEngine` to measure a SGLang server's REAL per-process
resident GPU memory after offload, via ``nvidia-smi`` compute-apps over the
engine's process tree. Kept free of sglang/torch imports so the parsing and
fail-open logic can be unit-tested without a GPU.

Semantics: ``MILES_MAX_RESIDUAL_GPU_MEM_GB`` is the **max per-GPU** resident
residual for an engine — for each GPU, sum the engine's process-tree usage on
that GPU, then take the max across GPUs. This supports TP>1 without summing
across cards (which would over-count and falsely trip the gate).
"""
from __future__ import annotations

import logging
import shutil
import subprocess

logger = logging.getLogger(__name__)


def build_process_tree(root_pid, proc_root: str = "/proc") -> set:
"""Return ``root_pid`` plus all descendant PIDs by reading
``<proc_root>/<pid>/stat`` ppid links. Pure /proc walk, no psutil.

``self.process.pid`` is the multiprocessing spawn parent; the real
GPU-resident process is the ``sglang::scheduler`` child, so the whole
tree must be walked.
"""
import os

try:
entries = [int(p) for p in os.listdir(proc_root) if p.isdigit()]
except OSError:
return {root_pid}
children: dict = {}
for pid in entries:
try:
with open(os.path.join(proc_root, str(pid), "stat"), "rb") as f:
data = f.read()
except OSError:
continue
# comm (2nd field) is paren-wrapped and may contain spaces/parens;
# ppid is the 2nd whitespace token after the final ')'.
try:
rparen = data.rindex(b")")
ppid = int(data[rparen + 2:].split()[1])
except (ValueError, IndexError):
continue
children.setdefault(ppid, []).append(pid)
tree = {root_pid}
stack = [root_pid]
while stack:
cur = stack.pop()
for ch in children.get(cur, ()):
if ch not in tree:
tree.add(ch)
stack.append(ch)
return tree


def parse_compute_apps_per_gpu_max_gb(nvidia_csv: str, tree_pids: set):
"""Parse ``nvidia-smi --query-compute-apps=gpu_bus_id,pid,used_memory``.

For PIDs in ``tree_pids``: sum ``used_memory`` (MiB) per GPU
(keyed by ``gpu_bus_id``), then take the MAX across GPUs and return GiB.
This is the ``MILES_MAX_RESIDUAL_GPU_MEM_GB`` semantics: the engine's
worst single-GPU resident residual (TP-safe — no cross-card summing).

Returns ``None`` (fail-open) if no tree PID appears in the listing.
"""
per_gpu: dict = {}
matched = False
for line in nvidia_csv.strip().splitlines():
parts = [p.strip() for p in line.split(",")]
if len(parts) < 3:
continue
bus_id = parts[0]
try:
pid = int(parts[1])
used = float(parts[2])
except ValueError:
continue
if pid in tree_pids:
per_gpu[bus_id] = per_gpu.get(bus_id, 0.0) + used
matched = True
if not matched:
return None
return max(per_gpu.values()) / 1024.0


def parse_compute_apps_used_gb(nvidia_csv: str, tree_pids: set):
"""Fallback parser for the legacy 2-col ``pid,used_memory`` query (no
``gpu_bus_id``). Sums all matched rows -> GiB. Used only when the
GPU-aware query is unavailable; it cannot distinguish per-GPU, so it
over-estimates for a multi-GPU engine.

Returns ``None`` (fail-open) if no tree PID appears.
"""
total_mib = 0.0
matched = False
for line in nvidia_csv.strip().splitlines():
parts = [p.strip() for p in line.split(",")]
if len(parts) < 2:
continue
try:
pid = int(parts[0])
used = float(parts[1])
except ValueError:
continue
if pid in tree_pids:
total_mib += used
matched = True
if not matched:
return None
return total_mib / 1024.0


def _run_nvidia_smi(args, timeout_s: float):
try:
return subprocess.check_output(
["nvidia-smi"] + args, stderr=subprocess.STDOUT, timeout=timeout_s
).decode("utf-8", errors="replace")
except (subprocess.SubprocessError, OSError):
return None


def query_process_tree_gpu_used_gb(root_pid, timeout_s: float = 5.0,
proc_root: str = "/proc"):
"""Max per-GPU resident GPU memory (GiB) of ``root_pid``'s process tree.

Prefers the GPU-aware query (``gpu_bus_id,pid,used_memory``): per-GPU
sum, max across GPUs. Falls back to the legacy 2-col query
(``pid,used_memory``, summed) with a warning if the GPU-aware query is
unsupported by this nvidia-smi.

Returns ``None`` (fail-open) when nvidia-smi is missing, the call fails,
or no tree PID appears (PID-namespace mismatch inside a container).
Callers MUST treat ``None`` as "cannot measure", never as 0.
"""
if root_pid is None or shutil.which("nvidia-smi") is None:
return None
tree = build_process_tree(root_pid, proc_root=proc_root)
if not tree:
return None
out = _run_nvidia_smi(
["--query-compute-apps=gpu_bus_id,pid,used_memory",
"--format=csv,noheader,nounits"],
timeout_s,
)
if out is not None:
return parse_compute_apps_per_gpu_max_gb(out, tree)
# GPU-aware query unsupported -> legacy 2-col fallback (summed).
logger.warning(
"nvidia-smi gpu_bus_id query unavailable; falling back to "
"pid,used_memory (summed; cannot distinguish per-GPU)"
)
out = _run_nvidia_smi(
["--query-compute-apps=pid,used_memory",
"--format=csv,noheader,nounits"],
timeout_s,
)
if out is None:
return None
return parse_compute_apps_used_gb(out, tree)
Loading