Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRPO questions #2608

Open
natolambert opened this issue Jan 22, 2025 · 9 comments
Open

GRPO questions #2608

natolambert opened this issue Jan 22, 2025 · 9 comments
Labels
🏋 GRPO Related to GRPO ❓ question Seeking clarification or more information

Comments

@natolambert
Copy link
Contributor

Hey friends! I have some questions on the GRPO implementation, happy to discuss.

  1. It looks like you apply the KL distance in the advantages, while the DeepSeekMath paper says “Also note that, instead of adding KLpenalty in the reward, GRPO regularizes by directly adding the KL divergence between the trained policy and the reference policy to the loss, avoiding complicating the calculation of 𝐴ˆ”
  2. Did any thought go into making this a sum of loss and not mean? We aren’t sure
    loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
  3. I didn’t see the PPO clipping logic in policy gradient loss, coming soon?
@github-actions github-actions bot added 🏋 PPO Related to PPO ❓ question Seeking clarification or more information labels Jan 22, 2025
@August-murr August-murr added 🏋 GRPO Related to GRPO and removed 🏋 PPO Related to PPO labels Jan 23, 2025
@qgallouedec
Copy link
Member

qgallouedec commented Jan 23, 2025

Hey @natolambert!!

  1. It looks like you apply the KL distance in the advantages, while the DeepSeekMath paper says “Also note that, instead of adding KLpenalty in the reward, GRPO regularizes by directly adding the KL divergence between the trained policy and the reference policy to the loss, avoiding complicating the calculation of 𝐴ˆ”

Here's how I understand it:
In PPO, the KL div is subtracted from the reward (adding KL penalty in the reward)

# 4. compute rewards
kl = logprobs - ref_logprobs
non_score_reward = -args.kl_coef * kl
rewards = non_score_reward.clone()
actual_start = torch.arange(rewards.size(0), device=rewards.device)
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
rewards[[actual_start, actual_end]] += scores

Here, in GRPO, the advantage is computed without the KL term. It's just the output of the reward function, normalised per group:

# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

and later you subtract the KL term

per_token_loss = -(advantages - self.beta * per_token_kl)

I don't know how complicated it would have been to integrate KL into the reward. It would probably look something like

# Subtract KL
rewards  = rewards - self.beta * per_token_kl  # <- THIS IS ADDED

# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)

# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

# x - x.detach() allows for preserving gradients from x
advantages = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -advantages  # <- THIS IS MODIFIED

which seems pretty simple. But perhaps they means that the subsequent equations and calculations would have been more complicated.

@qgallouedec
Copy link
Member

qgallouedec commented Jan 23, 2025

  1. Did any thought go into making this a sum of loss and not mean? We aren’t sure

No. What is the underlying intuition? Something like this?

loss = ((per_token_loss * completion_mask).sum(dim=1)).mean() 

@qgallouedec
Copy link
Member

qgallouedec commented Jan 23, 2025

  1. I didn’t see the PPO clipping logic in policy gradient loss, coming soon?

In the current implementation, we're just update once after a generation. In fact, we align with this sentence from the paper:

The policy model only has a single update following each exploration stage.

Therefore It implies that $\pi_{\theta_{\text{old}}} = \pi_\theta$, and the equation

$$\mathcal{J}_{\text{GRPO}}(\theta) =\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|}\sum_{t=1}^{|o_i|}\left[\min \left(\frac{\pi_\theta(o_{i,t} | q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} | q, o_{i,< t})} \hat{A}_{i,t}, \text{clip}\left(\frac{\pi_\theta(o_{i,t} | q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} | q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon\right) \hat{A}_{i,t}\right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right]\right]$$

can be simplified to

$$\mathcal{J}_{\text{GRPO}}(\theta) = \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} | q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} | q, o_{i,< t})\right]_\cancel{\nabla}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right].$$

But, we could support multiple update after each generation. And it would require to have this PPO clipping logic. It would probably allow to reuse generation and be more sample efficient. On the other hand, this would probably require a rather hard-to-read implementation, as the optimization step is performed in the parent trainer class. It would look something like

def __init__(self, ...):
    ...
    self.train_dataset = repeat_interleave(train_dataset, self.num_grpo_iterations)  # [prompt0, prompt1] -> [prompt0, prompt0, prompt0, prompt1, prompt1, prompt1]

