Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hamishivi committed Jan 23, 2025
1 parent 0591bf8 commit b8abcad
Showing 1 changed file with 32 additions and 26 deletions.
58 changes: 32 additions & 26 deletions open_instruct/grpo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,10 @@ def vllm_generate(
elif args.kl_estimator == "kl3":
kl = kl3

non_score_reward = -args.beta * kl
non_score_reward_sum = non_score_reward.sum(1)
rlhf_reward = scores + non_score_reward_sum

# print('training starts')
# Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch
for epoch_idx in range(args.num_epochs):
Expand Down Expand Up @@ -1213,19 +1217,21 @@ def vllm_generate(
local_metrics[1] = (responses == args.stop_token_id).sum().float().mean()
local_metrics[2] = kl.sum(1).mean()
local_metrics[3] = (-logprobs).sum(1).mean()
local_metrics[4] = scores.mean()
local_metrics[5] = approxkl_stats.mean()
local_metrics[6] = pg_clipfrac_stats.mean()
local_metrics[7] = pg_loss_stats.mean()
local_metrics[8] = reward_mean.mean()
local_metrics[9] = reward_std.mean()
local_metrics[10] = entropy_stats.mean()
local_metrics[11] = ratio_stats.mean()
local_metrics[12] = ratio_stats.var()
local_metrics[13] = ((kl) ** 2 / 2).sum(1).mean()
local_metrics[14] = ((-kl).exp() - 1 + kl).sum(1).mean()
local_metrics[15] = verifiable_correct_rate
local_metrics[16] = contain_stop_token.float().mean()
local_metrics[4] = non_score_reward_sum.mean()
local_metrics[5] = rlhf_reward.mean()
local_metrics[6] = scores.mean()
local_metrics[7] = approxkl_stats.mean()
local_metrics[8] = pg_clipfrac_stats.mean()
local_metrics[9] = pg_loss_stats.mean()
local_metrics[10] = reward_mean.mean()
local_metrics[11] = reward_std.mean()
local_metrics[12] = entropy_stats.mean()
local_metrics[13] = ratio_stats.mean()
local_metrics[14] = ratio_stats.var()
local_metrics[15] = ((kl) ** 2 / 2).sum(1).mean()
local_metrics[16] = ((-kl).exp() - 1 + kl).sum(1).mean()
local_metrics[17] = verifiable_correct_rate
local_metrics[18] = contain_stop_token.float().mean()
# global_metrics = accelerator.reduce(local_metrics, reduction="mean").tolist()
local_metrics /= dist.get_world_size()
dist.all_reduce(local_metrics, op=dist.ReduceOp.SUM)
Expand All @@ -1240,22 +1246,22 @@ def vllm_generate(
"val/sequence_lengths": global_metrics[0],
"val/num_stop_token_ids": global_metrics[1],
"objective/kl": global_metrics[2],
"objective/kl2": global_metrics[13],
"objective/kl3": global_metrics[14],
"objective/kl2": global_metrics[15],
"objective/kl3": global_metrics[16],
"objective/entropy": global_metrics[3],
"objective/non_score_reward": global_metrics[4],
"objective/rlhf_reward": global_metrics[5],
"objective/scores": global_metrics[4],
"policy/approxkl_avg": global_metrics[5],
"policy/clipfrac_avg": global_metrics[6],
"loss/policy_avg": global_metrics[7],
"objective/scores_mean": global_metrics[8],
"objective/reward_std": global_metrics[9],
"policy/entropy_avg": global_metrics[10],
"val/ratio": global_metrics[11],
"val/ratio_var": global_metrics[12],
"objective/verifiable_correct_rate": global_metrics[15],
"val/stop_token_rate": global_metrics[16],
"objective/scores": global_metrics[6],
"policy/approxkl_avg": global_metrics[7],
"policy/clipfrac_avg": global_metrics[8],
"loss/policy_avg": global_metrics[9],
"objective/scores_mean": global_metrics[10],
"objective/reward_std": global_metrics[11],
"policy/entropy_avg": global_metrics[12],
"val/ratio": global_metrics[13],
"val/ratio_var": global_metrics[14],
"objective/verifiable_correct_rate": global_metrics[17],
"val/stop_token_rate": global_metrics[18],
}
if accelerator.is_main_process:
print_rich_single_line_metrics(metrics)
Expand Down

0 comments on commit b8abcad

Please sign in to comment.