Skip to content

Commit b22e97d

Browse files
authored
Migrate ER-SDE from VE to VP algorithm and add its sampler node (#8744)
Apply alpha scaling in the algorithm for reverse-time SDE and add custom ER-SDE sampler node for other solver types (SDE, ODE).
1 parent f02de13 commit b22e97d

File tree

2 files changed

+84
-31
lines changed

2 files changed

+84
-31
lines changed

comfy/k_diffusion/sampling.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,27 +1447,34 @@ def post_cfg_function(args):
14471447
old_d = d
14481448
return x
14491449

1450+
14501451
@torch.no_grad()
14511452
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
14521453
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
14531454

1455+
14541456
@torch.no_grad()
1455-
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
1456-
"""
1457-
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
1457+
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
1458+
"""Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
14581459
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
14591460
"""
14601461
extra_args = {} if extra_args is None else extra_args
14611462
seed = extra_args.get("seed", None)
14621463
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
14631464
s_in = x.new_ones([x.shape[0]])
14641465

1465-
def default_noise_scaler(sigma):
1466-
return sigma * ((sigma ** 0.3).exp() + 10.0)
1467-
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
1466+
def default_er_sde_noise_scaler(x):
1467+
return x * ((x ** 0.3).exp() + 10.0)
1468+
1469+
noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
14681470
num_integration_points = 200.0
14691471
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
14701472

1473+
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
1474+
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
1475+
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
1476+
er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t
1477+
14711478
old_denoised = None
14721479
old_denoised_d = None
14731480

@@ -1478,32 +1485,36 @@ def default_noise_scaler(sigma):
14781485
stage_used = min(max_stage, i + 1)
14791486
if sigmas[i + 1] == 0:
14801487
x = denoised
1481-
elif stage_used == 1:
1482-
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
1483-
x = r * x + (1 - r) * denoised
14841488
else:
1485-
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
1486-
x = r * x + (1 - r) * denoised
1487-
1488-
dt = sigmas[i + 1] - sigmas[i]
1489-
sigma_step_size = -dt / num_integration_points
1490-
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
1491-
scaled_pos = noise_scaler(sigma_pos)
1492-
1493-
# Stage 2
1494-
s = torch.sum(1 / scaled_pos) * sigma_step_size
1495-
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
1496-
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
1497-
1498-
if stage_used >= 3:
1499-
# Stage 3
1500-
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
1501-
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
1502-
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
1503-
old_denoised_d = denoised_d
1504-
1505-
if s_noise != 0 and sigmas[i + 1] > 0:
1506-
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
1489+
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
1490+
alpha_s = sigmas[i] / er_lambda_s
1491+
alpha_t = sigmas[i + 1] / er_lambda_t
1492+
r_alpha = alpha_t / alpha_s
1493+
r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)
1494+
1495+
# Stage 1 Euler
1496+
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
1497+
1498+
if stage_used >= 2:
1499+
dt = er_lambda_t - er_lambda_s
1500+
lambda_step_size = -dt / num_integration_points
1501+
lambda_pos = er_lambda_t + point_indice * lambda_step_size
1502+
scaled_pos = noise_scaler(lambda_pos)
1503+
1504+
# Stage 2
1505+
s = torch.sum(1 / scaled_pos) * lambda_step_size
1506+
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
1507+
x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d
1508+
1509+
if stage_used >= 3:
1510+
# Stage 3
1511+
s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
1512+
denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
1513+
x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
1514+
old_denoised_d = denoised_d
1515+
1516+
if s_noise > 0:
1517+
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
15071518
old_denoised = denoised
15081519
return x
15091520

comfy_extras/nodes_custom_sampler.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import comfy.samplers
33
import comfy.sample
44
from comfy.k_diffusion import sampling as k_diffusion_sampling
5+
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
56
import latent_preview
67
import torch
78
import comfy.utils
@@ -480,6 +481,46 @@ def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_
480481
"s_noise":s_noise })
481482
return (sampler, )
482483

484+
485+
class SamplerER_SDE(ComfyNodeABC):
486+
@classmethod
487+
def INPUT_TYPES(cls) -> InputTypeDict:
488+
return {
489+
"required": {
490+
"solver_type": (IO.COMBO, {"options": ["ER-SDE", "Reverse-time SDE", "ODE"]}),
491+
"max_stage": (IO.INT, {"default": 3, "min": 1, "max": 3}),
492+
"eta": (
493+
IO.FLOAT,
494+
{"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False, "tooltip": "Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."},
495+
),
496+
"s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False}),
497+
}
498+
}
499+
500+
RETURN_TYPES = (IO.SAMPLER,)
501+
CATEGORY = "sampling/custom_sampling/samplers"
502+
503+
FUNCTION = "get_sampler"
504+
505+
def get_sampler(self, solver_type, max_stage, eta, s_noise):
506+
if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0):
507+
eta = 0
508+
s_noise = 0
509+
510+
def reverse_time_sde_noise_scaler(x):
511+
return x ** (eta + 1)
512+
513+
if solver_type == "ER-SDE":
514+
# Use the default one in sample_er_sde()
515+
noise_scaler = None
516+
else:
517+
noise_scaler = reverse_time_sde_noise_scaler
518+
519+
sampler_name = "er_sde"
520+
sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage})
521+
return (sampler,)
522+
523+
483524
class Noise_EmptyNoise:
484525
def __init__(self):
485526
self.seed = 0
@@ -787,6 +828,7 @@ def add_noise(self, model, noise, sigmas, latent_image):
787828
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
788829
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
789830
"SamplerDPMAdaptative": SamplerDPMAdaptative,
831+
"SamplerER_SDE": SamplerER_SDE,
790832
"SplitSigmas": SplitSigmas,
791833
"SplitSigmasDenoise": SplitSigmasDenoise,
792834
"FlipSigmas": FlipSigmas,

0 commit comments

Comments
 (0)