Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ldm patched && DORA support #3454

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
Prev Previous commit
Next Next commit
calculate_sigmas - Type unification
Some parameters (default_image_only_indicator) defaults to None to avoid exceptions
Some deprecated method unwrapped
Samples/Schedulers - tested, seems to be ok
- karas, euler, heun, restart, some others && samplers
- uni_pc/uni_pc_bh2: updated from latest comfy, original code had function parameters mismatch
ControlNets - tested, seems to be ok

Did not tested: Refiners (SDXL/SD15), Inpainting
I R committed Aug 5, 2024
commit 8569b70e13795636d7932aca08374e9f94aa43ae
40 changes: 11 additions & 29 deletions ldm_patched/unipc/uni_pc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#code taken from: https://github.com/wl-zhao/UniPC and modified
# code taken from: https://github.com/wl-zhao/UniPC and modified
# updated from https://github.com/comfyanonymous/ComfyUI/blob/a178e25912b01abf436eba1cfaab316ba02d272d/comfy/extra_samplers/uni_pc.py#L874

import torch
import torch.nn.functional as F
@@ -358,9 +359,6 @@ def __init__(
thresholding=False,
max_val=1.,
variant='bh1',
noise_mask=None,
masked_image=None,
noise=None,
):
"""Construct a UniPC.

@@ -372,9 +370,6 @@ def __init__(
self.predict_x0 = predict_x0
self.thresholding = thresholding
self.max_val = max_val
self.noise_mask = noise_mask
self.masked_image = masked_image
self.noise = noise

def dynamic_thresholding_fn(self, x0, t=None):
"""
@@ -391,10 +386,7 @@ def noise_prediction_fn(self, x, t):
"""
Return the noise prediction model.
"""
if self.noise_mask is not None:
return self.model(x, t) * self.noise_mask
else:
return self.model(x, t)
return self.model(x, t)

