Skip to content

Commit 9d5c276

Browse files
committed
[Feature] DAPO
ghstack-source-id: 936b0ee Pull-Request: #3206
1 parent 2bc3cb7 commit 9d5c276

File tree

2 files changed

+95
-11
lines changed

2 files changed

+95
-11
lines changed

test/llm/test_objectives.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,7 @@ def make_silly_trajectory(n_steps=None):
8686
# Mock infrastructure moved to conftest.py
8787

8888

89-
def _mock_data_grpo(
90-
vocab_size: int, device: torch.device | str = "cpu"
91-
) -> TensorDict:
89+
def _mock_data_grpo(vocab_size: int, device: torch.device | str = "cpu") -> TensorDict:
9290
from transformers import AutoTokenizer
9391

9492
device = torch.device(device)
@@ -175,11 +173,17 @@ def _mock_data_grpo(
175173

176174

177175
class TestLosses:
178-
def test_grpo(self, mock_transformer_model):
176+
@pytest.mark.parametrize("dapo", [True, False], ids=["dapo", "symmetric"])
177+
def test_grpo(self, mock_transformer_model, dapo):
179178
"""Test GRPO loss computation with mock models."""
180179
vocab_size = 1024
181180
device = torch.device("cpu")
182-
181+
if dapo:
182+
eps_low = 0.20
183+
eps_high = 0.28
184+
eps = (eps_low, eps_high)
185+
else:
186+
eps = 0.20
183187
# Create mock model and wrap it
184188
model = mock_transformer_model(vocab_size=vocab_size, device=device)
185189
actor_network = TransformersWrapper(
@@ -190,7 +194,7 @@ def test_grpo(self, mock_transformer_model):
190194
)
191195

192196
# Create loss module
193-
loss_fn = GRPOLoss(actor_network)
197+
loss_fn = GRPOLoss(actor_network, eps=eps)
194198

195199
# Create fake data
196200
data = _mock_data_grpo(vocab_size=vocab_size, device=device)

torchrl/objectives/llm/grpo.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,10 @@ class GRPOLoss(LossModule):
7878
The masking strategy must match the strategy used for advantage computation to avoid shape mismatches.
7979
8080
Keyword Args:
81-
clip_epsilon (scalar, optional): weight clipping threshold in the clipped PPO loss equation.
82-
default: 0.2
81+
clip_epsilon (float | tuple[float, float], optional): clipping threshold(s) for the clipped surrogate.
82+
- float x: symmetric clipping [1 - x, 1 + x] (default: 0.2)
83+
- tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher
84+
recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper.
8385
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
8486
loss to favour exploratory policies.
8587
samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -115,6 +117,9 @@ class GRPOLoss(LossModule):
115117
116118
.. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that
117119
the storages match the ones that are passed to other components, such as data collectors.
120+
121+
.. note:: For non-symmetric clipping thresholds, see the `DAPO <https://arxiv.org/html/2503.14476>`_ paper.
122+
118123
"""
119124

120125
actor_network: LLMWrapperBase
@@ -136,7 +141,7 @@ def __init__(
136141
self,
137142
actor_network: LLMWrapperBase | None = None,
138143
*,
139-
clip_epsilon: float = 0.2,
144+
clip_epsilon: float | tuple[float, float] = 0.2,
140145
entropy_bonus: bool = True,
141146
samples_mc_entropy: int = 1,
142147
entropy_coeff: float = 0.01,
@@ -165,7 +170,28 @@ def __init__(
165170
device = getattr(
166171
torch, "get_default_device", lambda: torch.device("cpu")
167172
)()
168-
self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon, device=device))
173+
# Accept symmetric or asymmetric thresholds
174+
if isinstance(clip_epsilon, (tuple, list)):
175+
if len(clip_epsilon) != 2:
176+
raise ValueError(
177+
f"clip_epsilon tuple must have length 2, got {clip_epsilon}."
178+
)
179+
eps_low, eps_high = clip_epsilon
180+
else:
181+
eps_low = float(clip_epsilon)
182+
eps_high = float(clip_epsilon)
183+
# Basic validation
184+
if eps_low < 0 or eps_high < 0:
185+
raise ValueError(
186+
f"clip_epsilon values must be non-negative, got ({eps_low}, {eps_high})."
187+
)
188+
if eps_low >= 1.0:
189+
raise ValueError(
190+
f"clip_epsilon low must be < 1 (to keep 1 - eps_low > 0), got {eps_low}."
191+
)
192+
# Register buffers
193+
self.register_buffer("clip_epsilon_low", torch.tensor(eps_low, device=device))
194+
self.register_buffer("clip_epsilon_high", torch.tensor(eps_high, device=device))
169195

170196
self.masking_strategy = masking_strategy
171197
# Defaults for keys
@@ -178,7 +204,11 @@ def __init__(
178204

179205
@property
180206
def _clip_bounds(self):
181-
return ((-self.clip_epsilon).log1p(), self.clip_epsilon.log1p())
207+
# Returns (log(1 - eps_low), log(1 + eps_high)) for clamping log-weight
208+
return (
209+
(-self.clip_epsilon_low).log1p(),
210+
self.clip_epsilon_high.log1p(),
211+
)
182212

183213
def _set_in_keys(self):
184214
keys = []
@@ -325,6 +355,7 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
325355
ratio = log_weight_clip.exp()
326356
gain2 = ratio * advantage
327357

358+
# Token-level objective: compute min over clipped/unclipped at the token level
328359
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
329360
td_out = TensorDict({"loss_objective": -gain})
330361
td_out.set("clip_fraction", clip_fraction)
@@ -514,6 +545,55 @@ def _log_weight(
514545
return log_weight, dist, kl_approx
515546

516547

548+
class DAPO(GRPOLoss):
549+
"""DAPO (Clip-Higher over GRPO).
550+
551+
Validates asymmetric clip thresholds; recommended (0.20, 0.28), see Eq. (10) in DAPO
552+
[arXiv](https://arxiv.org/html/2503.14476).
553+
"""
554+
555+
def __init__(
556+
self,
557+
tensordict: TensorDictBase,
558+
key: NestedKey = ("next", "ref_log_prob"),
559+
ref_log_prob: torch.Tensor | None = None,
560+
coeff: float | None = None,
561+
mask: torch.Tensor | None = None,
562+
dist: d.Distribution | None = None,
563+
):
564+
if coeff is None:
565+
coeff = self.kl_to_ref_coeff
566+
# TODO: customize this
567+
if ref_log_prob is None:
568+
ref_log_prob = tensordict.get(
569+
key,
570+
as_padded_tensor=True,
571+
padding_side="left",
572+
padding_value=0.0,
573+
)
574+
if ref_log_prob is None:
575+
raise KeyError(
576+
f"Couldn't find the ref log-prob {key} in the input data ({tensordict.keys(True)=})."
577+
)
578+
ref_log_prob = ref_log_prob.squeeze(-1)
579+
cur_log_prob = tensordict.get("_cur_log_prob")
580+
# TODO: remove this
581+
if cur_log_prob.shape != ref_log_prob.shape:
582+
raise ValueError(
583+
f"cur_log_prob and ref_log_prob must have the same shape, got {cur_log_prob.shape=} and {ref_log_prob.shape=}"
584+
)
585+
if mask is not None:
586+
ref_log_prob = torch.where(
587+
expand_as_right(mask, ref_log_prob), ref_log_prob, 0.0
588+
)
589+
cur_log_prob = torch.where(
590+
expand_as_right(mask, cur_log_prob), cur_log_prob, 0.0
591+
)
592+
diff = ref_log_prob - cur_log_prob
593+
kl_penalty = (diff.expm1() - diff).mean()
594+
return coeff * kl_penalty, kl_penalty
595+
596+
517597
class MCAdvantage(Transform):
518598
"""Monte-Carlo advantage computation engine.
519599

0 commit comments

Comments
 (0)