Skip to content

Commit a7cb14e

Browse files
delmalihstevhliu
andauthored
Improve docstrings and type hints in scheduling_ddpm_parallel.py (#13027)
* docs: improve docstring scheduling_ddpm_parallel.py * Update scheduling_ddpm_parallel.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
1 parent e8e88ff commit a7cb14e

File tree

4 files changed

+52
-39
lines changed

4 files changed

+52
-39
lines changed

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def set_timesteps(
281281
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
282282
283283
Args:
284-
num_inference_steps (`int`):
284+
num_inference_steps (`int`, *optional*):
285285
The number of diffusion steps used when generating samples with a pre-trained model. If used,
286286
`timesteps` must be `None`.
287287
device (`str` or `torch.device`, *optional*):
@@ -646,7 +646,7 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
646646
def __len__(self) -> int:
647647
return self.config.num_train_timesteps
648648

649-
def previous_timestep(self, timestep: int) -> int:
649+
def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]:
650650
"""
651651
Compute the previous timestep in the diffusion chain.
652652
@@ -655,7 +655,7 @@ def previous_timestep(self, timestep: int) -> int:
655655
The current timestep.
656656
657657
Returns:
658-
`int`:
658+
`int` or `torch.Tensor`:
659659
The previous timestep.
660660
"""
661661
if self.custom_timesteps or self.num_inference_steps:

src/diffusers/schedulers/scheduling_ddpm_parallel.py

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -149,38 +149,41 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
149149
For more details, see the original paper: https://huggingface.co/papers/2006.11239
150150
151151
Args:
152-
num_train_timesteps (`int`): number of diffusion steps used to train the model.
153-
beta_start (`float`): the starting `beta` value of inference.
154-
beta_end (`float`): the final `beta` value.
155-
beta_schedule (`str`):
156-
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
152+
num_train_timesteps (`int`, defaults to 1000):
153+
The number of diffusion steps to train the model.
154+
beta_start (`float`, defaults to 0.0001):
155+
The starting `beta` value of inference.
156+
beta_end (`float`, defaults to 0.02):
157+
The final `beta` value.
158+
beta_schedule (`str`, defaults to `"linear"`):
159+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
157160
`linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`.
158-
trained_betas (`np.ndarray`, optional):
159-
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
160-
variance_type (`str`):
161-
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
161+
trained_betas (`np.ndarray`, *optional*):
162+
Option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
163+
variance_type (`str`, defaults to `"fixed_small"`):
164+
Options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
162165
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
163-
clip_sample (`bool`, default `True`):
164-
option to clip predicted sample for numerical stability.
165-
clip_sample_range (`float`, default `1.0`):
166-
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
167-
prediction_type (`str`, default `epsilon`, optional):
168-
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
166+
clip_sample (`bool`, defaults to `True`):
167+
Option to clip predicted sample for numerical stability.
168+
prediction_type (`str`, defaults to `"epsilon"`):
169+
Prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
169170
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
170171
https://huggingface.co/papers/2210.02303)
171-
thresholding (`bool`, default `False`):
172-
whether to use the "dynamic thresholding" method (introduced by Imagen,
172+
thresholding (`bool`, defaults to `False`):
173+
Whether to use the "dynamic thresholding" method (introduced by Imagen,
173174
https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space
174175
diffusion models (such as stable-diffusion).
175-
dynamic_thresholding_ratio (`float`, default `0.995`):
176-
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
176+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
177+
The ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
177178
(https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`.
178-
sample_max_value (`float`, default `1.0`):
179-
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
180-
timestep_spacing (`str`, default `"leading"`):
179+
clip_sample_range (`float`, defaults to 1.0):
180+
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
181+
sample_max_value (`float`, defaults to 1.0):
182+
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
183+
timestep_spacing (`str`, defaults to `"leading"`):
181184
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
182185
Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
183-
steps_offset (`int`, default `0`):
186+
steps_offset (`int`, defaults to 0):
184187
An offset added to the inference steps, as required by some model families.
185188
rescale_betas_zero_snr (`bool`, defaults to `False`):
186189
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
@@ -293,7 +296,7 @@ def set_timesteps(
293296
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
294297
295298
Args:
296-
num_inference_steps (`int`):
299+
num_inference_steps (`int`, *optional*):
297300
The number of diffusion steps used when generating samples with a pre-trained model. If used,
298301
`timesteps` must be `None`.
299302
device (`str` or `torch.device`, *optional*):
@@ -478,7 +481,7 @@ def step(
478481
model_output: torch.Tensor,
479482
timestep: int,
480483
sample: torch.Tensor,
481-
generator=None,
484+
generator: Optional[torch.Generator] = None,
482485
return_dict: bool = True,
483486
) -> Union[DDPMParallelSchedulerOutput, Tuple]:
484487
"""
@@ -490,7 +493,8 @@ def step(
490493
timestep (`int`): current discrete timestep in the diffusion chain.
491494
sample (`torch.Tensor`):
492495
current instance of sample being created by diffusion process.
493-
generator: random number generator.
496+
generator (`torch.Generator`, *optional*):
497+
Random number generator.
494498
return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class
495499
496500
Returns:
@@ -503,7 +507,10 @@ def step(
503507

504508
prev_t = self.previous_timestep(t)
505509

506-
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
510+
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
511+
"learned",
512+
"learned_range",
513+
]:
507514
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
508515
else:
509516
predicted_variance = None
@@ -552,7 +559,10 @@ def step(
552559
if t > 0:
553560
device = model_output.device
554561
variance_noise = randn_tensor(
555-
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
562+
model_output.shape,
563+
generator=generator,
564+
device=device,
565+
dtype=model_output.dtype,
556566
)
557567
if self.variance_type == "fixed_small_log":
558568
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
@@ -575,7 +585,7 @@ def step(
575585
def batch_step_no_noise(
576586
self,
577587
model_output: torch.Tensor,
578-
timesteps: List[int],
588+
timesteps: torch.Tensor,
579589
sample: torch.Tensor,
580590
) -> torch.Tensor:
581591
"""
@@ -588,8 +598,8 @@ def batch_step_no_noise(
588598
589599
Args:
590600
model_output (`torch.Tensor`): direct output from learned diffusion model.
591-
timesteps (`List[int]`):
592-
current discrete timesteps in the diffusion chain. This is now a list of integers.
601+
timesteps (`torch.Tensor`):
602+
Current discrete timesteps in the diffusion chain. This is a tensor of integers.
593603
sample (`torch.Tensor`):
594604
current instance of sample being created by diffusion process.
595605
@@ -603,7 +613,10 @@ def batch_step_no_noise(
603613
t = t.view(-1, *([1] * (model_output.ndim - 1)))
604614
prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1)))
605615

606-
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
616+
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
617+
"learned",
618+
"learned_range",
619+
]:
607620
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
608621
else:
609622
pass
@@ -734,7 +747,7 @@ def __len__(self):
734747
return self.config.num_train_timesteps
735748

736749
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
737-
def previous_timestep(self, timestep):
750+
def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]:
738751
"""
739752
Compute the previous timestep in the diffusion chain.
740753
@@ -743,7 +756,7 @@ def previous_timestep(self, timestep):
743756
The current timestep.
744757
745758
Returns:
746-
`int`:
759+
`int` or `torch.Tensor`:
747760
The previous timestep.
748761
"""
749762
if self.custom_timesteps or self.num_inference_steps:

src/diffusers/schedulers/scheduling_lcm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def previous_timestep(self, timestep):
722722
The current timestep.
723723
724724
Returns:
725-
`int`:
725+
`int` or `torch.Tensor`:
726726
The previous timestep.
727727
"""
728728
if self.custom_timesteps or self.num_inference_steps:

src/diffusers/schedulers/scheduling_tcd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,7 @@ def previous_timestep(self, timestep):
777777
The current timestep.
778778
779779
Returns:
780-
`int`:
780+
`int` or `torch.Tensor`:
781781
The previous timestep.
782782
"""
783783
if self.custom_timesteps or self.num_inference_steps:

0 commit comments

Comments
 (0)