diff --git a/open_instruct/grpo_vllm_thread_ray_gtrl.py b/open_instruct/grpo_vllm_thread_ray_gtrl.py index b6696773b..29066df52 100644 --- a/open_instruct/grpo_vllm_thread_ray_gtrl.py +++ b/open_instruct/grpo_vllm_thread_ray_gtrl.py @@ -1136,7 +1136,6 @@ def vllm_generate( reward_mean = mean_grouped_rewards reward_std = std_grouped_rewards - # print('training starts') # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch for epoch_idx in range(args.num_epochs): @@ -1205,7 +1204,9 @@ def vllm_generate( # prob_dist = torch.nn.functional.softmax(logits, dim=-1) # entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) non_score_reward = -args.beta * kl - non_score_reward_sum_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = non_score_reward.sum(1).mean() + non_score_reward_sum_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + non_score_reward.sum(1).mean() + ) # print("step finished", self.rank, "micro batch start", micro_batch_start) approxkl = 0.5 * (logprobs_diff**2).mean() approxkl_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl