Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion megatron/rl/rl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,29 @@ def get_inference_interface(args, loop, model):
def get_rollout_generator(args, inference_interface, n_prompts, samples_per_group):
global _ROLLOUT_GENERATOR
if not (streaming := args.rl_partial_rollouts) or _ROLLOUT_GENERATOR is None:
# Autotune parallel generation tasks based on engine capacity.
engine = inference_interface._inference_engine
max_requests = engine.context.max_requests
dp_size = len(dist.get_process_group_ranks(engine.pg_collection.dp))
G = args.grpo_group_size
P = args.grpo_prompts_per_step
max_effective_tasks = max(1, dp_size * max_requests // G)
max_effective_lag = max_effective_tasks / P - 1
if args.rl_desired_lag is not None:
if args.rl_parallel_generation_tasks > max_effective_tasks:
print_rank_0(
f"WARNING: --rl-desired-lag {args.rl_desired_lag} results in "
f"{args.rl_parallel_generation_tasks} parallel tasks, which exceeds "
f"the maximum effective {max_effective_tasks} "
f"(DP={dp_size}, max_requests={max_requests}, G={G}). "
f"Maximum effective lag is {max_effective_lag:.2f}.")
else:
args.rl_parallel_generation_tasks = max_effective_tasks
print_rank_0(
f"Autotuned rl_parallel_generation_tasks={max_effective_tasks} "
f"(effective lag={max_effective_lag:.2f}, "
f"DP={dp_size}, max_requests={max_requests}, G={G}).")

agent = get_agent(args, parallel_generation_tasks=args.rl_parallel_generation_tasks)
request = GroupedRolloutRequest(
num_groups=args.rl_generation_batch_size,
Expand All @@ -471,7 +494,7 @@ def get_rollout_generator(args, inference_interface, n_prompts, samples_per_grou
'top_k': args.rl_default_top_k,
},
filter_groups_with_same_reward=args.grpo_filter_groups_with_same_reward,
enforce_order=args.rl_enforce_generation_order,
enforce_order=args.rl_use_strict_lag,
)
_ROLLOUT_GENERATOR = agent.get_grouped_rollouts(request)
return _ROLLOUT_GENERATOR
Expand Down
83 changes: 26 additions & 57 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import argparse
import dataclasses
import json
from math import gcd
import os
from pathlib import Path
import re
Expand Down Expand Up @@ -406,49 +407,26 @@ def validate_args(args, defaults={}):
"installed. See https://github.com/fzyzcjy/torch_memory_saver."
)

# Resolve deprecated --rl-parallel-generation-tasks -> --rl-num-parallel-generations.
assert args.rl_num_parallel_generations is None \
or args.rl_parallel_generation_tasks is None, \
"Cannot specify both --rl-num-parallel-generations and " \
"--rl-parallel-generation-tasks. Use --rl-num-parallel-generations " \
"(--rl-parallel-generation-tasks is deprecated)."
if args.rl_parallel_generation_tasks is not None:
print_rank_0(
"WARNING: --rl-parallel-generation-tasks is deprecated, "
"use --rl-num-parallel-generations instead.")
args.rl_num_parallel_generations = (
args.rl_parallel_generation_tasks * args.grpo_group_size)

