Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion megatron/rl/agent/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable
from collections.abc import AsyncIterable, AsyncIterator
from typing import Generic, TypeVar

import numpy as np
Expand Down Expand Up @@ -299,6 +299,48 @@ async def shutdown_queue_when_done():
task.cancel()


class RolloutStream(AsyncIterator):
"""Wrapper around an async generator that supports non-blocking drain."""

def __init__(self, inner):
self._inner = inner
self._pending_task = None

async def __anext__(self):
if self._pending_task is not None:
task = self._pending_task
self._pending_task = None
return await task
return await self._inner.__anext__()

async def aclose(self):
if self._pending_task is not None:
self._pending_task.cancel()
self._pending_task = None
await self._inner.aclose()

async def try_next(self):
"""Attempt to get next item without blocking."""
inner_task = asyncio.ensure_future(self._inner.__anext__())
try:
return await asyncio.wait_for(asyncio.shield(inner_task), timeout=0.01)
except asyncio.TimeoutError:
self._pending_task = inner_task
return None
except StopAsyncIteration:
return None

def drain(self, n, loop):
"""Synchronously drain up to n items."""
items = []
for _ in range(n):
item = loop.run_until_complete(self.try_next())
if item is None:
break
items.append(item)
return items


class EvaluationAgent(Agent, ABC):
"""An agent that can take an inference interface and return a benchmark score."""

Expand Down
104 changes: 68 additions & 36 deletions megatron/rl/rl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
RewardEvaluationResult,
Rollout,
RolloutGroup,
RolloutStream,
Rollouts,
TokenRollout,
)
Expand Down Expand Up @@ -473,23 +474,15 @@ def get_rollout_generator(args, inference_interface, n_prompts, samples_per_grou
filter_groups_with_same_reward=args.grpo_filter_groups_with_same_reward,
enforce_order=args.rl_enforce_generation_order,
)
_ROLLOUT_GENERATOR = agent.get_grouped_rollouts(request)
_ROLLOUT_GENERATOR = RolloutStream(agent.get_grouped_rollouts(request))
return _ROLLOUT_GENERATOR


def get_environment_rollouts(
model: LanguageModule, inference_model: LanguageModule, optimizer: MegatronOptimizer, n_prompts: int, samples_per_group: int
):
"""Sample environment rollouts from an LLM.

Args:
model: Model to sample from.
inference_model: Inference model to use for inference.
n_prompts: Number of prompts to sample for across *all* data parallel workers.
samples_per_group: Amount of trajectories per prompt.
def colocated_inference(model, inference_model, optimizer, n_prompts, samples_per_group, rank):
"""Enter inference mode and collect rollouts via the colocated inference engine.

Returns:
GroupedRollouts object which is a nested list with each element being a list of rollouts of a group.
Handles optimizer offload/restore, model weight swap, and CUDA graph management.
Returns the collected rollouts on rank 0, None on other ranks.
"""
args = get_args()
nvtx_range = get_nvtx_range()
Expand Down Expand Up @@ -524,10 +517,7 @@ def get_environment_rollouts(
else:
inference_model = model

inference_pg_collection = get_attr_wrapped_model(inference_model[0], "pg_collection")
pg_size = get_pg_size(inference_pg_collection.ep)
assert (n_prompts % pg_size == 0), f"{n_prompts=} must be divisible by {pg_size=}"

rollouts = None
with nvtx_range("rl/rollout-collection", time=True):
loop = get_asyncio_loop()
with megatron_rl_inference_mode(
Expand All @@ -544,8 +534,6 @@ def get_environment_rollouts(
args, inference_interface, n_prompts, samples_per_group
)

# NOTE(jbarker): we need to double check this when using PP>1
rank = torch.distributed.get_rank()
with nvtx_range("rl/collect-rollouts", time=True):
if rank == 0:
log_single_rank(
Expand All @@ -567,15 +555,6 @@ def get_environment_rollouts(
assert False, "Unexpected group left in generator."
except StopAsyncIteration:
break
else:
# Just set up space to collect the rollouts
rollouts = [[None for _ in range(samples_per_group)] for _ in range(n_prompts)]

with nvtx_range("rl/sync-rollouts", time=True):
# Wait for Rollouts to be collected
# TODO(jbarker): double check why this isn't causing rank 0 memory allocations
torch.distributed.broadcast_object_list(rollouts, src=0)
logger.debug(f"Got rollouts on rank {rank}")

if args.rl_offload_optimizer_during_inference:
with nvtx_range("rl/restore-optimizer-after-inference", time=True):
Expand All @@ -584,6 +563,67 @@ def get_environment_rollouts(
with nvtx_range("rl/restore/optimizer-state", time=True):
optimizer.restore_from_cpu()

return rollouts


def get_environment_rollouts(
model: LanguageModule,
inference_model: LanguageModule,
optimizer: MegatronOptimizer,
n_prompts: int,
samples_per_group: int,
):
"""Sample environment rollouts from an LLM.

Args:
model: Model to sample from.
inference_model: Inference model to use for inference.
n_prompts: Number of prompts to sample for across *all* data parallel workers.
samples_per_group: Amount of trajectories per prompt.

