Skip to content

Commit fd7db58

Browse files
committed
Add DPMSolverScheduler trait
1 parent 84b8680 commit fd7db58

File tree

3 files changed

+81
-38
lines changed

3 files changed

+81
-38
lines changed

src/schedulers/dpmsolver.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use tch::Tensor;
2+
13
use crate::schedulers::BetaSchedule;
24
use crate::schedulers::PredictionType;
35

@@ -65,3 +67,46 @@ impl Default for DPMSolverSchedulerConfig {
6567
}
6668
}
6769
}
70+
71+
pub trait DPMSolverScheduler {
72+
fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self;
73+
fn convert_model_output(
74+
&self,
75+
model_output: &Tensor,
76+
timestep: usize,
77+
sample: &Tensor,
78+
) -> Tensor;
79+
80+
fn first_order_update(
81+
&self,
82+
model_output: Tensor,
83+
timestep: usize,
84+
prev_timestep: usize,
85+
sample: &Tensor,
86+
) -> Tensor;
87+
88+
fn second_order_update(
89+
&self,
90+
model_output_list: &Vec<Tensor>,
91+
timestep_list: [usize; 2],
92+
prev_timestep: usize,
93+
sample: &Tensor,
94+
) -> Tensor;
95+
96+
fn third_order_update(
97+
&self,
98+
model_output_list: &Vec<Tensor>,
99+
timestep_list: [usize; 3],
100+
prev_timestep: usize,
101+
sample: &Tensor,
102+
) -> Tensor;
103+
104+
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor;
105+
106+
fn timesteps(&self) -> &[usize];
107+
fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Tensor;
108+
109+
110+
fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor;
111+
fn init_noise_sigma(&self) -> f64;
112+
}

src/schedulers/dpmsolver_multistep.rs

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
use super::{betas_for_alpha_bar, BetaSchedule, PredictionType, dpmsolver::{DPMSolverSchedulerConfig, DPMSolverAlgorithmType, DPMSolverType}};
1+
use super::{
2+
betas_for_alpha_bar,
3+
dpmsolver::{
4+
DPMSolverAlgorithmType, DPMSolverScheduler, DPMSolverSchedulerConfig, DPMSolverType,
5+
},
6+
BetaSchedule, PredictionType,
7+
};
28
use tch::{kind, Kind, Tensor};
39

