-
Notifications
You must be signed in to change notification settings - Fork 316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why isn't the reference model re-initialized for each epoch in GRPO? #546
Comments
Hi, thanks for pointing this out! This is actually iterative GRPO rather than base GRPO, but I agree that this could be a useful update, especially since the deepseek GRPO paper finds it helps performance. We actually did iterative PPO for Olmo 2 (https://arxiv.org/abs/2501.00656, figure 13). You can do this by just... starting the training job with the latest checkpoint from the last RL job. You could do similar with GRPO. Adding a more explicit flag for when to do this would make the end user experience nicer. As for the linked comment, I believe historically PPO-trained LMs have not updated the reference model, often because drifting far from the base is linked with extreme degradations. Additionally, when using neural-based rewards, the LM often overoptimizes fairly quickly, and so the training simply doesn't run for that long (and people typically spend a lot of time carefully tuning the KL beta to make sure that the policy doesn't get too far from the original reference). For rule-based rewards, though, it does indeed make more sense to run a more iterative version! |
This is kind of interesting: also if the reference model is updated very frequently, would it be the same as not having a reference model at all? |
I see. Thanks for your quick reply! Just saw a mirror PR in Moreover, it seems that we only need to move the responses sampling step to the start of each epoch and add the vllm weights runtime updating which has been resolved recently like this: vllm-project/vllm#5723 to modify this base GRPO to iterative GRPO. Right? BTW, the current sampling step is also kind of tricky. In |
It would require a bit more changes. The snippet you linked to is actually the This is separate to epochs in the pseudocode, which are epochs over the dataset -- that is, how many passes over the data we have done, rather than passes over some subset of samples. To properly change, we would need to track when the training_step passes into a new epoch, and update the reference model accordingly. |
Also: technically, yes, once you take more than 1 step we are off-policy. In the deepseek paper, they say they do 1 'ppo epoch' (i.e., remain on-policy). But in that case, all the clipping logic used in PPO and included in the GRPO paper doesn't come into effect. PPO/GRPO includes this to ensure we don't go 'too far' from the original policy, so our updates are still 'on-policy-ish' (that is, we haven't gone that far from our original policy). @vwxyzjn feel free to correct me if I'm slightly wrong here 😅 |
I see. The clipping logic does not work as they said they only have a single update for the policy at each exploration, which means we have the And I still have one more question: For Line 7 in the picture, the GRPO pseudocode said that we should sample prompts in each batch with the latest policy model. However, I observe that most GRPO implementations complete all generations for the whole training data at the initialization step. Is this because the reloading of vllm with the latest weights and re-sampling at each batch are computation-intensive in the training part? Or because we just have ~one epoch and the deviation is not that much (so we are okay with it)? |
For this, we actually do resample at each step! To break it down, at each train step:
So our samples are on-policy, but you can configure to take more and more steps on the responses, which becomes more 'offline'. the PPO clipping logic will help here, but yes, performance will drop if you go more 'offline', usually. One final caveat is that we usually run our code aync, which means we overlap generation and training -- we let the model take a train step while we generate samples from the previous step. So, technically, the samples are usually from the model 1 step prior. There is some past work showing this doesn't harm performance (https://arxiv.org/abs/2410.18252), potentially because only being 1 step off isn't that bad (vs being many steps off). Hopefully that makes sense! The code is definitely quite complicated. |
Hi. Thanks for your efficient updates on the GRPO implementation!
When comparing the
grpo_vllm_thread_ray_gtrl.py
and the pseudocode from the Deepseek-Math paper, I found that the implemented GRPO trainer does not re-initialize the reference model for each epoch, which contradicts the pseudocode in the paper. (This also holds for PPOTrainer in this repo.)open-instruct/open_instruct/grpo_vllm_thread_ray_gtrl.py
Lines 1177 to 1180 in 4afebde
So, I wonder whether you also follow the engineering trend like this: huggingface/trl#1112. (Did anyone try adding the reference model update step and demonstrate that in RL4LLM this could be skipped? I am just curious as the reinitialization does not need to take too much computation. Why not follow the pseudocodes?)
The text was updated successfully, but these errors were encountered: