Skip to content
115 changes: 114 additions & 1 deletion comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,112 @@ def extra_conds(self, **kwargs):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out

def ind_sel(target: torch.Tensor, ind: torch.Tensor, dim: int = 1):
"""Index selection utility function"""
assert (
len(ind.shape) > dim
), "Index must have the target dim, but get dim: %d, ind shape: %s" % (dim, str(ind.shape))

target = target.expand(
*tuple(
[ind.shape[k] if target.shape[k] == 1 else -1 for k in range(dim)]
+ [
-1,
]
* (len(target.shape) - dim)
)
)

ind_pad = ind

if len(target.shape) > dim + 1:
for _ in range(len(target.shape) - (dim + 1)):
ind_pad = ind_pad.unsqueeze(-1)
ind_pad = ind_pad.expand(*(-1,) * (dim + 1), *target.shape[(dim + 1) : :])

return torch.gather(target, dim=dim, index=ind_pad)


def merge_final(vert_attr: torch.Tensor, weight: torch.Tensor, vert_assign: torch.Tensor):
"""Merge vertex attributes with weights"""
target_dim = len(vert_assign.shape) - 1
if len(vert_attr.shape) == 2:
assert vert_attr.shape[0] > vert_assign.max()
new_shape = [1] * target_dim + list(vert_attr.shape)
tensor = vert_attr.reshape(new_shape)
sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim)
else:
assert vert_attr.shape[1] > vert_assign.max()
new_shape = [vert_attr.shape[0]] + [1] * (target_dim - 1) + list(vert_attr.shape[1:])
tensor = vert_attr.reshape(new_shape)
sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim)

final_attr = torch.sum(sel_attr * weight.unsqueeze(-1), dim=-2)
return final_attr


def patch_motion(
tracks: torch.FloatTensor, # (B, T, N, 4)
vid: torch.FloatTensor, # (C, T, H, W)
temperature: float = 220.0,
vae_divide: tuple = (4, 16),
topk: int = 2,
):
"""Apply motion patching based on tracks"""
with torch.no_grad():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove torch.no_grad

_, T, H, W = vid.shape
N = tracks.shape[2]
_, tracks_xy, visible = torch.split(
tracks, [1, 2, 1], dim=-1
) # (B, T, N, 2) | (B, T, N, 1)
tracks_n = tracks_xy / torch.tensor([W / min(H, W), H / min(H, W)], device=tracks_xy.device)
tracks_n = tracks_n.clamp(-1, 1)
visible = visible.clamp(0, 1)

xx = torch.linspace(-W / min(H, W), W / min(H, W), W)
yy = torch.linspace(-H / min(H, W), H / min(H, W), H)

grid = torch.stack(torch.meshgrid(yy, xx, indexing="ij")[::-1], dim=-1).to(
tracks_xy.device
)

tracks_pad = tracks_xy[:, 1:]
visible_pad = visible[:, 1:]

visible_align = visible_pad.view(T - 1, 4, *visible_pad.shape[2:]).sum(1)
tracks_align = (tracks_pad * visible_pad).view(T - 1, 4, *tracks_pad.shape[2:]).sum(
1
) / (visible_align + 1e-5)
dist_ = (
(tracks_align[:, None, None] - grid[None, :, :, None]).pow(2).sum(-1)
) # T, H, W, N
weight = torch.exp(-dist_ * temperature) * visible_align.clamp(0, 1).view(
T - 1, 1, 1, N
)
vert_weight, vert_index = torch.topk(
weight, k=min(topk, weight.shape[-1]), dim=-1
)

grid_mode = "bilinear"
point_feature = torch.nn.functional.grid_sample(
vid[vae_divide[0]:].permute(1, 0, 2, 3)[:1],
tracks_n[:, :1].type(vid.dtype),
mode=grid_mode,
padding_mode="zeros",
align_corners=False,
)
point_feature = point_feature.squeeze(0).squeeze(1).permute(1, 0) # N, C=16