410
pub struct DPMSolverMultistepScheduler {
@@ -15,8 +21,8 @@ pub struct DPMSolverMultistepScheduler {
1521
pub config: DPMSolverSchedulerConfig,
1622
}
1723

18-
impl DPMSolverMultistepScheduler {
19-
pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
24+
impl DPMSolverScheduler for DPMSolverMultistepScheduler {
25+
fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
2026
let betas = match config.beta_schedule {
2127
BetaSchedule::ScaledLinear => Tensor::linspace(
2228
config.beta_start.sqrt(),
@@ -116,7 +122,7 @@ impl DPMSolverMultistepScheduler {
116122

117123
/// One step for the first-order DPM-Solver (equivalent to DDIM).
118124
/// See https://arxiv.org/abs/2206.00927 for the detailed derivation.
119-
fn dpm_solver_first_order_update(
125+
fn first_order_update(
120126
&self,
121127
model_output: Tensor,
122128
timestep: usize,
@@ -138,7 +144,7 @@ impl DPMSolverMultistepScheduler {
138144
}
139145

140146
/// One step for the second-order multistep DPM-Solver.
141-
fn multistep_dpm_solver_second_order_update(
147+
fn second_order_update(
142148
&self,
143149
model_output_list: &Vec<Tensor>,
144150
timestep_list: [usize; 2],
@@ -191,7 +197,7 @@ impl DPMSolverMultistepScheduler {
191197
}
192198

193199
/// One step for the third-order multistep DPM-Solver
194-
fn multistep_dpm_solver_third_order_update(
200+
fn third_order_update(
195201
&self,
196202
model_output_list: &Vec<Tensor>,
197203
timestep_list: [usize; 3],
@@ -236,7 +242,7 @@ impl DPMSolverMultistepScheduler {
236242
}
237243
}
238244

239-
pub fn timesteps(&self) -> &[usize] {
245+
fn timesteps(&self) -> &[usize] {
240246
self.timesteps.as_slice()
241247
}
242248

@@ -271,24 +277,14 @@ impl DPMSolverMultistepScheduler {
271277
|| self.lower_order_nums < 1
272278
|| lower_order_final
273279
{
274-
self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample)
280+
self.first_order_update(model_output, timestep, prev_timestep, sample)
275281
} else if self.config.solver_order == 2 || self.lower_order_nums < 2 || lower_order_second {
276282
let timestep_list = [self.timesteps[step_index - 1], timestep];
277-
self.multistep_dpm_solver_second_order_update(
278-
&self.model_outputs,
279-
timestep_list,
280-
prev_timestep,
281-
sample,
282-
)
283+
self.second_order_update(&self.model_outputs, timestep_list, prev_timestep, sample)
283284
} else {
284285
let timestep_list =
285286
[self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep];
286-
self.multistep_dpm_solver_third_order_update(
287-
&self.model_outputs,
288-
timestep_list,
289-
prev_timestep,
290-
sample,
291-
)
287+
self.third_order_update(&self.model_outputs, timestep_list, prev_timestep, sample)
292288
};
293289

294290
if self.lower_order_nums < self.config.solver_order {
@@ -298,12 +294,12 @@ impl DPMSolverMultistepScheduler {
298294
prev_sample
299295
}
300296

301-
pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
297+
fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
302298
self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned()
303299
+ (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise
304300
}
305301

306-
pub fn init_noise_sigma(&self) -> f64 {
302+
fn init_noise_sigma(&self) -> f64 {
307303
self.init_noise_sigma
308304
}
309305
}

src/schedulers/dpmsolver_singlestep.rs

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ use std::iter::repeat;
22

33
use super::{
44
betas_for_alpha_bar,
5-
dpmsolver::{DPMSolverAlgorithmType, DPMSolverSchedulerConfig, DPMSolverType},
5+
dpmsolver::{
6+
DPMSolverAlgorithmType, DPMSolverScheduler, DPMSolverSchedulerConfig, DPMSolverType,
7+
},
68
BetaSchedule, PredictionType,
79
};
810
use tch::{kind, Kind, Tensor};
@@ -23,8 +25,8 @@ pub struct DPMSolverSinglestepScheduler {
2325
pub config: DPMSolverSchedulerConfig,
2426
}
2527

26-
impl DPMSolverSinglestepScheduler {
27-
pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
28+
impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
29+
fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
2830
let betas = match config.beta_schedule {
2931
BetaSchedule::ScaledLinear => Tensor::linspace(
3032
config.beta_start.sqrt(),
@@ -141,9 +143,9 @@ impl DPMSolverSinglestepScheduler {
141143
/// * `timestep` - current discrete timestep in the diffusion chain
142144
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
143145
/// * `sample` - current instance of sample being created by diffusion process
144-
fn dpm_solver_first_order_update(
146+
fn first_order_update(
145147
&self,
146-
model_output: &Tensor,
148+
model_output: Tensor,
147149
timestep: usize,
148150
prev_timestep: usize,
149151
sample: &Tensor,
@@ -171,7 +173,7 @@ impl DPMSolverSinglestepScheduler {
171173
/// * `timestep_list` - current and latter discrete timestep in the diffusion chain
172174
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
173175
/// * `sample` - current instance of sample being created by diffusion process
174-
fn singlestep_dpm_solver_second_order_update(
176+
fn second_order_update(
175177
&self,
176178
model_output_list: &Vec<Tensor>,
177179
timestep_list: [usize; 2],
@@ -232,7 +234,7 @@ impl DPMSolverSinglestepScheduler {
232234
/// * `timestep_list` - current and latter discrete timestep in the diffusion chain
233235
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
234236
/// * `sample` - current instance of sample being created by diffusion process
235-
fn singlestep_dpm_solver_third_order_update(
237+
fn third_order_update(
236238
&self,
237239
model_output_list: &Vec<Tensor>,
238240
timestep_list: [usize; 3],
@@ -290,13 +292,13 @@ impl DPMSolverSinglestepScheduler {
290292
}
291293
}
292294

293-
pub fn timesteps(&self) -> &[usize] {
295+
fn timesteps(&self) -> &[usize] {
294296
self.timesteps.as_slice()
295297
}
296298

297299
/// Ensures interchangeability with schedulers that need to scale the denoising model input
298300
/// depending on the current timestep.
299-
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
301+
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
300302
sample
301303
}
302304

@@ -307,7 +309,7 @@ impl DPMSolverSinglestepScheduler {
307309
/// * `model_output` - direct output from learned diffusion model
308310
/// * `timestep` - current discrete timestep in the diffusion chain
309311
/// * `sample` - current instance of sample being created by diffusion process
310-
pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
312+
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
311313
// https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py#L535
312314
let step_index: usize = self.timesteps.iter().position(|&t| t == timestep).unwrap();
313315

@@ -329,19 +331,19 @@ impl DPMSolverSinglestepScheduler {
329331
};
330332

331333
match order {
332-
1 => self.dpm_solver_first_order_update(
333-
&self.model_outputs[self.model_outputs.len() - 1],
334+
1 => self.first_order_update(
335+
model_output,
334336
timestep,
335337
prev_timestep,
336338
&self.sample.as_ref().unwrap(),
337339
),
338-
2 => self.singlestep_dpm_solver_second_order_update(
340+
2 => self.second_order_update(
339341
&self.model_outputs,
340342
[self.timesteps[step_index - 1], self.timesteps[step_index]],
341343
prev_timestep,
342344
&self.sample.as_ref().unwrap(),
343345
),
344-
3 => self.singlestep_dpm_solver_third_order_update(
346+
3 => self.third_order_update(
345347
&self.model_outputs,
346348
[
347349
self.timesteps[step_index - 2],
@@ -357,12 +359,12 @@ impl DPMSolverSinglestepScheduler {
357359
}
358360
}
359361

360-
pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
362+
fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
361363
self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned()
362364
+ (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise
363365
}
364366

365-
pub fn init_noise_sigma(&self) -> f64 {
367+
fn init_noise_sigma(&self) -> f64 {
366368
self.init_noise_sigma
367369
}
368370
}

0 commit comments

Comments
 (0)