@@ -78,8 +78,10 @@ class GRPOLoss(LossModule):
7878 The masking strategy must match the strategy used for advantage computation to avoid shape mismatches.
7979
8080 Keyword Args:
81- clip_epsilon (scalar, optional): weight clipping threshold in the clipped PPO loss equation.
82- default: 0.2
81+ clip_epsilon (float | tuple[float, float], optional): clipping threshold(s) for the clipped surrogate.
82+ - float x: symmetric clipping [1 - x, 1 + x] (default: 0.2)
83+ - tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher
84+ recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper.
8385 entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
8486 loss to favour exploratory policies.
8587 samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -115,6 +117,9 @@ class GRPOLoss(LossModule):
115117
116118 .. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that
117119 the storages match the ones that are passed to other components, such as data collectors.
120+
121+ .. note:: For non-symmetric clipping thresholds, see the `DAPO <https://arxiv.org/html/2503.14476>`_ paper.
122+
118123 """
119124
120125 actor_network : LLMWrapperBase
@@ -136,7 +141,7 @@ def __init__(
136141 self ,
137142 actor_network : LLMWrapperBase | None = None ,
138143 * ,
139- clip_epsilon : float = 0.2 ,
144+ clip_epsilon : float | tuple [ float , float ] = 0.2 ,
140145 entropy_bonus : bool = True ,
141146 samples_mc_entropy : int = 1 ,
142147 entropy_coeff : float = 0.01 ,
@@ -165,7 +170,28 @@ def __init__(
165170 device = getattr (
166171 torch , "get_default_device" , lambda : torch .device ("cpu" )
167172 )()
168- self .register_buffer ("clip_epsilon" , torch .tensor (clip_epsilon , device = device ))
173+ # Accept symmetric or asymmetric thresholds
174+ if isinstance (clip_epsilon , (tuple , list )):
175+ if len (clip_epsilon ) != 2 :
176+ raise ValueError (
177+ f"clip_epsilon tuple must have length 2, got { clip_epsilon } ."
178+ )
179+ eps_low , eps_high = clip_epsilon
180+ else :
181+ eps_low = float (clip_epsilon )
182+ eps_high = float (clip_epsilon )
183+ # Basic validation
184+ if eps_low < 0 or eps_high < 0 :
185+ raise ValueError (
186+ f"clip_epsilon values must be non-negative, got ({ eps_low } , { eps_high } )."
187+ )
188+ if eps_low >= 1.0 :
189+ raise ValueError (
190+ f"clip_epsilon low must be < 1 (to keep 1 - eps_low > 0), got { eps_low } ."
191+ )
192+ # Register buffers
193+ self .register_buffer ("clip_epsilon_low" , torch .tensor (eps_low , device = device ))
194+ self .register_buffer ("clip_epsilon_high" , torch .tensor (eps_high , device = device ))
169195
170196 self .masking_strategy = masking_strategy
171197 # Defaults for keys
@@ -178,7 +204,11 @@ def __init__(
178204
179205 @property
180206 def _clip_bounds (self ):
181- return ((- self .clip_epsilon ).log1p (), self .clip_epsilon .log1p ())
207+ # Returns (log(1 - eps_low), log(1 + eps_high)) for clamping log-weight
208+ return (
209+ (- self .clip_epsilon_low ).log1p (),
210+ self .clip_epsilon_high .log1p (),
211+ )
182212
183213 def _set_in_keys (self ):
184214 keys = []
@@ -325,6 +355,7 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
325355 ratio = log_weight_clip .exp ()
326356 gain2 = ratio * advantage
327357
358+ # Token-level objective: compute min over clipped/unclipped at the token level
328359 gain = torch .stack ([gain1 , gain2 ], - 1 ).min (dim = - 1 ).values
329360 td_out = TensorDict ({"loss_objective" : - gain })
330361 td_out .set ("clip_fraction" , clip_fraction )
@@ -514,6 +545,55 @@ def _log_weight(
514545 return log_weight , dist , kl_approx
515546
516547
548+ class DAPO (GRPOLoss ):
549+ """DAPO (Clip-Higher over GRPO).
550+
551+ Validates asymmetric clip thresholds; recommended (0.20, 0.28), see Eq. (10) in DAPO
552+ [arXiv](https://arxiv.org/html/2503.14476).
553+ """
554+
555+ def __init__ (
556+ self ,
557+ tensordict : TensorDictBase ,
558+ key : NestedKey = ("next" , "ref_log_prob" ),
559+ ref_log_prob : torch .Tensor | None = None ,
560+ coeff : float | None = None ,
561+ mask : torch .Tensor | None = None ,
562+ dist : d .Distribution | None = None ,
563+ ):
564+ if coeff is None :
565+ coeff = self .kl_to_ref_coeff
566+ # TODO: customize this
567+ if ref_log_prob is None :
568+ ref_log_prob = tensordict .get (
569+ key ,
570+ as_padded_tensor = True ,
571+ padding_side = "left" ,
572+ padding_value = 0.0 ,
573+ )
574+ if ref_log_prob is None :
575+ raise KeyError (
576+ f"Couldn't find the ref log-prob { key } in the input data ({ tensordict .keys (True )= } )."
577+ )
578+ ref_log_prob = ref_log_prob .squeeze (- 1 )
579+ cur_log_prob = tensordict .get ("_cur_log_prob" )
580+ # TODO: remove this
581+ if cur_log_prob .shape != ref_log_prob .shape :
582+ raise ValueError (
583+ f"cur_log_prob and ref_log_prob must have the same shape, got { cur_log_prob .shape = } and { ref_log_prob .shape = } "
584+ )
585+ if mask is not None :
586+ ref_log_prob = torch .where (
587+ expand_as_right (mask , ref_log_prob ), ref_log_prob , 0.0
588+ )
589+ cur_log_prob = torch .where (
590+ expand_as_right (mask , cur_log_prob ), cur_log_prob , 0.0
591+ )
592+ diff = ref_log_prob - cur_log_prob
593+ kl_penalty = (diff .expm1 () - diff ).mean ()
594+ return coeff * kl_penalty , kl_penalty
595+
596+
517597class MCAdvantage (Transform ):
518598 """Monte-Carlo advantage computation engine.
519599
0 commit comments