Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions nemo_rl/algorithms/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ def get_existing_target_weights(self) -> set[int]:
with self._lock:
return set(self.target_weight_versions)

def count_target_weight(self, target_weight_version: int) -> int:
"""Return buffered prompt groups intended for one training step."""
with self._lock:
return sum(
1
for version in self.target_weight_versions
if version == target_weight_version
)

def sample(
self,
num_prompt_groups: int,
Expand Down Expand Up @@ -247,6 +256,7 @@ def __init__(
master_config: MasterConfig,
replay_buffer: Any,
start_step: int = 0,
rlix_hooks: Any = None,
):
self.policy_generation = policy_generation
self.tokenizer = tokenizer
Expand All @@ -255,6 +265,17 @@ def __init__(
self.replay_buffer = replay_buffer
self.running = False

# F9: Optional RLix hooks for progress reporting.
# When provided, begin/end_progress_batch is driven inside this actor so
# local hook state is consistent with per-push reports.
# Defaults to NoOpRLixHooks so standalone runs need not pass anything.
if rlix_hooks is None:
from nemo_rl.algorithms.rlix_hooks import NoOpRLixHooks

rlix_hooks = NoOpRLixHooks()
self._rlix_hooks = rlix_hooks
self._rlix_progress_step: Optional[int] = None

self._pg_lock: _threading.Lock = _threading.Lock()

# Event for manual pause/resume control
Expand Down Expand Up @@ -352,6 +373,19 @@ def set_weight_version(self, version: int) -> None:
else:
print(f"🔄 Updated weight version to {version}")

def begin_progress_batch(self, step: int, count_intended: int) -> None:
"""Start reporting progress for the training step RLix is scheduling.

The collector may generate future-targeted trajectories. Only the
active progress step is reported so hook implementations with local
begin/end state do not see mismatched target versions.
"""
self._rlix_progress_step = step
self._rlix_hooks.begin_progress_batch(step, count_intended)
already_buffered = ray.get(self.replay_buffer.count_target_weight.remote(step))
if already_buffered:
self._rlix_hooks.end_progress_batch(step, already_buffered)

def _should_pause_for_generation_limits(self) -> bool:
"""Check if collection should be paused due to generation limits."""
try:
Expand Down Expand Up @@ -703,6 +737,15 @@ def _run_prompt_group_worker(
f"📦 Buffered per-prompt group (prompt_idx {prompt_idx}, target_weight {target_weight_version})"
)

# F9: Report only the target step currently being
# scheduled. Future-targeted trajectories are counted
# when that step becomes active, including a catch-up
# count in begin_progress_batch().
if target_weight_version == self._rlix_progress_step:
self._rlix_hooks.end_progress_batch(
target_weight_version, repeated_batch.size
)

# Release reservation when FIRST prompt group for this target is successfully buffered
if prompt_idx == 0:
with self._generation_check_lock:
Expand Down
146 changes: 112 additions & 34 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2376,6 +2376,7 @@ def async_grpo_train(
grpo_save_state: GRPOSaveState,
master_config: MasterConfig,
max_trajectory_age_steps: int = 1,
rlix_hooks: Optional[Any] = None,
) -> None:
"""Run asynchronous GRPO training with replay buffer.

Expand All @@ -2394,6 +2395,17 @@ def async_grpo_train(
master_config: Master configuration
max_trajectory_age_steps: Maximum age (in training steps) for trajectories to be used in training
"""
# F5/F11: RLix integration flag.
# True when RLIX_CONTROL_PLANE=rlix env var is set; False in standalone mode.
# Controls: skip standalone refit, enable before/after_training hooks, and
# skip prepare_for_generation() / refit_policy_generation() which conflict
# with scheduler-driven sleep/wake.
DO_TIME_SHARING: bool = os.environ.get("RLIX_CONTROL_PLANE") == "rlix"

# F5/F9: Resolve hooks — use injected real implementation or no-op default.
from nemo_rl.algorithms.rlix_hooks import NoOpRLixHooks, RLixHooksProtocol
hooks: RLixHooksProtocol = rlix_hooks if rlix_hooks is not None else NoOpRLixHooks()

# Ensure we are running with a compatible async generation backend
assert _should_use_async_rollouts(master_config), (
"Async GRPO requires vLLM backend with vllm_cfg.async_engine=True. "
Expand Down Expand Up @@ -2515,7 +2527,10 @@ def async_grpo_train(
},
}

nccl_state_snapshot: dict[str, Any] | None = None

# Initialize trajectory collector with synchronized collection
# F9: Pass rlix_hooks so ATC can call end_progress_batch after each push.
trajectory_collector = AsyncTrajectoryCollector.options(
runtime_env=_tc_runtime_env
).remote(
Expand All @@ -2525,44 +2540,58 @@ def async_grpo_train(
master_config=master_config,
replay_buffer=replay_buffer,
start_step=step,
rlix_hooks=hooks,
)

# Start trajectory collection in background
collection_task = trajectory_collector.start_collection.remote(dataloader)

# Ensure collector knows initial weight version
trajectory_collector.set_weight_version.remote(weight_version)

# F6: Register collector handle with pipeline actor so _expand_workers can
# call set_weight_version after each selective sync (before routing activation).
hooks.on_trajectory_collector_created(trajectory_collector)

# F9: Progress begin/end state lives in the collector actor because that is
# where per-push reports are emitted. Open the first stream before
# collection starts.
ray.get(trajectory_collector.begin_progress_batch.remote(step, num_prompts_per_step))

# Start trajectory collection in background
collection_task = trajectory_collector.start_collection.remote(dataloader)

print("📦 Started continuous background trajectory collection")

print(
f"🚀 Starting async GRPO training with buffer_size={optimal_buffer_size}, max_age={max_trajectory_age_steps} steps"
)

print("⏳ Preparing policy generation for training...")
if NEED_REFIT and POLICY_GENERATION_STALE:
print("🔄 Refitting policy generation with actual model weights...")
try:
refit_policy_generation(policy, policy_generation, colocated_inference)
print("✅ Policy generation refit completed successfully")
POLICY_GENERATION_STALE = False
except Exception as e:
print(f"❌ Policy generation refit failed: {e}")
import traceback

traceback.print_exc()
return
else:
print("🔄 Preparing policy generation for inference...")
try:
policy_generation.prepare_for_generation()
print("✅ Policy generation preparation completed successfully")
except Exception as e:
print(f"❌ Policy generation preparation failed: {e}")
import traceback

traceback.print_exc()
return
# F5/F11: In RLix mode, skip initial refit and prepare_for_generation.
# Weights are synced on first scheduler expand; sleep/wake is scheduler-driven.
# Calling prepare_for_generation here would reinitialize already-running inference workers.
if not DO_TIME_SHARING:
if NEED_REFIT and POLICY_GENERATION_STALE:
print("🔄 Refitting policy generation with actual model weights...")
try:
refit_policy_generation(policy, policy_generation, colocated_inference)
print("✅ Policy generation refit completed successfully")
POLICY_GENERATION_STALE = False
except Exception as e:
print(f"❌ Policy generation refit failed: {e}")
import traceback

traceback.print_exc()
return
else:
print("🔄 Preparing policy generation for inference...")
try:
policy_generation.prepare_for_generation()
print("✅ Policy generation preparation completed successfully")
except Exception as e:
print(f"❌ Policy generation preparation failed: {e}")
import traceback

traceback.print_exc()
return

print("✅ Policy generation setup complete, proceeding to validation...")

Expand Down Expand Up @@ -2782,6 +2811,17 @@ def async_grpo_train(

# Training phase (same as sync version)
print("▶ Preparing for logprob inference...")
# F5: Block until scheduler grants actor_train GPUs.
# In RLix mode: scheduler asynchronously shrinks overlap inference
# workers before returning. In standalone mode: no-op.
hooks.before_training(step)
if DO_TIME_SHARING and nccl_state_snapshot:
from nemo_rl.models.megatron.nccl_offload import (
reload_megatron_nccl_groups,
)

reload_megatron_nccl_groups(nccl_state_snapshot)
nccl_state_snapshot = None
with timer.time("logprob_inference_prep"):
policy.prepare_for_lp_inference()

Expand Down Expand Up @@ -2853,7 +2893,44 @@ def async_grpo_train(

print("🔄 Synchronizing policy weights to trajectory collector…")
generation_logger_metrics = None
if NEED_REFIT:
if DO_TIME_SHARING:
# F5/F11: RLix mode — replace standalone refit with scheduler-
# driven expand. The scheduler's resize_infer(add=overlap_ranks)
# calls pipeline._expand_workers() which does the atomic
# wake + selective sync + version update + routing activation (F6).
#
with timer.time("weight_sync"):
# F11: PR #4 owns the Megatron NCCL destroy/reload
# implementation. This branch only invokes it after
# NeMo has offloaded training-side state.
policy.offload_after_refit()
from nemo_rl.models.megatron.nccl_offload import (
destroy_megatron_nccl_groups,
)

nccl_stats = destroy_megatron_nccl_groups()
nccl_state_snapshot = nccl_stats.get("state_snapshot") or None

# Notify scheduler: actor_train GPUs are free.
# Scheduler asynchronously triggers expand + weight sync.
published_version = hooks.after_training(step)
# RLix publishes version=cache_ready_step after active
# refresh completes. Fall back to step for older hooks.
weight_version = (
int(published_version)
if published_version is not None
else int(step)
)
next_progress_step = step + 1
if next_progress_step < master_config["grpo"]["max_num_steps"]:
ray.get(
trajectory_collector.begin_progress_batch.remote(
next_progress_step, num_prompts_per_step
)
)
POLICY_GENERATION_STALE = False
elif NEED_REFIT:
# Standalone mode — original refit path.
# Measure pending-generation wait as exposed_generation time
print("🔄 Coordinating with trajectory collector before refit...")
with timer.time("exposed_generation"):
Expand Down Expand Up @@ -2894,13 +2971,14 @@ def async_grpo_train(
# Pause trajectory collection during validation to reduce memory pressure
trajectory_collector.pause.remote()

if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(
policy, policy_generation, colocated_inference
)
POLICY_GENERATION_STALE = False
else:
policy_generation.prepare_for_generation()
if not DO_TIME_SHARING:
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(
policy, policy_generation, colocated_inference
)
POLICY_GENERATION_STALE = False
else:
policy_generation.prepare_for_generation()
val_metrics, validation_timings = validate(
policy_generation,
val_dataloader,
Expand Down
Loading