@@ -101,7 +101,7 @@ def alpha_bar_fn(t):
101101
102102
103103# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
104- def rescale_zero_terminal_snr (betas ) :
104+ def rescale_zero_terminal_snr (betas : torch . Tensor ) -> torch . Tensor :
105105 """
106106 Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
107107
@@ -266,7 +266,7 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
266266 """
267267 return sample
268268
269- def _get_variance (self , timestep , prev_timestep = None ):
269+ def _get_variance (self , timestep : int , prev_timestep : Optional [ int ] = None ) -> torch . Tensor :
270270 if prev_timestep is None :
271271 prev_timestep = timestep - self .config .num_train_timesteps // self .num_inference_steps
272272
@@ -279,7 +279,7 @@ def _get_variance(self, timestep, prev_timestep=None):
279279
280280 return variance
281281
282- def _batch_get_variance (self , t , prev_t ) :
282+ def _batch_get_variance (self , t : torch . Tensor , prev_t : torch . Tensor ) -> torch . Tensor :
283283 alpha_prod_t = self .alphas_cumprod [t ]
284284 alpha_prod_t_prev = self .alphas_cumprod [torch .clip (prev_t , min = 0 )]
285285 alpha_prod_t_prev [prev_t < 0 ] = torch .tensor (1.0 )
@@ -335,7 +335,7 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
335335 return sample
336336
337337 # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps
338- def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
338+ def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ) -> None :
339339 """
340340 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
341341
@@ -392,7 +392,7 @@ def step(
392392 sample : torch .Tensor ,
393393 eta : float = 0.0 ,
394394 use_clipped_model_output : bool = False ,
395- generator = None ,
395+ generator : Optional [ torch . Generator ] = None ,
396396 variance_noise : Optional [torch .Tensor ] = None ,
397397 return_dict : bool = True ,
398398 ) -> Union [DDIMParallelSchedulerOutput , Tuple ]:
@@ -406,11 +406,13 @@ def step(
406406 sample (`torch.Tensor`):
407407 current instance of sample being created by diffusion process.
408408 eta (`float`): weight of noise for added noise in diffusion step.
409- use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
410- predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
411- `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
412- coincide with the one provided as input and `use_clipped_model_output` will have not effect.
413- generator: random number generator.
409+ use_clipped_model_output (`bool`, defaults to `False`):
410+ If `True`, compute "corrected" `model_output` from the clipped predicted original sample. This
411+ correction is necessary because the predicted original sample is clipped to [-1, 1] when
412+ `self.config.clip_sample` is `True`. If no clipping occurred, the "corrected" `model_output` matches
413+ the input and `use_clipped_model_output` has no effect.
414+ generator (`torch.Generator`, *optional*):
415+ Random number generator.
414416 variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we
415417 can directly provide the noise for the variance itself. This is useful for methods such as
416418 CycleDiffusion. (https://huggingface.co/papers/2210.05559)
@@ -496,7 +498,10 @@ def step(
496498
497499 if variance_noise is None :
498500 variance_noise = randn_tensor (
499- model_output .shape , generator = generator , device = model_output .device , dtype = model_output .dtype
501+ model_output .shape ,
502+ generator = generator ,
503+ device = model_output .device ,
504+ dtype = model_output .dtype ,
500505 )
501506 variance = std_dev_t * variance_noise
502507
@@ -513,7 +518,7 @@ def step(
513518 def batch_step_no_noise (
514519 self ,
515520 model_output : torch .Tensor ,
516- timesteps : List [ int ] ,
521+ timesteps : torch . Tensor ,
517522 sample : torch .Tensor ,
518523 eta : float = 0.0 ,
519524 use_clipped_model_output : bool = False ,
@@ -528,7 +533,7 @@ def batch_step_no_noise(
528533
529534 Args:
530535 model_output (`torch.Tensor`): direct output from learned diffusion model.
531- timesteps (`List[int] `):
536+ timesteps (`torch.Tensor `):
532537 current discrete timesteps in the diffusion chain. This is now a list of integers.
533538 sample (`torch.Tensor`):
534539 current instance of sample being created by diffusion process.
@@ -696,5 +701,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
696701 velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
697702 return velocity
698703
699- def __len__ (self ):
704+ def __len__ (self ) -> int :
700705 return self .config .num_train_timesteps
0 commit comments