Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/liger_kernel/chunked_loss/fused_linear_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def forward(
vllm_is_ratio=None,
delta=None,
use_bias_correction_kl=False,
log_ratio_clamp_value=20.0,
kl_input_clamp_value=20.0,
kl_output_clamp_value=10.0,
):
# TODO: check torch compile matmul
"""Chunked forward pass for PPO loss computation.
Expand Down Expand Up @@ -125,6 +128,9 @@ def forward(
sapo_temperature_neg=sapo_temperature_neg,
delta=delta,
use_bias_correction_kl=use_bias_correction_kl,
log_ratio_clamp_value=log_ratio_clamp_value,
kl_input_clamp_value=kl_input_clamp_value,
kl_output_clamp_value=kl_output_clamp_value,
)

def fused_fwd_bwd(
Expand Down Expand Up @@ -327,6 +333,9 @@ def _compute_chunk_loss(
sapo_temperature_neg=1.05,
delta=None,
use_bias_correction_kl=False,
log_ratio_clamp_value=20.0,
kl_input_clamp_value=20.0,
kl_output_clamp_value=10.0,
):
"""Compute loss for a single chunk."""
# Get policy log probabilities using chunk_forward
Expand Down Expand Up @@ -361,6 +370,9 @@ def _compute_chunk_loss(
vllm_is_ratio=vllm_is_ratio_chunk,
delta=delta,
use_bias_correction_kl=use_bias_correction_kl,
log_ratio_clamp_value=log_ratio_clamp_value,
kl_input_clamp_value=kl_input_clamp_value,
kl_output_clamp_value=kl_output_clamp_value,
)

return chunk_loss, chunk_metrics
Expand Down
80 changes: 75 additions & 5 deletions src/liger_kernel/chunked_loss/grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,28 @@
from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase


def k3_loss_fn(log_p, log_q):
# computes k3 estimate of KL[q, p]
# ref: http://joschu.net/blog/kl-approx.html
return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0
def k3_loss_fn(log_p, log_q, input_clamp_value=None, output_clamp_value=None):
"""k3 estimator of KL[q, p].

Optionally clamps ``log_p - log_q`` to ``[-input_clamp_value,
input_clamp_value]`` before exponentiation and clamps the resulting kl to
``[-output_clamp_value, output_clamp_value]``. These guards match the
numerical safety net used by other RL frameworks (e.g. NeMo-RL's
``calculate_kl``) and prevent fp32 ``exp`` overflow when reference and
policy log-probs diverge — in particular at masked / low-probability
positions, where unbounded ``exp`` produces ``inf`` that survives a
subsequent multiplication by a zero attention mask
(``inf * 0 == nan``) and contaminates the entire reduction.

