Skip to content

Commit 41e0286

Browse files
committed
Revert "Add DPMSolverScheduler trait"
This reverts commit 83b28b3.
1 parent fd7db58 commit 41e0286

File tree

3 files changed

+39
-81
lines changed

3 files changed

+39
-81
lines changed

src/schedulers/dpmsolver.rs

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use tch::Tensor;
2-
31
use crate::schedulers::BetaSchedule;
42
use crate::schedulers::PredictionType;
53

@@ -67,46 +65,3 @@ impl Default for DPMSolverSchedulerConfig {
6765
}
6866
}
6967
}
70-
71-
pub trait DPMSolverScheduler {
72-
fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self;
73-
fn convert_model_output(
74-
&self,
75-
model_output: &Tensor,
76-
timestep: usize,
77-
sample: &Tensor,
78-
) -> Tensor;
79-
80-
fn first_order_update(
81-
&self,
82-
model_output: Tensor,
83-
timestep: usize,
84-
prev_timestep: usize,
85-
sample: &Tensor,
86-
) -> Tensor;
87-
88-
fn second_order_update(
89-
&self,
90-
model_output_list: &Vec<Tensor>,
91-
timestep_list: [usize; 2],
92-
prev_timestep: usize,
93-
sample: &Tensor,
94-
) -> Tensor;
95-
96-
fn third_order_update(
97-
&self,
98-
model_output_list: &Vec<Tensor>,
99-
timestep_list: [usize; 3],
100-
prev_timestep: usize,
101-
sample: &Tensor,
102-
) -> Tensor;
103-
104-
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor;
105-
106-
fn timesteps(&self) -> &[usize];
107-
fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Tensor;
108-
109-
110-
fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor;
111-
fn init_noise_sigma(&self) -> f64;
112-
}

src/schedulers/dpmsolver_multistep.rs

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
1-
use super::{
2-
betas_for_alpha_bar,
3-
dpmsolver::{
4-
DPMSolverAlgorithmType, DPMSolverScheduler, DPMSolverSchedulerConfig, DPMSolverType,
5-
},
6-
BetaSchedule, PredictionType,
7-
};
1+
use super::{betas_for_alpha_bar, BetaSchedule, PredictionType, dpmsolver::{DPMSolverSchedulerConfig, DPMSolverAlgorithmType, DPMSolverType}};
2+
use std::iter;
83
use tch::{kind, Kind, Tensor};
94

