@@ -2,9 +2,7 @@ use std::iter::repeat;
2
2
3
3
use super :: {
4
4
betas_for_alpha_bar,
5
- dpmsolver:: {
6
- DPMSolverAlgorithmType , DPMSolverScheduler , DPMSolverSchedulerConfig , DPMSolverType ,
7
- } ,
5
+ dpmsolver:: { DPMSolverAlgorithmType , DPMSolverSchedulerConfig , DPMSolverType } ,
8
6
BetaSchedule , PredictionType ,
9
7
} ;
10
8
use tch:: { kind, Kind , Tensor } ;
@@ -25,8 +23,8 @@ pub struct DPMSolverSinglestepScheduler {
25
23
pub config : DPMSolverSchedulerConfig ,
26
24
}
27
25
28
- impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
29
- fn new ( inference_steps : usize , config : DPMSolverSchedulerConfig ) -> Self {
26
+ impl DPMSolverSinglestepScheduler {
27
+ pub fn new ( inference_steps : usize , config : DPMSolverSchedulerConfig ) -> Self {
30
28
let betas = match config. beta_schedule {
31
29
BetaSchedule :: ScaledLinear => Tensor :: linspace (
32
30
config. beta_start . sqrt ( ) ,
@@ -143,9 +141,9 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
143
141
/// * `timestep` - current discrete timestep in the diffusion chain
144
142
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
145
143
/// * `sample` - current instance of sample being created by diffusion process
146
- fn first_order_update (
144
+ fn dpm_solver_first_order_update (
147
145
& self ,
148
- model_output : Tensor ,
146
+ model_output : & Tensor ,
149
147
timestep : usize ,
150
148
prev_timestep : usize ,
151
149
sample : & Tensor ,
@@ -173,7 +171,7 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
173
171
/// * `timestep_list` - current and latter discrete timestep in the diffusion chain
174
172
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
175
173
/// * `sample` - current instance of sample being created by diffusion process
176
- fn second_order_update (
174
+ fn singlestep_dpm_solver_second_order_update (
177
175
& self ,
178
176
model_output_list : & Vec < Tensor > ,
179
177
timestep_list : [ usize ; 2 ] ,
@@ -234,7 +232,7 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
234
232
/// * `timestep_list` - current and latter discrete timestep in the diffusion chain
235
233
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
236
234
/// * `sample` - current instance of sample being created by diffusion process
237
- fn third_order_update (
235
+ fn singlestep_dpm_solver_third_order_update (
238
236
& self ,
239
237
model_output_list : & Vec < Tensor > ,
240
238
timestep_list : [ usize ; 3 ] ,
@@ -292,13 +290,13 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
292
290
}
293
291
}
294
292
295
- fn timesteps ( & self ) -> & [ usize ] {
293
+ pub fn timesteps ( & self ) -> & [ usize ] {
296
294
self . timesteps . as_slice ( )
297
295
}
298
296
299
297
/// Ensures interchangeability with schedulers that need to scale the denoising model input
300
298
/// depending on the current timestep.
301
- fn scale_model_input ( & self , sample : Tensor , _timestep : usize ) -> Tensor {
299
+ pub fn scale_model_input ( & self , sample : Tensor , _timestep : usize ) -> Tensor {
302
300
sample
303
301
}
304
302
@@ -309,7 +307,7 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
309
307
/// * `model_output` - direct output from learned diffusion model
310
308
/// * `timestep` - current discrete timestep in the diffusion chain
311
309
/// * `sample` - current instance of sample being created by diffusion process
312
- fn step ( & mut self , model_output : & Tensor , timestep : usize , sample : & Tensor ) -> Tensor {
310
+ pub fn step ( & mut self , model_output : & Tensor , timestep : usize , sample : & Tensor ) -> Tensor {
313
311
// https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py#L535
314
312
let step_index: usize = self . timesteps . iter ( ) . position ( |& t| t == timestep) . unwrap ( ) ;
315
313
@@ -331,19 +329,19 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
331
329
} ;
332
330
333
331
match order {
334
- 1 => self . first_order_update (
335
- model_output ,
332
+ 1 => self . dpm_solver_first_order_update (
333
+ & self . model_outputs [ self . model_outputs . len ( ) - 1 ] ,
336
334
timestep,
337
335
prev_timestep,
338
336
& self . sample . as_ref ( ) . unwrap ( ) ,
339
337
) ,
340
- 2 => self . second_order_update (
338
+ 2 => self . singlestep_dpm_solver_second_order_update (
341
339
& self . model_outputs ,
342
340
[ self . timesteps [ step_index - 1 ] , self . timesteps [ step_index] ] ,
343
341
prev_timestep,
344
342
& self . sample . as_ref ( ) . unwrap ( ) ,
345
343
) ,
346
- 3 => self . third_order_update (
344
+ 3 => self . singlestep_dpm_solver_third_order_update (
347
345
& self . model_outputs ,
348
346
[
349
347
self . timesteps [ step_index - 2 ] ,
@@ -359,12 +357,12 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
359
357
}
360
358
}
361
359
362
- fn add_noise ( & self , original_samples : & Tensor , noise : Tensor , timestep : usize ) -> Tensor {
360
+ pub fn add_noise ( & self , original_samples : & Tensor , noise : Tensor , timestep : usize ) -> Tensor {
363
361
self . alphas_cumprod [ timestep] . sqrt ( ) * original_samples. to_owned ( )
364
362
+ ( 1.0 - self . alphas_cumprod [ timestep] ) . sqrt ( ) * noise
365
363
}
366
364
367
- fn init_noise_sigma ( & self ) -> f64 {
365
+ pub fn init_noise_sigma ( & self ) -> f64 {
368
366
self . init_noise_sigma
369
367
}
370
368
}
0 commit comments