diff --git a/src/liger_kernel/chunked_loss/fused_linear_ppo.py b/src/liger_kernel/chunked_loss/fused_linear_ppo.py index b79f8ec64..8f1ed07b4 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_ppo.py +++ b/src/liger_kernel/chunked_loss/fused_linear_ppo.py @@ -47,6 +47,7 @@ def forward( log_ratio_clamp_value=20.0, kl_input_clamp_value=20.0, kl_output_clamp_value=10.0, + return_per_token_logps=False, ): # TODO: check torch compile matmul """Chunked forward pass for PPO loss computation. @@ -131,6 +132,7 @@ def forward( log_ratio_clamp_value=log_ratio_clamp_value, kl_input_clamp_value=kl_input_clamp_value, kl_output_clamp_value=kl_output_clamp_value, + return_per_token_logps=return_per_token_logps, ) def fused_fwd_bwd( @@ -336,6 +338,7 @@ def _compute_chunk_loss( log_ratio_clamp_value=20.0, kl_input_clamp_value=20.0, kl_output_clamp_value=10.0, + return_per_token_logps=False, ): """Compute loss for a single chunk.""" # Get policy log probabilities using chunk_forward @@ -373,6 +376,7 @@ def _compute_chunk_loss( log_ratio_clamp_value=log_ratio_clamp_value, kl_input_clamp_value=kl_input_clamp_value, kl_output_clamp_value=kl_output_clamp_value, + return_per_token_logps=return_per_token_logps, ) return chunk_loss, chunk_metrics diff --git a/src/liger_kernel/chunked_loss/grpo_loss.py b/src/liger_kernel/chunked_loss/grpo_loss.py index c38b758a6..6e7f69692 100644 --- a/src/liger_kernel/chunked_loss/grpo_loss.py +++ b/src/liger_kernel/chunked_loss/grpo_loss.py @@ -99,6 +99,7 @@ def ppo_loss_fn( log_ratio_clamp_value=20.0, # Clamp policy/old log-ratio before exp for numerical stability kl_input_clamp_value=20.0, # Clamp (ref - policy) log-ratio before exp inside k3 kl_output_clamp_value=10.0, # Clamp the resulting k3 KL output + return_per_token_logps=False, # Append per-token policy logp to metrics for diagnostic reuse **kwargs, ): """GRPO Loss Function matching GRPOTrainer implementation.""" @@ -254,6 +255,13 @@ def ppo_loss_fn( is_clipped = is_clipped.expand_as(attention_mask) metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)) + + # Optionally expose the per-token policy logps so callers can compute + # additional diagnostic metrics (e.g. probs_ratio, approx_entropy) + # without re-running the lm_head matmul. Detached so it doesn't + # participate in the backward chain. + if return_per_token_logps: + metrics.append(per_token_logps.detach()) return loss, metrics @classmethod @@ -289,6 +297,7 @@ def forward( log_ratio_clamp_value=20.0, kl_input_clamp_value=20.0, kl_output_clamp_value=10.0, + return_per_token_logps=False, ): """ Fused linear layer with GRPO loss. @@ -325,6 +334,12 @@ def forward( as ``log_ratio_clamp_value``, but for the KL penalty term. kl_output_clamp_value (float, optional): If set, clamps the resulting k3 KL value to ``[-value, value]``. + return_per_token_logps (bool): If True, append the (detached) per-token + policy log-probabilities to the returned metrics tuple. Lets + downstream callers compute additional diagnostic metrics + (e.g. probs_ratio, approx_entropy) without re-running the lm_head + matmul. Off by default to avoid the extra [B, T] tensor when + unused. Returns: torch.Tensor: Computed loss """ @@ -367,6 +382,7 @@ def forward( log_ratio_clamp_value=log_ratio_clamp_value, kl_input_clamp_value=kl_input_clamp_value, kl_output_clamp_value=kl_output_clamp_value, + return_per_token_logps=return_per_token_logps, ) @staticmethod @@ -405,6 +421,7 @@ def backward(ctx, grad_output, *grad_metrics): None, # grad_log_ratio_clamp_value None, # grad_kl_input_clamp_value None, # grad_kl_output_clamp_value + None, # grad_return_per_token_logps ) @@ -430,6 +447,7 @@ def __init__( log_ratio_clamp_value: Optional[float] = 20.0, kl_input_clamp_value: Optional[float] = 20.0, kl_output_clamp_value: Optional[float] = 10.0, + return_per_token_logps: bool = False, ): """ Args: @@ -457,6 +475,10 @@ def __init__( inside the k3 estimator before ``exp``. None disables. kl_output_clamp_value (float, optional): If set, clamps the resulting k3 KL value to ``[-value, value]``. None disables. + return_per_token_logps (bool): If True, the returned metrics tuple gets an + additional trailing entry — the (detached) per-token policy log-probs + concatenated across chunks ([B, T]). Lets callers compute extra + diagnostic metrics without a second lm_head forward. """ super().__init__() # Validate SAPO temperatures to prevent division by zero or numerical instability @@ -483,6 +505,7 @@ def __init__( self.log_ratio_clamp_value = log_ratio_clamp_value self.kl_input_clamp_value = kl_input_clamp_value self.kl_output_clamp_value = kl_output_clamp_value + self.return_per_token_logps = return_per_token_logps def forward( self, @@ -529,4 +552,5 @@ def forward( self.log_ratio_clamp_value, self.kl_input_clamp_value, self.kl_output_clamp_value, + self.return_per_token_logps, )