Skip to content

Commit

Permalink
Indexing ref_logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav Pandey [email protected] committed Jan 29, 2025
1 parent 7a2df6a commit 617fae1
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion open_instruct/grpo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,7 @@ def vllm_generate(
# grpo change: directly subtract KL in loss (add)

# kl loss should be computed without torch.no_grad()
kl1 = new_logprobs - ref_logprobs
kl1 = new_logprobs - ref_logprobs[micro_batch_inds]
kl2 = (kl1) ** 2 / 2
kl3 = (-kl1).exp() - 1 + kl1
if args.kl_estimator == "kl1":
Expand Down

0 comments on commit 617fae1

Please sign in to comment.