Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions src/diffusers/schedulers/scheduling_consistency_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,30 @@ class ConsistencyDecoderSchedulerOutput(BaseOutput):


class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
"""
A scheduler for the consistency decoder used in Stable Diffusion pipelines.

This scheduler implements a two-step denoising process using consistency models for decoding latent representations
into images.

This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.

Args:
num_train_timesteps (`int`, *optional*, defaults to `1024`):
The number of diffusion steps to train the model.
sigma_data (`float`, *optional*, defaults to `0.5`):
The standard deviation of the data distribution. Used for computing the skip and output scaling factors.
"""

order = 1

@register_to_config
def __init__(
self,
num_train_timesteps: int = 1024,
sigma_data: float = 0.5,
):
) -> None:
betas = betas_for_alpha_bar(num_train_timesteps)

alphas = 1.0 - betas
Expand All @@ -98,8 +114,18 @@ def __init__(
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
):
device: Optional[Union[str, torch.device]] = None,
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).

Args:
num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model. Currently, only
`2` inference steps are supported.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
if num_inference_steps != 2:
raise ValueError("Currently more than 2 inference steps are not supported.")

Expand All @@ -111,7 +137,15 @@ def set_timesteps(
self.c_in = self.c_in.to(device)

@property
def init_noise_sigma(self):
def init_noise_sigma(self) -> torch.Tensor:
"""
Return the standard deviation of the initial noise distribution.

Returns:
`torch.Tensor`:
The initial noise sigma value from the precomputed `sqrt_one_minus_alphas_cumprod` at the first
timestep.
"""
return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]]

def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
Expand Down Expand Up @@ -146,20 +180,20 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
timestep (`float`):
timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
A random number generator for reproducibility.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`.
[`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] or `tuple`.

Returns:
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`:
If return_dict is `True`,
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
[`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] or `tuple`:
If `return_dict` is `True`,
[`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
a tuple is returned where the first element is the sample tensor.
"""
x_0 = self.c_out[timestep] * model_output + self.c_skip[timestep] * sample
Expand Down
Loading