@@ -1447,27 +1447,34 @@ def post_cfg_function(args):
14471447 old_d = d
14481448 return x
14491449
1450+
14501451@torch .no_grad ()
14511452def sample_gradient_estimation_cfg_pp (model , x , sigmas , extra_args = None , callback = None , disable = None , ge_gamma = 2. ):
14521453 return sample_gradient_estimation (model , x , sigmas , extra_args = extra_args , callback = callback , disable = disable , ge_gamma = ge_gamma , cfg_pp = True )
14531454
1455+
14541456@torch .no_grad ()
1455- def sample_er_sde (model , x , sigmas , extra_args = None , callback = None , disable = None , s_noise = 1. , noise_sampler = None , noise_scaler = None , max_stage = 3 ):
1456- """
1457- Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
1457+ def sample_er_sde (model , x , sigmas , extra_args = None , callback = None , disable = None , s_noise = 1.0 , noise_sampler = None , noise_scaler = None , max_stage = 3 ):
1458+ """Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
14581459 Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
14591460 """
14601461 extra_args = {} if extra_args is None else extra_args
14611462 seed = extra_args .get ("seed" , None )
14621463 noise_sampler = default_noise_sampler (x , seed = seed ) if noise_sampler is None else noise_sampler
14631464 s_in = x .new_ones ([x .shape [0 ]])
14641465
1465- def default_noise_scaler (sigma ):
1466- return sigma * ((sigma ** 0.3 ).exp () + 10.0 )
1467- noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
1466+ def default_er_sde_noise_scaler (x ):
1467+ return x * ((x ** 0.3 ).exp () + 10.0 )
1468+
1469+ noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
14681470 num_integration_points = 200.0
14691471 point_indice = torch .arange (0 , num_integration_points , dtype = torch .float32 , device = x .device )
14701472
1473+ model_sampling = model .inner_model .model_patcher .get_model_object ("model_sampling" )
1474+ sigmas = offset_first_sigma_for_snr (sigmas , model_sampling )
1475+ half_log_snrs = sigma_to_half_log_snr (sigmas , model_sampling )
1476+ er_lambdas = half_log_snrs .neg ().exp () # er_lambda_t = sigma_t / alpha_t
1477+
14711478 old_denoised = None
14721479 old_denoised_d = None
14731480
@@ -1478,32 +1485,36 @@ def default_noise_scaler(sigma):
14781485 stage_used = min (max_stage , i + 1 )
14791486 if sigmas [i + 1 ] == 0 :
14801487 x = denoised
1481- elif stage_used == 1 :
1482- r = noise_scaler (sigmas [i + 1 ]) / noise_scaler (sigmas [i ])
1483- x = r * x + (1 - r ) * denoised
14841488 else :
1485- r = noise_scaler (sigmas [i + 1 ]) / noise_scaler (sigmas [i ])
1486- x = r * x + (1 - r ) * denoised
1487-
1488- dt = sigmas [i + 1 ] - sigmas [i ]
1489- sigma_step_size = - dt / num_integration_points
1490- sigma_pos = sigmas [i + 1 ] + point_indice * sigma_step_size
1491- scaled_pos = noise_scaler (sigma_pos )
1492-
1493- # Stage 2
1494- s = torch .sum (1 / scaled_pos ) * sigma_step_size
1495- denoised_d = (denoised - old_denoised ) / (sigmas [i ] - sigmas [i - 1 ])
1496- x = x + (dt + s * noise_scaler (sigmas [i + 1 ])) * denoised_d
1497-
1498- if stage_used >= 3 :
1499- # Stage 3
1500- s_u = torch .sum ((sigma_pos - sigmas [i ]) / scaled_pos ) * sigma_step_size
1501- denoised_u = (denoised_d - old_denoised_d ) / ((sigmas [i ] - sigmas [i - 2 ]) / 2 )
1502- x = x + ((dt ** 2 ) / 2 + s_u * noise_scaler (sigmas [i + 1 ])) * denoised_u
1503- old_denoised_d = denoised_d
1504-
1505- if s_noise != 0 and sigmas [i + 1 ] > 0 :
1506- x = x + noise_sampler (sigmas [i ], sigmas [i + 1 ]) * s_noise * (sigmas [i + 1 ] ** 2 - sigmas [i ] ** 2 * r ** 2 ).sqrt ().nan_to_num (nan = 0.0 )
1489+ er_lambda_s , er_lambda_t = er_lambdas [i ], er_lambdas [i + 1 ]
1490+ alpha_s = sigmas [i ] / er_lambda_s
1491+ alpha_t = sigmas [i + 1 ] / er_lambda_t
1492+ r_alpha = alpha_t / alpha_s
1493+ r = noise_scaler (er_lambda_t ) / noise_scaler (er_lambda_s )
1494+
1495+ # Stage 1 Euler
1496+ x = r_alpha * r * x + alpha_t * (1 - r ) * denoised
1497+
1498+ if stage_used >= 2 :
1499+ dt = er_lambda_t - er_lambda_s
1500+ lambda_step_size = - dt / num_integration_points
1501+ lambda_pos = er_lambda_t + point_indice * lambda_step_size
1502+ scaled_pos = noise_scaler (lambda_pos )
1503+
1504+ # Stage 2
1505+ s = torch .sum (1 / scaled_pos ) * lambda_step_size
1506+ denoised_d = (denoised - old_denoised ) / (er_lambda_s - er_lambdas [i - 1 ])
1507+ x = x + alpha_t * (dt + s * noise_scaler (er_lambda_t )) * denoised_d
1508+
1509+ if stage_used >= 3 :
1510+ # Stage 3
1511+ s_u = torch .sum ((lambda_pos - er_lambda_s ) / scaled_pos ) * lambda_step_size
1512+ denoised_u = (denoised_d - old_denoised_d ) / ((er_lambda_s - er_lambdas [i - 2 ]) / 2 )
1513+ x = x + alpha_t * ((dt ** 2 ) / 2 + s_u * noise_scaler (er_lambda_t )) * denoised_u
1514+ old_denoised_d = denoised_d
1515+
1516+ if s_noise > 0 :
1517+ x = x + alpha_t * noise_sampler (sigmas [i ], sigmas [i + 1 ]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2 ).sqrt ().nan_to_num (nan = 0.0 )
15071518 old_denoised = denoised
15081519 return x
15091520
0 commit comments