Skip to content

Commit 9a2d1e8

Browse files
committed
default lower_order_final is true
1 parent 74e33bc commit 9a2d1e8

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/schedulers/dpmsolver_singlestep.rs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ impl DPMSolverSinglestepScheduler {
4949
let lambda_t = alpha_t.log() - sigma_t.log();
5050

5151
let step = (config.train_timesteps - 1) as f64 / inference_steps as f64;
52-
// https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py#L199-L204
52+
// https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py#L172-L173
5353
let timesteps: Vec<usize> = (0..inference_steps + 1)
5454
.map(|i| (i as f64 * step).round() as usize)
5555
// discards the 0.0 element
@@ -62,13 +62,15 @@ impl DPMSolverSinglestepScheduler {
6262
model_outputs.push(Tensor::new());
6363
}
6464

65+
let order_list = get_order_list(inference_steps, config.solver_order, true);
66+
6567
Self {
6668
alphas_cumprod: Vec::<f64>::from(alphas_cumprod),
6769
alpha_t: Vec::<f64>::from(alpha_t),
6870
sigma_t: Vec::<f64>::from(sigma_t),
6971
lambda_t: Vec::<f64>::from(lambda_t),
7072
init_noise_sigma: 1.,
71-
order_list: get_order_list(inference_steps, config.solver_order, false),
73+
order_list,
7274
model_outputs,
7375
timesteps,
7476
config,
@@ -292,6 +294,12 @@ impl DPMSolverSinglestepScheduler {
292294
self.timesteps.as_slice()
293295
}
294296

297+
/// Ensures interchangeability with schedulers that need to scale the denoising model input
298+
/// depending on the current timestep.
299+
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
300+
sample
301+
}
302+
295303
/// Step function propagating the sample with the singlestep DPM-Solver
296304
///
297305
/// # Arguments
@@ -311,7 +319,7 @@ impl DPMSolverSinglestepScheduler {
311319
self.model_outputs[i] = self.model_outputs[i + 1].shallow_clone();
312320
}
313321
let m = self.model_outputs.len();
314-
self.model_outputs[m - 1] = model_output;
322+
self.model_outputs[m - 1] = model_output.shallow_clone();
315323

316324
let order = self.order_list[step_index];
317325

@@ -320,7 +328,7 @@ impl DPMSolverSinglestepScheduler {
320328
self.sample = Some(sample.shallow_clone());
321329
};
322330

323-
let prev_sample = match order {
331+
match order {
324332
1 => self.dpm_solver_first_order_update(
325333
&self.model_outputs[self.model_outputs.len() - 1],
326334
timestep,
@@ -331,7 +339,7 @@ impl DPMSolverSinglestepScheduler {
331339
&self.model_outputs,
332340
[self.timesteps[step_index - 1], self.timesteps[step_index]],
333341
prev_timestep,
334-
self.sample.as_ref().unwrap(),
342+
&self.sample.as_ref().unwrap(),
335343
),
336344
3 => self.singlestep_dpm_solver_third_order_update(
337345
&self.model_outputs,
@@ -341,14 +349,12 @@ impl DPMSolverSinglestepScheduler {
341349
self.timesteps[step_index],
342350
],
343351
prev_timestep,
344-
self.sample.as_ref().unwrap(),
352+
&self.sample.as_ref().unwrap(),
345353
),
346354
_ => {
347355
panic!("invalid order");
348356
}
349-
};
350-
351-
prev_sample
357+
}
352358
}
353359

354360
pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {

0 commit comments

Comments
 (0)