Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/liger_kernel/chunked_loss/fused_linear_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def forward(
log_ratio_clamp_value=20.0,
kl_input_clamp_value=20.0,
kl_output_clamp_value=10.0,
return_per_token_logps=False,
):
# TODO: check torch compile matmul
"""Chunked forward pass for PPO loss computation.
Expand Down Expand Up @@ -131,6 +132,7 @@ def forward(
log_ratio_clamp_value=log_ratio_clamp_value,
kl_input_clamp_value=kl_input_clamp_value,
kl_output_clamp_value=kl_output_clamp_value,
return_per_token_logps=return_per_token_logps,
)

def fused_fwd_bwd(
Expand Down Expand Up @@ -336,6 +338,7 @@ def _compute_chunk_loss(
log_ratio_clamp_value=20.0,
kl_input_clamp_value=20.0,
kl_output_clamp_value=10.0,
return_per_token_logps=False,
):
"""Compute loss for a single chunk."""
# Get policy log probabilities using chunk_forward
Expand Down Expand Up @@ -373,6 +376,7 @@ def _compute_chunk_loss(
log_ratio_clamp_value=log_ratio_clamp_value,
kl_input_clamp_value=kl_input_clamp_value,
kl_output_clamp_value=kl_output_clamp_value,
return_per_token_logps=return_per_token_logps,
)

return chunk_loss, chunk_metrics
Expand Down
24 changes: 24 additions & 0 deletions src/liger_kernel/chunked_loss/grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def ppo_loss_fn(
log_ratio_clamp_value=20.0, # Clamp policy/old log-ratio before exp for numerical stability
kl_input_clamp_value=20.0, # Clamp (ref - policy) log-ratio before exp inside k3
kl_output_clamp_value=10.0, # Clamp the resulting k3 KL output
return_per_token_logps=False, # Append per-token policy logp to metrics for diagnostic reuse
**kwargs,
):
"""GRPO Loss Function matching GRPOTrainer implementation."""
Expand Down Expand Up @@ -254,6 +255,13 @@ def ppo_loss_fn(
is_clipped = is_clipped.expand_as(attention_mask)

metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))

# Optionally expose the per-token policy logps so callers can compute
# additional diagnostic metrics (e.g. probs_ratio, approx_entropy)
# without re-running the lm_head matmul. Detached so it doesn't
# participate in the backward chain.
if return_per_token_logps:
metrics.append(per_token_logps.detach())
return loss, metrics

@classmethod
Expand Down Expand Up @@ -289,6 +297,7 @@ def forward(
log_ratio_clamp_value=20.0,
kl_input_clamp_value=20.0,
kl_output_clamp_value=10.0,
return_per_token_logps=False,
):
"""
Fused linear layer with GRPO loss.
Expand Down Expand Up @@ -325,6 +334,12 @@ def forward(
as ``log_ratio_clamp_value``, but for the KL penalty term.
kl_output_clamp_value (float, optional): If set, clamps the resulting
k3 KL value to ``[-value, value]``.
return_per_token_logps (bool): If True, append the (detached) per-token
policy log-probabilities to the returned metrics tuple. Lets
downstream callers compute additional diagnostic metrics
(e.g. probs_ratio, approx_entropy) without re-running the lm_head
matmul. Off by default to avoid the extra [B, T] tensor when
unused.
Returns:
torch.Tensor: Computed loss
"""
Expand Down Expand Up @@ -367,6 +382,7 @@ def forward(
log_ratio_clamp_value=log_ratio_clamp_value,
kl_input_clamp_value=kl_input_clamp_value,
kl_output_clamp_value=kl_output_clamp_value,
return_per_token_logps=return_per_token_logps,
)

@staticmethod
Expand Down Expand Up @@ -405,6 +421,7 @@ def backward(ctx, grad_output, *grad_metrics):
None, # grad_log_ratio_clamp_value
None, # grad_kl_input_clamp_value
None, # grad_kl_output_clamp_value
None, # grad_return_per_token_logps
)


Expand All @@ -430,6 +447,7 @@ def __init__(
log_ratio_clamp_value: Optional[float] = 20.0,
kl_input_clamp_value: Optional[float] = 20.0,
kl_output_clamp_value: Optional[float] = 10.0,
return_per_token_logps: bool = False,
):
"""
Args:
Expand Down Expand Up @@ -457,6 +475,10 @@ def __init__(
inside the k3 estimator before ``exp``. None disables.
kl_output_clamp_value (float, optional): If set, clamps the resulting k3 KL value
to ``[-value, value]``. None disables.
return_per_token_logps (bool): If True, the returned metrics tuple gets an
additional trailing entry — the (detached) per-token policy log-probs
concatenated across chunks ([B, T]). Lets callers compute extra
diagnostic metrics without a second lm_head forward.
"""
super().__init__()
# Validate SAPO temperatures to prevent division by zero or numerical instability
Expand All @@ -483,6 +505,7 @@ def __init__(
self.log_ratio_clamp_value = log_ratio_clamp_value
self.kl_input_clamp_value = kl_input_clamp_value
self.kl_output_clamp_value = kl_output_clamp_value
self.return_per_token_logps = return_per_token_logps

def forward(
self,
Expand Down Expand Up @@ -529,4 +552,5 @@ def forward(
self.log_ratio_clamp_value,
self.kl_input_clamp_value,
self.kl_output_clamp_value,
self.return_per_token_logps,
)