Skip to content

Commit e01e99d

Browse files
Support hunyuan image distilled model. (#9807)
1 parent 72212fe commit e01e99d

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

comfy/ldm/hunyuan_video/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class HunyuanVideoParams:
4141
qkv_bias: bool
4242
guidance_embed: bool
4343
byt5: bool
44+
meanflow: bool
4445

4546

4647
class SelfAttentionRef(nn.Module):
@@ -256,6 +257,11 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
256257
else:
257258
self.byt5_in = None
258259

260+
if params.meanflow:
261+
self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
262+
else:
263+
self.time_r_in = None
264+
259265
if final_layer:
260266
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
261267

@@ -282,6 +288,14 @@ def forward_orig(
282288
img = self.img_in(img)
283289
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
284290

291+
if self.time_r_in is not None:
292+
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
293+
if len(w) > 0:
294+
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
295+
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
296+
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
297+
vec = (vec + vec_r) / 2
298+
285299
if ref_latent is not None:
286300
ref_latent_ids = self.img_ids(ref_latent)
287301
ref_latent = self.img_in(ref_latent)

comfy/model_detection.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
142142
dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
143143
dit_config["patch_size"] = list(in_w.shape[2:])
144144
dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
145-
if '{}vector_in.in_layer.weight'.format(key_prefix) in state_dict:
145+
if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys):
146146
dit_config["vec_in_dim"] = 768
147-
dit_config["axes_dim"] = [16, 56, 56]
148147
else:
149148
dit_config["vec_in_dim"] = None
149+
150+
if len(dit_config["patch_size"]) == 2:
150151
dit_config["axes_dim"] = [64, 64]
152+
else:
153+
dit_config["axes_dim"] = [16, 56, 56]
154+
155+
if any(s.startswith('{}time_r_in.'.format(key_prefix)) for s in state_dict_keys):
156+
dit_config["meanflow"] = True
157+
else:
158+
dit_config["meanflow"] = False
151159

152160
dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
153161
dit_config["hidden_size"] = in_w.shape[0]

0 commit comments

Comments
 (0)