Skip to content

Commit 3b54b02

Browse files
committed
Merge branch 'master' into fix-context-window-slicing
2 parents 2835f7f + 18de0b2 commit 3b54b02

24 files changed

+2630
-42
lines changed

comfy/latent_formats.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,79 @@ def __init__(self):
533533
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
534534
]).view(1, self.latent_channels, 1, 1, 1)
535535

536+
class HunyuanImage21(LatentFormat):
537+
latent_channels = 64
538+
latent_dimensions = 2
539+
scale_factor = 0.75289
540+
541+
latent_rgb_factors = [
542+
[-0.0154, -0.0397, -0.0521],
543+
[ 0.0005, 0.0093, 0.0006],
544+
[-0.0805, -0.0773, -0.0586],
545+
[-0.0494, -0.0487, -0.0498],
546+
[-0.0212, -0.0076, -0.0261],
547+
[-0.0179, -0.0417, -0.0505],
548+
[ 0.0158, 0.0310, 0.0239],
549+
[ 0.0409, 0.0516, 0.0201],
550+
[ 0.0350, 0.0553, 0.0036],
551+
[-0.0447, -0.0327, -0.0479],
552+
[-0.0038, -0.0221, -0.0365],
553+
[-0.0423, -0.0718, -0.0654],
554+
[ 0.0039, 0.0368, 0.0104],
555+
[ 0.0655, 0.0217, 0.0122],
556+
[ 0.0490, 0.1638, 0.2053],
557+
[ 0.0932, 0.0829, 0.0650],
558+
[-0.0186, -0.0209, -0.0135],
559+
[-0.0080, -0.0076, -0.0148],
560+
[-0.0284, -0.0201, 0.0011],
561+
[-0.0642, -0.0294, -0.0777],
562+
[-0.0035, 0.0076, -0.0140],
563+
[ 0.0519, 0.0731, 0.0887],
564+
[-0.0102, 0.0095, 0.0704],
565+
[ 0.0068, 0.0218, -0.0023],
566+
[-0.0726, -0.0486, -0.0519],
567+
[ 0.0260, 0.0295, 0.0263],
568+
[ 0.0250, 0.0333, 0.0341],
569+
[ 0.0168, -0.0120, -0.0174],
570+
[ 0.0226, 0.1037, 0.0114],
571+
[ 0.2577, 0.1906, 0.1604],
572+
[-0.0646, -0.0137, -0.0018],
573+
[-0.0112, 0.0309, 0.0358],
574+
[-0.0347, 0.0146, -0.0481],
575+
[ 0.0234, 0.0179, 0.0201],
576+
[ 0.0157, 0.0313, 0.0225],
577+
[ 0.0423, 0.0675, 0.0524],
578+
[-0.0031, 0.0027, -0.0255],
579+
[ 0.0447, 0.0555, 0.0330],
580+
[-0.0152, 0.0103, 0.0299],
581+
[-0.0755, -0.0489, -0.0635],
582+
[ 0.0853, 0.0788, 0.1017],
583+
[-0.0272, -0.0294, -0.0471],
584+
[ 0.0440, 0.0400, -0.0137],
585+
[ 0.0335, 0.0317, -0.0036],
586+
[-0.0344, -0.0621, -0.0984],
587+
[-0.0127, -0.0630, -0.0620],
588+
[-0.0648, 0.0360, 0.0924],
589+
[-0.0781, -0.0801, -0.0409],
590+
[ 0.0363, 0.0613, 0.0499],
591+
[ 0.0238, 0.0034, 0.0041],
592+
[-0.0135, 0.0258, 0.0310],
593+
[ 0.0614, 0.1086, 0.0589],
594+
[ 0.0428, 0.0350, 0.0205],
595+
[ 0.0153, 0.0173, -0.0018],
596+
[-0.0288, -0.0455, -0.0091],
597+
[ 0.0344, 0.0109, -0.0157],
598+
[-0.0205, -0.0247, -0.0187],
599+
[ 0.0487, 0.0126, 0.0064],
600+
[-0.0220, -0.0013, 0.0074],
601+
[-0.0203, -0.0094, -0.0048],
602+
[-0.0719, 0.0429, -0.0442],
603+
[ 0.1042, 0.0497, 0.0356],
604+
[-0.0659, -0.0578, -0.0280],
605+
[-0.0060, -0.0322, -0.0234]]
606+
607+
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
608+
536609
class Hunyuan3Dv2(LatentFormat):
537610
latent_channels = 64
538611
latent_dimensions = 1

