Skip to content

Commit

Permalink
DS2 fix and additional logging
Browse files Browse the repository at this point in the history
vwxyzjn committed Jan 29, 2025
1 parent 8313b38 commit abfe720
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions open_instruct/grpo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
@@ -627,7 +627,9 @@ def from_pretrained(
# reference model
ds_config = get_eval_ds_config(
offload=False,
stage=args.deepspeed_stage,
# inference model only has stage 3 (sharding) or stage 0 (no sharding)
# stage 2 is optimizer sharding which doesn't apply to inference
stage=args.deepspeed_stage if args.deepspeed_stage == 3 else 0,
bf16=True,
)
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
@@ -666,7 +668,9 @@ def from_pretrained(
disable_dropout_in_model(self.reward_model)
ds_config = get_eval_ds_config(
offload=False,
stage=args.deepspeed_stage,
# inference model only has stage 3 (sharding) or stage 0 (no sharding)
# stage 2 is optimizer sharding which doesn't apply to inference
stage=args.deepspeed_stage if args.deepspeed_stage == 3 else 0,
bf16=True,
)
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
@@ -1163,12 +1167,17 @@ def vllm_generate(
logprobs_diff = new_logprobs - mb_logprobs
ratio = torch.exp(logprobs_diff)
pg_losses = -mb_advantage[:, None] * ratio
pg_losses2 = -mb_advantage[:, None] * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
pg_losses2 = -mb_advantage[:, None] * torch.clamp(
ratio, 1.0 - args.cliprange, 1.0 + args.cliprange
)
pg_loss_max = torch.max(pg_losses, pg_losses2)
pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
# grpo change: directly subtract KL in loss (add)

# kl loss should be computed without torch.no_grad()
# recalculate kl: difference from PPO because we want the KL loss
# to backpropagate through the model
# TODO: we potentially need to clip this kl loss in case some logits
# exploded, but more on it later.
kl1 = new_logprobs - ref_logprobs[micro_batch_inds]
kl2 = (kl1) ** 2 / 2
kl3 = (-kl1).exp() - 1 + kl1
@@ -1179,7 +1188,6 @@ def vllm_generate(
elif args.kl_estimator == "kl3":
kl = kl3


pg_loss = pg_loss + (args.beta * kl).mean()
loss = pg_loss
self.model.backward(loss)
@@ -1197,8 +1205,7 @@ def vllm_generate(
# prob_dist = torch.nn.functional.softmax(logits, dim=-1)
# entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
non_score_reward = -args.beta * kl
non_score_reward_sum = non_score_reward.sum(1).mean()
non_score_reward_sum_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = non_score_reward_sum
non_score_reward_sum_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = non_score_reward.sum(1).mean()
# print("step finished", self.rank, "micro batch start", micro_batch_start)
approxkl = 0.5 * (logprobs_diff**2).mean()
approxkl_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
@@ -1236,6 +1243,8 @@ def vllm_generate(
local_metrics[16] = ((-kl).exp() - 1 + kl).sum(1).mean()
local_metrics[17] = verifiable_correct_rate
local_metrics[18] = contain_stop_token.float().mean()
local_metrics[19] = sequence_lengths.float().min()
local_metrics[20] = sequence_lengths.float().max()
# global_metrics = accelerator.reduce(local_metrics, reduction="mean").tolist()
local_metrics /= dist.get_world_size()
dist.all_reduce(local_metrics, op=dist.ReduceOp.SUM)
@@ -1248,6 +1257,8 @@ def vllm_generate(
"time/from_scratch": time.time() - start_time,
"time/training": time.time() - training_time_start,
"val/sequence_lengths": global_metrics[0],
"val/sequence_lengths_min": global_metrics[19],
"val/sequence_lengths_max": global_metrics[20],
"val/num_stop_token_ids": global_metrics[1],
"objective/kl": global_metrics[2],
"objective/kl2": global_metrics[15],

0 comments on commit abfe720

Please sign in to comment.