Skip to content

feat(grpo): expose per-token policy logprobs via optional return flag#5

Draft
WyldeCat wants to merge 1 commit into
feat/flce-num-chunks-override-v2from
feat/grpo-return-per-token-logps
Draft

feat(grpo): expose per-token policy logprobs via optional return flag#5
WyldeCat wants to merge 1 commit into
feat/flce-num-chunks-override-v2from
feat/grpo-return-per-token-logps

Conversation

@WyldeCat
Copy link
Copy Markdown
Member

요약

`LigerFusedLinearGRPOFunction` / `LigerFusedLinearGRPOLoss`에 옵셔널 `return_per_token_logps: bool = False` 인자 추가. 켜면 chunk별 per-token policy log-probs을 detached로 누적해서 metrics tuple 끝에 `[B, T]` 텐서로 반환.

동기

chunked GRPO forward는 importance-ratio 계산 위해 이미 ``per_token_logps = log_probs.gather(...)``로 per-token logp를 계산하고 있음 (`grpo_loss.py:91-93`). 다만 외부로 노출 안 함 → 추가 진단 metric (e.g. `probs_ratio`, `approx_entropy`)을 호출자가 계산하려면 lm_head matmul forward를 한 번 더 돌려야 함. 이미 있는 값을 옵셔널로 내보내면 그 비용 없이 재사용 가능.

변경

3개 시그니처에 `return_per_token_logps` 추가:

  • `LigerFusedLinearGRPOFunction.ppo_loss_fn` / `forward` / `backward` (extra `None` grad)
  • `LigerFusedLinearGRPOLoss.init` / `forward`
  • `LigerFusedLinearPPOBase.forward` / `_compute_chunk_loss` (threading만)

True일 때:

  1. ppo_loss_fn이 metrics list 끝에 per_token_logps.detach()를 append (`grpo_loss.py:206-216`).
  2. base 클래스의 기존 tensor-aggregator (`fused_linear_ppo.py:183-195`)가 list-of-tensors를 처리 — chunk별로 append → forward 마지막에 torch.cat(metric, dim=0)로 `[B, T]` 단일 텐서로 합침.

Backwards compatibility

Default `False`라 기존 caller의 return shape 변화 없음. 단 `tuple(final_metrics)` 길이가 1 늘어남이 신경 쓰이면 caller는 그냥 index -1로 가드.

테스트

  • Module smoke: default `False` → `len(metrics)==2` ((kl, is_clipped) 둘 다 scalar). `True` → `len(metrics)==3` 마지막이 `[B, T]` 텐서, `requires_grad=False` (detached).
  • Functional: NeMoRL-torchtitan 측 wire-up 후 standard ClippedPGLossFn 경로와 `probs_ratio` 값 numerical diff 비교.

🤖 Generated with Claude Code

The chunked GRPO forward already computes per_token_logps via
log_probs.gather(...) for the importance-ratio path; expose it (under
an opt-in flag) so downstream callers can compute additional diagnostic
metrics (probs_ratio, approx_entropy, etc.) without re-running the
lm_head matmul a second time.

Add return_per_token_logps: bool = False to:
- LigerFusedLinearGRPOFunction.{ppo_loss_fn, forward, backward}
- LigerFusedLinearGRPOLoss.{__init__, forward}
- LigerFusedLinearPPOBase.{forward, _compute_chunk_loss} (threading)

When True, the (detached) per-token logps for each chunk are appended to
the metrics list; the existing tensor-aggregator in the base class
concatenates chunks via torch.cat into a single [B, T] tensor
returned as the last metrics entry. Default False keeps the return
shape identical for existing callers.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@WyldeCat WyldeCat marked this pull request as draft May 21, 2026 06:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant