Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion comfy/ldm/wan/model_animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def __init__(self,
def after_patch_embedding(self, x, pose_latents, face_pixel_values):
if pose_latents is not None:
pose_latents = self.pose_patch_embedding(pose_latents)
x[:, :, 1:] += pose_latents
x[:, :, 1:pose_latents.shape[2] + 1] += pose_latents[:, :, :x.shape[2] - 1]

if face_pixel_values is None:
return x, None
Expand Down
59 changes: 44 additions & 15 deletions comfy_extras/nodes_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,18 +1128,22 @@ def define_schema(cls):
io.Image.Input("pose_video", optional=True),
io.Int.Input("continue_motion_max_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Image.Input("continue_motion", optional=True),
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="The amount of frames to seek in all the input videos. Used for generating longer videos by chunk. Connect to the video_frame_offset output of the previous node for extending a video."),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
io.Int.Output(display_name="trim_latent"),
io.Int.Output(display_name="trim_image"),
io.Int.Output(display_name="video_frame_offset"),
],
is_experimental=True,
)

@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None) -> io.NodeOutput:
def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, video_frame_offset, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None) -> io.NodeOutput:
trim_to_pose_video = False
latent_length = ((length - 1) // 4) + 1
latent_width = width // 8
latent_height = height // 8
Expand All @@ -1152,35 +1156,60 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, con
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = torch.zeros((1, 1, concat_latent_image.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype)
trim_latent += concat_latent_image.shape[2]
ref_motion_latent_length = 0

if continue_motion is None:
image = torch.ones((length, height, width, 3)) * 0.5
else:
continue_motion = continue_motion[-continue_motion_max_frames:]
video_frame_offset -= continue_motion.shape[0]
video_frame_offset = max(0, video_frame_offset)
continue_motion = comfy.utils.common_upscale(continue_motion[-length:].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
image = torch.ones((length, height, width, continue_motion.shape[-1]), device=continue_motion.device, dtype=continue_motion.dtype) * 0.5
image[:continue_motion.shape[0]] = continue_motion
ref_motion_latent_length += ((continue_motion.shape[0] - 1) // 4) + 1

if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})

if face_video is not None:
face_video = comfy.utils.common_upscale(face_video[:length].movedim(-1, 1), 512, 512, "area", "center") * 2.0 - 1.0
face_video = face_video.movedim(0, 1).unsqueeze(0)
positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video})
negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0})
if pose_video is not None:
if pose_video.shape[0] <= video_frame_offset:
pose_video = None
else:
pose_video = pose_video[video_frame_offset:]

if pose_video is not None:
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
if not trim_to_pose_video:
if pose_video.shape[0] < length:
pose_video = torch.cat((pose_video,) + (pose_video[-1:],) * (length - pose_video.shape[0]), dim=0)

pose_video_latent = vae.encode(pose_video[:, :, :, :3])
positive = node_helpers.conditioning_set_values(positive, {"pose_video_latent": pose_video_latent})
negative = node_helpers.conditioning_set_values(negative, {"pose_video_latent": pose_video_latent})

if continue_motion is None:
image = torch.ones((length, height, width, 3)) * 0.5
else:
continue_motion = continue_motion[-continue_motion_max_frames:]
continue_motion = comfy.utils.common_upscale(continue_motion[-length:].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
image = torch.ones((length, height, width, continue_motion.shape[-1]), device=continue_motion.device, dtype=continue_motion.dtype) * 0.5
image[:continue_motion.shape[0]] = continue_motion
if trim_to_pose_video:
latent_length = pose_video_latent.shape[2]
length = latent_length * 4 - 3
image = image[:length]

if face_video is not None:
if face_video.shape[0] <= video_frame_offset:
face_video = None
else:
face_video = face_video[video_frame_offset:]

if face_video is not None:
face_video = comfy.utils.common_upscale(face_video[:length].movedim(-1, 1), 512, 512, "area", "center") * 2.0 - 1.0
face_video = face_video.movedim(0, 1).unsqueeze(0)
positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video})
negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0})

concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2)
mask_refmotion = torch.ones((1, 1, latent_length, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype)
if continue_motion is not None:
mask_refmotion[:, :, :((continue_motion.shape[0] - 1) // 4) + 1] = 0.0
mask_refmotion[:, :, :ref_motion_latent_length] = 0.0

mask = torch.cat((mask, mask_refmotion), dim=2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
Expand All @@ -1189,7 +1218,7 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, con
latent = torch.zeros([batch_size, 16, latent_length + trim_latent, latent_height, latent_width], device=comfy.model_management.intermediate_device())
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent, trim_latent)
return io.NodeOutput(positive, negative, out_latent, trim_latent, max(0, ref_motion_latent_length * 4 - 3), video_frame_offset + length)

class Wan22ImageToVideoLatent(io.ComfyNode):
@classmethod
Expand Down
Loading