diff --git a/args_manager.py b/args_manager.py index bb622c23a..bea9c6fda 100644 --- a/args_manager.py +++ b/args_manager.py @@ -7,7 +7,7 @@ help="Disables preset selection in Gradio.") args_parser.parser.add_argument("--language", type=str, default='default', - help="Translate UI using json files in [language] folder. " + help="Translate UI using json files in [language] folder." "For example, [--language example] will use [language/example.json] for translation.") # For example, https://github.com/lllyasviel/Fooocus/issues/849 @@ -15,18 +15,15 @@ help="Force loading models to vram when the unload can be avoided. " "Some Mac users may need this.") -args_parser.parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) +args_parser.parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme.", default=None) args_parser.parser.add_argument("--disable-image-log", action='store_true', help="Prevent writing images and logs to the outputs folder.") args_parser.parser.add_argument("--disable-analytics", action='store_true', help="Disables analytics for Gradio.") -args_parser.parser.add_argument("--disable-metadata", action='store_true', - help="Disables saving metadata to images.") - args_parser.parser.add_argument("--disable-preset-download", action='store_true', - help="Disables downloading models for presets", default=False) + help="Disables downloading models for presets.", default=False) args_parser.parser.add_argument("--disable-enhance-output-sorting", action='store_true', help="Disables enhance output sorting for final image gallery.") @@ -40,8 +37,10 @@ args_parser.parser.add_argument("--rebuild-hash-cache", help="Generates missing model and LoRA hashes.", type=int, nargs="?", metavar="CPU_NUM_THREADS", const=-1) +args_parser.parser.add_argument("--favicon-path", type=str, default=None, help="Set the favicon filepath.") +args_parser.parser.add_argument("--auth-message", type=str, default=None, help="Message to show for auth.") + args_parser.parser.set_defaults( - disable_cuda_malloc=True, in_browser=True, port=None ) diff --git a/ldm_patched/contrib/external.py b/ldm_patched/contrib/external.py index 927cd3f38..8f6f06c51 100644 --- a/ldm_patched/contrib/external.py +++ b/ldm_patched/contrib/external.py @@ -1,5 +1,3 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import torch import os @@ -10,14 +8,15 @@ import math import time import random +import logging -from PIL import Image, ImageOps, ImageSequence +from PIL import Image, ImageOps, ImageSequence, ImageFile from PIL.PngImagePlugin import PngInfo + import numpy as np import safetensors.torch -pass # sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "ldm_patched")) - +sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) import ldm_patched.modules.diffusers_load import ldm_patched.modules.samplers @@ -33,21 +32,27 @@ import importlib -import ldm_patched.utils.path_utils -import ldm_patched.utils.latent_visualization +import ldm_patched.utils.path_utils as folder_paths +import ldm_patched.utils.latent_visualization as latent_preview +import ldm_patched.utils.node_helpers as node_helpers + def before_node_execution(): ldm_patched.modules.model_management.throw_exception_if_processing_interrupted() + def interrupt_processing(value=True): ldm_patched.modules.model_management.interrupt_current_processing(value) -MAX_RESOLUTION=8192 + +MAX_RESOLUTION = 16384 + class CLIPTextEncode: @classmethod def INPUT_TYPES(s): - return {"required": {"text": ("STRING", {"multiline": True}), "clip": ("CLIP", )}} + return {"required": {"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP",)}} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" @@ -56,26 +61,31 @@ def INPUT_TYPES(s): def encode(self, clip, text): tokens = clip.tokenize(text) cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) - return ([[cond, {"pooled_output": pooled}]], ) + return ([[cond, {"pooled_output": pooled}]],) + class ConditioningCombine: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}} + return {"required": {"conditioning_1": ("CONDITIONING",), "conditioning_2": ("CONDITIONING",)}} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "combine" CATEGORY = "conditioning" def combine(self, conditioning_1, conditioning_2): - return (conditioning_1 + conditioning_2, ) + return (conditioning_1 + conditioning_2,) + -class ConditioningAverage : +class ConditioningAverage: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ), - "conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + return {"required": {"conditioning_to": ("CONDITIONING",), "conditioning_from": ("CONDITIONING",), + "conditioning_to_strength": ( + "FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "addWeighted" @@ -85,7 +95,8 @@ def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_streng out = [] if len(conditioning_from) > 1: - print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") + logging.warning( + "Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") cond_from = conditioning_from[0][0] pooled_output_from = conditioning_from[0][1].get("pooled_output", None) @@ -93,20 +104,22 @@ def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_streng for i in range(len(conditioning_to)): t1 = conditioning_to[i][0] pooled_output_to = conditioning_to[i][1].get("pooled_output", pooled_output_from) - t0 = cond_from[:,:t1.shape[1]] + t0 = cond_from[:, :t1.shape[1]] if t0.shape[1] < t1.shape[1]: t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1) tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength)) t_to = conditioning_to[i][1].copy() if pooled_output_from is not None and pooled_output_to is not None: - t_to["pooled_output"] = torch.mul(pooled_output_to, conditioning_to_strength) + torch.mul(pooled_output_from, (1.0 - conditioning_to_strength)) + t_to["pooled_output"] = torch.mul(pooled_output_to, conditioning_to_strength) + torch.mul( + pooled_output_from, (1.0 - conditioning_to_strength)) elif pooled_output_from is not None: t_to["pooled_output"] = pooled_output_from n = [tw, t_to] out.append(n) - return (out, ) + return (out,) + class ConditioningConcat: @classmethod @@ -114,7 +127,8 @@ def INPUT_TYPES(s): return {"required": { "conditioning_to": ("CONDITIONING",), "conditioning_from": ("CONDITIONING",), - }} + }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "concat" @@ -124,101 +138,115 @@ def concat(self, conditioning_to, conditioning_from): out = [] if len(conditioning_from) > 1: - print("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") + logging.warning( + "Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") cond_from = conditioning_from[0][0] for i in range(len(conditioning_to)): t1 = conditioning_to[i][0] - tw = torch.cat((t1, cond_from),1) + tw = torch.cat((t1, cond_from), 1) n = [tw, conditioning_to[i][1].copy()] out.append(n) - return (out, ) + return (out,) + class ConditioningSetArea: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + return {"required": {"conditioning": ("CONDITIONING",), + "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" CATEGORY = "conditioning" def append(self, conditioning, width, height, x, y, strength): - c = [] - for t in conditioning: - n = [t[0], t[1].copy()] - n[1]['area'] = (height // 8, width // 8, y // 8, x // 8) - n[1]['strength'] = strength - n[1]['set_area_to_bounds'] = False - c.append(n) - return (c, ) + c = node_helpers.conditioning_set_values(conditioning, {"area": (height // 8, width // 8, y // 8, x // 8), + "strength": strength, + "set_area_to_bounds": False}) + return (c,) + class ConditioningSetAreaPercentage: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - "width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), - "height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), - "x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}), - "y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + return {"required": {"conditioning": ("CONDITIONING",), + "width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), + "height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), + "x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}), + "y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" CATEGORY = "conditioning" def append(self, conditioning, width, height, x, y, strength): - c = [] - for t in conditioning: - n = [t[0], t[1].copy()] - n[1]['area'] = ("percentage", height, width, y, x) - n[1]['strength'] = strength - n[1]['set_area_to_bounds'] = False - c.append(n) - return (c, ) + c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", height, width, y, x), + "strength": strength, + "set_area_to_bounds": False}) + return (c,) + + +class ConditioningSetAreaStrength: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING",), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "conditioning" + + def append(self, conditioning, strength): + c = node_helpers.conditioning_set_values(conditioning, {"strength": strength}) + return (c,) + class ConditioningSetMask: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - "mask": ("MASK", ), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "set_cond_area": (["default", "mask bounds"],), + return {"required": {"conditioning": ("CONDITIONING",), + "mask": ("MASK",), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" CATEGORY = "conditioning" def append(self, conditioning, mask, set_cond_area, strength): - c = [] set_area_to_bounds = False if set_cond_area != "default": set_area_to_bounds = True if len(mask.shape) < 3: mask = mask.unsqueeze(0) - for t in conditioning: - n = [t[0], t[1].copy()] - _, h, w = mask.shape - n[1]['mask'] = mask - n[1]['set_area_to_bounds'] = set_area_to_bounds - n[1]['mask_strength'] = strength - c.append(n) - return (c, ) + + c = node_helpers.conditioning_set_values(conditioning, {"mask": mask, + "set_area_to_bounds": set_area_to_bounds, + "mask_strength": strength}) + return (c,) + class ConditioningZeroOut: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", )}} + return {"required": {"conditioning": ("CONDITIONING",)}} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "zero_out" @@ -232,118 +260,115 @@ def zero_out(self, conditioning): d["pooled_output"] = torch.zeros_like(d["pooled_output"]) n = [torch.zeros_like(t[0]), d] c.append(n) - return (c, ) + return (c,) + class ConditioningSetTimestepRange: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), + return {"required": {"conditioning": ("CONDITIONING",), "start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "set_range" CATEGORY = "advanced/conditioning" def set_range(self, conditioning, start, end): - c = [] - for t in conditioning: - d = t[1].copy() - d['start_percent'] = start - d['end_percent'] = end - n = [t[0], d] - c.append(n) - return (c, ) + c = node_helpers.conditioning_set_values(conditioning, {"start_percent": start, + "end_percent": end}) + return (c,) + class VAEDecode: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} + return {"required": {"samples": ("LATENT",), "vae": ("VAE",)}} + RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" CATEGORY = "latent" def decode(self, vae, samples): - return (vae.decode(samples["samples"]), ) + return (vae.decode(samples["samples"]),) + class VAEDecodeTiled: @classmethod def INPUT_TYPES(s): - return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), + return {"required": {"samples": ("LATENT",), "vae": ("VAE",), "tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64}) - }} + }} + RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" CATEGORY = "_for_testing" def decode(self, vae, samples, tile_size): - return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), ) + return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ),) + class VAEEncode: @classmethod def INPUT_TYPES(s): - return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} + return {"required": {"pixels": ("IMAGE",), "vae": ("VAE",)}} + RETURN_TYPES = ("LATENT",) FUNCTION = "encode" CATEGORY = "latent" - @staticmethod - def vae_encode_crop_pixels(pixels): - x = (pixels.shape[1] // 8) * 8 - y = (pixels.shape[2] // 8) * 8 - if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % 8) // 2 - y_offset = (pixels.shape[2] % 8) // 2 - pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] - return pixels - def encode(self, vae, pixels): - pixels = self.vae_encode_crop_pixels(pixels) - t = vae.encode(pixels[:,:,:,:3]) - return ({"samples":t}, ) + t = vae.encode(pixels[:, :, :, :3]) + return ({"samples": t},) + class VAEEncodeTiled: @classmethod def INPUT_TYPES(s): - return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ), + return {"required": {"pixels": ("IMAGE",), "vae": ("VAE",), "tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64}) - }} + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "encode" CATEGORY = "_for_testing" def encode(self, vae, pixels, tile_size): - pixels = VAEEncode.vae_encode_crop_pixels(pixels) - t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, ) - return ({"samples":t}, ) + t = vae.encode_tiled(pixels[:, :, :, :3], tile_x=tile_size, tile_y=tile_size, ) + return ({"samples": t},) + class VAEEncodeForInpaint: @classmethod def INPUT_TYPES(s): - return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}} + return {"required": {"pixels": ("IMAGE",), "vae": ("VAE",), "mask": ("MASK",), + "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}), }} + RETURN_TYPES = ("LATENT",) FUNCTION = "encode" CATEGORY = "latent/inpaint" def encode(self, vae, pixels, mask, grow_mask_by=6): - x = (pixels.shape[1] // 8) * 8 - y = (pixels.shape[2] // 8) * 8 - mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") + x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio + y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), + size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % 8) // 2 - y_offset = (pixels.shape[2] % 8) // 2 - pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] - mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] + x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2 + y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2 + pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] + mask = mask[:, :, x_offset:x + x_offset, y_offset:y + y_offset] - #grow mask by a few pixels to keep things seamless in latent space + # grow mask by a few pixels to keep things seamless in latent space if grow_mask_by == 0: mask_erosion = mask else: @@ -354,25 +379,25 @@ def encode(self, vae, pixels, mask, grow_mask_by=6): m = (1.0 - mask.round()).squeeze(1) for i in range(3): - pixels[:,:,:,i] -= 0.5 - pixels[:,:,:,i] *= m - pixels[:,:,:,i] += 0.5 + pixels[:, :, :, i] -= 0.5 + pixels[:, :, :, i] *= m + pixels[:, :, :, i] += 0.5 t = vae.encode(pixels) - return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) + return ({"samples": t, "noise_mask": (mask_erosion[:, :, :x, :y].round())},) class InpaintModelConditioning: @classmethod def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "pixels": ("IMAGE", ), - "mask": ("MASK", ), + return {"required": {"positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "vae": ("VAE",), + "pixels": ("IMAGE",), + "mask": ("MASK",), }} - RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") FUNCTION = "encode" @@ -381,21 +406,22 @@ def INPUT_TYPES(s): def encode(self, positive, negative, pixels, vae, mask): x = (pixels.shape[1] // 8) * 8 y = (pixels.shape[2] // 8) * 8 - mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), + size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") orig_pixels = pixels pixels = orig_pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: x_offset = (pixels.shape[1] % 8) // 2 y_offset = (pixels.shape[2] % 8) // 2 - pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] - mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] + pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] + mask = mask[:, :, x_offset:x + x_offset, y_offset:y + y_offset] m = (1.0 - mask.round()).squeeze(1) for i in range(3): - pixels[:,:,:,i] -= 0.5 - pixels[:,:,:,i] *= m - pixels[:,:,:,i] += 0.5 + pixels[:, :, :, i] -= 0.5 + pixels[:, :, :, i] *= m + pixels[:, :, :, i] += 0.5 concat_latent = vae.encode(pixels) orig_latent = vae.encode(orig_pixels) @@ -406,27 +432,23 @@ def encode(self, positive, negative, pixels, vae, mask): out = [] for conditioning in [positive, negative]: - c = [] - for t in conditioning: - d = t[1].copy() - d["concat_latent_image"] = concat_latent - d["concat_mask"] = mask - n = [t[0], d] - c.append(n) + c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent, + "concat_mask": mask}) out.append(c) return (out[0], out[1], out_latent) class SaveLatent: def __init__(self): - self.output_dir = ldm_patched.utils.path_utils.get_output_directory() + self.output_dir = folder_paths.get_output_directory() @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT", ), - "filename_prefix": ("STRING", {"default": "latents/ldm_patched"})}, + return {"required": {"samples": ("LATENT",), + "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } + RETURN_TYPES = () FUNCTION = "save" @@ -435,7 +457,7 @@ def INPUT_TYPES(s): CATEGORY = "_for_testing" def save(self, samples, filename_prefix="ldm_patched", prompt=None, extra_pnginfo=None): - full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, self.output_dir) + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) # support save metadata for latent sharing prompt_info = "" @@ -443,7 +465,7 @@ def save(self, samples, filename_prefix="ldm_patched", prompt=None, extra_pnginf prompt_info = json.dumps(prompt) metadata = None - if not args.disable_server_info: + if not args.disable_metadata: metadata = {"prompt": prompt_info} if extra_pnginfo is not None: for x in extra_pnginfo: @@ -465,33 +487,34 @@ def save(self, samples, filename_prefix="ldm_patched", prompt=None, extra_pnginf output["latent_format_version_0"] = torch.tensor([]) ldm_patched.modules.utils.save_torch_file(output, file, metadata=metadata) - return { "ui": { "latents": results } } + return {"ui": {"latents": results}} class LoadLatent: @classmethod def INPUT_TYPES(s): - input_dir = ldm_patched.utils.path_utils.get_input_directory() - files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")] + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if + os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")] return {"required": {"latent": [sorted(files), ]}, } CATEGORY = "_for_testing" - RETURN_TYPES = ("LATENT", ) + RETURN_TYPES = ("LATENT",) FUNCTION = "load" def load(self, latent): - latent_path = ldm_patched.utils.path_utils.get_annotated_filepath(latent) + latent_path = folder_paths.get_annotated_filepath(latent) latent = safetensors.torch.load_file(latent_path, device="cpu") multiplier = 1.0 if "latent_format_version_0" not in latent: multiplier = 1.0 / 0.18215 samples = {"samples": latent["latent_tensor"].float() * multiplier} - return (samples, ) + return (samples,) @classmethod def IS_CHANGED(s, latent): - image_path = ldm_patched.utils.path_utils.get_annotated_filepath(latent) + image_path = folder_paths.get_annotated_filepath(latent) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) @@ -499,7 +522,7 @@ def IS_CHANGED(s, latent): @classmethod def VALIDATE_INPUTS(s, latent): - if not ldm_patched.utils.path_utils.exists_annotated_filepath(latent): + if not folder_paths.exists_annotated_filepath(latent): return "Invalid latent file: {}".format(latent) return True @@ -507,81 +530,94 @@ def VALIDATE_INPUTS(s, latent): class CheckpointLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "config_name": (ldm_patched.utils.path_utils.get_filename_list("configs"), ), - "ckpt_name": (ldm_patched.utils.path_utils.get_filename_list("checkpoints"), )}} + return {"required": {"config_name": (folder_paths.get_filename_list("configs"),), + "ckpt_name": (folder_paths.get_filename_list("checkpoints"),)}} + RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" CATEGORY = "advanced/loaders" def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True): - config_path = ldm_patched.utils.path_utils.get_full_path("configs", config_name) - ckpt_path = ldm_patched.utils.path_utils.get_full_path("checkpoints", ckpt_name) - return ldm_patched.modules.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings")) + config_path = folder_paths.get_full_path("configs", config_name) + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + return ldm_patched.modules.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, + embedding_directory=folder_paths.get_folder_paths("embeddings")) + class CheckpointLoaderSimple: @classmethod def INPUT_TYPES(s): - return {"required": { "ckpt_name": (ldm_patched.utils.path_utils.get_filename_list("checkpoints"), ), + return {"required": {"ckpt_name": (folder_paths.get_filename_list("checkpoints"),), }} + RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" CATEGORY = "loaders" def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): - ckpt_path = ldm_patched.utils.path_utils.get_full_path("checkpoints", ckpt_name) - out = ldm_patched.modules.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings")) + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + out = ldm_patched.modules.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, + embedding_directory=folder_paths.get_folder_paths("embeddings")) return out[:3] + class DiffusersLoader: @classmethod def INPUT_TYPES(cls): paths = [] - for search_path in ldm_patched.utils.path_utils.get_folder_paths("diffusers"): + for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): for root, subdir, files in os.walk(search_path, followlinks=True): if "model_index.json" in files: paths.append(os.path.relpath(root, start=search_path)) return {"required": {"model_path": (paths,), }} + RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" CATEGORY = "advanced/loaders/deprecated" def load_checkpoint(self, model_path, output_vae=True, output_clip=True): - for search_path in ldm_patched.utils.path_utils.get_folder_paths("diffusers"): + for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): path = os.path.join(search_path, model_path) if os.path.exists(path): model_path = path break - return ldm_patched.modules.diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings")) + return ldm_patched.modules.diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, + embedding_directory=folder_paths.get_folder_paths("embeddings")) class unCLIPCheckpointLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "ckpt_name": (ldm_patched.utils.path_utils.get_filename_list("checkpoints"), ), + return {"required": {"ckpt_name": (folder_paths.get_filename_list("checkpoints"),), }} + RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION") FUNCTION = "load_checkpoint" CATEGORY = "loaders" def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): - ckpt_path = ldm_patched.utils.path_utils.get_full_path("checkpoints", ckpt_name) - out = ldm_patched.modules.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings")) + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + out = ldm_patched.modules.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, + output_clipvision=True, + embedding_directory=folder_paths.get_folder_paths("embeddings")) return out + class CLIPSetLastLayer: @classmethod def INPUT_TYPES(s): - return {"required": { "clip": ("CLIP", ), - "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}), - }} + return {"required": {"clip": ("CLIP",), + "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}), + }} + RETURN_TYPES = ("CLIP",) FUNCTION = "set_last_layer" @@ -592,18 +628,20 @@ def set_last_layer(self, clip, stop_at_clip_layer): clip.clip_layer(stop_at_clip_layer) return (clip,) + class LoraLoader: def __init__(self): self.loaded_lora = None @classmethod def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "clip": ("CLIP", ), - "lora_name": (ldm_patched.utils.path_utils.get_filename_list("loras"), ), - "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), - "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), - }} + return {"required": {"model": ("MODEL",), + "clip": ("CLIP",), + "lora_name": (folder_paths.get_filename_list("loras"),), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL", "CLIP") FUNCTION = "load_lora" @@ -613,7 +651,7 @@ def load_lora(self, model, clip, lora_name, strength_model, strength_clip): if strength_model == 0 and strength_clip == 0: return (model, clip) - lora_path = ldm_patched.utils.path_utils.get_full_path("loras", lora_name) + lora_path = folder_paths.get_full_path("loras", lora_name) lora = None if self.loaded_lora is not None: if self.loaded_lora[0] == lora_path: @@ -630,24 +668,27 @@ def load_lora(self, model, clip, lora_name, strength_model, strength_clip): model_lora, clip_lora = ldm_patched.modules.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) return (model_lora, clip_lora) + class LoraLoaderModelOnly(LoraLoader): @classmethod def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "lora_name": (ldm_patched.utils.path_utils.get_filename_list("loras"), ), - "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), - }} + return {"required": {"model": ("MODEL",), + "lora_name": (folder_paths.get_filename_list("loras"),), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) FUNCTION = "load_lora_model_only" def load_lora_model_only(self, model, lora_name, strength_model): return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) + class VAELoader: @staticmethod def vae_list(): - vaes = ldm_patched.utils.path_utils.get_filename_list("vae") - approx_vaes = ldm_patched.utils.path_utils.get_filename_list("vae_approx") + vaes = folder_paths.get_filename_list("vae") + approx_vaes = folder_paths.get_filename_list("vae_approx") sdxl_taesd_enc = False sdxl_taesd_dec = False sd1_taesd_enc = False @@ -671,16 +712,16 @@ def vae_list(): @staticmethod def load_taesd(name): sd = {} - approx_vaes = ldm_patched.utils.path_utils.get_filename_list("vae_approx") + approx_vaes = folder_paths.get_filename_list("vae_approx") encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes)) decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes)) - enc = ldm_patched.modules.utils.load_torch_file(ldm_patched.utils.path_utils.get_full_path("vae_approx", encoder)) + enc = ldm_patched.modules.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder)) for k in enc: sd["taesd_encoder.{}".format(k)] = enc[k] - dec = ldm_patched.modules.utils.load_torch_file(ldm_patched.utils.path_utils.get_full_path("vae_approx", decoder)) + dec = ldm_patched.modules.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder)) for k in dec: sd["taesd_decoder.{}".format(k)] = dec[k] @@ -692,26 +733,28 @@ def load_taesd(name): @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (s.vae_list(), )}} + return {"required": {"vae_name": (s.vae_list(),)}} + RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" CATEGORY = "loaders" - #TODO: scale factor? + # TODO: scale factor? def load_vae(self, vae_name): if vae_name in ["taesd", "taesdxl"]: sd = self.load_taesd(vae_name) else: - vae_path = ldm_patched.utils.path_utils.get_full_path("vae", vae_name) + vae_path = folder_paths.get_full_path("vae", vae_name) sd = ldm_patched.modules.utils.load_torch_file(vae_path) vae = ldm_patched.modules.sd.VAE(sd=sd) return (vae,) + class ControlNetLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "control_net_name": (ldm_patched.utils.path_utils.get_filename_list("controlnet"), )}} + return {"required": {"control_net_name": (folder_paths.get_filename_list("controlnet"),)}} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" @@ -719,15 +762,16 @@ def INPUT_TYPES(s): CATEGORY = "loaders" def load_controlnet(self, control_net_name): - controlnet_path = ldm_patched.utils.path_utils.get_full_path("controlnet", control_net_name) + controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) controlnet = ldm_patched.modules.controlnet.load_controlnet(controlnet_path) return (controlnet,) + class DiffControlNetLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "control_net_name": (ldm_patched.utils.path_utils.get_filename_list("controlnet"), )}} + return {"required": {"model": ("MODEL",), + "control_net_name": (folder_paths.get_filename_list("controlnet"),)}} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" @@ -735,7 +779,7 @@ def INPUT_TYPES(s): CATEGORY = "loaders" def load_controlnet(self, model, control_net_name): - controlnet_path = ldm_patched.utils.path_utils.get_full_path("controlnet", control_net_name) + controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) controlnet = ldm_patched.modules.controlnet.load_controlnet(controlnet_path, model) return (controlnet,) @@ -743,11 +787,12 @@ def load_controlnet(self, model, control_net_name): class ControlNetApply: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - "control_net": ("CONTROL_NET", ), - "image": ("IMAGE", ), + return {"required": {"conditioning": ("CONDITIONING",), + "control_net": ("CONTROL_NET",), + "image": ("IMAGE",), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}) }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_controlnet" @@ -755,10 +800,10 @@ def INPUT_TYPES(s): def apply_controlnet(self, conditioning, control_net, image, strength): if strength == 0: - return (conditioning, ) + return (conditioning,) c = [] - control_hint = image.movedim(-1,1) + control_hint = image.movedim(-1, 1) for t in conditioning: n = [t[0], t[1].copy()] c_net = control_net.copy().set_cond_hint(control_hint, strength) @@ -767,22 +812,22 @@ def apply_controlnet(self, conditioning, control_net, image, strength): n[1]['control'] = c_net n[1]['control_apply_to_uncond'] = True c.append(n) - return (c, ) + return (c,) class ControlNetApplyAdvanced: @classmethod def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "control_net": ("CONTROL_NET", ), - "image": ("IMAGE", ), + return {"required": {"positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "control_net": ("CONTROL_NET",), + "image": ("IMAGE",), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) }} - RETURN_TYPES = ("CONDITIONING","CONDITIONING") + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") RETURN_NAMES = ("positive", "negative") FUNCTION = "apply_controlnet" @@ -792,7 +837,7 @@ def apply_controlnet(self, positive, negative, control_net, image, strength, sta if strength == 0: return (positive, negative) - control_hint = image.movedim(-1,1) + control_hint = image.movedim(-1, 1) cnets = {} out = [] @@ -820,70 +865,87 @@ def apply_controlnet(self, positive, negative, control_net, image, strength, sta class UNETLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "unet_name": (ldm_patched.utils.path_utils.get_filename_list("unet"), ), + return {"required": {"unet_name": (folder_paths.get_filename_list("unet"),), }} + RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" CATEGORY = "advanced/loaders" def load_unet(self, unet_name): - unet_path = ldm_patched.utils.path_utils.get_full_path("unet", unet_name) + unet_path = folder_paths.get_full_path("unet", unet_name) model = ldm_patched.modules.sd.load_unet(unet_path) return (model,) + class CLIPLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name": (ldm_patched.utils.path_utils.get_filename_list("clip"), ), + return {"required": {"clip_name": (folder_paths.get_filename_list("clip"),), + "type": (["stable_diffusion", "stable_cascade"],), }} + RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" CATEGORY = "advanced/loaders" - def load_clip(self, clip_name): - clip_path = ldm_patched.utils.path_utils.get_full_path("clip", clip_name) - clip = ldm_patched.modules.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings")) + def load_clip(self, clip_name, type="stable_diffusion"): + clip_type = ldm_patched.modules.sd.CLIPType.STABLE_DIFFUSION + if type == "stable_cascade": + clip_type = ldm_patched.modules.sd.CLIPType.STABLE_CASCADE + + clip_path = folder_paths.get_full_path("clip", clip_name) + clip = ldm_patched.modules.sd.load_clip(ckpt_paths=[clip_path], + embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) return (clip,) + class DualCLIPLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name1": (ldm_patched.utils.path_utils.get_filename_list("clip"), ), "clip_name2": (ldm_patched.utils.path_utils.get_filename_list("clip"), ), + return {"required": {"clip_name1": (folder_paths.get_filename_list("clip"),), + "clip_name2": (folder_paths.get_filename_list("clip"),), }} + RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" CATEGORY = "advanced/loaders" def load_clip(self, clip_name1, clip_name2): - clip_path1 = ldm_patched.utils.path_utils.get_full_path("clip", clip_name1) - clip_path2 = ldm_patched.utils.path_utils.get_full_path("clip", clip_name2) - clip = ldm_patched.modules.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings")) + clip_path1 = folder_paths.get_full_path("clip", clip_name1) + clip_path2 = folder_paths.get_full_path("clip", clip_name2) + clip = ldm_patched.modules.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], + embedding_directory=folder_paths.get_folder_paths("embeddings")) return (clip,) + class CLIPVisionLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name": (ldm_patched.utils.path_utils.get_filename_list("clip_vision"), ), + return {"required": {"clip_name": (folder_paths.get_filename_list("clip_vision"),), }} + RETURN_TYPES = ("CLIP_VISION",) FUNCTION = "load_clip" CATEGORY = "loaders" def load_clip(self, clip_name): - clip_path = ldm_patched.utils.path_utils.get_full_path("clip_vision", clip_name) + clip_path = folder_paths.get_full_path("clip_vision", clip_name) clip_vision = ldm_patched.modules.clip_vision.load(clip_path) return (clip_vision,) + class CLIPVisionEncode: @classmethod def INPUT_TYPES(s): - return {"required": { "clip_vision": ("CLIP_VISION",), - "image": ("IMAGE",) + return {"required": {"clip_vision": ("CLIP_VISION",), + "image": ("IMAGE",) }} + RETURN_TYPES = ("CLIP_VISION_OUTPUT",) FUNCTION = "encode" @@ -893,10 +955,11 @@ def encode(self, clip_vision, image): output = clip_vision.encode_image(image) return (output,) + class StyleModelLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "style_model_name": (ldm_patched.utils.path_utils.get_filename_list("style_models"), )}} + return {"required": {"style_model_name": (folder_paths.get_filename_list("style_models"),)}} RETURN_TYPES = ("STYLE_MODEL",) FUNCTION = "load_style_model" @@ -904,7 +967,7 @@ def INPUT_TYPES(s): CATEGORY = "loaders" def load_style_model(self, style_model_name): - style_model_path = ldm_patched.utils.path_utils.get_full_path("style_models", style_model_name) + style_model_path = folder_paths.get_full_path("style_models", style_model_name) style_model = ldm_patched.modules.sd.load_style_model(style_model_path) return (style_model,) @@ -912,10 +975,11 @@ def load_style_model(self, style_model_name): class StyleModelApply: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - "style_model": ("STYLE_MODEL", ), - "clip_vision_output": ("CLIP_VISION_OUTPUT", ), + return {"required": {"conditioning": ("CONDITIONING",), + "style_model": ("STYLE_MODEL",), + "clip_vision_output": ("CLIP_VISION_OUTPUT",), }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_stylemodel" @@ -927,16 +991,18 @@ def apply_stylemodel(self, clip_vision_output, style_model, conditioning): for t in conditioning: n = [torch.cat((t[0], cond), dim=1), t[1].copy()] c.append(n) - return (c, ) + return (c,) + class unCLIPConditioning: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - "clip_vision_output": ("CLIP_VISION_OUTPUT", ), + return {"required": {"conditioning": ("CONDITIONING",), + "clip_vision_output": ("CLIP_VISION_OUTPUT",), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_adm" @@ -944,24 +1010,26 @@ def INPUT_TYPES(s): def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation): if strength == 0: - return (conditioning, ) + return (conditioning,) c = [] for t in conditioning: o = t[1].copy() - x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation} + x = {"clip_vision_output": clip_vision_output, "strength": strength, + "noise_augmentation": noise_augmentation} if "unclip_conditioning" in o: o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x] else: o["unclip_conditioning"] = [x] n = [t[0], o] c.append(n) - return (c, ) + return (c,) + class GLIGENLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "gligen_name": (ldm_patched.utils.path_utils.get_filename_list("gligen"), )}} + return {"required": {"gligen_name": (folder_paths.get_filename_list("gligen"),)}} RETURN_TYPES = ("GLIGEN",) FUNCTION = "load_gligen" @@ -969,22 +1037,24 @@ def INPUT_TYPES(s): CATEGORY = "loaders" def load_gligen(self, gligen_name): - gligen_path = ldm_patched.utils.path_utils.get_full_path("gligen", gligen_name) + gligen_path = folder_paths.get_full_path("gligen", gligen_name) gligen = ldm_patched.modules.sd.load_gligen(gligen_path) return (gligen,) + class GLIGENTextBoxApply: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning_to": ("CONDITIONING", ), - "clip": ("CLIP", ), - "gligen_textbox_model": ("GLIGEN", ), - "text": ("STRING", {"multiline": True}), - "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + return {"required": {"conditioning_to": ("CONDITIONING",), + "clip": ("CLIP",), + "gligen_textbox_model": ("GLIGEN",), + "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" @@ -992,7 +1062,7 @@ def INPUT_TYPES(s): def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y): c = [] - cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True) + cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled="unprojected") for t in conditioning_to: n = [t[0], t[1].copy()] position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)] @@ -1002,7 +1072,8 @@ def append(self, conditioning_to, clip, gligen_textbox_model, text, width, heigh n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params) c.append(n) - return (c, ) + return (c,) + class EmptyLatentImage: def __init__(self): @@ -1010,9 +1081,10 @@ def __init__(self): @classmethod def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} + return {"required": {"width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} + RETURN_TYPES = ("LATENT",) FUNCTION = "generate" @@ -1020,16 +1092,17 @@ def INPUT_TYPES(s): def generate(self, width, height, batch_size=1): latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) - return ({"samples":latent}, ) + return ({"samples": latent},) class LatentFromBatch: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), - "length": ("INT", {"default": 1, "min": 1, "max": 64}), - }} + return {"required": {"samples": ("LATENT",), + "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), + "length": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "frombatch" @@ -1050,17 +1123,19 @@ def frombatch(self, samples, batch_index, length): masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] s["noise_mask"] = masks[batch_index:batch_index + length].clone() if "batch_index" not in s: - s["batch_index"] = [x for x in range(batch_index, batch_index+length)] + s["batch_index"] = [x for x in range(batch_index, batch_index + length)] else: s["batch_index"] = samples["batch_index"][batch_index:batch_index + length] return (s,) - + + class RepeatLatentBatch: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "amount": ("INT", {"default": 1, "min": 1, "max": 64}), - }} + return {"required": {"samples": ("LATENT",), + "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "repeat" @@ -1069,28 +1144,30 @@ def INPUT_TYPES(s): def repeat(self, samples, amount): s = samples.copy() s_in = samples["samples"] - - s["samples"] = s_in.repeat((amount, 1,1,1)) + + s["samples"] = s_in.repeat((amount, 1, 1, 1)) if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1: masks = samples["noise_mask"] if masks.shape[0] < s_in.shape[0]: masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] - s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1)) + s["noise_mask"] = samples["noise_mask"].repeat((amount, 1, 1, 1)) if "batch_index" in s: offset = max(s["batch_index"]) - min(s["batch_index"]) + 1 s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]] return (s,) + class LatentUpscale: upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"] crop_methods = ["disabled", "center"] @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), - "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "crop": (s.crop_methods,)}} + return {"required": {"samples": ("LATENT",), "upscale_method": (s.upscale_methods,), + "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "crop": (s.crop_methods,)}} + RETURN_TYPES = ("LATENT",) FUNCTION = "upscale" @@ -1115,13 +1192,15 @@ def upscale(self, samples, upscale_method, width, height, crop): s["samples"] = ldm_patched.modules.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) return (s,) + class LatentUpscaleBy: upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"] @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), - "scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}} + return {"required": {"samples": ("LATENT",), "upscale_method": (s.upscale_methods,), + "scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}), }} + RETURN_TYPES = ("LATENT",) FUNCTION = "upscale" @@ -1134,12 +1213,14 @@ def upscale(self, samples, upscale_method, scale_by): s["samples"] = ldm_patched.modules.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled") return (s,) + class LatentRotate: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],), - }} + return {"required": {"samples": ("LATENT",), + "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],), + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "rotate" @@ -1158,12 +1239,14 @@ def rotate(self, samples, rotation): s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2]) return (s,) + class LatentFlip: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "flip_method": (["x-axis: vertically", "y-axis: horizontally"],), - }} + return {"required": {"samples": ("LATENT",), + "flip_method": (["x-axis: vertically", "y-axis: horizontally"],), + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "flip" @@ -1178,22 +1261,24 @@ def flip(self, samples, flip_method): return (s,) + class LatentComposite: @classmethod def INPUT_TYPES(s): - return {"required": { "samples_to": ("LATENT",), - "samples_from": ("LATENT",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - }} + return {"required": {"samples_to": ("LATENT",), + "samples_from": ("LATENT",), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "composite" CATEGORY = "latent" def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0): - x = x // 8 + x = x // 8 y = y // 8 feather = feather // 8 samples_out = samples_to.copy() @@ -1201,25 +1286,37 @@ def composite(self, samples_to, samples_from, x, y, composite_method="normal", f samples_to = samples_to["samples"] samples_from = samples_from["samples"] if feather == 0: - s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] + s[:, :, y:y + samples_from.shape[2], x:x + samples_from.shape[3]] = samples_from[:, :, + :samples_to.shape[2] - y, + :samples_to.shape[3] - x] else: - samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] + samples_from = samples_from[:, :, :samples_to.shape[2] - y, :samples_to.shape[3] - x] mask = torch.ones_like(samples_from) for t in range(feather): if y != 0: - mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1)) + mask[:, :, t:1 + t, :] *= ((1.0 / feather) * (t + 1)) if y + samples_from.shape[2] < samples_to.shape[2]: - mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1)) + mask[:, :, mask.shape[2] - 1 - t: mask.shape[2] - t, :] *= ((1.0 / feather) * (t + 1)) if x != 0: - mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1)) + mask[:, :, :, t:1 + t] *= ((1.0 / feather) * (t + 1)) if x + samples_from.shape[3] < samples_to.shape[3]: - mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1)) + mask[:, :, :, mask.shape[3] - 1 - t: mask.shape[3] - t] *= ((1.0 / feather) * (t + 1)) rev_mask = torch.ones_like(mask) - mask - s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask + s[:, :, y:y + samples_from.shape[2], x:x + samples_from.shape[3]] = samples_from[:, :, + :samples_to.shape[2] - y, + :samples_to.shape[3] - x] * mask + s[:, + :, + y:y + + samples_from.shape[ + 2], + x:x + + samples_from.shape[ + 3]] * rev_mask samples_out["samples"] = s return (samples_out,) + class LatentBlend: @classmethod def INPUT_TYPES(s): @@ -1239,7 +1336,7 @@ def INPUT_TYPES(s): CATEGORY = "_for_testing" - def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"): + def blend(self, samples1, samples2, blend_factor: float, blend_mode: str = "normal"): samples_out = samples1.copy() samples1 = samples1["samples"] @@ -1247,7 +1344,8 @@ def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal" if samples1.shape != samples2.shape: samples2.permute(0, 3, 1, 2) - samples2 = ldm_patched.modules.utils.common_upscale(samples2, samples1.shape[3], samples1.shape[2], 'bicubic', crop='center') + samples2 = ldm_patched.modules.utils.common_upscale(samples2, samples1.shape[3], samples1.shape[2], 'bicubic', + crop='center') samples2.permute(0, 2, 3, 1) samples_blended = self.blend_mode(samples1, samples2, blend_mode) @@ -1261,15 +1359,17 @@ def blend_mode(self, img1, img2, mode): else: raise ValueError(f"Unsupported blend mode: {mode}") + class LatentCrop: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - }} + return {"required": {"samples": ("LATENT",), + "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "crop" @@ -1278,10 +1378,10 @@ def INPUT_TYPES(s): def crop(self, samples, width, height, x, y): s = samples.copy() samples = samples['samples'] - x = x // 8 + x = x // 8 y = y // 8 - #enfonce minimum size of 64 + # enfonce minimum size of 64 if x > (samples.shape[3] - 8): x = samples.shape[3] - 8 if y > (samples.shape[2] - 8): @@ -1291,15 +1391,17 @@ def crop(self, samples, width, height, x, y): new_width = width // 8 to_x = new_width + x to_y = new_height + y - s['samples'] = samples[:,:,y:to_y, x:to_x] + s['samples'] = samples[:, :, y:to_y, x:to_x] return (s,) + class SetLatentNoiseMask: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "mask": ("MASK",), - }} + return {"required": {"samples": ("LATENT",), + "mask": ("MASK",), + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "set_mask" @@ -1310,7 +1412,9 @@ def set_mask(self, samples, mask): s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) return (s,) -def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): + +def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, + disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): latent_image = latent["samples"] if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") @@ -1322,29 +1426,32 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if "noise_mask" in latent: noise_mask = latent["noise_mask"] - callback = ldm_patched.utils.latent_visualization.prepare_callback(model, steps) + callback = latent_preview.prepare_callback(model, steps) disable_pbar = not ldm_patched.modules.utils.PROGRESS_BAR_ENABLED samples = ldm_patched.modules.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, - denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, - force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) + denoise=denoise, disable_noise=disable_noise, start_step=start_step, + last_step=last_step, + force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, + disable_pbar=disable_pbar, seed=seed) out = latent.copy() out["samples"] = samples - return (out, ) + return (out,) + class KSampler: @classmethod def INPUT_TYPES(s): return {"required": {"model": ("MODEL",), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), - "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "sampler_name": (ldm_patched.modules.samplers.KSampler.SAMPLERS, ), - "scheduler": (ldm_patched.modules.samplers.KSampler.SCHEDULERS, ), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "latent_image": ("LATENT", ), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), + "sampler_name": (ldm_patched.modules.samplers.KSampler.SAMPLERS,), + "scheduler": (ldm_patched.modules.samplers.KSampler.SCHEDULERS,), + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "latent_image": ("LATENT",), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), } } @@ -1354,25 +1461,27 @@ def INPUT_TYPES(s): CATEGORY = "sampling" def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): - return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) + return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + denoise=denoise) + class KSamplerAdvanced: @classmethod def INPUT_TYPES(s): return {"required": {"model": ("MODEL",), - "add_noise": (["enable", "disable"], ), - "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), - "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "sampler_name": (ldm_patched.modules.samplers.KSampler.SAMPLERS, ), - "scheduler": (ldm_patched.modules.samplers.KSampler.SCHEDULERS, ), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "latent_image": ("LATENT", ), - "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), - "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), - "return_with_leftover_noise": (["disable", "enable"], ), + "add_noise": (["enable", "disable"],), + "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), + "sampler_name": (ldm_patched.modules.samplers.KSampler.SAMPLERS,), + "scheduler": (ldm_patched.modules.samplers.KSampler.SCHEDULERS,), + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "latent_image": ("LATENT",), + "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), + "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), + "return_with_leftover_noise": (["disable", "enable"],), } } @@ -1381,27 +1490,31 @@ def INPUT_TYPES(s): CATEGORY = "sampling" - def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0): + def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, + latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0): force_full_denoise = True if return_with_leftover_noise == "enable": force_full_denoise = False disable_noise = False if add_noise == "disable": disable_noise = True - return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise) + return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, + last_step=end_at_step, force_full_denoise=force_full_denoise) + class SaveImage: def __init__(self): - self.output_dir = ldm_patched.utils.path_utils.get_output_directory() + self.output_dir = folder_paths.get_output_directory() self.type = "output" self.prefix_append = "" self.compress_level = 4 @classmethod def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ldm_patched"})}, + return {"required": + {"images": ("IMAGE",), + "filename_prefix": ("STRING", {"default": "ComfyUI"})}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } @@ -1412,15 +1525,16 @@ def INPUT_TYPES(s): CATEGORY = "image" - def save_images(self, images, filename_prefix="ldm_patched", prompt=None, extra_pnginfo=None): + def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( + filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) results = list() - for image in images: + for (batch_number, image) in enumerate(images): i = 255. * image.cpu().numpy() img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) metadata = None - if not args.disable_server_info: + if not args.disable_metadata: metadata = PngInfo() if prompt is not None: metadata.add_text("prompt", json.dumps(prompt)) @@ -1428,7 +1542,8 @@ def save_images(self, images, filename_prefix="ldm_patched", prompt=None, extra_ for x in extra_pnginfo: metadata.add_text(x, json.dumps(extra_pnginfo[x])) - file = f"{filename}_{counter:05}_.png" + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.png" img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level) results.append({ "filename": file, @@ -1437,11 +1552,12 @@ def save_images(self, images, filename_prefix="ldm_patched", prompt=None, extra_ }) counter += 1 - return { "ui": { "images": results } } + return {"ui": {"images": results}} + class PreviewImage(SaveImage): def __init__(self): - self.output_dir = ldm_patched.utils.path_utils.get_temp_directory() + self.output_dir = folder_paths.get_temp_directory() self.type = "temp" self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) self.compress_level = 1 @@ -1449,14 +1565,15 @@ def __init__(self): @classmethod def INPUT_TYPES(s): return {"required": - {"images": ("IMAGE", ), }, + {"images": ("IMAGE",), }, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } + class LoadImage: @classmethod def INPUT_TYPES(s): - input_dir = ldm_patched.utils.path_utils.get_input_directory() + input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": {"image": (sorted(files), {"image_upload": True})}, @@ -1466,27 +1583,43 @@ def INPUT_TYPES(s): RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" + def load_image(self, image): - image_path = ldm_patched.utils.path_utils.get_annotated_filepath(image) - img = Image.open(image_path) + image_path = folder_paths.get_annotated_filepath(image) + + img = node_helpers.pillow(Image.open, image_path) + output_images = [] output_masks = [] + w, h = None, None + + excluded_formats = ['MPO'] + for i in ImageSequence.Iterator(img): - i = ImageOps.exif_transpose(i) + i = node_helpers.pillow(ImageOps.exif_transpose, i) + if i.mode == 'I': i = i.point(lambda i: i * (1 / 255)) image = i.convert("RGB") + + if len(output_images) == 0: + w = image.size[0] + h = image.size[1] + + if image.size[0] != w or image.size[1] != h: + continue + image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] if 'A' in i.getbands(): mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 mask = 1. - torch.from_numpy(mask) else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") output_images.append(image) output_masks.append(mask.unsqueeze(0)) - if len(output_images) > 1: + if len(output_images) > 1 and img.format not in excluded_formats: output_image = torch.cat(output_images, dim=0) output_mask = torch.cat(output_masks, dim=0) else: @@ -1497,7 +1630,7 @@ def load_image(self, image): @classmethod def IS_CHANGED(s, image): - image_path = ldm_patched.utils.path_utils.get_annotated_filepath(image) + image_path = folder_paths.get_annotated_filepath(image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) @@ -1505,30 +1638,33 @@ def IS_CHANGED(s, image): @classmethod def VALIDATE_INPUTS(s, image): - if not ldm_patched.utils.path_utils.exists_annotated_filepath(image): + if not folder_paths.exists_annotated_filepath(image): return "Invalid image file: {}".format(image) return True + class LoadImageMask: _color_channels = ["alpha", "red", "green", "blue"] + @classmethod def INPUT_TYPES(s): - input_dir = ldm_patched.utils.path_utils.get_input_directory() + input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": {"image": (sorted(files), {"image_upload": True}), - "channel": (s._color_channels, ), } + "channel": (s._color_channels,), } } CATEGORY = "mask" RETURN_TYPES = ("MASK",) FUNCTION = "load_image" + def load_image(self, image, channel): - image_path = ldm_patched.utils.path_utils.get_annotated_filepath(image) - i = Image.open(image_path) - i = ImageOps.exif_transpose(i) + image_path = folder_paths.get_annotated_filepath(image) + i = node_helpers.pillow(Image.open, image_path) + i = node_helpers.pillow(ImageOps.exif_transpose, i) if i.getbands() != ("R", "G", "B", "A"): if i.mode == 'I': i = i.point(lambda i: i * (1 / 255)) @@ -1541,12 +1677,12 @@ def load_image(self, image, channel): if c == 'A': mask = 1. - mask else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") return (mask.unsqueeze(0),) @classmethod def IS_CHANGED(s, image, channel): - image_path = ldm_patched.utils.path_utils.get_annotated_filepath(image) + image_path = folder_paths.get_annotated_filepath(image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) @@ -1554,21 +1690,23 @@ def IS_CHANGED(s, image, channel): @classmethod def VALIDATE_INPUTS(s, image): - if not ldm_patched.utils.path_utils.exists_annotated_filepath(image): + if not folder_paths.exists_annotated_filepath(image): return "Invalid image file: {}".format(image) return True + class ImageScale: upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] crop_methods = ["disabled", "center"] @classmethod def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,), - "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "crop": (s.crop_methods,)}} + return {"required": {"image": ("IMAGE",), "upscale_method": (s.upscale_methods,), + "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "crop": (s.crop_methods,)}} + RETURN_TYPES = ("IMAGE",) FUNCTION = "upscale" @@ -1578,7 +1716,7 @@ def upscale(self, image, upscale_method, width, height, crop): if width == 0 and height == 0: s = image else: - samples = image.movedim(-1,1) + samples = image.movedim(-1, 1) if width == 0: width = max(1, round(samples.shape[3] * height / samples.shape[2])) @@ -1586,34 +1724,37 @@ def upscale(self, image, upscale_method, width, height, crop): height = max(1, round(samples.shape[2] * width / samples.shape[3])) s = ldm_patched.modules.utils.common_upscale(samples, width, height, upscale_method, crop) - s = s.movedim(1,-1) + s = s.movedim(1, -1) return (s,) + class ImageScaleBy: upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] @classmethod def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,), - "scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 8.0, "step": 0.01}),}} + return {"required": {"image": ("IMAGE",), "upscale_method": (s.upscale_methods,), + "scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 8.0, "step": 0.01}), }} + RETURN_TYPES = ("IMAGE",) FUNCTION = "upscale" CATEGORY = "image/upscaling" def upscale(self, image, upscale_method, scale_by): - samples = image.movedim(-1,1) + samples = image.movedim(-1, 1) width = round(samples.shape[3] * scale_by) height = round(samples.shape[2] * scale_by) s = ldm_patched.modules.utils.common_upscale(samples, width, height, upscale_method, "disabled") - s = s.movedim(1,-1) + s = s.movedim(1, -1) return (s,) + class ImageInvert: @classmethod def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",)}} + return {"required": {"image": ("IMAGE",)}} RETURN_TYPES = ("IMAGE",) FUNCTION = "invert" @@ -1624,11 +1765,12 @@ def invert(self, image): s = 1.0 - image return (s,) + class ImageBatch: @classmethod def INPUT_TYPES(s): - return {"required": { "image1": ("IMAGE",), "image2": ("IMAGE",)}} + return {"required": {"image1": ("IMAGE",), "image2": ("IMAGE",)}} RETURN_TYPES = ("IMAGE",) FUNCTION = "batch" @@ -1637,21 +1779,24 @@ def INPUT_TYPES(s): def batch(self, image1, image2): if image1.shape[1:] != image2.shape[1:]: - image2 = ldm_patched.modules.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1) + image2 = ldm_patched.modules.utils.common_upscale(image2.movedim(-1, 1), image1.shape[2], image1.shape[1], "bilinear", + "center").movedim(1, -1) s = torch.cat((image1, image2), dim=0) return (s,) + class EmptyImage: def __init__(self, device="cpu"): self.device = device @classmethod def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), - }} + return {"required": {"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), + }} + RETURN_TYPES = ("IMAGE",) FUNCTION = "generate" @@ -1661,7 +1806,8 @@ def generate(self, width, height, batch_size=1, color=0): r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF) g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF) b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF) - return (torch.cat((r, g, b), dim=-1), ) + return (torch.cat((r, g, b), dim=-1),) + class ImagePadForOutpaint: @@ -1751,11 +1897,12 @@ def expand_image(self, image, left, top, right, bottom, feathering): "ImageBatch": ImageBatch, "ImagePadForOutpaint": ImagePadForOutpaint, "EmptyImage": EmptyImage, - "ConditioningAverage": ConditioningAverage , + "ConditioningAverage": ConditioningAverage, "ConditioningCombine": ConditioningCombine, "ConditioningConcat": ConditioningConcat, "ConditioningSetArea": ConditioningSetArea, "ConditioningSetAreaPercentage": ConditioningSetAreaPercentage, + "ConditioningSetAreaStrength": ConditioningSetAreaStrength, "ConditioningSetMask": ConditioningSetMask, "KSamplerAdvanced": KSamplerAdvanced, "SetLatentNoiseMask": SetLatentNoiseMask, @@ -1836,7 +1983,7 @@ def expand_image(self, image, left, top, right, bottom, feathering): "LatentUpscaleBy": "Upscale Latent By", "LatentComposite": "Latent Composite", "LatentBlend": "Latent Blend", - "LatentFromBatch" : "Latent From Batch", + "LatentFromBatch": "Latent From Batch", "RepeatLatentBatch": "Repeat Latent Batch", # Image "SaveImage": "Save Image", @@ -1856,12 +2003,14 @@ def expand_image(self, image, left, top, right, bottom, feathering): EXTENSION_WEB_DIRS = {} + def load_custom_node(module_path, ignore=set()): module_name = os.path.basename(module_path) if os.path.isfile(module_path): sp = os.path.splitext(module_path) module_name = sp[0] try: + logging.debug("Trying to load custom node {}".format(module_path)) if os.path.isfile(module_path): module_spec = importlib.util.spec_from_file_location(module_name, module_path) module_dir = os.path.split(module_path)[0] @@ -1882,20 +2031,22 @@ def load_custom_node(module_path, ignore=set()): for name in module.NODE_CLASS_MAPPINGS: if name not in ignore: NODE_CLASS_MAPPINGS[name] = module.NODE_CLASS_MAPPINGS[name] - if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: + if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, + "NODE_DISPLAY_NAME_MAPPINGS") is not None: NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) return True else: - print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") + logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") return False except Exception as e: - print(traceback.format_exc()) - print(f"Cannot import {module_path} module for custom nodes:", e) + logging.warning(traceback.format_exc()) + logging.warning(f"Cannot import {module_path} module for custom nodes: {e}") return False + def load_custom_nodes(): base_node_names = set(NODE_CLASS_MAPPINGS.keys()) - node_paths = ldm_patched.utils.path_utils.get_folder_paths("custom_nodes") + node_paths = folder_paths.get_folder_paths("custom_nodes") node_import_times = [] for custom_node_path in node_paths: possible_modules = os.listdir(os.path.realpath(custom_node_path)) @@ -1911,17 +2062,18 @@ def load_custom_nodes(): node_import_times.append((time.perf_counter() - time_before, module_path, success)) if len(node_import_times) > 0: - print("\nImport times for custom nodes:") + logging.info("\nImport times for custom nodes:") for n in sorted(node_import_times): if n[2]: import_message = "" else: import_message = " (IMPORT FAILED)" - print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) - print() + logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1])) + logging.info("") + def init_custom_nodes(): - extras_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ldm_patched_extras") + extras_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras") extras_files = [ "nodes_latent.py", "nodes_hypernetwork.py", @@ -1946,9 +2098,35 @@ def init_custom_nodes(): "nodes_stable3d.py", "nodes_sdupscale.py", "nodes_photomaker.py", + "nodes_cond.py", + "nodes_morphology.py", + "nodes_stable_cascade.py", + "nodes_differential_diffusion.py", + "nodes_ip2p.py", + "nodes_model_merging_model_specific.py", + "nodes_pag.py", + "nodes_align_your_steps.py", + "nodes_attention_multiply.py", + "nodes_advanced_samplers.py", + "nodes_webcam.py", ] + import_failed = [] for node_file in extras_files: - load_custom_node(os.path.join(extras_dir, node_file)) + if not load_custom_node(os.path.join(extras_dir, node_file)): + import_failed.append(node_file) load_custom_nodes() + + if len(import_failed) > 0: + logging.warning( + "WARNING: some comfy_extras/ nodes did not import correctly. This may be because they are missing some dependencies.\n") + for node in import_failed: + logging.warning("IMPORT FAILED: {}".format(node)) + logging.warning( + "\nThis issue might be caused by new missing dependencies added the last time you updated ComfyUI.") + if args.windows_standalone_build: + logging.warning("Please run the update script: update/update_comfyui.bat") + else: + logging.warning("Please do a: pip install -r requirements.txt") + logging.warning("") \ No newline at end of file diff --git a/ldm_patched/contrib/external_advanced_samplers.py b/ldm_patched/contrib/external_advanced_samplers.py new file mode 100644 index 000000000..885d6f5de --- /dev/null +++ b/ldm_patched/contrib/external_advanced_samplers.py @@ -0,0 +1,61 @@ +import ldm_patched.modules.samplers +import ldm_patched.modules.utils +import torch +import numpy as np +from tqdm.auto import trange, tqdm +import math + + +@torch.no_grad() +def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable=None, total_upscale=2.0, upscale_method="bislerp", upscale_steps=None): + extra_args = {} if extra_args is None else extra_args + + if upscale_steps is None: + upscale_steps = max(len(sigmas) // 2 + 1, 2) + else: + upscale_steps += 1 + upscale_steps = min(upscale_steps, len(sigmas) + 1) + + upscales = np.linspace(1.0, total_upscale, upscale_steps)[1:] + + orig_shape = x.size() + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + + x = denoised + if i < len(upscales): + x = ldm_patched.modules.utils.common_upscale(x, round(orig_shape[-1] * upscales[i]), round(orig_shape[-2] * upscales[i]), upscale_method, "disabled") + + if sigmas[i + 1] > 0: + x += sigmas[i + 1] * torch.randn_like(x) + return x + + +class SamplerLCMUpscale: + upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"] + + @classmethod + def INPUT_TYPES(s): + return {"required": + {"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}), + "scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}), + "upscale_method": (s.upscale_methods,), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, scale_ratio, scale_steps, upscale_method): + if scale_steps < 0: + scale_steps = None + sampler = ldm_patched.modules.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method}) + return (sampler, ) + +NODE_CLASS_MAPPINGS = { + "SamplerLCMUpscale": SamplerLCMUpscale, +} diff --git a/ldm_patched/contrib/external_align_your_steps.py b/ldm_patched/contrib/external_align_your_steps.py index 624bbce2a..3ffe53187 100644 --- a/ldm_patched/contrib/external_align_your_steps.py +++ b/ldm_patched/contrib/external_align_your_steps.py @@ -1,5 +1,3 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - #from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html import numpy as np import torch @@ -52,4 +50,4 @@ def get_sigmas(self, model_type, steps, denoise): NODE_CLASS_MAPPINGS = { "AlignYourStepsScheduler": AlignYourStepsScheduler, -} \ No newline at end of file +} diff --git a/ldm_patched/contrib/external_attention_multiply.py b/ldm_patched/contrib/external_attention_multiply.py new file mode 100644 index 000000000..4747eb395 --- /dev/null +++ b/ldm_patched/contrib/external_attention_multiply.py @@ -0,0 +1,120 @@ + +def attention_multiply(attn, model, q, k, v, out): + m = model.clone() + sd = model.model_state_dict() + + for key in sd: + if key.endswith("{}.to_q.bias".format(attn)) or key.endswith("{}.to_q.weight".format(attn)): + m.add_patches({key: (None,)}, 0.0, q) + if key.endswith("{}.to_k.bias".format(attn)) or key.endswith("{}.to_k.weight".format(attn)): + m.add_patches({key: (None,)}, 0.0, k) + if key.endswith("{}.to_v.bias".format(attn)) or key.endswith("{}.to_v.weight".format(attn)): + m.add_patches({key: (None,)}, 0.0, v) + if key.endswith("{}.to_out.0.bias".format(attn)) or key.endswith("{}.to_out.0.weight".format(attn)): + m.add_patches({key: (None,)}, 0.0, out) + + return m + + +class UNetSelfAttentionMultiply: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing/attention_experiments" + + def patch(self, model, q, k, v, out): + m = attention_multiply("attn1", model, q, k, v, out) + return (m, ) + +class UNetCrossAttentionMultiply: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing/attention_experiments" + + def patch(self, model, q, k, v, out): + m = attention_multiply("attn2", model, q, k, v, out) + return (m, ) + +class CLIPAttentionMultiply: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip": ("CLIP",), + "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "patch" + + CATEGORY = "_for_testing/attention_experiments" + + def patch(self, clip, q, k, v, out): + m = clip.clone() + sd = m.patcher.model_state_dict() + + for key in sd: + if key.endswith("self_attn.q_proj.weight") or key.endswith("self_attn.q_proj.bias"): + m.add_patches({key: (None,)}, 0.0, q) + if key.endswith("self_attn.k_proj.weight") or key.endswith("self_attn.k_proj.bias"): + m.add_patches({key: (None,)}, 0.0, k) + if key.endswith("self_attn.v_proj.weight") or key.endswith("self_attn.v_proj.bias"): + m.add_patches({key: (None,)}, 0.0, v) + if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"): + m.add_patches({key: (None,)}, 0.0, out) + return (m, ) + +class UNetTemporalAttentionMultiply: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "self_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "self_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "cross_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "cross_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing/attention_experiments" + + def patch(self, model, self_structural, self_temporal, cross_structural, cross_temporal): + m = model.clone() + sd = model.model_state_dict() + + for k in sd: + if (k.endswith("attn1.to_out.0.bias") or k.endswith("attn1.to_out.0.weight")): + if '.time_stack.' in k: + m.add_patches({k: (None,)}, 0.0, self_temporal) + else: + m.add_patches({k: (None,)}, 0.0, self_structural) + elif (k.endswith("attn2.to_out.0.bias") or k.endswith("attn2.to_out.0.weight")): + if '.time_stack.' in k: + m.add_patches({k: (None,)}, 0.0, cross_temporal) + else: + m.add_patches({k: (None,)}, 0.0, cross_structural) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "UNetSelfAttentionMultiply": UNetSelfAttentionMultiply, + "UNetCrossAttentionMultiply": UNetCrossAttentionMultiply, + "CLIPAttentionMultiply": CLIPAttentionMultiply, + "UNetTemporalAttentionMultiply": UNetTemporalAttentionMultiply, +} diff --git a/ldm_patched/contrib/external_canny.py b/ldm_patched/contrib/external_canny.py index 7347ba1ed..c4414e43c 100644 --- a/ldm_patched/contrib/external_canny.py +++ b/ldm_patched/contrib/external_canny.py @@ -1,282 +1,6 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - -#From https://github.com/kornia/kornia -import math - -import torch -import torch.nn.functional as F +from kornia.filters import canny import ldm_patched.modules.model_management -def get_canny_nms_kernel(device=None, dtype=None): - """Utility function that returns 3x3 kernels for the Canny Non-maximal suppression.""" - return torch.tensor( - [ - [[[0.0, 0.0, 0.0], [0.0, 1.0, -1.0], [0.0, 0.0, 0.0]]], - [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]]], - [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]]], - [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]]], - [[[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 0.0]]], - [[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]], - [[[0.0, -1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]], - [[[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]], - ], - device=device, - dtype=dtype, - ) - - -def get_hysteresis_kernel(device=None, dtype=None): - """Utility function that returns the 3x3 kernels for the Canny hysteresis.""" - return torch.tensor( - [ - [[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]]], - [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]], - [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], - [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]], - [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], - [[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], - [[[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], - [[[0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], - ], - device=device, - dtype=dtype, - ) - -def gaussian_blur_2d(img, kernel_size, sigma): - ksize_half = (kernel_size - 1) * 0.5 - - x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) - - pdf = torch.exp(-0.5 * (x / sigma).pow(2)) - - x_kernel = pdf / pdf.sum() - x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) - - kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) - kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) - - padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] - - img = torch.nn.functional.pad(img, padding, mode="reflect") - img = torch.nn.functional.conv2d(img, kernel2d, groups=img.shape[-3]) - - return img - -def get_sobel_kernel2d(device=None, dtype=None): - kernel_x = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], device=device, dtype=dtype) - kernel_y = kernel_x.transpose(0, 1) - return torch.stack([kernel_x, kernel_y]) - -def spatial_gradient(input, normalized: bool = True): - r"""Compute the first order image derivative in both x and y using a Sobel operator. - .. image:: _static/img/spatial_gradient.png - Args: - input: input image tensor with shape :math:`(B, C, H, W)`. - mode: derivatives modality, can be: `sobel` or `diff`. - order: the order of the derivatives. - normalized: whether the output is normalized. - Return: - the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`. - .. note:: - See a working example `here `__. - Examples: - >>> input = torch.rand(1, 3, 4, 4) - >>> output = spatial_gradient(input) # 1x3x2x4x4 - >>> output.shape - torch.Size([1, 3, 2, 4, 4]) - """ - # KORNIA_CHECK_IS_TENSOR(input) - # KORNIA_CHECK_SHAPE(input, ['B', 'C', 'H', 'W']) - - # allocate kernel - kernel = get_sobel_kernel2d(device=input.device, dtype=input.dtype) - if normalized: - kernel = normalize_kernel2d(kernel) - - # prepare kernel - b, c, h, w = input.shape - tmp_kernel = kernel[:, None, ...] - - # Pad with "replicate for spatial dims, but with zeros for channel - spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2] - out_channels: int = 2 - padded_inp = torch.nn.functional.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate') - out = F.conv2d(padded_inp, tmp_kernel, groups=1, padding=0, stride=1) - return out.reshape(b, c, out_channels, h, w) - -def rgb_to_grayscale(image, rgb_weights = None): - r"""Convert a RGB image to grayscale version of image. - - .. image:: _static/img/rgb_to_grayscale.png - - The image data is assumed to be in the range of (0, 1). - - Args: - image: RGB image to be converted to grayscale with shape :math:`(*,3,H,W)`. - rgb_weights: Weights that will be applied on each channel (RGB). - The sum of the weights should add up to one. - Returns: - grayscale version of the image with shape :math:`(*,1,H,W)`. - - .. note:: - See a working example `here `__. - - Example: - >>> input = torch.rand(2, 3, 4, 5) - >>> gray = rgb_to_grayscale(input) # 2x1x4x5 - """ - - if len(image.shape) < 3 or image.shape[-3] != 3: - raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") - - if rgb_weights is None: - # 8 bit images - if image.dtype == torch.uint8: - rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8) - # floating point images - elif image.dtype in (torch.float16, torch.float32, torch.float64): - rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype) - else: - raise TypeError(f"Unknown data type: {image.dtype}") - else: - # is tensor that we make sure is in the same device/dtype - rgb_weights = rgb_weights.to(image) - - # unpack the color image channels with RGB order - r: Tensor = image[..., 0:1, :, :] - g: Tensor = image[..., 1:2, :, :] - b: Tensor = image[..., 2:3, :, :] - - w_r, w_g, w_b = rgb_weights.unbind() - return w_r * r + w_g * g + w_b * b - -def canny( - input, - low_threshold = 0.1, - high_threshold = 0.2, - kernel_size = 5, - sigma = 1, - hysteresis = True, - eps = 1e-6, -): - r"""Find edges of the input image and filters them using the Canny algorithm. - .. image:: _static/img/canny.png - Args: - input: input image tensor with shape :math:`(B,C,H,W)`. - low_threshold: lower threshold for the hysteresis procedure. - high_threshold: upper threshold for the hysteresis procedure. - kernel_size: the size of the kernel for the gaussian blur. - sigma: the standard deviation of the kernel for the gaussian blur. - hysteresis: if True, applies the hysteresis edge tracking. - Otherwise, the edges are divided between weak (0.5) and strong (1) edges. - eps: regularization number to avoid NaN during backprop. - Returns: - - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`. - - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`. - .. note:: - See a working example `here `__. - Example: - >>> input = torch.rand(5, 3, 4, 4) - >>> magnitude, edges = canny(input) # 5x3x4x4 - >>> magnitude.shape - torch.Size([5, 1, 4, 4]) - >>> edges.shape - torch.Size([5, 1, 4, 4]) - """ - # KORNIA_CHECK_IS_TENSOR(input) - # KORNIA_CHECK_SHAPE(input, ['B', 'C', 'H', 'W']) - # KORNIA_CHECK( - # low_threshold <= high_threshold, - # "Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: " - # f"{low_threshold}>{high_threshold}", - # ) - # KORNIA_CHECK(0 < low_threshold < 1, f'Invalid low threshold. Should be in range (0, 1). Got: {low_threshold}') - # KORNIA_CHECK(0 < high_threshold < 1, f'Invalid high threshold. Should be in range (0, 1). Got: {high_threshold}') - - device = input.device - dtype = input.dtype - - # To Grayscale - if input.shape[1] == 3: - input = rgb_to_grayscale(input) - - # Gaussian filter - blurred: Tensor = gaussian_blur_2d(input, kernel_size, sigma) - - # Compute the gradients - gradients: Tensor = spatial_gradient(blurred, normalized=False) - - # Unpack the edges - gx: Tensor = gradients[:, :, 0] - gy: Tensor = gradients[:, :, 1] - - # Compute gradient magnitude and angle - magnitude: Tensor = torch.sqrt(gx * gx + gy * gy + eps) - angle: Tensor = torch.atan2(gy, gx) - - # Radians to Degrees - angle = 180.0 * angle / math.pi - - # Round angle to the nearest 45 degree - angle = torch.round(angle / 45) * 45 - - # Non-maximal suppression - nms_kernels: Tensor = get_canny_nms_kernel(device, dtype) - nms_magnitude: Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2) - - # Get the indices for both directions - positive_idx: Tensor = (angle / 45) % 8 - positive_idx = positive_idx.long() - - negative_idx: Tensor = ((angle / 45) + 4) % 8 - negative_idx = negative_idx.long() - - # Apply the non-maximum suppression to the different directions - channel_select_filtered_positive: Tensor = torch.gather(nms_magnitude, 1, positive_idx) - channel_select_filtered_negative: Tensor = torch.gather(nms_magnitude, 1, negative_idx) - - channel_select_filtered: Tensor = torch.stack( - [channel_select_filtered_positive, channel_select_filtered_negative], 1 - ) - - is_max: Tensor = channel_select_filtered.min(dim=1)[0] > 0.0 - - magnitude = magnitude * is_max - - # Threshold - edges: Tensor = F.threshold(magnitude, low_threshold, 0.0) - - low: Tensor = magnitude > low_threshold - high: Tensor = magnitude > high_threshold - - edges = low * 0.5 + high * 0.5 - edges = edges.to(dtype) - - # Hysteresis - if hysteresis: - edges_old: Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype) - hysteresis_kernels: Tensor = get_hysteresis_kernel(device, dtype) - - while ((edges_old - edges).abs() != 0).any(): - weak: Tensor = (edges == 0.5).float() - strong: Tensor = (edges == 1).float() - - hysteresis_magnitude: Tensor = F.conv2d( - edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2 - ) - hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype) - hysteresis_magnitude = hysteresis_magnitude * weak + strong - - edges_old = edges.clone() - edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5 - - edges = hysteresis_magnitude - - return magnitude, edges - class Canny: @classmethod diff --git a/ldm_patched/contrib/external_clip_sdxl.py b/ldm_patched/contrib/external_clip_sdxl.py index 230321a87..cc1fe274a 100644 --- a/ldm_patched/contrib/external_clip_sdxl.py +++ b/ldm_patched/contrib/external_clip_sdxl.py @@ -1,7 +1,5 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import torch -from ldm_patched.contrib.external import MAX_RESOLUTION +from external import MAX_RESOLUTION class CLIPTextEncodeSDXLRefiner: @classmethod @@ -10,7 +8,7 @@ def INPUT_TYPES(s): "ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}), "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "text": ("STRING", {"multiline": True}), "clip": ("CLIP", ), + "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" @@ -32,8 +30,8 @@ def INPUT_TYPES(s): "crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), "target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), "target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "text_g": ("STRING", {"multiline": True, "default": "CLIP_G"}), "clip": ("CLIP", ), - "text_l": ("STRING", {"multiline": True, "default": "CLIP_L"}), "clip": ("CLIP", ), + "text_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), + "text_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" diff --git a/ldm_patched/contrib/external_compositing.py b/ldm_patched/contrib/external_compositing.py index 0cf91d9a7..bd9bffd19 100644 --- a/ldm_patched/contrib/external_compositing.py +++ b/ldm_patched/contrib/external_compositing.py @@ -1,5 +1,3 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import numpy as np import torch import ldm_patched.modules.utils diff --git a/ldm_patched/contrib/external_cond.py b/ldm_patched/contrib/external_cond.py new file mode 100644 index 000000000..4c3a1d5bf --- /dev/null +++ b/ldm_patched/contrib/external_cond.py @@ -0,0 +1,25 @@ + + +class CLIPTextEncodeControlnet: + @classmethod + def INPUT_TYPES(s): + return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True, "dynamicPrompts": True})}} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "_for_testing/conditioning" + + def encode(self, clip, conditioning, text): + tokens = clip.tokenize(text) + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['cross_attn_controlnet'] = cond + n[1]['pooled_output_controlnet'] = pooled + c.append(n) + return (c, ) + +NODE_CLASS_MAPPINGS = { + "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet +} diff --git a/ldm_patched/contrib/external_custom_sampler.py b/ldm_patched/contrib/external_custom_sampler.py index 60d5e3bd2..1a50893b8 100644 --- a/ldm_patched/contrib/external_custom_sampler.py +++ b/ldm_patched/contrib/external_custom_sampler.py @@ -1,11 +1,10 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import ldm_patched.modules.samplers import ldm_patched.modules.sample from ldm_patched.k_diffusion import sampling as k_diffusion_sampling -import ldm_patched.utils.latent_visualization +import ldm_patched.utils.latent_visualization as latent_preview import torch import ldm_patched.modules.utils +import ldm_patched.utils.node_helpers as node_helpers class BasicScheduler: @@ -26,10 +25,11 @@ def INPUT_TYPES(s): def get_sigmas(self, model, scheduler, steps, denoise): total_steps = steps if denoise < 1.0: + if denoise <= 0.0: + return (torch.FloatTensor([]),) total_steps = int(steps/denoise) - ldm_patched.modules.model_management.load_models_gpu([model]) - sigmas = ldm_patched.modules.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu() + sigmas = ldm_patched.modules.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu() sigmas = sigmas[-(steps + 1):] return (sigmas, ) @@ -39,8 +39,8 @@ class KarrasScheduler: def INPUT_TYPES(s): return {"required": {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), + "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), + "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), "rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), } } @@ -58,8 +58,8 @@ class ExponentialScheduler: def INPUT_TYPES(s): return {"required": {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), + "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), + "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), } } RETURN_TYPES = ("SIGMAS",) @@ -76,8 +76,8 @@ class PolyexponentialScheduler: def INPUT_TYPES(s): return {"required": {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), + "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), + "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), "rho": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), } } @@ -107,7 +107,7 @@ def INPUT_TYPES(s): def get_sigmas(self, model, steps, denoise): start_step = 10 - int(10 * denoise) timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] - sigmas = model.model_sampling.sigma(timesteps) + sigmas = model.get_model_object("model_sampling").sigma(timesteps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) return (sigmas, ) @@ -116,8 +116,8 @@ class VPScheduler: def INPUT_TYPES(s): return {"required": {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), #TODO: fix default values - "beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), + "beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), #TODO: fix default values + "beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), "eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}), } } @@ -139,6 +139,7 @@ def INPUT_TYPES(s): } } RETURN_TYPES = ("SIGMAS","SIGMAS") + RETURN_NAMES = ("high_sigmas", "low_sigmas") CATEGORY = "sampling/custom_sampling/sigmas" FUNCTION = "get_sigmas" @@ -148,6 +149,27 @@ def get_sigmas(self, sigmas, step): sigmas2 = sigmas[step:] return (sigmas1, sigmas2) +class SplitSigmasDenoise: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"sigmas": ("SIGMAS", ), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + RETURN_TYPES = ("SIGMAS","SIGMAS") + RETURN_NAMES = ("high_sigmas", "low_sigmas") + CATEGORY = "sampling/custom_sampling/sigmas" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, sigmas, denoise): + steps = max(sigmas.shape[-1] - 1, 0) + total_steps = round(steps * denoise) + sigmas1 = sigmas[:-(total_steps)] + sigmas2 = sigmas[-(total_steps + 1):] + return (sigmas1, sigmas2) + class FlipSigmas: @classmethod def INPUT_TYPES(s): @@ -161,6 +183,9 @@ def INPUT_TYPES(s): FUNCTION = "get_sigmas" def get_sigmas(self, sigmas): + if len(sigmas) == 0: + return (sigmas,) + sigmas = sigmas.flip(0) if sigmas[0] == 0: sigmas[0] = 0.0001 @@ -182,6 +207,28 @@ def get_sampler(self, sampler_name): sampler = ldm_patched.modules.samplers.sampler_object(sampler_name) return (sampler, ) +class SamplerDPMPP_3M_SDE: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "noise_device": (['gpu', 'cpu'], ), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta, s_noise, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_3m_sde" + else: + sampler_name = "dpmpp_3m_sde_gpu" + sampler = ldm_patched.modules.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise}) + return (sampler, ) + class SamplerDPMPP_2M_SDE: @classmethod def INPUT_TYPES(s): @@ -229,24 +276,83 @@ def get_sampler(self, eta, s_noise, r, noise_device): sampler = ldm_patched.modules.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) return (sampler, ) +class SamplerEulerAncestral: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta, s_noise): + sampler = ldm_patched.modules.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise}) + return (sampler, ) -class SamplerTCD: +class SamplerLMS: @classmethod def INPUT_TYPES(s): - return { - "required": { - "eta": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } + return {"required": + {"order": ("INT", {"default": 4, "min": 1, "max": 100}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, order): + sampler = ldm_patched.modules.samplers.ksampler("lms", {"order": order}) + return (sampler, ) + +class SamplerDPMAdaptative: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"order": ("INT", {"default": 3, "min": 2, "max": 3}), + "rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + } + } RETURN_TYPES = ("SAMPLER",) CATEGORY = "sampling/custom_sampling/samplers" FUNCTION = "get_sampler" - def get_sampler(self, eta=0.3): - sampler = ldm_patched.modules.samplers.ksampler("tcd", {"eta": eta}) + def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise): + sampler = ldm_patched.modules.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff, + "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, + "s_noise":s_noise }) return (sampler, ) +class Noise_EmptyNoise: + def __init__(self): + self.seed = 0 + + def generate_noise(self, input_latent): + latent_image = input_latent["samples"] + return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + + +class Noise_RandomNoise: + def __init__(self, seed): + self.seed = seed + + def generate_noise(self, input_latent): + latent_image = input_latent["samples"] + batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None + return ldm_patched.modules.sample.prepare_noise(latent_image, self.seed, batch_inds) class SamplerCustom: @classmethod @@ -275,17 +381,16 @@ def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, latent = latent_image latent_image = latent["samples"] if not add_noise: - noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + noise = Noise_EmptyNoise().generate_noise(latent) else: - batch_inds = latent["batch_index"] if "batch_index" in latent else None - noise = ldm_patched.modules.sample.prepare_noise(latent_image, noise_seed, batch_inds) + noise = Noise_RandomNoise(noise_seed).generate_noise(latent) noise_mask = None if "noise_mask" in latent: noise_mask = latent["noise_mask"] x0_output = {} - callback = ldm_patched.utils.latent_visualization.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) + callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) disable_pbar = not ldm_patched.modules.utils.PROGRESS_BAR_ENABLED samples = ldm_patched.modules.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) @@ -299,6 +404,207 @@ def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, out_denoised = out return (out, out_denoised) +class Guider_Basic(ldm_patched.modules.samplers.CFGGuider): + def set_conds(self, positive): + self.inner_set_conds({"positive": positive}) + +class BasicGuider: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "conditioning": ("CONDITIONING", ), + } + } + + RETURN_TYPES = ("GUIDER",) + + FUNCTION = "get_guider" + CATEGORY = "sampling/custom_sampling/guiders" + + def get_guider(self, model, conditioning): + guider = Guider_Basic(model) + guider.set_conds(conditioning) + return (guider,) + +class CFGGuider: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + } + } + + RETURN_TYPES = ("GUIDER",) + + FUNCTION = "get_guider" + CATEGORY = "sampling/custom_sampling/guiders" + + def get_guider(self, model, positive, negative, cfg): + guider = ldm_patched.modules.samplers.CFGGuider(model) + guider.set_conds(positive, negative) + guider.set_cfg(cfg) + return (guider,) + +class Guider_DualCFG(ldm_patched.modules.samplers.CFGGuider): + def set_cfg(self, cfg1, cfg2): + self.cfg1 = cfg1 + self.cfg2 = cfg2 + + def set_conds(self, positive, middle, negative): + middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"}) + self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative}) + + def predict_noise(self, x, timestep, model_options={}, seed=None): + negative_cond = self.conds.get("negative", None) + middle_cond = self.conds.get("middle", None) + + out = ldm_patched.modules.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, self.conds.get("positive", None)], x, timestep, model_options) + return ldm_patched.modules.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1 + +class DualCFGGuider: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "cond1": ("CONDITIONING", ), + "cond2": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + "cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + } + } + + RETURN_TYPES = ("GUIDER",) + + FUNCTION = "get_guider" + CATEGORY = "sampling/custom_sampling/guiders" + + def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative): + guider = Guider_DualCFG(model) + guider.set_conds(cond1, cond2, negative) + guider.set_cfg(cfg_conds, cfg_cond2_negative) + return (guider,) + +class DisableNoise: + @classmethod + def INPUT_TYPES(s): + return {"required":{ + } + } + + RETURN_TYPES = ("NOISE",) + FUNCTION = "get_noise" + CATEGORY = "sampling/custom_sampling/noise" + + def get_noise(self): + return (Noise_EmptyNoise(),) + + +class RandomNoise(DisableNoise): + @classmethod + def INPUT_TYPES(s): + return {"required":{ + "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + } + } + + def get_noise(self, noise_seed): + return (Noise_RandomNoise(noise_seed),) + + +class SamplerCustomAdvanced: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"noise": ("NOISE", ), + "guider": ("GUIDER", ), + "sampler": ("SAMPLER", ), + "sigmas": ("SIGMAS", ), + "latent_image": ("LATENT", ), + } + } + + RETURN_TYPES = ("LATENT","LATENT") + RETURN_NAMES = ("output", "denoised_output") + + FUNCTION = "sample" + + CATEGORY = "sampling/custom_sampling" + + def sample(self, noise, guider, sampler, sigmas, latent_image): + latent = latent_image + latent_image = latent["samples"] + + noise_mask = None + if "noise_mask" in latent: + noise_mask = latent["noise_mask"] + + x0_output = {} + callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output) + + disable_pbar = not ldm_patched.modules.utils.PROGRESS_BAR_ENABLED + samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed) + samples = samples.to(ldm_patched.modules.model_management.intermediate_device()) + + out = latent.copy() + out["samples"] = samples + if "x0" in x0_output: + out_denoised = latent.copy() + out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + else: + out_denoised = out + return (out, out_denoised) + +class AddNoise: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "noise": ("NOISE", ), + "sigmas": ("SIGMAS", ), + "latent_image": ("LATENT", ), + } + } + + RETURN_TYPES = ("LATENT",) + + FUNCTION = "add_noise" + + CATEGORY = "_for_testing/custom_sampling/noise" + + def add_noise(self, model, noise, sigmas, latent_image): + if len(sigmas) == 0: + return latent_image + + latent = latent_image + latent_image = latent["samples"] + + noisy = noise.generate_noise(latent) + + model_sampling = model.get_model_object("model_sampling") + process_latent_out = model.get_model_object("process_latent_out") + process_latent_in = model.get_model_object("process_latent_in") + + if len(sigmas) > 1: + scale = torch.abs(sigmas[0] - sigmas[-1]) + else: + scale = sigmas[0] + + if torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. + latent_image = process_latent_in(latent_image) + noisy = model_sampling.noise_scaling(scale, noisy, latent_image) + noisy = process_latent_out(noisy) + noisy = torch.nan_to_num(noisy, nan=0.0, posinf=0.0, neginf=0.0) + + out = latent.copy() + out["samples"] = noisy + return (out,) + + NODE_CLASS_MAPPINGS = { "SamplerCustom": SamplerCustom, "BasicScheduler": BasicScheduler, @@ -308,9 +614,21 @@ def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, "VPScheduler": VPScheduler, "SDTurboScheduler": SDTurboScheduler, "KSamplerSelect": KSamplerSelect, + "SamplerEulerAncestral": SamplerEulerAncestral, + "SamplerLMS": SamplerLMS, + "SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE, - "SamplerTCD": SamplerTCD, + "SamplerDPMAdaptative": SamplerDPMAdaptative, "SplitSigmas": SplitSigmas, + "SplitSigmasDenoise": SplitSigmasDenoise, "FlipSigmas": FlipSigmas, + + "CFGGuider": CFGGuider, + "DualCFGGuider": DualCFGGuider, + "BasicGuider": BasicGuider, + "RandomNoise": RandomNoise, + "DisableNoise": DisableNoise, + "AddNoise": AddNoise, + "SamplerCustomAdvanced": SamplerCustomAdvanced, } diff --git a/ldm_patched/contrib/external_differential_diffusion.py b/ldm_patched/contrib/external_differential_diffusion.py new file mode 100644 index 000000000..98dbbf102 --- /dev/null +++ b/ldm_patched/contrib/external_differential_diffusion.py @@ -0,0 +1,42 @@ +# code adapted from https://github.com/exx8/differential-diffusion + +import torch + +class DifferentialDiffusion(): + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply" + CATEGORY = "_for_testing" + INIT = False + + def apply(self, model): + model = model.clone() + model.set_model_denoise_mask_function(self.forward) + return (model,) + + def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): + model = extra_options["model"] + step_sigmas = extra_options["sigmas"] + sigma_to = model.inner_model.model_sampling.sigma_min + if step_sigmas[-1] > sigma_to: + sigma_to = step_sigmas[-1] + sigma_from = step_sigmas[0] + + ts_from = model.inner_model.model_sampling.timestep(sigma_from) + ts_to = model.inner_model.model_sampling.timestep(sigma_to) + current_ts = model.inner_model.model_sampling.timestep(sigma[0]) + + threshold = (current_ts - ts_to) / (ts_from - ts_to) + + return (denoise_mask >= threshold).to(denoise_mask.dtype) + + +NODE_CLASS_MAPPINGS = { + "DifferentialDiffusion": DifferentialDiffusion, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "DifferentialDiffusion": "Differential Diffusion", +} diff --git a/ldm_patched/contrib/external_freelunch.py b/ldm_patched/contrib/external_freelunch.py index 59ec5babd..c5ebcf26f 100644 --- a/ldm_patched/contrib/external_freelunch.py +++ b/ldm_patched/contrib/external_freelunch.py @@ -1,9 +1,7 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - #code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License) import torch - +import logging def Fourier_filter(x, threshold, scale): # FFT @@ -44,14 +42,14 @@ def patch(self, model, b1, b2, s1, s2): on_cpu_devices = {} def output_block_patch(h, hsp, transformer_options): - scale = scale_dict.get(h.shape[1], None) + scale = scale_dict.get(int(h.shape[1]), None) if scale is not None: h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0] if hsp.device not in on_cpu_devices: try: hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) except: - print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.") + logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device)) on_cpu_devices[hsp.device] = True hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) else: @@ -83,7 +81,7 @@ def patch(self, model, b1, b2, s1, s2): on_cpu_devices = {} def output_block_patch(h, hsp, transformer_options): - scale = scale_dict.get(h.shape[1], None) + scale = scale_dict.get(int(h.shape[1]), None) if scale is not None: hidden_mean = h.mean(1).unsqueeze(1) B = hidden_mean.shape[0] @@ -97,7 +95,7 @@ def output_block_patch(h, hsp, transformer_options): try: hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) except: - print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.") + logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device)) on_cpu_devices[hsp.device] = True hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) else: diff --git a/ldm_patched/contrib/external_hypernetwork.py b/ldm_patched/contrib/external_hypernetwork.py index 17aaacb00..0f8d84916 100644 --- a/ldm_patched/contrib/external_hypernetwork.py +++ b/ldm_patched/contrib/external_hypernetwork.py @@ -1,8 +1,7 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import ldm_patched.modules.utils -import ldm_patched.utils.path_utils +import ldm_patched.utils.path_utils as folder_paths import torch +import logging def load_hypernetwork_patch(path, strength): sd = ldm_patched.modules.utils.load_torch_file(path, safe_load=True) @@ -25,7 +24,7 @@ def load_hypernetwork_patch(path, strength): } if activation_func not in valid_activation: - print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout) + logging.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)) return None out = {} @@ -99,7 +98,7 @@ class HypernetworkLoader: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "hypernetwork_name": (ldm_patched.utils.path_utils.get_filename_list("hypernetworks"), ), + "hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), }} RETURN_TYPES = ("MODEL",) @@ -108,7 +107,7 @@ def INPUT_TYPES(s): CATEGORY = "loaders" def load_hypernetwork(self, model, hypernetwork_name, strength): - hypernetwork_path = ldm_patched.utils.path_utils.get_full_path("hypernetworks", hypernetwork_name) + hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name) model_hypernetwork = model.clone() patch = load_hypernetwork_patch(hypernetwork_path, strength) if patch is not None: diff --git a/ldm_patched/contrib/external_hypertile.py b/ldm_patched/contrib/external_hypertile.py index 5cf7d9d6d..ae55d23dd 100644 --- a/ldm_patched/contrib/external_hypertile.py +++ b/ldm_patched/contrib/external_hypertile.py @@ -1,5 +1,3 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - #Taken from: https://github.com/tfernd/HyperTile/ import math diff --git a/ldm_patched/contrib/external_images.py b/ldm_patched/contrib/external_images.py index 17e9c4978..e60515222 100644 --- a/ldm_patched/contrib/external_images.py +++ b/ldm_patched/contrib/external_images.py @@ -1,7 +1,5 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - -import ldm_patched.contrib.external -import ldm_patched.utils.path_utils +import external +import ldm_patched.utils.path_utils as folder_paths from ldm_patched.modules.args_parser import args from PIL import Image @@ -11,7 +9,7 @@ import json import os -MAX_RESOLUTION = ldm_patched.contrib.external.MAX_RESOLUTION +MAX_RESOLUTION = external.MAX_RESOLUTION class ImageCrop: @classmethod @@ -39,7 +37,7 @@ class RepeatImageBatch: @classmethod def INPUT_TYPES(s): return {"required": { "image": ("IMAGE",), - "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + "amount": ("INT", {"default": 1, "min": 1, "max": 4096}), }} RETURN_TYPES = ("IMAGE",) FUNCTION = "repeat" @@ -50,9 +48,28 @@ def repeat(self, image, amount): s = image.repeat((amount, 1,1,1)) return (s,) +class ImageFromBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}), + "length": ("INT", {"default": 1, "min": 1, "max": 4096}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "frombatch" + + CATEGORY = "image/batch" + + def frombatch(self, image, batch_index, length): + s_in = image + batch_index = min(s_in.shape[0] - 1, batch_index) + length = min(s_in.shape[0] - batch_index, length) + s = s_in[batch_index:batch_index + length].clone() + return (s,) + class SaveAnimatedWEBP: def __init__(self): - self.output_dir = ldm_patched.utils.path_utils.get_output_directory() + self.output_dir = folder_paths.get_output_directory() self.type = "output" self.prefix_append = "" @@ -61,7 +78,7 @@ def __init__(self): def INPUT_TYPES(s): return {"required": {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ldm_patched"}), + "filename_prefix": ("STRING", {"default": "ComfyUI"}), "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), "lossless": ("BOOLEAN", {"default": True}), "quality": ("INT", {"default": 80, "min": 0, "max": 100}), @@ -81,7 +98,7 @@ def INPUT_TYPES(s): def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): method = self.methods.get(method) filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) results = list() pil_images = [] for image in images: @@ -90,7 +107,7 @@ def save_images(self, images, fps, filename_prefix, lossless, quality, method, n pil_images.append(img) metadata = pil_images[0].getexif() - if not args.disable_server_info: + if not args.disable_metadata: if prompt is not None: metadata[0x0110] = "prompt:{}".format(json.dumps(prompt)) if extra_pnginfo is not None: @@ -118,7 +135,7 @@ def save_images(self, images, fps, filename_prefix, lossless, quality, method, n class SaveAnimatedPNG: def __init__(self): - self.output_dir = ldm_patched.utils.path_utils.get_output_directory() + self.output_dir = folder_paths.get_output_directory() self.type = "output" self.prefix_append = "" @@ -126,7 +143,7 @@ def __init__(self): def INPUT_TYPES(s): return {"required": {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ldm_patched"}), + "filename_prefix": ("STRING", {"default": "ComfyUI"}), "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), "compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) }, @@ -140,9 +157,9 @@ def INPUT_TYPES(s): CATEGORY = "image/animation" - def save_images(self, images, fps, compress_level, filename_prefix="ldm_patched", prompt=None, extra_pnginfo=None): + def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) results = list() pil_images = [] for image in images: @@ -151,13 +168,13 @@ def save_images(self, images, fps, compress_level, filename_prefix="ldm_patched" pil_images.append(img) metadata = None - if not args.disable_server_info: + if not args.disable_metadata: metadata = PngInfo() if prompt is not None: - metadata.add(b"ldm_patched", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True) + metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True) if extra_pnginfo is not None: for x in extra_pnginfo: - metadata.add(b"ldm_patched", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) + metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) file = f"{filename}_{counter:05}_.png" pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) @@ -172,6 +189,7 @@ def save_images(self, images, fps, compress_level, filename_prefix="ldm_patched" NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, "RepeatImageBatch": RepeatImageBatch, + "ImageFromBatch": ImageFromBatch, "SaveAnimatedWEBP": SaveAnimatedWEBP, "SaveAnimatedPNG": SaveAnimatedPNG, } diff --git a/ldm_patched/contrib/external_ip2p.py b/ldm_patched/contrib/external_ip2p.py new file mode 100644 index 000000000..c2e70a84c --- /dev/null +++ b/ldm_patched/contrib/external_ip2p.py @@ -0,0 +1,45 @@ +import torch + +class InstructPixToPixConditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "pixels": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/instructpix2pix" + + def encode(self, positive, negative, pixels, vae): + x = (pixels.shape[1] // 8) * 8 + y = (pixels.shape[2] // 8) * 8 + + if pixels.shape[1] != x or pixels.shape[2] != y: + x_offset = (pixels.shape[1] % 8) // 2 + y_offset = (pixels.shape[2] % 8) // 2 + pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] + + concat_latent = vae.encode(pixels) + + out_latent = {} + out_latent["samples"] = torch.zeros_like(concat_latent) + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + d["concat_latent_image"] = concat_latent + n = [t[0], d] + c.append(n) + out.append(c) + return (out[0], out[1], out_latent) + +NODE_CLASS_MAPPINGS = { + "InstructPixToPixConditioning": InstructPixToPixConditioning, +} diff --git a/ldm_patched/contrib/external_latent.py b/ldm_patched/contrib/external_latent.py index 6d753d0f7..b0afec15d 100644 --- a/ldm_patched/contrib/external_latent.py +++ b/ldm_patched/contrib/external_latent.py @@ -1,5 +1,3 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import ldm_patched.modules.utils import torch @@ -128,7 +126,7 @@ class LatentBatchSeedBehavior: @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), - "seed_behavior": (["random", "fixed"],),}} + "seed_behavior": (["random", "fixed"],{"default": "fixed"}),}} RETURN_TYPES = ("LATENT",) FUNCTION = "op" diff --git a/ldm_patched/contrib/external_mask.py b/ldm_patched/contrib/external_mask.py index a86a7fe69..e81dddfd7 100644 --- a/ldm_patched/contrib/external_mask.py +++ b/ldm_patched/contrib/external_mask.py @@ -1,11 +1,9 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import numpy as np import scipy.ndimage import torch import ldm_patched.modules.utils -from ldm_patched.contrib.external import MAX_RESOLUTION +from external import MAX_RESOLUTION def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): source = source.to(destination.device) @@ -343,6 +341,24 @@ def expand_mask(self, mask, expand, tapered_corners): out.append(output) return (torch.stack(out, dim=0),) +class ThresholdMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mask": ("MASK",), + "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + FUNCTION = "image_to_mask" + + def image_to_mask(self, mask, value): + mask = (mask > value).float() + return (mask,) NODE_CLASS_MAPPINGS = { @@ -357,6 +373,7 @@ def expand_mask(self, mask, expand, tapered_corners): "MaskComposite": MaskComposite, "FeatherMask": FeatherMask, "GrowMask": GrowMask, + "ThresholdMask": ThresholdMask, } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/ldm_patched/contrib/external_model_advanced.py b/ldm_patched/contrib/external_model_advanced.py index b9f0ebdca..779710aa5 100644 --- a/ldm_patched/contrib/external_model_advanced.py +++ b/ldm_patched/contrib/external_model_advanced.py @@ -1,8 +1,7 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - -import ldm_patched.utils.path_utils +import ldm_patched.utils.path_utils as folder_paths import ldm_patched.modules.sd import ldm_patched.modules.model_sampling +import ldm_patched.modules.latent_formats import torch class LCM(ldm_patched.modules.model_sampling.EPS): @@ -19,6 +18,10 @@ def calculate_denoised(self, sigma, model_output, model_input): return c_out * x0 + c_skip * model_input +class X0(ldm_patched.modules.model_sampling.EPS): + def calculate_denoised(self, sigma, model_output, model_input): + return model_output + class ModelSamplingDiscreteDistilled(ldm_patched.modules.model_sampling.ModelSamplingDiscrete): original_timesteps = 50 @@ -70,7 +73,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction", "lcm", "tcd"]), + "sampling": (["eps", "v_prediction", "lcm", "x0"],), "zsnr": ("BOOLEAN", {"default": False}), }} @@ -90,9 +93,8 @@ def patch(self, model, sampling, zsnr): elif sampling == "lcm": sampling_type = LCM sampling_base = ModelSamplingDiscreteDistilled - elif sampling == "tcd": - sampling_type = ldm_patched.modules.model_sampling.EPS - sampling_base = ModelSamplingDiscreteDistilled + elif sampling == "x0": + sampling_type = X0 class ModelSamplingAdvanced(sampling_base, sampling_type): pass @@ -104,6 +106,32 @@ class ModelSamplingAdvanced(sampling_base, sampling_type): m.add_object_patch("model_sampling", model_sampling) return (m, ) +class ModelSamplingStableCascade: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "shift": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step":0.01}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, shift): + m = model.clone() + + sampling_base = ldm_patched.modules.model_sampling.StableCascadeSampling + sampling_type = ldm_patched.modules.model_sampling.EPS + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): @@ -184,5 +212,6 @@ def rescale_cfg(args): NODE_CLASS_MAPPINGS = { "ModelSamplingDiscrete": ModelSamplingDiscrete, "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, + "ModelSamplingStableCascade": ModelSamplingStableCascade, "RescaleCFG": RescaleCFG, } diff --git a/ldm_patched/contrib/external_model_downscale.py b/ldm_patched/contrib/external_model_downscale.py index 4f1da54de..f06b9ecd4 100644 --- a/ldm_patched/contrib/external_model_downscale.py +++ b/ldm_patched/contrib/external_model_downscale.py @@ -1,5 +1,3 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import torch import ldm_patched.modules.utils @@ -22,8 +20,9 @@ def INPUT_TYPES(s): CATEGORY = "_for_testing" def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method): - sigma_start = model.model.model_sampling.percent_to_sigma(start_percent) - sigma_end = model.model.model_sampling.percent_to_sigma(end_percent) + model_sampling = model.get_model_object("model_sampling") + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) def input_block_patch(h, transformer_options): if transformer_options["block"][1] == block_number: diff --git a/ldm_patched/contrib/external_model_merging.py b/ldm_patched/contrib/external_model_merging.py index ae8145d4f..6c11c5de2 100644 --- a/ldm_patched/contrib/external_model_merging.py +++ b/ldm_patched/contrib/external_model_merging.py @@ -1,11 +1,12 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import ldm_patched.modules.sd import ldm_patched.modules.utils import ldm_patched.modules.model_base import ldm_patched.modules.model_management +import ldm_patched.modules.model_sampling + +import torch +import ldm_patched.utils.path_utils as folder_paths -import ldm_patched.utils.path_utils import json import os @@ -89,6 +90,50 @@ def merge(self, clip1, clip2, ratio): m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) + +class CLIPSubtract: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip1": ("CLIP",), + "clip2": ("CLIP",), + "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "merge" + + CATEGORY = "advanced/model_merging" + + def merge(self, clip1, clip2, multiplier): + m = clip1.clone() + kp = clip2.get_key_patches() + for k in kp: + if k.endswith(".position_ids") or k.endswith(".logit_scale"): + continue + m.add_patches({k: kp[k]}, - multiplier, multiplier) + return (m, ) + + +class CLIPAdd: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip1": ("CLIP",), + "clip2": ("CLIP",), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "merge" + + CATEGORY = "advanced/model_merging" + + def merge(self, clip1, clip2): + m = clip1.clone() + kp = clip2.get_key_patches() + for k in kp: + if k.endswith(".position_ids") or k.endswith(".logit_scale"): + continue + m.add_patches({k: kp[k]}, 1.0, 1.0) + return (m, ) + + class ModelMergeBlocks: @classmethod def INPUT_TYPES(s): @@ -122,7 +167,7 @@ def merge(self, model1, model2, **kwargs): return (m, ) def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None): - full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, output_dir) + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir) prompt_info = "" if prompt is not None: prompt_info = json.dumps(prompt) @@ -131,9 +176,14 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi enable_modelspec = True if isinstance(model.model, ldm_patched.modules.model_base.SDXL): - metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" + if isinstance(model.model, ldm_patched.modules.model_base.SDXL_instructpix2pix): + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-edit" + else: + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" elif isinstance(model.model, ldm_patched.modules.model_base.SDXLRefiner): metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" + elif isinstance(model.model, ldm_patched.modules.model_base.SVD_img2vid): + metadata["modelspec.architecture"] = "stable-video-diffusion-img2vid-v1" else: enable_modelspec = False @@ -147,12 +197,19 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", # "v2-inpainting" + extra_keys = {} + model_sampling = model.get_model_object("model_sampling") + if isinstance(model_sampling, ldm_patched.modules.model_sampling.ModelSamplingContinuousEDM): + if isinstance(model_sampling, ldm_patched.modules.model_sampling.V_PREDICTION): + extra_keys["edm_vpred.sigma_max"] = torch.tensor(model_sampling.sigma_max).float() + extra_keys["edm_vpred.sigma_min"] = torch.tensor(model_sampling.sigma_min).float() + if model.model.model_type == ldm_patched.modules.model_base.ModelType.EPS: metadata["modelspec.predict_key"] = "epsilon" elif model.model.model_type == ldm_patched.modules.model_base.ModelType.V_PREDICTION: metadata["modelspec.predict_key"] = "v" - if not args.disable_server_info: + if not args.disable_metadata: metadata["prompt"] = prompt_info if extra_pnginfo is not None: for x in extra_pnginfo: @@ -161,18 +218,18 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) - ldm_patched.modules.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata) + ldm_patched.modules.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys) class CheckpointSave: def __init__(self): - self.output_dir = ldm_patched.utils.path_utils.get_output_directory() + self.output_dir = folder_paths.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "clip": ("CLIP",), "vae": ("VAE",), - "filename_prefix": ("STRING", {"default": "checkpoints/ldm_patched"}),}, + "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} RETURN_TYPES = () FUNCTION = "save" @@ -186,12 +243,12 @@ def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=Non class CLIPSave: def __init__(self): - self.output_dir = ldm_patched.utils.path_utils.get_output_directory() + self.output_dir = folder_paths.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "clip": ("CLIP",), - "filename_prefix": ("STRING", {"default": "clip/ldm_patched"}),}, + "filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} RETURN_TYPES = () FUNCTION = "save" @@ -205,13 +262,13 @@ def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None): prompt_info = json.dumps(prompt) metadata = {} - if not args.disable_server_info: + if not args.disable_metadata: metadata["prompt"] = prompt_info if extra_pnginfo is not None: for x in extra_pnginfo: metadata[x] = json.dumps(extra_pnginfo[x]) - ldm_patched.modules.model_management.load_models_gpu([clip.load_model()]) + ldm_patched.modules.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True) clip_sd = clip.get_sd() for prefix in ["clip_l.", "clip_g.", ""]: @@ -230,7 +287,7 @@ def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None): replace_prefix[prefix] = "" replace_prefix["transformer."] = "" - full_output_folder, filename, counter, subfolder, filename_prefix_ = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix_, self.output_dir) + full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir) output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) @@ -242,12 +299,12 @@ def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None): class VAESave: def __init__(self): - self.output_dir = ldm_patched.utils.path_utils.get_output_directory() + self.output_dir = folder_paths.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "vae": ("VAE",), - "filename_prefix": ("STRING", {"default": "vae/ldm_patched_vae"}),}, + "filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} RETURN_TYPES = () FUNCTION = "save" @@ -256,13 +313,13 @@ def INPUT_TYPES(s): CATEGORY = "advanced/model_merging" def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None): - full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, self.output_dir) + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) prompt_info = "" if prompt is not None: prompt_info = json.dumps(prompt) metadata = {} - if not args.disable_server_info: + if not args.disable_metadata: metadata["prompt"] = prompt_info if extra_pnginfo is not None: for x in extra_pnginfo: @@ -281,6 +338,8 @@ def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None): "ModelMergeAdd": ModelAdd, "CheckpointSave": CheckpointSave, "CLIPMergeSimple": CLIPMergeSimple, + "CLIPMergeSubtract": CLIPSubtract, + "CLIPMergeAdd": CLIPAdd, "CLIPSave": CLIPSave, "VAESave": VAESave, } diff --git a/ldm_patched/contrib/external_model_merging_model_specific.py b/ldm_patched/contrib/external_model_merging_model_specific.py new file mode 100644 index 000000000..152e8c998 --- /dev/null +++ b/ldm_patched/contrib/external_model_merging_model_specific.py @@ -0,0 +1,60 @@ +import external_model_merging + +class ModelMergeSD1(external_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["time_embed."] = argument + arg_dict["label_emb."] = argument + + for i in range(12): + arg_dict["input_blocks.{}.".format(i)] = argument + + for i in range(3): + arg_dict["middle_block.{}.".format(i)] = argument + + for i in range(12): + arg_dict["output_blocks.{}.".format(i)] = argument + + arg_dict["out."] = argument + + return {"required": arg_dict} + + +class ModelMergeSDXL(external_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["time_embed."] = argument + arg_dict["label_emb."] = argument + + for i in range(9): + arg_dict["input_blocks.{}".format(i)] = argument + + for i in range(3): + arg_dict["middle_block.{}".format(i)] = argument + + for i in range(9): + arg_dict["output_blocks.{}".format(i)] = argument + + arg_dict["out."] = argument + + return {"required": arg_dict} + + +NODE_CLASS_MAPPINGS = { + "ModelMergeSD1": ModelMergeSD1, + "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks + "ModelMergeSDXL": ModelMergeSDXL, +} diff --git a/ldm_patched/contrib/external_morphology.py b/ldm_patched/contrib/external_morphology.py new file mode 100644 index 000000000..383753f4a --- /dev/null +++ b/ldm_patched/contrib/external_morphology.py @@ -0,0 +1,49 @@ +import torch +import ldm_patched.modules.model_management + +from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat + + +class Morphology: + @classmethod + def INPUT_TYPES(s): + return {"required": {"image": ("IMAGE",), + "operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],), + "kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}), + }} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + + CATEGORY = "image/postprocessing" + + def process(self, image, operation, kernel_size): + device = ldm_patched.modules.model_management.get_torch_device() + kernel = torch.ones(kernel_size, kernel_size, device=device) + image_k = image.to(device).movedim(-1, 1) + if operation == "erode": + output = erosion(image_k, kernel) + elif operation == "dilate": + output = dilation(image_k, kernel) + elif operation == "open": + output = opening(image_k, kernel) + elif operation == "close": + output = closing(image_k, kernel) + elif operation == "gradient": + output = gradient(image_k, kernel) + elif operation == "top_hat": + output = top_hat(image_k, kernel) + elif operation == "bottom_hat": + output = bottom_hat(image_k, kernel) + else: + raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'") + img_out = output.to(ldm_patched.modules.model_management.intermediate_device()).movedim(1, -1) + return (img_out,) + +NODE_CLASS_MAPPINGS = { + "Morphology": Morphology, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "Morphology": "ImageMorphology", +} \ No newline at end of file diff --git a/ldm_patched/contrib/external_pag.py b/ldm_patched/contrib/external_pag.py new file mode 100644 index 000000000..537f28205 --- /dev/null +++ b/ldm_patched/contrib/external_pag.py @@ -0,0 +1,56 @@ +#Modified/simplified version of the node from: https://github.com/pamparamm/sd-perturbed-attention +#If you want the one with more options see the above repo. + +#My modified one here is more basic but has less chances of breaking with ComfyUI updates. + +import ldm_patched.modules.model_patcher +import ldm_patched.modules.samplers + +class PerturbedAttentionGuidance: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, scale): + unet_block = "middle" + unet_block_id = 0 + m = model.clone() + + def perturbed_attention(q, k, v, extra_options, mask=None): + return v + + def post_cfg_function(args): + model = args["model"] + cond_pred = args["cond_denoised"] + cond = args["cond"] + cfg_result = args["denoised"] + sigma = args["sigma"] + model_options = args["model_options"].copy() + x = args["input"] + + if scale == 0: + return cfg_result + + # Replace Self-attention with PAG + model_options = ldm_patched.modules.model_patcher.set_model_options_patch_replace(model_options, perturbed_attention, "attn1", unet_block, unet_block_id) + (pag,) = ldm_patched.modules.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) + + return cfg_result + (cond_pred - pag) * scale + + m.set_model_sampler_post_cfg_function(post_cfg_function) + + return (m,) + +NODE_CLASS_MAPPINGS = { + "PerturbedAttentionGuidance": PerturbedAttentionGuidance, +} diff --git a/ldm_patched/contrib/external_perpneg.py b/ldm_patched/contrib/external_perpneg.py index ec91681fe..a6341a618 100644 --- a/ldm_patched/contrib/external_perpneg.py +++ b/ldm_patched/contrib/external_perpneg.py @@ -1,18 +1,26 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import torch import ldm_patched.modules.model_management -import ldm_patched.modules.sample +import ldm_patched.modules.sampler_helpers import ldm_patched.modules.samplers import ldm_patched.modules.utils +import ldm_patched.utils.node_helpers as node_helpers + +def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale): + pos = noise_pred_pos - noise_pred_nocond + neg = noise_pred_neg - noise_pred_nocond + perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos + perp_neg = perp * neg_scale + cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg) + return cfg_result +#TODO: This node should be removed, it has been replaced with PerpNegGuider class PerpNeg: @classmethod def INPUT_TYPES(s): return {"required": {"model": ("MODEL", ), "empty_conditioning": ("CONDITIONING", ), - "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}), + "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" @@ -21,7 +29,7 @@ def INPUT_TYPES(s): def patch(self, model, empty_conditioning, neg_scale): m = model.clone() - nocond = ldm_patched.modules.sample.convert_cond(empty_conditioning) + nocond = ldm_patched.modules.sampler_helpers.convert_cond(empty_conditioning) def cfg_function(args): model = args["model"] @@ -33,14 +41,9 @@ def cfg_function(args): model_options = args["model_options"] nocond_processed = ldm_patched.modules.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative") - (noise_pred_nocond, _) = ldm_patched.modules.samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options) + (noise_pred_nocond,) = ldm_patched.modules.samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options) - pos = noise_pred_pos - noise_pred_nocond - neg = noise_pred_neg - noise_pred_nocond - perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg - perp_neg = perp * neg_scale - cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg) - cfg_result = x - cfg_result + cfg_result = x - perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale) return cfg_result m.set_model_sampler_cfg_function(cfg_function) @@ -48,10 +51,78 @@ def cfg_function(args): return (m, ) +class Guider_PerpNeg(ldm_patched.modules.samplers.CFGGuider): + def set_conds(self, positive, negative, empty_negative_prompt): + empty_negative_prompt = node_helpers.conditioning_set_values(empty_negative_prompt, {"prompt_type": "negative"}) + self.inner_set_conds({"positive": positive, "empty_negative_prompt": empty_negative_prompt, "negative": negative}) + + def set_cfg(self, cfg, neg_scale): + self.cfg = cfg + self.neg_scale = neg_scale + + def predict_noise(self, x, timestep, model_options={}, seed=None): + # in CFGGuider.predict_noise, we call sampling_function(), which uses cfg_function() to compute pos & neg + # but we'd rather do a single batch of sampling pos, neg, and empty, so we call calc_cond_batch([pos,neg,empty]) directly + + positive_cond = self.conds.get("positive", None) + negative_cond = self.conds.get("negative", None) + empty_cond = self.conds.get("empty_negative_prompt", None) + + (noise_pred_pos, noise_pred_neg, noise_pred_empty) = \ + ldm_patched.modules.samplers.calc_cond_batch(self.inner_model, [positive_cond, negative_cond, empty_cond], x, timestep, model_options) + cfg_result = perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_empty, self.neg_scale, self.cfg) + + # normally this would be done in cfg_function, but we skipped + # that for efficiency: we can compute the noise predictions in + # a single call to calc_cond_batch() (rather than two) + # so we replicate the hook here + for fn in model_options.get("sampler_post_cfg_function", []): + args = { + "denoised": cfg_result, + "cond": positive_cond, + "uncond": negative_cond, + "model": self.inner_model, + "uncond_denoised": noise_pred_neg, + "cond_denoised": noise_pred_pos, + "sigma": timestep, + "model_options": model_options, + "input": x, + # not in the original call in samplers.py:cfg_function, but made available for future hooks + "empty_cond": empty_cond, + "empty_cond_denoised": noise_pred_empty,} + cfg_result = fn(args) + + return cfg_result + +class PerpNegGuider: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "empty_conditioning": ("CONDITIONING", ), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), + } + } + + RETURN_TYPES = ("GUIDER",) + + FUNCTION = "get_guider" + CATEGORY = "_for_testing" + + def get_guider(self, model, positive, negative, empty_conditioning, cfg, neg_scale): + guider = Guider_PerpNeg(model) + guider.set_conds(positive, negative, empty_conditioning) + guider.set_cfg(cfg, neg_scale) + return (guider,) + NODE_CLASS_MAPPINGS = { "PerpNeg": PerpNeg, + "PerpNegGuider": PerpNegGuider, } NODE_DISPLAY_NAME_MAPPINGS = { - "PerpNeg": "Perp-Neg", + "PerpNeg": "Perp-Neg (DEPRECATED by PerpNegGuider)", } diff --git a/ldm_patched/contrib/external_photomaker.py b/ldm_patched/contrib/external_photomaker.py index cc7f67100..dbac56747 100644 --- a/ldm_patched/contrib/external_photomaker.py +++ b/ldm_patched/contrib/external_photomaker.py @@ -1,8 +1,6 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import torch import torch.nn as nn -import ldm_patched.utils.path_utils +import ldm_patched.utils.path_utils as folder_paths import ldm_patched.modules.clip_model import ldm_patched.modules.clip_vision import ldm_patched.modules.ops @@ -120,7 +118,7 @@ def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask): class PhotoMakerLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "photomaker_model_name": (ldm_patched.utils.path_utils.get_filename_list("photomaker"), )}} + return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}} RETURN_TYPES = ("PHOTOMAKER",) FUNCTION = "load_photomaker_model" @@ -128,7 +126,7 @@ def INPUT_TYPES(s): CATEGORY = "_for_testing/photomaker" def load_photomaker_model(self, photomaker_model_name): - photomaker_model_path = ldm_patched.utils.path_utils.get_full_path("photomaker", photomaker_model_name) + photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name) photomaker_model = PhotoMakerIDEncoder() data = ldm_patched.modules.utils.load_torch_file(photomaker_model_path, safe_load=True) if "id_encoder" in data: @@ -143,7 +141,7 @@ def INPUT_TYPES(s): return {"required": { "photomaker": ("PHOTOMAKER",), "image": ("IMAGE",), "clip": ("CLIP", ), - "text": ("STRING", {"multiline": True, "default": "photograph of photomaker"}), + "text": ("STRING", {"multiline": True, "dynamicPrompts": True, "default": "photograph of photomaker"}), }} RETURN_TYPES = ("CONDITIONING",) diff --git a/ldm_patched/contrib/external_post_processing.py b/ldm_patched/contrib/external_post_processing.py index 93cb12122..9dccf538c 100644 --- a/ldm_patched/contrib/external_post_processing.py +++ b/ldm_patched/contrib/external_post_processing.py @@ -1,5 +1,3 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import numpy as np import torch import torch.nn.functional as F @@ -7,6 +5,7 @@ import math import ldm_patched.modules.utils +import ldm_patched.modules.model_management class Blend: @@ -104,6 +103,7 @@ def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): if blur_radius == 0: return (image,) + image = image.to(ldm_patched.modules.model_management.get_torch_device()) batch_size, height, width, channels = image.shape kernel_size = blur_radius * 2 + 1 @@ -114,7 +114,7 @@ def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = blurred.permute(0, 2, 3, 1) - return (blurred,) + return (blurred.to(ldm_patched.modules.model_management.intermediate_device()),) class Quantize: def __init__(self): @@ -206,13 +206,13 @@ def INPUT_TYPES(s): "default": 1.0, "min": 0.1, "max": 10.0, - "step": 0.1 + "step": 0.01 }), "alpha": ("FLOAT", { "default": 1.0, "min": 0.0, "max": 5.0, - "step": 0.1 + "step": 0.01 }), }, } @@ -227,6 +227,7 @@ def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: return (image,) batch_size, height, width, channels = image.shape + image = image.to(ldm_patched.modules.model_management.get_torch_device()) kernel_size = sharpen_radius * 2 + 1 kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10) @@ -241,7 +242,7 @@ def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: result = torch.clamp(sharpened, 0, 1) - return (result,) + return (result.to(ldm_patched.modules.model_management.intermediate_device()),) class ImageScaleToTotalPixels: upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] diff --git a/ldm_patched/contrib/external_rebatch.py b/ldm_patched/contrib/external_rebatch.py index c24cc8c32..3010fbd4b 100644 --- a/ldm_patched/contrib/external_rebatch.py +++ b/ldm_patched/contrib/external_rebatch.py @@ -1,5 +1,3 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import torch class LatentRebatch: diff --git a/ldm_patched/contrib/external_sag.py b/ldm_patched/contrib/external_sag.py index 804d56113..959a03087 100644 --- a/ldm_patched/contrib/external_sag.py +++ b/ldm_patched/contrib/external_sag.py @@ -1,18 +1,15 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import torch from torch import einsum import torch.nn.functional as F import math from einops import rearrange, repeat -import os -from ldm_patched.ldm.modules.attention import optimized_attention, _ATTN_PRECISION +from ldm_patched.ldm.modules.attention import optimized_attention import ldm_patched.modules.samplers -# from ldm_patched.modules/ldm/modules/attention.py +# from comfy/ldm/modules/attention.py # but modified to return attention scores as well as output -def attention_basic_with_sim(q, k, v, heads, mask=None): +def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None): b, _, dim_head = q.shape dim_head //= heads scale = dim_head ** -0.5 @@ -28,7 +25,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None): ) # force cast to fp32 to avoid overflowing - if _ATTN_PRECISION =="fp32": + if attn_precision == torch.float32: sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale @@ -123,13 +120,13 @@ def attn_and_record(q, k, v, extra_options): if 1 in cond_or_uncond: uncond_index = cond_or_uncond.index(1) # do the entire attention operation, but save the attention scores to attn_scores - (out, sim) = attention_basic_with_sim(q, k, v, heads=heads) + (out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"]) # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn] n_slices = heads * b attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)] return out else: - return optimized_attention(q, k, v, heads=heads) + return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"]) def post_cfg_function(args): nonlocal attn_scores @@ -152,7 +149,7 @@ def post_cfg_function(args): degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) degraded_noised = degraded + x - uncond_pred # call into the UNet - (sag, _) = ldm_patched.modules.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) + (sag,) = ldm_patched.modules.samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options) return cfg_result + (degraded - sag) * sag_scale m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True) diff --git a/ldm_patched/contrib/external_sdupscale.py b/ldm_patched/contrib/external_sdupscale.py index 68153c478..61318bcb4 100644 --- a/ldm_patched/contrib/external_sdupscale.py +++ b/ldm_patched/contrib/external_sdupscale.py @@ -1,7 +1,4 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import torch -import ldm_patched.contrib.external import ldm_patched.modules.utils class SD_4XUpscale_Conditioning: diff --git a/ldm_patched/contrib/external_stable3d.py b/ldm_patched/contrib/external_stable3d.py index bae2623fa..7ecf3ab55 100644 --- a/ldm_patched/contrib/external_stable3d.py +++ b/ldm_patched/contrib/external_stable3d.py @@ -1,7 +1,5 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import torch -import ldm_patched.contrib.external +import external import ldm_patched.modules.utils def camera_embeddings(elevation, azimuth): @@ -28,11 +26,11 @@ def INPUT_TYPES(s): return {"required": { "clip_vision": ("CLIP_VISION",), "init_image": ("IMAGE",), "vae": ("VAE",), - "width": ("INT", {"default": 256, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 256, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}), + "width": ("INT", {"default": 256, "min": 16, "max": external.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 256, "min": 16, "max": external.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") @@ -61,13 +59,13 @@ def INPUT_TYPES(s): return {"required": { "clip_vision": ("CLIP_VISION",), "init_image": ("IMAGE",), "vae": ("VAE",), - "width": ("INT", {"default": 256, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 256, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}), + "width": ("INT", {"default": 256, "min": 16, "max": external.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 256, "min": 16, "max": external.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") @@ -97,8 +95,49 @@ def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevat latent = torch.zeros([batch_size, 4, height // 8, width // 8]) return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) +class SV3D_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 576, "min": 16, "max": external.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 576, "min": 16, "max": external.MAX_RESOLUTION, "step": 8}), + "video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = ldm_patched.modules.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + + azimuth = 0 + azimuth_increment = 360 / (max(video_frames, 2) - 1) + + elevations = [] + azimuths = [] + for i in range(video_frames): + elevations.append(elevation) + azimuths.append(azimuth) + azimuth += azimuth_increment + + positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]] + latent = torch.zeros([video_frames, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent}) + NODE_CLASS_MAPPINGS = { "StableZero123_Conditioning": StableZero123_Conditioning, "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched, + "SV3D_Conditioning": SV3D_Conditioning, } diff --git a/ldm_patched/contrib/external_stable_cascade.py b/ldm_patched/contrib/external_stable_cascade.py new file mode 100644 index 000000000..4716df26c --- /dev/null +++ b/ldm_patched/contrib/external_stable_cascade.py @@ -0,0 +1,140 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +import external +import ldm_patched.modules.utils + + +class StableCascade_EmptyLatentImage: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "width": ("INT", {"default": 1024, "min": 256, "max": external.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 1024, "min": 256, "max": external.MAX_RESOLUTION, "step": 8}), + "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}) + }} + RETURN_TYPES = ("LATENT", "LATENT") + RETURN_NAMES = ("stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "latent/stable_cascade" + + def generate(self, width, height, compression, batch_size=1): + c_latent = torch.zeros([batch_size, 16, height // compression, width // compression]) + b_latent = torch.zeros([batch_size, 4, height // 4, width // 4]) + return ({ + "samples": c_latent, + }, { + "samples": b_latent, + }) + +class StableCascade_StageC_VAEEncode: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE",), + "vae": ("VAE", ), + "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), + }} + RETURN_TYPES = ("LATENT", "LATENT") + RETURN_NAMES = ("stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "latent/stable_cascade" + + def generate(self, image, vae, compression): + width = image.shape[-2] + height = image.shape[-3] + out_width = (width // compression) * vae.downscale_ratio + out_height = (height // compression) * vae.downscale_ratio + + s = ldm_patched.modules.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1) + + c_latent = vae.encode(s[:,:,:,:3]) + b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2]) + return ({ + "samples": c_latent, + }, { + "samples": b_latent, + }) + +class StableCascade_StageB_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "conditioning": ("CONDITIONING",), + "stage_c": ("LATENT",), + }} + RETURN_TYPES = ("CONDITIONING",) + + FUNCTION = "set_prior" + + CATEGORY = "conditioning/stable_cascade" + + def set_prior(self, conditioning, stage_c): + c = [] + for t in conditioning: + d = t[1].copy() + d['stable_cascade_prior'] = stage_c['samples'] + n = [t[0], d] + c.append(n) + return (c, ) + +class StableCascade_SuperResolutionControlnet: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE",), + "vae": ("VAE", ), + }} + RETURN_TYPES = ("IMAGE", "LATENT", "LATENT") + RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "_for_testing/stable_cascade" + + def generate(self, image, vae): + width = image.shape[-2] + height = image.shape[-3] + batch_size = image.shape[0] + controlnet_input = vae.encode(image[:,:,:,:3]).movedim(1, -1) + + c_latent = torch.zeros([batch_size, 16, height // 16, width // 16]) + b_latent = torch.zeros([batch_size, 4, height // 2, width // 2]) + return (controlnet_input, { + "samples": c_latent, + }, { + "samples": b_latent, + }) + +NODE_CLASS_MAPPINGS = { + "StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage, + "StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning, + "StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode, + "StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet, +} diff --git a/ldm_patched/contrib/external_tomesd.py b/ldm_patched/contrib/external_tomesd.py index b01d6910f..df0485063 100644 --- a/ldm_patched/contrib/external_tomesd.py +++ b/ldm_patched/contrib/external_tomesd.py @@ -1,5 +1,3 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - #Taken from: https://github.com/dbolya/tomesd import torch diff --git a/ldm_patched/contrib/external_upscale_model.py b/ldm_patched/contrib/external_upscale_model.py index 31d102f0e..5645584e7 100644 --- a/ldm_patched/contrib/external_upscale_model.py +++ b/ldm_patched/contrib/external_upscale_model.py @@ -1,16 +1,23 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - import os -from ldm_patched.pfn import model_loading +import logging +from spandrel import ModelLoader, ImageModelDescriptor from ldm_patched.modules import model_management import torch import ldm_patched.modules.utils -import ldm_patched.utils.path_utils +import ldm_patched.utils.path_utils as folder_paths + +try: + from spandrel_extra_arches import EXTRA_REGISTRY + from spandrel import MAIN_REGISTRY + MAIN_REGISTRY.add(*EXTRA_REGISTRY) + logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.") +except: + pass class UpscaleModelLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "model_name": (ldm_patched.utils.path_utils.get_filename_list("upscale_models"), ), + return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ), }} RETURN_TYPES = ("UPSCALE_MODEL",) FUNCTION = "load_model" @@ -18,11 +25,15 @@ def INPUT_TYPES(s): CATEGORY = "loaders" def load_model(self, model_name): - model_path = ldm_patched.utils.path_utils.get_full_path("upscale_models", model_name) + model_path = folder_paths.get_full_path("upscale_models", model_name) sd = ldm_patched.modules.utils.load_torch_file(model_path, safe_load=True) if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"module.":""}) - out = model_loading.load_state_dict(sd).eval() + out = ModelLoader().load_from_state_dict(sd).eval() + + if not isinstance(out, ImageModelDescriptor): + raise Exception("Upscale model must be a single-image model.") + return (out, ) @@ -39,9 +50,14 @@ def INPUT_TYPES(s): def upscale(self, upscale_model, image): device = model_management.get_torch_device() + + memory_required = model_management.module_size(upscale_model.model) + memory_required += (512 * 512 * 3) * image.element_size() * max(upscale_model.scale, 1.0) * 384.0 #The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate + memory_required += image.nelement() * image.element_size() + model_management.free_memory(memory_required, device) + upscale_model.to(device) in_img = image.movedim(-1,-3).to(device) - free_memory = model_management.get_free_memory(device) tile = 512 overlap = 32 @@ -58,7 +74,7 @@ def upscale(self, upscale_model, image): if tile < 128: raise e - upscale_model.cpu() + upscale_model.to("cpu") s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) return (s,) diff --git a/ldm_patched/contrib/external_video_model.py b/ldm_patched/contrib/external_video_model.py index 503df0e18..f5997ae6b 100644 --- a/ldm_patched/contrib/external_video_model.py +++ b/ldm_patched/contrib/external_video_model.py @@ -1,17 +1,15 @@ -# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py - -import ldm_patched.contrib.external +import external import torch import ldm_patched.modules.utils import ldm_patched.modules.sd -import ldm_patched.utils.path_utils -import ldm_patched.contrib.external_model_merging +import ldm_patched.utils.path_utils as folder_paths +import external_model_merging class ImageOnlyCheckpointLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "ckpt_name": (ldm_patched.utils.path_utils.get_filename_list("checkpoints"), ), + return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), }} RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") FUNCTION = "load_checkpoint" @@ -19,8 +17,8 @@ def INPUT_TYPES(s): CATEGORY = "loaders/video_models" def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): - ckpt_path = ldm_patched.utils.path_utils.get_full_path("checkpoints", ckpt_name) - out = ldm_patched.modules.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings")) + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + out = ldm_patched.modules.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return (out[0], out[3], out[2]) @@ -30,8 +28,8 @@ def INPUT_TYPES(s): return {"required": { "clip_vision": ("CLIP_VISION",), "init_image": ("IMAGE",), "vae": ("VAE",), - "width": ("INT", {"default": 1024, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 576, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}), + "width": ("INT", {"default": 1024, "min": 16, "max": external.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 576, "min": 16, "max": external.MAX_RESOLUTION, "step": 8}), "video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}), "motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}), "fps": ("INT", {"default": 6, "min": 1, "max": 1024}), @@ -81,7 +79,34 @@ def linear_cfg(args): m.set_model_sampler_cfg_function(linear_cfg) return (m, ) -class ImageOnlyCheckpointSave(ldm_patched.contrib.external_model_merging.CheckpointSave): +class VideoTriangleCFGGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "sampling/video_models" + + def patch(self, model, min_cfg): + def linear_cfg(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + period = 1.0 + values = torch.linspace(0, 1, cond.shape[0], device=cond.device) + values = 2 * (values / period - torch.floor(values / period + 0.5)).abs() + scale = (values * (cond_scale - min_cfg) + min_cfg).reshape((cond.shape[0], 1, 1, 1)) + + return uncond + scale * (cond - uncond) + + m = model.clone() + m.set_model_sampler_cfg_function(linear_cfg) + return (m, ) + +class ImageOnlyCheckpointSave(external_model_merging.CheckpointSave): CATEGORY = "_for_testing" @classmethod @@ -89,17 +114,18 @@ def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "clip_vision": ("CLIP_VISION",), "vae": ("VAE",), - "filename_prefix": ("STRING", {"default": "checkpoints/ldm_patched"}),}, + "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None): - ldm_patched.contrib.external_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) + external_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) return {} NODE_CLASS_MAPPINGS = { "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "VideoLinearCFGGuidance": VideoLinearCFGGuidance, + "VideoTriangleCFGGuidance": VideoTriangleCFGGuidance, "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave, } diff --git a/ldm_patched/contrib/external_webcam.py b/ldm_patched/contrib/external_webcam.py new file mode 100644 index 000000000..8ddaa3c00 --- /dev/null +++ b/ldm_patched/contrib/external_webcam.py @@ -0,0 +1,33 @@ +import external +import ldm_patched.utils.path_utils as folder_paths + +MAX_RESOLUTION = external.MAX_RESOLUTION + + +class WebcamCapture(external.LoadImage): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("WEBCAM", {}), + "width": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "capture_on_queue": ("BOOLEAN", {"default": True}), + } + } + RETURN_TYPES = ("IMAGE",) + FUNCTION = "load_capture" + + CATEGORY = "image" + + def load_capture(s, image, **kwargs): + return super().load_image(folder_paths.get_annotated_filepath(image)) + + +NODE_CLASS_MAPPINGS = { + "WebcamCapture": WebcamCapture, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "WebcamCapture": "Webcam Capture", +} \ No newline at end of file diff --git a/ldm_patched/controlnet/cldm.py b/ldm_patched/controlnet/cldm.py index 82265ef95..84dc0942b 100644 --- a/ldm_patched/controlnet/cldm.py +++ b/ldm_patched/controlnet/cldm.py @@ -52,6 +52,7 @@ def __init__( adm_in_channels=None, transformer_depth_middle=None, transformer_depth_output=None, + attn_precision=None, device=None, operations=ldm_patched.modules.ops.disable_weight_init, **kwargs, @@ -202,7 +203,7 @@ def __init__( SpatialTransformer( ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -262,7 +263,7 @@ def __init__( mid_block += [SpatialTransformer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations ), ResBlock( ch, @@ -309,4 +310,3 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs): outs.append(self.middle_block_out(h, emb, context)) return outs - diff --git a/ldm_patched/k_diffusion/sampling.py b/ldm_patched/k_diffusion/sampling.py index 4d9d4ea64..b23fe6b33 100644 --- a/ldm_patched/k_diffusion/sampling.py +++ b/ldm_patched/k_diffusion/sampling.py @@ -527,6 +527,9 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, @torch.no_grad() def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): """DPM-Solver++ (stochastic).""" + if len(sigmas) <= 1: + return x + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() seed = extra_args.get("seed", None) noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler @@ -595,6 +598,8 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No @torch.no_grad() def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): """DPM-Solver++(2M) SDE.""" + if len(sigmas) <= 1: + return x if solver_type not in {'heun', 'midpoint'}: raise ValueError('solver_type must be \'heun\' or \'midpoint\'') @@ -642,6 +647,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """DPM-Solver++(3M) SDE.""" + if len(sigmas) <= 1: + return x + seed = extra_args.get("seed", None) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler @@ -690,18 +698,27 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl @torch.no_grad() def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + if len(sigmas) <= 1: + return x + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) @torch.no_grad() def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): + if len(sigmas) <= 1: + return x + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) @torch.no_grad() def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): + if len(sigmas) <= 1: + return x + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) @@ -748,10 +765,11 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n x = denoised if sigmas[i + 1] > 0: - x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) + x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x) return x + @torch.no_grad() def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/ @@ -808,7 +826,7 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non x = x + d_prime * dt return x - +# used-by-Fooocus @torch.no_grad() def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, eta=0.3): extra_args = {} if extra_args is None else extra_args @@ -837,7 +855,7 @@ def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, n return x - +# used-by-Fooocus @torch.no_grad() def sample_restart(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None): """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023) @@ -905,4 +923,4 @@ def heun_step(x, old_sigma, new_sigma, second_order=True): x = heun_step(x, old_sigma, new_sigma) last_sigma = new_sigma - return x \ No newline at end of file + return x diff --git a/ldm_patched/ldm/cascade/common.py b/ldm_patched/ldm/cascade/common.py new file mode 100644 index 000000000..d783eb6a9 --- /dev/null +++ b/ldm_patched/ldm/cascade/common.py @@ -0,0 +1,161 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +import torch.nn as nn +from ldm_patched.ldm.modules.attention import optimized_attention + +class Linear(torch.nn.Linear): + def reset_parameters(self): + return None + +class Conv2d(torch.nn.Conv2d): + def reset_parameters(self): + return None + +class OptimizedAttention(nn.Module): + def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.heads = nhead + + self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + + self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + + def forward(self, q, k, v): + q = self.to_q(q) + k = self.to_k(k) + v = self.to_v(v) + + out = optimized_attention(q, k, v, self.heads) + + return self.out_proj(out) + +class Attention2D(nn.Module): + def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations) + # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device) + + def forward(self, x, kv, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn: + kv = torch.cat([x, kv], dim=1) + # x = self.attn(x, kv, kv, need_weights=False)[0] + x = self.attn(x, kv, kv) + x = x.permute(0, 2, 1).view(*orig_shape) + return x + + +def LayerNorm2d_op(operations): + class LayerNorm2d(operations.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return LayerNorm2d + +class GlobalResponseNorm(nn.Module): + "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" + def __init__(self, dim, dtype=None, device=None): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2): + super().__init__() + self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device) + # self.depthwise = SAMBlock(c, num_heads, expansion) + self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.channelwise = nn.Sequential( + operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device), + nn.GELU(), + GlobalResponseNorm(c * 4, dtype=dtype, device=device), + nn.Dropout(dropout), + operations.Linear(c * 4, c, dtype=dtype, device=device) + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + operations.Linear(c_cond, c, dtype=dtype, device=device) + ) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + return x + + +class FeedForwardBlock(nn.Module): + def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.channelwise = nn.Sequential( + operations.Linear(c, c * 4, dtype=dtype, device=device), + nn.GELU(), + GlobalResponseNorm(c * 4, dtype=dtype, device=device), + nn.Dropout(dropout), + operations.Linear(c * 4, c, dtype=dtype, device=device) + ) + + def forward(self, x): + x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None): + super().__init__() + self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)) + + def forward(self, x, t): + t = t.chunk(len(self.conds) + 1, dim=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) + a, b = a + ac, b + bc + return x * (1 + a) + b diff --git a/ldm_patched/ldm/cascade/controlnet.py b/ldm_patched/ldm/cascade/controlnet.py new file mode 100644 index 000000000..5dac59394 --- /dev/null +++ b/ldm_patched/ldm/cascade/controlnet.py @@ -0,0 +1,93 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +import torchvision +from torch import nn +from .common import LayerNorm2d_op + + +class CNetResBlock(nn.Module): + def __init__(self, c, dtype=None, device=None, operations=None): + super().__init__() + self.blocks = nn.Sequential( + LayerNorm2d_op(operations)(c, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c, c, kernel_size=3, padding=1), + LayerNorm2d_op(operations)(c, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c, c, kernel_size=3, padding=1), + ) + + def forward(self, x): + return x + self.blocks(x) + + +class ControlNet(nn.Module): + def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None, dtype=None, device=None, operations=nn): + super().__init__() + if bottleneck_mode is None: + bottleneck_mode = 'effnet' + self.proj_blocks = proj_blocks + if bottleneck_mode == 'effnet': + embd_channels = 1280 + self.backbone = torchvision.models.efficientnet_v2_s().features.eval() + if c_in != 3: + in_weights = self.backbone[0][0].weight.data + self.backbone[0][0] = operations.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False, dtype=dtype, device=device) + if c_in > 3: + # nn.init.constant_(self.backbone[0][0].weight, 0) + self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone() + else: + self.backbone[0][0].weight.data = in_weights[:, :c_in].clone() + elif bottleneck_mode == 'simple': + embd_channels = c_in + self.backbone = nn.Sequential( + operations.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1, dtype=dtype, device=device), + nn.LeakyReLU(0.2, inplace=True), + operations.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1, dtype=dtype, device=device), + ) + elif bottleneck_mode == 'large': + self.backbone = nn.Sequential( + operations.Conv2d(c_in, 4096 * 4, kernel_size=1, dtype=dtype, device=device), + nn.LeakyReLU(0.2, inplace=True), + operations.Conv2d(4096 * 4, 1024, kernel_size=1, dtype=dtype, device=device), + *[CNetResBlock(1024, dtype=dtype, device=device, operations=operations) for _ in range(8)], + operations.Conv2d(1024, 1280, kernel_size=1, dtype=dtype, device=device), + ) + embd_channels = 1280 + else: + raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}') + self.projections = nn.ModuleList() + for _ in range(len(proj_blocks)): + self.projections.append(nn.Sequential( + operations.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False, dtype=dtype, device=device), + nn.LeakyReLU(0.2, inplace=True), + operations.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False, dtype=dtype, device=device), + )) + # nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection + self.xl = False + self.input_channels = c_in + self.unshuffle_amount = 8 + + def forward(self, x): + x = self.backbone(x) + proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)] + for i, idx in enumerate(self.proj_blocks): + proj_outputs[idx] = self.projections[i](x) + return proj_outputs diff --git a/ldm_patched/ldm/cascade/stage_a.py b/ldm_patched/ldm/cascade/stage_a.py new file mode 100644 index 000000000..ca8867eaf --- /dev/null +++ b/ldm_patched/ldm/cascade/stage_a.py @@ -0,0 +1,255 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +from torch import nn +from torch.autograd import Function + +class vector_quantize(Function): + @staticmethod + def forward(ctx, x, codebook): + with torch.no_grad(): + codebook_sqr = torch.sum(codebook ** 2, dim=1) + x_sqr = torch.sum(x ** 2, dim=1, keepdim=True) + + dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0) + _, indices = dist.min(dim=1) + + ctx.save_for_backward(indices, codebook) + ctx.mark_non_differentiable(indices) + + nn = torch.index_select(codebook, 0, indices) + return nn, indices + + @staticmethod + def backward(ctx, grad_output, grad_indices): + grad_inputs, grad_codebook = None, None + + if ctx.needs_input_grad[0]: + grad_inputs = grad_output.clone() + if ctx.needs_input_grad[1]: + # Gradient wrt. the codebook + indices, codebook = ctx.saved_tensors + + grad_codebook = torch.zeros_like(codebook) + grad_codebook.index_add_(0, indices, grad_output) + + return (grad_inputs, grad_codebook) + + +class VectorQuantize(nn.Module): + def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False): + """ + Takes an input of variable size (as long as the last dimension matches the embedding size). + Returns one tensor containing the nearest neigbour embeddings to each of the inputs, + with the same size as the input, vq and commitment components for the loss as a touple + in the second output and the indices of the quantized vectors in the third: + quantized, (vq_loss, commit_loss), indices + """ + super(VectorQuantize, self).__init__() + + self.codebook = nn.Embedding(k, embedding_size) + self.codebook.weight.data.uniform_(-1./k, 1./k) + self.vq = vector_quantize.apply + + self.ema_decay = ema_decay + self.ema_loss = ema_loss + if ema_loss: + self.register_buffer('ema_element_count', torch.ones(k)) + self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight)) + + def _laplace_smoothing(self, x, epsilon): + n = torch.sum(x) + return ((x + epsilon) / (n + x.size(0) * epsilon) * n) + + def _updateEMA(self, z_e_x, indices): + mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float() + elem_count = mask.sum(dim=0) + weight_sum = torch.mm(mask.t(), z_e_x) + + self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count) + self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5) + self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum) + + self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) + + def idx2vq(self, idx, dim=-1): + q_idx = self.codebook(idx) + if dim != -1: + q_idx = q_idx.movedim(-1, dim) + return q_idx + + def forward(self, x, get_losses=True, dim=-1): + if dim != -1: + x = x.movedim(dim, -1) + z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x + z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach()) + vq_loss, commit_loss = None, None + if self.ema_loss and self.training: + self._updateEMA(z_e_x.detach(), indices.detach()) + # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss + z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices) + if get_losses: + vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean() + commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean() + + z_q_x = z_q_x.view(x.shape) + if dim != -1: + z_q_x = z_q_x.movedim(-1, dim) + return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1]) + + +class ResBlock(nn.Module): + def __init__(self, c, c_hidden): + super().__init__() + # depthwise/attention + self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential( + nn.ReplicationPad2d(1), + nn.Conv2d(c, c, kernel_size=3, groups=c) + ) + + # channelwise + self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c_hidden), + nn.GELU(), + nn.Linear(c_hidden, c), + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + + # Init weights + def _basic_init(module): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + def _norm(self, x, norm): + return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + def forward(self, x): + mods = self.gammas + + x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] + try: + x = x + self.depthwise(x_temp) * mods[2] + except: #operation not implemented for bf16 + x_temp = self.depthwise[0](x_temp.float()).to(x.dtype) + x = x + self.depthwise[1](x_temp) * mods[2] + + x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + + return x + + +class StageA(nn.Module): + def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192): + super().__init__() + self.c_latent = c_latent + c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] + + # Encoder blocks + self.in_block = nn.Sequential( + nn.PixelUnshuffle(2), + nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) + ) + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = ResBlock(c_levels[i], c_levels[i] * 4) + down_blocks.append(block) + down_blocks.append(nn.Sequential( + nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + )) + self.down_blocks = nn.Sequential(*down_blocks) + self.down_blocks[0] + + self.codebook_size = codebook_size + self.vquantizer = VectorQuantize(c_latent, k=codebook_size) + + # Decoder blocks + up_blocks = [nn.Sequential( + nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) + )] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) + up_blocks.append(block) + if i < levels - 1: + up_blocks.append( + nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, + padding=1)) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), + nn.PixelShuffle(2), + ) + + def encode(self, x, quantize=False): + x = self.in_block(x) + x = self.down_blocks(x) + if quantize: + qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) + return qe, x, indices, vq_loss + commit_loss * 0.25 + else: + return x + + def decode(self, x): + x = self.up_blocks(x) + x = self.out_block(x) + return x + + def forward(self, x, quantize=False): + qe, x, _, vq_loss = self.encode(x, quantize) + x = self.decode(qe) + return x, vq_loss + + +class Discriminator(nn.Module): + def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): + super().__init__() + d = max(depth - 3, 3) + layers = [ + nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), + nn.LeakyReLU(0.2), + ] + for i in range(depth - 1): + c_in = c_hidden // (2 ** max((d - i), 0)) + c_out = c_hidden // (2 ** max((d - 1 - i), 0)) + layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.InstanceNorm2d(c_out)) + layers.append(nn.LeakyReLU(0.2)) + self.encoder = nn.Sequential(*layers) + self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) + self.logits = nn.Sigmoid() + + def forward(self, x, cond=None): + x = self.encoder(x) + if cond is not None: + cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) + x = torch.cat([x, cond], dim=1) + x = self.shuffle(x) + x = self.logits(x) + return x diff --git a/ldm_patched/ldm/cascade/stage_b.py b/ldm_patched/ldm/cascade/stage_b.py new file mode 100644 index 000000000..7c3d8feab --- /dev/null +++ b/ldm_patched/ldm/cascade/stage_b.py @@ -0,0 +1,256 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import math +import torch +from torch import nn +from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock + +class StageB(nn.Module): + def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280], + nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], + block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280, + c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True, + t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None): + super().__init__() + self.dtype = dtype + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.effnet_mapper = nn.Sequential( + operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + ) + self.pixels_mapper = nn.Sequential( + operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + ) + self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device) + self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device), + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device), + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + # self.apply(self._init_weights) # General init + # nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings + # nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings + # nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings + # nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings + # nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings + # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + # nn.init.constant_(self.clf[1].weight, 0) # outputs + # + # # blocks + # for level_block in self.down_blocks + self.up_blocks: + # for block in level_block: + # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + # elif isinstance(block, TimestepBlock): + # for layer in block.modules(): + # if isinstance(layer, nn.Linear): + # nn.init.constant_(layer.weight, 0) + # + # def _init_weights(self, m): + # if isinstance(m, (nn.Conv2d, nn.Linear)): + # torch.nn.init.xavier_uniform_(m.weight) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip): + if len(clip.shape) == 2: + clip = clip.unsqueeze(1) + clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, effnet, clip, pixels=None, **kwargs): + if pixels is None: + pixels = x.new_zeros(x.size(0), 3, 8, 8) + + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1) + clip = self.gen_c_embeddings(clip) + + # Model Blocks + x = self.embedding(x) + x = x + self.effnet_mapper( + nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True)) + x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear', + align_corners=True) + level_outputs = self._down_encode(x, r_embed, clip) + x = self._up_decode(level_outputs, r_embed, clip) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/ldm_patched/ldm/cascade/stage_c.py b/ldm_patched/ldm/cascade/stage_c.py new file mode 100644 index 000000000..c85da1f01 --- /dev/null +++ b/ldm_patched/ldm/cascade/stage_c.py @@ -0,0 +1,273 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +from torch import nn +import math +from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock +# from .controlnet import ControlNetDeliverer + +class UpDownBlock2d(nn.Module): + def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None): + super().__init__() + assert mode in ['up', 'down'] + interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear', + align_corners=True) if enabled else nn.Identity() + mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device) + self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation]) + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + +class StageC(nn.Module): + def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32], + blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'], + c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3, + dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None, + dtype=None, device=None, operations=None): + super().__init__() + self.dtype = dtype + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device) + self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device) + self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device) + self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations) + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations) + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + # self.apply(self._init_weights) # General init + # nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings + # nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings + # nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings + # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + # nn.init.constant_(self.clf[1].weight, 0) # outputs + # + # # blocks + # for level_block in self.down_blocks + self.up_blocks: + # for block in level_block: + # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + # elif isinstance(block, TimestepBlock): + # for layer in block.modules(): + # if isinstance(layer, nn.Linear): + # nn.init.constant_(layer.weight, 0) + # + # def _init_weights(self, m): + # if isinstance(m, (nn.Conv2d, nn.Linear)): + # torch.nn.init.xavier_uniform_(m.weight) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): + clip_txt = self.clip_txt_mapper(clip_txt) + if len(clip_txt_pooled.shape) == 2: + clip_txt_pooled = clip_txt_pooled.unsqueeze(1) + if len(clip_img.shape) == 2: + clip_img = clip_img.unsqueeze(1) + clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1) + clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) + clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip, cnet=None): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + if cnet is not None: + next_cnet = cnet.pop() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', + align_corners=True).to(x.dtype) + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip, cnet=None): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear', + align_corners=True) + if cnet is not None: + next_cnet = cnet.pop() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', + align_corners=True).to(x.dtype) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs): + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1) + clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) + + if control is not None: + cnet = control.get("input") + else: + cnet = None + + # Model Blocks + x = self.embedding(x) + level_outputs = self._down_encode(x, r_embed, clip, cnet) + x = self._up_decode(level_outputs, r_embed, clip, cnet) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/ldm_patched/ldm/cascade/stage_c_coder.py b/ldm_patched/ldm/cascade/stage_c_coder.py new file mode 100644 index 000000000..0cb7c49fc --- /dev/null +++ b/ldm_patched/ldm/cascade/stage_c_coder.py @@ -0,0 +1,95 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" +import torch +import torchvision +from torch import nn + + +# EfficientNet +class EfficientNetEncoder(nn.Module): + def __init__(self, c_latent=16): + super().__init__() + self.backbone = torchvision.models.efficientnet_v2_s().features.eval() + self.mapper = nn.Sequential( + nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 + ) + self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406])) + self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225])) + + def forward(self, x): + x = x * 0.5 + 0.5 + x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1]) + o = self.mapper(self.backbone(x)) + return o + + +# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 +class Previewer(nn.Module): + def __init__(self, c_in=16, c_hidden=512, c_out=3): + super().__init__() + self.blocks = nn.Sequential( + nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), + ) + + def forward(self, x): + return (self.blocks(x) - 0.5) * 2.0 + +class StageC_coder(nn.Module): + def __init__(self): + super().__init__() + self.previewer = Previewer() + self.encoder = EfficientNetEncoder() + + def encode(self, x): + return self.encoder(x) + + def decode(self, x): + return self.previewer(x) diff --git a/ldm_patched/ldm/models/autoencoder.py b/ldm_patched/ldm/models/autoencoder.py index c809a0c31..d3b16e773 100644 --- a/ldm_patched/ldm/models/autoencoder.py +++ b/ldm_patched/ldm/models/autoencoder.py @@ -1,6 +1,4 @@ import torch -# import pytorch_lightning as pl -import torch.nn.functional as F from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Union diff --git a/ldm_patched/ldm/modules/attention.py b/ldm_patched/ldm/modules/attention.py index e10a868d2..f17539b62 100644 --- a/ldm_patched/ldm/modules/attention.py +++ b/ldm_patched/ldm/modules/attention.py @@ -3,9 +3,10 @@ import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat -from typing import Optional, Any +from typing import Optional +import logging -from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding +from .diffusionmodules.util import AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention from ldm_patched.modules import model_management @@ -18,13 +19,14 @@ import ldm_patched.modules.ops ops = ldm_patched.modules.ops.disable_weight_init -# CrossAttn precision handling -if args.disable_attention_upcast: - print("disabling upcasting of attention") - _ATTN_PRECISION = "fp16" -else: - _ATTN_PRECISION = "fp32" +FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype() +def get_attn_precision(attn_precision): + if args.dont_upcast_attention: + return None + if FORCE_UPCAST_ATTENTION_DTYPE is not None: + return FORCE_UPCAST_ATTENTION_DTYPE + return attn_precision def exists(val): return val is not None @@ -84,7 +86,9 @@ def forward(self, x): def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) -def attention_basic(q, k, v, heads, mask=None): +def attention_basic(q, k, v, heads, mask=None, attn_precision=None): + attn_precision = get_attn_precision(attn_precision) + b, _, dim_head = q.shape dim_head //= heads scale = dim_head ** -0.5 @@ -100,7 +104,7 @@ def attention_basic(q, k, v, heads, mask=None): ) # force cast to fp32 to avoid overflowing - if _ATTN_PRECISION =="fp32": + if attn_precision == torch.float32: sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale @@ -114,7 +118,12 @@ def attention_basic(q, k, v, heads, mask=None): mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) else: - sim += mask + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + sim.add_(mask) # attention, what we cannot get enough of sim = sim.softmax(dim=-1) @@ -129,7 +138,9 @@ def attention_basic(q, k, v, heads, mask=None): return out -def attention_sub_quad(query, key, value, heads, mask=None): +def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None): + attn_precision = get_attn_precision(attn_precision) + b, _, dim_head = query.shape dim_head //= heads @@ -140,7 +151,7 @@ def attention_sub_quad(query, key, value, heads, mask=None): key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1) dtype = query.dtype - upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32 + upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32 if upcast_attention: bytes_per_token = torch.finfo(torch.float32).bits//8 else: @@ -165,6 +176,13 @@ def attention_sub_quad(query, key, value, heads, mask=None): if query_chunk_size is None: query_chunk_size = 512 + if mask is not None: + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + hidden_states = efficient_dot_product_attention( query, key, @@ -182,7 +200,9 @@ def attention_sub_quad(query, key, value, heads, mask=None): hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) return hidden_states -def attention_split(q, k, v, heads, mask=None): +def attention_split(q, k, v, heads, mask=None, attn_precision=None): + attn_precision = get_attn_precision(attn_precision) + b, _, dim_head = q.shape dim_head //= heads scale = dim_head ** -0.5 @@ -201,10 +221,12 @@ def attention_split(q, k, v, heads, mask=None): mem_free_total = model_management.get_free_memory(q.device) - if _ATTN_PRECISION =="fp32": + if attn_precision == torch.float32: element_size = 4 + upcast = True else: element_size = q.element_size() + upcast = False gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size @@ -223,6 +245,13 @@ def attention_split(q, k, v, heads, mask=None): raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + if mask is not None: + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) first_op_done = False cleared_cache = False @@ -231,7 +260,7 @@ def attention_split(q, k, v, heads, mask=None): slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = i + slice_size - if _ATTN_PRECISION =="fp32": + if upcast: with torch.autocast(enabled=False, device_type = 'cuda'): s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale else: @@ -255,12 +284,12 @@ def attention_split(q, k, v, heads, mask=None): model_management.soft_empty_cache(True) if cleared_cache == False: cleared_cache = True - print("out of memory error, emptying cache and trying again") + logging.warning("out of memory error, emptying cache and trying again") continue steps *= 2 if steps > 64: raise e - print("out of memory error, increasing steps and trying again", steps) + logging.warning("out of memory error, increasing steps and trying again {}".format(steps)) else: raise e @@ -277,24 +306,30 @@ def attention_split(q, k, v, heads, mask=None): BROKEN_XFORMERS = False try: x_vers = xformers.__version__ - #I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error) - BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23") + # XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error) + BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20") except: pass -def attention_xformers(q, k, v, heads, mask=None): +def attention_xformers(q, k, v, heads, mask=None, attn_precision=None): b, _, dim_head = q.shape dim_head //= heads + + disabled_xformers = False + if BROKEN_XFORMERS: if b * heads > 65535: - return attention_pytorch(q, k, v, heads, mask) + disabled_xformers = True + + if not disabled_xformers: + if torch.jit.is_tracing() or torch.jit.is_scripting(): + disabled_xformers = True + + if disabled_xformers: + return attention_pytorch(q, k, v, heads, mask) q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, -1, heads, dim_head) - .permute(0, 2, 1, 3) - .reshape(b * heads, -1, dim_head) - .contiguous(), + lambda t: t.reshape(b, -1, heads, dim_head), (q, k, v), ) @@ -307,14 +342,11 @@ def attention_xformers(q, k, v, heads, mask=None): out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) out = ( - out.unsqueeze(0) - .reshape(b, heads, -1, dim_head) - .permute(0, 2, 1, 3) - .reshape(b, -1, heads * dim_head) + out.reshape(b, -1, heads * dim_head) ) return out -def attention_pytorch(q, k, v, heads, mask=None): +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None): b, _, dim_head = q.shape dim_head //= heads q, k, v = map( @@ -332,17 +364,17 @@ def attention_pytorch(q, k, v, heads, mask=None): optimized_attention = attention_basic if model_management.xformers_enabled(): - print("Using xformers cross attention") + logging.info("Using xformers cross attention") optimized_attention = attention_xformers elif model_management.pytorch_attention_enabled(): - print("Using pytorch cross attention") + logging.info("Using pytorch cross attention") optimized_attention = attention_pytorch else: - if args.attention_split: - print("Using split optimization for cross attention") + if args.use_split_cross_attention: + logging.info("Using split optimization for cross attention") optimized_attention = attention_split else: - print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --attention-split") + logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad optimized_attention_masked = optimized_attention @@ -364,10 +396,11 @@ def optimized_attention_for_device(device, mask=False, small_input=False): class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) + self.attn_precision = attn_precision self.heads = heads self.dim_head = dim_head @@ -389,15 +422,15 @@ def forward(self, x, context=None, value=None, mask=None): v = self.to_v(context) if mask is None: - out = optimized_attention(q, k, v, self.heads) + out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) else: - out = optimized_attention_masked(q, k, v, self.heads, mask) + out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) return self.to_out(out) class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None, - disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops): + disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops): super().__init__() self.ff_in = ff_in or inner_dim is not None @@ -405,6 +438,7 @@ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff= inner_dim = dim self.is_res = inner_dim == dim + self.attn_precision = attn_precision if self.ff_in: self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device) @@ -412,7 +446,7 @@ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff= self.disable_self_attn = disable_self_attn self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn + context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) if disable_temporal_crossattention: @@ -426,20 +460,16 @@ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff= context_dim_attn2 = context_dim self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, - heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none + heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) - self.checkpoint = checkpoint self.n_heads = n_heads self.d_head = d_head self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa def forward(self, x, context=None, transformer_options={}): - return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) - - def _forward(self, x, context=None, transformer_options={}): extra_options = {} block = transformer_options.get("block", None) block_index = transformer_options.get("block_index", 0) @@ -456,6 +486,7 @@ def _forward(self, x, context=None, transformer_options={}): extra_options["n_heads"] = self.n_heads extra_options["dim_head"] = self.d_head + extra_options["attn_precision"] = self.attn_precision if self.ff_in: x_skip = x @@ -566,7 +597,7 @@ class SpatialTransformer(nn.Module): def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, - use_checkpoint=True, dtype=None, device=None, operations=ops): + use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops): super().__init__() if exists(context_dim) and not isinstance(context_dim, list): context_dim = [context_dim] * depth @@ -584,7 +615,7 @@ def __init__(self, in_channels, n_heads, d_head, self.transformer_blocks = nn.ModuleList( [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], - disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations) + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) for d in range(depth)] ) if not use_linear: @@ -605,7 +636,7 @@ def forward(self, x, context=None, transformer_options={}): x = self.norm(x) if not self.use_linear: x = self.proj_in(x) - x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + x = x.movedim(1, 3).flatten(1, 2).contiguous() if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): @@ -613,7 +644,7 @@ def forward(self, x, context=None, transformer_options={}): x = block(x, context=context[i], transformer_options=transformer_options) if self.use_linear: x = self.proj_out(x) - x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(3, 1).contiguous() if not self.use_linear: x = self.proj_out(x) return x + x_in @@ -640,6 +671,7 @@ def __init__( disable_self_attn=False, disable_temporal_crossattention=False, max_time_embed_period: int = 10000, + attn_precision=None, dtype=None, device=None, operations=ops ): super().__init__( @@ -652,6 +684,7 @@ def __init__( context_dim=context_dim, use_linear=use_linear, disable_self_attn=disable_self_attn, + attn_precision=attn_precision, dtype=dtype, device=device, operations=operations ) self.time_depth = time_depth @@ -681,6 +714,7 @@ def __init__( inner_dim=time_mix_inner_dim, disable_self_attn=disable_self_attn, disable_temporal_crossattention=disable_temporal_crossattention, + attn_precision=attn_precision, dtype=dtype, device=device, operations=operations ) for _ in range(self.depth) diff --git a/ldm_patched/ldm/modules/diffusionmodules/model.py b/ldm_patched/ldm/modules/diffusionmodules/model.py index 1901145c5..9b2553234 100644 --- a/ldm_patched/ldm/modules/diffusionmodules/model.py +++ b/ldm_patched/ldm/modules/diffusionmodules/model.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn import numpy as np -from einops import rearrange from typing import Optional, Any +import logging from ldm_patched.modules import model_management import ldm_patched.modules.ops @@ -190,7 +190,7 @@ def slice_attention(q, k, v): steps *= 2 if steps > 128: raise e - print("out of memory error, increasing steps and trying again", steps) + logging.warning("out of memory error, increasing steps and trying again {}".format(steps)) return r1 @@ -235,7 +235,7 @@ def pytorch_attention(q, k, v): out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(B, C, H, W) except model_management.OOM_EXCEPTION as e: - print("scaled_dot_product_attention OOMed: switched to slice attention") + logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) return out @@ -268,13 +268,13 @@ def __init__(self, in_channels): padding=0) if model_management.xformers_enabled_vae(): - print("Using xformers attention in VAE") + logging.info("Using xformers attention in VAE") self.optimized_attention = xformers_attention elif model_management.pytorch_attention_enabled(): - print("Using pytorch attention in VAE") + logging.info("Using pytorch attention in VAE") self.optimized_attention = pytorch_attention else: - print("Using split attention in VAE") + logging.info("Using split attention in VAE") self.optimized_attention = normal_attention def forward(self, x): @@ -562,7 +562,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, block_in = ch*ch_mult[self.num_resolutions-1] curr_res = resolution // 2**(self.num_resolutions-1) self.z_shape = (1,z_channels,curr_res,curr_res) - print("Working with z of shape {} = {} dimensions.".format( + logging.debug("Working with z of shape {} = {} dimensions.".format( self.z_shape, np.prod(self.z_shape))) # z to block_in diff --git a/ldm_patched/ldm/modules/diffusionmodules/openaimodel.py b/ldm_patched/ldm/modules/diffusionmodules/openaimodel.py index 4b695f76a..94c5da9a3 100644 --- a/ldm_patched/ldm/modules/diffusionmodules/openaimodel.py +++ b/ldm_patched/ldm/modules/diffusionmodules/openaimodel.py @@ -4,6 +4,7 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange +import logging from .util import ( checkpoint, @@ -257,7 +258,7 @@ def _forward(self, x, emb): else: if emb_out is not None: if self.exchange_temb_dims: - emb_out = rearrange(emb_out, "b t c ... -> b c t ...") + emb_out = emb_out.movedim(1, 2) h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h @@ -359,7 +360,7 @@ def apply_control(h, control, name): try: h += ctrl except: - print("warning control could not be applied", h.shape, ctrl.shape) + logging.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape)) return h class UNetModel(nn.Module): @@ -430,6 +431,7 @@ def __init__( video_kernel_size=None, disable_temporal_crossattention=False, max_ddpm_temb_period=10000, + attn_precision=None, device=None, operations=ops, ): @@ -484,7 +486,6 @@ def __init__( self.predict_codebook_ids = n_embed is not None self.default_num_video_frames = None - self.default_image_only_indicator = None time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( @@ -497,7 +498,7 @@ def __init__( if isinstance(self.num_classes, int): self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device) elif self.num_classes == "continuous": - print("setting up linear c_adm embedding layer") + logging.debug("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) elif self.num_classes == "sequential": assert adm_in_channels is not None @@ -550,13 +551,14 @@ def get_attention_layer( disable_self_attn=disable_self_attn, disable_temporal_crossattention=disable_temporal_crossattention, max_time_embed_period=max_ddpm_temb_period, + attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations ) else: return SpatialTransformer( ch, num_heads, dim_head, depth=depth, context_dim=context_dim, disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations ) def get_resblock( @@ -708,27 +710,30 @@ def get_resblock( device=device, operations=operations )] - if transformer_depth_middle >= 0: - mid_block += [get_attention_layer( # always uses a self-attn - ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint - ), - get_resblock( - merge_factor=merge_factor, - merge_strategy=merge_strategy, - video_kernel_size=video_kernel_size, - ch=ch, - time_embed_dim=time_embed_dim, - dropout=dropout, - out_channels=None, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - dtype=self.dtype, - device=device, - operations=operations - )] - self.middle_block = TimestepEmbedSequential(*mid_block) + + self.middle_block = None + if transformer_depth_middle >= -1: + if transformer_depth_middle >= 0: + mid_block += [get_attention_layer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint + ), + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_channels=None, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, + operations=operations + )] + self.middle_block = TimestepEmbedSequential(*mid_block) self._feature_size += ch self.output_blocks = nn.ModuleList([]) @@ -827,7 +832,7 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo transformer_patches = transformer_options.get("patches", {}) num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) - image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) + image_only_indicator = kwargs.get("image_only_indicator", None) time_context = kwargs.get("time_context", None) assert (y is not None) == ( @@ -858,7 +863,8 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo h = p(h, transformer_options) transformer_options["block"] = ("middle", 0) - h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) + if self.middle_block is not None: + h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'middle') diff --git a/ldm_patched/ldm/modules/diffusionmodules/util.py b/ldm_patched/ldm/modules/diffusionmodules/util.py index e261e06a3..5c96a46f7 100644 --- a/ldm_patched/ldm/modules/diffusionmodules/util.py +++ b/ldm_patched/ldm/modules/diffusionmodules/util.py @@ -46,23 +46,25 @@ def __init__( else: raise ValueError(f"unknown merge strategy {self.merge_strategy}") - def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: + def get_alpha(self, image_only_indicator: torch.Tensor, device) -> torch.Tensor: # skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t) if self.merge_strategy == "fixed": # make shape compatible # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) - alpha = self.mix_factor.to(image_only_indicator.device) + alpha = self.mix_factor.to(device) elif self.merge_strategy == "learned": - alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device)) + alpha = torch.sigmoid(self.mix_factor.to(device)) # make shape compatible # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) elif self.merge_strategy == "learned_with_images": - assert image_only_indicator is not None, "need image_only_indicator ..." - alpha = torch.where( - image_only_indicator.bool(), - torch.ones(1, 1, device=image_only_indicator.device), - rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"), - ) + if image_only_indicator is None: + alpha = rearrange(torch.sigmoid(self.mix_factor.to(device)), "... -> ... 1") + else: + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"), + ) alpha = rearrange(alpha, self.rearrange_pattern) # make shape compatible # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) @@ -76,7 +78,7 @@ def forward( x_temporal, image_only_indicator=None, ) -> torch.Tensor: - alpha = self.get_alpha(image_only_indicator) + alpha = self.get_alpha(image_only_indicator, x_spatial.device) x = ( alpha.to(x_spatial.dtype) * x_spatial + (1.0 - alpha).to(x_spatial.dtype) * x_temporal @@ -98,7 +100,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, alphas = torch.cos(alphas).pow(2) alphas = alphas / alphas[0] betas = 1 - alphas[1:] / alphas[:-1] - betas = np.clip(betas, a_min=0, a_max=0.999) + betas = torch.clamp(betas, min=0, max=0.999) elif schedule == "squaredcos_cap_v2": # used for karlo prior # return early @@ -113,7 +115,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 else: raise ValueError(f"schedule '{schedule}' unknown.") - return betas.numpy() + return betas def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): diff --git a/ldm_patched/ldm/modules/sub_quadratic_attention.py b/ldm_patched/ldm/modules/sub_quadratic_attention.py index 9f4c23c7e..732c5e7b5 100644 --- a/ldm_patched/ldm/modules/sub_quadratic_attention.py +++ b/ldm_patched/ldm/modules/sub_quadratic_attention.py @@ -14,6 +14,7 @@ from torch import Tensor from torch.utils.checkpoint import checkpoint import math +import logging try: from typing import Optional, NamedTuple, List, Protocol @@ -170,7 +171,7 @@ def _get_attention_scores_no_kv_chunking( attn_probs = attn_scores.softmax(dim=-1) del attn_scores except model_management.OOM_EXCEPTION: - print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") + logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") attn_scores -= attn_scores.max(dim=-1, keepdim=True).values torch.exp(attn_scores, out=attn_scores) summed = torch.sum(attn_scores, dim=-1, keepdim=True) diff --git a/ldm_patched/modules/args_parser.py b/ldm_patched/modules/args_parser.py index bf8737835..917c9ce49 100644 --- a/ldm_patched/modules/args_parser.py +++ b/ldm_patched/modules/args_parser.py @@ -33,94 +33,105 @@ def __call__(self, parser, namespace, values, option_string=None): parser = argparse.ArgumentParser() -parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0") -parser.add_argument("--port", type=int, default=8188) -parser.add_argument("--disable-header-check", type=str, default=None, metavar="ORIGIN", nargs="?", const="*") -parser.add_argument("--web-upload-size", type=float, default=100) -parser.add_argument("--hf-mirror", type=str, default=None) - -parser.add_argument("--external-working-path", type=str, default=None, metavar="PATH", nargs='+', action='append') -parser.add_argument("--output-path", type=str, default=None) -parser.add_argument("--temp-path", type=str, default=None) -parser.add_argument("--cache-path", type=str, default=None) -parser.add_argument("--in-browser", action="store_true") -parser.add_argument("--disable-in-browser", action="store_true") -parser.add_argument("--gpu-device-id", type=int, default=None, metavar="DEVICE_ID") -cm_group = parser.add_mutually_exclusive_group() -cm_group.add_argument("--async-cuda-allocation", action="store_true") -cm_group.add_argument("--disable-async-cuda-allocation", action="store_true") - -parser.add_argument("--disable-attention-upcast", action="store_true") +parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") +parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") +parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function") +parser.add_argument("--tls-keyfile-password", type=str, help="Password for the TLS (SSL) key file. Requires --tls-keyfile to function") +parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function") +parser.add_argument("--tls-verify", type=bool, default=True, help="If False, skips certificate validation which allows self-signed certificates to be used. Requires --tls-keyfile to function") +parser.add_argument("--hf-mirror", type=str, default=None, help="Use HuggingFace mirror URL") + +parser.add_argument("--output-path", type=str, default=None, help="Set the output directory.") +parser.add_argument("--temp-path", type=str, default=None, help="Set the temp directory.") +parser.add_argument("--in-browser", action="store_true", help="Automatically launch in the default browser.") +parser.add_argument("--disable-in-browser", action="store_true", help="Disable auto launching the browser.") +parser.add_argument("--gpu-device-id", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") fp_group = parser.add_mutually_exclusive_group() -fp_group.add_argument("--all-in-fp32", action="store_true") -fp_group.add_argument("--all-in-fp16", action="store_true") +fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") +fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") fpunet_group = parser.add_mutually_exclusive_group() -fpunet_group.add_argument("--unet-in-bf16", action="store_true") -fpunet_group.add_argument("--unet-in-fp16", action="store_true") -fpunet_group.add_argument("--unet-in-fp8-e4m3fn", action="store_true") -fpunet_group.add_argument("--unet-in-fp8-e5m2", action="store_true") +fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") +fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.") +fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.") +fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.") fpvae_group = parser.add_mutually_exclusive_group() -fpvae_group.add_argument("--vae-in-fp16", action="store_true") -fpvae_group.add_argument("--vae-in-fp32", action="store_true") -fpvae_group.add_argument("--vae-in-bf16", action="store_true") +fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") +fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") +fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") -parser.add_argument("--vae-in-cpu", action="store_true") +parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.") fpte_group = parser.add_mutually_exclusive_group() -fpte_group.add_argument("--clip-in-fp8-e4m3fn", action="store_true") -fpte_group.add_argument("--clip-in-fp8-e5m2", action="store_true") -fpte_group.add_argument("--clip-in-fp16", action="store_true") -fpte_group.add_argument("--clip-in-fp32", action="store_true") +fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).") +fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).") +fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.") +fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.") -parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1) +parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") -parser.add_argument("--disable-ipex-hijack", action="store_true") +parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.") class LatentPreviewMethod(enum.Enum): NoPreviews = "none" Auto = "auto" - Latent2RGB = "fast" + Latent2RGB = "latent2rgb" TAESD = "taesd" -parser.add_argument("--preview-option", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, action=EnumAction) +parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) attn_group = parser.add_mutually_exclusive_group() -attn_group.add_argument("--attention-split", action="store_true") -attn_group.add_argument("--attention-quad", action="store_true") -attn_group.add_argument("--attention-pytorch", action="store_true") +attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") +attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") +attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") + +parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") + +upcast = parser.add_mutually_exclusive_group() +upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.") +upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.") -parser.add_argument("--disable-xformers", action="store_true") vram_group = parser.add_mutually_exclusive_group() -vram_group.add_argument("--always-gpu", action="store_true") -vram_group.add_argument("--always-high-vram", action="store_true") -vram_group.add_argument("--always-normal-vram", action="store_true") -vram_group.add_argument("--always-low-vram", action="store_true") -vram_group.add_argument("--always-no-vram", action="store_true") -vram_group.add_argument("--always-cpu", type=int, nargs="?", metavar="CPU_NUM_THREADS", const=-1) +vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).") +vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") +vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.") +vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.") +vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") +vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") + + +parser.add_argument("--disable-smart-memory", action="store_true", help="Force agressively offload to regular ram instead of keeping models in vram when it can.") +parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") -parser.add_argument("--always-offload-from-vram", action="store_true") -parser.add_argument("--pytorch-deterministic", action="store_true") +parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") +parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") +parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") -parser.add_argument("--disable-server-log", action="store_true") -parser.add_argument("--debug-mode", action="store_true") -parser.add_argument("--is-windows-embedded-python", action="store_true") +parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") -parser.add_argument("--disable-server-info", action="store_true") +parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") + +parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.") -parser.add_argument("--multi-user", action="store_true") if ldm_patched.modules.options.args_parsing: - args = parser.parse_args([]) + args = parser.parse_args() else: args = parser.parse_args([]) -if args.is_windows_embedded_python: +if args.windows_standalone_build: args.in_browser = True if args.disable_in_browser: args.in_browser = False + +import logging +logging_level = logging.WARNING +if args.verbose: + logging_level = logging.DEBUG + +logging.basicConfig(format="%(message)s", level=logging_level) diff --git a/ldm_patched/modules/clip_model.py b/ldm_patched/modules/clip_model.py index aceca86d6..87e9ed90b 100644 --- a/ldm_patched/modules/clip_model.py +++ b/ldm_patched/modules/clip_model.py @@ -97,7 +97,7 @@ def forward(self, input_tokens, attention_mask=None, intermediate_output=None, f x = self.embeddings(input_tokens) mask = None if attention_mask is not None: - mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) @@ -119,6 +119,9 @@ def __init__(self, config_dict, dtype, device, operations): super().__init__() self.num_layers = config_dict["num_hidden_layers"] self.text_model = CLIPTextModel_(config_dict, dtype, device, operations) + embed_dim = config_dict["hidden_size"] + self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) + self.text_projection.weight.copy_(torch.eye(embed_dim)) self.dtype = dtype def get_input_embeddings(self): @@ -128,7 +131,10 @@ def set_input_embeddings(self, embeddings): self.text_model.embeddings.token_embedding = embeddings def forward(self, *args, **kwargs): - return self.text_model(*args, **kwargs) + x = self.text_model(*args, **kwargs) + out = self.text_projection(x[2]) + return (x[0], x[1], out, x[2]) + class CLIPVisionEmbeddings(torch.nn.Module): def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None): diff --git a/ldm_patched/modules/clip_vision.py b/ldm_patched/modules/clip_vision.py index affdb8b24..ab506285b 100644 --- a/ldm_patched/modules/clip_vision.py +++ b/ldm_patched/modules/clip_vision.py @@ -2,6 +2,7 @@ import os import torch import json +import logging import ldm_patched.modules.ops import ldm_patched.modules.model_patcher @@ -55,7 +56,7 @@ def encode_image(self, image): outputs = Output() outputs["last_hidden_state"] = out[0].to(ldm_patched.modules.model_management.intermediate_device()) outputs["image_embeds"] = out[2].to(ldm_patched.modules.model_management.intermediate_device()) - outputs["penultimate_hidden_states"] = out[1].to(ldm_patched.modules.model_management.intermediate_device()) + outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) return outputs def convert_to_transformers(sd, prefix): @@ -99,7 +100,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): clip = ClipVisionModel(json_config) m, u = clip.load_sd(sd) if len(m) > 0: - print("extra clip vision:", m) + logging.warning("missing clip vision: {}".format(m)) u = set(u) keys = list(sd.keys()) for k in keys: diff --git a/ldm_patched/modules/conds.py b/ldm_patched/modules/conds.py index 0ee184bc8..ed03bd64b 100644 --- a/ldm_patched/modules/conds.py +++ b/ldm_patched/modules/conds.py @@ -3,6 +3,8 @@ import ldm_patched.modules.utils +def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) + return abs(a*b) // math.gcd(a, b) class CONDRegular: def __init__(self, cond): @@ -39,7 +41,7 @@ def can_concat(self, other): if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen return False - mult_min = math.lcm(s1[1], s2[1]) + mult_min = lcm(s1[1], s2[1]) diff = mult_min // min(s1[1], s2[1]) if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much return False @@ -50,7 +52,7 @@ def concat(self, others): crossattn_max_len = self.cond.shape[1] for x in others: c = x.cond - crossattn_max_len = math.lcm(crossattn_max_len, c.shape[1]) + crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) conds.append(c) out = [] diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 7e11497fe..473dd96c1 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -1,6 +1,7 @@ import torch import math import os +import logging import ldm_patched.modules.utils import ldm_patched.modules.model_management import ldm_patched.modules.model_detection @@ -9,6 +10,7 @@ import ldm_patched.controlnet.cldm import ldm_patched.t2ia.adapter +import ldm_patched.ldm.cascade.controlnet def broadcast_image_to(tensor, target_batch_size, batched_number): @@ -37,6 +39,8 @@ def __init__(self, device=None): self.timestep_percent_range = (0.0, 1.0) self.global_average_pooling = False self.timestep_range = None + self.compression_ratio = 8 + self.upscale_algorithm = 'nearest-exact' if device is None: device = ldm_patched.modules.model_management.get_torch_device() @@ -77,6 +81,8 @@ def copy_to(self, c): c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range c.global_average_pooling = self.global_average_pooling + c.compression_ratio = self.compression_ratio + c.upscale_algorithm = self.upscale_algorithm def inference_memory_requirements(self, dtype): if self.previous_controlnet is not None: @@ -132,11 +138,13 @@ def control_merge(self, control_input, control_output, control_prev, output_dtyp return out class ControlNet(ControlBase): - def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): + def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): super().__init__(device) self.control_model = control_model self.load_device = load_device - self.control_model_wrapped = ldm_patched.modules.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=ldm_patched.modules.model_management.unet_offload_device()) + if control_model is not None: + self.control_model_wrapped = ldm_patched.modules.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=ldm_patched.modules.model_management.unet_offload_device()) + self.global_average_pooling = global_average_pooling self.model_sampling_current = None self.manual_cast_dtype = manual_cast_dtype @@ -158,15 +166,15 @@ def get_control(self, x_noisy, t, cond, batched_number): dtype = self.manual_cast_dtype output_dtype = x_noisy.dtype - if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device) + self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio, self.upscale_algorithm, "center").to(dtype).to(self.device) if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - context = cond['c_crossattn'] + context = cond.get('crossattn_controlnet', cond['c_crossattn']) y = cond.get('y', None) if y is not None: y = y.to(dtype) @@ -177,7 +185,9 @@ def get_control(self, x_noisy, t, cond, batched_number): return self.control_merge(None, control, control_prev, output_dtype) def copy(self): - c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) + c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) + c.control_model = self.control_model + c.control_model_wrapped = self.control_model_wrapped self.copy_to(c) return c @@ -195,7 +205,7 @@ def cleanup(self): super().cleanup() class ControlLoraOps: - class Linear(torch.nn.Module): + class Linear(torch.nn.Module, ldm_patched.modules.ops.CastWeightBiasOp): def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} @@ -214,7 +224,7 @@ def forward(self, input): else: return torch.nn.functional.linear(input, weight, bias) - class Conv2d(torch.nn.Module): + class Conv2d(torch.nn.Module, ldm_patched.modules.ops.CastWeightBiasOp): def __init__( self, in_channels, @@ -287,13 +297,13 @@ class control_lora_ops(ControlLoraOps, ldm_patched.modules.ops.manual_cast): for k in sd: weight = sd[k] try: - ldm_patched.modules.utils.set_attr(self.control_model, k, weight) + ldm_patched.modules.utils.set_attr_param(self.control_model, k, weight) except: pass for k in self.control_weights: if k not in {"lora_controlnet"}: - ldm_patched.modules.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(ldm_patched.modules.model_management.get_torch_device())) + ldm_patched.modules.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(ldm_patched.modules.model_management.get_torch_device())) def copy(self): c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) @@ -318,9 +328,10 @@ def load_controlnet(ckpt_path, model=None): return ControlLora(controlnet_data) controlnet_config = None + supported_inference_dtypes = None + if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format - unet_dtype = ldm_patched.modules.model_management.unet_dtype() - controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype) + controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data) diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(controlnet_config) diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" @@ -361,7 +372,7 @@ def load_controlnet(ckpt_path, model=None): leftover_keys = controlnet_data.keys() if len(leftover_keys) > 0: - print("leftover keys:", leftover_keys) + logging.warning("leftover keys: {}".format(leftover_keys)) controlnet_data = new_sd pth_key = 'control_model.zero_convs.0.0.weight' @@ -376,16 +387,24 @@ def load_controlnet(ckpt_path, model=None): else: net = load_t2i_adapter(controlnet_data) if net is None: - print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) + logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path)) return net if controlnet_config is None: - unet_dtype = ldm_patched.modules.model_management.unet_dtype() - controlnet_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config + model_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix, True) + supported_inference_dtypes = model_config.supported_inference_dtypes + controlnet_config = model_config.unet_config + load_device = ldm_patched.modules.model_management.get_torch_device() + if supported_inference_dtypes is None: + unet_dtype = ldm_patched.modules.model_management.unet_dtype() + else: + unet_dtype = ldm_patched.modules.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes) + manual_cast_dtype = ldm_patched.modules.model_management.unet_manual_cast(unet_dtype, load_device) if manual_cast_dtype is not None: controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast + controlnet_config["dtype"] = unet_dtype controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config) @@ -403,7 +422,7 @@ def load_controlnet(ckpt_path, model=None): cd = controlnet_data[x] cd += model_sd[sd_key].type(cd.dtype).to(cd.device) else: - print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") + logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") class WeightsLoader(torch.nn.Module): pass @@ -412,7 +431,12 @@ class WeightsLoader(torch.nn.Module): missing, unexpected = w.load_state_dict(controlnet_data, strict=False) else: missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) - print(missing, unexpected) + + if len(missing) > 0: + logging.warning("missing controlnet keys: {}".format(missing)) + + if len(unexpected) > 0: + logging.debug("unexpected controlnet keys: {}".format(unexpected)) global_average_pooling = False filename = os.path.splitext(ckpt_path)[0] @@ -423,11 +447,13 @@ class WeightsLoader(torch.nn.Module): return control class T2IAdapter(ControlBase): - def __init__(self, t2i_model, channels_in, device=None): + def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None): super().__init__(device) self.t2i_model = t2i_model self.channels_in = channels_in self.control_input = None + self.compression_ratio = compression_ratio + self.upscale_algorithm = upscale_algorithm def scale_image_to(self, width, height): unshuffle_amount = self.t2i_model.unshuffle_amount @@ -447,13 +473,13 @@ def get_control(self, x_noisy, t, cond, batched_number): else: return None - if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint self.control_input = None self.cond_hint = None - width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8) - self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device) + width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio) + self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device) if self.channels_in == 1 and self.cond_hint.shape[1] > 1: self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) if x_noisy.shape[0] != self.cond_hint.shape[0]: @@ -472,11 +498,14 @@ def get_control(self, x_noisy, t, cond, batched_number): return self.control_merge(control_input, mid, control_prev, x_noisy.dtype) def copy(self): - c = T2IAdapter(self.t2i_model, self.channels_in) + c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm) self.copy_to(c) return c def load_t2i_adapter(t2i_data): + compression_ratio = 8 + upscale_algorithm = 'nearest-exact' + if 'adapter' in t2i_data: t2i_data = t2i_data['adapter'] if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format @@ -504,13 +533,22 @@ def load_t2i_adapter(t2i_data): if cin == 256 or cin == 768: xl = True model_ad = ldm_patched.t2ia.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) + elif "backbone.0.0.weight" in keys: + model_ad = ldm_patched.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63]) + compression_ratio = 32 + upscale_algorithm = 'bilinear' + elif "backbone.10.blocks.0.weight" in keys: + model_ad = ldm_patched.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63]) + compression_ratio = 1 + upscale_algorithm = 'nearest-exact' else: return None + missing, unexpected = model_ad.load_state_dict(t2i_data) if len(missing) > 0: - print("t2i missing", missing) + logging.warning("t2i missing {}".format(missing)) if len(unexpected) > 0: - print("t2i unexpected", unexpected) + logging.debug("t2i unexpected {}".format(unexpected)) - return T2IAdapter(model_ad, model_ad.input_channels) + return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm) diff --git a/ldm_patched/modules/diffusers_convert.py b/ldm_patched/modules/diffusers_convert.py index a9eb9302f..ed2a45fea 100644 --- a/ldm_patched/modules/diffusers_convert.py +++ b/ldm_patched/modules/diffusers_convert.py @@ -1,5 +1,6 @@ import re import torch +import logging # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -177,7 +178,7 @@ def convert_vae_state_dict(vae_state_dict): for k, v in new_state_dict.items(): for weight_name in weights_to_convert: if f"mid.attn_1.{weight_name}.weight" in k: - print(f"Reshaping {k} for SD format") + logging.debug(f"Reshaping {k} for SD format") new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict @@ -205,6 +206,21 @@ def convert_vae_state_dict(vae_state_dict): # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp code2idx = {"q": 0, "k": 1, "v": 2} +# This function exists because at the time of writing torch.cat can't do fp8 with cuda +def cat_tensors(tensors): + x = 0 + for t in tensors: + x += t.shape[0] + + shape = [x] + list(tensors[0].shape)[1:] + out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype) + + x = 0 + for t in tensors: + out[x:x + t.shape[0]] = t + x += t.shape[0] + + return out def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): new_state_dict = {} @@ -237,20 +253,24 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): capture_qkv_bias[k_pre][code2idx[k_code]] = v continue - relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) - new_state_dict[relabelled_key] = v + text_proj = "transformer.text_projection.weight" + if k.endswith(text_proj): + new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous() + else: + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) + new_state_dict[relabelled_key] = v for k_pre, tensors in capture_qkv_weight.items(): if None in tensors: raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) - new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) + new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors) for k_pre, tensors in capture_qkv_bias.items(): if None in tensors: raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) - new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) + new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors) return new_state_dict diff --git a/ldm_patched/modules/gligen.py b/ldm_patched/modules/gligen.py index 11f1ee938..d33acadeb 100644 --- a/ldm_patched/modules/gligen.py +++ b/ldm_patched/modules/gligen.py @@ -1,8 +1,11 @@ +import math + import torch from torch import nn from ldm_patched.ldm.modules.attention import CrossAttention from inspect import isfunction - +import ldm_patched.modules.ops +ops = ldm_patched.modules.ops.manual_cast def exists(val): return val is not None @@ -22,7 +25,7 @@ def default(val, d): class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) + self.proj = ops.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) @@ -35,14 +38,14 @@ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential( - nn.Linear(dim, inner_dim), + ops.Linear(dim, inner_dim), nn.GELU() ) if not glu else GEGLU(dim, inner_dim) self.net = nn.Sequential( project_in, nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) + ops.Linear(inner_dim, dim_out) ) def forward(self, x): @@ -57,11 +60,12 @@ def __init__(self, query_dim, context_dim, n_heads, d_head): query_dim=query_dim, context_dim=context_dim, heads=n_heads, - dim_head=d_head) + dim_head=d_head, + operations=ops) self.ff = FeedForward(query_dim, glu=True) - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) + self.norm1 = ops.LayerNorm(query_dim) + self.norm2 = ops.LayerNorm(query_dim) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) @@ -87,17 +91,18 @@ def __init__(self, query_dim, context_dim, n_heads, d_head): # we need a linear projection since we need cat visual feature and obj # feature - self.linear = nn.Linear(context_dim, query_dim) + self.linear = ops.Linear(context_dim, query_dim) self.attn = CrossAttention( query_dim=query_dim, context_dim=query_dim, heads=n_heads, - dim_head=d_head) + dim_head=d_head, + operations=ops) self.ff = FeedForward(query_dim, glu=True) - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) + self.norm1 = ops.LayerNorm(query_dim) + self.norm2 = ops.LayerNorm(query_dim) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) @@ -126,14 +131,14 @@ def __init__(self, query_dim, context_dim, n_heads, d_head): # we need a linear projection since we need cat visual feature and obj # feature - self.linear = nn.Linear(context_dim, query_dim) + self.linear = ops.Linear(context_dim, query_dim) self.attn = CrossAttention( - query_dim=query_dim, context_dim=query_dim, dim_head=d_head) + query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops) self.ff = FeedForward(query_dim, glu=True) - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) + self.norm1 = ops.LayerNorm(query_dim) + self.norm2 = ops.LayerNorm(query_dim) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) @@ -201,11 +206,11 @@ def __init__(self, in_dim, out_dim, fourier_freqs=8): self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy self.linears = nn.Sequential( - nn.Linear(self.in_dim + self.position_dim, 512), + ops.Linear(self.in_dim + self.position_dim, 512), nn.SiLU(), - nn.Linear(512, 512), + ops.Linear(512, 512), nn.SiLU(), - nn.Linear(512, out_dim), + ops.Linear(512, out_dim), ) self.null_positive_feature = torch.nn.Parameter( @@ -215,16 +220,15 @@ def __init__(self, in_dim, out_dim, fourier_freqs=8): def forward(self, boxes, masks, positive_embeddings): B, N, _ = boxes.shape - dtype = self.linears[0].weight.dtype - masks = masks.unsqueeze(-1).to(dtype) - positive_embeddings = positive_embeddings.to(dtype) + masks = masks.unsqueeze(-1) + positive_embeddings = positive_embeddings # embedding position (it may includes padding as placeholder) - xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C + xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C # learnable null embedding - positive_null = self.null_positive_feature.view(1, 1, -1) - xyxy_null = self.null_position_feature.view(1, 1, -1) + positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1) + xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1) # replace padding with learnable null embedding positive_embeddings = positive_embeddings * \ @@ -251,7 +255,7 @@ def _set_position(self, boxes, masks, positive_embeddings): def func(x, extra_options): key = extra_options["transformer_index"] module = self.module_list[key] - return module(x, objs) + return module(x, objs.to(device=x.device, dtype=x.dtype)) return func def set_position(self, latent_image_shape, position_params, device): diff --git a/ldm_patched/modules/latent_formats.py b/ldm_patched/modules/latent_formats.py index 1606793e0..4ca466d9a 100644 --- a/ldm_patched/modules/latent_formats.py +++ b/ldm_patched/modules/latent_formats.py @@ -101,4 +101,4 @@ def __init__(self): [-0.2093, -0.0222, -0.0195], [-0.3087, -0.1535, 0.0366], [ 0.0290, -0.1574, -0.4078] - ] \ No newline at end of file + ] diff --git a/ldm_patched/modules/lora.py b/ldm_patched/modules/lora.py index cc5a29da8..622cbdc30 100644 --- a/ldm_patched/modules/lora.py +++ b/ldm_patched/modules/lora.py @@ -1,4 +1,5 @@ import ldm_patched.modules.utils +import logging LORA_CLIP_MAP = { "mlp.fc1": "mlp_fc1", @@ -20,6 +21,12 @@ def load_lora(lora, to_load): alpha = lora[alpha_name].item() loaded_keys.add(alpha_name) + dora_scale_name = "{}.dora_scale".format(x) + dora_scale = None + if dora_scale_name in lora.keys(): + dora_scale = lora[dora_scale_name] + loaded_keys.add(dora_scale_name) + regular_lora = "{}.lora_up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x) @@ -43,7 +50,7 @@ def load_lora(lora, to_load): if mid_name is not None and mid_name in lora.keys(): mid = lora[mid_name] loaded_keys.add(mid_name) - patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid)) + patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale)) loaded_keys.add(A_name) loaded_keys.add(B_name) @@ -64,7 +71,7 @@ def load_lora(lora, to_load): loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t2_name) - patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)) + patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -116,7 +123,7 @@ def load_lora(lora, to_load): loaded_keys.add(lokr_t2_name) if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)) + patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) #glora a1_name = "{}.a1.weight".format(x) @@ -124,7 +131,7 @@ def load_lora(lora, to_load): b1_name = "{}.b1.weight".format(x) b2_name = "{}.b2.weight".format(x) if a1_name in lora: - patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha)) + patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) loaded_keys.add(a1_name) loaded_keys.add(a2_name) loaded_keys.add(b1_name) @@ -156,7 +163,7 @@ def load_lora(lora, to_load): for x in lora.keys(): if x not in loaded_keys: - print("lora key not loaded", x) + logging.warning("lora key not loaded: {}".format(x)) return patch_dict def model_lora_keys_clip(model, key_map={}): @@ -197,6 +204,15 @@ def model_lora_keys_clip(model, key_map={}): key_map[lora_key] = k lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora key_map[lora_key] = k + lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config + key_map[lora_key] = k + + + k = "clip_g.transformer.text_projection.weight" + if k in sdk: + key_map["lora_prior_te_text_projection"] = k #cascade lora? + # key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too + # key_map["lora_te_text_projection"] = k return key_map @@ -207,6 +223,7 @@ def model_lora_keys_unet(model, key_map={}): if k.startswith("diffusion_model.") and k.endswith(".weight"): key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") key_map["lora_unet_{}".format(key_lora)] = k + key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(model.model_config.unet_config) for k in diffusers_keys: diff --git a/ldm_patched/modules/model_base.py b/ldm_patched/modules/model_base.py index 9c69e98b8..ab73a8853 100644 --- a/ldm_patched/modules/model_base.py +++ b/ldm_patched/modules/model_base.py @@ -1,5 +1,8 @@ import torch +import logging from ldm_patched.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep +from ldm_patched.ldm.cascade.stage_c import StageC +from ldm_patched.ldm.cascade.stage_b import StageB from ldm_patched.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from ldm_patched.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation import ldm_patched.modules.model_management @@ -12,9 +15,11 @@ class ModelType(Enum): EPS = 1 V_PREDICTION = 2 V_PREDICTION_EDM = 3 + STABLE_CASCADE = 4 + EDM = 5 -from ldm_patched.modules.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM +from ldm_patched.modules.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling def model_sampling(model_config, model_type): @@ -27,6 +32,12 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.V_PREDICTION_EDM: c = V_PREDICTION s = ModelSamplingContinuousEDM + elif model_type == ModelType.STABLE_CASCADE: + c = EPS + s = StableCascadeSampling + elif model_type == ModelType.EDM: + c = EDM + s = ModelSamplingContinuousEDM class ModelSampling(s, c): pass @@ -35,7 +46,7 @@ class ModelSampling(s, c): class BaseModel(torch.nn.Module): - def __init__(self, model_config, model_type=ModelType.EPS, device=None): + def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel): super().__init__() unet_config = model_config.unet_config @@ -48,16 +59,17 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None): operations = ldm_patched.modules.ops.manual_cast else: operations = ldm_patched.modules.ops.disable_weight_init - self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations) + self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 - self.inpaint_model = False - print("model_type", model_type.name) - print("UNet ADM Dimension", self.adm_channels) + + self.concat_keys = () + logging.info("model_type {}".format(model_type.name)) + logging.debug("adm {}".format(self.adm_channels)) def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t @@ -96,8 +108,7 @@ def encode_adm(self, **kwargs): def extra_conds(self, **kwargs): out = {} - if self.inpaint_model: - concat_keys = ("mask", "masked_image") + if len(self.concat_keys) > 0: cond_concat = [] denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) concat_latent_image = kwargs.get("concat_latent_image", None) @@ -114,24 +125,16 @@ def extra_conds(self, **kwargs): concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) - if len(denoise_mask.shape) == len(noise.shape): - denoise_mask = denoise_mask[:,:1] - - denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) - if denoise_mask.shape[-2:] != noise.shape[-2:]: - denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") - denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) + if denoise_mask is not None: + if len(denoise_mask.shape) == len(noise.shape): + denoise_mask = denoise_mask[:,:1] - def blank_inpaint_image_like(latent_image): - blank_image = torch.ones_like(latent_image) - # these are the values for "zero" in pixel space translated to latent space - blank_image[:,0] *= 0.8223 - blank_image[:,1] *= -0.6876 - blank_image[:,2] *= 0.6364 - blank_image[:,3] *= 0.1380 - return blank_image + denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) + if denoise_mask.shape[-2:] != noise.shape[-2:]: + denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") + denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) - for ck in concat_keys: + for ck in self.concat_keys: if denoise_mask is not None: if ck == "mask": cond_concat.append(denoise_mask.to(device)) @@ -141,7 +144,7 @@ def blank_inpaint_image_like(latent_image): if ck == "mask": cond_concat.append(torch.ones_like(noise)[:,:1]) elif ck == "masked_image": - cond_concat.append(blank_inpaint_image_like(noise)) + cond_concat.append(self.blank_inpaint_image_like(noise)) data = torch.cat(cond_concat, dim=1) out['c_concat'] = ldm_patched.modules.conds.CONDNoiseShape(data) @@ -153,6 +156,14 @@ def blank_inpaint_image_like(latent_image): if cross_attn is not None: out['c_crossattn'] = ldm_patched.modules.conds.CONDCrossAttn(cross_attn) + cross_attn_cnet = kwargs.get("cross_attn_controlnet", None) + if cross_attn_cnet is not None: + out['crossattn_controlnet'] = ldm_patched.modules.conds.CONDCrossAttn(cross_attn_cnet) + + c_concat = kwargs.get("noise_concat", None) + if c_concat is not None: + out['c_concat'] = ldm_patched.modules.conds.CONDNoiseShape(c_concat) + return out def load_model_weights(self, sd, unet_prefix=""): @@ -165,10 +176,10 @@ def load_model_weights(self, sd, unet_prefix=""): to_load = self.model_config.process_unet_state_dict(to_load) m, u = self.diffusion_model.load_state_dict(to_load, strict=False) if len(m) > 0: - print("unet missing:", m) + logging.warning("unet missing: {}".format(m)) if len(u) > 0: - print("unet unexpected:", u) + logging.warning("unet unexpected: {}".format(u)) del to_load return self @@ -202,7 +213,16 @@ def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_ return unet_state_dict def set_inpaint(self): - self.inpaint_model = True + self.concat_keys = ("mask", "masked_image") + def blank_inpaint_image_like(latent_image): + blank_image = torch.ones_like(latent_image) + # these are the values for "zero" in pixel space translated to latent space + blank_image[:,0] *= 0.8223 + blank_image[:,1] *= -0.6876 + blank_image[:,2] *= 0.6364 + blank_image[:,3] *= 0.1380 + return blank_image + self.blank_inpaint_image_like = blank_inpaint_image_like def memory_required(self, input_shape): if ldm_patched.modules.model_management.xformers_enabled() or ldm_patched.modules.model_management.pytorch_attention_flash_attention(): @@ -358,10 +378,39 @@ def extra_conds(self, **kwargs): if "time_conditioning" in kwargs: out["time_context"] = ldm_patched.modules.conds.CONDCrossAttn(kwargs["time_conditioning"]) - out['image_only_indicator'] = ldm_patched.modules.conds.CONDConstant(torch.zeros((1,), device=device)) out['num_video_frames'] = ldm_patched.modules.conds.CONDConstant(noise.shape[0]) return out +class SV3D_u(SVD_img2vid): + def encode_adm(self, **kwargs): + augmentation = kwargs.get("augmentation_level", 0) + + out = [] + out.append(self.embedder(torch.flatten(torch.Tensor([augmentation])))) + + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) + return flat + +class SV3D_p(SVD_img2vid): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): + super().__init__(model_config, model_type, device=device) + self.embedder_512 = Timestep(512) + + def encode_adm(self, **kwargs): + augmentation = kwargs.get("augmentation_level", 0) + elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here + azimuth = kwargs.get("azimuth", 0) + noise = kwargs.get("noise", None) + + out = [] + out.append(self.embedder(torch.flatten(torch.Tensor([augmentation])))) + out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(90 - torch.Tensor([elevation])), 360.0)))) + out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(torch.Tensor([azimuth])), 360.0)))) + + out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out)) + return torch.cat(out, dim=1) + + class Stable_Zero123(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): super().__init__(model_config, model_type, device=device) @@ -423,3 +472,88 @@ def extra_conds(self, **kwargs): out['c_concat'] = ldm_patched.modules.conds.CONDNoiseShape(image) out['y'] = ldm_patched.modules.conds.CONDRegular(noise_level) return out + +class IP2P: + def extra_conds(self, **kwargs): + out = {} + + image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + device = kwargs["device"] + + if image is None: + image = torch.zeros_like(noise) + + if image.shape[1:] != noise.shape[1:]: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + + image = utils.resize_to_batch_size(image, noise.shape[0]) + + out['c_concat'] = ldm_patched.modules.conds.CONDNoiseShape(self.process_ip2p_image_in(image)) + adm = self.encode_adm(**kwargs) + if adm is not None: + out['y'] = ldm_patched.modules.conds.CONDRegular(adm) + return out + +class SD15_instructpix2pix(IP2P, BaseModel): + def __init__(self, model_config, model_type=ModelType.EPS, device=None): + super().__init__(model_config, model_type, device=device) + self.process_ip2p_image_in = lambda image: image + +class SDXL_instructpix2pix(IP2P, SDXL): + def __init__(self, model_config, model_type=ModelType.EPS, device=None): + super().__init__(model_config, model_type, device=device) + if model_type == ModelType.V_PREDICTION_EDM: + self.process_ip2p_image_in = lambda image: ldm_patched.modules.latent_formats.SDXL().process_in(image) #cosxl ip2p + else: + self.process_ip2p_image_in = lambda image: image #diffusers ip2p + + +class StableCascade_C(BaseModel): + def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): + super().__init__(model_config, model_type, device=device, unet_model=StageC) + self.diffusion_model.eval().requires_grad_(False) + + def extra_conds(self, **kwargs): + out = {} + clip_text_pooled = kwargs["pooled_output"] + if clip_text_pooled is not None: + out['clip_text_pooled'] = ldm_patched.modules.conds.CONDRegular(clip_text_pooled) + + if "unclip_conditioning" in kwargs: + embeds = [] + for unclip_cond in kwargs["unclip_conditioning"]: + weight = unclip_cond["strength"] + embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight) + clip_img = torch.cat(embeds, dim=1) + else: + clip_img = torch.zeros((1, 1, 768)) + out["clip_img"] = ldm_patched.modules.conds.CONDRegular(clip_img) + out["sca"] = ldm_patched.modules.conds.CONDRegular(torch.zeros((1,))) + out["crp"] = ldm_patched.modules.conds.CONDRegular(torch.zeros((1,))) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['clip_text'] = ldm_patched.modules.conds.CONDCrossAttn(cross_attn) + return out + + +class StableCascade_B(BaseModel): + def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): + super().__init__(model_config, model_type, device=device, unet_model=StageB) + self.diffusion_model.eval().requires_grad_(False) + + def extra_conds(self, **kwargs): + out = {} + noise = kwargs.get("noise", None) + + clip_text_pooled = kwargs["pooled_output"] + if clip_text_pooled is not None: + out['clip'] = ldm_patched.modules.conds.CONDRegular(clip_text_pooled) + + #size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched + prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device)) + + out["effnet"] = ldm_patched.modules.conds.CONDRegular(prior) + out["sca"] = ldm_patched.modules.conds.CONDRegular(torch.zeros((1,))) + return out diff --git a/ldm_patched/modules/model_detection.py b/ldm_patched/modules/model_detection.py index 126386ca8..ced109d14 100644 --- a/ldm_patched/modules/model_detection.py +++ b/ldm_patched/modules/model_detection.py @@ -1,5 +1,6 @@ import ldm_patched.modules.supported_models import ldm_patched.modules.supported_models_base +import logging def count_blocks(state_dict_keys, prefix_string): count = 0 @@ -28,9 +29,38 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict): return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack return None -def detect_unet_config(state_dict, key_prefix, dtype): +def detect_unet_config(state_dict, key_prefix): state_dict_keys = list(state_dict.keys()) + if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade + unet_config = {} + text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix) + if text_mapper_name in state_dict_keys: + unet_config['stable_cascade_stage'] = 'c' + w = state_dict[text_mapper_name] + if w.shape[0] == 1536: #stage c lite + unet_config['c_cond'] = 1536 + unet_config['c_hidden'] = [1536, 1536] + unet_config['nhead'] = [24, 24] + unet_config['blocks'] = [[4, 12], [12, 4]] + elif w.shape[0] == 2048: #stage c full + unet_config['c_cond'] = 2048 + elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys: + unet_config['stable_cascade_stage'] = 'b' + w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)] + if w.shape[-1] == 640: + unet_config['c_hidden'] = [320, 640, 1280, 1280] + unet_config['nhead'] = [-1, -1, 20, 20] + unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]] + unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]] + elif w.shape[-1] == 576: #stage b lite + unet_config['c_hidden'] = [320, 576, 1152, 1152] + unet_config['nhead'] = [-1, 9, 18, 18] + unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]] + unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] + + return unet_config + unet_config = { "use_checkpoint": False, "image_size": 32, @@ -45,7 +75,6 @@ def detect_unet_config(state_dict, key_prefix, dtype): else: unet_config["adm_in_channels"] = None - unet_config["dtype"] = dtype model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] @@ -123,8 +152,10 @@ def detect_unet_config(state_dict, key_prefix, dtype): channel_mult.append(last_channel_mult) if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys: transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}') - else: + elif "{}middle_block.0.in_layers.0.weight".format(key_prefix) in state_dict_keys: transformer_depth_middle = -1 + else: + transformer_depth_middle = -2 unet_config["in_channels"] = in_channels unet_config["out_channels"] = out_channels @@ -151,17 +182,17 @@ def detect_unet_config(state_dict, key_prefix, dtype): return unet_config -def model_config_from_unet_config(unet_config): +def model_config_from_unet_config(unet_config, state_dict=None): for model_config in ldm_patched.modules.supported_models.models: - if model_config.matches(unet_config): + if model_config.matches(unet_config, state_dict): return model_config(unet_config) - print("no match", unet_config) + logging.error("no match {}".format(unet_config)) return None -def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False): - unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype) - model_config = model_config_from_unet_config(unet_config) +def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): + unet_config = detect_unet_config(state_dict, unet_key_prefix) + model_config = model_config_from_unet_config(unet_config, state_dict) if model_config is None and use_base_if_no_match: return ldm_patched.modules.supported_models_base.BASE(unet_config) else: @@ -206,7 +237,7 @@ def convert_config(unet_config): return new_config -def unet_config_from_diffusers_unet(state_dict, dtype): +def unet_config_from_diffusers_unet(state_dict, dtype=None): match = {} transformer_depth = [] @@ -214,6 +245,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype): down_blocks = count_blocks(state_dict, "down_blocks.{}") for i in range(down_blocks): attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}') + res_blocks = count_blocks(state_dict, "down_blocks.{}.resnets.".format(i) + '{}') for ab in range(attn_blocks): transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}') transformer_depth.append(transformer_count) @@ -222,8 +254,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype): attn_res *= 2 if attn_blocks == 0: - transformer_depth.append(0) - transformer_depth.append(0) + for i in range(res_blocks): + transformer_depth.append(0) match["transformer_depth"] = transformer_depth @@ -289,6 +321,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype): 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], 'use_temporal_attention': False, 'use_temporal_resblock': False} + SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320, + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, + 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], + 'use_temporal_attention': False, 'use_temporal_resblock': False} + SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4], @@ -301,7 +339,32 @@ def unet_config_from_diffusers_unet(state_dict, dtype): 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'use_temporal_attention': False, 'use_temporal_resblock': False} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega] + KOALA_700M = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 5], 'transformer_depth_output': [0, 0, 2, 2, 5, 5], + 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, + 'use_temporal_attention': False, 'use_temporal_resblock': False} + + KOALA_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 6], 'transformer_depth_output': [0, 0, 2, 2, 6, 6], + 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, + 'use_temporal_attention': False, 'use_temporal_resblock': False} + + SD09_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1], + 'transformer_depth': [1, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, + 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1], + 'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]} + + SD_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1], + 'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False, + 'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1], + 'use_temporal_attention': False, 'use_temporal_resblock': False} + + + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p] for unet_config in supported_models: matches = True @@ -313,8 +376,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype): return convert_config(unet_config) return None -def model_config_from_diffusers_unet(state_dict, dtype): - unet_config = unet_config_from_diffusers_unet(state_dict, dtype) +def model_config_from_diffusers_unet(state_dict): + unet_config = unet_config_from_diffusers_unet(state_dict) if unet_config is not None: return model_config_from_unet_config(unet_config) return None diff --git a/ldm_patched/modules/model_management.py b/ldm_patched/modules/model_management.py index 840d79a07..4886722a3 100644 --- a/ldm_patched/modules/model_management.py +++ b/ldm_patched/modules/model_management.py @@ -1,9 +1,10 @@ import psutil +import logging from enum import Enum from ldm_patched.modules.args_parser import args -import ldm_patched.modules.utils import torch import sys +import platform class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -28,8 +29,8 @@ class CPUState(Enum): lowvram_available = True xpu_available = False -if args.pytorch_deterministic: - print("Using deterministic algorithms for pytorch") +if args.deterministic: + logging.info("Using deterministic algorithms for pytorch") torch.use_deterministic_algorithms(True, warn_only=True) directml_enabled = False @@ -41,7 +42,7 @@ class CPUState(Enum): directml_device = torch_directml.device() else: directml_device = torch_directml.device(device_index) - print("Using directml with device:", torch_directml.device_name(device_index)) + logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index))) # torch_directml.disable_tiled_resources(True) lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. @@ -59,10 +60,7 @@ class CPUState(Enum): except: pass -if args.always_cpu: - if args.always_cpu > 0: - torch.set_num_threads(args.always_cpu) - print(f"Running on {torch.get_num_threads()} CPU threads") +if args.cpu: cpu_state = CPUState.CPU def is_intel_xpu(): @@ -85,7 +83,7 @@ def get_torch_device(): return torch.device("cpu") else: if is_intel_xpu(): - return torch.device("xpu") + return torch.device("xpu", torch.xpu.current_device()) else: return torch.device(torch.cuda.current_device()) @@ -104,8 +102,8 @@ def get_total_memory(dev=None, torch_total_too=False): elif is_intel_xpu(): stats = torch.xpu.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] - mem_total = torch.xpu.get_device_properties(dev).total_memory mem_total_torch = mem_reserved + mem_total = torch.xpu.get_device_properties(dev).total_memory else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -120,11 +118,12 @@ def get_total_memory(dev=None, torch_total_too=False): total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) -print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) -if not args.always_normal_vram and not args.always_cpu: - if lowvram_available and total_vram <= 4096: - print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --always-normal-vram") - set_vram_to = VRAMState.LOW_VRAM +logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) + +try: + logging.info("pytorch version: {}".format(torch.version.__version__)) +except: + pass try: OOM_EXCEPTION = torch.cuda.OutOfMemoryError @@ -146,12 +145,10 @@ def get_total_memory(dev=None, torch_total_too=False): pass try: XFORMERS_VERSION = xformers.version.__version__ - print("xformers version:", XFORMERS_VERSION) + logging.info("xformers version: {}".format(XFORMERS_VERSION)) if XFORMERS_VERSION.startswith("0.0.18"): - print() - print("WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") - print("Please downgrade or upgrade xformers to a different version.") - print() + logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") + logging.warning("Please downgrade or upgrade xformers to a different version.\n") XFORMERS_ENABLED_VAE = False except: pass @@ -166,7 +163,7 @@ def is_nvidia(): return False ENABLE_PYTORCH_ATTENTION = False -if args.attention_pytorch: +if args.use_pytorch_cross_attention: ENABLE_PYTORCH_ATTENTION = True XFORMERS_IS_AVAILABLE = False @@ -176,12 +173,12 @@ def is_nvidia(): if is_nvidia(): torch_version = torch.version.__version__ if int(torch_version[0]) >= 2: - if ENABLE_PYTORCH_ATTENTION == False and args.attention_split == False and args.attention_quad == False: + if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: VAE_DTYPE = torch.bfloat16 if is_intel_xpu(): - if args.attention_split == False and args.attention_quad == False: + if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True except: pass @@ -189,14 +186,14 @@ def is_nvidia(): if is_intel_xpu(): VAE_DTYPE = torch.bfloat16 -if args.vae_in_cpu: +if args.cpu_vae: VAE_DTYPE = torch.float32 -if args.vae_in_fp16: +if args.fp16_vae: VAE_DTYPE = torch.float16 -elif args.vae_in_bf16: +elif args.bf16_vae: VAE_DTYPE = torch.bfloat16 -elif args.vae_in_fp32: +elif args.fp32_vae: VAE_DTYPE = torch.float32 @@ -205,22 +202,22 @@ def is_nvidia(): torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) -if args.always_low_vram: +if args.lowvram: set_vram_to = VRAMState.LOW_VRAM lowvram_available = True -elif args.always_no_vram: +elif args.novram: set_vram_to = VRAMState.NO_VRAM -elif args.always_high_vram or args.always_gpu: +elif args.highvram or args.gpu_only: vram_state = VRAMState.HIGH_VRAM FORCE_FP32 = False FORCE_FP16 = False -if args.all_in_fp32: - print("Forcing FP32, if this improves things please report it.") +if args.force_fp32: + logging.info("Forcing FP32, if this improves things please report it.") FORCE_FP32 = True -if args.all_in_fp16: - print("Forcing FP16.") +if args.force_fp16: + logging.info("Forcing FP16.") FORCE_FP16 = True if lowvram_available: @@ -234,12 +231,12 @@ def is_nvidia(): if cpu_state == CPUState.MPS: vram_state = VRAMState.SHARED -print(f"Set vram state to: {vram_state.name}") +logging.info(f"Set vram state to: {vram_state.name}") -ALWAYS_VRAM_OFFLOAD = args.always_offload_from_vram +DISABLE_SMART_MEMORY = args.disable_smart_memory -if ALWAYS_VRAM_OFFLOAD: - print("Always offload VRAM") +if DISABLE_SMART_MEMORY: + logging.info("Disabling smart memory management") def get_torch_device_name(device): if hasattr(device, 'type'): @@ -257,11 +254,11 @@ def get_torch_device_name(device): return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) try: - print("Device:", get_torch_device_name(get_torch_device())) + logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) except: - print("Could not pick default device.") + logging.warning("Could not pick default device.") -print("VAE dtype:", VAE_DTYPE) +logging.info("VAE dtype: {}".format(VAE_DTYPE)) current_loaded_models = [] @@ -276,8 +273,9 @@ def module_size(module): class LoadedModel: def __init__(self, model): self.model = model - self.model_accelerated = False self.device = model.load_device + self.weights_loaded = False + self.real_model = None def model_memory(self): return self.model.model_size() @@ -288,55 +286,40 @@ def model_memory_required(self, device): else: return self.model_memory() - def model_load(self, lowvram_model_memory=0): - patch_model_to = None - if lowvram_model_memory == 0: - patch_model_to = self.device + def model_load(self, lowvram_model_memory=0, force_patch_weights=False): + patch_model_to = self.device self.model.model_patches_to(self.device) self.model.model_patches_to(self.model.model_dtype()) + load_weights = not self.weights_loaded + try: - self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU + if lowvram_model_memory > 0 and load_weights: + self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) + else: + self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) except Exception as e: self.model.unpatch_model(self.model.offload_device) self.model_unload() raise e - if lowvram_model_memory > 0: - print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024)) - mem_counter = 0 - for m in self.real_model.modules(): - if hasattr(m, "ldm_patched_cast_weights"): - m.prev_ldm_patched_cast_weights = m.ldm_patched_cast_weights - m.ldm_patched_cast_weights = True - module_mem = module_size(m) - if mem_counter + module_mem < lowvram_model_memory: - m.to(self.device) - mem_counter += module_mem - elif hasattr(m, "weight"): #only modules with ldm_patched_cast_weights can be set to lowvram mode - m.to(self.device) - mem_counter += module_size(m) - print("lowvram: loaded module regularly", m) - - self.model_accelerated = True - - if is_intel_xpu() and not args.disable_ipex_hijack: - self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) + if is_intel_xpu() and not args.disable_ipex_optimize: + self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True) + self.weights_loaded = True return self.real_model - def model_unload(self): - if self.model_accelerated: - for m in self.real_model.modules(): - if hasattr(m, "prev_ldm_patched_cast_weights"): - m.ldm_patched_cast_weights = m.prev_ldm_patched_cast_weights - del m.prev_ldm_patched_cast_weights - - self.model_accelerated = False + def should_reload_model(self, force_patch_weights=False): + if force_patch_weights and self.model.lowvram_patch_counter > 0: + return True + return False - self.model.unpatch_model(self.model.offload_device) + def model_unload(self, unpatch_weights=True): + self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.model_patches_to(self.model.offload_device) + self.weights_loaded = self.weights_loaded and not unpatch_weights + self.real_model = None def __eq__(self, other): return self.model is other.model @@ -344,31 +327,57 @@ def __eq__(self, other): def minimum_inference_memory(): return (1024 * 1024 * 1024) -def unload_model_clones(model): +def unload_model_clones(model, unload_weights_only=True, force_unload=True): to_unload = [] for i in range(len(current_loaded_models)): if model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload + if len(to_unload) == 0: + return True + + same_weights = 0 for i in to_unload: - print("unload clone", i) - current_loaded_models.pop(i).model_unload() + if model.clone_has_same_weights(current_loaded_models[i].model): + same_weights += 1 + + if same_weights == len(to_unload): + unload_weight = False + else: + unload_weight = True + + if not force_unload: + if unload_weights_only and unload_weight == False: + return None + + for i in to_unload: + logging.debug("unload clone {} {}".format(i, unload_weight)) + current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight) + + return unload_weight def free_memory(memory_required, device, keep_loaded=[]): - unloaded_model = False + unloaded_model = [] + can_unload = [] + for i in range(len(current_loaded_models) -1, -1, -1): - if not ALWAYS_VRAM_OFFLOAD: - if get_free_memory(device) > memory_required: - break shift_model = current_loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded: - m = current_loaded_models.pop(i) - m.model_unload() - del m - unloaded_model = True + can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) + + for x in sorted(can_unload): + i = x[-1] + if not DISABLE_SMART_MEMORY: + if get_free_memory(device) > memory_required: + break + current_loaded_models[i].model_unload() + unloaded_model.append(i) + + for i in sorted(unloaded_model, reverse=True): + current_loaded_models.pop(i) - if unloaded_model: + if len(unloaded_model) > 0: soft_empty_cache() else: if vram_state != VRAMState.HIGH_VRAM: @@ -376,24 +385,36 @@ def free_memory(memory_required, device, keep_loaded=[]): if mem_free_torch > mem_free_total * 0.25: soft_empty_cache() -def load_models_gpu(models, memory_required=0): +def load_models_gpu(models, memory_required=0, force_patch_weights=False): global vram_state inference_memory = minimum_inference_memory() extra_mem = max(inference_memory, memory_required) + models = set(models) + models_to_load = [] models_already_loaded = [] for x in models: loaded_model = LoadedModel(x) + loaded = None - if loaded_model in current_loaded_models: - index = current_loaded_models.index(loaded_model) - current_loaded_models.insert(0, current_loaded_models.pop(index)) - models_already_loaded.append(loaded_model) - else: + try: + loaded_model_index = current_loaded_models.index(loaded_model) + except: + loaded_model_index = None + + if loaded_model_index is not None: + loaded = current_loaded_models[loaded_model_index] + if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic + current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True) + loaded = None + else: + models_already_loaded.append(loaded) + + if loaded is None: if hasattr(x, "model"): - print(f"Requested to load {x.model.__class__.__name__}") + logging.info(f"Requested to load {x.model.__class__.__name__}") models_to_load.append(loaded_model) if len(models_to_load) == 0: @@ -403,17 +424,22 @@ def load_models_gpu(models, memory_required=0): free_memory(extra_mem, d, models_already_loaded) return - print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") + logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") total_memory_required = {} for loaded_model in models_to_load: - unload_model_clones(loaded_model.model) - total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different + total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) for device in total_memory_required: if device != torch.device("cpu"): free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) + for loaded_model in models_to_load: + weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded + if weights_unloaded is not None: + loaded_model.weights_loaded = not weights_unloaded + for loaded_model in models_to_load: model = loaded_model.model torch_dev = model.load_device @@ -426,15 +452,13 @@ def load_models_gpu(models, memory_required=0): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) - if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary - vram_set_state = VRAMState.LOW_VRAM - else: + if model_size <= (current_free_mem - inference_memory): #only switch to lowvram if really necessary lowvram_model_memory = 0 if vram_set_state == VRAMState.NO_VRAM: lowvram_model_memory = 64 * 1024 * 1024 - cur_loaded_model = loaded_model.model_load(lowvram_model_memory) + cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) current_loaded_models.insert(0, loaded_model) return @@ -442,11 +466,15 @@ def load_models_gpu(models, memory_required=0): def load_model_gpu(model): return load_models_gpu([model]) -def cleanup_models(): +def cleanup_models(keep_clone_weights_loaded=False): to_delete = [] for i in range(len(current_loaded_models)): if sys.getrefcount(current_loaded_models[i].model) <= 2: - to_delete = [i] + to_delete + if not keep_clone_weights_loaded: + to_delete = [i] + to_delete + #TODO: find a less fragile way to do this. + elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model + to_delete = [i] + to_delete for i in to_delete: x = current_loaded_models.pop(i) @@ -478,7 +506,7 @@ def unet_inital_load_device(parameters, dtype): return torch_dev cpu_dev = torch.device("cpu") - if ALWAYS_VRAM_OFFLOAD: + if DISABLE_SMART_MEMORY: return cpu_dev model_size = dtype_size(dtype) * parameters @@ -490,45 +518,54 @@ def unet_inital_load_device(parameters, dtype): else: return cpu_dev -def unet_dtype(device=None, model_params=0): - if args.unet_in_bf16: +def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): + if args.bf16_unet: return torch.bfloat16 - if args.unet_in_fp16: + if args.fp16_unet: return torch.float16 - if args.unet_in_fp8_e4m3fn: + if args.fp8_e4m3fn_unet: return torch.float8_e4m3fn - if args.unet_in_fp8_e5m2: + if args.fp8_e5m2_unet: return torch.float8_e5m2 - if should_use_fp16(device=device, model_params=model_params): - return torch.float16 + if should_use_fp16(device=device, model_params=model_params, manual_cast=True): + if torch.float16 in supported_dtypes: + return torch.float16 + if should_use_bf16(device, model_params=model_params, manual_cast=True): + if torch.bfloat16 in supported_dtypes: + return torch.bfloat16 return torch.float32 # None means no manual cast -def unet_manual_cast(weight_dtype, inference_device): +def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): if weight_dtype == torch.float32: return None - fp16_supported = ldm_patched.modules.model_management.should_use_fp16(inference_device, prioritize_performance=False) + fp16_supported = should_use_fp16(inference_device, prioritize_performance=False) if fp16_supported and weight_dtype == torch.float16: return None - if fp16_supported: + bf16_supported = should_use_bf16(inference_device) + if bf16_supported and weight_dtype == torch.bfloat16: + return None + + if fp16_supported and torch.float16 in supported_dtypes: return torch.float16 + + elif bf16_supported and torch.bfloat16 in supported_dtypes: + return torch.bfloat16 else: return torch.float32 def text_encoder_offload_device(): - if args.always_gpu: + if args.gpu_only: return get_torch_device() else: return torch.device("cpu") def text_encoder_device(): - if args.always_gpu: + if args.gpu_only: return get_torch_device() elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: - if is_intel_xpu(): - return torch.device("cpu") if should_use_fp16(prioritize_performance=False): return get_torch_device() else: @@ -537,36 +574,34 @@ def text_encoder_device(): return torch.device("cpu") def text_encoder_dtype(device=None): - if args.clip_in_fp8_e4m3fn: + if args.fp8_e4m3fn_text_enc: return torch.float8_e4m3fn - elif args.clip_in_fp8_e5m2: + elif args.fp8_e5m2_text_enc: return torch.float8_e5m2 - elif args.clip_in_fp16: + elif args.fp16_text_enc: return torch.float16 - elif args.clip_in_fp32: + elif args.fp32_text_enc: return torch.float32 if is_device_cpu(device): return torch.float16 - if should_use_fp16(device, prioritize_performance=False): - return torch.float16 - else: - return torch.float32 + return torch.float16 + def intermediate_device(): - if args.always_gpu: + if args.gpu_only: return get_torch_device() else: return torch.device("cpu") def vae_device(): - if args.vae_in_cpu: + if args.cpu_vae: return torch.device("cpu") return get_torch_device() def vae_offload_device(): - if args.always_gpu: + if args.gpu_only: return get_torch_device() else: return torch.device("cpu") @@ -594,8 +629,19 @@ def supports_dtype(device, dtype): #TODO def device_supports_non_blocking(device): if is_device_mps(device): return False #pytorch bug? mps doesn't support non blocking + if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews) + return False + if directml_enabled: + return False return True +def device_should_use_non_blocking(device): + if not device_supports_non_blocking(device): + return False + return False + # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others + + def cast_to_device(tensor, device, dtype, copy=False): device_supports_cast = False if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: @@ -606,7 +652,7 @@ def cast_to_device(tensor, device, dtype, copy=False): elif is_intel_xpu(): device_supports_cast = True - non_blocking = device_supports_non_blocking(device) + non_blocking = device_should_use_non_blocking(device) if device_supports_cast: if copy: @@ -649,6 +695,18 @@ def pytorch_attention_flash_attention(): return True return False +def force_upcast_attention_dtype(): + upcast = args.force_upcast_attention + try: + if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5 + upcast = True + except: + pass + if upcast: + return torch.float32 + else: + return None + def get_free_memory(dev=None, torch_free_too=False): global directml_enabled if dev is None: @@ -664,10 +722,10 @@ def get_free_memory(dev=None, torch_free_too=False): elif is_intel_xpu(): stats = torch.xpu.memory_stats(dev) mem_active = stats['active_bytes.all.current'] - mem_allocated = stats['allocated_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] mem_free_torch = mem_reserved - mem_active - mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated + mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved + mem_free_total = mem_free_xpu + mem_free_torch else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -689,19 +747,22 @@ def mps_mode(): global cpu_state return cpu_state == CPUState.MPS -def is_device_cpu(device): +def is_device_type(device, type): if hasattr(device, 'type'): - if (device.type == 'cpu'): + if (device.type == type): return True return False +def is_device_cpu(device): + return is_device_type(device, 'cpu') + def is_device_mps(device): - if hasattr(device, 'type'): - if (device.type == 'mps'): - return True - return False + return is_device_type(device, 'mps') + +def is_device_cuda(device): + return is_device_type(device, 'cuda') -def should_use_fp16(device=None, model_params=0, prioritize_performance=True): +def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): global directml_enabled if device is not None: @@ -711,9 +772,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): if FORCE_FP16: return True - if device is not None: #TODO + if device is not None: if is_device_mps(device): - return False + return True if FORCE_FP32: return False @@ -721,16 +782,22 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): if directml_enabled: return False - if cpu_mode() or mps_mode(): - return False #TODO ? + if mps_mode(): + return True + + if cpu_mode(): + return False if is_intel_xpu(): return True - if torch.cuda.is_bf16_supported(): + if torch.version.hip: return True props = torch.cuda.get_device_properties("cuda") + if props.major >= 8: + return True + if props.major < 6: return False @@ -738,12 +805,12 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): #FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled #when the model doesn't actually fit on the card #TODO: actually test if GP106 and others have the same type of behavior - nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050"] + nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"] for x in nvidia_10_series: if x in props.name.lower(): fp16_works = True - if fp16_works: + if fp16_works or manual_cast: free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) if (not prioritize_performance) or model_params * 4 > free_model_memory: return True @@ -759,6 +826,43 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): return True +def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): + if device is not None: + if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow + return False + + if device is not None: #TODO not sure about mps bf16 support + if is_device_mps(device): + return False + + if FORCE_FP32: + return False + + if directml_enabled: + return False + + if cpu_mode() or mps_mode(): + return False + + if is_intel_xpu(): + return True + + if device is None: + device = torch.device("cuda") + + props = torch.cuda.get_device_properties(device) + if props.major >= 8: + return True + + bf16_works = torch.cuda.is_bf16_supported() + + if bf16_works or manual_cast: + free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + if (not prioritize_performance) or model_params * 4 > free_model_memory: + return True + + return False + def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: @@ -775,6 +879,7 @@ def unload_all_models(): def resolve_lowvram_weight(weight, model, key): #TODO: remove + print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.") return weight #TODO: might be cleaner to put this somewhere else diff --git a/ldm_patched/modules/model_patcher.py b/ldm_patched/modules/model_patcher.py index dd816e52e..d42fdd57c 100644 --- a/ldm_patched/modules/model_patcher.py +++ b/ldm_patched/modules/model_patcher.py @@ -1,9 +1,55 @@ import torch import copy import inspect +import logging +import uuid import ldm_patched.modules.utils import ldm_patched.modules.model_management +from ldm_patched.modules.types import UnetWrapperFunction + + +def weight_decompose(dora_scale, weight, lora_diff, alpha, strength): + dora_scale = ldm_patched.modules.model_management.cast_to_device(dora_scale, weight.device, torch.float32) + lora_diff *= alpha + weight_calc = weight + lora_diff.type(weight.dtype) + weight_norm = ( + weight_calc.transpose(0, 1) + .reshape(weight_calc.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) + .transpose(0, 1) + ) + + weight_calc *= (dora_scale / weight_norm).type(weight.dtype) + if strength != 1.0: + weight_calc -= weight + weight += strength * (weight_calc) + else: + weight[:] = weight_calc + return weight + + +def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): + to = model_options["transformer_options"].copy() + + if "patches_replace" not in to: + to["patches_replace"] = {} + else: + to["patches_replace"] = to["patches_replace"].copy() + + if name not in to["patches_replace"]: + to["patches_replace"][name] = {} + else: + to["patches_replace"][name] = to["patches_replace"][name].copy() + + if transformer_index is not None: + block = (block_name, number, transformer_index) + else: + block = (block_name, number) + to["patches_replace"][name][block] = patch + model_options["transformer_options"] = to + return model_options class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): @@ -23,6 +69,9 @@ def __init__(self, model, load_device, offload_device, size=0, current_device=No self.current_device = current_device self.weight_inplace_update = weight_inplace_update + self.model_lowvram = False + self.lowvram_patch_counter = 0 + self.patches_uuid = uuid.uuid4() def model_size(self): if self.size > 0: @@ -37,10 +86,13 @@ def clone(self): n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] + n.patches_uuid = self.patches_uuid n.object_patches = self.object_patches.copy() n.model_options = copy.deepcopy(self.model_options) n.model_keys = self.model_keys + n.backup = self.backup + n.object_patches_backup = self.object_patches_backup return n def is_clone(self, other): @@ -48,6 +100,19 @@ def is_clone(self, other): return True return False + def clone_has_same_weights(self, clone): + if not self.is_clone(clone): + return False + + if len(self.patches) == 0 and len(clone.patches) == 0: + return True + + if self.patches_uuid == clone.patches_uuid: + if len(self.patches) != len(clone.patches): + logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.") + else: + return True + def memory_required(self, input_shape): return self.model.memory_required(input_shape=input_shape) @@ -64,9 +129,12 @@ def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_op if disable_cfg1_optimization: self.model_options["disable_cfg1_optimization"] = True - def set_model_unet_function_wrapper(self, unet_wrapper_function): + def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction): self.model_options["model_function_wrapper"] = unet_wrapper_function + def set_model_denoise_mask_function(self, denoise_mask_function): + self.model_options["denoise_mask_function"] = denoise_mask_function + def set_model_patch(self, patch, name): to = self.model_options["transformer_options"] if "patches" not in to: @@ -74,16 +142,7 @@ def set_model_patch(self, patch, name): to["patches"][name] = to["patches"].get(name, []) + [patch] def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None): - to = self.model_options["transformer_options"] - if "patches_replace" not in to: - to["patches_replace"] = {} - if name not in to["patches_replace"]: - to["patches_replace"][name] = {} - if transformer_index is not None: - block = (block_name, number, transformer_index) - else: - block = (block_name, number) - to["patches_replace"][name][block] = patch + self.model_options = set_model_options_patch_replace(self.model_options, patch, name, block_name, number, transformer_index=transformer_index) def set_model_attn1_patch(self, patch): self.set_model_patch(patch, "attn1_patch") @@ -115,6 +174,15 @@ def set_model_output_block_patch(self, patch): def add_object_patch(self, name, obj): self.object_patches[name] = obj + def get_model_object(self, name): + if name in self.object_patches: + return self.object_patches[name] + else: + if name in self.object_patches_backup: + return self.object_patches_backup[name] + else: + return ldm_patched.modules.utils.get_attr(self.model, name) + def model_patches_to(self, device): to = self.model_options["transformer_options"] if "patches" in to: @@ -149,6 +217,7 @@ def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): current_patches.append((strength_patch, patches[k], strength_model)) self.patches[k] = current_patches + self.patches_uuid = uuid.uuid4() return list(p) def get_key_patches(self, filter_prefix=None): @@ -174,37 +243,41 @@ def model_state_dict(self, filter_prefix=None): sd.pop(k) return sd + def patch_weight_to_device(self, key, device_to=None): + if key not in self.patches: + return + + weight = ldm_patched.modules.utils.get_attr(self.model, key) + + inplace_update = self.weight_inplace_update + + if key not in self.backup: + self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + + if device_to is not None: + temp_weight = ldm_patched.modules.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + if inplace_update: + ldm_patched.modules.utils.copy_to_param(self.model, key, out_weight) + else: + ldm_patched.modules.utils.set_attr_param(self.model, key, out_weight) + def patch_model(self, device_to=None, patch_weights=True): for k in self.object_patches: - old = getattr(self.model, k) + old = ldm_patched.modules.utils.set_attr(self.model, k, self.object_patches[k]) if k not in self.object_patches_backup: self.object_patches_backup[k] = old - setattr(self.model, k, self.object_patches[k]) if patch_weights: model_sd = self.model_state_dict() for key in self.patches: if key not in model_sd: - print("could not patch. key doesn't exist in model:", key) + logging.warning("could not patch. key doesn't exist in model: {}".format(key)) continue - weight = model_sd[key] - - inplace_update = self.weight_inplace_update - - if key not in self.backup: - self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) - - if device_to is not None: - temp_weight = ldm_patched.modules.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) - else: - temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - if inplace_update: - ldm_patched.modules.utils.copy_to_param(self.model, key, out_weight) - else: - ldm_patched.modules.utils.set_attr(self.model, key, out_weight) - del temp_weight + self.patch_weight_to_device(key, device_to) if device_to is not None: self.model.to(device_to) @@ -212,9 +285,60 @@ def patch_model(self, device_to=None, patch_weights=True): return self.model + def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): + self.patch_model(device_to, patch_weights=False) + + logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) + class LowVramPatch: + def __init__(self, key, model_patcher): + self.key = key + self.model_patcher = model_patcher + def __call__(self, weight): + return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) + + mem_counter = 0 + patch_counter = 0 + for n, m in self.model.named_modules(): + lowvram_weight = False + if hasattr(m, "comfy_cast_weights"): + module_mem = ldm_patched.modules.model_management.module_size(m) + if mem_counter + module_mem >= lowvram_model_memory: + lowvram_weight = True + + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + + if lowvram_weight: + if weight_key in self.patches: + if force_patch_weights: + self.patch_weight_to_device(weight_key) + else: + m.weight_function = LowVramPatch(weight_key, self) + patch_counter += 1 + if bias_key in self.patches: + if force_patch_weights: + self.patch_weight_to_device(bias_key) + else: + m.bias_function = LowVramPatch(bias_key, self) + patch_counter += 1 + + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + else: + if hasattr(m, "weight"): + self.patch_weight_to_device(weight_key, device_to) + self.patch_weight_to_device(bias_key, device_to) + m.to(device_to) + mem_counter += ldm_patched.modules.model_management.module_size(m) + logging.debug("lowvram: loaded module regularly {}".format(m)) + + self.model_lowvram = True + self.lowvram_patch_counter = patch_counter + return self.model + def calculate_weight(self, patches, weight, key): for p in patches: - alpha = p[0] + strength = p[0] v = p[1] strength_model = p[2] @@ -232,25 +356,33 @@ def calculate_weight(self, patches, weight, key): if patch_type == "diff": w1 = v[0] - if alpha != 0.0: + if strength != 0.0: if w1.shape != weight.shape: - print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) else: - weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype) + weight += strength * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype) elif patch_type == "lora": #lora/locon mat1 = ldm_patched.modules.model_management.cast_to_device(v[0], weight.device, torch.float32) mat2 = ldm_patched.modules.model_management.cast_to_device(v[1], weight.device, torch.float32) + dora_scale = v[4] if v[2] is not None: - alpha *= v[2] / mat2.shape[0] + alpha = v[2] / mat2.shape[0] + else: + alpha = 1.0 + if v[3] is not None: #locon mid weights, hopefully the math is fine because I didn't properly test it mat3 = ldm_patched.modules.model_management.cast_to_device(v[3], weight.device, torch.float32) final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) try: - weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) + lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength) + else: + weight += ((strength * alpha) * lora_diff).type(weight.dtype) except Exception as e: - print("ERROR", key, e) + logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "lokr": w1 = v[0] w2 = v[1] @@ -259,6 +391,7 @@ def calculate_weight(self, patches, weight, key): w2_a = v[5] w2_b = v[6] t2 = v[7] + dora_scale = v[8] dim = None if w1 is None: @@ -284,19 +417,29 @@ def calculate_weight(self, patches, weight, key): if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) if v[2] is not None and dim is not None: - alpha *= v[2] / dim + alpha = v[2] / dim + else: + alpha = 1.0 try: - weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) + lora_diff = torch.kron(w1, w2).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength) + else: + weight += ((strength * alpha) * lora_diff).type(weight.dtype) except Exception as e: - print("ERROR", key, e) + logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "loha": w1a = v[0] w1b = v[1] if v[2] is not None: - alpha *= v[2] / w1b.shape[0] + alpha = v[2] / w1b.shape[0] + else: + alpha = 1.0 + w2a = v[3] w2b = v[4] + dora_scale = v[7] if v[5] is not None: #cp decomposition t1 = v[5] t2 = v[6] @@ -316,42 +459,69 @@ def calculate_weight(self, patches, weight, key): ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32)) try: - weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) + lora_diff = (m1 * m2).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength) + else: + weight += ((strength * alpha) * lora_diff).type(weight.dtype) except Exception as e: - print("ERROR", key, e) + logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "glora": if v[4] is not None: - alpha *= v[4] / v[0].shape[0] + alpha = v[4] / v[0].shape[0] + else: + alpha = 1.0 + + dora_scale = v[5] a1 = ldm_patched.modules.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) a2 = ldm_patched.modules.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) b1 = ldm_patched.modules.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) b2 = ldm_patched.modules.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) - weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) + try: + lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength) + else: + weight += ((strength * alpha) * lora_diff).type(weight.dtype) + except Exception as e: + logging.error("ERROR {} {} {}".format(patch_type, key, e)) else: - print("patch type not recognized", patch_type, key) + logging.warning("patch type not recognized {} {}".format(patch_type, key)) return weight - def unpatch_model(self, device_to=None): - keys = list(self.backup.keys()) + def unpatch_model(self, device_to=None, unpatch_weights=True): + if unpatch_weights: + if self.model_lowvram: + for m in self.model.modules(): + if hasattr(m, "prev_comfy_cast_weights"): + m.comfy_cast_weights = m.prev_comfy_cast_weights + del m.prev_comfy_cast_weights + m.weight_function = None + m.bias_function = None - if self.weight_inplace_update: - for k in keys: - ldm_patched.modules.utils.copy_to_param(self.model, k, self.backup[k]) - else: - for k in keys: - ldm_patched.modules.utils.set_attr(self.model, k, self.backup[k]) + self.model_lowvram = False + self.lowvram_patch_counter = 0 - self.backup = {} + keys = list(self.backup.keys()) - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to + if self.weight_inplace_update: + for k in keys: + ldm_patched.modules.utils.copy_to_param(self.model, k, self.backup[k]) + else: + for k in keys: + ldm_patched.modules.utils.set_attr_param(self.model, k, self.backup[k]) + + self.backup.clear() + + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to keys = list(self.object_patches_backup.keys()) for k in keys: - setattr(self.model, k, self.object_patches_backup[k]) + ldm_patched.modules.utils.set_attr(self.model, k, self.object_patches_backup[k]) - self.object_patches_backup = {} + self.object_patches_backup.clear() diff --git a/ldm_patched/modules/model_sampling.py b/ldm_patched/modules/model_sampling.py index 8971b4e6e..5bad59c4a 100644 --- a/ldm_patched/modules/model_sampling.py +++ b/ldm_patched/modules/model_sampling.py @@ -72,7 +72,6 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) self.set_sigmas(sigmas) - self.set_alphas_cumprod(alphas_cumprod.float()) def set_sigmas(self, sigmas): self.register_buffer('sigmas', sigmas.float()) @@ -206,4 +205,4 @@ def percent_to_sigma(self, percent): return 0.0 percent = 1.0 - percent - return self.sigma(torch.tensor(percent)) \ No newline at end of file + return self.sigma(torch.tensor(percent)) diff --git a/ldm_patched/modules/ops.py b/ldm_patched/modules/ops.py index 2d7fa3776..80ac5f1ba 100644 --- a/ldm_patched/modules/ops.py +++ b/ldm_patched/modules/ops.py @@ -1,89 +1,135 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import torch import ldm_patched.modules.model_management def cast_bias_weight(s, input): bias = None - non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device) + non_blocking = ldm_patched.modules.model_management.device_should_use_non_blocking(input.device) if s.bias is not None: bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + if s.bias_function is not None: + bias = s.bias_function(bias) weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + if s.weight_function is not None: + weight = s.weight_function(weight) return weight, bias +class CastWeightBiasOp: + comfy_cast_weights = False + weight_function = None + bias_function = None class disable_weight_init: - class Linear(torch.nn.Linear): - ldm_patched_cast_weights = False + class Linear(torch.nn.Linear, CastWeightBiasOp): def reset_parameters(self): return None - def forward_ldm_patched_cast_weights(self, input): + def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) return torch.nn.functional.linear(input, weight, bias) def forward(self, *args, **kwargs): - if self.ldm_patched_cast_weights: - return self.forward_ldm_patched_cast_weights(*args, **kwargs) + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) - class Conv2d(torch.nn.Conv2d): - ldm_patched_cast_weights = False + class Conv2d(torch.nn.Conv2d, CastWeightBiasOp): def reset_parameters(self): return None - def forward_ldm_patched_cast_weights(self, input): + def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): - if self.ldm_patched_cast_weights: - return self.forward_ldm_patched_cast_weights(*args, **kwargs) + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) - class Conv3d(torch.nn.Conv3d): - ldm_patched_cast_weights = False + class Conv3d(torch.nn.Conv3d, CastWeightBiasOp): def reset_parameters(self): return None - def forward_ldm_patched_cast_weights(self, input): + def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): - if self.ldm_patched_cast_weights: - return self.forward_ldm_patched_cast_weights(*args, **kwargs) + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) - class GroupNorm(torch.nn.GroupNorm): - ldm_patched_cast_weights = False + class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp): def reset_parameters(self): return None - def forward_ldm_patched_cast_weights(self, input): + def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) def forward(self, *args, **kwargs): - if self.ldm_patched_cast_weights: - return self.forward_ldm_patched_cast_weights(*args, **kwargs) + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) - class LayerNorm(torch.nn.LayerNorm): - ldm_patched_cast_weights = False + class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp): def reset_parameters(self): return None - def forward_ldm_patched_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) + def forward_comfy_cast_weights(self, input): + if self.weight is not None: + weight, bias = cast_bias_weight(self, input) + else: + weight = None + bias = None return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) def forward(self, *args, **kwargs): - if self.ldm_patched_cast_weights: - return self.forward_ldm_patched_cast_weights(*args, **kwargs) + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input, output_size=None): + num_spatial_dims = 2 + output_padding = self._output_padding( + input, output_size, self.stride, self.padding, self.kernel_size, + num_spatial_dims, self.dilation) + + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.conv_transpose2d( + input, weight, bias, self.stride, self.padding, + output_padding, self.groups, self.dilation) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -99,16 +145,19 @@ def conv_nd(s, dims, *args, **kwargs): class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear): - ldm_patched_cast_weights = True + comfy_cast_weights = True class Conv2d(disable_weight_init.Conv2d): - ldm_patched_cast_weights = True + comfy_cast_weights = True class Conv3d(disable_weight_init.Conv3d): - ldm_patched_cast_weights = True + comfy_cast_weights = True class GroupNorm(disable_weight_init.GroupNorm): - ldm_patched_cast_weights = True + comfy_cast_weights = True class LayerNorm(disable_weight_init.LayerNorm): - ldm_patched_cast_weights = True + comfy_cast_weights = True + + class ConvTranspose2d(disable_weight_init.ConvTranspose2d): + comfy_cast_weights = True diff --git a/ldm_patched/modules/sample.py b/ldm_patched/modules/sample.py index 0f4839503..9a37ac85c 100644 --- a/ldm_patched/modules/sample.py +++ b/ldm_patched/modules/sample.py @@ -1,10 +1,9 @@ import torch import ldm_patched.modules.model_management import ldm_patched.modules.samplers -import ldm_patched.modules.conds import ldm_patched.modules.utils -import math import numpy as np +import logging def prepare_noise(latent_image, seed, noise_inds=None): """ @@ -25,94 +24,21 @@ def prepare_noise(latent_image, seed, noise_inds=None): noises = torch.cat(noises, axis=0) return noises -def prepare_mask(noise_mask, shape, device): - """ensures noise mask is of proper dimensions""" - noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") - noise_mask = torch.cat([noise_mask] * shape[1], dim=1) - noise_mask = ldm_patched.modules.utils.repeat_to_batch_size(noise_mask, shape[0]) - noise_mask = noise_mask.to(device) - return noise_mask - -def get_models_from_cond(cond, model_type): - models = [] - for c in cond: - if model_type in c: - models += [c[model_type]] - return models - -def convert_cond(cond): - out = [] - for c in cond: - temp = c[1].copy() - model_conds = temp.get("model_conds", {}) - if c[0] is not None: - model_conds["c_crossattn"] = ldm_patched.modules.conds.CONDCrossAttn(c[0]) #TODO: remove - temp["cross_attn"] = c[0] - temp["model_conds"] = model_conds - out.append(temp) - return out - -def get_additional_models(positive, negative, dtype): - """loads additional models in positive and negative conditioning""" - control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")) - - inference_memory = 0 - control_models = [] - for m in control_nets: - control_models += m.get_models() - inference_memory += m.inference_memory_requirements(dtype) - - gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") - gligen = [x[1] for x in gligen] - models = control_models + gligen - return models, inference_memory - -def cleanup_additional_models(models): - """cleanup additional models that were loaded""" - for m in models: - if hasattr(m, 'cleanup'): - m.cleanup() - def prepare_sampling(model, noise_shape, positive, negative, noise_mask): - device = model.load_device - positive = convert_cond(positive) - negative = convert_cond(negative) - - if noise_mask is not None: - noise_mask = prepare_mask(noise_mask, noise_shape, device) - - real_model = None - models, inference_memory = get_additional_models(positive, negative, model.model_dtype()) - ldm_patched.modules.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory) - real_model = model.model - - return real_model, positive, negative, noise_mask, models + logging.warning("Warning: ldm_patched.modules.sample.prepare_sampling isn't used anymore and can be removed") + return model, positive, negative, noise_mask, [] +def cleanup_additional_models(models): + logging.warning("Warning: ldm_patched.modules.sample.cleanup_additional_models isn't used anymore and can be removed") def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): - real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) - - noise = noise.to(model.load_device) - latent_image = latent_image.to(model.load_device) + sampler = ldm_patched.modules.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - sampler = ldm_patched.modules.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) + samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.to(ldm_patched.modules.model_management.intermediate_device()) - - cleanup_additional_models(models) - cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): - real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) - noise = noise.to(model.load_device) - latent_image = latent_image.to(model.load_device) - sigmas = sigmas.to(model.load_device) - - samples = ldm_patched.modules.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) + samples = ldm_patched.modules.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.to(ldm_patched.modules.model_management.intermediate_device()) - cleanup_additional_models(models) - cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples - diff --git a/ldm_patched/modules/sampler_helpers.py b/ldm_patched/modules/sampler_helpers.py new file mode 100644 index 000000000..6e9e1df02 --- /dev/null +++ b/ldm_patched/modules/sampler_helpers.py @@ -0,0 +1,76 @@ +import torch +import ldm_patched.modules.model_management +import ldm_patched.modules.conds + +def prepare_mask(noise_mask, shape, device): + """ensures noise mask is of proper dimensions""" + noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") + noise_mask = torch.cat([noise_mask] * shape[1], dim=1) + noise_mask = ldm_patched.modules.utils.repeat_to_batch_size(noise_mask, shape[0]) + noise_mask = noise_mask.to(device) + return noise_mask + +def get_models_from_cond(cond, model_type): + models = [] + for c in cond: + if model_type in c: + models += [c[model_type]] + return models + +def convert_cond(cond): + out = [] + for c in cond: + temp = c[1].copy() + model_conds = temp.get("model_conds", {}) + if c[0] is not None: + model_conds["c_crossattn"] = ldm_patched.modules.conds.CONDCrossAttn(c[0]) #TODO: remove + temp["cross_attn"] = c[0] + temp["model_conds"] = model_conds + out.append(temp) + return out + +def get_additional_models(conds, dtype): + """loads additional models in conditioning""" + cnets = [] + gligen = [] + + for k in conds: + cnets += get_models_from_cond(conds[k], "control") + gligen += get_models_from_cond(conds[k], "gligen") + + control_nets = set(cnets) + + inference_memory = 0 + control_models = [] + for m in control_nets: + control_models += m.get_models() + inference_memory += m.inference_memory_requirements(dtype) + + gligen = [x[1] for x in gligen] + models = control_models + gligen + return models, inference_memory + +def cleanup_additional_models(models): + """cleanup additional models that were loaded""" + for m in models: + if hasattr(m, 'cleanup'): + m.cleanup() + + +def prepare_sampling(model, noise_shape, conds): + device = model.load_device + real_model = None + models, inference_memory = get_additional_models(conds, model.model_dtype()) + ldm_patched.modules.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory) + real_model = model.model + + return real_model, conds, models + +def cleanup_models(conds, models): + cleanup_additional_models(models) + + control_cleanup = [] + for k in conds: + control_cleanup += get_models_from_cond(conds[k], "control") + + cleanup_additional_models(set(control_cleanup)) diff --git a/ldm_patched/modules/samplers.py b/ldm_patched/modules/samplers.py index 05b4b3174..631194702 100644 --- a/ldm_patched/modules/samplers.py +++ b/ldm_patched/modules/samplers.py @@ -4,6 +4,8 @@ import collections from ldm_patched.modules import model_management import math +import logging +import ldm_patched.modules.sampler_helpers def get_area_and_mult(conds, x_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) @@ -32,7 +34,7 @@ def get_area_and_mult(conds, x_in, timestep_in): mask = conds['mask'] assert(mask.shape[1] == x_in.shape[2]) assert(mask.shape[2] == x_in.shape[3]) - mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength + mask = mask[:input_x.shape[0],area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) else: mask = torch.ones_like(input_x) @@ -126,30 +128,23 @@ def cond_cat(c_list): return out -def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): - out_cond = torch.zeros_like(x_in) - out_count = torch.ones_like(x_in) * 1e-37 - - out_uncond = torch.zeros_like(x_in) - out_uncond_count = torch.ones_like(x_in) * 1e-37 - - COND = 0 - UNCOND = 1 - +def calc_cond_batch(model, conds, x_in, timestep, model_options): + out_conds = [] + out_counts = [] to_run = [] - for x in cond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue - to_run += [(p, COND)] - if uncond is not None: - for x in uncond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue + for i in range(len(conds)): + out_conds.append(torch.zeros_like(x_in)) + out_counts.append(torch.ones_like(x_in) * 1e-37) - to_run += [(p, UNCOND)] + cond = conds[i] + if cond is not None: + for x in cond: + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue + + to_run += [(p, i)] while len(to_run) > 0: first = to_run[0] @@ -208,6 +203,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): cur_patches[p] = cur_patches[p] + patches[p] else: cur_patches[p] = patches[p] + transformer_options["patches"] = cur_patches else: transformer_options["patches"] = patches @@ -220,71 +216,66 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) else: output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) - del input_x for o in range(batch_chunks): - if cond_or_uncond[o] == COND: - out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - else: - out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - del mult + cond_index = cond_or_uncond[o] + out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] + out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] + + for i in range(len(out_conds)): + out_conds[i] /= out_counts[i] - out_cond /= out_count - del out_count - out_uncond /= out_uncond_count - del out_uncond_count - return out_cond, out_uncond + return out_conds + +def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove + logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.") + return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options)) + +def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None): + if "sampler_cfg_function" in model_options: + args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, + "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} + cfg_result = x - model_options["sampler_cfg_function"](args) + else: + cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale + + for fn in model_options.get("sampler_post_cfg_function", []): + args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, + "sigma": timestep, "model_options": model_options, "input": x} + cfg_result = fn(args) + + return cfg_result #The main sampling function shared by all the samplers #Returns denoised def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): - if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: - uncond_ = None - else: - uncond_ = uncond - - cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) - if "sampler_cfg_function" in model_options: - args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, - "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} - cfg_result = x - model_options["sampler_cfg_function"](args) - else: - cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale + if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: + uncond_ = None + else: + uncond_ = uncond - for fn in model_options.get("sampler_post_cfg_function", []): - args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, - "sigma": timestep, "model_options": model_options, "input": x} - cfg_result = fn(args) + conds = [cond, uncond_] + out = calc_cond_batch(model, conds, x, timestep, model_options) + return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_) - return cfg_result -class CFGNoisePredictor(torch.nn.Module): - def __init__(self, model): - super().__init__() +class KSamplerX0Inpaint: + def __init__(self, model, sigmas): self.inner_model = model - def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None): - out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed) - return out - def forward(self, *args, **kwargs): - return self.apply_model(*args, **kwargs) - -class KSamplerX0Inpaint(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): + self.sigmas = sigmas + def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None): if denoise_mask is not None: + if "denoise_mask_function" in model_options: + denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) latent_mask = 1. - denoise_mask - x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask - out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) + x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask + out = self.inner_model(x, sigma, model_options=model_options, seed=seed) if denoise_mask is not None: out = out * denoise_mask + self.latent_image * latent_mask return out -def simple_scheduler(model, steps): - s = model.model_sampling +def simple_scheduler(model_sampling, steps): + s = model_sampling sigs = [] ss = len(s.sigmas) / steps for x in range(steps): @@ -292,10 +283,10 @@ def simple_scheduler(model, steps): sigs += [0.0] return torch.FloatTensor(sigs) -def ddim_scheduler(model, steps): - s = model.model_sampling +def ddim_scheduler(model_sampling, steps): + s = model_sampling sigs = [] - ss = len(s.sigmas) // steps + ss = max(len(s.sigmas) // steps, 1) x = 1 while x < len(s.sigmas): sigs += [float(s.sigmas[x])] @@ -304,8 +295,8 @@ def ddim_scheduler(model, steps): sigs += [0.0] return torch.FloatTensor(sigs) -def normal_scheduler(model, steps, sgm=False, floor=False): - s = model.model_sampling +def normal_scheduler(model_sampling, steps, sgm=False, floor=False): + s = model_sampling start = s.timestep(s.sigma_max) end = s.timestep(s.sigma_min) @@ -513,14 +504,6 @@ def max_denoise(self, model_wrap, sigmas): sigma = float(sigmas[0]) return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma -class UNIPC(Sampler): - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar) - -class UNIPCBH2(Sampler): - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) - KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "tcd", "edm_playground_v2.5", "restart"] @@ -533,7 +516,7 @@ def __init__(self, sampler_function, extra_options={}, inpaint_options={}): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): extra_args["denoise_mask"] = denoise_mask - model_k = KSamplerX0Inpaint(model_wrap) + model_k = KSamplerX0Inpaint(model_wrap, sigmas) model_k.latent_image = latent_image if self.inpaint_options.get("random", False): #TODO: Should this be the default? generator = torch.manual_seed(extra_args.get("seed", 41) + 1) @@ -541,26 +524,24 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N else: model_k.noise = noise - if self.max_denoise(model_wrap, sigmas): - noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) - else: - noise = noise * sigmas[0] + noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas)) k_callback = None total_steps = len(sigmas) - 1 if callback is not None: k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) - if latent_image is not None: - noise += latent_image - samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) + samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples) return samples def ksampler(sampler_name, extra_options={}, inpaint_options={}): if sampler_name == "dpm_fast": def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable): + if len(sigmas) <= 1: + return noise + sigma_min = sigmas[-1] if sigma_min == 0: sigma_min = sigmas[-2] @@ -568,81 +549,145 @@ def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable): return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable) sampler_function = dpm_fast_function elif sampler_name == "dpm_adaptive": - def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable): + def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options): + if len(sigmas) <= 1: + return noise + sigma_min = sigmas[-1] if sigma_min == 0: sigma_min = sigmas[-2] - return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable) + return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable, **extra_options) sampler_function = dpm_adaptive_function else: sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) return KSAMPLER(sampler_function, extra_options, inpaint_options) -def wrap_model(model): - model_denoise = CFGNoisePredictor(model) - return model_denoise -def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - positive = positive[:] - negative = negative[:] +def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None): + for k in conds: + conds[k] = conds[k][:] + resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device) - resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device) - resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device) + for k in conds: + calculate_start_end_timesteps(model, conds[k]) - model_wrap = wrap_model(model) + if hasattr(model, 'extra_conds'): + for k in conds: + conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) - calculate_start_end_timesteps(model, negative) - calculate_start_end_timesteps(model, positive) + #make sure each cond area has an opposite one with the same area + for k in conds: + for c in conds[k]: + for kk in conds: + if k != kk: + create_cond_with_same_area_if_none(conds[kk], c) + + for k in conds: + pre_run_control(model, conds[k]) + + if "positive" in conds: + positive = conds["positive"] + for k in conds: + if k != "positive": + apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), conds[k], 'control', lambda cond_cnets, x: cond_cnets[x]) + apply_empty_x_to_equal_area(positive, conds[k], 'gligen', lambda cond_cnets, x: cond_cnets[x]) - if latent_image is not None: - latent_image = model.process_latent_in(latent_image) + return conds - if hasattr(model, 'extra_conds'): - positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) - negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) +class CFGGuider: + def __init__(self, model_patcher): + self.model_patcher = model_patcher + self.model_options = model_patcher.model_options + self.original_conds = {} + self.cfg = 1.0 - #make sure each cond area has an opposite one with the same area - for c in positive: - create_cond_with_same_area_if_none(negative, c) - for c in negative: - create_cond_with_same_area_if_none(positive, c) + def set_conds(self, positive, negative): + self.inner_set_conds({"positive": positive, "negative": negative}) + + def set_cfg(self, cfg): + self.cfg = cfg + + def inner_set_conds(self, conds): + for k in conds: + self.original_conds[k] = ldm_patched.modules.sampler_helpers.convert_cond(conds[k]) + + def __call__(self, *args, **kwargs): + return self.predict_noise(*args, **kwargs) + + def predict_noise(self, x, timestep, model_options={}, seed=None): + return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) - pre_run_control(model, negative + positive) + def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed): + if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. + latent_image = self.inner_model.process_latent_in(latent_image) - apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) - apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) + self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} + extra_args = {"model_options": self.model_options, "seed":seed} + + samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) + return self.inner_model.process_latent_out(samples.to(torch.float32)) + + def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + if sigmas.shape[-1] == 0: + return latent_image + + self.conds = {} + for k in self.original_conds: + self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) + + self.inner_model, self.conds, self.loaded_models = ldm_patched.modules.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) + device = self.model_patcher.load_device + + if denoise_mask is not None: + denoise_mask = ldm_patched.modules.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) + + noise = noise.to(device) + latent_image = latent_image.to(device) + sigmas = sigmas.to(device) + + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + + ldm_patched.modules.sampler_helpers.cleanup_models(self.conds, self.loaded_models) + del self.inner_model + del self.conds + del self.loaded_models + return output + + +def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + cfg_guider = CFGGuider(model) + cfg_guider.set_conds(positive, negative) + cfg_guider.set_cfg(cfg) + return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) - samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) - return model.process_latent_out(samples.to(torch.float32)) SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] -def calculate_sigmas_scheduler(model, scheduler_name, steps): +def calculate_sigmas(model_sampling, scheduler_name, steps): if scheduler_name == "karras": - sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) elif scheduler_name == "exponential": - sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) elif scheduler_name == "normal": - sigmas = normal_scheduler(model, steps) + sigmas = normal_scheduler(model_sampling, steps) elif scheduler_name == "simple": - sigmas = simple_scheduler(model, steps) + sigmas = simple_scheduler(model_sampling, steps) elif scheduler_name == "ddim_uniform": - sigmas = ddim_scheduler(model, steps) + sigmas = ddim_scheduler(model_sampling, steps) elif scheduler_name == "sgm_uniform": - sigmas = normal_scheduler(model, steps, sgm=True) + sigmas = normal_scheduler(model_sampling, steps, sgm=True) else: - print("error invalid scheduler", scheduler_name) + logging.error("error invalid scheduler {}".format(scheduler_name)) return sigmas def sampler_object(name): if name == "uni_pc": - sampler = UNIPC() + sampler = KSAMPLER(uni_pc.sample_unipc) elif name == "uni_pc_bh2": - sampler = UNIPCBH2() + sampler = KSAMPLER(uni_pc.sample_unipc_bh2) elif name == "ddim": sampler = ksampler("euler", inpaint_options={"random": True}) else: @@ -652,6 +697,7 @@ def sampler_object(name): class KSampler: SCHEDULERS = SCHEDULER_NAMES SAMPLERS = SAMPLER_NAMES + DISCARD_PENULTIMATE_SIGMA_SAMPLERS = set(('dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2')) def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model @@ -670,11 +716,11 @@ def calculate_sigmas(self, steps): sigmas = None discard_penultimate_sigma = False - if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']: + if self.sampler in self.DISCARD_PENULTIMATE_SIGMA_SAMPLERS: steps += 1 discard_penultimate_sigma = True - sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps) + sigmas = calculate_sigmas(self.model, self.scheduler, steps) if discard_penultimate_sigma: sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) @@ -685,9 +731,12 @@ def set_steps(self, steps, denoise=None): if denoise is None or denoise > 0.9999: self.sigmas = self.calculate_sigmas(steps).to(self.device) else: - new_steps = int(steps/denoise) - sigmas = self.calculate_sigmas(new_steps).to(self.device) - self.sigmas = sigmas[-(steps + 1):] + if denoise <= 0.0: + self.sigmas = torch.FloatTensor([]) + else: + new_steps = int(steps/denoise) + sigmas = self.calculate_sigmas(new_steps).to(self.device) + self.sigmas = sigmas[-(steps + 1):] def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): if sigmas is None: diff --git a/ldm_patched/modules/sd.py b/ldm_patched/modules/sd.py index 282f2559a..0ba9143e8 100644 --- a/ldm_patched/modules/sd.py +++ b/ldm_patched/modules/sd.py @@ -1,7 +1,12 @@ import torch +from enum import Enum +import logging from ldm_patched.modules import model_management from ldm_patched.ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine +from ldm_patched.ldm.cascade.stage_a import StageA +from ldm_patched.ldm.cascade.stage_c_coder import StageC_coder + import yaml import ldm_patched.modules.utils @@ -9,7 +14,6 @@ from . import clip_vision from . import gligen from . import diffusers_convert -from . import model_base from . import model_detection from . import sd1_clip @@ -33,7 +37,7 @@ def load_model_weights(model, sd): w = sd.pop(x) del w if len(m) > 0: - print("extra", m) + logging.warning("missing {}".format(m)) return model def load_clip_weights(model, sd): @@ -48,7 +52,7 @@ def load_clip_weights(model, sd): if ids.dtype == torch.float32: sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() - sd = ldm_patched.modules.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) + sd = ldm_patched.modules.utils.clip_text_transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.") return load_model_weights(model, sd) @@ -77,7 +81,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): k1 = set(k1) for x in loaded: if (x not in k) and (x not in k1): - print("NOT LOADED", x) + logging.warning("NOT LOADED {}".format(x)) return (new_modelpatcher, new_clip) @@ -119,10 +123,13 @@ def tokenize(self, text, return_word_ids=False): return self.tokenizer.tokenize_with_weights(text, return_word_ids) def encode_from_tokens(self, tokens, return_pooled=False): + self.cond_stage_model.reset_clip_options() + if self.layer_idx is not None: - self.cond_stage_model.clip_layer(self.layer_idx) - else: - self.cond_stage_model.reset_clip_layer() + self.cond_stage_model.set_clip_options({"layer": self.layer_idx}) + + if return_pooled == "unprojected": + self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() cond, pooled = self.cond_stage_model.encode_token_weights(tokens) @@ -134,8 +141,11 @@ def encode(self, text): tokens = self.tokenize(text) return self.encode_from_tokens(tokens) - def load_sd(self, sd): - return self.cond_stage_model.load_sd(sd) + def load_sd(self, sd, full_model=False): + if full_model: + return self.cond_stage_model.load_state_dict(sd, strict=False) + else: + return self.cond_stage_model.load_sd(sd) def get_sd(self): return self.cond_stage_model.state_dict() @@ -155,7 +165,10 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) self.downscale_ratio = 8 + self.upscale_ratio = 8 self.latent_channels = 4 + self.process_input = lambda image: image * 2.0 - 1.0 + self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) if config is None: if "decoder.mid.block_1.mix_factor" in sd: @@ -168,25 +181,64 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): decoder_config={'target': "ldm_patched.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) elif "taesd_decoder.1.weight" in sd: self.first_stage_model = ldm_patched.taesd.taesd.TAESD() - else: + elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade + self.first_stage_model = StageA() + self.downscale_ratio = 4 + self.upscale_ratio = 4 + #TODO + #self.memory_used_encode + #self.memory_used_decode + self.process_input = lambda image: image + self.process_output = lambda image: image + elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade + self.first_stage_model = StageC_coder() + self.downscale_ratio = 32 + self.latent_channels = 16 + new_sd = {} + for k in sd: + new_sd["encoder.{}".format(k)] = sd[k] + sd = new_sd + elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade + self.first_stage_model = StageC_coder() + self.latent_channels = 16 + new_sd = {} + for k in sd: + new_sd["previewer.{}".format(k)] = sd[k] + sd = new_sd + elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade + self.first_stage_model = StageC_coder() + self.downscale_ratio = 32 + self.latent_channels = 16 + elif "decoder.conv_in.weight" in sd: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE + if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE ddconfig['ch_mult'] = [1, 2, 4] self.downscale_ratio = 4 - - self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) + self.upscale_ratio = 4 + + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + if 'quant_conv.weight' in sd: + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) + else: + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "ldm_patched.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "ldm_patched.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, + decoder_config={'target': "ldm_patched.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) + else: + logging.warning("WARNING: No VAE weights detected, VAE not initalized.") + self.first_stage_model = None + return else: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() m, u = self.first_stage_model.load_state_dict(sd, strict=False) if len(m) > 0: - print("Missing VAE keys", m) + logging.warning("Missing VAE keys {}".format(m)) if len(u) > 0: - print("Leftover VAE keys", u) + logging.debug("Leftover VAE keys {}".format(u)) if device is None: device = model_management.vae_device() @@ -200,18 +252,27 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + def vae_encode_crop_pixels(self, pixels): + x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio + y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio + if pixels.shape[1] != x or pixels.shape[2] != y: + x_offset = (pixels.shape[1] % self.downscale_ratio) // 2 + y_offset = (pixels.shape[2] % self.downscale_ratio) // 2 + pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] + return pixels + def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = ldm_patched.modules.utils.ProgressBar(steps) - decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() - output = torch.clamp(( - (ldm_patched.modules.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) + - ldm_patched.modules.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) + - ldm_patched.modules.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar)) - / 3.0) / 2.0, min=0.0, max=1.0) + decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + output = self.process_output( + (ldm_patched.modules.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + + ldm_patched.modules.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + + ldm_patched.modules.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar)) + / 3.0) return output def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): @@ -220,7 +281,7 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): steps += pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = ldm_patched.modules.utils.ProgressBar(steps) - encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() samples = ldm_patched.modules.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += ldm_patched.modules.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += ldm_patched.modules.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) @@ -235,12 +296,12 @@ def decode(self, samples_in): batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device) + pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device) for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0) + pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) except model_management.OOM_EXCEPTION as e: - print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) @@ -252,6 +313,7 @@ def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): return output.movedim(1,-1) def encode(self, pixel_samples): + pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1,1) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) @@ -261,16 +323,17 @@ def encode(self, pixel_samples): batch_number = max(1, batch_number) samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device) for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) + pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device) samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() except model_management.OOM_EXCEPTION as e: - print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") + logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") samples = self.encode_tiled_(pixel_samples) return samples def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): + pixel_samples = self.vae_encode_crop_pixels(pixel_samples) model_management.load_model_gpu(self.patcher) pixel_samples = pixel_samples.movedim(-1,1) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) @@ -291,14 +354,17 @@ def load_style_model(ckpt_path): model_data = ldm_patched.modules.utils.load_torch_file(ckpt_path, safe_load=True) keys = model_data.keys() if "style_embedding" in keys: - model = ldm_patched.t2ia.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) + model = ldm_patched.modules.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) else: raise Exception("invalid style model {}".format(ckpt_path)) model.load_state_dict(model_data) return StyleModel(model) +class CLIPType(Enum): + STABLE_DIFFUSION = 1 + STABLE_CASCADE = 2 -def load_clip(ckpt_paths, embedding_directory=None): +def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION): clip_data = [] for p in ckpt_paths: clip_data.append(ldm_patched.modules.utils.load_torch_file(p, safe_load=True)) @@ -308,14 +374,21 @@ class EmptyClass: for i in range(len(clip_data)): if "transformer.resblocks.0.ln_1.weight" in clip_data[i]: - clip_data[i] = ldm_patched.modules.utils.transformers_convert(clip_data[i], "", "text_model.", 32) + clip_data[i] = ldm_patched.modules.utils.clip_text_transformers_convert(clip_data[i], "", "") + else: + if "text_projection" in clip_data[i]: + clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node clip_target = EmptyClass() clip_target.params = {} if len(clip_data) == 1: if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]: - clip_target.clip = sdxl_clip.SDXLRefinerClipModel - clip_target.tokenizer = sdxl_clip.SDXLTokenizer + if clip_type == CLIPType.STABLE_CASCADE: + clip_target.clip = sdxl_clip.StableCascadeClipModel + clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer + else: + clip_target.clip = sdxl_clip.SDXLRefinerClipModel + clip_target.tokenizer = sdxl_clip.SDXLTokenizer elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: clip_target.clip = sd2_clip.SD2ClipModel clip_target.tokenizer = sd2_clip.SD2Tokenizer @@ -330,10 +403,10 @@ class EmptyClass: for c in clip_data: m, u = clip.load_sd(c) if len(m) > 0: - print("clip missing:", m) + logging.warning("clip missing: {}".format(m)) if len(u) > 0: - print("clip unexpected:", u) + logging.debug("clip unexpected: {}".format(u)) return clip def load_gligen(ckpt_path): @@ -344,6 +417,8 @@ def load_gligen(ckpt_path): return ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): + logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.") + model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True) #TODO: this function is a mess and should be removed eventually if config is None: with open(config_path, 'r') as stream: @@ -351,81 +426,20 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl model_config_params = config['model']['params'] clip_config = model_config_params['cond_stage_config'] scale_factor = model_config_params['scale_factor'] - vae_config = model_config_params['first_stage_config'] - - fp16 = False - if "unet_config" in model_config_params: - if "params" in model_config_params["unet_config"]: - unet_config = model_config_params["unet_config"]["params"] - if "use_fp16" in unet_config: - fp16 = unet_config.pop("use_fp16") - if fp16: - unet_config["dtype"] = torch.float16 - - noise_aug_config = None - if "noise_aug_config" in model_config_params: - noise_aug_config = model_config_params["noise_aug_config"] - - model_type = model_base.ModelType.EPS if "parameterization" in model_config_params: if model_config_params["parameterization"] == "v": - model_type = model_base.ModelType.V_PREDICTION - - clip = None - vae = None - - class WeightsLoader(torch.nn.Module): - pass - - if state_dict is None: - state_dict = ldm_patched.modules.utils.load_torch_file(ckpt_path) - - class EmptyClass: - pass + m = model.clone() + class ModelSamplingAdvanced(ldm_patched.modules.model_sampling.ModelSamplingDiscrete, ldm_patched.modules.model_sampling.V_PREDICTION): + pass + m.add_object_patch("model_sampling", ModelSamplingAdvanced(model.model.model_config)) + model = m - model_config = ldm_patched.modules.supported_models_base.BASE({}) + layer_idx = clip_config.get("params", {}).get("layer_idx", None) + if layer_idx is not None: + clip.clip_layer(layer_idx) - from . import latent_formats - model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor) - model_config.unet_config = model_detection.convert_config(unet_config) - - if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): - model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type) - else: - model = model_base.BaseModel(model_config, model_type=model_type) - - if config['model']["target"].endswith("LatentInpaintDiffusion"): - model.set_inpaint() - - if fp16: - model = model.half() - - offload_device = model_management.unet_offload_device() - model = model.to(offload_device) - model.load_model_weights(state_dict, "model.diffusion_model.") - - if output_vae: - vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True) - vae = VAE(sd=vae_sd, config=vae_config) - - if output_clip: - w = WeightsLoader() - clip_target = EmptyClass() - clip_target.params = clip_config.get("params", {}) - if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"): - clip_target.clip = sd2_clip.SD2ClipModel - clip_target.tokenizer = sd2_clip.SD2Tokenizer - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model.clip_h - elif clip_config["target"].endswith("FrozenCLIPEmbedder"): - clip_target.clip = sd1_clip.SD1ClipModel - clip_target.tokenizer = sd1_clip.SD1Tokenizer - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model.clip_l - load_clip_weights(w, state_dict) - - return (ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) + return (model, clip, vae) def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, vae_filename_param=None): sd = ldm_patched.modules.utils.load_torch_file(ckpt_path) @@ -439,15 +453,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o clip_target = None parameters = ldm_patched.modules.utils.calculate_parameters(sd, "model.diffusion_model.") - unet_dtype = model_management.unet_dtype(model_params=parameters) load_device = model_management.get_torch_device() - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) - class WeightsLoader(torch.nn.Module): - pass - - model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype) - model_config.set_manual_cast(manual_cast_dtype) + model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.") + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) @@ -464,7 +475,7 @@ class WeightsLoader(torch.nn.Module): if output_vae: if vae_filename_param is None: - vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) + vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) vae_sd = model_config.process_vae_state_dict(vae_sd) else: vae_sd = ldm_patched.modules.utils.load_torch_file(vae_filename_param) @@ -472,41 +483,50 @@ class WeightsLoader(torch.nn.Module): vae = VAE(sd=vae_sd) if output_clip: - w = WeightsLoader() clip_target = model_config.clip_target() if clip_target is not None: - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - sd = model_config.process_clip_state_dict(sd) - load_model_weights(w, sd) + clip_sd = model_config.process_clip_state_dict(sd) + if len(clip_sd) > 0: + clip = CLIP(clip_target, embedding_directory=embedding_directory) + m, u = clip.load_sd(clip_sd, full_model=True) + if len(m) > 0: + m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) + if len(m_filter) > 0: + logging.warning("clip missing: {}".format(m)) + else: + logging.debug("clip missing: {}".format(m)) + + if len(u) > 0: + logging.debug("clip unexpected {}:".format(u)) + else: + logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") left_over = sd.keys() if len(left_over) > 0: - print("left over keys:", left_over) + logging.debug("left over keys: {}".format(left_over)) if output_model: model_patcher = ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) if inital_load_device != torch.device("cpu"): - print("loaded straight to GPU") + logging.info("loaded straight to GPU") model_management.load_model_gpu(model_patcher) - return model_patcher, clip, vae, vae_filename, clipvision + return (model_patcher, clip, vae, vae_filename, clipvision) def load_unet_state_dict(sd): #load unet in diffusers format parameters = ldm_patched.modules.utils.calculate_parameters(sd) unet_dtype = model_management.unet_dtype(model_params=parameters) load_device = model_management.get_torch_device() - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) - if "input_blocks.0.0.weight" in sd: #ldm - model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) + if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade + model_config = model_detection.model_config_from_unet(sd, "") if model_config is None: return None new_sd = sd else: #diffusers - model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype) + model_config = model_detection.model_config_from_diffusers_unet(sd) if model_config is None: return None @@ -517,33 +537,39 @@ def load_unet_state_dict(sd): #load unet in diffusers format if k in sd: new_sd[diffusers_keys[k]] = sd.pop(k) else: - print(diffusers_keys[k], k) + logging.warning("{} {}".format(diffusers_keys[k], k)) + offload_device = model_management.unet_offload_device() - model_config.set_manual_cast(manual_cast_dtype) + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") left_over = sd.keys() if len(left_over) > 0: - print("left over keys in unet:", left_over) + logging.info("left over keys in unet: {}".format(left_over)) return ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) def load_unet(unet_path): sd = ldm_patched.modules.utils.load_torch_file(unet_path) model = load_unet_state_dict(sd) if model is None: - print("ERROR UNSUPPORTED UNET", unet_path) + logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) return model -def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None): +def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}): clip_sd = None load_models = [model] if clip is not None: load_models.append(clip.load_model()) clip_sd = clip.get_sd() - model_management.load_models_gpu(load_models) + model_management.load_models_gpu(load_models, force_patch_weights=True) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) + for k in extra_keys: + sd[k] = extra_keys[k] + ldm_patched.modules.utils.save_torch_file(sd, output_path, metadata=metadata) diff --git a/ldm_patched/modules/sd1_clip.py b/ldm_patched/modules/sd1_clip.py index 38579cf4c..f2e487b18 100644 --- a/ldm_patched/modules/sd1_clip.py +++ b/ldm_patched/modules/sd1_clip.py @@ -8,6 +8,7 @@ from . import model_management import ldm_patched.modules.clip_model import json +import logging def gen_empty_tokens(special_tokens, length): start_token = special_tokens.get("start", None) @@ -67,7 +68,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ] def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=ldm_patched.modules.clip_model.CLIPTextModel, - special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32 + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS @@ -86,16 +87,18 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_le self.layer = layer self.layer_idx = None self.special_tokens = special_tokens - self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) - self.enable_attention_masks = False + self.enable_attention_masks = enable_attention_masks self.layer_norm_hidden_state = layer_norm_hidden_state + self.return_projected_pooled = return_projected_pooled + if layer == "hidden": assert layer_idx is not None assert abs(layer_idx) < self.num_layers - self.clip_layer(layer_idx) - self.layer_default = (self.layer, self.layer_idx) + self.set_clip_options({"layer": layer_idx}) + self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) def freeze(self): self.transformer = self.transformer.eval() @@ -103,16 +106,19 @@ def freeze(self): for param in self.parameters(): param.requires_grad = False - def clip_layer(self, layer_idx): - if abs(layer_idx) > self.num_layers: + def set_clip_options(self, options): + layer_idx = options.get("layer", self.layer_idx) + self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) + if layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" else: self.layer = "hidden" self.layer_idx = layer_idx - def reset_clip_layer(self): - self.layer = self.layer_default[0] - self.layer_idx = self.layer_default[1] + def reset_clip_options(self): + self.layer = self.options_default[0] + self.layer_idx = self.options_default[1] + self.return_projected_pooled = self.options_default[2] def set_up_textual_embeddings(self, tokens, current_embeds): out_tokens = [] @@ -132,7 +138,7 @@ def set_up_textual_embeddings(self, tokens, current_embeds): tokens_temp += [next_new_token] next_new_token += 1 else: - print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) + logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1])) while len(tokens_temp) < len(x): tokens_temp += [self.special_tokens["pad"]] out_tokens += [tokens_temp] @@ -177,23 +183,19 @@ def forward(self, tokens): else: z = outputs[1] - if outputs[2] is not None: - pooled_output = outputs[2].float() - else: - pooled_output = None + pooled_output = None + if len(outputs) >= 3: + if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: + pooled_output = outputs[3].float() + elif outputs[2] is not None: + pooled_output = outputs[2].float() - if self.text_projection is not None and pooled_output is not None: - pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() return z.float(), pooled_output def encode(self, tokens): return self(tokens) def load_sd(self, sd): - if "text_projection" in sd: - self.text_projection[:] = sd.pop("text_projection") - if "text_projection.weight" in sd: - self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1) return self.transformer.load_state_dict(sd, strict=False) def parse_parentheses(string): @@ -328,9 +330,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No else: embed = torch.load(embed_path, map_location="cpu", weights_only=True) except Exception as e: - print(traceback.format_exc()) - print() - print("error loading embedding, skipping loading:", embedding_name) + logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name)) return None if embed_out is None: @@ -354,11 +354,12 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path) self.max_length = max_length + self.min_length = min_length empty = self.tokenizer('')["input_ids"] if has_start_token: @@ -420,7 +421,7 @@ def tokenize_with_weights(self, text:str, return_word_ids=False): embedding_name = word[len(self.embedding_identifier):].strip('\n') embed, leftover = self._try_get_embedding(embedding_name) if embed is None: - print(f"warning, embedding:{embedding_name} does not exist, ignoring") + logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring") else: if len(embed.shape) == 1: tokens.append([(embed, weight)]) @@ -470,6 +471,8 @@ def tokenize_with_weights(self, text:str, return_word_ids=False): batch.append((self.end_token, 1.0, 0)) if self.pad_to_max_length: batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch))) + if self.min_length is not None and len(batch) < self.min_length: + batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch))) if not return_word_ids: batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] @@ -503,11 +506,11 @@ def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipMod self.clip = "clip_{}".format(self.clip_name) setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) - def clip_layer(self, layer_idx): - getattr(self, self.clip).clip_layer(layer_idx) + def set_clip_options(self, options): + getattr(self, self.clip).set_clip_options(options) - def reset_clip_layer(self): - getattr(self, self.clip).reset_clip_layer() + def reset_clip_options(self): + getattr(self, self.clip).reset_clip_options() def encode_token_weights(self, token_weight_pairs): token_weight_pairs = token_weight_pairs[self.clip_name] diff --git a/ldm_patched/modules/sd2_clip.py b/ldm_patched/modules/sd2_clip.py index 41f9e388d..c1fa728e6 100644 --- a/ldm_patched/modules/sd2_clip.py +++ b/ldm_patched/modules/sd2_clip.py @@ -1,5 +1,4 @@ from ldm_patched.modules import sd1_clip -import torch import os class SD2ClipHModel(sd1_clip.SDClipModel): diff --git a/ldm_patched/modules/sdxl_clip.py b/ldm_patched/modules/sdxl_clip.py index 9d3d83d82..f06266df4 100644 --- a/ldm_patched/modules/sdxl_clip.py +++ b/ldm_patched/modules/sdxl_clip.py @@ -40,13 +40,13 @@ def __init__(self, device="cpu", dtype=None): self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_g = SDXLClipG(device=device, dtype=dtype) - def clip_layer(self, layer_idx): - self.clip_l.clip_layer(layer_idx) - self.clip_g.clip_layer(layer_idx) + def set_clip_options(self, options): + self.clip_l.set_clip_options(options) + self.clip_g.set_clip_options(options) - def reset_clip_layer(self): - self.clip_g.reset_clip_layer() - self.clip_l.reset_clip_layer() + def reset_clip_options(self): + self.clip_g.reset_clip_options() + self.clip_l.reset_clip_options() def encode_token_weights(self, token_weight_pairs): token_weight_pairs_g = token_weight_pairs["g"] @@ -64,3 +64,25 @@ def load_sd(self, sd): class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None): super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG) + + +class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): + def __init__(self, tokenizer_path=None, embedding_directory=None): + super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') + +class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None): + super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer) + +class StableCascadeClipG(sd1_clip.SDClipModel): + def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True) + + def load_sd(self, sd): + return super().load_sd(sd) + +class StableCascadeClipModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None): + super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG) diff --git a/ldm_patched/modules/supported_models.py b/ldm_patched/modules/supported_models.py index 1d442d4dd..6ca32e8ee 100644 --- a/ldm_patched/modules/supported_models.py +++ b/ldm_patched/modules/supported_models.py @@ -40,11 +40,16 @@ def process_clip_state_dict(self, state_dict): state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() replace_prefix = {} - replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l." - state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) + replace_prefix["cond_stage_model."] = "clip_l." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) return state_dict def process_clip_state_dict_for_saving(self, state_dict): + pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"] + for p in pop_keys: + if p in state_dict: + state_dict.pop(p) + replace_prefix = {"clip_l.": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) @@ -60,22 +65,28 @@ class SD20(supported_models_base.BASE): "use_temporal_attention": False, } + unet_extra_config = { + "num_heads": -1, + "num_head_channels": 64, + "attn_precision": torch.float32, + } + latent_format = latent_formats.SD15 def model_type(self, state_dict, prefix=""): if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix) - out = state_dict[k] - if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. + out = state_dict.get(k, None) + if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. return model_base.ModelType.V_PREDICTION return model_base.ModelType.EPS def process_clip_state_dict(self, state_dict): replace_prefix = {} - replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format - state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) - - state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24) + replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format + replace_prefix["cond_stage_model.model."] = "clip_h." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) + state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.") return state_dict def process_clip_state_dict_for_saving(self, state_dict): @@ -131,11 +142,10 @@ def get_model(self, state_dict, prefix="", device=None): def process_clip_state_dict(self, state_dict): keys_to_replace = {} replace_prefix = {} + replace_prefix["conditioner.embedders.0.model."] = "clip_g." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) - state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) - keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection" - keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" - + state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.") state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) return state_dict @@ -164,7 +174,18 @@ class SDXL(supported_models_base.BASE): latent_format = latent_formats.SDXL def model_type(self, state_dict, prefix=""): - if "v_pred" in state_dict: + if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5 + self.latent_format = latent_formats.SDXL_Playground_2_5() + self.sampling_settings["sigma_data"] = 0.5 + self.sampling_settings["sigma_max"] = 80.0 + self.sampling_settings["sigma_min"] = 0.002 + return model_base.ModelType.EDM + elif "edm_vpred.sigma_max" in state_dict: + self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item()) + if "edm_vpred.sigma_min" in state_dict: + self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item()) + return model_base.ModelType.V_PREDICTION_EDM + elif "v_pred" in state_dict: return model_base.ModelType.V_PREDICTION else: return model_base.ModelType.EPS @@ -179,26 +200,28 @@ def process_clip_state_dict(self, state_dict): keys_to_replace = {} replace_prefix = {} - replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model" - state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) - keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection" - keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection" - keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" + replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model" + replace_prefix["conditioner.embedders.1.model."] = "clip_g." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) - state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) + state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.") return state_dict def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {} keys_to_replace = {} state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") - if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g: - state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") for k in state_dict: if k.startswith("clip_l"): state_dict_g[k] = state_dict[k] + state_dict_g["clip_l.transformer.text_model.embeddings.position_ids"] = torch.arange(77).expand((1, -1)) + pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"] + for p in pop_keys: + if p in state_dict_g: + state_dict_g.pop(p) + replace_prefix["clip_g"] = "conditioner.embedders.1.model" replace_prefix["clip_l"] = "conditioner.embedders.0" state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) @@ -227,6 +250,26 @@ class Segmind_Vega(SDXL): "use_temporal_attention": False, } +class KOALA_700M(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 2, 5], + "context_dim": 2048, + "adm_in_channels": 2816, + "use_temporal_attention": False, + } + +class KOALA_1B(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 2, 6], + "context_dim": 2048, + "adm_in_channels": 2816, + "use_temporal_attention": False, + } + class SVD_img2vid(supported_models_base.BASE): unet_config = { "model_channels": 320, @@ -239,6 +282,12 @@ class SVD_img2vid(supported_models_base.BASE): "use_temporal_resblock": True } + unet_extra_config = { + "num_heads": -1, + "num_head_channels": 64, + "attn_precision": torch.float32, + } + clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual." latent_format = latent_formats.SD15 @@ -252,6 +301,41 @@ def get_model(self, state_dict, prefix="", device=None): def clip_target(self): return None +class SV3D_u(SVD_img2vid): + unet_config = { + "model_channels": 320, + "in_channels": 8, + "use_linear_in_transformer": True, + "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0], + "context_dim": 1024, + "adm_in_channels": 256, + "use_temporal_attention": True, + "use_temporal_resblock": True + } + + vae_key_prefix = ["conditioner.embedders.1.encoder."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SV3D_u(self, device=device) + return out + +class SV3D_p(SV3D_u): + unet_config = { + "model_channels": 320, + "in_channels": 8, + "use_linear_in_transformer": True, + "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0], + "context_dim": 1024, + "adm_in_channels": 1280, + "use_temporal_attention": True, + "use_temporal_resblock": True + } + + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SV3D_p(self, device=device) + return out + class Stable_Zero123(supported_models_base.BASE): unet_config = { "context_dim": 768, @@ -267,6 +351,11 @@ class Stable_Zero123(supported_models_base.BASE): "num_head_channels": -1, } + required_keys = { + "cc_projection.weight": None, + "cc_projection.bias": None, + } + clip_vision_prefix = "cond_stage_model.model.visual." latent_format = latent_formats.SD15 @@ -306,5 +395,99 @@ def get_model(self, state_dict, prefix="", device=None): out = model_base.SD_X4Upscaler(self, device=device) return out -models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler] +class Stable_Cascade_C(supported_models_base.BASE): + unet_config = { + "stable_cascade_stage": 'c', + } + + unet_extra_config = {} + + latent_format = latent_formats.SC_Prior + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + sampling_settings = { + "shift": 2.0, + } + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoder."] + clip_vision_prefix = "clip_l_vision." + + def process_unet_state_dict(self, state_dict): + key_list = list(state_dict.keys()) + for y in ["weight", "bias"]: + suffix = "in_proj_{}".format(y) + keys = filter(lambda a: a.endswith(suffix), key_list) + for k_from in keys: + weights = state_dict.pop(k_from) + prefix = k_from[:-(len(suffix) + 1)] + shape_from = weights.shape[0] // 3 + for x in range(3): + p = ["to_q", "to_k", "to_v"] + k_to = "{}.{}.{}".format(prefix, p[x], y) + state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)] + return state_dict + + def process_clip_state_dict(self, state_dict): + state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True) + if "clip_g.text_projection" in state_dict: + state_dict["clip_g.transformer.text_projection.weight"] = state_dict.pop("clip_g.text_projection").transpose(0, 1) + return state_dict + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.StableCascade_C(self, device=device) + return out + + def clip_target(self): + return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel) + +class Stable_Cascade_B(Stable_Cascade_C): + unet_config = { + "stable_cascade_stage": 'b', + } + + unet_extra_config = {} + + latent_format = latent_formats.SC_B + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + sampling_settings = { + "shift": 1.0, + } + + clip_vision_prefix = None + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.StableCascade_B(self, device=device) + return out + +class SD15_instructpix2pix(SD15): + unet_config = { + "context_dim": 768, + "model_channels": 320, + "use_linear_in_transformer": False, + "adm_in_channels": None, + "use_temporal_attention": False, + "in_channels": 8, + } + + def get_model(self, state_dict, prefix="", device=None): + return model_base.SD15_instructpix2pix(self, device=device) + +class SDXL_instructpix2pix(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 0, 2, 2, 10, 10], + "context_dim": 2048, + "adm_in_channels": 2816, + "use_temporal_attention": False, + "in_channels": 8, + } + + def get_model(self, state_dict, prefix="", device=None): + return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device) + +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p] + models += [SVD_img2vid] diff --git a/ldm_patched/modules/supported_models_base.py b/ldm_patched/modules/supported_models_base.py index 5baf4bca6..cf7cdff34 100644 --- a/ldm_patched/modules/supported_models_base.py +++ b/ldm_patched/modules/supported_models_base.py @@ -16,19 +16,28 @@ class BASE: "num_head_channels": 64, } + required_keys = {} + clip_prefix = [] clip_vision_prefix = None noise_aug_config = None sampling_settings = {} latent_format = latent_formats.LatentFormat + vae_key_prefix = ["first_stage_model."] + text_encoder_key_prefix = ["cond_stage_model."] + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] manual_cast_dtype = None @classmethod - def matches(s, unet_config): + def matches(s, unet_config, state_dict=None): for k in s.unet_config: - if s.unet_config[k] != unet_config[k]: + if k not in unet_config or s.unet_config[k] != unet_config[k]: return False + if state_dict is not None: + for k in s.required_keys: + if k not in state_dict: + return False return True def model_type(self, state_dict, prefix=""): @@ -38,7 +47,8 @@ def inpaint_model(self): return self.unet_config["in_channels"] > 4 def __init__(self, unet_config): - self.unet_config = unet_config + self.unet_config = unet_config.copy() + self.sampling_settings = self.sampling_settings.copy() self.latent_format = self.latent_format() for x in self.unet_extra_config: self.unet_config[x] = self.unet_extra_config[x] @@ -53,6 +63,7 @@ def get_model(self, state_dict, prefix="", device=None): return out def process_clip_state_dict(self, state_dict): + state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True) return state_dict def process_unet_state_dict(self, state_dict): @@ -62,7 +73,7 @@ def process_vae_state_dict(self, state_dict): return state_dict def process_clip_state_dict_for_saving(self, state_dict): - replace_prefix = {"": "cond_stage_model."} + replace_prefix = {"": self.text_encoder_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_clip_vision_state_dict_for_saving(self, state_dict): @@ -76,8 +87,9 @@ def process_unet_state_dict_for_saving(self, state_dict): return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_vae_state_dict_for_saving(self, state_dict): - replace_prefix = {"": "first_stage_model."} + replace_prefix = {"": self.vae_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) - def set_manual_cast(self, manual_cast_dtype): + def set_inference_dtype(self, dtype, manual_cast_dtype): + self.unet_config['dtype'] = dtype self.manual_cast_dtype = manual_cast_dtype diff --git a/ldm_patched/modules/types.py b/ldm_patched/modules/types.py new file mode 100644 index 000000000..70cf4b158 --- /dev/null +++ b/ldm_patched/modules/types.py @@ -0,0 +1,32 @@ +import torch +from typing import Callable, Protocol, TypedDict, Optional, List + + +class UnetApplyFunction(Protocol): + """Function signature protocol on comfy.model_base.BaseModel.apply_model""" + + def __call__(self, x: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor: + pass + + +class UnetApplyConds(TypedDict): + """Optional conditions for unet apply function.""" + + c_concat: Optional[torch.Tensor] + c_crossattn: Optional[torch.Tensor] + control: Optional[torch.Tensor] + transformer_options: Optional[dict] + + +class UnetParams(TypedDict): + # Tensor of shape [B, C, H, W] + input: torch.Tensor + # Tensor of shape [B] + timestep: torch.Tensor + c: UnetApplyConds + # List of [0, 1], [0], [1], ... + # 0 means conditional, 1 means conditional unconditional + cond_or_uncond: List[int] + + +UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor] diff --git a/ldm_patched/modules/utils.py b/ldm_patched/modules/utils.py index f8283a86e..d1381e9d6 100644 --- a/ldm_patched/modules/utils.py +++ b/ldm_patched/modules/utils.py @@ -5,6 +5,7 @@ import safetensors.torch import numpy as np from PIL import Image +import logging def load_torch_file(ckpt, safe_load=False, device=None): if device is None: @@ -14,14 +15,14 @@ def load_torch_file(ckpt, safe_load=False, device=None): else: if safe_load: if not 'weights_only' in torch.load.__code__.co_varnames: - print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") + logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") safe_load = False if safe_load: pl_sd = torch.load(ckpt, map_location=device, weights_only=True) else: pl_sd = torch.load(ckpt, map_location=device, pickle_module=ldm_patched.modules.checkpoint_pickle) if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") + logging.debug(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: @@ -98,8 +99,22 @@ def transformers_convert(sd, prefix_from, prefix_to, number): p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] + return sd +def clip_text_transformers_convert(sd, prefix_from, prefix_to): + sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32) + + tp = "{}text_projection.weight".format(prefix_from) + if tp in sd: + sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp) + + tp = "{}text_projection".format(prefix_from) + if tp in sd: + sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous() + return sd + + UNET_MAP_ATTENTIONS = { "proj_in.weight", "proj_in.bias", @@ -169,6 +184,8 @@ def transformers_convert(sd, prefix_from, prefix_to, number): } def unet_to_diffusers(unet_config): + if "num_res_blocks" not in unet_config: + return {} num_res_blocks = unet_config["num_res_blocks"] channel_mult = unet_config["channel_mult"] transformer_depth = unet_config["transformer_depth"][:] @@ -278,8 +295,11 @@ def set_attr(obj, attr, value): for name in attrs[:-1]: obj = getattr(obj, name) prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) - del prev + setattr(obj, attrs[-1], value) + return prev + +def set_attr_param(obj, attr, value): + return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it @@ -413,6 +433,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device) for y in range(0, s.shape[2], tile_y - overlap): for x in range(0, s.shape[3], tile_x - overlap): + x = max(0, min(s.shape[-1] - overlap, x)) + y = max(0, min(s.shape[-2] - overlap, y)) s_in = s[:,:,y:y+tile_y,x:x+tile_x] ps = function(s_in).to(output_device) diff --git a/ldm_patched/unipc/uni_pc.py b/ldm_patched/unipc/uni_pc.py index 08bf0fc9e..56b90300f 100644 --- a/ldm_patched/unipc/uni_pc.py +++ b/ldm_patched/unipc/uni_pc.py @@ -1,4 +1,5 @@ -#code taken from: https://github.com/wl-zhao/UniPC and modified +# code taken from: https://github.com/wl-zhao/UniPC and modified +# updated from https://github.com/comfyanonymous/ComfyUI/blob/a178e25912b01abf436eba1cfaab316ba02d272d/comfy/extra_samplers/uni_pc.py#L874 import torch import torch.nn.functional as F @@ -358,9 +359,6 @@ def __init__( thresholding=False, max_val=1., variant='bh1', - noise_mask=None, - masked_image=None, - noise=None, ): """Construct a UniPC. @@ -372,9 +370,6 @@ def __init__( self.predict_x0 = predict_x0 self.thresholding = thresholding self.max_val = max_val - self.noise_mask = noise_mask - self.masked_image = masked_image - self.noise = noise def dynamic_thresholding_fn(self, x0, t=None): """ @@ -391,10 +386,7 @@ def noise_prediction_fn(self, x, t): """ Return the noise prediction model. """ - if self.noise_mask is not None: - return self.model(x, t) * self.noise_mask - else: - return self.model(x, t) + return self.model(x, t) def data_prediction_fn(self, x, t): """ @@ -409,8 +401,6 @@ def data_prediction_fn(self, x, t): s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) x0 = torch.clamp(x0, -s, s) / s - if self.noise_mask is not None: - x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image return x0 def model_fn(self, x, t): @@ -723,8 +713,6 @@ def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='tim assert timesteps.shape[0] - 1 == steps # with torch.no_grad(): for step_index in trange(steps, disable=disable_pbar): - if self.noise_mask is not None: - x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index])) if step_index == 0: vec_t = timesteps[0].expand((x.shape[0])) model_prev_list = [self.model_fn(x, vec_t)] @@ -766,7 +754,7 @@ def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='tim model_x = self.model_fn(x, vec_t) model_prev_list[-1] = model_x if callback is not None: - callback(step_index, model_prev_list[-1], x, steps) + callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]}) else: raise NotImplementedError() # if denoise_to_zero: @@ -858,7 +846,7 @@ def predict_eps_sigma(model, input, sigma_in, **kwargs): return (input - model(input, sigma_in, **kwargs)) / sigma -def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): +def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'): timesteps = sigmas.clone() if sigmas[-1] == 0: timesteps = sigmas[:] @@ -867,16 +855,7 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call timesteps = sigmas.clone() ns = SigmaConvert() - if image is not None: - img = image * ns.marginal_alpha(timesteps[0]) - if max_denoise: - noise_mult = 1.0 - else: - noise_mult = ns.marginal_std(timesteps[0]) - img += noise * noise_mult - else: - img = noise - + noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0) model_type = "noise" model_fn = model_wrapper( @@ -888,7 +867,10 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call ) order = min(3, len(timesteps) - 2) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) - x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant) + x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) x /= ns.marginal_alpha(timesteps[-1]) return x + +def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False): + return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2') \ No newline at end of file diff --git a/ldm_patched/utils/node_helpers.py b/ldm_patched/utils/node_helpers.py new file mode 100644 index 000000000..f472c2665 --- /dev/null +++ b/ldm_patched/utils/node_helpers.py @@ -0,0 +1,24 @@ +from PIL import ImageFile, UnidentifiedImageError + +def conditioning_set_values(conditioning, values={}): + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + for k in values: + n[1][k] = values[k] + c.append(n) + + return c + +def pillow(fn, arg): + prev_value = None + try: + x = fn(arg) + except (OSError, UnidentifiedImageError, ValueError): #PIL issues #4472 and #2445, also fixes ComfyUI issue #3416 + prev_value = ImageFile.LOAD_TRUNCATED_IMAGES + ImageFile.LOAD_TRUNCATED_IMAGES = True + x = fn(arg) + finally: + if prev_value is not None: + ImageFile.LOAD_TRUNCATED_IMAGES = prev_value + return x \ No newline at end of file diff --git a/modules/core.py b/modules/core.py index 1c3dacb99..9a80faf73 100644 --- a/modules/core.py +++ b/modules/core.py @@ -16,7 +16,6 @@ from ldm_patched.contrib.external import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, \ ControlNetApplyAdvanced from ldm_patched.contrib.external_freelunch import FreeU_V2 -from ldm_patched.modules.sample import prepare_mask from modules.lora import match_lora from modules.util import get_file_from_folder_list from ldm_patched.modules.lora import model_lora_keys_unet, model_lora_keys_clip diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index 494644d69..6ed908b70 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -38,7 +38,10 @@ def refresh_controlnets(model_paths): if p in loaded_ControlNets: cache[p] = loaded_ControlNets[p] else: - cache[p] = core.load_controlnet(p) + p_m = core.load_controlnet(p) + if p_m is None: + print(f'WARNING: Failed to load ControlNet model: {p}') + cache[p] = p_m loaded_ControlNets = cache return @@ -242,7 +245,7 @@ def refresh_everything(refiner_model_name, base_model_name, loras, final_refiner_vae = None if use_synthetic_refiner and refiner_model_name == 'None': - print('Synthetic Refiner Activated') + print('Using Synthetic Refiner') refresh_base_model(base_model_name, vae_name) synthesize_refiner_model() else: @@ -288,14 +291,14 @@ def vae_parse(latent): @torch.no_grad() @torch.inference_mode() def calculate_sigmas_all(sampler, model, scheduler, steps): - from ldm_patched.modules.samplers import calculate_sigmas_scheduler + from ldm_patched.modules.samplers import calculate_sigmas discard_penultimate_sigma = False if sampler in ['dpm_2', 'dpm_2_ancestral']: steps += 1 discard_penultimate_sigma = True - sigmas = calculate_sigmas_scheduler(model, scheduler, steps) + sigmas = calculate_sigmas(model, scheduler, steps) if discard_penultimate_sigma: sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) diff --git a/modules/lora.py b/modules/lora.py index 088545c70..8f7cf286e 100644 --- a/modules/lora.py +++ b/modules/lora.py @@ -14,8 +14,16 @@ def match_lora(lora, to_load): alpha = lora[alpha_name].item() loaded_keys.add(alpha_name) + dora_scale_name = "{}.dora_scale".format(x) + dora_scale = None + if dora_scale_name in lora.keys(): + dora_scale = lora[dora_scale_name] + loaded_keys.add(dora_scale_name) + regular_lora = "{}.lora_up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x) + diffusers2_lora = "{}.lora_B.weight".format(x) + diffusers3_lora = "{}.lora.up.weight".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x) A_name = None @@ -27,6 +35,14 @@ def match_lora(lora, to_load): A_name = diffusers_lora B_name = "{}_lora.down.weight".format(x) mid_name = None + elif diffusers2_lora in lora.keys(): + A_name = diffusers2_lora + B_name = "{}.lora_A.weight".format(x) + mid_name = None + elif diffusers3_lora in lora.keys(): + A_name = diffusers3_lora + B_name = "{}.lora.down.weight".format(x) + mid_name = None elif transformers_lora in lora.keys(): A_name = transformers_lora B_name ="{}.lora_linear_layer.down.weight".format(x) @@ -37,7 +53,7 @@ def match_lora(lora, to_load): if mid_name is not None and mid_name in lora.keys(): mid = lora[mid_name] loaded_keys.add(mid_name) - patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid)) + patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale)) loaded_keys.add(A_name) loaded_keys.add(B_name) @@ -58,7 +74,7 @@ def match_lora(lora, to_load): loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t2_name) - patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)) + patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -110,7 +126,7 @@ def match_lora(lora, to_load): loaded_keys.add(lokr_t2_name) if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)) + patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) #glora a1_name = "{}.a1.weight".format(x) @@ -118,7 +134,7 @@ def match_lora(lora, to_load): b1_name = "{}.b1.weight".format(x) b2_name = "{}.b2.weight".format(x) if a1_name in lora: - patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha)) + patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) loaded_keys.add(a1_name) loaded_keys.add(a2_name) loaded_keys.add(b1_name) diff --git a/modules/patch.py b/modules/patch.py index 3c2dd8f47..83720984b 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -21,7 +21,7 @@ import safetensors.torch import modules.constants as constants -from ldm_patched.modules.samplers import calc_cond_uncond_batch +from ldm_patched.modules.samplers import calc_cond_batch from ldm_patched.k_diffusion.sampling import BatchedBrownianTree from ldm_patched.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, apply_control from modules.patch_precision import patch_all_precision @@ -227,14 +227,16 @@ def patched_sampling_function(model, x, timestep, uncond, cond, cond_scale, mode pid = os.getpid() if math.isclose(cond_scale, 1.0) and not model_options.get("disable_cfg1_optimization", False): - final_x0 = calc_cond_uncond_batch(model, cond, None, x, timestep, model_options)[0] + calc_cond_uncond_batch = tuple(calc_cond_batch(model, [cond, None], x, timestep, model_options)) + final_x0 = calc_cond_uncond_batch[0] if patch_settings[pid].eps_record is not None: patch_settings[pid].eps_record = ((x - final_x0) / timestep).cpu() return final_x0 - positive_x0, negative_x0 = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) + calc_cond_uncond_batch = tuple(calc_cond_batch(model, [cond, uncond], x, timestep, model_options)) + positive_x0, negative_x0 = calc_cond_uncond_batch positive_eps = x - positive_x0 negative_eps = x - negative_x0 @@ -294,7 +296,10 @@ def embedder(number_list): return final_adm -def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): +def patched_KSamplerX0Inpaint_forward(self, x, sigma, denoise_mask, model_options={}, seed=None): + # uncond = self.inner_model.conds.get("negative", None) + # cond = self.inner_model.conds.get("positive", None) + # cond_scale = self.inner_model.cfg if inpaint_worker.current_task is not None: latent_processor = self.inner_model.inner_model.process_latent_in inpaint_latent = latent_processor(inpaint_worker.current_task.latent).to(x) @@ -310,18 +315,18 @@ def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, x = x * inpaint_mask + (inpaint_latent + current_energy) * (1.0 - inpaint_mask) out = self.inner_model(x, sigma, - cond=cond, - uncond=uncond, - cond_scale=cond_scale, + # cond=cond, + # uncond=uncond, + # cond_scale=cond_scale, model_options=model_options, seed=seed) out = out * inpaint_mask + inpaint_latent * (1.0 - inpaint_mask) else: out = self.inner_model(x, sigma, - cond=cond, - uncond=uncond, - cond_scale=cond_scale, + # cond=cond, + # uncond=uncond, + # cond_scale=cond_scale, model_options=model_options, seed=seed) return out @@ -384,7 +389,10 @@ def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control= transformer_patches = transformer_options.get("patches", {}) num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) - image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) + + image_only_indicator = None + if hasattr(self, "default_image_only_indicator"): + image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) time_context = kwargs.get("time_context", None) assert (y is not None) == ( @@ -486,6 +494,8 @@ def loader(*args, **kwargs): def patch_all(): + # Fooocus-specific additions over ComfyUI ldm. + # ALSO: marked with 'used-by-Fooocus' if ldm_patched.modules.model_management.directml_enabled: ldm_patched.modules.model_management.lowvram_available = True ldm_patched.modules.model_management.OOM_EXCEPTION = Exception @@ -501,7 +511,7 @@ def patch_all(): ldm_patched.controlnet.cldm.ControlNet.forward = patched_cldm_forward ldm_patched.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward ldm_patched.modules.model_base.SDXL.encode_adm = sdxl_encode_adm_patched - ldm_patched.modules.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward + ldm_patched.modules.samplers.KSamplerX0Inpaint.__call__ = patched_KSamplerX0Inpaint_forward ldm_patched.k_diffusion.sampling.BrownianTreeNoiseSampler = BrownianTreeNoiseSamplerPatched ldm_patched.modules.samplers.sampling_function = patched_sampling_function diff --git a/modules/patch_clip.py b/modules/patch_clip.py index 06b7f01bb..dd6cb2f44 100644 --- a/modules/patch_clip.py +++ b/modules/patch_clip.py @@ -1,5 +1,6 @@ # Consistent with Kohya/A1111 to reduce differences between model training and inference. +import json import os import torch import ldm_patched.controlnet.cldm @@ -62,48 +63,42 @@ def patched_encode_token_weights(self, token_weight_pairs): return torch.cat(output, dim=-2).to(ldm_patched.modules.model_management.intermediate_device()), first_pooled -def patched_SDClipModel__init__(self, max_length=77, freeze=True, layer="last", layer_idx=None, - textmodel_json_config=None, dtype=None, special_tokens=None, - layer_norm_hidden_state=True, **kwargs): +def patched_SDClipModel__init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, + freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=ldm_patched.modules.clip_model.CLIPTextModel, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32 torch.nn.Module.__init__(self) assert layer in self.LAYERS - if special_tokens is None: - special_tokens = {"start": 49406, "end": 49407, "pad": 49407} - if textmodel_json_config is None: - textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(ldm_patched.modules.sd1_clip.__file__)), - "sd1_clip_config.json") - - config = CLIPTextConfig.from_json_file(textmodel_json_config) - self.num_layers = config.num_hidden_layers + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(ldm_patched.modules.sd1_clip.__file__)), "sd1_clip_config.json") - with use_patched_ops(ops.manual_cast): - with modeling_utils.no_init_weights(): - self.transformer = CLIPTextModel(config) - - if dtype is not None: - self.transformer.to(dtype) + with open(textmodel_json_config) as f: + config = json.load(f) + self.transformer = model_class(config, dtype, device, ldm_patched.modules.ops.manual_cast) + self.num_layers = self.transformer.num_layers self.transformer.text_model.embeddings.to(torch.float32) + self.max_length = max_length if freeze: self.freeze() - - self.max_length = max_length self.layer = layer self.layer_idx = None self.special_tokens = special_tokens + + # TODO check if necessary self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) - self.enable_attention_masks = False + self.enable_attention_masks = enable_attention_masks self.layer_norm_hidden_state = layer_norm_hidden_state + self.return_projected_pooled = return_projected_pooled + if layer == "hidden": assert layer_idx is not None assert abs(layer_idx) < self.num_layers - self.clip_layer(layer_idx) - self.layer_default = (self.layer, self.layer_idx) + self.set_clip_options({"layer": layer_idx}) + self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) def patched_SDClipModel_forward(self, tokens): @@ -122,8 +117,7 @@ def patched_SDClipModel_forward(self, tokens): if tokens[x, y] == max_token: break - outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, - output_hidden_states=self.layer == "hidden") + outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) self.transformer.set_input_embeddings(backup_embeds) if self.layer == "last": @@ -131,17 +125,14 @@ def patched_SDClipModel_forward(self, tokens): elif self.layer == "pooled": z = outputs.pooler_output[:, None, :] else: - z = outputs.hidden_states[self.layer_idx] - if self.layer_norm_hidden_state: - z = self.transformer.text_model.final_layer_norm(z) - - if hasattr(outputs, "pooler_output"): - pooled_output = outputs.pooler_output.float() - else: - pooled_output = None - - if self.text_projection is not None and pooled_output is not None: - pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() + z = outputs[1] + + pooled_output = None + if len(outputs) >= 3: + if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: + pooled_output = outputs[3].float() + elif outputs[2] is not None: + pooled_output = outputs[2].float() return z.float(), pooled_output diff --git a/modules/patch_precision.py b/modules/patch_precision.py index 22ffda0ad..51a1e2b78 100644 --- a/modules/patch_precision.py +++ b/modules/patch_precision.py @@ -49,10 +49,8 @@ def patched_register_schedule(self, given_betas=None, beta_schedule="linear", ti self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end - sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32) + sigmas = (((1 - alphas_cumprod) / alphas_cumprod) ** 0.5).clone().detach().type(torch.float32) self.set_sigmas(sigmas) - alphas_cumprod = torch.tensor(alphas_cumprod, dtype=torch.float32) - self.set_alphas_cumprod(alphas_cumprod) return diff --git a/modules/sample_hijack.py b/modules/sample_hijack.py index 84752ede7..5a96e840b 100644 --- a/modules/sample_hijack.py +++ b/modules/sample_hijack.py @@ -9,10 +9,12 @@ from ldm_patched.modules.samplers import normal_scheduler, simple_scheduler, ddim_scheduler from ldm_patched.modules.model_base import SDXLRefiner, SDXL from ldm_patched.modules.conds import CONDRegular -from ldm_patched.modules.sample import get_additional_models, get_models_from_cond, cleanup_additional_models -from ldm_patched.modules.samplers import resolve_areas_and_cond_masks, wrap_model, calculate_start_end_timesteps, \ - create_cond_with_same_area_if_none, pre_run_control, apply_empty_x_to_equal_area, encode_model_conds - +from ldm_patched.modules.sampler_helpers import get_additional_models, get_models_from_cond, cleanup_additional_models +from ldm_patched.modules.samplers import resolve_areas_and_cond_masks, calculate_start_end_timesteps, \ + create_cond_with_same_area_if_none, pre_run_control, apply_empty_x_to_equal_area, encode_model_conds, CFGGuider, \ + process_conds +from ldm_patched.modules.model_patcher import ModelPatcher +from modules.util import sys_dump_pythonobj current_refiner = None refiner_switch_step = -1 @@ -83,97 +85,98 @@ def clip_separate_after_preparation(cond, target_model=None, target_clip=None): return results - -@torch.no_grad() -@torch.inference_mode() def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - global current_refiner - - positive = positive[:] - negative = negative[:] - - resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device) - resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device) - - model_wrap = wrap_model(model) - - calculate_start_end_timesteps(model, negative) - calculate_start_end_timesteps(model, positive) - - if latent_image is not None: - latent_image = model.process_latent_in(latent_image) - - if hasattr(model, 'extra_conds'): - positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) - - #make sure each cond area has an opposite one with the same area - for c in positive: - create_cond_with_same_area_if_none(negative, c) - for c in negative: - create_cond_with_same_area_if_none(positive, c) - - # pre_run_control(model, negative + positive) - pre_run_control(model, positive) # negative is not necessary in Fooocus, 0.5s faster. - - apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) - apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) - - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} - - if current_refiner is not None and hasattr(current_refiner.model, 'extra_conds'): - positive_refiner = clip_separate_after_preparation(positive, target_model=current_refiner.model) - negative_refiner = clip_separate_after_preparation(negative, target_model=current_refiner.model) - - positive_refiner = encode_model_conds(current_refiner.model.extra_conds, positive_refiner, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) - negative_refiner = encode_model_conds(current_refiner.model.extra_conds, negative_refiner, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) - - def refiner_switch(): - cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) - - extra_args["cond"] = positive_refiner - extra_args["uncond"] = negative_refiner - - # clear ip-adapter for refiner - extra_args['model_options'] = {k: {} if k == 'transformer_options' else v for k, v in extra_args['model_options'].items()} - - models, inference_memory = get_additional_models(positive_refiner, negative_refiner, current_refiner.model_dtype()) - ldm_patched.modules.model_management.load_models_gpu( - [current_refiner] + models, - model.memory_required([noise.shape[0] * 2] + list(noise.shape[1:])) + inference_memory) - - model_wrap.inner_model = current_refiner.model - print('Refiner Swapped') - return - - def callback_wrap(step, x0, x, total_steps): - if step == refiner_switch_step and current_refiner is not None: - refiner_switch() - if callback is not None: - # residual_noise_preview = x - x0 - # residual_noise_preview /= residual_noise_preview.std() - # residual_noise_preview *= x0.std() - callback(step, x0, x, total_steps) - - samples = sampler.sample(model_wrap, sigmas, extra_args, callback_wrap, noise, latent_image, denoise_mask, disable_pbar) - return model.process_latent_out(samples.to(torch.float32)) + cfg_guider = CFGGuiderHacked(model) + cfg_guider.set_conds(positive, negative) + # TODO cfg_guider.inner_set_conds({"positive": positive}) # negative is not necessary in Fooocus, 0.5s faster. + cfg_guider.set_cfg(cfg) + return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + + +class CFGGuiderHacked(CFGGuider): + def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed): + global current_refiner + + if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. + latent_image = self.inner_model.process_latent_in(latent_image) + + self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) + + extra_args = {"model_options": self.model_options, "seed":seed} + + if current_refiner is not None and hasattr(current_refiner.model, 'extra_conds'): + positive_refiner = clip_separate_after_preparation(self.conds['positive'], target_model=current_refiner.model) + negative_refiner = clip_separate_after_preparation(self.conds['negative'], target_model=current_refiner.model) + + positive_refiner = encode_model_conds(current_refiner.model.extra_conds, positive_refiner, noise, device, + "positive", latent_image=latent_image, denoise_mask=denoise_mask) + negative_refiner = encode_model_conds(current_refiner.model.extra_conds, negative_refiner, noise, device, + "negative", latent_image=latent_image, denoise_mask=denoise_mask) + + def refiner_switch(): + cleanup_additional_models( + set(get_models_from_cond(self.conds['positive'], "control") + get_models_from_cond(self.conds['negative'], "control"))) + + # extra_args["cond"] = positive_refiner + # extra_args["uncond"] = negative_refiner + self.set_conds( [[None, positive_refiner[0]]], [[None, negative_refiner[0]]] ) + + # clear ip-adapter for refiner + extra_args['model_options'] = {k: {} if k == 'transformer_options' else v for k, v in + extra_args['model_options'].items()} + + # current_refiner = SDXL object || ModelPatcher + model_dtype = None + if isinstance(current_refiner, ModelPatcher): + model_dtype = current_refiner.model_dtype() + else: + model_dtype = current_refiner.get_dtype() + # models, inference_memory = get_additional_models(positive_refiner, negative_refiner) + models, inference_memory = get_additional_models(self.conds, model_dtype) + ldm_patched.modules.model_management.load_models_gpu( + [current_refiner] + models, + self.model_patcher.memory_required([noise.shape[0] * 2] + list(noise.shape[1:])) + inference_memory) + + self.inner_model = current_refiner.model + # rerun with new inner_model needed + self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) + print('Refiner Swapped') + return + + def callback_wrap(step, x0, x, total_steps): + if step == refiner_switch_step and current_refiner is not None: + refiner_switch() + if callback is not None: + # residual_noise_preview = x - x0 + # residual_noise_preview /= residual_noise_preview.std() + # residual_noise_preview *= x0.std() + callback(step, x0, x, total_steps) + + samples = sampler.sample(self, sigmas, extra_args, callback_wrap, noise, latent_image, denoise_mask, disable_pbar) + return self.inner_model.process_latent_out(samples.to(torch.float32)) @torch.no_grad() @torch.inference_mode() def calculate_sigmas_scheduler_hacked(model, scheduler_name, steps): + # model = SDXL object || ModelPatcher + # sys_dump_pythonobj(model, False, "- calculate_sigmas_scheduler_hacked model") + if isinstance(model, ModelPatcher): + model_sampling = model.get_model_object("model_sampling") + else: + model_sampling = model.model_sampling if scheduler_name == "karras": - sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) elif scheduler_name == "exponential": - sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) elif scheduler_name == "normal": - sigmas = normal_scheduler(model, steps) + sigmas = normal_scheduler(model_sampling, steps) elif scheduler_name == "simple": - sigmas = simple_scheduler(model, steps) + sigmas = simple_scheduler(model_sampling, steps) elif scheduler_name == "ddim_uniform": - sigmas = ddim_scheduler(model, steps) + sigmas = ddim_scheduler(model_sampling, steps) elif scheduler_name == "sgm_uniform": - sigmas = normal_scheduler(model, steps, sgm=True) + sigmas = normal_scheduler(model_sampling, steps, sgm=True) elif scheduler_name == "turbo": sigmas = SDTurboScheduler().get_sigmas(model=model, steps=steps, denoise=1.0)[0] elif scheduler_name == "align_your_steps": @@ -183,6 +186,6 @@ def calculate_sigmas_scheduler_hacked(model, scheduler_name, steps): raise TypeError("error invalid scheduler") return sigmas - -ldm_patched.modules.samplers.calculate_sigmas_scheduler = calculate_sigmas_scheduler_hacked +# used-by-Fooocus +ldm_patched.modules.samplers.calculate_sigmas = calculate_sigmas_scheduler_hacked ldm_patched.modules.samplers.sample = sample_hacked diff --git a/modules/util.py b/modules/util.py index 30b9f4d18..53034dd6b 100644 --- a/modules/util.py +++ b/modules/util.py @@ -513,3 +513,14 @@ def get_image_size_info(image: np.ndarray, aspect_ratios: list) -> str: return size_info except Exception as e: return f'Error reading image: {e}' + +def sys_dump_pythonobj(obj, withValue, hintStr = None): + if hintStr is None: + hintStr = "- object Dump:" + print(hintStr, type(obj)) + for attr in dir(obj): + if hasattr( obj, attr ): + if withValue: + print( "...%s = %s" % (attr, getattr(obj, attr))) + else: + print( "...%s = ???" % (attr)) \ No newline at end of file diff --git a/requirements_versions.txt b/requirements_versions.txt index 9147db47f..486be4293 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -21,4 +21,6 @@ tokenizers==0.19.1 packaging==24.1 rembg==2.0.57 groundingdino-py==0.4.0 -segment_anything==1.0 \ No newline at end of file +segment_anything==1.0 +kornia==0.7.2 +spandrel==0.3.4 \ No newline at end of file diff --git a/webui.py b/webui.py index b8159d855..0a0568853 100644 --- a/webui.py +++ b/webui.py @@ -1116,13 +1116,19 @@ def dump_default_english_config(): # dump_default_english_config() - shared.gradio_root.launch( inbrowser=args_manager.args.in_browser, server_name=args_manager.args.listen, server_port=args_manager.args.port, share=args_manager.args.share, auth=check_auth if (args_manager.args.share or args_manager.args.listen) and auth_enabled else None, + auth_message=args_manager.args.auth_message if (args_manager.args.share or args_manager.args.listen) and auth_enabled else None, allowed_paths=[modules.config.path_outputs], - blocked_paths=[constants.AUTH_FILENAME] + blocked_paths=[constants.AUTH_FILENAME], + debug=args_manager.args.verbose, + ssl_keyfile=args_manager.args.tls_keyfile, + ssl_keyfile_password=args_manager.args.tls_keyfile_password, + ssl_certfile=args_manager.args.tls_certfile, + ssl_verify=args_manager.args.tls_verify, + favicon_path=args_manager.args.favicon_path )