@@ -336,6 +336,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None):
336336 self .memory_used_decode = lambda shape , dtype : (1000 * shape [2 ] * shape [3 ] * shape [4 ] * (6 * 8 * 8 )) * model_management .dtype_size (dtype )
337337 self .memory_used_encode = lambda shape , dtype : (1.5 * max (shape [2 ], 7 ) * shape [3 ] * shape [4 ] * (6 * 8 * 8 )) * model_management .dtype_size (dtype )
338338 self .upscale_ratio = (lambda a : max (0 , a * 6 - 5 ), 8 , 8 )
339+ self .downscale_ratio = (lambda a : max (0 , (a + 3 ) / 6 ), 8 , 8 )
339340 self .working_dtypes = [torch .float16 , torch .float32 ]
340341 elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd : #lightricks ltxv
341342 self .first_stage_model = comfy .ldm .lightricks .vae .causal_video_autoencoder .VideoVAE ()
@@ -344,12 +345,14 @@ def __init__(self, sd=None, device=None, config=None, dtype=None):
344345 self .memory_used_decode = lambda shape , dtype : (900 * shape [2 ] * shape [3 ] * shape [4 ] * (8 * 8 * 8 )) * model_management .dtype_size (dtype )
345346 self .memory_used_encode = lambda shape , dtype : (70 * max (shape [2 ], 7 ) * shape [3 ] * shape [4 ]) * model_management .dtype_size (dtype )
346347 self .upscale_ratio = (lambda a : max (0 , a * 8 - 7 ), 32 , 32 )
348+ self .downscale_ratio = (lambda a : max (0 , (a + 4 ) / 8 ), 32 , 32 )
347349 self .working_dtypes = [torch .bfloat16 , torch .float32 ]
348350 elif "decoder.conv_in.conv.weight" in sd :
349351 ddconfig = {'double_z' : True , 'z_channels' : 4 , 'resolution' : 256 , 'in_channels' : 3 , 'out_ch' : 3 , 'ch' : 128 , 'ch_mult' : [1 , 2 , 4 , 4 ], 'num_res_blocks' : 2 , 'attn_resolutions' : [], 'dropout' : 0.0 }
350352 ddconfig ["conv3d" ] = True
351353 ddconfig ["time_compress" ] = 4
352354 self .upscale_ratio = (lambda a : max (0 , a * 4 - 3 ), 8 , 8 )
355+ self .downscale_ratio = (lambda a : max (0 , (a + 2 ) / 4 ), 8 , 8 )
353356 self .latent_dim = 3
354357 self .latent_channels = ddconfig ['z_channels' ] = sd ["decoder.conv_in.conv.weight" ].shape [1 ]
355358 self .first_stage_model = AutoencoderKL (ddconfig = ddconfig , embed_dim = sd ['post_quant_conv.weight' ].shape [1 ])
@@ -385,10 +388,12 @@ def __init__(self, sd=None, device=None, config=None, dtype=None):
385388 logging .debug ("VAE load device: {}, offload device: {}, dtype: {}" .format (self .device , offload_device , self .vae_dtype ))
386389
387390 def vae_encode_crop_pixels (self , pixels ):
391+ downscale_ratio = self .spacial_compression_encode ()
392+
388393 dims = pixels .shape [1 :- 1 ]
389394 for d in range (len (dims )):
390- x = (dims [d ] // self . downscale_ratio ) * self . downscale_ratio
391- x_offset = (dims [d ] % self . downscale_ratio ) // 2
395+ x = (dims [d ] // downscale_ratio ) * downscale_ratio
396+ x_offset = (dims [d ] % downscale_ratio ) // 2
392397 if x != dims [d ]:
393398 pixels = pixels .narrow (d + 1 , x_offset , x )
394399 return pixels
@@ -409,7 +414,7 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
409414
410415 def decode_tiled_1d (self , samples , tile_x = 128 , overlap = 32 ):
411416 decode_fn = lambda a : self .first_stage_model .decode (a .to (self .vae_dtype ).to (self .device )).float ()
412- return comfy .utils .tiled_scale_multidim (samples , decode_fn , tile = (tile_x ,), overlap = overlap , upscale_amount = self .upscale_ratio , out_channels = self .output_channels , output_device = self .output_device )
417+ return self . process_output ( comfy .utils .tiled_scale_multidim (samples , decode_fn , tile = (tile_x ,), overlap = overlap , upscale_amount = self .upscale_ratio , out_channels = self .output_channels , output_device = self .output_device ) )
413418
414419 def decode_tiled_3d (self , samples , tile_t = 999 , tile_x = 32 , tile_y = 32 , overlap = (1 , 8 , 8 )):
415420 decode_fn = lambda a : self .first_stage_model .decode (a .to (self .vae_dtype ).to (self .device )).float ()
@@ -432,6 +437,10 @@ def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
432437 encode_fn = lambda a : self .first_stage_model .encode ((self .process_input (a )).to (self .vae_dtype ).to (self .device )).float ()
433438 return comfy .utils .tiled_scale_multidim (samples , encode_fn , tile = (tile_x ,), overlap = overlap , upscale_amount = (1 / self .downscale_ratio ), out_channels = self .latent_channels , output_device = self .output_device )
434439
440+ def encode_tiled_3d (self , samples , tile_t = 9999 , tile_x = 512 , tile_y = 512 , overlap = (1 , 64 , 64 )):
441+ encode_fn = lambda a : self .first_stage_model .encode ((self .process_input (a )).to (self .vae_dtype ).to (self .device )).float ()
442+ return comfy .utils .tiled_scale_multidim (samples , encode_fn , tile = (tile_t , tile_x , tile_y ), overlap = overlap , upscale_amount = self .downscale_ratio , out_channels = self .latent_channels , downscale = True , output_device = self .output_device )
443+
435444 def decode (self , samples_in ):
436445 pixel_samples = None
437446 try :
@@ -504,18 +513,43 @@ def encode(self, pixel_samples):
504513
505514 except model_management .OOM_EXCEPTION :
506515 logging .warning ("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding." )
507- if len (pixel_samples .shape ) == 3 :
516+ if self .latent_dim == 3 :
517+ tile = 256
518+ overlap = tile // 4
519+ samples = self .encode_tiled_3d (pixel_samples , tile_x = tile , tile_y = tile , overlap = (1 , overlap , overlap ))
520+ elif self .latent_dim == 1 :
508521 samples = self .encode_tiled_1d (pixel_samples )
509522 else :
510523 samples = self .encode_tiled_ (pixel_samples )
511524
512525 return samples
513526
514- def encode_tiled (self , pixel_samples , tile_x = 512 , tile_y = 512 , overlap = 64 ):
527+ def encode_tiled (self , pixel_samples , tile_x = None , tile_y = None , overlap = None ):
515528 pixel_samples = self .vae_encode_crop_pixels (pixel_samples )
516- model_management .load_model_gpu (self .patcher )
517- pixel_samples = pixel_samples .movedim (- 1 ,1 )
518- samples = self .encode_tiled_ (pixel_samples , tile_x = tile_x , tile_y = tile_y , overlap = overlap )
529+ dims = self .latent_dim
530+ pixel_samples = pixel_samples .movedim (- 1 , 1 )
531+ if dims == 3 :
532+ pixel_samples = pixel_samples .movedim (1 , 0 ).unsqueeze (0 )
533+
534+ memory_used = self .memory_used_encode (pixel_samples .shape , self .vae_dtype ) # TODO: calculate mem required for tile
535+ model_management .load_models_gpu ([self .patcher ], memory_required = memory_used )
536+
537+ args = {}
538+ if tile_x is not None :
539+ args ["tile_x" ] = tile_x
540+ if tile_y is not None :
541+ args ["tile_y" ] = tile_y
542+ if overlap is not None :
543+ args ["overlap" ] = overlap
544+
545+ if dims == 1 :
546+ args .pop ("tile_y" )
547+ samples = self .encode_tiled_1d (pixel_samples , ** args )
548+ elif dims == 2 :
549+ samples = self .encode_tiled_ (pixel_samples , ** args )
550+ elif dims == 3 :
551+ samples = self .encode_tiled_3d (pixel_samples , ** args )
552+
519553 return samples
520554
521555 def get_sd (self ):
@@ -527,6 +561,12 @@ def spacial_compression_decode(self):
527561 except :
528562 return self .upscale_ratio
529563
564+ def spacial_compression_encode (self ):
565+ try :
566+ return self .downscale_ratio [- 1 ]
567+ except :
568+ return self .downscale_ratio
569+
530570class StyleModel :
531571 def __init__ (self , model , device = "cpu" ):
532572 self .model = model
0 commit comments