# Resolve --rl-num-parallel-generations / --rl-num-parallel-generation-batches.
assert args.rl_num_parallel_generations is None \
or args.rl_num_parallel_generation_batches is None, \
"--rl-num-parallel-generations and --rl-num-parallel-generation-batches " \
"are mutually exclusive."
if args.rl_num_parallel_generations is not None:
# Resolve --rl-desired-lag / --rl-use-strict-lag into internal parameters.
P = args.grpo_prompts_per_step
if args.rl_use_strict_lag:
assert args.rl_desired_lag is not None, \
"--rl-use-strict-lag requires --rl-desired-lag."
if args.rl_desired_lag is not None:
assert args.rl_partial_rollouts, \
"--rl-num-parallel-generations requires --rl-partial-rollouts."
assert args.rl_num_parallel_generations % args.grpo_group_size == 0, \
f"--rl-num-parallel-generations ({args.rl_num_parallel_generations}) " \
f"must be divisible by --grpo-group-size ({args.grpo_group_size})."
args.rl_parallel_generation_tasks = (
args.rl_num_parallel_generations // args.grpo_group_size)
if args.rl_generation_batch_size is None:
"--rl-desired-lag requires --rl-partial-rollouts."
assert args.rl_desired_lag >= -1, \
f"--rl-desired-lag ({args.rl_desired_lag}) must be >= -1."
tasks = max(1, round((args.rl_desired_lag + 1) * P))
args.rl_parallel_generation_tasks = tasks
if args.rl_use_strict_lag:
args.rl_generation_batch_size = gcd(tasks, P)
else:
args.rl_generation_batch_size = 1
elif args.rl_num_parallel_generation_batches is not None:
assert args.rl_partial_rollouts, \
"--rl-num-parallel-generation-batches requires --rl-partial-rollouts."
if args.rl_generation_batch_size is None:
args.rl_generation_batch_size = args.grpo_prompts_per_step
args.rl_parallel_generation_tasks = (
args.rl_num_parallel_generation_batches * args.rl_generation_batch_size)
else:
if args.rl_generation_batch_size is None:
args.rl_generation_batch_size = 1
args.rl_generation_batch_size = 1
args.rl_parallel_generation_tasks = 512

# Derive enforce_order after all resolution is complete.
args.rl_enforce_generation_order = (args.rl_generation_batch_size > 1)

args.grpo_samples_per_iteration = args.grpo_prompts_per_step * args.grpo_group_size

if args.rl_use_sequence_packing:
Expand Down Expand Up @@ -2289,21 +2267,15 @@ def _add_rl_args(parser):
help="Number of GRPO groups (G in the paper).")
group.add_argument('--grpo-group-size', type=int, default=2,
help="Number of samples per a GRPO group.")
group.add_argument('--rl-num-parallel-generations', type=int, default=None,
help='Number of rollouts being generated by the inference engine simultaneously. '
'Internally divided by grpo_group_size. '
'Requires --rl-partial-rollouts. '
'Mutually exclusive with --rl-num-parallel-generation-batches.')
group.add_argument('--rl-num-parallel-generation-batches', type=int, default=None,
help='Number of generation batches in flight. '
'Set to L+1 to allow for L steps of staleness between the inference and training policies. '
'Each batch contains grpo_prompts_per_step groups by default. '
'Requires --rl-partial-rollouts. '
'Mutually exclusive with --rl-num-parallel-generations.')
group.add_argument('--rl-generation-batch-size', type=int, default=None,
help='Override the number of groups per generation batch. '
'Defaults to grpo_prompts_per_step when '
'--rl-num-parallel-generation-batches is set.')
group.add_argument('--rl-desired-lag', type=float, default=None,
help='Desired collection lag: the number of training steps worth of rollouts '
'generated ahead of the current training step. A lag of L means L+1 '
'steps worth of prompt groups are in flight simultaneously. '
'Requires --rl-partial-rollouts.')
group.add_argument('--rl-use-strict-lag', action='store_true', default=False,
help='Enforce strict ordering of generation batches so that the lag is '
'deterministic rather than statistical. When set, rollouts are yielded '
'in batch order. Requires --rl-desired-lag.')
group.add_argument('--grpo-iterations', type=int, default=2,
help="Number of iterations per a GRPO implementation.")
# As in DAPO, we keep upper/lower eps different.
Expand Down Expand Up @@ -2341,8 +2313,7 @@ def _add_rl_args(parser):
help='Allow inference to continue generating rollouts while training updates '
'the policy weights. This enables off-policy training where rollouts may '
'be generated with a stale version of the policy. Use '
'--rl-num-parallel-generations or --rl-num-parallel-generation-batches '
'to control the degree of staleness.')
'--rl-desired-lag to control the degree of staleness.')
group.add_argument('--rl-inference-logprobs-is-correction', action=argparse.BooleanOptionalAction, type=bool, default=False,
help='If set, use inference logprobs in importance sampling correction of the loss.')
group.add_argument('--rl-importance-sampling-truncation-coef', type=float, default=None,
Expand Down Expand Up @@ -2415,8 +2386,6 @@ def _add_rl_args(parser):
help='If set, verify that the model weights were correctly transferred by comparing forward pass outputs on'
'the first swap of model weights.')

group.add_argument('--rl-parallel-generation-tasks', type=int, default=None,
help='Deprecated: use --rl-num-parallel-generations instead.')
group.add_argument('--rl-skip-bos-token', action=argparse.BooleanOptionalAction, type=bool, default=False,
help='Skip BOS token at the beginning of the sequences. Default is False.')
group.add_argument('--rl-inference-parsers', nargs='*', default=[],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ MODEL_ARGS:
--rl-use-sequence-packing: true
--rl-sequence-packing-algo: fifo
--rl-offload-optimizer-during-inference: true
--rl-num-parallel-generations: 2
--rl-desired-lag: -0.5
--cuda-graph-impl: local
--micro-batch-size: 1
--global-batch-size: 4
Expand Down
Loading