@@ -149,38 +149,41 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
149149 For more details, see the original paper: https://huggingface.co/papers/2006.11239
150150
151151 Args:
152- num_train_timesteps (`int`): number of diffusion steps used to train the model.
153- beta_start (`float`): the starting `beta` value of inference.
154- beta_end (`float`): the final `beta` value.
155- beta_schedule (`str`):
156- the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
152+ num_train_timesteps (`int`, defaults to 1000):
153+ The number of diffusion steps to train the model.
154+ beta_start (`float`, defaults to 0.0001):
155+ The starting `beta` value of inference.
156+ beta_end (`float`, defaults to 0.02):
157+ The final `beta` value.
158+ beta_schedule (`str`, defaults to `"linear"`):
159+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
157160 `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`.
158- trained_betas (`np.ndarray`, optional):
159- option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
160- variance_type (`str`):
161- options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
161+ trained_betas (`np.ndarray`, * optional* ):
162+ Option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
163+ variance_type (`str`, defaults to `"fixed_small"` ):
164+ Options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
162165 `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
163- clip_sample (`bool`, default `True`):
164- option to clip predicted sample for numerical stability.
165- clip_sample_range (`float`, default `1.0`):
166- the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
167- prediction_type (`str`, default `epsilon`, optional):
168- prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
166+ clip_sample (`bool`, defaults to `True`):
167+ Option to clip predicted sample for numerical stability.
168+ prediction_type (`str`, defaults to `"epsilon"`):
169+ Prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
169170 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
170171 https://huggingface.co/papers/2210.02303)
171- thresholding (`bool`, default `False`):
172- whether to use the "dynamic thresholding" method (introduced by Imagen,
172+ thresholding (`bool`, defaults to `False`):
173+ Whether to use the "dynamic thresholding" method (introduced by Imagen,
173174 https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space
174175 diffusion models (such as stable-diffusion).
175- dynamic_thresholding_ratio (`float`, default ` 0.995` ):
176- the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
176+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
177+ The ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
177178 (https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`.
178- sample_max_value (`float`, default `1.0`):
179- the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
180- timestep_spacing (`str`, default `"leading"`):
179+ clip_sample_range (`float`, defaults to 1.0):
180+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
181+ sample_max_value (`float`, defaults to 1.0):
182+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
183+ timestep_spacing (`str`, defaults to `"leading"`):
181184 The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
182185 Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
183- steps_offset (`int`, default `0` ):
186+ steps_offset (`int`, defaults to 0 ):
184187 An offset added to the inference steps, as required by some model families.
185188 rescale_betas_zero_snr (`bool`, defaults to `False`):
186189 Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
@@ -293,7 +296,7 @@ def set_timesteps(
293296 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
294297
295298 Args:
296- num_inference_steps (`int`):
299+ num_inference_steps (`int`, *optional* ):
297300 The number of diffusion steps used when generating samples with a pre-trained model. If used,
298301 `timesteps` must be `None`.
299302 device (`str` or `torch.device`, *optional*):
@@ -478,7 +481,7 @@ def step(
478481 model_output : torch .Tensor ,
479482 timestep : int ,
480483 sample : torch .Tensor ,
481- generator = None ,
484+ generator : Optional [ torch . Generator ] = None ,
482485 return_dict : bool = True ,
483486 ) -> Union [DDPMParallelSchedulerOutput , Tuple ]:
484487 """
@@ -490,7 +493,8 @@ def step(
490493 timestep (`int`): current discrete timestep in the diffusion chain.
491494 sample (`torch.Tensor`):
492495 current instance of sample being created by diffusion process.
493- generator: random number generator.
496+ generator (`torch.Generator`, *optional*):
497+ Random number generator.
494498 return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class
495499
496500 Returns:
@@ -503,7 +507,10 @@ def step(
503507
504508 prev_t = self .previous_timestep (t )
505509
506- if model_output .shape [1 ] == sample .shape [1 ] * 2 and self .variance_type in ["learned" , "learned_range" ]:
510+ if model_output .shape [1 ] == sample .shape [1 ] * 2 and self .variance_type in [
511+ "learned" ,
512+ "learned_range" ,
513+ ]:
507514 model_output , predicted_variance = torch .split (model_output , sample .shape [1 ], dim = 1 )
508515 else :
509516 predicted_variance = None
@@ -552,7 +559,10 @@ def step(
552559 if t > 0 :
553560 device = model_output .device
554561 variance_noise = randn_tensor (
555- model_output .shape , generator = generator , device = device , dtype = model_output .dtype
562+ model_output .shape ,
563+ generator = generator ,
564+ device = device ,
565+ dtype = model_output .dtype ,
556566 )
557567 if self .variance_type == "fixed_small_log" :
558568 variance = self ._get_variance (t , predicted_variance = predicted_variance ) * variance_noise
@@ -575,7 +585,7 @@ def step(
575585 def batch_step_no_noise (
576586 self ,
577587 model_output : torch .Tensor ,
578- timesteps : List [ int ] ,
588+ timesteps : torch . Tensor ,
579589 sample : torch .Tensor ,
580590 ) -> torch .Tensor :
581591 """
@@ -588,8 +598,8 @@ def batch_step_no_noise(
588598
589599 Args:
590600 model_output (`torch.Tensor`): direct output from learned diffusion model.
591- timesteps (`List[int] `):
592- current discrete timesteps in the diffusion chain. This is now a list of integers.
601+ timesteps (`torch.Tensor `):
602+ Current discrete timesteps in the diffusion chain. This is a tensor of integers.
593603 sample (`torch.Tensor`):
594604 current instance of sample being created by diffusion process.
595605
@@ -603,7 +613,10 @@ def batch_step_no_noise(
603613 t = t .view (- 1 , * ([1 ] * (model_output .ndim - 1 )))
604614 prev_t = prev_t .view (- 1 , * ([1 ] * (model_output .ndim - 1 )))
605615
606- if model_output .shape [1 ] == sample .shape [1 ] * 2 and self .variance_type in ["learned" , "learned_range" ]:
616+ if model_output .shape [1 ] == sample .shape [1 ] * 2 and self .variance_type in [
617+ "learned" ,
618+ "learned_range" ,
619+ ]:
607620 model_output , predicted_variance = torch .split (model_output , sample .shape [1 ], dim = 1 )
608621 else :
609622 pass
@@ -734,7 +747,7 @@ def __len__(self):
734747 return self .config .num_train_timesteps
735748
736749 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
737- def previous_timestep (self , timestep ) :
750+ def previous_timestep (self , timestep : int ) -> Union [ int , torch . Tensor ] :
738751 """
739752 Compute the previous timestep in the diffusion chain.
740753
@@ -743,7 +756,7 @@ def previous_timestep(self, timestep):
743756 The current timestep.
744757
745758 Returns:
746- `int`:
759+ `int` or `torch.Tensor` :
747760 The previous timestep.
748761 """
749762 if self .custom_timesteps or self .num_inference_steps :
0 commit comments