Skip to content

Commit 6e24cd8

Browse files
delmalihstevhliu
andauthored
Improve docstrings and type hints in scheduling_ddim_parallel.py (#13023)
* docs: improve docstring scheduling_ddim_parallel.py * docs: improve docstring scheduling_ddim_parallel.py * Update src/diffusers/schedulers/scheduling_ddim_parallel.py Co-authored-by: Steven Liu <[email protected]> * Update src/diffusers/schedulers/scheduling_ddim_parallel.py Co-authored-by: Steven Liu <[email protected]> * Update src/diffusers/schedulers/scheduling_ddim_parallel.py Co-authored-by: Steven Liu <[email protected]> * Update src/diffusers/schedulers/scheduling_ddim_parallel.py Co-authored-by: Steven Liu <[email protected]> * fix style --------- Co-authored-by: Steven Liu <[email protected]>
1 parent 981eb80 commit 6e24cd8

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

src/diffusers/schedulers/scheduling_ddim_parallel.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def alpha_bar_fn(t):
101101

102102

103103
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
104-
def rescale_zero_terminal_snr(betas):
104+
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
105105
"""
106106
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
107107
@@ -266,7 +266,7 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
266266
"""
267267
return sample
268268

269-
def _get_variance(self, timestep, prev_timestep=None):
269+
def _get_variance(self, timestep: int, prev_timestep: Optional[int] = None) -> torch.Tensor:
270270
if prev_timestep is None:
271271
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
272272

@@ -279,7 +279,7 @@ def _get_variance(self, timestep, prev_timestep=None):
279279

280280
return variance
281281

282-
def _batch_get_variance(self, t, prev_t):
282+
def _batch_get_variance(self, t: torch.Tensor, prev_t: torch.Tensor) -> torch.Tensor:
283283
alpha_prod_t = self.alphas_cumprod[t]
284284
alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)]
285285
alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0)
@@ -335,7 +335,7 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
335335
return sample
336336

337337
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps
338-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
338+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None:
339339
"""
340340
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
341341
@@ -392,7 +392,7 @@ def step(
392392
sample: torch.Tensor,
393393
eta: float = 0.0,
394394
use_clipped_model_output: bool = False,
395-
generator=None,
395+
generator: Optional[torch.Generator] = None,
396396
variance_noise: Optional[torch.Tensor] = None,
397397
return_dict: bool = True,
398398
) -> Union[DDIMParallelSchedulerOutput, Tuple]:
@@ -406,11 +406,13 @@ def step(
406406
sample (`torch.Tensor`):
407407
current instance of sample being created by diffusion process.
408408
eta (`float`): weight of noise for added noise in diffusion step.
409-
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
410-
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
411-
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
412-
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
413-
generator: random number generator.
409+
use_clipped_model_output (`bool`, defaults to `False`):
410+
If `True`, compute "corrected" `model_output` from the clipped predicted original sample. This
411+
correction is necessary because the predicted original sample is clipped to [-1, 1] when
412+
`self.config.clip_sample` is `True`. If no clipping occurred, the "corrected" `model_output` matches
413+
the input and `use_clipped_model_output` has no effect.
414+
generator (`torch.Generator`, *optional*):
415+
Random number generator.
414416
variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we
415417
can directly provide the noise for the variance itself. This is useful for methods such as
416418
CycleDiffusion. (https://huggingface.co/papers/2210.05559)
@@ -496,7 +498,10 @@ def step(
496498

497499
if variance_noise is None:
498500
variance_noise = randn_tensor(
499-
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
501+
model_output.shape,
502+
generator=generator,
503+
device=model_output.device,
504+
dtype=model_output.dtype,
500505
)
501506
variance = std_dev_t * variance_noise
502507

@@ -513,7 +518,7 @@ def step(
513518
def batch_step_no_noise(
514519
self,
515520
model_output: torch.Tensor,
516-
timesteps: List[int],
521+
timesteps: torch.Tensor,
517522
sample: torch.Tensor,
518523
eta: float = 0.0,
519524
use_clipped_model_output: bool = False,
@@ -528,7 +533,7 @@ def batch_step_no_noise(
528533
529534
Args:
530535
model_output (`torch.Tensor`): direct output from learned diffusion model.
531-
timesteps (`List[int]`):
536+
timesteps (`torch.Tensor`):
532537
current discrete timesteps in the diffusion chain. This is now a list of integers.
533538
sample (`torch.Tensor`):
534539
current instance of sample being created by diffusion process.
@@ -696,5 +701,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
696701
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
697702
return velocity
698703

699-
def __len__(self):
704+
def __len__(self) -> int:
700705
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)