def data_prediction_fn(self, x, t):
"""
@@ -409,8 +401,6 @@ def data_prediction_fn(self, x, t):
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
if self.noise_mask is not None:
x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image
return x0

def model_fn(self, x, t):
@@ -723,8 +713,6 @@ def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='tim
assert timesteps.shape[0] - 1 == steps
# with torch.no_grad():
for step_index in trange(steps, disable=disable_pbar):
if self.noise_mask is not None:
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
if step_index == 0:
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)]
@@ -766,7 +754,7 @@ def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='tim
model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x
if callback is not None:
callback(step_index, model_prev_list[-1], x, steps)
callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]})
else:
raise NotImplementedError()
# if denoise_to_zero:
@@ -858,7 +846,7 @@ def predict_eps_sigma(model, input, sigma_in, **kwargs):
return (input - model(input, sigma_in, **kwargs)) / sigma


def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
timesteps = sigmas.clone()
if sigmas[-1] == 0:
timesteps = sigmas[:]
@@ -867,16 +855,7 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call
timesteps = sigmas.clone()
ns = SigmaConvert()

if image is not None:
img = image * ns.marginal_alpha(timesteps[0])
if max_denoise:
noise_mult = 1.0
else:
noise_mult = ns.marginal_std(timesteps[0])
img += noise * noise_mult
else:
img = noise

noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0)
model_type = "noise"

model_fn = model_wrapper(
@@ -888,7 +867,10 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call
)

order = min(3, len(timesteps) - 2)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant)
x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
x /= ns.marginal_alpha(timesteps[-1])
return x

def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
13 changes: 9 additions & 4 deletions modules/patch.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@
import safetensors.torch
import modules.constants as constants

from ldm_patched.modules.samplers import calc_cond_uncond_batch
from ldm_patched.modules.samplers import calc_cond_batch
from ldm_patched.k_diffusion.sampling import BatchedBrownianTree
from ldm_patched.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, apply_control
from modules.patch_precision import patch_all_precision
@@ -227,14 +227,16 @@ def patched_sampling_function(model, x, timestep, uncond, cond, cond_scale, mode
pid = os.getpid()

if math.isclose(cond_scale, 1.0) and not model_options.get("disable_cfg1_optimization", False):
final_x0 = calc_cond_uncond_batch(model, cond, None, x, timestep, model_options)[0]
calc_cond_uncond_batch = tuple(calc_cond_batch(model, [cond, None], x, timestep, model_options))
final_x0 = calc_cond_uncond_batch()[0]

if patch_settings[pid].eps_record is not None:
patch_settings[pid].eps_record = ((x - final_x0) / timestep).cpu()

return final_x0

positive_x0, negative_x0 = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
calc_cond_uncond_batch = tuple(calc_cond_batch(model, [cond, uncond], x, timestep, model_options))
positive_x0, negative_x0 = calc_cond_uncond_batch

positive_eps = x - positive_x0
negative_eps = x - negative_x0
@@ -384,7 +386,10 @@ def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=
transformer_patches = transformer_options.get("patches", {})

num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)

image_only_indicator = None
if hasattr(self, "default_image_only_indicator"):
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
time_context = kwargs.get("time_context", None)

assert (y is not None) == (
53 changes: 45 additions & 8 deletions modules/sample_hijack.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,8 @@
from ldm_patched.modules.samplers import resolve_areas_and_cond_masks, calculate_start_end_timesteps, \
create_cond_with_same_area_if_none, pre_run_control, apply_empty_x_to_equal_area, encode_model_conds, CFGGuider, \
process_conds

from ldm_patched.modules.model_patcher import ModelPatcher
from modules.util import sys_dump_pythonobj

current_refiner = None
refiner_switch_step = -1
@@ -84,6 +85,38 @@ def clip_separate_after_preparation(cond, target_model=None, target_clip=None):

return results

@torch.no_grad()
@torch.inference_mode()
def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
timesteps = sigmas.clone()
if sigmas[-1] == 0:
timesteps = sigmas[:]
timesteps[-1] = 0.001
else:
timesteps = sigmas.clone()
ns = SigmaConvert()

noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0)
model_type = "noise"

model_fn = model_wrapper(
lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
ns,
model_type=model_type,
guidance_type="uncond",
model_kwargs=extra_args,
)

order = min(3, len(timesteps) - 2)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant)
x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
x /= ns.marginal_alpha(timesteps[-1])
return x

@torch.no_grad()
@torch.inference_mode()
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')

# @torch.no_grad()
# @torch.inference_mode()
@@ -224,18 +257,23 @@ def callback_wrap(step, x0, x, total_steps):
@torch.no_grad()
@torch.inference_mode()
def calculate_sigmas_scheduler_hacked(model, scheduler_name, steps):
# sys_dump_pythonobj(model, False, "- calculate_sigmas_scheduler_hacked model")
if isinstance(model, ModelPatcher):
model_sampling = model.get_model_object("model_sampling")
else:
model_sampling = model.model_sampling
if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "normal":
sigmas = normal_scheduler(model.model_sampling, steps)
sigmas = normal_scheduler(model_sampling, steps)
elif scheduler_name == "simple":
sigmas = simple_scheduler(model.model_sampling, steps)
sigmas = simple_scheduler(model_sampling, steps)
elif scheduler_name == "ddim_uniform":
sigmas = ddim_scheduler(model.model_sampling, steps)
sigmas = ddim_scheduler(model_sampling, steps)
elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model.model_sampling, steps, sgm=True)
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
elif scheduler_name == "turbo":
sigmas = SDTurboScheduler().get_sigmas(model=model, steps=steps, denoise=1.0)[0]
elif scheduler_name == "align_your_steps":
@@ -245,6 +283,5 @@ def calculate_sigmas_scheduler_hacked(model, scheduler_name, steps):
raise TypeError("error invalid scheduler")
return sigmas


ldm_patched.modules.samplers.calculate_sigmas = calculate_sigmas_scheduler_hacked
ldm_patched.modules.samplers.sample = sample_hacked
11 changes: 11 additions & 0 deletions modules/util.py
Original file line number Diff line number Diff line change
@@ -500,3 +500,14 @@ def get_image_size_info(image: np.ndarray, aspect_ratios: list) -> str:
return size_info
except Exception as e:
return f'Error reading image: {e}'

def sys_dump_pythonobj(obj, withValue, hintStr = None):
if hintStr is None:
hintStr = "- object Dump:"
print(hintStr, type(obj))
for attr in dir(obj):
if hasattr( obj, attr ):
if withValue:
print( "...%s = %s" % (attr, getattr(obj, attr)))
else:
print( "...%s = ???" % (attr))