@@ -2,7 +2,9 @@ use std::iter::repeat;
2
2
3
3
use super :: {
4
4
betas_for_alpha_bar,
5
- dpmsolver:: { DPMSolverAlgorithmType , DPMSolverSchedulerConfig , DPMSolverType } ,
5
+ dpmsolver:: {
6
+ DPMSolverAlgorithmType , DPMSolverScheduler , DPMSolverSchedulerConfig , DPMSolverType ,
7
+ } ,
6
8
BetaSchedule , PredictionType ,
7
9
} ;
8
10
use tch:: { kind, Kind , Tensor } ;
@@ -23,8 +25,8 @@ pub struct DPMSolverSinglestepScheduler {
23
25
pub config : DPMSolverSchedulerConfig ,
24
26
}
25
27
26
- impl DPMSolverSinglestepScheduler {
27
- pub fn new ( inference_steps : usize , config : DPMSolverSchedulerConfig ) -> Self {
28
+ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
29
+ fn new ( inference_steps : usize , config : DPMSolverSchedulerConfig ) -> Self {
28
30
let betas = match config. beta_schedule {
29
31
BetaSchedule :: ScaledLinear => Tensor :: linspace (
30
32
config. beta_start . sqrt ( ) ,
@@ -141,9 +143,9 @@ impl DPMSolverSinglestepScheduler {
141
143
/// * `timestep` - current discrete timestep in the diffusion chain
142
144
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
143
145
/// * `sample` - current instance of sample being created by diffusion process
144
- fn dpm_solver_first_order_update (
146
+ fn first_order_update (
145
147
& self ,
146
- model_output : & Tensor ,
148
+ model_output : Tensor ,
147
149
timestep : usize ,
148
150
prev_timestep : usize ,
149
151
sample : & Tensor ,
@@ -171,7 +173,7 @@ impl DPMSolverSinglestepScheduler {
171
173
/// * `timestep_list` - current and latter discrete timestep in the diffusion chain
172
174
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
173
175
/// * `sample` - current instance of sample being created by diffusion process
174
- fn singlestep_dpm_solver_second_order_update (
176
+ fn second_order_update (
175
177
& self ,
176
178
model_output_list : & Vec < Tensor > ,
177
179
timestep_list : [ usize ; 2 ] ,
@@ -232,7 +234,7 @@ impl DPMSolverSinglestepScheduler {
232
234
/// * `timestep_list` - current and latter discrete timestep in the diffusion chain
233
235
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
234
236
/// * `sample` - current instance of sample being created by diffusion process
235
- fn singlestep_dpm_solver_third_order_update (
237
+ fn third_order_update (
236
238
& self ,
237
239
model_output_list : & Vec < Tensor > ,
238
240
timestep_list : [ usize ; 3 ] ,
@@ -290,13 +292,13 @@ impl DPMSolverSinglestepScheduler {
290
292
}
291
293
}
292
294
293
- pub fn timesteps ( & self ) -> & [ usize ] {
295
+ fn timesteps ( & self ) -> & [ usize ] {
294
296
self . timesteps . as_slice ( )
295
297
}
296
298
297
299
/// Ensures interchangeability with schedulers that need to scale the denoising model input
298
300
/// depending on the current timestep.
299
- pub fn scale_model_input ( & self , sample : Tensor , _timestep : usize ) -> Tensor {
301
+ fn scale_model_input ( & self , sample : Tensor , _timestep : usize ) -> Tensor {
300
302
sample
301
303
}
302
304
@@ -307,7 +309,7 @@ impl DPMSolverSinglestepScheduler {
307
309
/// * `model_output` - direct output from learned diffusion model
308
310
/// * `timestep` - current discrete timestep in the diffusion chain
309
311
/// * `sample` - current instance of sample being created by diffusion process
310
- pub fn step ( & mut self , model_output : & Tensor , timestep : usize , sample : & Tensor ) -> Tensor {
312
+ fn step ( & mut self , model_output : & Tensor , timestep : usize , sample : & Tensor ) -> Tensor {
311
313
// https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py#L535
312
314
let step_index: usize = self . timesteps . iter ( ) . position ( |& t| t == timestep) . unwrap ( ) ;
313
315
@@ -329,19 +331,19 @@ impl DPMSolverSinglestepScheduler {
329
331
} ;
330
332
331
333
match order {
332
- 1 => self . dpm_solver_first_order_update (
333
- & self . model_outputs [ self . model_outputs . len ( ) - 1 ] ,
334
+ 1 => self . first_order_update (
335
+ model_output ,
334
336
timestep,
335
337
prev_timestep,
336
338
& self . sample . as_ref ( ) . unwrap ( ) ,
337
339
) ,
338
- 2 => self . singlestep_dpm_solver_second_order_update (
340
+ 2 => self . second_order_update (
339
341
& self . model_outputs ,
340
342
[ self . timesteps [ step_index - 1 ] , self . timesteps [ step_index] ] ,
341
343
prev_timestep,
342
344
& self . sample . as_ref ( ) . unwrap ( ) ,
343
345
) ,
344
- 3 => self . singlestep_dpm_solver_third_order_update (
346
+ 3 => self . third_order_update (
345
347
& self . model_outputs ,
346
348
[
347
349
self . timesteps [ step_index - 2 ] ,
@@ -357,12 +359,12 @@ impl DPMSolverSinglestepScheduler {
357
359
}
358
360
}
359
361
360
- pub fn add_noise ( & self , original_samples : & Tensor , noise : Tensor , timestep : usize ) -> Tensor {
362
+ fn add_noise ( & self , original_samples : & Tensor , noise : Tensor , timestep : usize ) -> Tensor {
361
363
self . alphas_cumprod [ timestep] . sqrt ( ) * original_samples. to_owned ( )
362
364
+ ( 1.0 - self . alphas_cumprod [ timestep] ) . sqrt ( ) * noise
363
365
}
364
366
365
- pub fn init_noise_sigma ( & self ) -> f64 {
367
+ fn init_noise_sigma ( & self ) -> f64 {
366
368
self . init_noise_sigma
367
369
}
368
370
}
0 commit comments