Add SDPO to policy gradients#387
Conversation
|
Probably a good idea to first merge #385 since it decouples the rollouts and the policy updates - a breaking change. |
🫨🫨🤨🤨 |
|
#385 has been merged and has introduced quite large refactoring. I know the #360 suggests to incorporate SDPO under policy_gradients/ - and while technically feasible - it will hurt readability. If we want to compare the performance of distillation vs. RL we should make an effort to sync the same metrics in the same WandB project for 1-to-1 comparison. Apart from that, I think it will benefit to have a separate module in the For example, trl also implements SDPO under a With that said, I think we can take a look at the new layout of the
I might miss something else, but think covered most of my points |

Summary
Refs #360.
This PR adds SDPO (Self-Distillation Policy Optimization) to
code/policy_gradients/, following the issue discussion that SDPO fits the online policy-gradient training loop better thandirect_alignment/.Changes include:
SDPOLoss, implemented as GRPO plus token-level reverse-KL self-distillation.action_mask, so variable-length generations do not accidentally include prompt tokens.sdpo_teacher_ema_rateandsdpo_is_clip.policy_gradients/configs/sdpo.yaml, README/changelog updates, and inclusion inrun_all_policy_gradients.sh.Validation
Local checks passed:
.venv/bin/python -m compileall policy_gradientsfromcode/uvx ruff check code/policy_gradientspolicy_gradients/configs/sdpo.yamlloads successfully withpolicy_gradients.config.load_configI also ran reduced local wandb probes in
sdpo-testbecause I do not have enough compute for the full reference run:spell_backwardsmoke test finished, but SDPO stayed inactive because the model did not produce successful rollouts: https://wandb.ai/bokicasheks-loka/sdpo-test/runs/ijgtli7kbasic_arithmeticreduced validation finished with nonzero reward/loss/grad norm, exercising the SDPO path: https://wandb.ai/bokicasheks-loka/sdpo-test/runs/5s619wr7Final reduced-run summary:
avg_reward=0.7708,loss=-0.2717,grad_norm=3.9375.Maintainer Request
@natolambert, could you run the full
policy_gradients/configs/sdpo.yamlreference job on proper compute and decide what official wandb run should go into the README table? I left the README status as pending full validation rather than treating my reduced local run as the canonical result.