ref: http://joschu.net/blog/kl-approx.html
"""
logr = log_p - log_q
if input_clamp_value is not None:
logr = logr.clamp(min=-input_clamp_value, max=input_clamp_value)
kl = torch.exp(logr) - logr - 1.0
if output_clamp_value is not None:
kl = kl.clamp(min=-output_clamp_value, max=output_clamp_value)
return kl


def sapo_loss_fn(importance_ratio: torch.Tensor, temperature: float) -> torch.Tensor:
Expand Down Expand Up @@ -78,6 +96,9 @@ def ppo_loss_fn(
vllm_is_ratio=None, # vLLM importance sampling ratio (chunk_size, seq_len) or (chunk_size, 1) or None
delta=None, # Upper clamp for two-sided clipping (INTELLECT-2)
use_bias_correction_kl=False, # Importance-sampling-corrected KL (DeepSeek-V3.2)
log_ratio_clamp_value=20.0, # Clamp policy/old log-ratio before exp for numerical stability
kl_input_clamp_value=20.0, # Clamp (ref - policy) log-ratio before exp inside k3
kl_output_clamp_value=10.0, # Clamp the resulting k3 KL output
**kwargs,
):
"""GRPO Loss Function matching GRPOTrainer implementation."""
Expand Down Expand Up @@ -105,6 +126,15 @@ def ppo_loss_fn(
# Compute policy gradient loss with importance sampling ratio
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
log_ratio = per_token_logps - old_per_token_logps
# Clamp the policy/old log-ratio before any subsequent ``exp`` so that
# extreme values at masked / low-probability positions cannot overflow
# fp32 ``exp`` to ``inf`` and contaminate the masked reduction
# (``inf * 0 == nan``). Matches the guard used by NeMo-RL's
# ``calculate_kl`` (``input_clamp_value``).
if log_ratio_clamp_value is not None:
log_ratio = log_ratio.clamp(
min=-log_ratio_clamp_value, max=log_ratio_clamp_value
)

if importance_sampling_level == "token":
log_importance_weights = log_ratio
Expand Down Expand Up @@ -157,7 +187,12 @@ def ppo_loss_fn(

if beta != 0.0:
# Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
kl_div = k3_loss_fn(
ref_per_token_logps,
per_token_logps,
input_clamp_value=kl_input_clamp_value,
output_clamp_value=kl_output_clamp_value,
)
if use_bias_correction_kl:
# Importance-sampling-corrected KL (DeepSeek-V3.2): kl *= token-level coef_1
token_coef_1 = torch.exp(per_token_logps - old_per_token_logps)
Expand Down Expand Up @@ -251,6 +286,9 @@ def forward(
vllm_is_ratio=None,
delta=None,
use_bias_correction_kl=False,
log_ratio_clamp_value=20.0,
kl_input_clamp_value=20.0,
kl_output_clamp_value=10.0,
):
"""
Fused linear layer with GRPO loss.
Expand Down Expand Up @@ -278,6 +316,15 @@ def forward(
chunk_size (int): Size of chunks for processing.
vllm_is_ratio (torch.Tensor, optional): vLLM importance sampling ratio (batch_size, seq_len) or (batch_size, 1) or None.
Used to correct for distribution mismatch when using vLLM for generation.
log_ratio_clamp_value (float, optional): If set, clamps the policy/old
log-ratio to ``[-value, value]`` before exponentiation. Prevents
``exp`` overflow at masked / low-probability positions where
``inf * 0 == nan`` would otherwise contaminate the reduction.
kl_input_clamp_value (float, optional): If set, clamps ``ref - policy``
log-ratio inside the k3 estimator before ``exp``. Same rationale
as ``log_ratio_clamp_value``, but for the KL penalty term.
kl_output_clamp_value (float, optional): If set, clamps the resulting
k3 KL value to ``[-value, value]``.
Returns:
torch.Tensor: Computed loss
"""
Expand Down Expand Up @@ -317,6 +364,9 @@ def forward(
vllm_is_ratio=vllm_is_ratio,
delta=delta,
use_bias_correction_kl=use_bias_correction_kl,
log_ratio_clamp_value=log_ratio_clamp_value,
kl_input_clamp_value=kl_input_clamp_value,
kl_output_clamp_value=kl_output_clamp_value,
)

@staticmethod
Expand Down Expand Up @@ -352,6 +402,9 @@ def backward(ctx, grad_output, *grad_metrics):
None, # grad_vllm_is_ratio
None, # grad_delta
None, # grad_use_bias_correction_kl
None, # grad_log_ratio_clamp_value
None, # grad_kl_input_clamp_value
None, # grad_kl_output_clamp_value
)


Expand All @@ -374,6 +427,9 @@ def __init__(
temperature: float = 1.0,
delta: Optional[float] = None,
use_bias_correction_kl: bool = False,
log_ratio_clamp_value: Optional[float] = 20.0,
kl_input_clamp_value: Optional[float] = 20.0,
kl_output_clamp_value: Optional[float] = 10.0,
):
"""
Args:
Expand All @@ -393,6 +449,14 @@ def __init__(
temperature (float): Temperature for the logits.
delta (float, optional): Upper clamp for two-sided clipping (INTELLECT-2). None means disabled.
use_bias_correction_kl (bool): If True, multiply KL by importance sampling ratio (DeepSeek-V3.2).
log_ratio_clamp_value (float, optional): If set, clamps the policy/old log-ratio
to ``[-value, value]`` before exponentiation. Prevents fp32 ``exp`` overflow
at masked / low-probability positions where ``inf * 0 == nan`` would otherwise
contaminate the masked reduction. None disables the guard.
kl_input_clamp_value (float, optional): If set, clamps ``ref - policy`` log-ratio
inside the k3 estimator before ``exp``. None disables.
kl_output_clamp_value (float, optional): If set, clamps the resulting k3 KL value
to ``[-value, value]``. None disables.
"""
super().__init__()
# Validate SAPO temperatures to prevent division by zero or numerical instability
Expand All @@ -416,6 +480,9 @@ def __init__(
self.temperature = temperature
self.delta = delta
self.use_bias_correction_kl = use_bias_correction_kl
self.log_ratio_clamp_value = log_ratio_clamp_value
self.kl_input_clamp_value = kl_input_clamp_value
self.kl_output_clamp_value = kl_output_clamp_value

def forward(
self,
Expand Down Expand Up @@ -459,4 +526,7 @@ def forward(
vllm_is_ratio,
self.delta,
self.use_bias_correction_kl,
self.log_ratio_clamp_value,
self.kl_input_clamp_value,
self.kl_output_clamp_value,
)