diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index 386a43db0f9c..195ff81b4c91 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -83,7 +83,7 @@ def __init__( s_noise: float = 1.0, rho: float = 7.0, clip_denoised: bool = True, - ): + ) -> None: # standard deviation of the initial noise distribution self.init_noise_sigma = sigma_max @@ -102,21 +102,29 @@ def __init__( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property - def step_index(self): + def step_index(self) -> Optional[int]: """ The index counter for current timestep. It will increase 1 after each scheduler step. + + Returns: + `int` or `None`: + The current step index, or `None` if not yet initialized. """ return self._step_index @property - def begin_index(self): + def begin_index(self) -> Optional[int]: """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + + Returns: + `int` or `None`: + The begin index, or `None` if not yet set. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): + def set_begin_index(self, begin_index: int = 0) -> None: """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. @@ -151,7 +159,7 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T self.is_scale_input_called = True return sample - def sigma_to_t(self, sigmas: Union[float, np.ndarray]): + def sigma_to_t(self, sigmas: Union[float, np.ndarray]) -> np.ndarray: """ Gets scaled timesteps from the Karras sigmas for input to the consistency model. @@ -160,8 +168,8 @@ def sigma_to_t(self, sigmas: Union[float, np.ndarray]): A single Karras sigma or an array of Karras sigmas. Returns: - `float` or `np.ndarray`: - A scaled input timestep or scaled input timestep array. + `np.ndarray`: + A scaled input timestep array. """ if not isinstance(sigmas, np.ndarray): sigmas = np.array(sigmas, dtype=np.float64) @@ -173,14 +181,14 @@ def sigma_to_t(self, sigmas: Union[float, np.ndarray]): def set_timesteps( self, num_inference_steps: Optional[int] = None, - device: Union[str, torch.device] = None, + device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, - ): + ) -> None: """ Sets the timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. @@ -244,9 +252,19 @@ def set_timesteps( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Modified _convert_to_karras implementation that takes in ramp as argument - def _convert_to_karras(self, ramp): - """Constructs the noise schedule of Karras et al. (2022).""" + def _convert_to_karras(self, ramp: np.ndarray) -> np.ndarray: + """ + Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative + Models](https://huggingface.co/papers/2206.00364). + Args: + ramp (`np.ndarray`): + A ramp array of values between 0 and 1 used to interpolate between sigma_min and sigma_max. + + Returns: + `np.ndarray`: + The Karras sigma schedule array. + """ sigma_min: float = self.config.sigma_min sigma_max: float = self.config.sigma_max @@ -256,14 +274,25 @@ def _convert_to_karras(self, ramp): sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas - def get_scalings(self, sigma): + def get_scalings(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes the scaling factors for the consistency model output. + + Args: + sigma (`torch.Tensor`): + The current sigma value in the noise schedule. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing `c_skip` (scaling for the input sample) and `c_out` (scaling for the model output). + """ sigma_data = self.config.sigma_data c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 return c_skip, c_out - def get_scalings_for_boundary_condition(self, sigma): + def get_scalings_for_boundary_condition(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Gets the scalings used in the consistency model parameterization (from Appendix C of the [paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition. @@ -275,7 +304,7 @@ def get_scalings_for_boundary_condition(self, sigma): The current sigma in the Karras sigma schedule. Returns: - `tuple`: + `Tuple[torch.Tensor, torch.Tensor]`: A two-element tuple where `c_skip` (which weights the current sample) is the first element and `c_out` (which weights the consistency model output) is the second element. """ @@ -348,13 +377,13 @@ def step( Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. - timestep (`float`): + timestep (`float` or `torch.Tensor`): The current timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. - return_dict (`bool`, *optional*, defaults to `True`): + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_consistency_models.CMStochasticIterativeSchedulerOutput`] or `tuple`. @@ -406,7 +435,10 @@ def step( # Noise is not used for onestep sampling. if len(self.timesteps) > 1: noise = randn_tensor( - model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + model_output.shape, + dtype=model_output.dtype, + device=model_output.device, + generator=generator, ) else: noise = torch.zeros_like(model_output) @@ -475,5 +507,12 @@ def add_noise( noisy_samples = original_samples + noise * sigma return noisy_samples - def __len__(self): + def __len__(self) -> int: + """ + Returns the number of training timesteps. + + Returns: + `int`: + The number of training timesteps configured for the scheduler. + """ return self.config.num_train_timesteps