Skip to content

Commit d36564f

Browse files
authored
Improve docstrings and type hints in scheduling_consistency_models.py (#12931)
docs: improve docstring scheduling_consistency_models.py
1 parent 441b69e commit d36564f

File tree

1 file changed

+58
-19
lines changed

1 file changed

+58
-19
lines changed

src/diffusers/schedulers/scheduling_consistency_models.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
s_noise: float = 1.0,
8484
rho: float = 7.0,
8585
clip_denoised: bool = True,
86-
):
86+
) -> None:
8787
# standard deviation of the initial noise distribution
8888
self.init_noise_sigma = sigma_max
8989

@@ -102,21 +102,29 @@ def __init__(
102102
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
103103

104104
@property
105-
def step_index(self):
105+
def step_index(self) -> Optional[int]:
106106
"""
107107
The index counter for current timestep. It will increase 1 after each scheduler step.
108+
109+
Returns:
110+
`int` or `None`:
111+
The current step index, or `None` if not yet initialized.
108112
"""
109113
return self._step_index
110114

111115
@property
112-
def begin_index(self):
116+
def begin_index(self) -> Optional[int]:
113117
"""
114118
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
119+
120+
Returns:
121+
`int` or `None`:
122+
The begin index, or `None` if not yet set.
115123
"""
116124
return self._begin_index
117125

118126
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
119-
def set_begin_index(self, begin_index: int = 0):
127+
def set_begin_index(self, begin_index: int = 0) -> None:
120128
"""
121129
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
122130
@@ -151,7 +159,7 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
151159
self.is_scale_input_called = True
152160
return sample
153161

154-
def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
162+
def sigma_to_t(self, sigmas: Union[float, np.ndarray]) -> np.ndarray:
155163
"""
156164
Gets scaled timesteps from the Karras sigmas for input to the consistency model.
157165
@@ -160,8 +168,8 @@ def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
160168
A single Karras sigma or an array of Karras sigmas.
161169
162170
Returns:
163-
`float` or `np.ndarray`:
164-
A scaled input timestep or scaled input timestep array.
171+
`np.ndarray`:
172+
A scaled input timestep array.
165173
"""
166174
if not isinstance(sigmas, np.ndarray):
167175
sigmas = np.array(sigmas, dtype=np.float64)
@@ -173,14 +181,14 @@ def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
173181
def set_timesteps(
174182
self,
175183
num_inference_steps: Optional[int] = None,
176-
device: Union[str, torch.device] = None,
184+
device: Optional[Union[str, torch.device]] = None,
177185
timesteps: Optional[List[int]] = None,
178-
):
186+
) -> None:
179187
"""
180188
Sets the timesteps used for the diffusion chain (to be run before inference).
181189
182190
Args:
183-
num_inference_steps (`int`):
191+
num_inference_steps (`int`, *optional*):
184192
The number of diffusion steps used when generating samples with a pre-trained model.
185193
device (`str` or `torch.device`, *optional*):
186194
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
@@ -244,9 +252,19 @@ def set_timesteps(
244252
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
245253

246254
# Modified _convert_to_karras implementation that takes in ramp as argument
247-
def _convert_to_karras(self, ramp):
248-
"""Constructs the noise schedule of Karras et al. (2022)."""
255+
def _convert_to_karras(self, ramp: np.ndarray) -> np.ndarray:
256+
"""
257+
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
258+
Models](https://huggingface.co/papers/2206.00364).
249259
260+
Args:
261+
ramp (`np.ndarray`):
262+
A ramp array of values between 0 and 1 used to interpolate between sigma_min and sigma_max.
263+
264+
Returns:
265+
`np.ndarray`:
266+
The Karras sigma schedule array.
267+
"""
250268
sigma_min: float = self.config.sigma_min
251269
sigma_max: float = self.config.sigma_max
252270

@@ -256,14 +274,25 @@ def _convert_to_karras(self, ramp):
256274
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
257275
return sigmas
258276

259-
def get_scalings(self, sigma):
277+
def get_scalings(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
278+
"""
279+
Computes the scaling factors for the consistency model output.
280+
281+
Args:
282+
sigma (`torch.Tensor`):
283+
The current sigma value in the noise schedule.
284+
285+
Returns:
286+
`Tuple[torch.Tensor, torch.Tensor]`:
287+
A tuple containing `c_skip` (scaling for the input sample) and `c_out` (scaling for the model output).
288+
"""
260289
sigma_data = self.config.sigma_data
261290

262291
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
263292
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
264293
return c_skip, c_out
265294

266-
def get_scalings_for_boundary_condition(self, sigma):
295+
def get_scalings_for_boundary_condition(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
267296
"""
268297
Gets the scalings used in the consistency model parameterization (from Appendix C of the
269298
[paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition.
@@ -275,7 +304,7 @@ def get_scalings_for_boundary_condition(self, sigma):
275304
The current sigma in the Karras sigma schedule.
276305
277306
Returns:
278-
`tuple`:
307+
`Tuple[torch.Tensor, torch.Tensor]`:
279308
A two-element tuple where `c_skip` (which weights the current sample) is the first element and `c_out`
280309
(which weights the consistency model output) is the second element.
281310
"""
@@ -348,13 +377,13 @@ def step(
348377
Args:
349378
model_output (`torch.Tensor`):
350379
The direct output from the learned diffusion model.
351-
timestep (`float`):
380+
timestep (`float` or `torch.Tensor`):
352381
The current timestep in the diffusion chain.
353382
sample (`torch.Tensor`):
354383
A current instance of a sample created by the diffusion process.
355384
generator (`torch.Generator`, *optional*):
356385
A random number generator.
357-
return_dict (`bool`, *optional*, defaults to `True`):
386+
return_dict (`bool`, defaults to `True`):
358387
Whether or not to return a
359388
[`~schedulers.scheduling_consistency_models.CMStochasticIterativeSchedulerOutput`] or `tuple`.
360389
@@ -406,7 +435,10 @@ def step(
406435
# Noise is not used for onestep sampling.
407436
if len(self.timesteps) > 1:
408437
noise = randn_tensor(
409-
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
438+
model_output.shape,
439+
dtype=model_output.dtype,
440+
device=model_output.device,
441+
generator=generator,
410442
)
411443
else:
412444
noise = torch.zeros_like(model_output)
@@ -475,5 +507,12 @@ def add_noise(
475507
noisy_samples = original_samples + noise * sigma
476508
return noisy_samples
477509

478-
def __len__(self):
510+
def __len__(self) -> int:
511+
"""
512+
Returns the number of training timesteps.
513+
514+
Returns:
515+
`int`:
516+
The number of training timesteps configured for the scheduler.
517+
"""
479518
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)