Skip to content

Commit 2af1c2a

Browse files
committed
Refactor common parts of multistep/singlestep into dpmsolver
1 parent 26bb145 commit 2af1c2a

File tree

4 files changed

+84
-142
lines changed

4 files changed

+84
-142
lines changed

src/schedulers/dpmsolver.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
use crate::schedulers::BetaSchedule;
2+
use crate::schedulers::PredictionType;
3+
4+
/// The algorithm type for the solver.
5+
///
6+
#[derive(Default, Debug, Clone, PartialEq, Eq)]
7+
pub enum DPMSolverAlgorithmType {
8+
/// Implements the algorithms defined in <https://arxiv.org/abs/2211.01095>.
9+
#[default]
10+
DPMSolverPlusPlus,
11+
/// Implements the algorithms defined in <https://arxiv.org/abs/2206.00927>.
12+
DPMSolver,
13+
}
14+
15+
/// The solver type for the second-order solver.
16+
/// The solver type slightly affects the sample quality, especially for
17+
/// small number of steps.
18+
#[derive(Default, Debug, Clone, PartialEq, Eq)]
19+
pub enum DPMSolverType {
20+
#[default]
21+
Midpoint,
22+
Heun,
23+
}
24+
25+
#[derive(Debug, Clone)]
26+
pub struct DPMSolverSchedulerConfig {
27+
/// The value of beta at the beginning of training.
28+
pub beta_start: f64,
29+
/// The value of beta at the end of training.
30+
pub beta_end: f64,
31+
/// How beta evolved during training.
32+
pub beta_schedule: BetaSchedule,
33+
/// number of diffusion steps used to train the model.
34+
pub train_timesteps: usize,
35+
/// the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
36+
/// sampling, and `solver_order=3` for unconditional sampling.
37+
pub solver_order: usize,
38+
/// prediction type of the scheduler function
39+
pub prediction_type: PredictionType,
40+
/// The threshold value for dynamic thresholding. Valid only when `thresholding: true` and
41+
/// `algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus`.
42+
pub sample_max_value: f32,
43+
/// The algorithm type for the solver
44+
pub algorithm_type: DPMSolverAlgorithmType,
45+
/// The solver type for the second-order solver.
46+
pub solver_type: DPMSolverType,
47+
/// Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
48+
/// find this can stabilize the sampling of DPM-Solver for `steps < 15`, especially for steps <= 10.
49+
pub lower_order_final: bool,
50+
}
51+
52+
impl Default for DPMSolverSchedulerConfig {
53+
fn default() -> Self {
54+
Self {
55+
beta_start: 0.0001,
56+
beta_end: 0.02,
57+
beta_schedule: BetaSchedule::Linear,
58+
train_timesteps: 1000,
59+
solver_order: 2,
60+
prediction_type: PredictionType::Epsilon,
61+
sample_max_value: 1.0,
62+
algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus,
63+
solver_type: DPMSolverType::Midpoint,
64+
lower_order_final: true,
65+
}
66+
}
67+
}

src/schedulers/dpmsolver_multistep.rs

Lines changed: 5 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,22 @@
1-
use super::{betas_for_alpha_bar, BetaSchedule, PredictionType};
2-
use std::iter;
1+
use super::{betas_for_alpha_bar, BetaSchedule, PredictionType, dpmsolver::{DPMSolverSchedulerConfig, DPMSolverAlgorithmType, DPMSolverType}};
32
use tch::{kind, Kind, Tensor};
43

5-
/// The algorithm type for the solver.
6-
///
7-
#[derive(Default, Debug, Clone, PartialEq, Eq)]
8-
pub enum DPMSolverAlgorithmType {
9-
/// Implements the algorithms defined in <https://arxiv.org/abs/2211.01095>.
10-
#[default]
11-
DPMSolverPlusPlus,
12-
/// Implements the algorithms defined in <https://arxiv.org/abs/2206.00927>.
13-
DPMSolver,
14-
}
15-
16-
/// The solver type for the second-order solver.
17-
/// The solver type slightly affects the sample quality, especially for
18-
/// small number of steps.
19-
#[derive(Default, Debug, Clone, PartialEq, Eq)]
20-
pub enum DPMSolverType {
21-
#[default]
22-
Midpoint,
23-
Heun,
24-
}
25-
26-
#[derive(Debug, Clone)]
27-
pub struct DPMSolverMultistepSchedulerConfig {
28-
/// The value of beta at the beginning of training.
29-
pub beta_start: f64,
30-
/// The value of beta at the end of training.
31-
pub beta_end: f64,
32-
/// How beta evolved during training.
33-
pub beta_schedule: BetaSchedule,
34-
/// number of diffusion steps used to train the model.
35-
pub train_timesteps: usize,
36-
/// the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
37-
/// sampling, and `solver_order=3` for unconditional sampling.
38-
pub solver_order: usize,
39-
/// prediction type of the scheduler function
40-
pub prediction_type: PredictionType,
41-
/// The threshold value for dynamic thresholding. Valid only when `thresholding: true` and
42-
/// `algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus`.
43-
pub sample_max_value: f32,
44-
/// The algorithm type for the solver
45-
pub algorithm_type: DPMSolverAlgorithmType,
46-
/// The solver type for the second-order solver.
47-
pub solver_type: DPMSolverType,
48-
/// Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
49-
/// find this can stabilize the sampling of DPM-Solver for `steps < 15`, especially for steps <= 10.
50-
pub lower_order_final: bool,
51-
}
52-
53-
impl Default for DPMSolverMultistepSchedulerConfig {
54-
fn default() -> Self {
55-
Self {
56-
beta_start: 0.00085,
57-
beta_end: 0.012,
58-
beta_schedule: BetaSchedule::ScaledLinear,
59-
train_timesteps: 1000,
60-
solver_order: 2,
61-
prediction_type: PredictionType::Epsilon,
62-
sample_max_value: 1.0,
63-
algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus,
64-
solver_type: DPMSolverType::Midpoint,
65-
lower_order_final: true,
66-
}
67-
}
68-
}
69-
704
pub struct DPMSolverMultistepScheduler {
715
alphas_cumprod: Vec<f64>,
726
alpha_t: Vec<f64>,
737
sigma_t: Vec<f64>,
748
lambda_t: Vec<f64>,
759
init_noise_sigma: f64,
7610
lower_order_nums: usize,
11+
/// Direct outputs from learned diffusion model at current and latter timesteps
7712
model_outputs: Vec<Tensor>,
13+
/// List of current discrete timesteps in the diffusion chain
7814
timesteps: Vec<usize>,
79-
pub config: DPMSolverMultistepSchedulerConfig,
15+
pub config: DPMSolverSchedulerConfig,
8016
}
8117

8218
impl DPMSolverMultistepScheduler {
83-
pub fn new(inference_steps: usize, config: DPMSolverMultistepSchedulerConfig) -> Self {
19+
pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
8420
let betas = match config.beta_schedule {
8521
BetaSchedule::ScaledLinear => Tensor::linspace(
8622
config.beta_start.sqrt(),

src/schedulers/dpmsolver_singlestep.rs

Lines changed: 11 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,30 @@
11
use std::iter::repeat;
22

3-
use super::{betas_for_alpha_bar, BetaSchedule, PredictionType};
3+
use super::{
4+
betas_for_alpha_bar,
5+
dpmsolver::{DPMSolverAlgorithmType, DPMSolverSchedulerConfig, DPMSolverType},
6+
BetaSchedule, PredictionType,
7+
};
48
use tch::{kind, Kind, Tensor};
59

6-
/// The algorithm type for the solver.
7-
///
8-
#[derive(Default, Debug, Clone, PartialEq, Eq)]
9-
pub enum DPMSolverAlgorithmType {
10-
/// Implements the algorithms defined in <https://arxiv.org/abs/2211.01095>.
11-
#[default]
12-
DPMSolverPlusPlus,
13-
/// Implements the algorithms defined in <https://arxiv.org/abs/2206.00927>.
14-
DPMSolver,
15-
}
16-
17-
/// The solver type for the second-order solver.
18-
/// The solver type slightly affects the sample quality, especially for
19-
/// small number of steps.
20-
#[derive(Default, Debug, Clone, PartialEq, Eq)]
21-
pub enum DPMSolverType {
22-
#[default]
23-
Midpoint,
24-
Heun,
25-
}
26-
27-
#[derive(Debug, Clone)]
28-
pub struct DPMSolverSinglestepSchedulerConfig {
29-
/// The value of beta at the beginning of training.
30-
pub beta_start: f64,
31-
/// The value of beta at the end of training.
32-
pub beta_end: f64,
33-
/// How beta evolved during training.
34-
pub beta_schedule: BetaSchedule,
35-
/// number of diffusion steps used to train the model.
36-
pub train_timesteps: usize,
37-
/// the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
38-
/// sampling, and `solver_order=3` for unconditional sampling.
39-
pub solver_order: usize,
40-
/// prediction type of the scheduler function
41-
pub prediction_type: PredictionType,
42-
/// The threshold value for dynamic thresholding. Valid only when `thresholding: true` and
43-
/// `algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus`.
44-
pub sample_max_value: f32,
45-
/// The algorithm type for the solver
46-
pub algorithm_type: DPMSolverAlgorithmType,
47-
/// The solver type for the second-order solver.
48-
pub solver_type: DPMSolverType,
49-
/// Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
50-
/// find this can stabilize the sampling of DPM-Solver for `steps < 15`, especially for steps <= 10.
51-
pub lower_order_final: bool,
52-
}
53-
54-
impl Default for DPMSolverSinglestepSchedulerConfig {
55-
fn default() -> Self {
56-
Self {
57-
beta_start: 0.0001,
58-
beta_end: 0.02,
59-
train_timesteps: 1000,
60-
beta_schedule: BetaSchedule::Linear,
61-
solver_order: 2,
62-
prediction_type: PredictionType::Epsilon,
63-
sample_max_value: 1.0,
64-
algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus,
65-
solver_type: DPMSolverType::Midpoint,
66-
lower_order_final: true,
67-
}
68-
}
69-
}
70-
7110
pub struct DPMSolverSinglestepScheduler {
7211
alphas_cumprod: Vec<f64>,
7312
alpha_t: Vec<f64>,
7413
sigma_t: Vec<f64>,
7514
lambda_t: Vec<f64>,
7615
init_noise_sigma: f64,
7716
order_list: Vec<usize>,
17+
/// Direct outputs from learned diffusion model at current and latter timesteps
7818
model_outputs: Vec<Tensor>,
19+
/// List of current discrete timesteps in the diffusion chain
7920
timesteps: Vec<usize>,
21+
/// Current instance of sample being created by diffusion process
8022
sample: Option<Tensor>,
81-
pub config: DPMSolverSinglestepSchedulerConfig,
23+
pub config: DPMSolverSchedulerConfig,
8224
}
8325

8426
impl DPMSolverSinglestepScheduler {
85-
pub fn new(inference_steps: usize, config: DPMSolverSinglestepSchedulerConfig) -> Self {
27+
pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
8628
let betas = match config.beta_schedule {
8729
BetaSchedule::ScaledLinear => Tensor::linspace(
8830
config.beta_start.sqrt(),
@@ -462,8 +404,6 @@ fn get_order_list(steps: usize, solver_order: usize, lower_order_final: bool) ->
462404
.chain([&[1][..]])
463405
.flatten()
464406
.map(|v| *v)
465-
.collect()
466-
}
467407
} else if solver_order == 1 {
468408
repeat(&[1][..]).take(steps).flatten().map(|v| *v).collect()
469409
} else {
@@ -473,11 +413,9 @@ fn get_order_list(steps: usize, solver_order: usize, lower_order_final: bool) ->
473413
if solver_order == 3 {
474414
repeat(&[1, 2, 3][..]).take(steps / 3).flatten().map(|v| *v).collect()
475415
} else if solver_order == 2 {
476-
repeat(&[1, 2][..]).take(steps / 2).flatten().map(|v| *v).collect()
416+
repeat(dbg!(&[1, 2][..])).take(dbg!(steps / 2)).flatten().map(|v| dbg!(*v)).collect()
477417
} else if solver_order == 1 {
478418
repeat(&[1][..]).take(steps).flatten().map(|v| *v).collect()
479-
} else {
480-
panic!("invalid solver_order");
481419
}
482420
}
483421
}

src/schedulers/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use tch::{Kind, Tensor};
77

88
pub mod ddim;
99
pub mod ddpm;
10+
pub mod dpmsolver;
1011
pub mod dpmsolver_multistep;
1112
pub mod dpmsolver_singlestep;
1213
pub mod euler_ancestral_discrete;

0 commit comments

Comments
 (0)