@@ -40,6 +40,8 @@ class HunyuanVideoParams:
4040 patch_size : list
4141 qkv_bias : bool
4242 guidance_embed : bool
43+ byt5 : bool
44+ meanflow : bool
4345
4446
4547class SelfAttentionRef (nn .Module ):
@@ -161,6 +163,30 @@ def forward(
161163 x = self .individual_token_refiner (x , c , mask )
162164 return x
163165
166+
167+ class ByT5Mapper (nn .Module ):
168+ def __init__ (self , in_dim , out_dim , hidden_dim , out_dim1 , use_res = False , dtype = None , device = None , operations = None ):
169+ super ().__init__ ()
170+ self .layernorm = operations .LayerNorm (in_dim , dtype = dtype , device = device )
171+ self .fc1 = operations .Linear (in_dim , hidden_dim , dtype = dtype , device = device )
172+ self .fc2 = operations .Linear (hidden_dim , out_dim , dtype = dtype , device = device )
173+ self .fc3 = operations .Linear (out_dim , out_dim1 , dtype = dtype , device = device )
174+ self .use_res = use_res
175+ self .act_fn = nn .GELU ()
176+
177+ def forward (self , x ):
178+ if self .use_res :
179+ res = x
180+ x = self .layernorm (x )
181+ x = self .fc1 (x )
182+ x = self .act_fn (x )
183+ x = self .fc2 (x )
184+ x2 = self .act_fn (x )
185+ x2 = self .fc3 (x2 )
186+ if self .use_res :
187+ x2 = x2 + res
188+ return x2
189+
164190class HunyuanVideo (nn .Module ):
165191 """
166192 Transformer model for flow matching on sequences.
@@ -185,9 +211,13 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
185211 self .num_heads = params .num_heads
186212 self .pe_embedder = EmbedND (dim = pe_dim , theta = params .theta , axes_dim = params .axes_dim )
187213
188- self .img_in = comfy .ldm .modules .diffusionmodules .mmdit .PatchEmbed (None , self .patch_size , self .in_channels , self .hidden_size , conv3d = True , dtype = dtype , device = device , operations = operations )
214+ self .img_in = comfy .ldm .modules .diffusionmodules .mmdit .PatchEmbed (None , self .patch_size , self .in_channels , self .hidden_size , conv3d = len ( self . patch_size ) == 3 , dtype = dtype , device = device , operations = operations )
189215 self .time_in = MLPEmbedder (in_dim = 256 , hidden_dim = self .hidden_size , dtype = dtype , device = device , operations = operations )
190- self .vector_in = MLPEmbedder (params .vec_in_dim , self .hidden_size , dtype = dtype , device = device , operations = operations )
216+ if params .vec_in_dim is not None :
217+ self .vector_in = MLPEmbedder (params .vec_in_dim , self .hidden_size , dtype = dtype , device = device , operations = operations )
218+ else :
219+ self .vector_in = None
220+
191221 self .guidance_in = (
192222 MLPEmbedder (in_dim = 256 , hidden_dim = self .hidden_size , dtype = dtype , device = device , operations = operations ) if params .guidance_embed else nn .Identity ()
193223 )
@@ -215,6 +245,23 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
215245 ]
216246 )
217247
248+ if params .byt5 :
249+ self .byt5_in = ByT5Mapper (
250+ in_dim = 1472 ,
251+ out_dim = 2048 ,
252+ hidden_dim = 2048 ,
253+ out_dim1 = self .hidden_size ,
254+ use_res = False ,
255+ dtype = dtype , device = device , operations = operations
256+ )
257+ else :
258+ self .byt5_in = None
259+
260+ if params .meanflow :
261+ self .time_r_in = MLPEmbedder (in_dim = 256 , hidden_dim = self .hidden_size , dtype = dtype , device = device , operations = operations )
262+ else :
263+ self .time_r_in = None
264+
218265 if final_layer :
219266 self .final_layer = LastLayer (self .hidden_size , self .patch_size [- 1 ], self .out_channels , dtype = dtype , device = device , operations = operations )
220267
@@ -226,7 +273,8 @@ def forward_orig(
226273 txt_ids : Tensor ,
227274 txt_mask : Tensor ,
228275 timesteps : Tensor ,
229- y : Tensor ,
276+ y : Tensor = None ,
277+ txt_byt5 = None ,
230278 guidance : Tensor = None ,
231279 guiding_frame_index = None ,
232280 ref_latent = None ,
@@ -240,6 +288,14 @@ def forward_orig(
240288 img = self .img_in (img )
241289 vec = self .time_in (timestep_embedding (timesteps , 256 , time_factor = 1.0 ).to (img .dtype ))
242290
291+ if self .time_r_in is not None :
292+ w = torch .where (transformer_options ['sigmas' ][0 ] == transformer_options ['sample_sigmas' ])[0 ] # This most likely could be improved
293+ if len (w ) > 0 :
294+ timesteps_r = transformer_options ['sample_sigmas' ][w [0 ] + 1 ]
295+ timesteps_r = timesteps_r .unsqueeze (0 ).to (device = timesteps .device , dtype = timesteps .dtype )
296+ vec_r = self .time_r_in (timestep_embedding (timesteps_r , 256 , time_factor = 1000.0 ).to (img .dtype ))
297+ vec = (vec + vec_r ) / 2
298+
243299 if ref_latent is not None :
244300 ref_latent_ids = self .img_ids (ref_latent )
245301 ref_latent = self .img_in (ref_latent )
@@ -250,13 +306,17 @@ def forward_orig(
250306
251307 if guiding_frame_index is not None :
252308 token_replace_vec = self .time_in (timestep_embedding (guiding_frame_index , 256 , time_factor = 1.0 ))
253- vec_ = self .vector_in (y [:, :self .params .vec_in_dim ])
254- vec = torch .cat ([(vec_ + token_replace_vec ).unsqueeze (1 ), (vec_ + vec ).unsqueeze (1 )], dim = 1 )
309+ if self .vector_in is not None :
310+ vec_ = self .vector_in (y [:, :self .params .vec_in_dim ])
311+ vec = torch .cat ([(vec_ + token_replace_vec ).unsqueeze (1 ), (vec_ + vec ).unsqueeze (1 )], dim = 1 )
312+ else :
313+ vec = torch .cat ([(token_replace_vec ).unsqueeze (1 ), (vec ).unsqueeze (1 )], dim = 1 )
255314 frame_tokens = (initial_shape [- 1 ] // self .patch_size [- 1 ]) * (initial_shape [- 2 ] // self .patch_size [- 2 ])
256315 modulation_dims = [(0 , frame_tokens , 0 ), (frame_tokens , None , 1 )]
257316 modulation_dims_txt = [(0 , None , 1 )]
258317 else :
259- vec = vec + self .vector_in (y [:, :self .params .vec_in_dim ])
318+ if self .vector_in is not None :
319+ vec = vec + self .vector_in (y [:, :self .params .vec_in_dim ])
260320 modulation_dims = None
261321 modulation_dims_txt = None
262322
@@ -269,6 +329,12 @@ def forward_orig(
269329
270330 txt = self .txt_in (txt , timesteps , txt_mask )
271331
332+ if self .byt5_in is not None and txt_byt5 is not None :
333+ txt_byt5 = self .byt5_in (txt_byt5 )
334+ txt_byt5_ids = torch .zeros ((txt_ids .shape [0 ], txt_byt5 .shape [1 ], txt_ids .shape [- 1 ]), device = txt_ids .device , dtype = txt_ids .dtype )
335+ txt = torch .cat ((txt , txt_byt5 ), dim = 1 )
336+ txt_ids = torch .cat ((txt_ids , txt_byt5_ids ), dim = 1 )
337+
272338 ids = torch .cat ((img_ids , txt_ids ), dim = 1 )
273339 pe = self .pe_embedder (ids )
274340
@@ -328,12 +394,16 @@ def block_wrap(args):
328394
329395 img = self .final_layer (img , vec , modulation_dims = modulation_dims ) # (N, T, patch_size ** 2 * out_channels)
330396
331- shape = initial_shape [- 3 :]
397+ shape = initial_shape [- len ( self . patch_size ) :]
332398 for i in range (len (shape )):
333399 shape [i ] = shape [i ] // self .patch_size [i ]
334400 img = img .reshape ([img .shape [0 ]] + shape + [self .out_channels ] + self .patch_size )
335- img = img .permute (0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 )
336- img = img .reshape (initial_shape [0 ], self .out_channels , initial_shape [2 ], initial_shape [3 ], initial_shape [4 ])
401+ if img .ndim == 8 :
402+ img = img .permute (0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 )
403+ img = img .reshape (initial_shape [0 ], self .out_channels , initial_shape [2 ], initial_shape [3 ], initial_shape [4 ])
404+ else :
405+ img = img .permute (0 , 3 , 1 , 4 , 2 , 5 )
406+ img = img .reshape (initial_shape [0 ], self .out_channels , initial_shape [2 ], initial_shape [3 ])
337407 return img
338408
339409 def img_ids (self , x ):
@@ -348,16 +418,30 @@ def img_ids(self, x):
348418 img_ids [:, :, :, 2 ] = img_ids [:, :, :, 2 ] + torch .linspace (0 , w_len - 1 , steps = w_len , device = x .device , dtype = x .dtype ).reshape (1 , 1 , - 1 )
349419 return repeat (img_ids , "t h w c -> b (t h w) c" , b = bs )
350420
351- def forward (self , x , timestep , context , y , guidance = None , attention_mask = None , guiding_frame_index = None , ref_latent = None , control = None , transformer_options = {}, ** kwargs ):
421+ def img_ids_2d (self , x ):
422+ bs , c , h , w = x .shape
423+ patch_size = self .patch_size
424+ h_len = ((h + (patch_size [0 ] // 2 )) // patch_size [0 ])
425+ w_len = ((w + (patch_size [1 ] // 2 )) // patch_size [1 ])
426+ img_ids = torch .zeros ((h_len , w_len , 2 ), device = x .device , dtype = x .dtype )
427+ img_ids [:, :, 0 ] = img_ids [:, :, 0 ] + torch .linspace (0 , h_len - 1 , steps = h_len , device = x .device , dtype = x .dtype ).unsqueeze (1 )
428+ img_ids [:, :, 1 ] = img_ids [:, :, 1 ] + torch .linspace (0 , w_len - 1 , steps = w_len , device = x .device , dtype = x .dtype ).unsqueeze (0 )
429+ return repeat (img_ids , "h w c -> b (h w) c" , b = bs )
430+
431+ def forward (self , x , timestep , context , y = None , txt_byt5 = None , guidance = None , attention_mask = None , guiding_frame_index = None , ref_latent = None , control = None , transformer_options = {}, ** kwargs ):
352432 return comfy .patcher_extension .WrapperExecutor .new_class_executor (
353433 self ._forward ,
354434 self ,
355435 comfy .patcher_extension .get_all_wrappers (comfy .patcher_extension .WrappersMP .DIFFUSION_MODEL , transformer_options )
356- ).execute (x , timestep , context , y , guidance , attention_mask , guiding_frame_index , ref_latent , control , transformer_options , ** kwargs )
436+ ).execute (x , timestep , context , y , txt_byt5 , guidance , attention_mask , guiding_frame_index , ref_latent , control , transformer_options , ** kwargs )
357437
358- def _forward (self , x , timestep , context , y , guidance = None , attention_mask = None , guiding_frame_index = None , ref_latent = None , control = None , transformer_options = {}, ** kwargs ):
359- bs , c , t , h , w = x .shape
360- img_ids = self .img_ids (x )
361- txt_ids = torch .zeros ((bs , context .shape [1 ], 3 ), device = x .device , dtype = x .dtype )
362- out = self .forward_orig (x , img_ids , context , txt_ids , attention_mask , timestep , y , guidance , guiding_frame_index , ref_latent , control = control , transformer_options = transformer_options )
438+ def _forward (self , x , timestep , context , y = None , txt_byt5 = None , guidance = None , attention_mask = None , guiding_frame_index = None , ref_latent = None , control = None , transformer_options = {}, ** kwargs ):
439+ bs = x .shape [0 ]
440+ if len (self .patch_size ) == 3 :
441+ img_ids = self .img_ids (x )
442+ txt_ids = torch .zeros ((bs , context .shape [1 ], 3 ), device = x .device , dtype = x .dtype )
443+ else :
444+ img_ids = self .img_ids_2d (x )
445+ txt_ids = torch .zeros ((bs , context .shape [1 ], 2 ), device = x .device , dtype = x .dtype )
446+ out = self .forward_orig (x , img_ids , context , txt_ids , attention_mask , timestep , y , txt_byt5 , guidance , guiding_frame_index , ref_latent , control = control , transformer_options = transformer_options )
363447 return out
0 commit comments