out_feature = merge_final(point_feature, vert_weight, vert_index).permute(3, 0, 1, 2) # T - 1, H, W, C => C, T - 1, H, W
out_weight = vert_weight.sum(-1) # T - 1, H, W

# out feature -> already soft weighted
mix_feature = out_feature + vid[vae_divide[0]:, 1:] * (1 - out_weight.clamp(0, 1))

out_feature_full = torch.cat([vid[vae_divide[0]:, :1], mix_feature], dim=1) # C, T, H, W
out_mask_full = torch.cat([torch.ones_like(out_weight[:1]), out_weight], dim=0) # T, H, W
return torch.cat([out_mask_full[None].expand(vae_divide[0], -1, -1, -1), out_feature_full], dim=0)

class WAN21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
Expand Down Expand Up @@ -1117,7 +1223,14 @@ def concat_cond(self, **kwargs):
mask = mask.repeat(1, 4, 1, 1, 1)
mask = utils.resize_to_batch_size(mask, noise.shape[0])

return torch.cat((mask, image), dim=1)
res = torch.cat((mask, image), dim=1)

tracks = kwargs.get("tracks", None)
if tracks is not None:
res = patch_motion(tracks.to(device), res[0], kwargs.get("ati_temperature", None), (4, 16), kwargs.get("ati_topk", None))[None]

return res


def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
Expand Down
215 changes: 214 additions & 1 deletion comfy_extras/nodes_wan.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import math
import nodes
import node_helpers
import torch
import comfy.model_management
import comfy.utils
import comfy.latent_formats
import comfy.clip_vision

import json
import numpy as np
from typing import Tuple

class WanImageToVideo:
@classmethod
Expand Down Expand Up @@ -383,7 +386,217 @@ def encode(self, positive, negative, vae, width, height, length, batch_size, ima
out_latent["samples"] = latent
return (positive, cond2, negative, out_latent)

def parse_json_tracks(tracks):
"""Parse JSON track data into a standardized format"""
tracks_data = []
try:
# If tracks is a string, try to parse it as JSON
if isinstance(tracks, str):
parsed = json.loads(tracks.replace("'", '"'))
tracks_data.extend(parsed)
else:
# If tracks is a list of strings, parse each one
for track_str in tracks:
parsed = json.loads(track_str.replace("'", '"'))
tracks_data.append(parsed)

# Check if we have a single track (dict with x,y) or a list of tracks
if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]:
# Single track detected, wrap it in a list
tracks_data = [tracks_data]
elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]:
# Already a list of tracks, nothing to do
pass
else:
# Unexpected format
pass

except json.JSONDecodeError:
tracks_data = []
return tracks_data

def tracks_to_tensor(tracks_data, length, width, height, batch_size=1):
"""Convert parsed track data to tensor format (B, T, N, 4)"""
if not tracks_data:
# Return empty tracks if no data
return torch.zeros((batch_size, length, 1, 4))

num_tracks = len(tracks_data)
tracks_tensor = torch.zeros((batch_size, length, num_tracks, 4))

for batch_idx in range(batch_size):
for track_idx, track in enumerate(tracks_data):
for frame_idx in range(min(length, len(track))):
point = track[frame_idx]
if isinstance(point, dict):
x = point.get('x', 0)
y = point.get('y', 0)
# Normalize coordinates to [-1, 1] range
x_norm = (x / width) * 2 - 1
y_norm = (y / height) * 2 - 1
visible = point.get('visible', 1)

tracks_tensor[batch_idx, frame_idx, track_idx] = torch.tensor([
track_idx, # track_id
x_norm, # x coordinate
y_norm, # y coordinate
visible # visibility
])

return tracks_tensor

def process_tracks(tracks_np: np.ndarray, frame_size: Tuple[int, int], num_frames, quant_multi: int = 8, **kwargs):
# tracks: shape [t, h, w, 3] => samples align with 24 fps, model trained with 16 fps.
# frame_size: tuple (W, H)
tracks = torch.from_numpy(tracks_np).float()

if tracks.shape[1] == 121:
tracks = torch.permute(tracks, (1, 0, 2, 3))

tracks, visibles = tracks[..., :2], tracks[..., 2:3]

short_edge = min(*frame_size)

frame_center = torch.tensor([*frame_size]).type_as(tracks) / 2
tracks = tracks - frame_center

tracks = tracks / short_edge * 2

visibles = visibles * 2 - 1

trange = torch.linspace(-1, 1, tracks.shape[0]).view(-1, 1, 1, 1).expand(*visibles.shape)

out_ = torch.cat([trange, tracks, visibles], dim=-1).view(121, -1, 4)

out_0 = out_[:1]

out_l = out_[1:] # 121 => 120 | 1
a = 120 // math.gcd(120, num_frames)
b = num_frames // math.gcd(120, num_frames)
out_l = torch.repeat_interleave(out_l, b, dim=0)[1::a] # 120 => 120 * b => 120 * b / a == F

final_result = torch.cat([out_0, out_l], dim=0)

return final_result

FIXED_LENGTH = 121
def pad_pts(tr):
"""Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating."""
pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32)
n = pts.shape[0]
if n < FIXED_LENGTH:
pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32)
pts = np.vstack((pts, pad))
else:
pts = pts[:FIXED_LENGTH]
return pts.reshape(FIXED_LENGTH, 1, 3)

class WanTrackToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"tracks": ("STRING", {"multiline": True, "default": "[]"}),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
"topk": ("INT", {"default": 2, "min": 1, "max": 10}),
"start_image": ("IMAGE", ),
},
"optional": {
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
}}

RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"

CATEGORY = "conditioning/video_models"

def encode(self, positive, negative, vae, tracks, width, height, length, batch_size,
temperature, topk, start_image=None, clip_vision_output=None):

# Parse tracks from JSON
tracks_data = parse_json_tracks(tracks)

if not tracks_data:
return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output)

latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
device=comfy.model_management.intermediate_device())
# Convert tracks to tensor format
arrs = []
for i, track in enumerate(tracks_data):
pts = pad_pts(track)
arrs.append(pts)

tracks_np = np.stack(arrs, axis=0)
processed_tracks = process_tracks(tracks_np, (width, height), length - 1).unsqueeze(0)

if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)

lat_h = height // 8
lat_w = width // 8

msk = torch.ones(1, length, lat_h, lat_w, device=start_image.device)
msk[:, 1:] = 0

# repeat first frame 4 times
msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
],
dim=1)

# Reshape mask into groups of 4 frames
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)

# first batch
msk = msk.transpose(1, 2)

dummy_frames = torch.ones(3, length - 1, height, width) * .5

start_image = start_image.permute(3,0,1,2) # C, T, H, W
res = torch.concat([
start_image.to(start_image.device),
dummy_frames
],
dim=1).to(start_image.device)

res = res.permute(1,2,3,0)[:, :, :, :3] # T, H, W, C

y = vae.encode(res)

# Add motion features to conditioning
positive = node_helpers.conditioning_set_values(positive,
{"tracks": processed_tracks,
"concat_mask": msk,
"concat_latent_image": y,
"ati_temperature": temperature,
"ati_topk": topk})
negative = node_helpers.conditioning_set_values(negative,
{"tracks": processed_tracks,
"concat_mask": msk,
"concat_latent_image": y,
"ati_temperature": temperature,
"ati_topk": topk})


# Handle clip vision output if provided
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})

out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)

NODE_CLASS_MAPPINGS = {
"WanTrackToVideo": WanTrackToVideo,
"WanImageToVideo": WanImageToVideo,
"WanFunControlToVideo": WanFunControlToVideo,
"WanFunInpaintToVideo": WanFunInpaintToVideo,
Expand Down
Loading