@@ -83,7 +83,7 @@ def __init__(
8383 s_noise : float = 1.0 ,
8484 rho : float = 7.0 ,
8585 clip_denoised : bool = True ,
86- ):
86+ ) -> None :
8787 # standard deviation of the initial noise distribution
8888 self .init_noise_sigma = sigma_max
8989
@@ -102,21 +102,29 @@ def __init__(
102102 self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
103103
104104 @property
105- def step_index (self ):
105+ def step_index (self ) -> Optional [ int ] :
106106 """
107107 The index counter for current timestep. It will increase 1 after each scheduler step.
108+
109+ Returns:
110+ `int` or `None`:
111+ The current step index, or `None` if not yet initialized.
108112 """
109113 return self ._step_index
110114
111115 @property
112- def begin_index (self ):
116+ def begin_index (self ) -> Optional [ int ] :
113117 """
114118 The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
119+
120+ Returns:
121+ `int` or `None`:
122+ The begin index, or `None` if not yet set.
115123 """
116124 return self ._begin_index
117125
118126 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
119- def set_begin_index (self , begin_index : int = 0 ):
127+ def set_begin_index (self , begin_index : int = 0 ) -> None :
120128 """
121129 Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
122130
@@ -151,7 +159,7 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
151159 self .is_scale_input_called = True
152160 return sample
153161
154- def sigma_to_t (self , sigmas : Union [float , np .ndarray ]):
162+ def sigma_to_t (self , sigmas : Union [float , np .ndarray ]) -> np . ndarray :
155163 """
156164 Gets scaled timesteps from the Karras sigmas for input to the consistency model.
157165
@@ -160,8 +168,8 @@ def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
160168 A single Karras sigma or an array of Karras sigmas.
161169
162170 Returns:
163- `float` or ` np.ndarray`:
164- A scaled input timestep or scaled input timestep array.
171+ `np.ndarray`:
172+ A scaled input timestep array.
165173 """
166174 if not isinstance (sigmas , np .ndarray ):
167175 sigmas = np .array (sigmas , dtype = np .float64 )
@@ -173,14 +181,14 @@ def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
173181 def set_timesteps (
174182 self ,
175183 num_inference_steps : Optional [int ] = None ,
176- device : Union [str , torch .device ] = None ,
184+ device : Optional [ Union [str , torch .device ] ] = None ,
177185 timesteps : Optional [List [int ]] = None ,
178- ):
186+ ) -> None :
179187 """
180188 Sets the timesteps used for the diffusion chain (to be run before inference).
181189
182190 Args:
183- num_inference_steps (`int`):
191+ num_inference_steps (`int`, *optional* ):
184192 The number of diffusion steps used when generating samples with a pre-trained model.
185193 device (`str` or `torch.device`, *optional*):
186194 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
@@ -244,9 +252,19 @@ def set_timesteps(
244252 self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
245253
246254 # Modified _convert_to_karras implementation that takes in ramp as argument
247- def _convert_to_karras (self , ramp ):
248- """Constructs the noise schedule of Karras et al. (2022)."""
255+ def _convert_to_karras (self , ramp : np .ndarray ) -> np .ndarray :
256+ """
257+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
258+ Models](https://huggingface.co/papers/2206.00364).
249259
260+ Args:
261+ ramp (`np.ndarray`):
262+ A ramp array of values between 0 and 1 used to interpolate between sigma_min and sigma_max.
263+
264+ Returns:
265+ `np.ndarray`:
266+ The Karras sigma schedule array.
267+ """
250268 sigma_min : float = self .config .sigma_min
251269 sigma_max : float = self .config .sigma_max
252270
@@ -256,14 +274,25 @@ def _convert_to_karras(self, ramp):
256274 sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
257275 return sigmas
258276
259- def get_scalings (self , sigma ):
277+ def get_scalings (self , sigma : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
278+ """
279+ Computes the scaling factors for the consistency model output.
280+
281+ Args:
282+ sigma (`torch.Tensor`):
283+ The current sigma value in the noise schedule.
284+
285+ Returns:
286+ `Tuple[torch.Tensor, torch.Tensor]`:
287+ A tuple containing `c_skip` (scaling for the input sample) and `c_out` (scaling for the model output).
288+ """
260289 sigma_data = self .config .sigma_data
261290
262291 c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2 )
263292 c_out = sigma * sigma_data / (sigma ** 2 + sigma_data ** 2 ) ** 0.5
264293 return c_skip , c_out
265294
266- def get_scalings_for_boundary_condition (self , sigma ) :
295+ def get_scalings_for_boundary_condition (self , sigma : torch . Tensor ) -> Tuple [ torch . Tensor , torch . Tensor ] :
267296 """
268297 Gets the scalings used in the consistency model parameterization (from Appendix C of the
269298 [paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition.
@@ -275,7 +304,7 @@ def get_scalings_for_boundary_condition(self, sigma):
275304 The current sigma in the Karras sigma schedule.
276305
277306 Returns:
278- `tuple `:
307+ `Tuple[torch.Tensor, torch.Tensor] `:
279308 A two-element tuple where `c_skip` (which weights the current sample) is the first element and `c_out`
280309 (which weights the consistency model output) is the second element.
281310 """
@@ -348,13 +377,13 @@ def step(
348377 Args:
349378 model_output (`torch.Tensor`):
350379 The direct output from the learned diffusion model.
351- timestep (`float`):
380+ timestep (`float` or `torch.Tensor` ):
352381 The current timestep in the diffusion chain.
353382 sample (`torch.Tensor`):
354383 A current instance of a sample created by the diffusion process.
355384 generator (`torch.Generator`, *optional*):
356385 A random number generator.
357- return_dict (`bool`, *optional*, defaults to `True`):
386+ return_dict (`bool`, defaults to `True`):
358387 Whether or not to return a
359388 [`~schedulers.scheduling_consistency_models.CMStochasticIterativeSchedulerOutput`] or `tuple`.
360389
@@ -406,7 +435,10 @@ def step(
406435 # Noise is not used for onestep sampling.
407436 if len (self .timesteps ) > 1 :
408437 noise = randn_tensor (
409- model_output .shape , dtype = model_output .dtype , device = model_output .device , generator = generator
438+ model_output .shape ,
439+ dtype = model_output .dtype ,
440+ device = model_output .device ,
441+ generator = generator ,
410442 )
411443 else :
412444 noise = torch .zeros_like (model_output )
@@ -475,5 +507,12 @@ def add_noise(
475507 noisy_samples = original_samples + noise * sigma
476508 return noisy_samples
477509
478- def __len__ (self ):
510+ def __len__ (self ) -> int :
511+ """
512+ Returns the number of training timesteps.
513+
514+ Returns:
515+ `int`:
516+ The number of training timesteps configured for the scheduler.
517+ """
479518 return self .config .num_train_timesteps
0 commit comments