Returns:
GroupedRollouts object which is a nested list
where each element being a list of rollouts of a group.
"""
args = get_args()
nvtx_range = get_nvtx_range()
# NOTE(jbarker): we need to double check this when using PP>1
rank = torch.distributed.get_rank()

inference_pg_collection = get_attr_wrapped_model(model[0], "pg_collection")
pg_size = get_pg_size(inference_pg_collection.ep)
assert (n_prompts % pg_size == 0), f"{n_prompts=} must be divisible by {pg_size=}"

fast_drained = False

# Fast path: when partial rollouts are enabled and the generator already has
# pre-generated groups buffered, drain them without entering inference mode
# (skips optimizer offload, weight swap, CUDA graph toggle, etc.).
if args.rl_partial_rollouts and _ROLLOUT_GENERATOR is not None:
loop = get_asyncio_loop()
if rank == 0:
drained = _ROLLOUT_GENERATOR.drain(n_prompts, loop)
fast_drained = len(drained) >= n_prompts
# All ranks must agree on the path (full inference has collective ops)
flag = torch.tensor([fast_drained], device='cuda', dtype=torch.bool)
torch.distributed.broadcast(flag, src=0)
fast_drained = flag.item()

if fast_drained and rank == 0:
rollouts = drained[:n_prompts]

if not fast_drained:
rollouts = colocated_inference(
model, inference_model, optimizer, n_prompts, samples_per_group, rank
)

# Shared broadcast + return path (used by both fast drain and full inference)
if rank != 0:
rollouts = [[None for _ in range(samples_per_group)] for _ in range(n_prompts)]
with nvtx_range("rl/sync-rollouts", time=True):
torch.distributed.broadcast_object_list(rollouts, src=0)
logger.debug(f"Got rollouts on rank {rank}")

if lang_rl_log_dir and rank == get_pg_rank(inference_pg_collection.tp):
with open(
lang_rl_log_dir
Expand Down Expand Up @@ -1516,14 +1556,6 @@ def prepare_data_for_update(
data = TensorDataset(*dataset_tensors)
loader = DataLoader(data, batch_size=args.micro_batch_size)

with nvtx_range("rl/log-wandb-tb", time=True):
maybe_log_training_metrics(
group_stats=group_stats,
current_iteration=args.curr_iteration,
tokenizer=tokenizer,
example_groups=example_groups,
)

return RerunDataIterator(itertools.cycle(loader)), group_stats, example_groups


Expand Down
70 changes: 69 additions & 1 deletion tests/unit_tests/rl/test_grouped_rollouts.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,42 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import asyncio
from unittest.mock import MagicMock
from contextlib import nullcontext
from unittest.mock import MagicMock, patch

import pytest
import torch

from megatron.rl import rl_utils
from megatron.rl.agent.api import (
GroupedRolloutGenerator,
GroupedRolloutRequest,
Rollout,
RolloutGenerator,
RolloutGroup,
RolloutStream,
)
from megatron.rl.agent.weighted_multi_task import AgentConfig, WeightedMultiTask
from megatron.rl.inference import ReturnsRaw


def _make_group(i, rollouts_per_group=1):
return RolloutGroup(
rollouts=[
Rollout(
trajectory=[f"t{i}"],
reward=float(i),
policy_epoch=[[(0, 0)]],
kv_cache_epoch=[[(0, 0)]],
num_evictions=[0],
)
for _ in range(rollouts_per_group)
],
batch_id=i,
index_in_batch=0,
)


class MockGenerator(RolloutGenerator, GroupedRolloutGenerator):
"""Mock generator with configurable per-call delays."""

Expand Down Expand Up @@ -117,3 +138,50 @@ async def spy(req, orig=original):
assert sub_req.num_groups in (1, 3) # distributed proportionally by weight
assert sub_req.enforce_order == request.enforce_order
assert sub_req.streaming == request.streaming


@pytest.mark.parametrize("buffered_groups, expect_colocated_call", [
pytest.param(6, False, id="drain_sufficient_skips_inference"),
pytest.param(1, True, id="drain_insufficient_calls_inference"),
])
def test_get_environment_rollouts(self, buffered_groups, expect_colocated_call):
n_prompts = 4

async def gen():
for i in range(buffered_groups):
yield _make_group(i)
# Block forever so drain sees "nothing available" after the buffered items.
await asyncio.sleep(1000)

def mock_nvtx(*args, **kwargs):
return nullcontext()

mock_args = MagicMock()
mock_args.rl_partial_rollouts = True
mock_args.curr_iteration = 1
mock_args.langrl_env_config = "test.yaml"

loop = asyncio.new_event_loop()
mock_colocated = MagicMock(return_value=[_make_group(i) for i in range(n_prompts)])
try:
with patch.multiple('megatron.rl.rl_utils',
colocated_inference=mock_colocated,
get_args=MagicMock(return_value=mock_args),
get_nvtx_range=MagicMock(return_value=mock_nvtx),
get_asyncio_loop=MagicMock(return_value=loop),
get_attr_wrapped_model=MagicMock(return_value=MagicMock()),
get_pg_size=MagicMock(return_value=1),
get_pg_rank=MagicMock(return_value=0),
lang_rl_log_dir=None,
_ROLLOUT_GENERATOR=RolloutStream(gen()),
), patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.broadcast'), \
patch('torch.distributed.broadcast_object_list'):
rollouts = rl_utils.get_environment_rollouts(
model=[MagicMock()], inference_model=None,
optimizer=MagicMock(), n_prompts=n_prompts, samples_per_group=1,
)
assert mock_colocated.called == expect_colocated_call
assert len(rollouts) == n_prompts
finally:
loop.close()
Loading