-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GRPO questions #2608
Comments
Hey @natolambert!!
Here's how I understand it: trl/trl/trainer/ppo_trainer.py Lines 509 to 515 in 949db23
Here, in GRPO, the advantage is computed without the KL term. It's just the output of the reward function, normalised per group: trl/trl/trainer/grpo_trainer.py Lines 274 to 281 in 949db23
and later you subtract the KL term trl/trl/trainer/grpo_trainer.py Line 285 in 949db23
I don't know how complicated it would have been to integrate KL into the reward. It would probably look something like # Subtract KL
rewards = rewards - self.beta * per_token_kl # <- THIS IS ADDED
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# x - x.detach() allows for preserving gradients from x
advantages = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -advantages # <- THIS IS MODIFIED which seems pretty simple. But perhaps they means that the subsequent equations and calculations would have been more complicated. |
No. What is the underlying intuition? Something like this? loss = ((per_token_loss * completion_mask).sum(dim=1)).mean() |
In the current implementation, we're just update once after a generation. In fact, we align with this sentence from the paper:
Therefore It implies that can be simplified to But, we could support multiple update after each generation. And it would require to have this PPO clipping logic. It would probably allow to reuse generation and be more sample efficient. On the other hand, this would probably require a rather hard-to-read implementation, as the optimization step is performed in the parent trainer class. It would look something like def __init__(self, ...):
...
self.train_dataset = repeat_interleave(train_dataset, self.num_grpo_iterations) # [prompt0, prompt1] -> [prompt0, prompt0, prompt0, prompt1, prompt1, prompt1]
def compute_loss(self, model, inputs):
if self.step % self.num_grpo_iterations == 0: # self.num_grpo_iterations is 𝜇 in the paper
completions = model.generate(prompts)
self.old_log_probs = model(cat(prompts, completions))
log_probs = model(cat(prompts, completions))
log_ratio = log_probs - self.old_log_probs
losses = min(exp(log_ratio)*advantages, clip(exp(log_ratio), 1-epsilon, 1+epsilon)*advantages)
losses = losses - beta*kl not sure if it's worth having this extra complexity. |
i think there is miscommunication. # Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# x - x.detach() allows for preserving gradients from x
advantages = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(advantages - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() here and for the second question, i think we should mean both per group loss and per token loss following the paper. p.s. and the reason why there is no clipping make sense because current impl does not allow to iteratively update policy using trajectories sampled from old policy, so log ratio will always be 1 |
So this instead (just a renaming)?: # x - x.detach() allows for preserving gradients from x
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() Maybe more aligned with the formulation
|
lgtm :) |
Yeah, I'm mostly aligned now (I haven't fully checked the math), but it seems like practically this is because you don't do minimatches you can get away with it? Do you have a line by line derivation (lol, I will ask cluade). 👀 Thanks @SeunghyunSEO for the snipe on the variable names, that's what caught me up. Looks fine now and thanks for updating it. EDIT: okay, I see, if they are the same then the min and clip become redundant (cliped to 1, and min of two identical quantities), so it simplifies to what you have. |
For example, you can see our implementation in open instruct that we just added too allenai/open-instruct#523 |
Hey friends! I have some questions on the GRPO implementation, happy to discuss.
trl/trl/trainer/grpo_trainer.py
Line 286 in fe4b5ef
The text was updated successfully, but these errors were encountered: