Skip to content

Commit c441048

Browse files
Make VAE Encode tiled node work with video VAE.
1 parent 9f4b181 commit c441048

File tree

3 files changed

+70
-17
lines changed

3 files changed

+70
-17
lines changed

comfy/sd.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None):
336336
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
337337
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
338338
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
339+
self.downscale_ratio = (lambda a: max(0, (a + 3) / 6), 8, 8)
339340
self.working_dtypes = [torch.float16, torch.float32]
340341
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
341342
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE()
@@ -344,12 +345,14 @@ def __init__(self, sd=None, device=None, config=None, dtype=None):
344345
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
345346
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
346347
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
348+
self.downscale_ratio = (lambda a: max(0, (a + 4) / 8), 32, 32)
347349
self.working_dtypes = [torch.bfloat16, torch.float32]
348350
elif "decoder.conv_in.conv.weight" in sd:
349351
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
350352
ddconfig["conv3d"] = True
351353
ddconfig["time_compress"] = 4
352354
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
355+
self.downscale_ratio = (lambda a: max(0, (a + 2) / 4), 8, 8)
353356
self.latent_dim = 3
354357
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
355358
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
@@ -385,10 +388,12 @@ def __init__(self, sd=None, device=None, config=None, dtype=None):
385388
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
386389

387390
def vae_encode_crop_pixels(self, pixels):
391+
downscale_ratio = self.spacial_compression_encode()
392+
388393
dims = pixels.shape[1:-1]
389394
for d in range(len(dims)):
390-
x = (dims[d] // self.downscale_ratio) * self.downscale_ratio
391-
x_offset = (dims[d] % self.downscale_ratio) // 2
395+
x = (dims[d] // downscale_ratio) * downscale_ratio
396+
x_offset = (dims[d] % downscale_ratio) // 2
392397
if x != dims[d]:
393398
pixels = pixels.narrow(d + 1, x_offset, x)
394399
return pixels
@@ -409,7 +414,7 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
409414

410415
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
411416
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
412-
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
417+
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
413418

414419
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
415420
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
@@ -432,6 +437,10 @@ def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
432437
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
433438
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
434439

440+
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
441+
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
442+
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, output_device=self.output_device)
443+
435444
def decode(self, samples_in):
436445
pixel_samples = None
437446
try:
@@ -504,18 +513,43 @@ def encode(self, pixel_samples):
504513

505514
except model_management.OOM_EXCEPTION:
506515
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
507-
if len(pixel_samples.shape) == 3:
516+
if self.latent_dim == 3:
517+
tile = 256
518+
overlap = tile // 4
519+
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
520+
elif self.latent_dim == 1:
508521
samples = self.encode_tiled_1d(pixel_samples)
509522
else:
510523
samples = self.encode_tiled_(pixel_samples)
511524

512525
return samples
513526

514-
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
527+
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None):
515528
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
516-
model_management.load_model_gpu(self.patcher)
517-
pixel_samples = pixel_samples.movedim(-1,1)
518-
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
529+
dims = self.latent_dim
530+
pixel_samples = pixel_samples.movedim(-1, 1)
531+
if dims == 3:
532+
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
533+
534+
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
535+
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
536+
537+
args = {}
538+
if tile_x is not None:
539+
args["tile_x"] = tile_x
540+
if tile_y is not None:
541+
args["tile_y"] = tile_y
542+
if overlap is not None:
543+
args["overlap"] = overlap
544+
545+
if dims == 1:
546+
args.pop("tile_y")
547+
samples = self.encode_tiled_1d(pixel_samples, **args)
548+
elif dims == 2:
549+
samples = self.encode_tiled_(pixel_samples, **args)
550+
elif dims == 3:
551+
samples = self.encode_tiled_3d(pixel_samples, **args)
552+
519553
return samples
520554

521555
def get_sd(self):
@@ -527,6 +561,12 @@ def spacial_compression_decode(self):
527561
except:
528562
return self.upscale_ratio
529563

564+
def spacial_compression_encode(self):
565+
try:
566+
return self.downscale_ratio[-1]
567+
except:
568+
return self.downscale_ratio
569+
530570
class StyleModel:
531571
def __init__(self, model, device="cpu"):
532572
self.model = model

comfy/utils.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
751751
return rows * cols
752752

753753
@torch.inference_mode()
754-
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
754+
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, pbar=None):
755755
dims = len(tile)
756756

757757
if not (isinstance(upscale_amount, (tuple, list))):
@@ -767,10 +767,22 @@ def get_upscale(dim, val):
767767
else:
768768
return up * val
769769

770+
def get_downscale(dim, val):
771+
up = upscale_amount[dim]
772+
if callable(up):
773+
return up(val)
774+
else:
775+
return val / up
776+
777+
if downscale:
778+
get_scale = get_downscale
779+
else:
780+
get_scale = get_upscale
781+
770782
def mult_list_upscale(a):
771783
out = []
772784
for i in range(len(a)):
773-
out.append(round(get_upscale(i, a[i])))
785+
out.append(round(get_scale(i, a[i])))
774786
return out
775787

776788
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
@@ -798,13 +810,13 @@ def mult_list_upscale(a):
798810
pos = max(0, min(s.shape[d + 2] - (overlap[d] + 1), it[d]))
799811
l = min(tile[d], s.shape[d + 2] - pos)
800812
s_in = s_in.narrow(d + 2, pos, l)
801-
upscaled.append(round(get_upscale(d, pos)))
813+
upscaled.append(round(get_scale(d, pos)))
802814

803815
ps = function(s_in).to(output_device)
804816
mask = torch.ones_like(ps)
805817

806818
for d in range(2, dims + 2):
807-
feather = round(get_upscale(d - 2, overlap[d - 2]))
819+
feather = round(get_scale(d - 2, overlap[d - 2]))
808820
if feather >= mask.shape[d]:
809821
continue
810822
for t in range(feather):
@@ -828,7 +840,7 @@ def mult_list_upscale(a):
828840
return output
829841

830842
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
831-
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar)
843+
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
832844

833845
PROGRESS_BAR_ENABLED = True
834846
def set_progress_bar_enabled(enabled):

nodes.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ class VAEDecodeTiled:
291291
@classmethod
292292
def INPUT_TYPES(s):
293293
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
294-
"tile_size": ("INT", {"default": 512, "min": 128, "max": 4096, "step": 32}),
294+
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}),
295295
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
296296
}}
297297
RETURN_TYPES = ("IMAGE",)
@@ -325,15 +325,16 @@ class VAEEncodeTiled:
325325
@classmethod
326326
def INPUT_TYPES(s):
327327
return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
328-
"tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
328+
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
329+
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
329330
}}
330331
RETURN_TYPES = ("LATENT",)
331332
FUNCTION = "encode"
332333

333334
CATEGORY = "_for_testing"
334335

335-
def encode(self, vae, pixels, tile_size):
336-
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
336+
def encode(self, vae, pixels, tile_size, overlap):
337+
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap)
337338
return ({"samples":t}, )
338339

339340
class VAEEncodeForInpaint:

0 commit comments

Comments
 (0)