Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about p_sample :) #135

Open
tryumanshow opened this issue Dec 6, 2022 · 2 comments
Open

Question about p_sample :) #135

tryumanshow opened this issue Dec 6, 2022 · 2 comments

Comments

@tryumanshow
Copy link

tryumanshow commented Dec 6, 2022

Hi!
Always thank you for your great codes that you provide!

Anyway, there are 2 points that I can't understand.

  1. Why do you use posterior mean and variance on reverse step? ( in p_sample function )
    I expected using the equation (11) of original DDPM paper, but I think it is not on this code.
    Can you explain this for me ? :)
    def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
        preds = self.model_predictions(x, t, x_self_cond)
        x_start = preds.pred_x_start

        if clip_denoised:
            x_start.clamp_(-1., 1.)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
        return model_mean, posterior_variance, posterior_log_variance, x_start

    @torch.no_grad()
    def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True):
        b, *_, device = *x.shape, x.device # b: 4
        batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
        model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, # Isotrophic Normal Gaussian
                                                                          t = batched_times, 
                                                                          x_self_cond = x_self_cond, 
                                                                          clip_denoised = clip_denoised)  #  <------ This part!
        noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
        pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
        return pred_img, x_start
  1. Can you explain the intent of (0.5 * model_log_variance).exp() on `pred_img = model_mean + (0.5 * model_log_variance).exp() * noise in p_sample_loop?

The full code is as below:

    @torch.no_grad()
    def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True):
        b, *_, device = *x.shape, x.device 
        batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
        model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, # Isotrophic Normal Gaussian
                                                                          t = batched_times, 
                                                                          x_self_cond = x_self_cond, 
                                                                          clip_denoised = clip_denoised)
        noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
        pred_img = model_mean + (0.5 * model_log_variance).exp() * noise  # <------------ This part!
        return pred_img, x_start
@lhaippp
Copy link
Contributor

lhaippp commented Dec 16, 2022

I got the same question

In my mind, posterior_variance is the one that we need.
as 'pred_img = x_start + posterior_variance * noise'

@robert-graf
Copy link

1: Step 4 from Algorithm, but using equivalence from formula 9,7.
Equation 9 right side of mean_tilde_t( ... , HERE) and replacing mean_tilde_t with 7
I tested equation 11 and came across stability issues. I think that small numbers cause some floating point errors.
This also enables the clamping of the image.

2: There is a commentary somewhere "# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain". I can't remember if this comment is from this repository. This is probably again a numerical fix.

Note: sqrt(posterior_variance) == exp(0.5*log(posterior_variance))
posterior_variance stores the square of σ

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants