You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When args.number_samples_per_prompt > 1, the learning rate will reduce to 0 faster if we divide by args.number_samples_per_prompt.
For args.number_samples_per_prompt > 1, there are multiple updates taking place in a single training step; so for every model.step() happening inside a training step, the learning rate decreases and eventually becomes 0 earlier than it should be.
The above num_training_steps will work fine when num_samples_per_prompt is 1.
Also, the args.num_train_epochs seems redundant in the code.
Hi,
This piece of code in both ppo/grpo codebase seems incorrect (for args.num_samples_per_prompt > 1):
When args.number_samples_per_prompt > 1, the learning rate will reduce to 0 faster if we divide by args.number_samples_per_prompt.
For args.number_samples_per_prompt > 1, there are multiple updates taking place in a single training step; so for every model.step() happening inside a training step, the learning rate decreases and eventually becomes 0 earlier than it should be.
The above num_training_steps will work fine when num_samples_per_prompt is 1.
Also, the args.num_train_epochs seems redundant in the code.
To fix this, we can edit the code as:
The text was updated successfully, but these errors were encountered: