1
1
use std:: iter:: repeat;
2
2
3
- use super :: { betas_for_alpha_bar, BetaSchedule , PredictionType } ;
3
+ use super :: {
4
+ betas_for_alpha_bar,
5
+ dpmsolver:: { DPMSolverAlgorithmType , DPMSolverSchedulerConfig , DPMSolverType } ,
6
+ BetaSchedule , PredictionType ,
7
+ } ;
4
8
use tch:: { kind, Kind , Tensor } ;
5
9
6
- /// The algorithm type for the solver.
7
- ///
8
- #[ derive( Default , Debug , Clone , PartialEq , Eq ) ]
9
- pub enum DPMSolverAlgorithmType {
10
- /// Implements the algorithms defined in <https://arxiv.org/abs/2211.01095>.
11
- #[ default]
12
- DPMSolverPlusPlus ,
13
- /// Implements the algorithms defined in <https://arxiv.org/abs/2206.00927>.
14
- DPMSolver ,
15
- }
16
-
17
- /// The solver type for the second-order solver.
18
- /// The solver type slightly affects the sample quality, especially for
19
- /// small number of steps.
20
- #[ derive( Default , Debug , Clone , PartialEq , Eq ) ]
21
- pub enum DPMSolverType {
22
- #[ default]
23
- Midpoint ,
24
- Heun ,
25
- }
26
-
27
- #[ derive( Debug , Clone ) ]
28
- pub struct DPMSolverSinglestepSchedulerConfig {
29
- /// The value of beta at the beginning of training.
30
- pub beta_start : f64 ,
31
- /// The value of beta at the end of training.
32
- pub beta_end : f64 ,
33
- /// How beta evolved during training.
34
- pub beta_schedule : BetaSchedule ,
35
- /// number of diffusion steps used to train the model.
36
- pub train_timesteps : usize ,
37
- /// the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
38
- /// sampling, and `solver_order=3` for unconditional sampling.
39
- pub solver_order : usize ,
40
- /// prediction type of the scheduler function
41
- pub prediction_type : PredictionType ,
42
- /// The threshold value for dynamic thresholding. Valid only when `thresholding: true` and
43
- /// `algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus`.
44
- pub sample_max_value : f32 ,
45
- /// The algorithm type for the solver
46
- pub algorithm_type : DPMSolverAlgorithmType ,
47
- /// The solver type for the second-order solver.
48
- pub solver_type : DPMSolverType ,
49
- /// Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
50
- /// find this can stabilize the sampling of DPM-Solver for `steps < 15`, especially for steps <= 10.
51
- pub lower_order_final : bool ,
52
- }
53
-
54
- impl Default for DPMSolverSinglestepSchedulerConfig {
55
- fn default ( ) -> Self {
56
- Self {
57
- beta_start : 0.0001 ,
58
- beta_end : 0.02 ,
59
- train_timesteps : 1000 ,
60
- beta_schedule : BetaSchedule :: Linear ,
61
- solver_order : 2 ,
62
- prediction_type : PredictionType :: Epsilon ,
63
- sample_max_value : 1.0 ,
64
- algorithm_type : DPMSolverAlgorithmType :: DPMSolverPlusPlus ,
65
- solver_type : DPMSolverType :: Midpoint ,
66
- lower_order_final : true ,
67
- }
68
- }
69
- }
70
-
71
10
pub struct DPMSolverSinglestepScheduler {
72
11
alphas_cumprod : Vec < f64 > ,
73
12
alpha_t : Vec < f64 > ,
74
13
sigma_t : Vec < f64 > ,
75
14
lambda_t : Vec < f64 > ,
76
15
init_noise_sigma : f64 ,
77
16
order_list : Vec < usize > ,
17
+ /// Direct outputs from learned diffusion model at current and latter timesteps
78
18
model_outputs : Vec < Tensor > ,
19
+ /// List of current discrete timesteps in the diffusion chain
79
20
timesteps : Vec < usize > ,
21
+ /// Current instance of sample being created by diffusion process
80
22
sample : Option < Tensor > ,
81
- pub config : DPMSolverSinglestepSchedulerConfig ,
23
+ pub config : DPMSolverSchedulerConfig ,
82
24
}
83
25
84
26
impl DPMSolverSinglestepScheduler {
85
- pub fn new ( inference_steps : usize , config : DPMSolverSinglestepSchedulerConfig ) -> Self {
27
+ pub fn new ( inference_steps : usize , config : DPMSolverSchedulerConfig ) -> Self {
86
28
let betas = match config. beta_schedule {
87
29
BetaSchedule :: ScaledLinear => Tensor :: linspace (
88
30
config. beta_start . sqrt ( ) ,
@@ -462,8 +404,6 @@ fn get_order_list(steps: usize, solver_order: usize, lower_order_final: bool) ->
462
404
. chain ( [ & [ 1 ] [ ..] ] )
463
405
. flatten ( )
464
406
. map ( |v| * v)
465
- . collect ( )
466
- }
467
407
} else if solver_order == 1 {
468
408
repeat ( & [ 1 ] [ ..] ) . take ( steps) . flatten ( ) . map ( |v| * v) . collect ( )
469
409
} else {
@@ -473,11 +413,9 @@ fn get_order_list(steps: usize, solver_order: usize, lower_order_final: bool) ->
473
413
if solver_order == 3 {
474
414
repeat ( & [ 1 , 2 , 3 ] [ ..] ) . take ( steps / 3 ) . flatten ( ) . map ( |v| * v) . collect ( )
475
415
} else if solver_order == 2 {
476
- repeat ( & [ 1 , 2 ] [ ..] ) . take ( steps / 2 ) . flatten ( ) . map ( |v| * v ) . collect ( )
416
+ repeat ( dbg ! ( & [ 1 , 2 ] [ ..] ) ) . take ( dbg ! ( steps / 2 ) ) . flatten ( ) . map ( |v| dbg ! ( * v ) ) . collect ( )
477
417
} else if solver_order == 1 {
478
418
repeat ( & [ 1 ] [ ..] ) . take ( steps) . flatten ( ) . map ( |v| * v) . collect ( )
479
- } else {
480
- panic ! ( "invalid solver_order" ) ;
481
419
}
482
420
}
483
421
}
0 commit comments