Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scheduler Issue in PPO/GRPO implementation #537

Closed
ashish230897 opened this issue Jan 31, 2025 · 0 comments · Fixed by #560
Closed

Scheduler Issue in PPO/GRPO implementation #537

ashish230897 opened this issue Jan 31, 2025 · 0 comments · Fixed by #560

Comments

@ashish230897
Copy link

Hi,

This piece of code in both ppo/grpo codebase seems incorrect (for args.num_samples_per_prompt > 1):

args.num_training_steps = args.total_episodes // (args.rollout_batch_size * args.number_samples_per_prompt)
num_training_steps = args.num_training_steps * args.num_train_epochs * args.num_epochs
warm_up_steps = args.warm_up_steps
if args.warmup_ratio >= 0.0:
    warm_up_steps = int(num_training_steps * args.warmup_ratio)
scheduler = get_scheduler(
        args.lr_scheduler_type,
        optimizer=self.optimizer,
        num_warmup_steps=warm_up_steps,
        num_training_steps=num_training_steps,
    )

When args.number_samples_per_prompt > 1, the learning rate will reduce to 0 faster if we divide by args.number_samples_per_prompt.
For args.number_samples_per_prompt > 1, there are multiple updates taking place in a single training step; so for every model.step() happening inside a training step, the learning rate decreases and eventually becomes 0 earlier than it should be.
The above num_training_steps will work fine when num_samples_per_prompt is 1.

Also, the args.num_train_epochs seems redundant in the code.

To fix this, we can edit the code as:

num_scheduler_steps = args.num_training_steps * args.num_epochs * args.number_samples_per_prompt
warm_up_steps = args.warm_up_steps
if args.warmup_ratio >= 0.0:
    warm_up_steps = int(num_scheduler_steps * args.warmup_ratio)
scheduler = get_scheduler(
        args.lr_scheduler_type,
        optimizer=self.optimizer,
        num_warmup_steps=warm_up_steps,
        num_training_steps=num_scheduler_steps,
    )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant