diff --git a/code/CHANGELOG.md b/code/CHANGELOG.md index 257d2bae..a3ab2d89 100644 --- a/code/CHANGELOG.md +++ b/code/CHANGELOG.md @@ -5,6 +5,7 @@ On release, entries get moved under a version heading. ## Unreleased +- 2026-04-20: added SDPO (Self-Distillation Policy Optimization, Hübotter et al., 2026) as a hybrid GRPO + reverse-KL self-distillation loss in `policy_gradients/`. The successful rollout of each group becomes a teacher demonstration; `distill_mask` excludes the successful rollout itself and no-success groups so the surrogate does not collapse to SFT. Adds `distillation_weight`, `success_reward_threshold`, and SDPO stabilization config fields plus `policy_gradients/configs/sdpo.yaml`. Full reference run pending. - 2026-04-16: [PR #374](https://github.com/natolambert/rlhf-book/pull/374) added CI ruff lint/format check for PRs touching `code/`, applied ruff format to all existing files, fixed lint errors (unused imports, unsorted imports, `zip()` without `strict=`), and documented linting in README. - 2026-04-15: [PR #372](https://github.com/natolambert/rlhf-book/pull/372) documented build-essential requirement for Ubuntu/Debian and uv version guidance in README install section. - 2026-04-15: [PR #370](https://github.com/natolambert/rlhf-book/pull/370) cleaned up CLAUDE.md for generic use, removed dead LoRA/QLoRA references from RM docstrings and base.py, moved ORPO/SimPO debug notes to direct_alignment/ORPO_SIMPO.md. diff --git a/code/policy_gradients/README.md b/code/policy_gradients/README.md index 21296998..2b71f856 100644 --- a/code/policy_gradients/README.md +++ b/code/policy_gradients/README.md @@ -18,6 +18,7 @@ See the parent [`code/README.md`](../README.md) for installation, configuration, | **GSPO** | `gspo.yaml` | Group-Sequence Policy Optimization ([Zheng et al., 2025](https://arxiv.org/abs/2505.13818)) | | **CISPO** | `cispo.yaml` | Clipped Importance Sampling PO ([MiniMax, 2025](https://arxiv.org/abs/2506.13585)) | | **SAPO** | `sapo.yaml` | Soft Adaptive Policy Optimization ([Gao et al., 2025](https://arxiv.org/abs/2511.20347)) | +| **SDPO** | `sdpo.yaml` | Self-Distillation Policy Optimization ([Hübotter et al., 2026](https://arxiv.org/abs/2601.20802)) — hybrid GRPO + reverse-KL self-distillation | ## Reference Runs @@ -31,6 +32,7 @@ See the parent [`code/README.md`](../README.md) for installation, configuration, | **GSPO** | [run](https://wandb.ai/natolambert/rlhf-book/runs/10sxytli) | ✅ Validated | | **CISPO** | [run](https://wandb.ai/natolambert/rlhf-book/runs/6dg0m06n) | ✅ Validated | | **SAPO** | [run](https://wandb.ai/natolambert/rlhf-book/runs/79608nwk) | ✅ Validated | +| **SDPO** | — | ⏳ Pending full validation | ## Quick Start @@ -46,6 +48,9 @@ uv run python -m policy_gradients.train --config policy_gradients/configs/reinfo # PPO with value function uv run python -m policy_gradients.train --config policy_gradients/configs/ppo.yaml + +# SDPO hybrid GRPO + self-distillation +uv run python -m policy_gradients.train --config policy_gradients/configs/sdpo.yaml ``` ## TODOs for Community Contributions diff --git a/code/policy_gradients/__init__.py b/code/policy_gradients/__init__.py index ec6bd0d4..6f9dbc58 100644 --- a/code/policy_gradients/__init__.py +++ b/code/policy_gradients/__init__.py @@ -8,7 +8,7 @@ from .buffer import Experience, ReplayBuffer from .config import Config, load_config -from .loss import CISPOLoss, GRPOLoss, GSPOLoss, PPOLoss, ReinforceLoss +from .loss import CISPOLoss, GRPOLoss, GSPOLoss, PPOLoss, ReinforceLoss, SAPOLoss, SDPOLoss __all__ = [ @@ -21,4 +21,6 @@ "PPOLoss", "ReinforceLoss", "CISPOLoss", + "SAPOLoss", + "SDPOLoss", ] diff --git a/code/policy_gradients/buffer.py b/code/policy_gradients/buffer.py index bc4e2669..8f246b68 100644 --- a/code/policy_gradients/buffer.py +++ b/code/policy_gradients/buffer.py @@ -23,6 +23,11 @@ class Experience: - log_probs_old: Log probabilities from the rollout policy - log_probs_ref: Log probabilities from the reference policy (for KL) - values_old: Value estimates (for PPO) + - teacher_sequence_ids: SDPO teacher input = teacher_prompt + completion + - teacher_attention_mask: Attention mask for teacher_sequence_ids + - teacher_action_mask: Mask for completion tokens in teacher sequence + - distill_mask: Per-sample {0,1} — 0 when sample is excluded from the SDPO + distillation term (no-success group, or the successful rollout itself) """ sequence_ids: torch.Tensor @@ -32,6 +37,10 @@ class Experience: log_probs_old: torch.Tensor | None = None log_probs_ref: torch.Tensor | None = None values_old: torch.Tensor | None = None + teacher_sequence_ids: torch.Tensor | None = None + teacher_attention_mask: torch.Tensor | None = None + teacher_action_mask: torch.Tensor | None = None + distill_mask: torch.Tensor | None = None def to(self, device: torch.device) -> Self: """Move all tensors to the specified device.""" diff --git a/code/policy_gradients/config.py b/code/policy_gradients/config.py index 32f89540..dec0df69 100644 --- a/code/policy_gradients/config.py +++ b/code/policy_gradients/config.py @@ -30,10 +30,10 @@ class Config(BaseModel): Attributes: data: Dataset configuration - loss: Loss function (reinforce, rloo, ppo, grpo, drgrpo, gspo, cispo, sapo) + loss: Loss function (reinforce, rloo, ppo, grpo, drgrpo, gspo, cispo, sapo, sdpo) model_name: HuggingFace model identifier - # Clipping (GRPO, DrGRPO, GSPO, CISPO, PPO) + # Clipping (GRPO, DrGRPO, GSPO, CISPO, PPO, SDPO-hybrid) clip_eps_lo: Lower clipping bound for policy ratio clip_eps_hi: Upper clipping bound for policy ratio @@ -41,6 +41,12 @@ class Config(BaseModel): sapo_temp_pos: Sigmoid temperature for positive advantages sapo_temp_neg: Sigmoid temperature for negative advantages + # SDPO-specific + distillation_weight: Weight on the self-distillation term added to GRPO loss + success_reward_threshold: Minimum reward for a rollout to act as the teacher demo + sdpo_teacher_ema_rate: EMA update rate for an optional SDPO teacher copy (0 disables) + sdpo_is_clip: Importance-sampling clip applied to the SDPO distillation term + # PPO-specific clip_eps_val: Clipping bound for value function gamma: Discount factor for GAE @@ -90,6 +96,12 @@ class Config(BaseModel): sapo_temp_pos: float = 1.0 sapo_temp_neg: float = 1.05 + # SDPO-specific params (hybrid GRPO + self-distillation) + distillation_weight: float = 1.0 + success_reward_threshold: float = 1.0 + sdpo_teacher_ema_rate: float = 0.0 + sdpo_is_clip: float | None = 2.0 + # KL penalty (optional, for REINFORCE/RLOO/GRPO when beta > 0) beta: float = 0.0 ref_model_device_id: int = 0 @@ -124,6 +136,15 @@ def validate_rollout_batch_size(self) -> "Config": raise ValueError( "prompts_per_step * num_rollouts must be divisible by rollout_batch_size." ) + if self.loss == "sdpo": + if self.num_rollouts < 2: + raise ValueError("SDPO requires at least 2 rollouts per prompt.") + if self.distillation_weight <= 0: + raise ValueError("SDPO requires distillation_weight > 0.") + if not 0.0 <= self.sdpo_teacher_ema_rate <= 1.0: + raise ValueError("sdpo_teacher_ema_rate must be in [0, 1].") + if self.sdpo_is_clip is not None and self.sdpo_is_clip <= 0: + raise ValueError("sdpo_is_clip must be positive when set.") return self diff --git a/code/policy_gradients/configs/sdpo.yaml b/code/policy_gradients/configs/sdpo.yaml new file mode 100644 index 00000000..10c9d612 --- /dev/null +++ b/code/policy_gradients/configs/sdpo.yaml @@ -0,0 +1,34 @@ +data: + size: 3000 + specs: + - name: spell_backward + weight: 1 + config: + min_word_len: 3 + max_word_len: 10 + +loss: sdpo +model_name: Qwen/Qwen3-1.7B +clip_eps_lo: 0.2 +clip_eps_hi: 0.2 +beta: 0.0 # KL penalty coefficient (0 = disabled, set >0 to enable) +distillation_weight: 1.0 +success_reward_threshold: 1.0 # accuracy reward = 1.0 when answer is correct +sdpo_teacher_ema_rate: 0.0 # set to 0.01 for a more paper-like regularized teacher if memory allows +sdpo_is_clip: 2.0 +lr: 5e-6 +temperature: 0.6 +top_p: 0.95 +top_k: 20 +min_p: 0.0 +max_new_tokens: 512 +prompts_per_step: 4 +num_rollouts: 8 +rollout_batch_size: 8 +train_batch_size: 2 +batch_acc: 4 +max_norm: 1.0 +seed: 42 +model_device_id: 0 +wandb_project: rlhf-book +wandb_run_name: sdpo_spell_backwards diff --git a/code/policy_gradients/loss.py b/code/policy_gradients/loss.py index 9b4df83e..2ea84468 100644 --- a/code/policy_gradients/loss.py +++ b/code/policy_gradients/loss.py @@ -11,6 +11,7 @@ # - GSPO (Zheng et al., 2025) # - CISPO (MiniMax, 2025) # - SAPO (Qwen Team, 2025) +# - SDPO (Hübotter et al., 2026) import torch import torch.nn as nn @@ -196,6 +197,63 @@ def forward(self, log_probs: torch.Tensor, experience: Experience, **kwargs) -> return loss +class SDPOLoss(nn.Module): + """Self-Distillation Policy Optimization loss (Hübotter et al., 2026). + + Hybrid objective: GRPO policy loss + reverse-KL self-distillation. + The teacher is the same policy model conditioned on a richer context that + includes the highest-reward rollout of the current group as a demonstration. + The distillation term uses the token-level REINFORCE surrogate for + reverse-KL(student || teacher): + + log_ratio = (student_log_probs - teacher_log_probs).detach() + distill_loss = log_ratio * student_log_probs + + The successful rollout itself (and samples in groups with no successful + rollout) are excluded from the distillation term via `distill_mask`, since + the teacher would otherwise see the target completion verbatim in its own + context and the surrogate would collapse to SFT on that completion. + """ + + def __init__( + self, + clip_eps_lo: float, + clip_eps_hi: float, + beta: float, + distillation_weight: float, + sdpo_is_clip: float | None = None, + **kwargs, + ) -> None: + super().__init__() + self.grpo = GRPOLoss(clip_eps_lo=clip_eps_lo, clip_eps_hi=clip_eps_hi, beta=beta) + self.distillation_weight = distillation_weight + self.sdpo_is_clip = sdpo_is_clip + + def forward( + self, + log_probs: torch.Tensor, + experience: Experience, + student_comp_log_probs: torch.Tensor | None = None, + teacher_comp_log_probs: torch.Tensor | None = None, + comp_action_mask: torch.Tensor | None = None, + old_student_comp_log_probs: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + grpo_loss = self.grpo(log_probs=log_probs, experience=experience) + + log_ratio = (student_comp_log_probs - teacher_comp_log_probs).detach() + distill_token_loss = log_ratio * student_comp_log_probs + if self.sdpo_is_clip is not None and old_student_comp_log_probs is not None: + approx_log_ratio = (student_comp_log_probs - old_student_comp_log_probs).detach() + approx_log_ratio = approx_log_ratio.clamp(min=-20.0, max=20.0) + distill_token_loss = distill_token_loss * approx_log_ratio.exp().clamp(max=self.sdpo_is_clip) + # Broadcast distill_mask [B, 1] over completion tokens [B, M]. + effective_mask = comp_action_mask * experience.distill_mask + distill_loss = masked_mean(distill_token_loss, effective_mask, dim=-1).mean(dim=0) + + return grpo_loss + self.distillation_weight * distill_loss + + class PPOLoss(nn.Module): """Proximal Policy Optimization loss (Schulman et al., 2017). diff --git a/code/policy_gradients/train.py b/code/policy_gradients/train.py index a62527f0..9855e9a6 100644 --- a/code/policy_gradients/train.py +++ b/code/policy_gradients/train.py @@ -8,6 +8,7 @@ # - Added SDPA fallback for platforms without flash-attn (e.g., DGX Spark) import argparse +import copy import os import platform import random @@ -22,7 +23,6 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -import wandb from reasoning_gym.composite import DatasetSpec from reasoning_gym.dataset import ProceduralDataset from reasoning_gym.utils import SYSTEM_PROMPTS, extract_answer @@ -31,6 +31,8 @@ from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +import wandb + from .buffer import Experience, ReplayBuffer, join_experiences_batch from .config import Config, load_config from .loss import ( @@ -40,6 +42,7 @@ PPOLoss, ReinforceLoss, SAPOLoss, + SDPOLoss, approx_kl, masked_mean, ) @@ -104,6 +107,18 @@ def get_ref_model(model_name: str, device_map: Any, beta: float): return ref_model +def get_sdpo_teacher_model(model, loss: str, ema_rate: float): + """Create an optional EMA teacher for SDPO.""" + if loss != "sdpo" or ema_rate <= 0: + return None + teacher_model = copy.deepcopy(model) + if hasattr(teacher_model, "gradient_checkpointing_disable"): + teacher_model.gradient_checkpointing_disable() + teacher_model.requires_grad_(False) + teacher_model.eval() + return teacher_model + + def get_val_model(model_name: str, device_map: Any, loss: str, gradient_checkpointing: bool = True): """Load value model for PPO (only if loss == 'ppo').""" if loss not in ["ppo"]: @@ -127,6 +142,8 @@ def get_loss_objective(loss: str, **kwargs) -> nn.Module: return CISPOLoss(**kwargs) elif loss == "sapo": return SAPOLoss(**kwargs) + elif loss == "sdpo": + return SDPOLoss(**kwargs) elif loss == "ppo": return PPOLoss(**kwargs) raise ValueError(f"Unsupported loss type: {loss}") @@ -247,7 +264,7 @@ def compute_advantages( lam: float | None = None, ) -> torch.Tensor: """Compute advantages using the appropriate method for the loss function.""" - if loss in ["grpo", "gspo", "cispo", "sapo"]: + if loss in ["grpo", "gspo", "cispo", "sapo", "sdpo"]: return compute_standardized_advantages(rewards) elif loss in ["drgrpo"]: return compute_nonstandardized_advantages(rewards) @@ -274,6 +291,58 @@ def compute_log_probs( return target_log_probs +def extract_completion_ids( + sequence_ids: torch.Tensor, action_mask: torch.Tensor, pad_token_id: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Extract generated completion token ids aligned by action_mask.""" + target_ids = sequence_ids[:, 1:] + completion_rows = [ids[mask] for ids, mask in zip(target_ids, action_mask, strict=True)] + max_len = max((row.size(0) for row in completion_rows), default=0) + completion_ids = torch.full( + (len(completion_rows), max_len), + fill_value=pad_token_id, + dtype=sequence_ids.dtype, + device=sequence_ids.device, + ) + completion_mask = torch.zeros( + (len(completion_rows), max_len), + dtype=torch.bool, + device=sequence_ids.device, + ) + for i, row in enumerate(completion_rows): + row_len = row.size(0) + if row_len == 0: + continue + completion_ids[i, :row_len] = row + completion_mask[i, :row_len] = True + return completion_ids, completion_mask + + +def extract_masked_log_probs( + log_probs: torch.Tensor, action_mask: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """Extract per-completion log probs aligned by a token mask.""" + log_prob_rows = [row[mask] for row, mask in zip(log_probs, action_mask, strict=True)] + max_len = max((row.size(0) for row in log_prob_rows), default=0) + completion_log_probs = torch.zeros( + (len(log_prob_rows), max_len), + dtype=log_probs.dtype, + device=log_probs.device, + ) + completion_mask = torch.zeros( + (len(log_prob_rows), max_len), + dtype=log_probs.dtype, + device=log_probs.device, + ) + for i, row in enumerate(log_prob_rows): + row_len = row.size(0) + if row_len == 0: + continue + completion_log_probs[i, :row_len] = row + completion_mask[i, :row_len] = 1.0 + return completion_log_probs, completion_mask + + def compute_values(model, sequence_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """Compute value estimates for each position (PPO).""" if not model: @@ -350,6 +419,83 @@ def rollout( return sequence_ids, action_mask, attention_mask, rewards, completions +def build_teacher_inputs( + tokenizer: AutoTokenizer, + group_entries: list[dict], + completions: list[str], + rewards: torch.Tensor, + sequence_ids: torch.Tensor, + action_mask: torch.Tensor, + attention_mask: torch.Tensor, + success_reward_threshold: float, + device: torch.device, +): + """Build SDPO teacher inputs for one rollout group (shared prompt, K rollouts). + + Picks the highest-reward rollout; if its reward >= success_reward_threshold, + constructs a teacher prompt that embeds that rollout as a demonstration and + keeps its completion_ids as the target. Otherwise, returns the student's + own tensors as teacher inputs with a zero distill_mask so the distillation + term drops out for the group. The successful rollout itself is also masked + out (i==j would collapse to SFT). + """ + k = len(group_entries) + rewards_list = rewards.squeeze(-1).tolist() + best_idx = int(max(range(k), key=lambda i: rewards_list[i])) + best_reward = rewards_list[best_idx] + pad_token_id = ( + tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id + ) + + if best_reward < success_reward_threshold: + distill_mask = torch.zeros(k, 1, dtype=torch.float32, device=device) + return ( + sequence_ids.to(device), + attention_mask.to(device), + action_mask.to(device), + distill_mask, + ) + + teacher_user = ( + f"{group_entries[0]['question']}\n\n" + f"Correct solution:\n{completions[best_idx]}\n\n" + f"Correctly solve the original question." + ) + teacher_template = tokenizer.apply_chat_template( + [ + {"role": "system", "content": SYSTEM_PROMPTS["DeepSeekZero"]}, + {"role": "user", "content": teacher_user}, + ], + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + teacher_inputs = tokenizer( + [teacher_template] * k, + return_tensors="pt", + padding=True, + padding_side="left", + return_attention_mask=True, + ).to(device) + teacher_prompt_ids = teacher_inputs["input_ids"] + teacher_prompt_mask = teacher_inputs["attention_mask"].bool() + completion_ids, completion_mask = extract_completion_ids( + sequence_ids=sequence_ids.to(device), + action_mask=action_mask.to(device), + pad_token_id=pad_token_id, + ) + teacher_sequence_ids = torch.cat([teacher_prompt_ids, completion_ids], dim=1) + teacher_attention_mask = torch.cat([teacher_prompt_mask, completion_mask], dim=1) + + teacher_token_mask = torch.zeros_like(teacher_sequence_ids, dtype=torch.bool) + teacher_token_mask[:, teacher_prompt_ids.shape[1] :] = completion_mask + teacher_action_mask = teacher_token_mask[:, 1:] + + distill_mask = torch.ones(k, 1, dtype=torch.float32, device=device) + distill_mask[best_idx] = 0.0 + return teacher_sequence_ids, teacher_attention_mask, teacher_action_mask, distill_mask + + def create_dataset(cfg: Config) -> ProceduralDataset: """Create the training dataset from config.""" specs = [DatasetSpec(name=s.name, weight=s.weight, config=s.config) for s in cfg.data.specs] @@ -384,6 +530,9 @@ def main(cfg: Config): ) model, tokenizer = load_model(cfg.model_name, model_device, gradient_checkpointing=True) ref_model = get_ref_model(cfg.model_name, ref_model_device, cfg.beta) + sdpo_teacher_model = get_sdpo_teacher_model( + model=model, loss=cfg.loss, ema_rate=cfg.sdpo_teacher_ema_rate + ) val_model = get_val_model( cfg.model_name, val_model_device, cfg.loss, gradient_checkpointing=True ) @@ -396,6 +545,8 @@ def main(cfg: Config): beta=cfg.beta, sapo_temp_pos=cfg.sapo_temp_pos, sapo_temp_neg=cfg.sapo_temp_neg, + distillation_weight=cfg.distillation_weight, + sdpo_is_clip=cfg.sdpo_is_clip, ).to(model.device) params = list(model.parameters()) + (list(val_model.parameters()) if val_model else []) optimizer = optim.Adam(params, lr=cfg.lr) @@ -457,6 +608,29 @@ def main(cfg: Config): rewards, cfg.loss, action_mask, values_old, cfg.gamma, cfg.lam ) + if cfg.loss == "sdpo": + ( + teacher_sequence_ids, + teacher_attention_mask, + teacher_action_mask, + distill_mask, + ) = build_teacher_inputs( + tokenizer=tokenizer, + group_entries=batch, + completions=completions, + rewards=rewards, + sequence_ids=sequence_ids, + action_mask=action_mask, + attention_mask=attention_mask, + success_reward_threshold=cfg.success_reward_threshold, + device=model.device, + ) + else: + teacher_sequence_ids = None + teacher_attention_mask = None + teacher_action_mask = None + distill_mask = None + experience = Experience( sequence_ids=sequence_ids, attention_mask=attention_mask, @@ -465,6 +639,10 @@ def main(cfg: Config): log_probs_old=log_probs_old, log_probs_ref=log_probs_ref, values_old=values_old, + teacher_sequence_ids=teacher_sequence_ids, + teacher_attention_mask=teacher_attention_mask, + teacher_action_mask=teacher_action_mask, + distill_mask=distill_mask, ).to(cpu_device) replay_buffer.add(experience) @@ -507,7 +685,39 @@ def main(cfg: Config): values = compute_values( val_model, experience.sequence_ids, experience.attention_mask ) - loss = objective(log_probs=log_probs, experience=experience, values=values) + extra_kwargs: dict[str, Any] = {} + if cfg.loss == "sdpo": + # Teacher output is always detached inside the loss (reverse-KL + # surrogate puts the gradient on the student). Skip autograd graph + # construction for the teacher pass to keep activations out of memory. + with torch.no_grad(): + teacher_log_probs = compute_log_probs( + sdpo_teacher_model if sdpo_teacher_model is not None else model, + experience.teacher_sequence_ids, + experience.teacher_attention_mask, + ) + student_comp_log_probs, comp_action_mask = extract_masked_log_probs( + log_probs, experience.action_mask + ) + teacher_comp_log_probs, teacher_comp_mask = extract_masked_log_probs( + teacher_log_probs, experience.teacher_action_mask + ) + old_student_comp_log_probs, old_comp_mask = extract_masked_log_probs( + experience.log_probs_old, experience.action_mask + ) + if not torch.equal(comp_action_mask.bool(), teacher_comp_mask.bool()): + raise RuntimeError("SDPO student/teacher completion masks are misaligned.") + if not torch.equal(comp_action_mask.bool(), old_comp_mask.bool()): + raise RuntimeError("SDPO old/current completion masks are misaligned.") + extra_kwargs = { + "student_comp_log_probs": student_comp_log_probs, + "teacher_comp_log_probs": teacher_comp_log_probs, + "comp_action_mask": comp_action_mask, + "old_student_comp_log_probs": old_student_comp_log_probs, + } + loss = objective( + log_probs=log_probs, experience=experience, values=values, **extra_kwargs + ) if not loss.isfinite(): continue scaled_loss = loss / cfg.batch_acc @@ -520,6 +730,12 @@ def main(cfg: Config): ): grad_norm = clip_grad_norm_(params, max_norm=cfg.max_norm) optimizer.step() + if sdpo_teacher_model is not None: + with torch.no_grad(): + for teacher_param, student_param in zip( + sdpo_teacher_model.parameters(), model.parameters(), strict=True + ): + teacher_param.lerp_(student_param.detach(), cfg.sdpo_teacher_ema_rate) optimizer.zero_grad(set_to_none=True) torch.cuda.empty_cache() diff --git a/code/scripts/run_all_policy_gradients.sh b/code/scripts/run_all_policy_gradients.sh index c1041c5b..31d4bb67 100755 --- a/code/scripts/run_all_policy_gradients.sh +++ b/code/scripts/run_all_policy_gradients.sh @@ -47,6 +47,12 @@ echo "Running SAPO..." echo "==========================================" uv run python -m policy_gradients.train --config policy_gradients/configs/sapo.yaml +# SDPO +echo "==========================================" +echo "Running SDPO..." +echo "==========================================" +uv run python -m policy_gradients.train --config policy_gradients/configs/sdpo.yaml + echo "" echo "==========================================" echo "All runs complete!"