Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 58 additions & 19 deletions src/diffusers/schedulers/scheduling_consistency_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
s_noise: float = 1.0,
rho: float = 7.0,
clip_denoised: bool = True,
):
) -> None:
# standard deviation of the initial noise distribution
self.init_noise_sigma = sigma_max

Expand All @@ -102,21 +102,29 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

@property
def step_index(self):
def step_index(self) -> Optional[int]:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.

Returns:
`int` or `None`:
The current step index, or `None` if not yet initialized.
"""
return self._step_index

@property
def begin_index(self):
def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.

Returns:
`int` or `None`:
The begin index, or `None` if not yet set.
"""
return self._begin_index

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

Expand Down Expand Up @@ -151,7 +159,7 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
self.is_scale_input_called = True
return sample

def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
def sigma_to_t(self, sigmas: Union[float, np.ndarray]) -> np.ndarray:
"""
Gets scaled timesteps from the Karras sigmas for input to the consistency model.

Expand All @@ -160,8 +168,8 @@ def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
A single Karras sigma or an array of Karras sigmas.

Returns:
`float` or `np.ndarray`:
A scaled input timestep or scaled input timestep array.
`np.ndarray`:
A scaled input timestep array.
"""
if not isinstance(sigmas, np.ndarray):
sigmas = np.array(sigmas, dtype=np.float64)
Expand All @@ -173,14 +181,14 @@ def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
):
) -> None:
"""
Sets the timesteps used for the diffusion chain (to be run before inference).

Args:
num_inference_steps (`int`):
num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
Expand Down Expand Up @@ -244,9 +252,19 @@ def set_timesteps(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# Modified _convert_to_karras implementation that takes in ramp as argument
def _convert_to_karras(self, ramp):
"""Constructs the noise schedule of Karras et al. (2022)."""
def _convert_to_karras(self, ramp: np.ndarray) -> np.ndarray:
"""
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
Models](https://huggingface.co/papers/2206.00364).

Args:
ramp (`np.ndarray`):
A ramp array of values between 0 and 1 used to interpolate between sigma_min and sigma_max.

Returns:
`np.ndarray`:
The Karras sigma schedule array.
"""
sigma_min: float = self.config.sigma_min
sigma_max: float = self.config.sigma_max

Expand All @@ -256,14 +274,25 @@ def _convert_to_karras(self, ramp):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas

def get_scalings(self, sigma):
def get_scalings(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the scaling factors for the consistency model output.

Args:
sigma (`torch.Tensor`):
The current sigma value in the noise schedule.

Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing `c_skip` (scaling for the input sample) and `c_out` (scaling for the model output).
"""
sigma_data = self.config.sigma_data

c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
return c_skip, c_out

def get_scalings_for_boundary_condition(self, sigma):
def get_scalings_for_boundary_condition(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Gets the scalings used in the consistency model parameterization (from Appendix C of the
[paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition.
Expand All @@ -275,7 +304,7 @@ def get_scalings_for_boundary_condition(self, sigma):
The current sigma in the Karras sigma schedule.

Returns:
`tuple`:
`Tuple[torch.Tensor, torch.Tensor]`:
A two-element tuple where `c_skip` (which weights the current sample) is the first element and `c_out`
(which weights the consistency model output) is the second element.
"""
Expand Down Expand Up @@ -348,13 +377,13 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
timestep (`float`):
timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
return_dict (`bool`, defaults to `True`):
Whether or not to return a
[`~schedulers.scheduling_consistency_models.CMStochasticIterativeSchedulerOutput`] or `tuple`.

Expand Down Expand Up @@ -406,7 +435,10 @@ def step(
# Noise is not used for onestep sampling.
if len(self.timesteps) > 1:
noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
model_output.shape,
dtype=model_output.dtype,
device=model_output.device,
generator=generator,
)
else:
noise = torch.zeros_like(model_output)
Expand Down Expand Up @@ -475,5 +507,12 @@ def add_noise(
noisy_samples = original_samples + noise * sigma
return noisy_samples

def __len__(self):
def __len__(self) -> int:
"""
Returns the number of training timesteps.

Returns:
`int`:
The number of training timesteps configured for the scheduler.
"""
return self.config.num_train_timesteps
Loading