diff --git a/src/liger_kernel/chunked_loss/fused_linear_ppo.py b/src/liger_kernel/chunked_loss/fused_linear_ppo.py index a382cda1b..b79f8ec64 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_ppo.py +++ b/src/liger_kernel/chunked_loss/fused_linear_ppo.py @@ -44,6 +44,9 @@ def forward( vllm_is_ratio=None, delta=None, use_bias_correction_kl=False, + log_ratio_clamp_value=20.0, + kl_input_clamp_value=20.0, + kl_output_clamp_value=10.0, ): # TODO: check torch compile matmul """Chunked forward pass for PPO loss computation. @@ -125,6 +128,9 @@ def forward( sapo_temperature_neg=sapo_temperature_neg, delta=delta, use_bias_correction_kl=use_bias_correction_kl, + log_ratio_clamp_value=log_ratio_clamp_value, + kl_input_clamp_value=kl_input_clamp_value, + kl_output_clamp_value=kl_output_clamp_value, ) def fused_fwd_bwd( @@ -327,6 +333,9 @@ def _compute_chunk_loss( sapo_temperature_neg=1.05, delta=None, use_bias_correction_kl=False, + log_ratio_clamp_value=20.0, + kl_input_clamp_value=20.0, + kl_output_clamp_value=10.0, ): """Compute loss for a single chunk.""" # Get policy log probabilities using chunk_forward @@ -361,6 +370,9 @@ def _compute_chunk_loss( vllm_is_ratio=vllm_is_ratio_chunk, delta=delta, use_bias_correction_kl=use_bias_correction_kl, + log_ratio_clamp_value=log_ratio_clamp_value, + kl_input_clamp_value=kl_input_clamp_value, + kl_output_clamp_value=kl_output_clamp_value, ) return chunk_loss, chunk_metrics diff --git a/src/liger_kernel/chunked_loss/grpo_loss.py b/src/liger_kernel/chunked_loss/grpo_loss.py index f05cc8744..c38b758a6 100644 --- a/src/liger_kernel/chunked_loss/grpo_loss.py +++ b/src/liger_kernel/chunked_loss/grpo_loss.py @@ -5,10 +5,28 @@ from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase -def k3_loss_fn(log_p, log_q): - # computes k3 estimate of KL[q, p] - # ref: http://joschu.net/blog/kl-approx.html - return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0 +def k3_loss_fn(log_p, log_q, input_clamp_value=None, output_clamp_value=None): + """k3 estimator of KL[q, p]. + + Optionally clamps ``log_p - log_q`` to ``[-input_clamp_value, + input_clamp_value]`` before exponentiation and clamps the resulting kl to + ``[-output_clamp_value, output_clamp_value]``. These guards match the + numerical safety net used by other RL frameworks (e.g. NeMo-RL's + ``calculate_kl``) and prevent fp32 ``exp`` overflow when reference and + policy log-probs diverge — in particular at masked / low-probability + positions, where unbounded ``exp`` produces ``inf`` that survives a + subsequent multiplication by a zero attention mask + (``inf * 0 == nan``) and contaminates the entire reduction. + + ref: http://joschu.net/blog/kl-approx.html + """ + logr = log_p - log_q + if input_clamp_value is not None: + logr = logr.clamp(min=-input_clamp_value, max=input_clamp_value) + kl = torch.exp(logr) - logr - 1.0 + if output_clamp_value is not None: + kl = kl.clamp(min=-output_clamp_value, max=output_clamp_value) + return kl def sapo_loss_fn(importance_ratio: torch.Tensor, temperature: float) -> torch.Tensor: @@ -78,6 +96,9 @@ def ppo_loss_fn( vllm_is_ratio=None, # vLLM importance sampling ratio (chunk_size, seq_len) or (chunk_size, 1) or None delta=None, # Upper clamp for two-sided clipping (INTELLECT-2) use_bias_correction_kl=False, # Importance-sampling-corrected KL (DeepSeek-V3.2) + log_ratio_clamp_value=20.0, # Clamp policy/old log-ratio before exp for numerical stability + kl_input_clamp_value=20.0, # Clamp (ref - policy) log-ratio before exp inside k3 + kl_output_clamp_value=10.0, # Clamp the resulting k3 KL output **kwargs, ): """GRPO Loss Function matching GRPOTrainer implementation.""" @@ -105,6 +126,15 @@ def ppo_loss_fn( # Compute policy gradient loss with importance sampling ratio old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach() log_ratio = per_token_logps - old_per_token_logps + # Clamp the policy/old log-ratio before any subsequent ``exp`` so that + # extreme values at masked / low-probability positions cannot overflow + # fp32 ``exp`` to ``inf`` and contaminate the masked reduction + # (``inf * 0 == nan``). Matches the guard used by NeMo-RL's + # ``calculate_kl`` (``input_clamp_value``). + if log_ratio_clamp_value is not None: + log_ratio = log_ratio.clamp( + min=-log_ratio_clamp_value, max=log_ratio_clamp_value + ) if importance_sampling_level == "token": log_importance_weights = log_ratio @@ -157,7 +187,12 @@ def ppo_loss_fn( if beta != 0.0: # Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps]) - kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps) + kl_div = k3_loss_fn( + ref_per_token_logps, + per_token_logps, + input_clamp_value=kl_input_clamp_value, + output_clamp_value=kl_output_clamp_value, + ) if use_bias_correction_kl: # Importance-sampling-corrected KL (DeepSeek-V3.2): kl *= token-level coef_1 token_coef_1 = torch.exp(per_token_logps - old_per_token_logps) @@ -251,6 +286,9 @@ def forward( vllm_is_ratio=None, delta=None, use_bias_correction_kl=False, + log_ratio_clamp_value=20.0, + kl_input_clamp_value=20.0, + kl_output_clamp_value=10.0, ): """ Fused linear layer with GRPO loss. @@ -278,6 +316,15 @@ def forward( chunk_size (int): Size of chunks for processing. vllm_is_ratio (torch.Tensor, optional): vLLM importance sampling ratio (batch_size, seq_len) or (batch_size, 1) or None. Used to correct for distribution mismatch when using vLLM for generation. + log_ratio_clamp_value (float, optional): If set, clamps the policy/old + log-ratio to ``[-value, value]`` before exponentiation. Prevents + ``exp`` overflow at masked / low-probability positions where + ``inf * 0 == nan`` would otherwise contaminate the reduction. + kl_input_clamp_value (float, optional): If set, clamps ``ref - policy`` + log-ratio inside the k3 estimator before ``exp``. Same rationale + as ``log_ratio_clamp_value``, but for the KL penalty term. + kl_output_clamp_value (float, optional): If set, clamps the resulting + k3 KL value to ``[-value, value]``. Returns: torch.Tensor: Computed loss """ @@ -317,6 +364,9 @@ def forward( vllm_is_ratio=vllm_is_ratio, delta=delta, use_bias_correction_kl=use_bias_correction_kl, + log_ratio_clamp_value=log_ratio_clamp_value, + kl_input_clamp_value=kl_input_clamp_value, + kl_output_clamp_value=kl_output_clamp_value, ) @staticmethod @@ -352,6 +402,9 @@ def backward(ctx, grad_output, *grad_metrics): None, # grad_vllm_is_ratio None, # grad_delta None, # grad_use_bias_correction_kl + None, # grad_log_ratio_clamp_value + None, # grad_kl_input_clamp_value + None, # grad_kl_output_clamp_value ) @@ -374,6 +427,9 @@ def __init__( temperature: float = 1.0, delta: Optional[float] = None, use_bias_correction_kl: bool = False, + log_ratio_clamp_value: Optional[float] = 20.0, + kl_input_clamp_value: Optional[float] = 20.0, + kl_output_clamp_value: Optional[float] = 10.0, ): """ Args: @@ -393,6 +449,14 @@ def __init__( temperature (float): Temperature for the logits. delta (float, optional): Upper clamp for two-sided clipping (INTELLECT-2). None means disabled. use_bias_correction_kl (bool): If True, multiply KL by importance sampling ratio (DeepSeek-V3.2). + log_ratio_clamp_value (float, optional): If set, clamps the policy/old log-ratio + to ``[-value, value]`` before exponentiation. Prevents fp32 ``exp`` overflow + at masked / low-probability positions where ``inf * 0 == nan`` would otherwise + contaminate the masked reduction. None disables the guard. + kl_input_clamp_value (float, optional): If set, clamps ``ref - policy`` log-ratio + inside the k3 estimator before ``exp``. None disables. + kl_output_clamp_value (float, optional): If set, clamps the resulting k3 KL value + to ``[-value, value]``. None disables. """ super().__init__() # Validate SAPO temperatures to prevent division by zero or numerical instability @@ -416,6 +480,9 @@ def __init__( self.temperature = temperature self.delta = delta self.use_bias_correction_kl = use_bias_correction_kl + self.log_ratio_clamp_value = log_ratio_clamp_value + self.kl_input_clamp_value = kl_input_clamp_value + self.kl_output_clamp_value = kl_output_clamp_value def forward( self, @@ -459,4 +526,7 @@ def forward( vllm_is_ratio, self.delta, self.use_bias_correction_kl, + self.log_ratio_clamp_value, + self.kl_input_clamp_value, + self.kl_output_clamp_value, )