feat(grpo): expose per-token policy logprobs via optional return flag#5
Draft
WyldeCat wants to merge 1 commit into
Draft
feat(grpo): expose per-token policy logprobs via optional return flag#5WyldeCat wants to merge 1 commit into
WyldeCat wants to merge 1 commit into
Conversation
The chunked GRPO forward already computes per_token_logps via
log_probs.gather(...) for the importance-ratio path; expose it (under
an opt-in flag) so downstream callers can compute additional diagnostic
metrics (probs_ratio, approx_entropy, etc.) without re-running the
lm_head matmul a second time.
Add return_per_token_logps: bool = False to:
- LigerFusedLinearGRPOFunction.{ppo_loss_fn, forward, backward}
- LigerFusedLinearGRPOLoss.{__init__, forward}
- LigerFusedLinearPPOBase.{forward, _compute_chunk_loss} (threading)
When True, the (detached) per-token logps for each chunk are appended to
the metrics list; the existing tensor-aggregator in the base class
concatenates chunks via torch.cat into a single [B, T] tensor
returned as the last metrics entry. Default False keeps the return
shape identical for existing callers.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
요약
`LigerFusedLinearGRPOFunction` / `LigerFusedLinearGRPOLoss`에 옵셔널 `return_per_token_logps: bool = False` 인자 추가. 켜면 chunk별 per-token policy log-probs을 detached로 누적해서 metrics tuple 끝에 `[B, T]` 텐서로 반환.
동기
chunked GRPO forward는 importance-ratio 계산 위해 이미 ``per_token_logps = log_probs.gather(...)``로 per-token logp를 계산하고 있음 (`grpo_loss.py:91-93`). 다만 외부로 노출 안 함 → 추가 진단 metric (e.g. `probs_ratio`, `approx_entropy`)을 호출자가 계산하려면 lm_head matmul forward를 한 번 더 돌려야 함. 이미 있는 값을 옵셔널로 내보내면 그 비용 없이 재사용 가능.
변경
3개 시그니처에 `return_per_token_logps` 추가:
True일 때:
per_token_logps.detach()를 append (`grpo_loss.py:206-216`).torch.cat(metric, dim=0)로 `[B, T]` 단일 텐서로 합침.Backwards compatibility
Default `False`라 기존 caller의 return shape 변화 없음. 단 `tuple(final_metrics)` 길이가 1 늘어남이 신경 쓰이면 caller는 그냥 index -1로 가드.
테스트
ClippedPGLossFn경로와 `probs_ratio` 값 numerical diff 비교.🤖 Generated with Claude Code