Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why isn't the reference model re-initialized for each epoch in GRPO? #546

Open
Jerrrrykun opened this issue Feb 3, 2025 · 7 comments
Open

Comments

@Jerrrrykun
Copy link

Hi. Thanks for your efficient updates on the GRPO implementation!

When comparing the grpo_vllm_thread_ray_gtrl.py and the pseudocode from the Deepseek-Math paper, I found that the implemented GRPO trainer does not re-initialize the reference model for each epoch, which contradicts the pseudocode in the paper. (This also holds for PPOTrainer in this repo.)

Image

for epoch_idx in range(args.num_epochs):
b_inds = np.random.permutation(args.local_rollout_batch_size * args.number_samples_per_prompt)
minibatch_idx = 0
for mini_batch_start in range(

So, I wonder whether you also follow the engineering trend like this: huggingface/trl#1112. (Did anyone try adding the reference model update step and demonstrate that in RL4LLM this could be skipped? I am just curious as the reinitialization does not need to take too much computation. Why not follow the pseudocodes?)

@hamishivi
Copy link
Collaborator

Hi, thanks for pointing this out! This is actually iterative GRPO rather than base GRPO, but I agree that this could be a useful update, especially since the deepseek GRPO paper finds it helps performance.

We actually did iterative PPO for Olmo 2 (https://arxiv.org/abs/2501.00656, figure 13). You can do this by just... starting the training job with the latest checkpoint from the last RL job. You could do similar with GRPO. Adding a more explicit flag for when to do this would make the end user experience nicer.

As for the linked comment, I believe historically PPO-trained LMs have not updated the reference model, often because drifting far from the base is linked with extreme degradations. Additionally, when using neural-based rewards, the LM often overoptimizes fairly quickly, and so the training simply doesn't run for that long (and people typically spend a lot of time carefully tuning the KL beta to make sure that the policy doesn't get too far from the original reference). For rule-based rewards, though, it does indeed make more sense to run a more iterative version!

@vwxyzjn
Copy link
Collaborator

vwxyzjn commented Feb 3, 2025

This is kind of interesting: also if the reference model is updated very frequently, would it be the same as not having a reference model at all?

@Jerrrrykun
Copy link
Author

Jerrrrykun commented Feb 3, 2025

Hi, thanks for pointing this out! This is actually iterative GRPO rather than base GRPO, but I agree that this could be a useful update, especially since the deepseek GRPO paper finds it helps performance.

We actually did iterative PPO for Olmo 2 (https://arxiv.org/abs/2501.00656, figure 13). You can do this by just... starting the training job with the latest checkpoint from the last RL job. You could do similar with GRPO. Adding a more explicit flag for when to do this would make the end user experience nicer.

As for the linked comment, I believe historically PPO-trained LMs have not updated the reference model, often because drifting far from the base is linked with extreme degradations. Additionally, when using neural-based rewards, the LM often overoptimizes fairly quickly, and so the training simply doesn't run for that long (and people typically spend a lot of time carefully tuning the KL beta to make sure that the policy doesn't get too far from the original reference). For rule-based rewards, though, it does indeed make more sense to run a more iterative version!

I see. Thanks for your quick reply! Just saw a mirror PR in trl: huggingface/trl#2700

Moreover, it seems that we only need to move the responses sampling step to the start of each epoch and add the vllm weights runtime updating which has been resolved recently like this: vllm-project/vllm#5723 to modify this base GRPO to iterative GRPO. Right?

BTW, the current sampling step is also kind of tricky. In grpo_vllm_thread_ray_gtrl.py, we finish all the responses sampling before entering each training epoch (off-policy). But the pseudocode said we should do the sampling for each batch (on policy).🤔

@hamishivi
Copy link
Collaborator

hamishivi commented Feb 4, 2025

It would require a bit more changes. The snippet you linked to is actually the ppo_epochs (GRPO iterations in the pseudocode you linked), which are different to the data-level epochs (confusing, sorry) -- this is because in RL sometimes people sample data, and then perform a few passes (epochs) of training on it with the PPO objective. Actually, by default we do 4 passes over it in this repo!

This is separate to epochs in the pseudocode, which are epochs over the dataset -- that is, how many passes over the data we have done, rather than passes over some subset of samples.

To properly change, we would need to track when the training_step passes into a new epoch, and update the reference model accordingly.

@hamishivi
Copy link
Collaborator

hamishivi commented Feb 4, 2025

Also: technically, yes, once you take more than 1 step we are off-policy. In the deepseek paper, they say they do 1 'ppo epoch' (i.e., remain on-policy). But in that case, all the clipping logic used in PPO and included in the GRPO paper doesn't come into effect. PPO/GRPO includes this to ensure we don't go 'too far' from the original policy, so our updates are still 'on-policy-ish' (that is, we haven't gone that far from our original policy).

@vwxyzjn feel free to correct me if I'm slightly wrong here 😅

@Jerrrrykun
Copy link
Author

Jerrrrykun commented Feb 5, 2025

Also: technically, yes, once you take more than 1 step we are off-policy. In the deepseek paper, they say they do 1 'ppo epoch' (i.e., remain on-policy). But in that case, all the clipping logic used in PPO and included in the GRPO paper doesn't come into effect. PPO/GRPO includes this to ensure we don't go 'too far' from the original policy, so our updates are still 'on-policy-ish' (that is, we haven't gone that far from our original policy).

@vwxyzjn feel free to correct me if I'm slightly wrong here 😅

I see. The clipping logic does not work as they said they only have a single update for the policy at each exploration, which means we have the $\mu$ as 1 in the pseudocode. Thanks for your explanations!

And I still have one more question: For Line 7 in the picture, the GRPO pseudocode said that we should sample prompts in each batch with the latest policy model. However, I observe that most GRPO implementations complete all generations for the whole training data at the initialization step. Is this because the reloading of vllm with the latest weights and re-sampling at each batch are computation-intensive in the training part? Or because we just have ~one epoch and the deviation is not that much (so we are okay with it)?

@hamishivi
Copy link
Collaborator

For this, we actually do resample at each step!

To break it down, at each train step:

  1. We sample from the model at lines 999-10243, using vllm (https://github.com/allenai/open-instruct/blob/main/open_instruct/grpo_vllm_thread_ray_gtrl.py#L999-L1043)
  2. We post-process the responses, compute reward scores, etc. lines 1055-1194. (https://github.com/allenai/open-instruct/blob/main/open_instruct/grpo_vllm_thread_ray_gtrl.py#L1055-L1194)
  3. We take training steps over these responses, potentially multiple passes, and potentially in multiple batches in lines 1198-1291. (https://github.com/allenai/open-instruct/blob/main/open_instruct/grpo_vllm_thread_ray_gtrl.py#L1198-L1291). You can setup the hyperparameters such that you only do one pass and one batch (or simulate one batch with gradient accumulation).
  4. We repeat the process. We call steps 1-4 a grpo/ppo training step.

So our samples are on-policy, but you can configure to take more and more steps on the responses, which becomes more 'offline'. the PPO clipping logic will help here, but yes, performance will drop if you go more 'offline', usually.

One final caveat is that we usually run our code aync, which means we overlap generation and training -- we let the model take a train step while we generate samples from the previous step. So, technically, the samples are usually from the model 1 step prior. There is some past work showing this doesn't harm performance (https://arxiv.org/abs/2410.18252), potentially because only being 1 step off isn't that bad (vs being many steps off).

Hopefully that makes sense! The code is definitely quite complicated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants