@@ -143,7 +143,20 @@ def set_begin_index(self, begin_index: int = 0):
143143 self ._begin_index = begin_index
144144
145145 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
146- def precondition_inputs (self , sample , sigma ):
146+ def precondition_inputs (self , sample : torch .Tensor , sigma : Union [float , torch .Tensor ]) -> torch .Tensor :
147+ """
148+ Precondition the input sample by scaling it according to the EDM formulation.
149+
150+ Args:
151+ sample (`torch.Tensor`):
152+ The input sample tensor to precondition.
153+ sigma (`float` or `torch.Tensor`):
154+ The current sigma (noise level) value.
155+
156+ Returns:
157+ `torch.Tensor`:
158+ The scaled input sample.
159+ """
147160 c_in = self ._get_conditioning_c_in (sigma )
148161 scaled_sample = sample * c_in
149162 return scaled_sample
@@ -155,7 +168,27 @@ def precondition_noise(self, sigma):
155168 return sigma .atan () / math .pi * 2
156169
157170 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
158- def precondition_outputs (self , sample , model_output , sigma ):
171+ def precondition_outputs (
172+ self ,
173+ sample : torch .Tensor ,
174+ model_output : torch .Tensor ,
175+ sigma : Union [float , torch .Tensor ],
176+ ) -> torch .Tensor :
177+ """
178+ Precondition the model outputs according to the EDM formulation.
179+
180+ Args:
181+ sample (`torch.Tensor`):
182+ The input sample tensor.
183+ model_output (`torch.Tensor`):
184+ The direct output from the learned diffusion model.
185+ sigma (`float` or `torch.Tensor`):
186+ The current sigma (noise level) value.
187+
188+ Returns:
189+ `torch.Tensor`:
190+ The denoised sample computed by combining the skip connection and output scaling.
191+ """
159192 sigma_data = self .config .sigma_data
160193 c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2 )
161194
@@ -173,13 +206,13 @@ def precondition_outputs(self, sample, model_output, sigma):
173206 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
174207 def scale_model_input (self , sample : torch .Tensor , timestep : Union [float , torch .Tensor ]) -> torch .Tensor :
175208 """
176- Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
177- current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm .
209+ Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
210+ need to scale the denoising model input depending on the current timestep .
178211
179212 Args:
180213 sample (`torch.Tensor`):
181- The input sample.
182- timestep (`int`, *optional* ):
214+ The input sample tensor .
215+ timestep (`float` or `torch.Tensor` ):
183216 The current timestep in the diffusion chain.
184217
185218 Returns:
@@ -242,8 +275,27 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
242275 self .noise_sampler = None
243276
244277 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
245- def _compute_karras_sigmas (self , ramp , sigma_min = None , sigma_max = None ) -> torch .Tensor :
246- """Constructs the noise schedule of Karras et al. (2022)."""
278+ def _compute_karras_sigmas (
279+ self ,
280+ ramp : torch .Tensor ,
281+ sigma_min : Optional [float ] = None ,
282+ sigma_max : Optional [float ] = None ,
283+ ) -> torch .Tensor :
284+ """
285+ Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
286+
287+ Args:
288+ ramp (`torch.Tensor`):
289+ A tensor of values in [0, 1] representing the interpolation positions.
290+ sigma_min (`float`, *optional*):
291+ Minimum sigma value. If `None`, uses `self.config.sigma_min`.
292+ sigma_max (`float`, *optional*):
293+ Maximum sigma value. If `None`, uses `self.config.sigma_max`.
294+
295+ Returns:
296+ `torch.Tensor`:
297+ The computed Karras sigma schedule.
298+ """
247299 sigma_min = sigma_min or self .config .sigma_min
248300 sigma_max = sigma_max or self .config .sigma_max
249301
@@ -254,10 +306,27 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.
254306 return sigmas
255307
256308 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
257- def _compute_exponential_sigmas (self , ramp , sigma_min = None , sigma_max = None ) -> torch .Tensor :
258- """Implementation closely follows k-diffusion.
259-
309+ def _compute_exponential_sigmas (
310+ self ,
311+ ramp : torch .Tensor ,
312+ sigma_min : Optional [float ] = None ,
313+ sigma_max : Optional [float ] = None ,
314+ ) -> torch .Tensor :
315+ """
316+ Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
260317 https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
318+
319+ Args:
320+ ramp (`torch.Tensor`):
321+ A tensor of values representing the interpolation positions.
322+ sigma_min (`float`, *optional*):
323+ Minimum sigma value. If `None`, uses `self.config.sigma_min`.
324+ sigma_max (`float`, *optional*):
325+ Maximum sigma value. If `None`, uses `self.config.sigma_max`.
326+
327+ Returns:
328+ `torch.Tensor`:
329+ The computed exponential sigma schedule.
261330 """
262331 sigma_min = sigma_min or self .config .sigma_min
263332 sigma_max = sigma_max or self .config .sigma_max
@@ -354,7 +423,10 @@ def dpm_solver_first_order_update(
354423 `torch.Tensor`:
355424 The sample tensor at the previous timestep.
356425 """
357- sigma_t , sigma_s = self .sigmas [self .step_index + 1 ], self .sigmas [self .step_index ]
426+ sigma_t , sigma_s = (
427+ self .sigmas [self .step_index + 1 ],
428+ self .sigmas [self .step_index ],
429+ )
358430 alpha_t , sigma_t = self ._sigma_to_alpha_sigma_t (sigma_t )
359431 alpha_s , sigma_s = self ._sigma_to_alpha_sigma_t (sigma_s )
360432 lambda_t = torch .log (alpha_t ) - torch .log (sigma_t )
@@ -540,7 +612,10 @@ def step(
540612 [g .initial_seed () for g in generator ] if isinstance (generator , list ) else generator .initial_seed ()
541613 )
542614 self .noise_sampler = BrownianTreeNoiseSampler (
543- model_output , sigma_min = self .config .sigma_min , sigma_max = self .config .sigma_max , seed = seed
615+ model_output ,
616+ sigma_min = self .config .sigma_min ,
617+ sigma_max = self .config .sigma_max ,
618+ seed = seed ,
544619 )
545620 noise = self .noise_sampler (self .sigmas [self .step_index ], self .sigmas [self .step_index + 1 ]).to (
546621 model_output .device
@@ -612,7 +687,18 @@ def add_noise(
612687 return noisy_samples
613688
614689 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
615- def _get_conditioning_c_in (self , sigma ):
690+ def _get_conditioning_c_in (self , sigma : Union [float , torch .Tensor ]) -> Union [float , torch .Tensor ]:
691+ """
692+ Compute the input conditioning factor for the EDM formulation.
693+
694+ Args:
695+ sigma (`float` or `torch.Tensor`):
696+ The current sigma (noise level) value.
697+
698+ Returns:
699+ `float` or `torch.Tensor`:
700+ The input conditioning factor `c_in`.
701+ """
616702 c_in = 1 / ((sigma ** 2 + self .config .sigma_data ** 2 ) ** 0.5 )
617703 return c_in
618704
0 commit comments