@@ -49,7 +49,7 @@ impl DPMSolverSinglestepScheduler {
49
49
let lambda_t = alpha_t. log ( ) - sigma_t. log ( ) ;
50
50
51
51
let step = ( config. train_timesteps - 1 ) as f64 / inference_steps as f64 ;
52
- // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_multistep .py#L199-L204
52
+ // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_singlestep .py#L172-L173
53
53
let timesteps: Vec < usize > = ( 0 ..inference_steps + 1 )
54
54
. map ( |i| ( i as f64 * step) . round ( ) as usize )
55
55
// discards the 0.0 element
@@ -62,13 +62,15 @@ impl DPMSolverSinglestepScheduler {
62
62
model_outputs. push ( Tensor :: new ( ) ) ;
63
63
}
64
64
65
+ let order_list = get_order_list ( inference_steps, config. solver_order , true ) ;
66
+
65
67
Self {
66
68
alphas_cumprod : Vec :: < f64 > :: from ( alphas_cumprod) ,
67
69
alpha_t : Vec :: < f64 > :: from ( alpha_t) ,
68
70
sigma_t : Vec :: < f64 > :: from ( sigma_t) ,
69
71
lambda_t : Vec :: < f64 > :: from ( lambda_t) ,
70
72
init_noise_sigma : 1. ,
71
- order_list : get_order_list ( inference_steps , config . solver_order , false ) ,
73
+ order_list,
72
74
model_outputs,
73
75
timesteps,
74
76
config,
@@ -292,6 +294,12 @@ impl DPMSolverSinglestepScheduler {
292
294
self . timesteps . as_slice ( )
293
295
}
294
296
297
+ /// Ensures interchangeability with schedulers that need to scale the denoising model input
298
+ /// depending on the current timestep.
299
+ pub fn scale_model_input ( & self , sample : Tensor , _timestep : usize ) -> Tensor {
300
+ sample
301
+ }
302
+
295
303
/// Step function propagating the sample with the singlestep DPM-Solver
296
304
///
297
305
/// # Arguments
@@ -311,7 +319,7 @@ impl DPMSolverSinglestepScheduler {
311
319
self . model_outputs [ i] = self . model_outputs [ i + 1 ] . shallow_clone ( ) ;
312
320
}
313
321
let m = self . model_outputs . len ( ) ;
314
- self . model_outputs [ m - 1 ] = model_output;
322
+ self . model_outputs [ m - 1 ] = model_output. shallow_clone ( ) ;
315
323
316
324
let order = self . order_list [ step_index] ;
317
325
@@ -320,7 +328,7 @@ impl DPMSolverSinglestepScheduler {
320
328
self . sample = Some ( sample. shallow_clone ( ) ) ;
321
329
} ;
322
330
323
- let prev_sample = match order {
331
+ match order {
324
332
1 => self . dpm_solver_first_order_update (
325
333
& self . model_outputs [ self . model_outputs . len ( ) - 1 ] ,
326
334
timestep,
@@ -331,7 +339,7 @@ impl DPMSolverSinglestepScheduler {
331
339
& self . model_outputs ,
332
340
[ self . timesteps [ step_index - 1 ] , self . timesteps [ step_index] ] ,
333
341
prev_timestep,
334
- self . sample . as_ref ( ) . unwrap ( ) ,
342
+ & self . sample . as_ref ( ) . unwrap ( ) ,
335
343
) ,
336
344
3 => self . singlestep_dpm_solver_third_order_update (
337
345
& self . model_outputs ,
@@ -341,14 +349,12 @@ impl DPMSolverSinglestepScheduler {
341
349
self . timesteps [ step_index] ,
342
350
] ,
343
351
prev_timestep,
344
- self . sample . as_ref ( ) . unwrap ( ) ,
352
+ & self . sample . as_ref ( ) . unwrap ( ) ,
345
353
) ,
346
354
_ => {
347
355
panic ! ( "invalid order" ) ;
348
356
}
349
- } ;
350
-
351
- prev_sample
357
+ }
352
358
}
353
359
354
360
pub fn add_noise ( & self , original_samples : & Tensor , noise : Tensor , timestep : usize ) -> Tensor {
0 commit comments