105
pub struct DPMSolverMultistepScheduler {
@@ -21,8 +16,8 @@ pub struct DPMSolverMultistepScheduler {
2116
pub config: DPMSolverSchedulerConfig,
2217
}
2318

24-
impl DPMSolverScheduler for DPMSolverMultistepScheduler {
25-
fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
19+
impl DPMSolverMultistepScheduler {
20+
pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
2621
let betas = match config.beta_schedule {
2722
BetaSchedule::ScaledLinear => Tensor::linspace(
2823
config.beta_start.sqrt(),
@@ -122,7 +117,7 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler {
122117

123118
/// One step for the first-order DPM-Solver (equivalent to DDIM).
124119
/// See https://arxiv.org/abs/2206.00927 for the detailed derivation.
125-
fn first_order_update(
120+
fn dpm_solver_first_order_update(
126121
&self,
127122
model_output: Tensor,
128123
timestep: usize,
@@ -144,7 +139,7 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler {
144139
}
145140

146141
/// One step for the second-order multistep DPM-Solver.
147-
fn second_order_update(
142+
fn multistep_dpm_solver_second_order_update(
148143
&self,
149144
model_output_list: &Vec<Tensor>,
150145
timestep_list: [usize; 2],
@@ -197,7 +192,7 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler {
197192
}
198193

199194
/// One step for the third-order multistep DPM-Solver
200-
fn third_order_update(
195+
fn multistep_dpm_solver_third_order_update(
201196
&self,
202197
model_output_list: &Vec<Tensor>,
203198
timestep_list: [usize; 3],
@@ -242,7 +237,7 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler {
242237
}
243238
}
244239

245-
fn timesteps(&self) -> &[usize] {
240+
pub fn timesteps(&self) -> &[usize] {
246241
self.timesteps.as_slice()
247242
}
248243

@@ -277,14 +272,24 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler {
277272
|| self.lower_order_nums < 1
278273
|| lower_order_final
279274
{
280-
self.first_order_update(model_output, timestep, prev_timestep, sample)
275+
self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample)
281276
} else if self.config.solver_order == 2 || self.lower_order_nums < 2 || lower_order_second {
282277
let timestep_list = [self.timesteps[step_index - 1], timestep];
283-
self.second_order_update(&self.model_outputs, timestep_list, prev_timestep, sample)
278+
self.multistep_dpm_solver_second_order_update(
279+
&self.model_outputs,
280+
timestep_list,
281+
prev_timestep,
282+
sample,
283+
)
284284
} else {
285285
let timestep_list =
286286
[self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep];
287-
self.third_order_update(&self.model_outputs, timestep_list, prev_timestep, sample)
287+
self.multistep_dpm_solver_third_order_update(
288+
&self.model_outputs,
289+
timestep_list,
290+
prev_timestep,
291+
sample,
292+
)
288293
};
289294

290295
if self.lower_order_nums < self.config.solver_order {
@@ -294,12 +299,12 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler {
294299
prev_sample
295300
}
296301

297-
fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
302+
pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
298303
self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned()
299304
+ (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise
300305
}
301306

302-
fn init_noise_sigma(&self) -> f64 {
307+
pub fn init_noise_sigma(&self) -> f64 {
303308
self.init_noise_sigma
304309
}
305310
}

src/schedulers/dpmsolver_singlestep.rs

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ use std::iter::repeat;
22

33
use super::{
44
betas_for_alpha_bar,
5-
dpmsolver::{
6-
DPMSolverAlgorithmType, DPMSolverScheduler, DPMSolverSchedulerConfig, DPMSolverType,
7-
},
5+
dpmsolver::{DPMSolverAlgorithmType, DPMSolverSchedulerConfig, DPMSolverType},
86
BetaSchedule, PredictionType,
97
};
108
use tch::{kind, Kind, Tensor};
@@ -25,8 +23,8 @@ pub struct DPMSolverSinglestepScheduler {
2523
pub config: DPMSolverSchedulerConfig,
2624
}
2725

28-
impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
29-
fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
26+
impl DPMSolverSinglestepScheduler {
27+
pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
3028
let betas = match config.beta_schedule {
3129
BetaSchedule::ScaledLinear => Tensor::linspace(
3230
config.beta_start.sqrt(),
@@ -143,9 +141,9 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
143141
/// * `timestep` - current discrete timestep in the diffusion chain
144142
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
145143
/// * `sample` - current instance of sample being created by diffusion process
146-
fn first_order_update(
144+
fn dpm_solver_first_order_update(
147145
&self,
148-
model_output: Tensor,
146+
model_output: &Tensor,
149147
timestep: usize,
150148
prev_timestep: usize,
151149
sample: &Tensor,
@@ -173,7 +171,7 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
173171
/// * `timestep_list` - current and latter discrete timestep in the diffusion chain
174172
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
175173
/// * `sample` - current instance of sample being created by diffusion process
176-
fn second_order_update(
174+
fn singlestep_dpm_solver_second_order_update(
177175
&self,
178176
model_output_list: &Vec<Tensor>,
179177
timestep_list: [usize; 2],
@@ -234,7 +232,7 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
234232
/// * `timestep_list` - current and latter discrete timestep in the diffusion chain
235233
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
236234
/// * `sample` - current instance of sample being created by diffusion process
237-
fn third_order_update(
235+
fn singlestep_dpm_solver_third_order_update(
238236
&self,
239237
model_output_list: &Vec<Tensor>,
240238
timestep_list: [usize; 3],
@@ -292,13 +290,13 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
292290
}
293291
}
294292

295-
fn timesteps(&self) -> &[usize] {
293+
pub fn timesteps(&self) -> &[usize] {
296294
self.timesteps.as_slice()
297295
}
298296

299297
/// Ensures interchangeability with schedulers that need to scale the denoising model input
300298
/// depending on the current timestep.
301-
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
299+
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
302300
sample
303301
}
304302

@@ -309,7 +307,7 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
309307
/// * `model_output` - direct output from learned diffusion model
310308
/// * `timestep` - current discrete timestep in the diffusion chain
311309
/// * `sample` - current instance of sample being created by diffusion process
312-
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
310+
pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
313311
// https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py#L535
314312
let step_index: usize = self.timesteps.iter().position(|&t| t == timestep).unwrap();
315313

@@ -331,19 +329,19 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
331329
};
332330

333331
match order {
334-
1 => self.first_order_update(
335-
model_output,
332+
1 => self.dpm_solver_first_order_update(
333+
&self.model_outputs[self.model_outputs.len() - 1],
336334
timestep,
337335
prev_timestep,
338336
&self.sample.as_ref().unwrap(),
339337
),
340-
2 => self.second_order_update(
338+
2 => self.singlestep_dpm_solver_second_order_update(
341339
&self.model_outputs,
342340
[self.timesteps[step_index - 1], self.timesteps[step_index]],
343341
prev_timestep,
344342
&self.sample.as_ref().unwrap(),
345343
),
346-
3 => self.third_order_update(
344+
3 => self.singlestep_dpm_solver_third_order_update(
347345
&self.model_outputs,
348346
[
349347
self.timesteps[step_index - 2],
@@ -359,12 +357,12 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
359357
}
360358
}
361359

362-
fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
360+
pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
363361
self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned()
364362
+ (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise
365363
}
366364

367-
fn init_noise_sigma(&self) -> f64 {
365+
pub fn init_noise_sigma(&self) -> f64 {
368366
self.init_noise_sigma
369367
}
370368
}

0 commit comments

Comments
 (0)