Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions code/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ On release, entries get moved under a version heading.

## Unreleased

- 2026-04-20: added SDPO (Self-Distillation Policy Optimization, Hübotter et al., 2026) as a hybrid GRPO + reverse-KL self-distillation loss in `policy_gradients/`. The successful rollout of each group becomes a teacher demonstration; `distill_mask` excludes the successful rollout itself and no-success groups so the surrogate does not collapse to SFT. Adds `distillation_weight`, `success_reward_threshold`, and SDPO stabilization config fields plus `policy_gradients/configs/sdpo.yaml`. Full reference run pending.
- 2026-04-16: [PR #374](https://github.com/natolambert/rlhf-book/pull/374) added CI ruff lint/format check for PRs touching `code/`, applied ruff format to all existing files, fixed lint errors (unused imports, unsorted imports, `zip()` without `strict=`), and documented linting in README.
- 2026-04-15: [PR #372](https://github.com/natolambert/rlhf-book/pull/372) documented build-essential requirement for Ubuntu/Debian and uv version guidance in README install section.
- 2026-04-15: [PR #370](https://github.com/natolambert/rlhf-book/pull/370) cleaned up CLAUDE.md for generic use, removed dead LoRA/QLoRA references from RM docstrings and base.py, moved ORPO/SimPO debug notes to direct_alignment/ORPO_SIMPO.md.
Expand Down
5 changes: 5 additions & 0 deletions code/policy_gradients/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ See the parent [`code/README.md`](../README.md) for installation, configuration,
| **GSPO** | `gspo.yaml` | Group-Sequence Policy Optimization ([Zheng et al., 2025](https://arxiv.org/abs/2505.13818)) |
| **CISPO** | `cispo.yaml` | Clipped Importance Sampling PO ([MiniMax, 2025](https://arxiv.org/abs/2506.13585)) |
| **SAPO** | `sapo.yaml` | Soft Adaptive Policy Optimization ([Gao et al., 2025](https://arxiv.org/abs/2511.20347)) |
| **SDPO** | `sdpo.yaml` | Self-Distillation Policy Optimization ([Hübotter et al., 2026](https://arxiv.org/abs/2601.20802)) — hybrid GRPO + reverse-KL self-distillation |

## Reference Runs

Expand All @@ -31,6 +32,7 @@ See the parent [`code/README.md`](../README.md) for installation, configuration,
| **GSPO** | [run](https://wandb.ai/natolambert/rlhf-book/runs/10sxytli) | ✅ Validated |
| **CISPO** | [run](https://wandb.ai/natolambert/rlhf-book/runs/6dg0m06n) | ✅ Validated |
| **SAPO** | [run](https://wandb.ai/natolambert/rlhf-book/runs/79608nwk) | ✅ Validated |
| **SDPO** | — | ⏳ Pending full validation |

## Quick Start

Expand All @@ -46,6 +48,9 @@ uv run python -m policy_gradients.train --config policy_gradients/configs/reinfo

# PPO with value function
uv run python -m policy_gradients.train --config policy_gradients/configs/ppo.yaml

# SDPO hybrid GRPO + self-distillation
uv run python -m policy_gradients.train --config policy_gradients/configs/sdpo.yaml
```

## TODOs for Community Contributions
Expand Down
4 changes: 3 additions & 1 deletion code/policy_gradients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .buffer import Experience, ReplayBuffer
from .config import Config, load_config
from .loss import CISPOLoss, GRPOLoss, GSPOLoss, PPOLoss, ReinforceLoss
from .loss import CISPOLoss, GRPOLoss, GSPOLoss, PPOLoss, ReinforceLoss, SAPOLoss, SDPOLoss


__all__ = [
Expand All @@ -21,4 +21,6 @@
"PPOLoss",
"ReinforceLoss",
"CISPOLoss",
"SAPOLoss",
"SDPOLoss",
]
9 changes: 9 additions & 0 deletions code/policy_gradients/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ class Experience:
- log_probs_old: Log probabilities from the rollout policy
- log_probs_ref: Log probabilities from the reference policy (for KL)
- values_old: Value estimates (for PPO)
- teacher_sequence_ids: SDPO teacher input = teacher_prompt + completion
- teacher_attention_mask: Attention mask for teacher_sequence_ids
- teacher_action_mask: Mask for completion tokens in teacher sequence
- distill_mask: Per-sample {0,1} — 0 when sample is excluded from the SDPO
distillation term (no-success group, or the successful rollout itself)
"""

sequence_ids: torch.Tensor
Expand All @@ -32,6 +37,10 @@ class Experience:
log_probs_old: torch.Tensor | None = None
log_probs_ref: torch.Tensor | None = None
values_old: torch.Tensor | None = None
teacher_sequence_ids: torch.Tensor | None = None
teacher_attention_mask: torch.Tensor | None = None
teacher_action_mask: torch.Tensor | None = None
distill_mask: torch.Tensor | None = None

def to(self, device: torch.device) -> Self:
"""Move all tensors to the specified device."""
Expand Down
25 changes: 23 additions & 2 deletions code/policy_gradients/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,23 @@ class Config(BaseModel):

Attributes:
data: Dataset configuration
loss: Loss function (reinforce, rloo, ppo, grpo, drgrpo, gspo, cispo, sapo)
loss: Loss function (reinforce, rloo, ppo, grpo, drgrpo, gspo, cispo, sapo, sdpo)
model_name: HuggingFace model identifier

# Clipping (GRPO, DrGRPO, GSPO, CISPO, PPO)
# Clipping (GRPO, DrGRPO, GSPO, CISPO, PPO, SDPO-hybrid)
clip_eps_lo: Lower clipping bound for policy ratio
clip_eps_hi: Upper clipping bound for policy ratio

# SAPO-specific
sapo_temp_pos: Sigmoid temperature for positive advantages
sapo_temp_neg: Sigmoid temperature for negative advantages

# SDPO-specific
distillation_weight: Weight on the self-distillation term added to GRPO loss
success_reward_threshold: Minimum reward for a rollout to act as the teacher demo
sdpo_teacher_ema_rate: EMA update rate for an optional SDPO teacher copy (0 disables)
sdpo_is_clip: Importance-sampling clip applied to the SDPO distillation term

# PPO-specific
clip_eps_val: Clipping bound for value function
gamma: Discount factor for GAE
Expand Down Expand Up @@ -90,6 +96,12 @@ class Config(BaseModel):
sapo_temp_pos: float = 1.0
sapo_temp_neg: float = 1.05

# SDPO-specific params (hybrid GRPO + self-distillation)
distillation_weight: float = 1.0
success_reward_threshold: float = 1.0
sdpo_teacher_ema_rate: float = 0.0
sdpo_is_clip: float | None = 2.0

# KL penalty (optional, for REINFORCE/RLOO/GRPO when beta > 0)
beta: float = 0.0
ref_model_device_id: int = 0
Expand Down Expand Up @@ -124,6 +136,15 @@ def validate_rollout_batch_size(self) -> "Config":
raise ValueError(
"prompts_per_step * num_rollouts must be divisible by rollout_batch_size."
)
if self.loss == "sdpo":
if self.num_rollouts < 2:
raise ValueError("SDPO requires at least 2 rollouts per prompt.")
if self.distillation_weight <= 0:
raise ValueError("SDPO requires distillation_weight > 0.")
if not 0.0 <= self.sdpo_teacher_ema_rate <= 1.0:
raise ValueError("sdpo_teacher_ema_rate must be in [0, 1].")
if self.sdpo_is_clip is not None and self.sdpo_is_clip <= 0:
raise ValueError("sdpo_is_clip must be positive when set.")
return self


Expand Down
34 changes: 34 additions & 0 deletions code/policy_gradients/configs/sdpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
data:
size: 3000
specs:
- name: spell_backward
weight: 1
config:
min_word_len: 3
max_word_len: 10

loss: sdpo
model_name: Qwen/Qwen3-1.7B
clip_eps_lo: 0.2
clip_eps_hi: 0.2
beta: 0.0 # KL penalty coefficient (0 = disabled, set >0 to enable)
distillation_weight: 1.0
success_reward_threshold: 1.0 # accuracy reward = 1.0 when answer is correct
sdpo_teacher_ema_rate: 0.0 # set to 0.01 for a more paper-like regularized teacher if memory allows
sdpo_is_clip: 2.0
lr: 5e-6
temperature: 0.6
top_p: 0.95
top_k: 20
min_p: 0.0
max_new_tokens: 512
prompts_per_step: 4
num_rollouts: 8
rollout_batch_size: 8
train_batch_size: 2
batch_acc: 4
max_norm: 1.0
seed: 42
model_device_id: 0
wandb_project: rlhf-book
wandb_run_name: sdpo_spell_backwards
58 changes: 58 additions & 0 deletions code/policy_gradients/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# - GSPO (Zheng et al., 2025)
# - CISPO (MiniMax, 2025)
# - SAPO (Qwen Team, 2025)
# - SDPO (Hübotter et al., 2026)

import torch
import torch.nn as nn
Expand Down Expand Up @@ -196,6 +197,63 @@ def forward(self, log_probs: torch.Tensor, experience: Experience, **kwargs) ->
return loss


class SDPOLoss(nn.Module):
"""Self-Distillation Policy Optimization loss (Hübotter et al., 2026).

Hybrid objective: GRPO policy loss + reverse-KL self-distillation.
The teacher is the same policy model conditioned on a richer context that
includes the highest-reward rollout of the current group as a demonstration.
The distillation term uses the token-level REINFORCE surrogate for
reverse-KL(student || teacher):

log_ratio = (student_log_probs - teacher_log_probs).detach()
distill_loss = log_ratio * student_log_probs

The successful rollout itself (and samples in groups with no successful
rollout) are excluded from the distillation term via `distill_mask`, since
the teacher would otherwise see the target completion verbatim in its own
context and the surrogate would collapse to SFT on that completion.
"""

def __init__(
self,
clip_eps_lo: float,
clip_eps_hi: float,
beta: float,
distillation_weight: float,
sdpo_is_clip: float | None = None,
**kwargs,
) -> None:
super().__init__()
self.grpo = GRPOLoss(clip_eps_lo=clip_eps_lo, clip_eps_hi=clip_eps_hi, beta=beta)
self.distillation_weight = distillation_weight
self.sdpo_is_clip = sdpo_is_clip

def forward(
self,
log_probs: torch.Tensor,
experience: Experience,
student_comp_log_probs: torch.Tensor | None = None,
teacher_comp_log_probs: torch.Tensor | None = None,
comp_action_mask: torch.Tensor | None = None,
old_student_comp_log_probs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
grpo_loss = self.grpo(log_probs=log_probs, experience=experience)

log_ratio = (student_comp_log_probs - teacher_comp_log_probs).detach()
distill_token_loss = log_ratio * student_comp_log_probs
if self.sdpo_is_clip is not None and old_student_comp_log_probs is not None:
approx_log_ratio = (student_comp_log_probs - old_student_comp_log_probs).detach()
approx_log_ratio = approx_log_ratio.clamp(min=-20.0, max=20.0)
distill_token_loss = distill_token_loss * approx_log_ratio.exp().clamp(max=self.sdpo_is_clip)
# Broadcast distill_mask [B, 1] over completion tokens [B, M].
effective_mask = comp_action_mask * experience.distill_mask
distill_loss = masked_mean(distill_token_loss, effective_mask, dim=-1).mean(dim=0)

return grpo_loss + self.distillation_weight * distill_loss


class PPOLoss(nn.Module):
"""Proximal Policy Optimization loss (Schulman et al., 2017).

Expand Down
Loading