comfy/ldm/hunyuan3dv2_1/hunyuandit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def __init__(
426426
text_states_dim=1024,
427427
qk_norm=False,
428428
norm_layer=nn.LayerNorm,
429-
qk_norm_layer=nn.RMSNorm,
429+
qk_norm_layer=True,
430430
qkv_bias=True,
431431
skip_connection=True,
432432
timested_modulate=False,

comfy/ldm/hunyuan_video/model.py

Lines changed: 100 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class HunyuanVideoParams:
4040
patch_size: list
4141
qkv_bias: bool
4242
guidance_embed: bool
43+
byt5: bool
44+
meanflow: bool
4345

4446

4547
class SelfAttentionRef(nn.Module):
@@ -161,6 +163,30 @@ def forward(
161163
x = self.individual_token_refiner(x, c, mask)
162164
return x
163165

166+
167+
class ByT5Mapper(nn.Module):
168+
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None):
169+
super().__init__()
170+
self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device)
171+
self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device)
172+
self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
173+
self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device)
174+
self.use_res = use_res
175+
self.act_fn = nn.GELU()
176+
177+
def forward(self, x):
178+
if self.use_res:
179+
res = x
180+
x = self.layernorm(x)
181+
x = self.fc1(x)
182+
x = self.act_fn(x)
183+
x = self.fc2(x)
184+
x2 = self.act_fn(x)
185+
x2 = self.fc3(x2)
186+
if self.use_res:
187+
x2 = x2 + res
188+
return x2
189+
164190
class HunyuanVideo(nn.Module):
165191
"""
166192
Transformer model for flow matching on sequences.
@@ -185,9 +211,13 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
185211
self.num_heads = params.num_heads
186212
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
187213

188-
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
214+
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations)
189215
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
190-
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
216+
if params.vec_in_dim is not None:
217+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
218+
else:
219+
self.vector_in = None
220+
191221
self.guidance_in = (
192222
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
193223
)
@@ -215,6 +245,23 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
215245
]
216246
)
217247

248+
if params.byt5:
249+
self.byt5_in = ByT5Mapper(
250+
in_dim=1472,
251+
out_dim=2048,
252+
hidden_dim=2048,
253+
out_dim1=self.hidden_size,
254+
use_res=False,
255+
dtype=dtype, device=device, operations=operations
256+
)
257+
else:
258+
self.byt5_in = None
259+
260+
if params.meanflow:
261+
self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
262+
else:
263+
self.time_r_in = None
264+
218265
if final_layer:
219266
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
220267

@@ -226,7 +273,8 @@ def forward_orig(
226273
txt_ids: Tensor,
227274
txt_mask: Tensor,
228275
timesteps: Tensor,
229-
y: Tensor,
276+
y: Tensor = None,
277+
txt_byt5=None,
230278
guidance: Tensor = None,
231279
guiding_frame_index=None,
232280
ref_latent=None,
@@ -240,6 +288,14 @@ def forward_orig(
240288
img = self.img_in(img)
241289
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
242290

291+
if self.time_r_in is not None:
292+
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
293+
if len(w) > 0:
294+
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
295+
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
296+
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
297+
vec = (vec + vec_r) / 2
298+
243299
if ref_latent is not None:
244300
ref_latent_ids = self.img_ids(ref_latent)
245301
ref_latent = self.img_in(ref_latent)
@@ -250,13 +306,17 @@ def forward_orig(
250306

251307
if guiding_frame_index is not None:
252308
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
253-
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
254-
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
309+
if self.vector_in is not None:
310+
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
311+
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
312+
else:
313+
vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1)
255314
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
256315
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
257316
modulation_dims_txt = [(0, None, 1)]
258317
else:
259-
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
318+
if self.vector_in is not None:
319+
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
260320
modulation_dims = None
261321
modulation_dims_txt = None
262322

@@ -269,6 +329,12 @@ def forward_orig(
269329

270330
txt = self.txt_in(txt, timesteps, txt_mask)
271331

332+
if self.byt5_in is not None and txt_byt5 is not None:
333+
txt_byt5 = self.byt5_in(txt_byt5)
334+
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
335+
txt = torch.cat((txt, txt_byt5), dim=1)
336+
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
337+
272338
ids = torch.cat((img_ids, txt_ids), dim=1)
273339
pe = self.pe_embedder(ids)
274340

@@ -328,12 +394,16 @@ def block_wrap(args):
328394

329395
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
330396

331-
shape = initial_shape[-3:]
397+
shape = initial_shape[-len(self.patch_size):]
332398
for i in range(len(shape)):
333399
shape[i] = shape[i] // self.patch_size[i]
334400
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
335-
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
336-
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
401+
if img.ndim == 8:
402+
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
403+
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
404+
else:
405+
img = img.permute(0, 3, 1, 4, 2, 5)
406+
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3])
337407
return img
338408

339409
def img_ids(self, x):
@@ -348,16 +418,30 @@ def img_ids(self, x):
348418
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
349419
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
350420

351-
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
421+
def img_ids_2d(self, x):
422+
bs, c, h, w = x.shape
423+
patch_size = self.patch_size
424+
h_len = ((h + (patch_size[0] // 2)) // patch_size[0])
425+
w_len = ((w + (patch_size[1] // 2)) // patch_size[1])
426+
img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype)
427+
img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
428+
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
429+
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
430+
431+
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
352432
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
353433
self._forward,
354434
self,
355435
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
356-
).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
436+
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
357437

358-
def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
359-
bs, c, t, h, w = x.shape
360-
img_ids = self.img_ids(x)
361-
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
362-
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
438+
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
439+
bs = x.shape[0]
440+
if len(self.patch_size) == 3:
441+
img_ids = self.img_ids(x)
442+
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
443+
else:
444+
img_ids = self.img_ids_2d(x)
445+
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
446+
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
363447
return out

0 commit comments

Comments
 (0)