def compute_loss(self, model, inputs):
    if self.step % self.num_grpo_iterations == 0: # self.num_grpo_iterations is 𝜇 in the paper
        completions = model.generate(prompts)
        self.old_log_probs = model(cat(prompts, completions))

    log_probs = model(cat(prompts, completions))
    log_ratio = log_probs - self.old_log_probs
    losses = min(exp(log_ratio)*advantages, clip(exp(log_ratio), 1-epsilon, 1+epsilon)*advantages)
    losses = losses - beta*kl

not sure if it's worth having this extra complexity.

@SeunghyunSEO
Copy link

SeunghyunSEO commented Jan 23, 2025

Hey @natolambert!!

  1. It looks like you apply the KL distance in the advantages, while the DeepSeekMath paper says “Also note that, instead of adding KLpenalty in the reward, GRPO regularizes by directly adding the KL divergence between the trained policy and the reference policy to the loss, avoiding complicating the calculation of 𝐴ˆ”

Here's how I understand it:
In PPO, the KL div is subtracted from the reward (adding KL penalty in the reward)

# 4. compute rewards
kl = logprobs - ref_logprobs
non_score_reward = -args.kl_coef * kl
rewards = non_score_reward.clone()
actual_start = torch.arange(rewards.size(0), device=rewards.device)
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
rewards[[actual_start, actual_end]] += scores

Here, in GRPO, the advantage is computed without the KL term. It's just the output of the reward function, normalised per group:

# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

and later you subtract the KL term

per_token_loss = -(advantages - self.beta * per_token_kl)

I don't know how complicated it would have been to integrate KL into the reward. It would probably look something like

Subtract KL

rewards = rewards - self.beta * per_token_kl # <- THIS IS ADDED

Compute grouped-wise rewards

mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)

Normalize the rewards to compute the advantages

mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

x - x.detach() allows for preserving gradients from x

advantages = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -advantages # <- THIS IS MODIFIED



which seems pretty simple. But perhaps they means that the subsequent _equations and calculations_ would have been more complicated.

i think there is miscommunication.
@natolambert pointed out you should not abstract kl term in advantages term following original GRPO impl and trl impl follows paper well. but the variable name make readers confused.

# Normalize the rewards to compute the advantages
        mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

        # x - x.detach() allows for preserving gradients from x
        advantages = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
        per_token_loss = -(advantages - self.beta * per_token_kl)
        loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

here per_token_loss in last line means per_token_prob * A - kld not A - kld.
so the current impl is right.

and for the second question, i think we should mean both per group loss and per token loss following the paper.
so the current impl lgtm again.

image

p.s. and the reason why there is no clipping make sense because current impl does not allow to iteratively update policy using trajectories sampled from old policy, so log ratio will always be 1

@qgallouedec
Copy link
Member

qgallouedec commented Jan 23, 2025

So this instead (just a renaming)?:

# x - x.detach() allows for preserving gradients from x
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

Maybe more aligned with the formulation

directly adding the KL divergence between the trained policy and the reference policy to the loss

@SeunghyunSEO
Copy link

So this instead (just a renaming)?:

# x - x.detach() allows for preserving gradients from x
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

Maybe more aligned with the formulation

directly adding the KL divergence between the trained policy and the reference policy to the loss

lgtm :)

@qgallouedec
Copy link
Member

#2616

@natolambert
Copy link
Contributor Author

natolambert commented Jan 23, 2025

Yeah, I'm mostly aligned now (I haven't fully checked the math), but it seems like practically this is because you don't do minimatches you can get away with it? Do you have a line by line derivation (lol, I will ask cluade). 👀

Thanks @SeunghyunSEO for the snipe on the variable names, that's what caught me up. Looks fine now and thanks for updating it.

EDIT: okay, I see, if they are the same then the min and clip become redundant (cliped to 1, and min of two identical quantities), so it simplifies to what you have.

@natolambert
Copy link
Contributor Author

For example, you can see our implementation in open instruct that we just added too allenai/open-instruct#523

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🏋 GRPO Related to GRPO ❓ question Seeking clarification or more information
Projects
None yet
Development

No branches or pull requests

4 participants