diff --git a/nodes.py b/nodes.py index 834d335d..3b87ddc3 100644 --- a/nodes.py +++ b/nodes.py @@ -7,7 +7,9 @@ import hashlib from diffusers.schedulers import FlowMatchEulerDiscreteScheduler -from .wanvideo.modules.model import rope_params +# Importing the Vace related classes +from .wanvideo.modules.model import rope_params,VaceWanModel, VaceWanAttentionBlock, BaseWanAttentionBlock + from .custom_linear import remove_lora_from_module, set_lora_params from .wanvideo.schedulers import get_scheduler, get_sampling_sigmas, retrieve_timesteps, scheduler_list from .gguf.gguf import set_lora_params_gguf @@ -27,6 +29,10 @@ from comfy.cli_args import args, LatentPreviewMethod import folder_paths +# Import the necessary FramePack classes +from .wanvideo.framepack_vace import FramepackVace +from .wanvideo.wan_video_vae import WanVideoVAE + script_directory = os.path.dirname(os.path.abspath(__file__)) device = mm.get_torch_device() @@ -3191,6 +3197,99 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i "samples": callback_latent.unsqueeze(0).cpu() if callback is not None else None, }) +# Framepack VACE specific Sampler +class WanVACEVideoFramepackSampler: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("WANVIDEOMODEL",), # Expects the loaded FramepackVace model + "steps": ("INT", {"default": 30, "min": 1}), + "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), + "shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "scheduler": (scheduler_list, {"default": "uni_pc",}), + "text_embeds": ("WANVIDEOTEXTEMBEDS", ), + "frame_num": ("INT", {"default": 81, "min": 1}), # Total number of frames for the output video + "context_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.1}), + "image_width": ("INT", {"default": 832, "min": 16}), + "image_height": ("INT", {"default": 480, "min": 16}), + "src_video": ("VIDEO", {"default": None}), + "src_mask": ("MASK", {"default": None}), + "src_ref_images": ("IMAGE", {"default": None}), + "force_offload": ("BOOLEAN", {"default": True, "tooltip": "Moves the model to the offload device after sampling"}), + }, + "optional": { + "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "freeinit_args": ("FREEINITARGS", ), + "start_step": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Start step for the sampling, 0 means full sampling, otherwise samples only from this step"}), + "end_step": ("INT", {"default": -1, "min": -1, "max": 10000, "step": 1, "tooltip": "End step for the sampling, -1 means full sampling, otherwise samples only until this step"}), + } + } + + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("samples",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "A sampler specifically for the FramePack algorithm for long video generation." + + def process(self, model, steps, cfg, shift, seed, scheduler, text_embeds, frame_num, context_scale, + image_width, image_height, src_video=None, src_mask=None, src_ref_images=None, force_offload=True, + denoise_strength=1.0, freeinit_args=None, start_step=0, end_step=-1): + # Ensure the provided model is an instance of FramepackVace + if not isinstance(model.model, FramepackVace): + raise TypeError("This sampler requires a FramepackVace model. Please check your model loader node.") + + # Get the FramepackVace instance and the current device + framepack_vace = model.model + current_device = mm.get_torch_device() + + # Extract positive and negative prompts from the text_embeds dictionary + prompt = text_embeds["prompt_embeds"][0] if text_embeds["prompt_embeds"] else "" + n_prompt = text_embeds["negative_prompt_embeds"][0][0] if text_embeds["negative_prompt_embeds"] else "" + + # ComfyUI's VIDEO format is (B, F, H, W, C). The model expects (B, C, F, H, W). + # Handle src_video: (B, F, H, W, C) -> (B, C, F, H, W) + input_frames_list = [src_video.permute(0, 4, 1, 2, 3)] if src_video is not None else [None] + + # Handle src_mask: (B, H, W) -> (B, 1, F, H, W) + input_masks_list = [src_mask.unsqueeze(0).unsqueeze(0).permute(0, 1, 4, 2, 3)] if src_mask is not None else [None] + + # Handle src_ref_images: (B, H, W, C) -> (B, C, 1, H, W) + input_ref_images_list = [src_ref_images.permute(0, 3, 1, 2).unsqueeze(1)] if src_ref_images is not None else [None] + + # Calling the FramepackVace instance. + src_video_prepared, src_mask_prepared, src_ref_images_prepared = framepack_vace.prepare_source( + input_frames_list, + input_masks_list, + input_ref_images_list, + frame_num, + (image_height, image_width), + current_device + ) + + # Calling FramePack generation method. + log.info(f"Starting FramePack generation for {frame_num} frames.") + final_video_latent = framepack_vace.generate_with_framepack( + input_prompt=prompt, + input_frames=src_video_prepared, + input_masks=src_mask_prepared, + input_ref_images=src_ref_images_prepared, + size=(image_width, image_height), + frame_num=frame_num, + sample_solver=scheduler, + sampling_steps=steps, + guide_scale=cfg, + n_prompt=n_prompt, + seed=seed, + offload_model=force_offload + ) + + log.info(f"FramePack generation complete. Output latent tensor shape: {final_video_latent.shape}") + + # The output of generate_with_framepack is a single tensor (C, T, H, W). + return ({"samples": final_video_latent.unsqueeze(0).cpu(),},) + #region VideoDecode class WanVideoDecode: @classmethod @@ -3432,6 +3531,7 @@ def encode(self, samples, direction): "WanVideoTextEncodeCached": WanVideoTextEncodeCached, "WanVideoAddExtraLatent": WanVideoAddExtraLatent, "WanVideoLatentReScale": WanVideoLatentReScale, + "WanVACEVideoFramepackSampler": WanVACEVideoFramepackSampler, } NODE_DISPLAY_NAME_MAPPINGS = { "WanVideoSampler": "WanVideo Sampler", @@ -3463,4 +3563,5 @@ def encode(self, samples, direction): "WanVideoTextEncodeCached": "WanVideo TextEncode Cached", "WanVideoAddExtraLatent": "WanVideo Add Extra Latent", "WanVideoLatentReScale": "WanVideo Latent ReScale", + "WanVACEVideoFramepackSampler": "WanVideo Framepack Sampler", } diff --git a/nodes_framepack.py b/nodes_framepack.py new file mode 100644 index 00000000..3979ad0d --- /dev/null +++ b/nodes_framepack.py @@ -0,0 +1,3448 @@ +# Initializing nodes_framepack.py +import os, gc, math +import torch +import torch.nn.functional as F +import numpy as np +from tqdm import tqdm +import inspect +import hashlib +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler + +from .wanvideo.modules.model import rope_params +from .custom_linear import remove_lora_from_module, set_lora_params +from .wanvideo.schedulers import get_scheduler, get_sampling_sigmas, retrieve_timesteps, scheduler_list +from .gguf.gguf import set_lora_params_gguf +from .multitalk.multitalk import timestep_transform, add_noise +from .utils import(log, print_memory, apply_lora, clip_encode_image_tiled, fourier_filter, + add_noise_to_reference_video, optimized_scale, setup_radial_attention, + compile_model, dict_to_device, tangential_projection, set_module_tensor_to_device, get_raag_guidance) +from .cache_methods.cache_methods import cache_report +from .enhance_a_video.globals import set_enhance_weight, set_num_frames +from .taehv import TAEHV + +from einops import rearrange + +from comfy import model_management as mm +from comfy.utils import ProgressBar, common_upscale +from comfy.clip_vision import clip_preprocess, ClipVisionModel +from comfy.cli_args import args, LatentPreviewMethod +import folder_paths + +script_directory = os.path.dirname(os.path.abspath(__file__)) + +device = mm.get_torch_device() +offload_device = mm.unet_offload_device() + +VAE_STRIDE = (4, 8, 8) +PATCH_SIZE = (1, 2, 2) + +def offload_transformer(transformer): + transformer.teacache_state.clear_all() + transformer.magcache_state.clear_all() + transformer.easycache_state.clear_all() + transformer.to(offload_device) + mm.soft_empty_cache() + gc.collect() + +class WanVideoEnhanceAVideo: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "weight": ("FLOAT", {"default": 2.0, "min": 0, "max": 100, "step": 0.01, "tooltip": "The feta Weight of the Enhance-A-Video"}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply Enhance-A-Video"}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply Enhance-A-Video"}), + }, + } + RETURN_TYPES = ("FETAARGS",) + RETURN_NAMES = ("feta_args",) + FUNCTION = "setargs" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video" + + def setargs(self, **kwargs): + return (kwargs, ) + +class WanVideoSetBlockSwap: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("WANVIDEOMODEL", ), + }, + "optional": { + "block_swap_args": ("BLOCKSWAPARGS", ), + } + } + + RETURN_TYPES = ("WANVIDEOMODEL",) + RETURN_NAMES = ("model", ) + FUNCTION = "loadmodel" + CATEGORY = "WanVideoWrapper" + + def loadmodel(self, model, block_swap_args=None): + if block_swap_args is None: + return (model,) + patcher = model.clone() + if 'transformer_options' not in patcher.model_options: + patcher.model_options['transformer_options'] = {} + patcher.model_options["transformer_options"]["block_swap_args"] = block_swap_args + + return (patcher,) + +class WanVideoSetRadialAttention: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("WANVIDEOMODEL", ), + "dense_attention_mode": ([ + "sdpa", + "flash_attn_2", + "flash_attn_3", + "sageattn", + "sparse_sage_attention", + ], {"default": "sageattn", "tooltip": "The attention mode for dense attention"}), + "dense_blocks": ("INT", {"default": 1, "min": 0, "max": 40, "step": 1, "tooltip": "Number of blocks to apply normal attention to"}), + "dense_vace_blocks": ("INT", {"default": 1, "min": 0, "max": 15, "step": 1, "tooltip": "Number of vace blocks to apply normal attention to"}), + "dense_timesteps": ("INT", {"default": 2, "min": 0, "max": 100, "step": 1, "tooltip": "The step to start applying sparse attention"}), + "decay_factor": ("FLOAT", {"default": 0.2, "min": 0, "max": 1, "step": 0.01, "tooltip": "Controls how quickly the attention window shrinks as the distance between frames increases in the sparse attention mask."}), + "block_size":([128, 64], {"default": 128, "tooltip": "Radial attention block size, larger blocks are faster but restricts usable dimensions more."}), + } + } + + RETURN_TYPES = ("WANVIDEOMODEL",) + RETURN_NAMES = ("model", ) + FUNCTION = "loadmodel" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Sets radial attention parameters, dense attention refers to normal attention" + + def loadmodel(self, model, dense_attention_mode, dense_blocks, dense_vace_blocks, dense_timesteps, decay_factor, block_size): + if "radial" not in model.model.diffusion_model.attention_mode: + raise Exception("Enable radial attention first in the model loader.") + + patcher = model.clone() + if 'transformer_options' not in patcher.model_options: + patcher.model_options['transformer_options'] = {} + + patcher.model_options["transformer_options"]["dense_attention_mode"] = dense_attention_mode + patcher.model_options["transformer_options"]["dense_blocks"] = dense_blocks + patcher.model_options["transformer_options"]["dense_vace_blocks"] = dense_vace_blocks + patcher.model_options["transformer_options"]["dense_timesteps"] = dense_timesteps + patcher.model_options["transformer_options"]["decay_factor"] = decay_factor + patcher.model_options["transformer_options"]["block_size"] = block_size + + return (patcher,) + +class WanVideoBlockList: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "blocks": ("STRING", {"default": "1", "multiline":True}), + } + } + + RETURN_TYPES = ("INT",) + RETURN_NAMES = ("block_list", ) + FUNCTION = "create_list" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Comma separated list of blocks to apply block swap to, can also use ranges like '0-5' or '0,2,3-5' etc., can be connected to the dense_blocks input of 'WanVideoSetRadialAttention' node" + + def create_list(self, blocks): + block_list = [] + for line in blocks.splitlines(): + for part in line.split(","): + part = part.strip() + if not part: + continue + if "-" in part: + try: + start, end = map(int, part.split("-", 1)) + block_list.extend(range(start, end + 1)) + except Exception: + raise ValueError(f"Invalid range: '{part}'") + else: + try: + block_list.append(int(part)) + except Exception: + raise ValueError(f"Invalid integer: '{part}'") + return (block_list,) + +# In-memory cache for prompt extender output +_extender_cache = {} + +cache_dir = os.path.join(script_directory, 'text_embed_cache') + +def get_cache_path(prompt): + cache_key = prompt.strip() + cache_hash = hashlib.sha256(cache_key.encode('utf-8')).hexdigest() + return os.path.join(cache_dir, f"{cache_hash}.pt") + +def get_cached_text_embeds(positive_prompt, negative_prompt): + + os.makedirs(cache_dir, exist_ok=True) + + context = None + context_null = None + + pos_cache_path = get_cache_path(positive_prompt) + neg_cache_path = get_cache_path(negative_prompt) + + # Try to load positive prompt embeds + if os.path.exists(pos_cache_path): + try: + log.info(f"Loading prompt embeds from cache: {pos_cache_path}") + context = torch.load(pos_cache_path) + except Exception as e: + log.warning(f"Failed to load cache: {e}, will re-encode.") + + # Try to load negative prompt embeds + if os.path.exists(neg_cache_path): + try: + log.info(f"Loading prompt embeds from cache: {neg_cache_path}") + context_null = torch.load(neg_cache_path) + except Exception as e: + log.warning(f"Failed to load cache: {e}, will re-encode.") + + return context, context_null + +class WanVideoTextEncodeCached: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model_name": (folder_paths.get_filename_list("text_encoders"), {"tooltip": "These models are loaded from 'ComfyUI/models/text_encoders'"}), + "precision": (["fp32", "bf16"], + {"default": "bf16"} + ), + "positive_prompt": ("STRING", {"default": "", "multiline": True} ), + "negative_prompt": ("STRING", {"default": "", "multiline": True} ), + "quantization": (['disabled', 'fp8_e4m3fn'], {"default": 'disabled', "tooltip": "optional quantization method"}), + "use_disk_cache": ("BOOLEAN", {"default": True, "tooltip": "Cache the text embeddings to disk for faster re-use, under the custom_nodes/ComfyUI-WanVideoWrapper/text_embed_cache directory"}), + "device": (["gpu", "cpu"], {"default": "gpu", "tooltip": "Device to run the text encoding on."}), + }, + "optional": { + "extender_args": ("WANVIDEOPROMPTEXTENDER_ARGS", {"tooltip": "Use this node to extend the prompt with additional text."}), + } + } + + RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", "WANVIDEOTEXTEMBEDS", "STRING") + RETURN_NAMES = ("text_embeds", "negative_text_embeds", "positive_prompt") + OUTPUT_TOOLTIPS = ("The text embeddings for both prompts", "The text embeddings for the negative prompt only (for NAG)", "Positive prompt to display prompt extender results") + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = """Encodes text prompts into text embeddings. This node loads and completely unloads the T5 after done, +leaving no VRAM or RAM imprint. If prompts have been cached before T5 is not loaded at all. +negative output is meant to be used with NAG, it contains only negative prompt embeddings. + +Additionally you can provide a Qwen LLM model to extend the positive prompt with either one +of the original Wan templates or a custom system prompt. +""" + + + def process(self, model_name, precision, positive_prompt, negative_prompt, quantization='disabled', use_disk_cache=True, device="gpu", extender_args=None): + from .nodes_model_loading import LoadWanVideoT5TextEncoder + pbar = ProgressBar(3) + + echoshot = True if "[1]" in positive_prompt else False + + # Handle prompt extension with in-memory cache + orig_prompt = positive_prompt + if extender_args is not None: + extender_key = (orig_prompt, str(extender_args)) + if extender_key in _extender_cache: + positive_prompt = _extender_cache[extender_key] + log.info(f"Loaded extended prompt from in-memory cache: {positive_prompt}") + else: + from .qwen.qwen import QwenLoader, WanVideoPromptExtender + log.info("Using WanVideoPromptExtender to process prompts") + qwen, = QwenLoader().load( + extender_args["model"], + load_device="main_device" if device == "gpu" else "cpu", + precision=precision) + positive_prompt, = WanVideoPromptExtender().generate( + qwen=qwen, + max_new_tokens=extender_args["max_new_tokens"], + prompt=orig_prompt, + device=device, + force_offload=False, + custom_system_prompt=extender_args["system_prompt"], + seed=extender_args["seed"] + ) + log.info(f"Extended positive prompt: {positive_prompt}") + _extender_cache[extender_key] = positive_prompt + del qwen + pbar.update(1) + + # Now check disk cache using the (possibly extended) prompt + if use_disk_cache: + context, context_null = get_cached_text_embeds(positive_prompt, negative_prompt) + if context is not None and context_null is not None: + return{ + "prompt_embeds": context, + "negative_prompt_embeds": context_null, + "echoshot": echoshot, + },{"prompt_embeds": context_null}, positive_prompt + + t5, = LoadWanVideoT5TextEncoder().loadmodel(model_name, precision, "main_device", quantization) + pbar.update(1) + + prompt_embeds_dict, = WanVideoTextEncode().process( + positive_prompt=positive_prompt, + negative_prompt=negative_prompt, + t5=t5, + force_offload=False, + model_to_offload=None, + use_disk_cache=use_disk_cache, + device=device + ) + pbar.update(1) + del t5 + mm.soft_empty_cache() + gc.collect() + return (prompt_embeds_dict, {"prompt_embeds": prompt_embeds_dict["negative_prompt_embeds"]}, positive_prompt) + +#region TextEncode +class WanVideoTextEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "positive_prompt": ("STRING", {"default": "", "multiline": True} ), + "negative_prompt": ("STRING", {"default": "", "multiline": True} ), + }, + "optional": { + "t5": ("WANTEXTENCODER",), + "force_offload": ("BOOLEAN", {"default": True}), + "model_to_offload": ("WANVIDEOMODEL", {"tooltip": "Model to move to offload_device before encoding"}), + "use_disk_cache": ("BOOLEAN", {"default": False, "tooltip": "Cache the text embeddings to disk for faster re-use, under the custom_nodes/ComfyUI-WanVideoWrapper/text_embed_cache directory"}), + "device": (["gpu", "cpu"], {"default": "gpu", "tooltip": "Device to run the text encoding on."}), + } + } + + RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) + RETURN_NAMES = ("text_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Encodes text prompts into text embeddings. For rudimentary prompt travel you can input multiple prompts separated by '|', they will be equally spread over the video length" + + + def process(self, positive_prompt, negative_prompt, t5=None, force_offload=True, model_to_offload=None, use_disk_cache=False, device="gpu"): + if t5 is None and not use_disk_cache: + raise ValueError("T5 encoder is required for text encoding. Please provide a valid T5 encoder or enable disk cache.") + + echoshot = True if "[1]" in positive_prompt else False + + if use_disk_cache: + context, context_null = get_cached_text_embeds(positive_prompt, negative_prompt) + if context is not None and context_null is not None: + return{ + "prompt_embeds": context, + "negative_prompt_embeds": context_null, + "echoshot": echoshot, + }, + + if t5 is None: + raise ValueError("No cached text embeds found for prompts, please provide a T5 encoder.") + + if model_to_offload is not None and device == "gpu": + log.info(f"Moving video model to {offload_device}") + model_to_offload.model.to(offload_device) + + encoder = t5["model"] + dtype = t5["dtype"] + + positive_prompts = [] + all_weights = [] + + # Split positive prompts and process each with weights + if "|" in positive_prompt: + log.info("Multiple positive prompts detected, splitting by '|'") + positive_prompts_raw = [p.strip() for p in positive_prompt.split('|')] + elif "[1]" in positive_prompt: + log.info("Multiple positive prompts detected, splitting by [#] and enabling EchoShot") + import re + segments = re.split(r'\[\d+\]', positive_prompt) + positive_prompts_raw = [segment.strip() for segment in segments if segment.strip()] + assert len(positive_prompts_raw) > 1 and len(positive_prompts_raw) < 7, 'Input shot num must between 2~6 !' + else: + positive_prompts_raw = [positive_prompt.strip()] + + for p in positive_prompts_raw: + cleaned_prompt, weights = self.parse_prompt_weights(p) + positive_prompts.append(cleaned_prompt) + all_weights.append(weights) + + mm.soft_empty_cache() + + if device == "gpu": + device_to = mm.get_torch_device() + else: + device_to = torch.device("cpu") + + if encoder.quantization == "fp8_e4m3fn": + cast_dtype = torch.float8_e4m3fn + else: + cast_dtype = encoder.dtype + + params_to_keep = {'norm', 'pos_embedding', 'token_embedding'} + for name, param in encoder.model.named_parameters(): + dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype + value = encoder.state_dict[name] if hasattr(encoder, 'state_dict') else encoder.model.state_dict()[name] + set_module_tensor_to_device(encoder.model, name, device=device_to, dtype=dtype_to_use, value=value) + if hasattr(encoder, 'state_dict'): + del encoder.state_dict + mm.soft_empty_cache() + gc.collect() + + with torch.autocast(device_type=mm.get_autocast_device(device_to), dtype=encoder.dtype, enabled=encoder.quantization != 'disabled'): + # Encode positive if not loaded from cache + if use_disk_cache and context is not None: + pass + else: + context = encoder(positive_prompts, device_to) + # Apply weights to embeddings if any were extracted + for i, weights in enumerate(all_weights): + for text, weight in weights.items(): + log.info(f"Applying weight {weight} to prompt: {text}") + if len(weights) > 0: + context[i] = context[i] * weight + + # Encode negative if not loaded from cache + if use_disk_cache and context_null is not None: + pass + else: + context_null = encoder([negative_prompt], device_to) + + if force_offload: + encoder.model.to(offload_device) + mm.soft_empty_cache() + gc.collect() + + prompt_embeds_dict = { + "prompt_embeds": context, + "negative_prompt_embeds": context_null, + "echoshot": echoshot, + } + + # Save each part to its own cache file if needed + if use_disk_cache: + pos_cache_path = get_cache_path(positive_prompt) + neg_cache_path = get_cache_path(negative_prompt) + try: + if not os.path.exists(pos_cache_path): + torch.save(context, pos_cache_path) + log.info(f"Saved prompt embeds to cache: {pos_cache_path}") + except Exception as e: + log.warning(f"Failed to save cache: {e}") + try: + if not os.path.exists(neg_cache_path): + torch.save(context_null, neg_cache_path) + log.info(f"Saved prompt embeds to cache: {neg_cache_path}") + except Exception as e: + log.warning(f"Failed to save cache: {e}") + + return (prompt_embeds_dict,) + + def parse_prompt_weights(self, prompt): + """Extract text and weights from prompts with (text:weight) format""" + import re + + # Parse all instances of (text:weight) in the prompt + pattern = r'\((.*?):([\d\.]+)\)' + matches = re.findall(pattern, prompt) + + # Replace each match with just the text part + cleaned_prompt = prompt + weights = {} + + for match in matches: + text, weight = match + orig_text = f"({text}:{weight})" + cleaned_prompt = cleaned_prompt.replace(orig_text, text) + weights[text] = float(weight) + + return cleaned_prompt, weights + +class WanVideoTextEncodeSingle: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "prompt": ("STRING", {"default": "", "multiline": True} ), + }, + "optional": { + "t5": ("WANTEXTENCODER",), + "force_offload": ("BOOLEAN", {"default": True}), + "model_to_offload": ("WANVIDEOMODEL", {"tooltip": "Model to move to offload_device before encoding"}), + "use_disk_cache": ("BOOLEAN", {"default": False, "tooltip": "Cache the text embeddings to disk for faster re-use, under the custom_nodes/ComfyUI-WanVideoWrapper/text_embed_cache directory"}), + "device": (["gpu", "cpu"], {"default": "gpu", "tooltip": "Device to run the text encoding on."}), + } + } + + RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) + RETURN_NAMES = ("text_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Encodes text prompt into text embedding." + + def process(self, prompt, t5=None, force_offload=True, model_to_offload=None, use_disk_cache=False, device="gpu"): + # Unified cache logic: use a single cache file per unique prompt + encoded = None + echoshot = True if "[1]" in prompt else False + if use_disk_cache: + cache_dir = os.path.join(script_directory, 'text_embed_cache') + os.makedirs(cache_dir, exist_ok=True) + def get_cache_path(prompt): + cache_key = prompt.strip() + cache_hash = hashlib.sha256(cache_key.encode('utf-8')).hexdigest() + return os.path.join(cache_dir, f"{cache_hash}.pt") + cache_path = get_cache_path(prompt) + if os.path.exists(cache_path): + try: + log.info(f"Loading prompt embeds from cache: {cache_path}") + encoded = torch.load(cache_path) + except Exception as e: + log.warning(f"Failed to load cache: {e}, will re-encode.") + + if t5 is None and encoded is None: + raise ValueError("No cached text embeds found for prompts, please provide a T5 encoder.") + + if encoded is None: + if model_to_offload is not None and device == "gpu": + log.info(f"Moving video model to {offload_device}") + model_to_offload.model.to(offload_device) + mm.soft_empty_cache() + + encoder = t5["model"] + dtype = t5["dtype"] + + if device == "gpu": + device_to = mm.get_torch_device() + else: + device_to = torch.device("cpu") + + if encoder.quantization == "fp8_e4m3fn": + cast_dtype = torch.float8_e4m3fn + else: + cast_dtype = encoder.dtype + params_to_keep = {'norm', 'pos_embedding', 'token_embedding'} + for name, param in encoder.model.named_parameters(): + dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype + value = encoder.state_dict[name] if hasattr(encoder, 'state_dict') else encoder.model.state_dict()[name] + set_module_tensor_to_device(encoder.model, name, device=device_to, dtype=dtype_to_use, value=value) + if hasattr(encoder, 'state_dict'): + del encoder.state_dict + mm.soft_empty_cache() + gc.collect() + with torch.autocast(device_type=mm.get_autocast_device(device_to), dtype=encoder.dtype, enabled=encoder.quantization != 'disabled'): + encoded = encoder([prompt], device_to) + + if force_offload: + encoder.model.to(offload_device) + mm.soft_empty_cache() + + # Save to cache if enabled + if use_disk_cache: + try: + if not os.path.exists(cache_path): + torch.save(encoded, cache_path) + log.info(f"Saved prompt embeds to cache: {cache_path}") + except Exception as e: + log.warning(f"Failed to save cache: {e}") + + prompt_embeds_dict = { + "prompt_embeds": encoded, + "negative_prompt_embeds": None, + "echoshot": echoshot + } + return (prompt_embeds_dict,) + +class WanVideoApplyNAG: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "original_text_embeds": ("WANVIDEOTEXTEMBEDS",), + "nag_text_embeds": ("WANVIDEOTEXTEMBEDS",), + "nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.1}), + "nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.1}), + "nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}), + }, + } + + RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) + RETURN_NAMES = ("text_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Adds NAG prompt embeds to original prompt embeds: 'https://github.com/ChenDarYen/Normalized-Attention-Guidance'" + + def process(self, original_text_embeds, nag_text_embeds, nag_scale, nag_tau, nag_alpha): + prompt_embeds_dict_copy = original_text_embeds.copy() + prompt_embeds_dict_copy.update({ + "nag_prompt_embeds": nag_text_embeds["prompt_embeds"], + "nag_params": { + "nag_scale": nag_scale, + "nag_tau": nag_tau, + "nag_alpha": nag_alpha, + } + }) + return (prompt_embeds_dict_copy,) + +class WanVideoTextEmbedBridge: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "positive": ("CONDITIONING",), + }, + "optional": { + "negative": ("CONDITIONING",), + } + } + + RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) + RETURN_NAMES = ("text_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Bridge between ComfyUI native text embedding and WanVideoWrapper text embedding" + + def process(self, positive, negative=None): + prompt_embeds_dict = { + "prompt_embeds": positive[0][0].to(device), + "negative_prompt_embeds": negative[0][0].to(device) if negative is not None else None, + } + return (prompt_embeds_dict,) + +#region clip vision +class WanVideoClipVisionEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip_vision": ("CLIP_VISION",), + "image_1": ("IMAGE", {"tooltip": "Image to encode"}), + "strength_1": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}), + "strength_2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}), + "crop": (["center", "disabled"], {"default": "center", "tooltip": "Crop image to 224x224 before encoding"}), + "combine_embeds": (["average", "sum", "concat", "batch"], {"default": "average", "tooltip": "Method to combine multiple clip embeds"}), + "force_offload": ("BOOLEAN", {"default": True}), + }, + "optional": { + "image_2": ("IMAGE", ), + "negative_image": ("IMAGE", {"tooltip": "image to use for uncond"}), + "tiles": ("INT", {"default": 0, "min": 0, "max": 16, "step": 2, "tooltip": "Use matteo's tiled image encoding for improved accuracy"}), + "ratio": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Ratio of the tile average"}), + } + } + + RETURN_TYPES = ("WANVIDIMAGE_CLIPEMBEDS",) + RETURN_NAMES = ("image_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, clip_vision, image_1, strength_1, strength_2, force_offload, crop, combine_embeds, image_2=None, negative_image=None, tiles=0, ratio=1.0): + image_mean = [0.48145466, 0.4578275, 0.40821073] + image_std = [0.26862954, 0.26130258, 0.27577711] + + if image_2 is not None: + image = torch.cat([image_1, image_2], dim=0) + else: + image = image_1 + + clip_vision.model.to(device) + + negative_clip_embeds = None + + if tiles > 0: + log.info("Using tiled image encoding") + clip_embeds = clip_encode_image_tiled(clip_vision, image.to(device), tiles=tiles, ratio=ratio) + if negative_image is not None: + negative_clip_embeds = clip_encode_image_tiled(clip_vision, negative_image.to(device), tiles=tiles, ratio=ratio) + else: + if isinstance(clip_vision, ClipVisionModel): + clip_embeds = clip_vision.encode_image(image).penultimate_hidden_states.to(device) + if negative_image is not None: + negative_clip_embeds = clip_vision.encode_image(negative_image).penultimate_hidden_states.to(device) + else: + pixel_values = clip_preprocess(image.to(device), size=224, mean=image_mean, std=image_std, crop=(not crop == "disabled")).float() + clip_embeds = clip_vision.visual(pixel_values) + if negative_image is not None: + pixel_values = clip_preprocess(negative_image.to(device), size=224, mean=image_mean, std=image_std, crop=(not crop == "disabled")).float() + negative_clip_embeds = clip_vision.visual(pixel_values) + + log.info(f"Clip embeds shape: {clip_embeds.shape}, dtype: {clip_embeds.dtype}") + + weighted_embeds = [] + weighted_embeds.append(clip_embeds[0:1] * strength_1) + + # Handle all additional embeddings + if clip_embeds.shape[0] > 1: + weighted_embeds.append(clip_embeds[1:2] * strength_2) + + if clip_embeds.shape[0] > 2: + for i in range(2, clip_embeds.shape[0]): + weighted_embeds.append(clip_embeds[i:i+1]) # Add as-is without strength modifier + + # Combine all weighted embeddings + if combine_embeds == "average": + clip_embeds = torch.mean(torch.stack(weighted_embeds), dim=0) + elif combine_embeds == "sum": + clip_embeds = torch.sum(torch.stack(weighted_embeds), dim=0) + elif combine_embeds == "concat": + clip_embeds = torch.cat(weighted_embeds, dim=1) + elif combine_embeds == "batch": + clip_embeds = torch.cat(weighted_embeds, dim=0) + else: + clip_embeds = weighted_embeds[0] + + + log.info(f"Combined clip embeds shape: {clip_embeds.shape}") + + if force_offload: + clip_vision.model.to(offload_device) + mm.soft_empty_cache() + + clip_embeds_dict = { + "clip_embeds": clip_embeds, + "negative_clip_embeds": negative_clip_embeds + } + + return (clip_embeds_dict,) + +class WanVideoRealisDanceLatents: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "ref_latent": ("LATENT", {"tooltip": "Reference image to encode"}), + "pose_cond_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the SMPL model"}), + "pose_cond_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the SMPL model"}), + }, + "optional": { + "smpl_latent": ("LATENT", {"tooltip": "SMPL pose image to encode"}), + "hamer_latent": ("LATENT", {"tooltip": "Hamer hand pose image to encode"}), + }, + } + + RETURN_TYPES = ("ADD_COND_LATENTS",) + RETURN_NAMES = ("add_cond_latents",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, ref_latent, pose_cond_start_percent, pose_cond_end_percent, hamer_latent=None, smpl_latent=None): + if smpl_latent is None and hamer_latent is None: + raise Exception("At least one of smpl_latent or hamer_latent must be provided") + if smpl_latent is None: + smpl = torch.zeros_like(hamer_latent["samples"]) + else: + smpl = smpl_latent["samples"] + if hamer_latent is None: + hamer = torch.zeros_like(smpl_latent["samples"]) + else: + hamer = hamer_latent["samples"] + + pose_latent = torch.cat((smpl, hamer), dim=1) + + add_cond_latents = { + "ref_latent": ref_latent["samples"], + "pose_latent": pose_latent, + "pose_cond_start_percent": pose_cond_start_percent, + "pose_cond_end_percent": pose_cond_end_percent, + } + + return (add_cond_latents,) + +class WanVideoImageToVideoEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "vae": ("WANVAE",), + "width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}), + "height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}), + "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), + "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of noise augmentation, helpful for I2V where some noise can add motion and give sharper results"}), + "start_latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for I2V where lower values allow for more motion"}), + "end_latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for I2V where lower values allow for more motion"}), + "force_offload": ("BOOLEAN", {"default": True}), + }, + "optional": { + "clip_embeds": ("WANVIDIMAGE_CLIPEMBEDS", {"tooltip": "Clip vision encoded image"}), + "start_image": ("IMAGE", {"tooltip": "Image to encode"}), + "end_image": ("IMAGE", {"tooltip": "end frame"}), + "control_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "Control signal for the Fun -model"}), + "fun_or_fl2v_model": ("BOOLEAN", {"default": True, "tooltip": "Enable when using official FLF2V or Fun model"}), + "temporal_mask": ("MASK", {"tooltip": "mask"}), + "extra_latents": ("LATENT", {"tooltip": "Extra latents to add to the input front, used for Skyreels A2 reference images"}), + "tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}), + "add_cond_latents": ("ADD_COND_LATENTS", {"advanced": True, "tooltip": "Additional cond latents WIP"}), + } + } + + RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) + RETURN_NAMES = ("image_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, vae, width, height, num_frames, force_offload, noise_aug_strength, + start_latent_strength, end_latent_strength, start_image=None, end_image=None, control_embeds=None, fun_or_fl2v_model=False, + temporal_mask=None, extra_latents=None, clip_embeds=None, tiled_vae=False, add_cond_latents=None): + + H = height + W = width + + lat_h = H // 8 + lat_w = W // 8 + + num_frames = ((num_frames - 1) // 4) * 4 + 1 + two_ref_images = start_image is not None and end_image is not None + + base_frames = num_frames + (1 if two_ref_images and not fun_or_fl2v_model else 0) + if temporal_mask is None: + mask = torch.zeros(1, base_frames, lat_h, lat_w, device=device) + if start_image is not None: + mask[:, 0:start_image.shape[0]] = 1 # First frame + if end_image is not None: + mask[:, -end_image.shape[0]:] = 1 # End frame if exists + else: + mask = common_upscale(temporal_mask.unsqueeze(1).to(device), lat_w, lat_h, "nearest", "disabled").squeeze(1) + if mask.shape[0] > base_frames: + mask = mask[:base_frames] + elif mask.shape[0] < base_frames: + mask = torch.cat([mask, torch.zeros(base_frames - mask.shape[0], lat_h, lat_w, device=device)]) + mask = mask.unsqueeze(0).to(device) + + # Repeat first frame and optionally end frame + start_mask_repeated = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1) # T, C, H, W + if end_image is not None and not fun_or_fl2v_model: + end_mask_repeated = torch.repeat_interleave(mask[:, -1:], repeats=4, dim=1) # T, C, H, W + mask = torch.cat([start_mask_repeated, mask[:, 1:-1], end_mask_repeated], dim=1) + else: + mask = torch.cat([start_mask_repeated, mask[:, 1:]], dim=1) + + # Reshape mask into groups of 4 frames + mask = mask.view(1, mask.shape[1] // 4, 4, lat_h, lat_w) # 1, T, C, H, W + mask = mask.movedim(1, 2)[0]# C, T, H, W + + # Resize and rearrange the input image dimensions + if start_image is not None: + resized_start_image = common_upscale(start_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1) + resized_start_image = resized_start_image * 2 - 1 + if noise_aug_strength > 0.0: + resized_start_image = add_noise_to_reference_video(resized_start_image, ratio=noise_aug_strength) + + if end_image is not None: + resized_end_image = common_upscale(end_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1) + resized_end_image = resized_end_image * 2 - 1 + if noise_aug_strength > 0.0: + resized_end_image = add_noise_to_reference_video(resized_end_image, ratio=noise_aug_strength) + + # Concatenate image with zero frames and encode + vae.to(device) + + if temporal_mask is None: + if start_image is not None and end_image is None: + zero_frames = torch.zeros(3, num_frames-start_image.shape[0], H, W, device=device) + concatenated = torch.cat([resized_start_image.to(device), zero_frames], dim=1) + elif start_image is None and end_image is not None: + zero_frames = torch.zeros(3, num_frames-end_image.shape[0], H, W, device=device) + concatenated = torch.cat([zero_frames, resized_end_image.to(device)], dim=1) + elif start_image is None and end_image is None: + concatenated = torch.zeros(3, num_frames, H, W, device=device) + else: + if fun_or_fl2v_model: + zero_frames = torch.zeros(3, num_frames-(start_image.shape[0]+end_image.shape[0]), H, W, device=device) + else: + zero_frames = torch.zeros(3, num_frames-1, H, W, device=device) + concatenated = torch.cat([resized_start_image.to(device), zero_frames, resized_end_image.to(device)], dim=1) + else: + temporal_mask = common_upscale(temporal_mask.unsqueeze(1), W, H, "nearest", "disabled").squeeze(1) + concatenated = resized_start_image[:,:num_frames] * temporal_mask[:num_frames].unsqueeze(0) + + y = vae.encode([concatenated.to(device=device, dtype=vae.dtype)], device, end_=(end_image is not None and not fun_or_fl2v_model),tiled=tiled_vae)[0] + has_ref = False + if extra_latents is not None: + samples = extra_latents["samples"].squeeze(0) + y = torch.cat([samples, y], dim=1) + mask = torch.cat([torch.ones_like(mask[:, 0:samples.shape[1]]), mask], dim=1) + num_frames += samples.shape[1] * 4 + has_ref = True + y[:, :1] *= start_latent_strength + y[:, -1:] *= end_latent_strength + if control_embeds is None: + y = torch.cat([mask, y]) + else: + if end_image is None: + y[:, 1:] = 0 + elif start_image is None: + y[:, -1:] = 0 + + # Calculate maximum sequence length + patches_per_frame = lat_h * lat_w // (PATCH_SIZE[1] * PATCH_SIZE[2]) + frames_per_stride = (num_frames - 1) // 4 + (2 if end_image is not None and not fun_or_fl2v_model else 1) + max_seq_len = frames_per_stride * patches_per_frame + + if add_cond_latents is not None: + add_cond_latents["ref_latent_neg"] = vae.encode(torch.zeros(1, 3, 1, H, W, device=device, dtype=vae.dtype), device) + + vae.model.clear_cache() + if force_offload: + vae.model.to(offload_device) + mm.soft_empty_cache() + gc.collect() + + image_embeds = { + "image_embeds": y, + "clip_context": clip_embeds.get("clip_embeds", None) if clip_embeds is not None else None, + "negative_clip_context": clip_embeds.get("negative_clip_embeds", None) if clip_embeds is not None else None, + "max_seq_len": max_seq_len, + "num_frames": num_frames, + "lat_h": lat_h, + "lat_w": lat_w, + "control_embeds": control_embeds["control_embeds"] if control_embeds is not None else None, + "end_image": resized_end_image if end_image is not None else None, + "fun_or_fl2v_model": fun_or_fl2v_model, + "has_ref": has_ref, + "add_cond_latents": add_cond_latents, + "mask": mask if control_embeds is not None else None, # for 2.2 Fun control as it can handle masks + } + + return (image_embeds,) + +class WanVideoEmptyEmbeds: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}), + "height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}), + "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), + }, + "optional": { + "control_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "control signal for the Fun -model"}), + "extra_latents": ("LATENT", {"tooltip": "First latent to use for the Pusa -model"}), + } + } + + RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) + RETURN_NAMES = ("image_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, num_frames, width, height, control_embeds=None, extra_latents=None): + target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1, + height // VAE_STRIDE[1], + width // VAE_STRIDE[2]) + + embeds = { + "target_shape": target_shape, + "num_frames": num_frames, + "control_embeds": control_embeds["control_embeds"] if control_embeds is not None else None, + } + if extra_latents is not None: + embeds["extra_latents"] = [{ + "samples": extra_latents["samples"], + "index": 0, + }] + + return (embeds,) + +class WanVideoAddExtraLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "embeds": ("WANVIDIMAGE_EMBEDS",), + "extra_latents": ("LATENT",), + "latent_index": ("INT", {"default": 0, "min": -1000, "max": 1000, "step": 1, "tooltip": "Index to insert the extra latents at in latent space"}), + } + } + + RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) + RETURN_NAMES = ("image_embeds",) + FUNCTION = "add" + CATEGORY = "WanVideoWrapper" + + def add(self, embeds, extra_latents, latent_index): + # Prepare the new extra latent entry + new_entry = { + "samples": extra_latents["samples"], + "index": latent_index, + } + # Get previous extra_latents list, or start a new one + prev_extra_latents = embeds.get("extra_latents", None) + if prev_extra_latents is None: + extra_latents_list = [new_entry] + elif isinstance(prev_extra_latents, list): + extra_latents_list = prev_extra_latents + [new_entry] + else: + extra_latents_list = [prev_extra_latents, new_entry] + + # Return a new dict with updated extra_latents + updated = dict(embeds) + updated["extra_latents"] = extra_latents_list + return (updated,) + +class WanVideoMiniMaxRemoverEmbeds: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}), + "height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}), + "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), + "latents": ("LATENT", {"tooltip": "Encoded latents to use as control signals"}), + "mask_latents": ("LATENT", {"tooltip": "Encoded latents to use as mask"}), + }, + } + + RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) + RETURN_NAMES = ("image_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, num_frames, width, height, latents, mask_latents): + target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1, + height // VAE_STRIDE[1], + width // VAE_STRIDE[2]) + + embeds = { + "target_shape": target_shape, + "num_frames": num_frames, + "minimax_latents": latents["samples"].squeeze(0), + "minimax_mask_latents": mask_latents["samples"].squeeze(0), + } + + return (embeds,) + +# region phantom +class WanVideoPhantomEmbeds: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), + "phantom_latent_1": ("LATENT", {"tooltip": "reference latents for the phantom model"}), + + "phantom_cfg_scale": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "CFG scale for the extra phantom cond pass"}), + "phantom_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the phantom model"}), + "phantom_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the phantom model"}), + }, + "optional": { + "phantom_latent_2": ("LATENT", {"tooltip": "reference latents for the phantom model"}), + "phantom_latent_3": ("LATENT", {"tooltip": "reference latents for the phantom model"}), + "phantom_latent_4": ("LATENT", {"tooltip": "reference latents for the phantom model"}), + "vace_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "VACE embeds"}), + } + } + + RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) + RETURN_NAMES = ("image_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, num_frames, phantom_cfg_scale, phantom_start_percent, phantom_end_percent, phantom_latent_1, phantom_latent_2=None, phantom_latent_3=None, phantom_latent_4=None, vace_embeds=None): + samples = phantom_latent_1["samples"].squeeze(0) + if phantom_latent_2 is not None: + samples = torch.cat([samples, phantom_latent_2["samples"].squeeze(0)], dim=1) + if phantom_latent_3 is not None: + samples = torch.cat([samples, phantom_latent_3["samples"].squeeze(0)], dim=1) + if phantom_latent_4 is not None: + samples = torch.cat([samples, phantom_latent_4["samples"].squeeze(0)], dim=1) + C, T, H, W = samples.shape + + log.info(f"Phantom latents shape: {samples.shape}") + + target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1 + T, + H * 8 // VAE_STRIDE[1], + W * 8 // VAE_STRIDE[2]) + + embeds = { + "target_shape": target_shape, + "num_frames": num_frames, + "phantom_latents": samples, + "phantom_cfg_scale": phantom_cfg_scale, + "phantom_start_percent": phantom_start_percent, + "phantom_end_percent": phantom_end_percent, + } + if vace_embeds is not None: + vace_input = { + "vace_context": vace_embeds["vace_context"], + "vace_scale": vace_embeds["vace_scale"], + "has_ref": vace_embeds["has_ref"], + "vace_start_percent": vace_embeds["vace_start_percent"], + "vace_end_percent": vace_embeds["vace_end_percent"], + "vace_seq_len": vace_embeds["vace_seq_len"], + "additional_vace_inputs": vace_embeds["additional_vace_inputs"], + } + embeds.update(vace_input) + + return (embeds,) + +class WanVideoControlEmbeds: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "latents": ("LATENT", {"tooltip": "Encoded latents to use as control signals"}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the control signal"}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the control signal"}), + }, + "optional": { + "fun_ref_image": ("LATENT", {"tooltip": "Reference latent for the Fun 1.1 -model"}), + } + } + + RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) + RETURN_NAMES = ("image_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, latents, start_percent, end_percent, fun_ref_image=None): + + samples = latents["samples"].squeeze(0) + C, T, H, W = samples.shape + + num_frames = (T - 1) * 4 + 1 + seq_len = math.ceil((H * W) / 4 * ((num_frames - 1) // 4 + 1)) + + embeds = { + "max_seq_len": seq_len, + "target_shape": samples.shape, + "num_frames": num_frames, + "control_embeds": { + "control_images": samples, + "start_percent": start_percent, + "end_percent": end_percent, + "fun_ref_image": fun_ref_image["samples"][:,:, 0] if fun_ref_image is not None else None, + } + } + + return (embeds,) + +class WanVideoSLG: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "blocks": ("STRING", {"default": "10", "tooltip": "Blocks to skip uncond on, separated by comma, index starts from 0"}), + "start_percent": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the control signal"}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the control signal"}), + }, + } + + RETURN_TYPES = ("SLGARGS", ) + RETURN_NAMES = ("slg_args",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Skips uncond on the selected blocks" + + def process(self, blocks, start_percent, end_percent): + slg_block_list = [int(x.strip()) for x in blocks.split(",")] + + slg_args = { + "blocks": slg_block_list, + "start_percent": start_percent, + "end_percent": end_percent, + } + return (slg_args,) + +#region VACE +class WanVideoVACEEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "vae": ("WANVAE",), + "width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}), + "height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}), + "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}), + "vace_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the steps to apply VACE"}), + "vace_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the steps to apply VACE"}), + }, + "optional": { + "input_frames": ("IMAGE",), + "ref_images": ("IMAGE",), + "input_masks": ("MASK",), + "prev_vace_embeds": ("WANVIDIMAGE_EMBEDS",), + "tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}), + }, + } + + RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) + RETURN_NAMES = ("vace_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, vae, width, height, num_frames, strength, vace_start_percent, vace_end_percent, input_frames=None, ref_images=None, input_masks=None, prev_vace_embeds=None, tiled_vae=False): + vae = vae.to(device) + + width = (width // 16) * 16 + height = (height // 16) * 16 + + target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1, + height // VAE_STRIDE[1], + width // VAE_STRIDE[2]) + # vace context encode + if input_frames is None: + input_frames = torch.zeros((1, 3, num_frames, height, width), device=device, dtype=vae.dtype) + else: + input_frames = input_frames[:num_frames] + input_frames = common_upscale(input_frames.clone().movedim(-1, 1), width, height, "lanczos", "disabled").movedim(1, -1) + input_frames = input_frames.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W + input_frames = input_frames * 2 - 1 + if input_masks is None: + input_masks = torch.ones_like(input_frames, device=device) + else: + log.info(f"input_masks shape: {input_masks.shape}") + input_masks = input_masks[:num_frames] + input_masks = common_upscale(input_masks.clone().unsqueeze(1), width, height, "nearest-exact", "disabled").squeeze(1) + input_masks = input_masks.to(vae.dtype).to(device) + input_masks = input_masks.unsqueeze(-1).unsqueeze(0).permute(0, 4, 1, 2, 3).repeat(1, 3, 1, 1, 1) # B, C, T, H, W + + if ref_images is not None: + # Create padded image + if ref_images.shape[0] > 1: + ref_images = torch.cat([ref_images[i] for i in range(ref_images.shape[0])], dim=1).unsqueeze(0) + + B, H, W, C = ref_images.shape + current_aspect = W / H + target_aspect = width / height + if current_aspect > target_aspect: + # Image is wider than target, pad height + new_h = int(W / target_aspect) + pad_h = (new_h - H) // 2 + padded = torch.ones(ref_images.shape[0], new_h, W, ref_images.shape[3], device=ref_images.device, dtype=ref_images.dtype) + padded[:, pad_h:pad_h+H, :, :] = ref_images + ref_images = padded + elif current_aspect < target_aspect: + # Image is taller than target, pad width + new_w = int(H * target_aspect) + pad_w = (new_w - W) // 2 + padded = torch.ones(ref_images.shape[0], H, new_w, ref_images.shape[3], device=ref_images.device, dtype=ref_images.dtype) + padded[:, :, pad_w:pad_w+W, :] = ref_images + ref_images = padded + ref_images = common_upscale(ref_images.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) + + ref_images = ref_images.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3).unsqueeze(0) + ref_images = ref_images * 2 - 1 + + z0 = self.vace_encode_frames(vae, input_frames, ref_images, masks=input_masks, tiled_vae=tiled_vae) + vae.model.clear_cache() + m0 = self.vace_encode_masks(input_masks, ref_images) + z = self.vace_latent(z0, m0) + + vae.to(offload_device) + + vace_input = { + "vace_context": z, + "vace_scale": strength, + "has_ref": ref_images is not None, + "num_frames": num_frames, + "target_shape": target_shape, + "vace_start_percent": vace_start_percent, + "vace_end_percent": vace_end_percent, + "vace_seq_len": math.ceil((z[0].shape[2] * z[0].shape[3]) / 4 * z[0].shape[1]), + "additional_vace_inputs": [], + } + + if prev_vace_embeds is not None: + if "additional_vace_inputs" in prev_vace_embeds and prev_vace_embeds["additional_vace_inputs"]: + vace_input["additional_vace_inputs"] = prev_vace_embeds["additional_vace_inputs"].copy() + vace_input["additional_vace_inputs"].append(prev_vace_embeds) + + return (vace_input,) + def vace_encode_frames(self, vae, frames, ref_images, masks=None, tiled_vae=False): + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = vae.encode(frames, device=device, tiled=tiled_vae) + else: + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = vae.encode(inactive, device=device, tiled=tiled_vae) + reactive = vae.encode(reactive, device=device, tiled=tiled_vae) + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + vae.model.clear_cache() + cat_latents = [] + + pbar = ProgressBar(len(frames)) + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = vae.encode(refs, device=device, tiled=tiled_vae) + else: + ref_latent = vae.encode(refs, device=device, tiled=tiled_vae) + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + pbar.update(1) + return cat_latents + + def vace_encode_masks(self, masks, ref_images=None): + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + pbar = ProgressBar(len(masks)) + for mask, refs in zip(masks, ref_images): + _c, depth, height, width = mask.shape + new_depth = int((depth + 3) // VAE_STRIDE[0]) + height = 2 * (int(height) // (VAE_STRIDE[1] * 2)) + width = 2 * (int(width) // (VAE_STRIDE[2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view( + depth, height, VAE_STRIDE[1], width, VAE_STRIDE[1] + ) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape( + VAE_STRIDE[1] * VAE_STRIDE[2], depth, height, width + ) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + + if refs is not None: + length = len(refs) + mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + pbar.update(1) + return result_masks + + def vace_latent(self, z, m): + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] + +#region context options +class WanVideoContextOptions: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "context_schedule": (["uniform_standard", "uniform_looped", "static_standard"],), + "context_frames": ("INT", {"default": 81, "min": 2, "max": 1000, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ), + "context_stride": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ), + "context_overlap": ("INT", {"default": 16, "min": 4, "max": 100, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ), + "freenoise": ("BOOLEAN", {"default": True, "tooltip": "Shuffle the noise"}), + "verbose": ("BOOLEAN", {"default": False, "tooltip": "Print debug output"}), + }, + "optional": { + "fuse_method": (["linear", "pyramid"], {"default": "linear", "tooltip": "Window weight function: linear=ramps at edges only, pyramid=triangular weights peaking in middle"}), + "reference_latent": ("LATENT", {"tooltip": "Image to be used as init for I2V models for windows where first frame is not the actual first frame. Mostly useful with MAGREF model"}), + } + } + + RETURN_TYPES = ("WANVIDCONTEXT", ) + RETURN_NAMES = ("context_options",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Context options for WanVideo, allows splitting the video into context windows and attemps blending them for longer generations than the model and memory otherwise would allow." + + def process(self, context_schedule, context_frames, context_stride, context_overlap, freenoise, verbose, image_cond_start_step=6, image_cond_window_count=2, vae=None, fuse_method="linear", reference_latent=None): + context_options = { + "context_schedule":context_schedule, + "context_frames":context_frames, + "context_stride":context_stride, + "context_overlap":context_overlap, + "freenoise":freenoise, + "verbose":verbose, + "fuse_method":fuse_method, + "reference_latent":reference_latent["samples"][0] if reference_latent is not None else None, + } + + return (context_options,) + +class WanVideoFlowEdit: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "source_embeds": ("WANVIDEOTEXTEMBEDS", ), + "skip_steps": ("INT", {"default": 4, "min": 0}), + "drift_steps": ("INT", {"default": 0, "min": 0}), + "drift_flow_shift": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 30.0, "step": 0.01}), + "source_cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), + "drift_cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), + }, + "optional": { + "source_image_embeds": ("WANVIDIMAGE_EMBEDS", ), + } + } + + RETURN_TYPES = ("FLOWEDITARGS", ) + RETURN_NAMES = ("flowedit_args",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Flowedit options for WanVideo" + + def process(self, **kwargs): + return (kwargs,) + +class WanVideoLoopArgs: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "shift_skip": ("INT", {"default": 6, "min": 0, "tooltip": "Skip step of latent shift"}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the looping effect"}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the looping effect"}), + }, + } + + RETURN_TYPES = ("LOOPARGS", ) + RETURN_NAMES = ("loop_args",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Looping through latent shift as shown in https://github.com/YisuiTT/Mobius/" + + def process(self, **kwargs): + return (kwargs,) + +class WanVideoExperimentalArgs: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "video_attention_split_steps": ("STRING", {"default": "", "tooltip": "Steps to split self attention when using multiple prompts"}), + "cfg_zero_star": ("BOOLEAN", {"default": False, "tooltip": "https://github.com/WeichenFan/CFG-Zero-star"}), + "use_zero_init": ("BOOLEAN", {"default": False}), + "zero_star_steps": ("INT", {"default": 0, "min": 0, "tooltip": "Steps to split self attention when using multiple prompts"}), + "use_fresca": ("BOOLEAN", {"default": False, "tooltip": "https://github.com/WikiChao/FreSca"}), + "fresca_scale_low": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "fresca_scale_high": ("FLOAT", {"default": 1.25, "min": 0.0, "max": 10.0, "step": 0.01}), + "fresca_freq_cutoff": ("INT", {"default": 20, "min": 0, "max": 10000, "step": 1}), + "use_tcfg": ("BOOLEAN", {"default": False, "tooltip": "https://arxiv.org/abs/2503.18137 TCFG: Tangential Damping Classifier-free Guidance. CFG artifacts reduction."}), + "raag_alpha": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Alpha value for RAAG, 1.0 is default, 0.0 is disabled."}), + }, + } + + RETURN_TYPES = ("EXPERIMENTALARGS", ) + RETURN_NAMES = ("exp_args",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Experimental stuff" + EXPERIMENTAL = True + + def process(self, **kwargs): + return (kwargs,) + +class WanVideoFreeInitArgs: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "freeinit_num_iters": ("INT", {"default": 3, "min": 1, "max": 10, "tooltip": "Number of FreeInit iterations"}), + "freeinit_method": (["butterworth", "ideal", "gaussian", "none"], {"default": "ideal", "tooltip": "Frequency filter type"}), + "freeinit_n": ("INT", {"default": 4, "min": 1, "max": 10, "tooltip": "Butterworth filter order (only for butterworth)"}), + "freeinit_d_s": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Spatial filter cutoff"}), + "freeinit_d_t": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Temporal filter cutoff"}), + }, + } + + RETURN_TYPES = ("FREEINITARGS", ) + RETURN_NAMES = ("freeinit_args",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "https://github.com/TianxingWu/FreeInit; FreeInit, a concise yet effective method to improve temporal consistency of videos generated by diffusion models" + EXPERIMENTAL = True + + def process(self, **kwargs): + return (kwargs,) + +#region Sampler +class WanVideoSampler: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("WANVIDEOMODEL",), + "image_embeds": ("WANVIDIMAGE_EMBEDS", ), + "steps": ("INT", {"default": 30, "min": 1}), + "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), + "shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "force_offload": ("BOOLEAN", {"default": True, "tooltip": "Moves the model to the offload device after sampling"}), + "scheduler": (scheduler_list, {"default": "uni_pc",}), + "riflex_freq_index": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1, "tooltip": "Frequency index for RIFLEX, disabled when 0, default 6. Allows for new frames to be generated after without looping"}), + # Customized for framepack + "total_second_length": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 120.0, "step": 0.1}), + "latent_window_size": ("INT", {"default": 8, "min": 1, "max": 50, "step": 1}), + }, + "optional": { + "text_embeds": ("WANVIDEOTEXTEMBEDS", ), + "samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ), + "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "feta_args": ("FETAARGS", ), + "context_options": ("WANVIDCONTEXT", ), + "cache_args": ("CACHEARGS", ), + "flowedit_args": ("FLOWEDITARGS", ), + "batched_cfg": ("BOOLEAN", {"default": False, "tooltip": "Batch cond and uncond for faster sampling, possibly faster on some hardware, uses more memory"}), + "slg_args": ("SLGARGS", ), + "rope_function": (["default", "comfy", "comfy_chunked"], {"default": "comfy", "tooltip": "Comfy's RoPE implementation doesn't use complex numbers and can thus be compiled, that should be a lot faster when using torch.compile. Chunked version has reduced peak VRAM usage when not using torch.compile"}), + "loop_args": ("LOOPARGS", ), + "experimental_args": ("EXPERIMENTALARGS", ), + "sigmas": ("SIGMAS", ), + "unianimate_poses": ("UNIANIMATE_POSE", ), + "fantasytalking_embeds": ("FANTASYTALKING_EMBEDS", ), + "uni3c_embeds": ("UNI3C_EMBEDS", ), + "multitalk_embeds": ("MULTITALK_EMBEDS", ), + "freeinit_args": ("FREEINITARGS", ), + "start_step": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Start step for the sampling, 0 means full sampling, otherwise samples only from this step"}), + "end_step": ("INT", {"default": -1, "min": -1, "max": 10000, "step": 1, "tooltip": "End step for the sampling, -1 means full sampling, otherwise samples only until this step"}), + "add_noise_to_samples": ("BOOLEAN", {"default": False, "tooltip": "Add noise to the samples before sampling, needed for video2video sampling when starting from clean video"}), + } + } + + RETURN_TYPES = ("LATENT", "LATENT",) + RETURN_NAMES = ("samples", "denoised_samples",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + #Region model pred for a single section + def _predict_with_cfg(self, z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, **kwargs): + + nonlocal transformer + z_pos = kwargs.get('z_pos', z) + z_neg = kwargs.get('z_neg', z) + z = z.to(dtype) + with torch.autocast(device_type=mm.get_autocast_device(device), dtype=dtype, enabled=("fp8" in model["quantization"])): + if use_cfg_zero_star and (idx <= zero_star_steps) and use_zero_init: + return z*0, None + nonlocal patcher + current_step_percentage = idx / len(timesteps) + control_lora_enabled = False + image_cond_input = None + if control_latents is not None: + if control_lora: + control_lora_enabled = True + else: + if (control_start_percent <= current_step_percentage <= control_end_percent) or \ + (control_end_percent > 0 and idx == 0 and current_step_percentage >= control_start_percent): + image_cond_input = torch.cat([control_latents.to(z), image_cond.to(z)]) + else: + image_cond_input = torch.cat([torch.zeros_like(control_latents, dtype=dtype), image_cond.to(z)]) + if fun_ref_image is not None: + fun_ref_input = fun_ref_image.to(z) + else: + fun_ref_input = torch.zeros_like(z, dtype=z.dtype)[:, 0].unsqueeze(1) + #fun_ref_input = None + if control_lora: + if not control_start_percent <= current_step_percentage <= control_end_percent: + control_lora_enabled = False + if patcher.model.is_patched: + log.info("Unloading LoRA...") + patcher.unpatch_model(device) + patcher.model.is_patched = False + else: + image_cond_input = control_latents.to(z) + if not patcher.model.is_patched: + log.info("Loading LoRA...") + patcher = apply_lora(patcher, device, device, low_mem_load=False, control_lora=True) + patcher.model.is_patched = True + + elif ATI_tracks is not None and ((ati_start_percent <= current_step_percentage <= ati_end_percent) or + (ati_end_percent > 0 and idx == 0 and current_step_percentage >= ati_start_percent)): + image_cond_input = image_cond_ati.to(z) + else: + image_cond_input = image_cond.to(z) if image_cond is not None else None + if control_camera_latents is not None: + if (control_camera_start_percent <= current_step_percentage <= control_camera_end_percent) or \ + (control_end_percent > 0 and idx == 0 and current_step_percentage >= control_camera_start_percent): + control_camera_input = control_camera_latents.to(z) + else: + control_camera_input = None + if recammaster is not None: + z = torch.cat([z, recam_latents.to(z)], dim=1) + + use_phantom = False + if phantom_latents is not None: + if (phantom_start_percent <= current_step_percentage <= phantom_end_percent) or \ + (phantom_end_percent > 0 and idx == 0 and current_step_percentage >= phantom_start_percent): + z_pos = torch.cat([z[:,:-phantom_latents.shape[1]], phantom_latents.to(z)], dim=1) + z_phantom_img = torch.cat([z[:,:-phantom_latents.shape[1]], phantom_latents.to(z)], dim=1) + z_neg = torch.cat([z[:,:-phantom_latents.shape[1]], torch.zeros_like(phantom_latents).to(z)], dim=1) + use_phantom = True + if cache_state is not None and len(cache_state) != 3: + cache_state.append(None) + if not use_phantom: + z_pos = z_neg = z + if controlnet_latents is not None: + if (controlnet_start <= current_step_percentage < controlnet_end): + self.controlnet.to(device) + controlnet_states = self.controlnet( + hidden_states=z.unsqueeze(0).to(device, self.controlnet.dtype), + timestep=timestep, + encoder_hidden_states=positive_embeds[0].unsqueeze(0).to(device, self.controlnet.dtype), + attention_kwargs=None, + controlnet_states=controlnet_latents.to(device, self.controlnet.dtype), + return_dict=False, + )[0] + if isinstance(controlnet_states, (tuple, list)): + controlnet["controlnet_states"] = [x.to(z) for x in controlnet_states] + else: + controlnet["controlnet_states"] = controlnet_states.to(z) + add_cond_input = None + if add_cond is not None: + if (add_cond_start_percent <= current_step_percentage <= add_cond_end_percent) or \ + (add_cond_end_percent > 0 and idx == 0 and current_step_percentage >= add_cond_start_percent): + add_cond_input = add_cond + if minimax_latents is not None: + if context_window is not None: + z_pos = z_neg = torch.cat([z, minimax_latents[:, context_window], minimax_mask_latents[:, context_window]], dim=0) + else: + z_pos = z_neg = torch.cat([z, minimax_latents, minimax_mask_latents], dim=0) + + if not multitalk_sampling and multitalk_audio_embedding is not None: + audio_embedding = multitalk_audio_embedding + audio_embs = [] + indices = (torch.arange(4 + 1) - 2) * 1 + human_num = len(audio_embedding) + # split audio with window size + if context_window is None: + for human_idx in range(human_num): + center_indices = torch.arange( + 0, + latent_video_length * 4 + 1 if add_cond is not None else (latent_video_length-1) * 4 + 1, + 1).unsqueeze(1) + indices.unsqueeze(0) + center_indices = torch.clamp(center_indices, min=0, max=audio_embedding[human_idx].shape[0] - 1) + audio_emb = audio_embedding[human_idx][center_indices].unsqueeze(0).to(device) + audio_embs.append(audio_emb) + else: + for human_idx in range(human_num): + audio_start = context_window[0] * 4 + audio_end = context_window[-1] * 4 + 1 + #print("audio_start: ", audio_start, "audio_end: ", audio_end) + center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0) + center_indices = torch.clamp(center_indices, min=0, max=audio_embedding[human_idx].shape[0] - 1) + audio_emb = audio_embedding[human_idx][center_indices].unsqueeze(0).to(device) + audio_embs.append(audio_emb) + multitalk_audio_input = torch.concat(audio_embs, dim=0).to(dtype) + + elif multitalk_sampling and multitalk_audio_embeds is not None: + multitalk_audio_input = multitalk_audio_embeds + + if context_window is not None and pcd_data is not None and pcd_data["render_latent"].shape[2] != context_frames: + pcd_data_input = {"render_latent": pcd_data["render_latent"][:, :, context_window]} + for k in pcd_data: + if k != "render_latent": + pcd_data_input[k] = pcd_data[k] + else: + pcd_data_input = pcd_data + + base_params = { + 'seq_len': seq_len, + 'device': device, + 'freqs': freqs, + 't': timestep, + 'current_step': idx, + 'last_step': len(timesteps) - 1 == idx, + 'control_lora_enabled': control_lora_enabled, + 'enhance_enabled': enhance_enabled, + 'camera_embed': camera_embed, + 'unianim_data': unianim_data, + 'fun_ref': fun_ref_input if fun_ref_image is not None else None, + 'fun_camera': control_camera_input if control_camera_latents is not None else None, + 'audio_proj': audio_proj if fantasytalking_embeds is not None else None, + 'audio_scale': audio_scale, + "pcd_data": pcd_data_input, + "controlnet": controlnet, + "add_cond": add_cond_input, + "nag_params": text_embeds.get("nag_params", {}), + "nag_context": text_embeds.get("nag_prompt_embeds", None), + "multitalk_audio": multitalk_audio_input if multitalk_audio_embedding is not None else None, + "ref_target_masks": ref_target_masks if multitalk_audio_embedding is not None else None, + "inner_t": [shot_len] if shot_len else None, + } + batch_size = 1 + if not math.isclose(cfg_scale, 1.0): + if negative_embeds is None: + raise ValueError("Negative embeddings must be provided for CFG scale > 1.0") + if len(positive_embeds) > 1: + negative_embeds = negative_embeds * len(positive_embeds) + try: + if not batched_cfg: + #cond + noise_pred_cond, cache_state_cond = transformer( + [z_pos], context=positive_embeds, y=[image_cond_input] if image_cond_input is not None else None, + clip_fea=clip_fea, is_uncond=False, current_step_percentage=current_step_percentage, + pred_id=cache_state[0] if cache_state else None, + vace_data=vace_data, attn_cond=attn_cond, + **base_params + ) + noise_pred_cond = noise_pred_cond[0].to(intermediate_device) + if math.isclose(cfg_scale, 1.0): + if use_fresca: + noise_pred_cond = fourier_filter( + noise_pred_cond, + scale_low=fresca_scale_low, + scale_high=fresca_scale_high, + freq_cutoff=fresca_freq_cutoff, + ) + return noise_pred_cond, [cache_state_cond] + #uncond + if fantasytalking_embeds is not None: + if not math.isclose(audio_cfg_scale[idx], 1.0): + base_params['audio_proj'] = None + noise_pred_uncond, cache_state_uncond = transformer( + [z_neg], context=negative_embeds, clip_fea=clip_fea_neg if clip_fea_neg is not None else clip_fea, + y=[image_cond_input] if image_cond_input is not None else None, + is_uncond=True, current_step_percentage=current_step_percentage, + pred_id=cache_state[1] if cache_state else None, + vace_data=vace_data, attn_cond=attn_cond_neg, + **base_params + ) + noise_pred_uncond = noise_pred_uncond[0].to(intermediate_device) + #phantom + if use_phantom and not math.isclose(phantom_cfg_scale[idx], 1.0): + noise_pred_phantom, cache_state_phantom = transformer( + [z_phantom_img], context=negative_embeds, clip_fea=clip_fea_neg if clip_fea_neg is not None else clip_fea, + y=[image_cond_input] if image_cond_input is not None else None, + is_uncond=True, current_step_percentage=current_step_percentage, + pred_id=cache_state[2] if cache_state else None, + vace_data=None, + **base_params + ) + noise_pred_phantom = noise_pred_phantom[0].to(intermediate_device) + + noise_pred = noise_pred_uncond + phantom_cfg_scale[idx] * (noise_pred_phantom - noise_pred_uncond) + cfg_scale * (noise_pred_cond - noise_pred_phantom) + return noise_pred, [cache_state_cond, cache_state_uncond, cache_state_phantom] + #fantasytalking + if fantasytalking_embeds is not None: + if not math.isclose(audio_cfg_scale[idx], 1.0): + if cache_state is not None and len(cache_state) != 3: + cache_state.append(None) + base_params['audio_proj'] = None + noise_pred_no_audio, cache_state_audio = transformer( + [z_pos], context=positive_embeds, y=[image_cond_input] if image_cond_input is not None else None, + clip_fea=clip_fea, is_uncond=False, current_step_percentage=current_step_percentage, + pred_id=cache_state[2] if cache_state else None, + vace_data=vace_data, + **base_params + ) + noise_pred_no_audio = noise_pred_no_audio[0].to(intermediate_device) + noise_pred = ( + noise_pred_uncond + + cfg_scale * (noise_pred_no_audio - noise_pred_uncond) + + audio_cfg_scale[idx] * (noise_pred_cond - noise_pred_no_audio) + ) + return noise_pred, [cache_state_cond, cache_state_uncond, cache_state_audio] + elif multitalk_audio_embedding is not None: + if not math.isclose(audio_cfg_scale[idx], 1.0): + if cache_state is not None and len(cache_state) != 3: + cache_state.append(None) + base_params['multitalk_audio'] = torch.zeros_like(multitalk_audio_input)[-1:] + noise_pred_no_audio, cache_state_audio = transformer( + [z_pos], context=negative_embeds, y=[image_cond_input] if image_cond_input is not None else None, + clip_fea=clip_fea, is_uncond=False, current_step_percentage=current_step_percentage, + pred_id=cache_state[2] if cache_state else None, + vace_data=vace_data, + **base_params + ) + noise_pred_no_audio = noise_pred_no_audio[0].to(intermediate_device) + noise_pred = ( + noise_pred_no_audio + + cfg_scale * (noise_pred_cond - noise_pred_uncond) + + audio_cfg_scale[idx] * (noise_pred_uncond - noise_pred_no_audio) + ) + return noise_pred, [cache_state_cond, cache_state_uncond, cache_state_audio] + #batched + else: + cache_state_uncond = None + [noise_pred_cond, noise_pred_uncond], cache_state_cond = transformer( + [z] + [z], context=positive_embeds + negative_embeds, + y=[image_cond_input] + [image_cond_input] if image_cond_input is not None else None, + clip_fea=clip_fea.repeat(2,1,1), is_uncond=False, current_step_percentage=current_step_percentage, + pred_id=cache_state[0] if cache_state else None, + **base_params + ) + except Exception as e: + log.error(f"Error during model prediction: {e}") + if force_offload: + if model["manual_offloading"]: + offload_transformer(transformer) + raise e + #https://github.com/WeichenFan/CFG-Zero-star/ + if use_cfg_zero_star: + alpha = optimized_scale( + noise_pred_cond.view(batch_size, -1), + noise_pred_uncond.view(batch_size, -1) + ).view(batch_size, 1, 1, 1) + else: + alpha = 1.0 + + noise_pred_uncond_scaled = noise_pred_uncond * alpha + if use_tangential: + noise_pred_uncond_scaled = tangential_projection(noise_pred_cond, noise_pred_uncond_scaled) + + # RAAG (RATIO-aware Adaptive Guidance) + if raag_alpha > 0.0: + cfg_scale = get_raag_guidance(noise_pred_cond, noise_pred_uncond_scaled, cfg_scale, raag_alpha) + log.info(f"RAAG modified cfg: {cfg_scale}") + + #https://github.com/WikiChao/FreSca + if use_fresca: + filtered_cond = fourier_filter( + noise_pred_cond - noise_pred_uncond, + scale_low=fresca_scale_low, + scale_high=fresca_scale_high, + freq_cutoff=fresca_freq_cutoff, + ) + noise_pred = noise_pred_uncond_scaled + cfg_scale * filtered_cond * alpha + else: + noise_pred = noise_pred_uncond_scaled + cfg_scale * (noise_pred_cond - noise_pred_uncond_scaled) + + return noise_pred, [cache_state_cond, cache_state_uncond] + + def process(self, model, image_embeds, shift, steps, cfg, seed, scheduler, riflex_freq_index, text_embeds=None, + force_offload=True, samples=None, feta_args=None, denoise_strength=1.0, context_options=None, + cache_args=None, teacache_args=None, flowedit_args=None, batched_cfg=False, slg_args=None, rope_function="default", loop_args=None, + experimental_args=None, sigmas=None, unianimate_poses=None, fantasytalking_embeds=None, uni3c_embeds=None, multitalk_embeds=None, + freeinit_args=None, start_step=0, end_step=-1, add_noise_to_samples=False, total_second_length=5.0, latent_window_size=8): + + patcher = model + model = model.model + transformer = model.diffusion_model + + dtype = model["dtype"] + fp8_matmul = model["fp8_matmul"] + gguf = model["gguf"] + control_lora = model["control_lora"] + + transformer_options = patcher.model_options.get("transformer_options", None) + merge_loras = transformer_options["merge_loras"] + + is_5b = transformer.out_dim == 48 + vae_upscale_factor = 16 if is_5b else 8 + + patch_linear = transformer_options.get("patch_linear", False) + + if gguf: + set_lora_params_gguf(transformer, patcher.patches) + elif len(patcher.patches) != 0 and patch_linear: + log.info(f"Using {len(patcher.patches)} LoRA weight patches for WanVideo model") + if not merge_loras and fp8_matmul: + raise NotImplementedError("FP8 matmul with unmerged LoRAs is not supported") + set_lora_params(transformer, patcher.patches) + else: + remove_lora_from_module(transformer) + + transformer.lora_scheduling_enabled = transformer_options.get("lora_scheduling_enabled", False) + + #torch.compile + if model["auto_cpu_offload"] is False: + transformer = compile_model(transformer, model["compile_args"]) + + multitalk_sampling = image_embeds.get("multitalk_sampling", False) + if not multitalk_sampling and scheduler == "multitalk": + raise Exception("multitalk scheduler is only for multitalk sampling when using ImagetoVideoMultiTalk -node") + + if text_embeds == None: + text_embeds = { + "prompt_embeds": [], + "negative_prompt_embeds": [], + } + else: + text_embeds = dict_to_device(text_embeds, device) + + seed_g = torch.Generator(device=torch.device("cpu")) + seed_g.manual_seed(seed) + + #region Scheduler + if scheduler != "multitalk": + sample_scheduler, timesteps = get_scheduler(scheduler, steps, shift, device, transformer.dim, flowedit_args, denoise_strength, sigmas=sigmas) + else: + timesteps = torch.tensor([1000, 750, 500, 250], device=device) + log.info(f"sigmas: {sample_scheduler.sigmas}") + + steps = len(timesteps) + + if end_step != -1 and start_step >= end_step: + raise ValueError("start_step must be less than end_step") + + if denoise_strength < 1.0: + if start_step != 0: + raise ValueError("start_step must be 0 when denoise_strength is used") + start_step = steps - int(steps * denoise_strength) - 1 + add_noise_to_samples = True #for now to not break old workflows + + first_sampler = (end_step != -1 or end_step >= steps) + + if isinstance(cfg, list): + if steps != len(cfg): + log.info(f"Received {len(cfg)} cfg values, but only {steps} steps. Setting step count to match.") + steps = len(cfg) + else: + cfg = [cfg] * (steps + 1) + + if first_sampler: + timesteps = timesteps[:end_step] + sample_scheduler.sigmas = sample_scheduler.sigmas[:end_step+1] + log.info(f"Sampling until step {end_step}, timestep: {timesteps[-1]}") + if start_step > 0: + timesteps = timesteps[start_step:] + sample_scheduler.sigmas = sample_scheduler.sigmas[start_step:] + log.info(f"Skipping first {start_step} steps, starting from timestep {timesteps[0]}") + + log.info(f"timesteps: {timesteps}") + + if hasattr(sample_scheduler, 'timesteps'): + sample_scheduler.timesteps = timesteps + + scheduler_step_args = {"generator": seed_g} + step_sig = inspect.signature(sample_scheduler.step) + for arg in list(scheduler_step_args.keys()): + if arg not in step_sig.parameters: + scheduler_step_args.pop(arg) + + control_latents = control_camera_latents = clip_fea = clip_fea_neg = end_image = recammaster = camera_embed = unianim_data = None + vace_data = vace_context = vace_scale = None + fun_or_fl2v_model = has_ref = drop_last = False + phantom_latents = fun_ref_image = ATI_tracks = None + add_cond = attn_cond = attn_cond_neg = None + + #I2V + image_cond = image_embeds.get("image_embeds", None) + if image_cond is not None: + if transformer.in_dim == 16: + raise ValueError("T2V (text to video) model detected, encoded images only work with I2V (Image to video) models") + log.info(f"image_cond shape: {image_cond.shape}") + #ATI tracks + if transformer_options is not None: + ATI_tracks = transformer_options.get("ati_tracks", None) + if ATI_tracks is not None: + from .ATI.motion_patch import patch_motion + topk = transformer_options.get("ati_topk", 2) + temperature = transformer_options.get("ati_temperature", 220.0) + ati_start_percent = transformer_options.get("ati_start_percent", 0.0) + ati_end_percent = transformer_options.get("ati_end_percent", 1.0) + image_cond_ati = patch_motion(ATI_tracks.to(image_cond.device, image_cond.dtype), image_cond, topk=topk, temperature=temperature) + log.info(f"ATI tracks shape: {ATI_tracks.shape}") + + add_cond_latents = image_embeds.get("add_cond_latents", None) + if add_cond_latents is not None: + add_cond = add_cond_latents["pose_latent"] + attn_cond = add_cond_latents["ref_latent"] + attn_cond_neg = add_cond_latents["ref_latent_neg"] + add_cond_start_percent = add_cond_latents["pose_cond_start_percent"] + add_cond_end_percent = add_cond_latents["pose_cond_end_percent"] + + end_image = image_embeds.get("end_image", None) + fun_or_fl2v_model = image_embeds.get("fun_or_fl2v_model", False) + + noise = torch.randn( #C, T, H, W + 48 if is_5b else 16, + (image_embeds["num_frames"] - 1) // 4 + (2 if end_image is not None and not fun_or_fl2v_model else 1), + image_embeds["lat_h"], + image_embeds["lat_w"], + dtype=torch.float32, + generator=seed_g, + device=torch.device("cpu")) + seq_len = image_embeds["max_seq_len"] + + clip_fea = image_embeds.get("clip_context", None) + if clip_fea is not None: + clip_fea = clip_fea.to(dtype) + clip_fea_neg = image_embeds.get("negative_clip_context", None) + if clip_fea_neg is not None: + clip_fea_neg = clip_fea_neg.to(dtype) + + control_embeds = image_embeds.get("control_embeds", None) + if control_embeds is not None: + if transformer.in_dim not in [52, 48, 32]: + raise ValueError("Control signal only works with Fun-Control model") + if transformer.in_dim == 52: #fun 2.2 control + image_cond_mask = image_embeds.get("mask", None) + if image_cond_mask is not None: + image_cond = torch.cat([image_cond_mask, image_cond]) + control_latents = control_embeds.get("control_images", None) + control_camera_latents = control_embeds.get("control_camera_latents", None) + control_camera_start_percent = control_embeds.get("control_camera_start_percent", 0.0) + control_camera_end_percent = control_embeds.get("control_camera_end_percent", 1.0) + control_start_percent = control_embeds.get("start_percent", 0.0) + control_end_percent = control_embeds.get("end_percent", 1.0) + drop_last = image_embeds.get("drop_last", False) + has_ref = image_embeds.get("has_ref", False) + else: #t2v + target_shape = image_embeds.get("target_shape", None) + if target_shape is None: + raise ValueError("Empty image embeds must be provided for T2V models") + + has_ref = image_embeds.get("has_ref", False) + vace_context = image_embeds.get("vace_context", None) + vace_scale = image_embeds.get("vace_scale", None) + if not isinstance(vace_scale, list): + vace_scale = [vace_scale] * (steps+1) + vace_start_percent = image_embeds.get("vace_start_percent", 0.0) + vace_end_percent = image_embeds.get("vace_end_percent", 1.0) + vace_seqlen = image_embeds.get("vace_seq_len", None) + + vace_additional_embeds = image_embeds.get("additional_vace_inputs", []) + if vace_context is not None: + vace_data = [ + {"context": vace_context, + "scale": vace_scale, + "start": vace_start_percent, + "end": vace_end_percent, + "seq_len": vace_seqlen + } + ] + if len(vace_additional_embeds) > 0: + for i in range(len(vace_additional_embeds)): + if vace_additional_embeds[i].get("has_ref", False): + has_ref = True + vace_scale = vace_additional_embeds[i]["vace_scale"] + if not isinstance(vace_scale, list): + vace_scale = [vace_scale] * (steps+1) + vace_data.append({ + "context": vace_additional_embeds[i]["vace_context"], + "scale": vace_scale, + "start": vace_additional_embeds[i]["vace_start_percent"], + "end": vace_additional_embeds[i]["vace_end_percent"], + "seq_len": vace_additional_embeds[i]["vace_seq_len"] + }) + + noise = torch.randn( + 48 if is_5b else 16, + target_shape[1] + 1 if has_ref else target_shape[1], + target_shape[2] // 2 if is_5b else target_shape[2], #todo make this smarter + target_shape[3] // 2 if is_5b else target_shape[3], #todo make this smarter + dtype=torch.float32, + device=torch.device("cpu"), + generator=seed_g) + + seq_len = math.ceil((noise.shape[2] * noise.shape[3]) / 4 * noise.shape[1]) + + recammaster = image_embeds.get("recammaster", None) + if recammaster is not None: + camera_embed = recammaster.get("camera_embed", None) + recam_latents = recammaster.get("source_latents", None) + orig_noise_len = noise.shape[1] + log.info(f"RecamMaster camera embed shape: {camera_embed.shape}") + log.info(f"RecamMaster source video shape: {recam_latents.shape}") + seq_len *= 2 + + control_embeds = image_embeds.get("control_embeds", None) + if control_embeds is not None: + control_latents = control_embeds.get("control_images", None) + if control_latents is not None: + control_latents = control_latents.to(device) + control_camera_latents = control_embeds.get("control_camera_latents", None) + control_camera_start_percent = control_embeds.get("control_camera_start_percent", 0.0) + control_camera_end_percent = control_embeds.get("control_camera_end_percent", 1.0) + if control_camera_latents is not None: + control_camera_latents = control_camera_latents.to(device) + + if control_lora: + image_cond = control_latents.to(device) + if not patcher.model.is_patched: + log.info("Re-loading control LoRA...") + patcher = apply_lora(patcher, device, device, low_mem_load=False, control_lora=True) + patcher.model.is_patched = True + else: + if transformer.in_dim not in [48, 32, 52]: + raise ValueError("Control signal only works with Fun-Control model") + image_cond = torch.zeros_like(noise).to(device) #fun control + if transformer.in_dim == 52: #fun 2.2 control + mask_latents = torch.tile( + torch.zeros_like(noise[:1]), [4, 1, 1, 1] + ) + masked_video_latents_input = torch.zeros_like(noise) + image_cond = torch.cat([mask_latents, masked_video_latents_input], dim=0).to(device) + clip_fea = None + fun_ref_image = control_embeds.get("fun_ref_image", None) + control_start_percent = control_embeds.get("start_percent", 0.0) + control_end_percent = control_embeds.get("end_percent", 1.0) + else: + if transformer.in_dim == 36: #fun inp + mask_latents = torch.tile( + torch.zeros_like(noise[:1]), [4, 1, 1, 1] + ) + masked_video_latents_input = torch.zeros_like(noise) + image_cond = torch.cat([mask_latents, masked_video_latents_input], dim=0).to(device) + + phantom_latents = image_embeds.get("phantom_latents", None) + phantom_cfg_scale = image_embeds.get("phantom_cfg_scale", None) + if not isinstance(phantom_cfg_scale, list): + phantom_cfg_scale = [phantom_cfg_scale] * (steps +1) + phantom_start_percent = image_embeds.get("phantom_start_percent", 0.0) + phantom_end_percent = image_embeds.get("phantom_end_percent", 1.0) + if phantom_latents is not None: + phantom_latents = phantom_latents.to(device) + + latent_video_length = noise.shape[1] + + # Initialize FreeInit filter if enabled + freq_filter = None + if freeinit_args is not None: + from .freeinit.freeinit_utils import get_freq_filter, freq_mix_3d + filter_shape = list(noise.shape) # [batch, C, T, H, W] + freq_filter = get_freq_filter( + filter_shape, + device=device, + filter_type=freeinit_args.get("freeinit_method", "butterworth"), + n=freeinit_args.get("freeinit_n", 4) if freeinit_args.get("freeinit_method", "butterworth") == "butterworth" else None, + d_s=freeinit_args.get("freeinit_s", 1.0), + d_t=freeinit_args.get("freeinit_t", 1.0) + ) + if samples is not None: + saved_generator_state = samples.get("generator_state", None) + if saved_generator_state is not None: + seed_g.set_state(saved_generator_state) + + # UniAnimate + if unianimate_poses is not None: + transformer.dwpose_embedding.to(device, model["dtype"]) + dwpose_data = unianimate_poses["pose"].to(device, model["dtype"]) + dwpose_data = torch.cat([dwpose_data[:,:,:1].repeat(1,1,3,1,1), dwpose_data], dim=2) + dwpose_data = transformer.dwpose_embedding(dwpose_data) + log.info(f"UniAnimate pose embed shape: {dwpose_data.shape}") + if dwpose_data.shape[2] > latent_video_length: + log.warning(f"UniAnimate pose embed length {dwpose_data.shape[2]} is longer than the video length {latent_video_length}, truncating") + dwpose_data = dwpose_data[:,:, :latent_video_length] + elif dwpose_data.shape[2] < latent_video_length: + log.warning(f"UniAnimate pose embed length {dwpose_data.shape[2]} is shorter than the video length {latent_video_length}, padding with last pose") + pad_len = latent_video_length - dwpose_data.shape[2] + pad = dwpose_data[:,:,:1].repeat(1,1,pad_len,1,1) + dwpose_data = torch.cat([dwpose_data, pad], dim=2) + dwpose_data_flat = rearrange(dwpose_data, 'b c f h w -> b (f h w) c').contiguous() + + random_ref_dwpose_data = None + if image_cond is not None: + transformer.randomref_embedding_pose.to(device) + random_ref_dwpose = unianimate_poses.get("ref", None) + if random_ref_dwpose is not None: + random_ref_dwpose_data = transformer.randomref_embedding_pose( + random_ref_dwpose.to(device) + ).unsqueeze(2).to(model["dtype"]) # [1, 20, 104, 60] + + unianim_data = { + "dwpose": dwpose_data_flat, + "random_ref": random_ref_dwpose_data.squeeze(0) if random_ref_dwpose_data is not None else None, + "strength": unianimate_poses["strength"], + "start_percent": unianimate_poses["start_percent"], + "end_percent": unianimate_poses["end_percent"] + } + + # FantasyTalking + audio_proj = multitalk_audio_embedding = None + audio_scale = 1.0 + if fantasytalking_embeds is not None: + audio_proj = fantasytalking_embeds["audio_proj"].to(device) + audio_scale = fantasytalking_embeds["audio_scale"] + audio_cfg_scale = fantasytalking_embeds["audio_cfg_scale"] + if not isinstance(audio_cfg_scale, list): + audio_cfg_scale = [audio_cfg_scale] * (steps +1) + log.info(f"Audio proj shape: {audio_proj.shape}") + elif multitalk_embeds is not None: + # Handle single or multiple speaker embeddings + audio_features_in = multitalk_embeds.get("audio_features", None) + if audio_features_in is None: + multitalk_audio_embedding = None + else: + if isinstance(audio_features_in, list): + multitalk_audio_embedding = [emb.to(device, dtype) for emb in audio_features_in] + else: + # keep backward-compatibility with single tensor input + multitalk_audio_embedding = [audio_features_in.to(device, dtype)] + + audio_scale = multitalk_embeds.get("audio_scale", 1.0) + audio_cfg_scale = multitalk_embeds.get("audio_cfg_scale", 1.0) + ref_target_masks = multitalk_embeds.get("ref_target_masks", None) + if not isinstance(audio_cfg_scale, list): + audio_cfg_scale = [audio_cfg_scale] * (steps + 1) + + shapes = [tuple(e.shape) for e in multitalk_audio_embedding] + log.info(f"Multitalk audio features shapes (per speaker): {shapes}") + + # MiniMax Remover + minimax_latents = minimax_mask_latents = None + minimax_latents = image_embeds.get("minimax_latents", None) + minimax_mask_latents = image_embeds.get("minimax_mask_latents", None) + if minimax_latents is not None: + log.info(f"minimax_latents: {minimax_latents.shape}") + log.info(f"minimax_mask_latents: {minimax_mask_latents.shape}") + minimax_latents = minimax_latents.to(device, dtype) + minimax_mask_latents = minimax_mask_latents.to(device, dtype) + + # Context windows + is_looped = False + context_reference_latent = None + if context_options is not None: + context_schedule = context_options["context_schedule"] + context_frames = (context_options["context_frames"] - 1) // 4 + 1 + context_stride = context_options["context_stride"] // 4 + context_overlap = context_options["context_overlap"] // 4 + context_reference_latent = context_options.get("reference_latent", None) + + # Get total number of prompts + num_prompts = len(text_embeds["prompt_embeds"]) + log.info(f"Number of prompts: {num_prompts}") + # Calculate which section this context window belongs to + section_size = (latent_video_length / num_prompts) if num_prompts != 0 else 1 + log.info(f"Section size: {section_size}") + is_looped = context_schedule == "uniform_looped" + + seq_len = math.ceil((noise.shape[2] * noise.shape[3]) / 4 * context_frames) + + if context_options["freenoise"]: + log.info("Applying FreeNoise") + # code from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) + delta = context_frames - context_overlap + for start_idx in range(0, latent_video_length-context_frames, delta): + place_idx = start_idx + context_frames + if place_idx >= latent_video_length: + break + end_idx = place_idx - 1 + + if end_idx + delta >= latent_video_length: + final_delta = latent_video_length - place_idx + list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long) + list_idx = list_idx[torch.randperm(final_delta, generator=seed_g)] + noise[:, place_idx:place_idx + final_delta, :, :] = noise[:, list_idx, :, :] + break + list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long) + list_idx = list_idx[torch.randperm(delta, generator=seed_g)] + noise[:, place_idx:place_idx + delta, :, :] = noise[:, list_idx, :, :] + + log.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap") + from .context_windows.context import get_context_scheduler, create_window_mask, WindowTracker + self.window_tracker = WindowTracker(verbose=context_options["verbose"]) + context = get_context_scheduler(context_schedule) + + # vid2vid + if samples is not None: + saved_generator_state = samples.get("generator_state", None) + if saved_generator_state is not None: + seed_g.set_state(saved_generator_state) + input_samples = samples["samples"].squeeze(0).to(noise) + if input_samples.shape[1] != noise.shape[1]: + input_samples = torch.cat([input_samples[:, :1].repeat(1, noise.shape[1] - input_samples.shape[1], 1, 1), input_samples], dim=1) + + if add_noise_to_samples: + latent_timestep = timesteps[:1].to(noise) + noise = noise * latent_timestep / 1000 + (1 - latent_timestep / 1000) * input_samples + else: + noise = input_samples + mask = samples.get("mask", None) + if mask is not None: + original_image = input_samples.to(device) + if mask.shape[2] != noise.shape[1]: + mask = torch.cat([torch.zeros(1, noise.shape[0], noise.shape[1] - mask.shape[2], noise.shape[2], noise.shape[3]), mask], dim=2) + + # extra latents (Pusa) and 5b + latents_to_insert = add_index = None + if (extra_latents := image_embeds.get("extra_latents", None)) is not None: + all_indices = [] + for entry in extra_latents: + add_index = entry["index"] + num_extra_frames = entry["samples"].shape[2] + noise[:, add_index:add_index+num_extra_frames] = entry["samples"].to(noise) + log.info(f"Adding extra samples to latent indices {add_index} to {add_index+num_extra_frames-1}") + all_indices.extend(range(add_index, add_index+num_extra_frames)) + + + latent = noise.to(device) + + #controlnet + controlnet_latents = controlnet = None + if transformer_options is not None: + controlnet = transformer_options.get("controlnet", None) + if controlnet is not None: + self.controlnet = controlnet["controlnet"] + controlnet_start = controlnet["controlnet_start"] + controlnet_end = controlnet["controlnet_end"] + controlnet_latents = controlnet["control_latents"] + controlnet["controlnet_weight"] = controlnet["controlnet_strength"] + controlnet["controlnet_stride"] = controlnet["control_stride"] + + #uni3c + pcd_data = pcd_data_input = None + if uni3c_embeds is not None: + transformer.controlnet = uni3c_embeds["controlnet"] + pcd_data = { + "render_latent": uni3c_embeds["render_latent"].to(dtype), + "render_mask": uni3c_embeds["render_mask"], + "camera_embedding": uni3c_embeds["camera_embedding"], + "controlnet_weight": uni3c_embeds["controlnet_weight"], + "start": uni3c_embeds["start"], + "end": uni3c_embeds["end"], + } + + # Enhance-a-video (feta) + if feta_args is not None and latent_video_length > 1: + set_enhance_weight(feta_args["weight"]) + feta_start_percent = feta_args["start_percent"] + feta_end_percent = feta_args["end_percent"] + if context_options is not None: + set_num_frames(context_frames) + else: + set_num_frames(latent_video_length) + enhance_enabled = True + else: + feta_args = None + enhance_enabled = False + + # EchoShot https://github.com/D2I-ai/EchoShot + echoshot = False + shot_len = None + if text_embeds is not None: + echoshot = text_embeds.get("echoshot", False) + if echoshot: + shot_num = len(text_embeds["prompt_embeds"]) + shot_len = [latent_video_length//shot_num] * (shot_num-1) + shot_len.append(latent_video_length-sum(shot_len)) + rope_function = "default" #echoshot does not support comfy rope function + log.info(f"Number of shots in prompt: {shot_num}, Shot token lengths: {shot_len}") + + #region transformer settings + #rope + freqs = None + transformer.rope_embedder.k = None + transformer.rope_embedder.num_frames = None + if "comfy" in rope_function: + transformer.rope_embedder.k = riflex_freq_index + transformer.rope_embedder.num_frames = latent_video_length + else: + d = transformer.dim // transformer.num_heads + freqs = torch.cat([ + rope_params(1024, d - 4 * (d // 6), L_test=latent_video_length, k=riflex_freq_index), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)) + ], + dim=1) + transformer.rope_func = rope_function + for block in transformer.blocks: + block.rope_func = rope_function + if transformer.vace_layers is not None: + for block in transformer.vace_blocks: + block.rope_func = rope_function + + #blockswap init + + mm.unload_all_models() + mm.soft_empty_cache() + gc.collect() + + if transformer_options is not None: + block_swap_args = transformer_options.get("block_swap_args", None) + + if block_swap_args is not None: + transformer.use_non_blocking = block_swap_args.get("use_non_blocking", False) + for name, param in transformer.named_parameters(): + if "block" not in name: + param.data = param.data.to(device) + if "control_adapter" in name: + param.data = param.data.to(device) + elif block_swap_args["offload_txt_emb"] and "txt_emb" in name: + param.data = param.data.to(offload_device) + elif block_swap_args["offload_img_emb"] and "img_emb" in name: + param.data = param.data.to(offload_device) + + transformer.block_swap( + block_swap_args["blocks_to_swap"] - 1 , + block_swap_args["offload_txt_emb"], + block_swap_args["offload_img_emb"], + vace_blocks_to_swap = block_swap_args.get("vace_blocks_to_swap", None), + prefetch_blocks = block_swap_args.get("prefetch_blocks", 0), + block_swap_debug = block_swap_args.get("block_swap_debug", False), + ) + elif model["auto_cpu_offload"]: + for module in transformer.modules(): + if hasattr(module, "offload"): + module.offload() + if hasattr(module, "onload"): + module.onload() + for block in transformer.blocks: + block.modulation = torch.nn.Parameter(block.modulation.to(device)) + transformer.head.modulation = torch.nn.Parameter(transformer.head.modulation.to(device)) + + elif model["manual_offloading"]: + transformer.to(device) + + # Initialize Cache if enabled + transformer.enable_teacache = transformer.enable_magcache = transformer.enable_easycache = False + cache_args = teacache_args if teacache_args is not None else cache_args #for backward compatibility on old workflows + if cache_args is not None: + from .cache_methods.cache_methods import set_transformer_cache_method + transformer = set_transformer_cache_method(transformer, timesteps, cache_args) + + # Initialize cache state + self.cache_state = [None, None] + if phantom_latents is not None: + log.info(f"Phantom latents shape: {phantom_latents.shape}") + self.cache_state = [None, None, None] + self.cache_state_source = [None, None] + self.cache_states_context = [] + + # Skip layer guidance (SLG) + if slg_args is not None: + assert batched_cfg is not None, "Batched cfg is not supported with SLG" + transformer.slg_blocks = slg_args["blocks"] + transformer.slg_start_percent = slg_args["start_percent"] + transformer.slg_end_percent = slg_args["end_percent"] + else: + transformer.slg_blocks = None + + # Setup radial attention + if transformer.attention_mode == "radial_sage_attention": + setup_radial_attention(transformer, transformer_options, latent, seq_len, latent_video_length, context_options=context_options) + + # FlowEdit setup + if flowedit_args is not None: + source_embeds = flowedit_args["source_embeds"] + source_embeds = dict_to_device(source_embeds, device) + source_image_embeds = flowedit_args.get("source_image_embeds", image_embeds) + source_image_cond = source_image_embeds.get("image_embeds", None) + source_clip_fea = source_image_embeds.get("clip_fea", clip_fea) + if source_image_cond is not None: + source_image_cond = source_image_cond.to(dtype) + skip_steps = flowedit_args["skip_steps"] + drift_steps = flowedit_args["drift_steps"] + source_cfg = flowedit_args["source_cfg"] + if not isinstance(source_cfg, list): + source_cfg = [source_cfg] * (steps +1) + drift_cfg = flowedit_args["drift_cfg"] + if not isinstance(drift_cfg, list): + drift_cfg = [drift_cfg] * (steps +1) + + x_init = samples["samples"].clone().squeeze(0).to(device) + x_tgt = samples["samples"].squeeze(0).to(device) + + sample_scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, + shift=flowedit_args["drift_flow_shift"], + use_dynamic_shifting=False) + + sampling_sigmas = get_sampling_sigmas(steps, flowedit_args["drift_flow_shift"]) + + drift_timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=device, + sigmas=sampling_sigmas) + + if drift_steps > 0: + drift_timesteps = torch.cat([drift_timesteps, torch.tensor([0]).to(drift_timesteps.device)]).to(drift_timesteps.device) + timesteps[-drift_steps:] = drift_timesteps[-drift_steps:] + + # Experimental args + use_cfg_zero_star = use_tangential = use_fresca = False + raag_alpha = 0.0 + if experimental_args is not None: + video_attention_split_steps = experimental_args.get("video_attention_split_steps", []) + if video_attention_split_steps: + transformer.video_attention_split_steps = [int(x.strip()) for x in video_attention_split_steps.split(",")] + else: + transformer.video_attention_split_steps = [] + + use_zero_init = experimental_args.get("use_zero_init", True) + use_cfg_zero_star = experimental_args.get("cfg_zero_star", False) + use_tangential = experimental_args.get("use_tcfg", False) + zero_star_steps = experimental_args.get("zero_star_steps", 0) + raag_alpha = experimental_args.get("raag_alpha", 0.0) + + use_fresca = experimental_args.get("use_fresca", False) + if use_fresca: + fresca_scale_low = experimental_args.get("fresca_scale_low", 1.0) + fresca_scale_high = experimental_args.get("fresca_scale_high", 1.25) + fresca_freq_cutoff = experimental_args.get("fresca_freq_cutoff", 20) + + log.info(f"Seq len: {seq_len}") + + if args.preview_method in [LatentPreviewMethod.Auto, LatentPreviewMethod.Latent2RGB]: #default for latent2rgb + from latent_preview import prepare_callback + else: + from .latent_preview import prepare_callback #custom for tiny VAE previews + callback = prepare_callback(patcher, len(timesteps)) + + log.info(f"Sampling {(latent_video_length-1) * 4 + 1} frames at {latent.shape[3]*vae_upscale_factor}x{latent.shape[2]*vae_upscale_factor} with {steps} steps") + + intermediate_device = device + + # diff diff prep + masks = None + if samples is not None and mask is not None: + mask = 1 - mask + thresholds = torch.arange(len(timesteps), dtype=original_image.dtype) / len(timesteps) + thresholds = thresholds.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1).to(device) + masks = mask.repeat(len(timesteps), 1, 1, 1, 1).to(device) + masks = masks > thresholds + + latent_shift_loop = False + if loop_args is not None: + latent_shift_loop = True + is_looped = True + latent_skip = loop_args["shift_skip"] + latent_shift_start_percent = loop_args["start_percent"] + latent_shift_end_percent = loop_args["end_percent"] + shift_idx = 0 + + #clear memory before sampling + mm.soft_empty_cache() + gc.collect() + try: + torch.cuda.reset_peak_memory_stats(device) + #torch.cuda.memory._record_memory_history(max_entries=100000) + except: + pass + + # Main sampling loop with FreeInit iterations + iterations = freeinit_args.get("freeinit_num_iters", 3) if freeinit_args is not None else 1 + current_latent = latent + + for iter_idx in range(iterations): + + # FreeInit noise reinitialization (after first iteration) + if freeinit_args is not None and iter_idx > 0: + # restart scheduler for each iteration + sample_scheduler, timesteps = get_scheduler(scheduler, steps, shift, device, transformer.dim, flowedit_args, denoise_strength, sigmas=sigmas) + + # Re-apply start_step and end_step logic to timesteps and sigmas + if end_step != -1: + timesteps = timesteps[:end_step] + sample_scheduler.sigmas = sample_scheduler.sigmas[:end_step+1] + if start_step > 0: + timesteps = timesteps[start_step:] + sample_scheduler.sigmas = sample_scheduler.sigmas[start_step:] + if hasattr(sample_scheduler, 'timesteps'): + sample_scheduler.timesteps = timesteps + + # Diffuse current latent to t=999 + diffuse_timesteps = torch.full((noise.shape[0],), 999, device=device, dtype=torch.long) + z_T = add_noise( + current_latent.to(device), + initial_noise_saved.to(device), + diffuse_timesteps + ) + + # Generate new random noise + z_rand = torch.randn(z_T.shape, dtype=torch.float32, generator=seed_g, device=torch.device("cpu")) + + # Apply frequency mixing + current_latent = freq_mix_3d(z_T.to(torch.float32), z_rand.to(device), LPF=freq_filter) + current_latent = current_latent.to(dtype) + + # Store initial noise for first iteration + if freeinit_args is not None and iter_idx == 0: + initial_noise_saved = current_latent.detach().clone() + if samples is not None: + current_latent = input_samples.to(device) + continue + + # Reset per-iteration states + self.cache_state = [None, None] + self.cache_state_source = [None, None] + self.cache_states_context = [] + if context_options is not None: + self.window_tracker = WindowTracker(verbose=context_options["verbose"]) + + # Set latent for denoising + latent = current_latent + + try: + pbar = ProgressBar(len(timesteps)) + #region main loop start + for idx, t in enumerate(tqdm(timesteps)): + if flowedit_args is not None: + if idx < skip_steps: + continue + + # diff diff + if masks is not None: + if idx < len(timesteps) - 1: + noise_timestep = timesteps[idx+1] + image_latent = sample_scheduler.scale_noise( + original_image, torch.tensor([noise_timestep]), noise.to(device) + ) + mask = masks[idx] + mask = mask.to(latent) + latent = image_latent * mask + latent * (1-mask) + # end diff diff + + latent_model_input = latent.to(device) + + current_step_percentage = idx / len(timesteps) + + timestep = torch.tensor([t]).to(device) + if scheduler == "flowmatch_pusa" or (is_5b and 'all_indices' in locals()): + orig_timestep = timestep + timestep = timestep.unsqueeze(1).repeat(1, latent_video_length) + if extra_latents is not None: + if 'all_indices' in locals() and all_indices: + timestep[:, all_indices] = 0 + #print("timestep: ", timestep) + + ### latent shift + if latent_shift_loop: + if latent_shift_start_percent <= current_step_percentage <= latent_shift_end_percent: + latent_model_input = torch.cat([latent_model_input[:, shift_idx:]] + [latent_model_input[:, :shift_idx]], dim=1) + + #enhance-a-video + enhance_enabled = False + if feta_args is not None and feta_start_percent <= current_step_percentage <= feta_end_percent: + enhance_enabled = True + + #flow-edit + if flowedit_args is not None: + sigma = t / 1000.0 + sigma_prev = (timesteps[idx + 1] if idx < len(timesteps) - 1 else timesteps[-1]) / 1000.0 + noise = torch.randn(x_init.shape, generator=seed_g, device=torch.device("cpu")) + if idx < len(timesteps) - drift_steps: + cfg = drift_cfg + + zt_src = (1-sigma) * x_init + sigma * noise.to(t) + zt_tgt = x_tgt + zt_src - x_init + + #source + if idx < len(timesteps) - drift_steps: + if context_options is not None: + counter = torch.zeros_like(zt_src, device=intermediate_device) + vt_src = torch.zeros_like(zt_src, device=intermediate_device) + context_queue = list(context(idx, steps, latent_video_length, context_frames, context_stride, context_overlap)) + for c in context_queue: + window_id = self.window_tracker.get_window_id(c) + + if cache_args is not None: + current_teacache = self.window_tracker.get_teacache(window_id, self.cache_state) + else: + current_teacache = None + + prompt_index = min(int(max(c) / section_size), num_prompts - 1) + if context_options["verbose"]: + log.info(f"Prompt index: {prompt_index}") + + if len(source_embeds["prompt_embeds"]) > 1: + positive = source_embeds["prompt_embeds"][prompt_index] + else: + positive = source_embeds["prompt_embeds"] + + partial_img_emb = None + if source_image_cond is not None: + partial_img_emb = source_image_cond[:, c, :, :] + partial_img_emb[:, 0, :, :] = source_image_cond[:, 0, :, :].to(intermediate_device) + + partial_zt_src = zt_src[:, c, :, :] + vt_src_context, new_teacache = predict_with_cfg( + partial_zt_src, cfg[idx], + positive, source_embeds["negative_prompt_embeds"], + timestep, idx, partial_img_emb, control_latents, + source_clip_fea, current_teacache) + + if cache_args is not None: + self.window_tracker.cache_states[window_id] = new_teacache + + window_mask = create_window_mask(vt_src_context, c, latent_video_length, context_overlap) + vt_src[:, c, :, :] += vt_src_context * window_mask + counter[:, c, :, :] += window_mask + vt_src /= counter + else: + vt_src, self.cache_state_source = predict_with_cfg( + zt_src, cfg[idx], + source_embeds["prompt_embeds"], + source_embeds["negative_prompt_embeds"], + timestep, idx, source_image_cond, + source_clip_fea, control_latents, + cache_state=self.cache_state_source) + else: + if idx == len(timesteps) - drift_steps: + x_tgt = zt_tgt + zt_tgt = x_tgt + vt_src = 0 + #target + if context_options is not None: + counter = torch.zeros_like(zt_tgt, device=intermediate_device) + vt_tgt = torch.zeros_like(zt_tgt, device=intermediate_device) + context_queue = list(context(idx, steps, latent_video_length, context_frames, context_stride, context_overlap)) + for c in context_queue: + window_id = self.window_tracker.get_window_id(c) + + if cache_args is not None: + current_teacache = self.window_tracker.get_teacache(window_id, self.cache_state) + else: + current_teacache = None + + prompt_index = min(int(max(c) / section_size), num_prompts - 1) + if context_options["verbose"]: + log.info(f"Prompt index: {prompt_index}") + + if len(text_embeds["prompt_embeds"]) > 1: + positive = text_embeds["prompt_embeds"][prompt_index] + else: + positive = text_embeds["prompt_embeds"] + + partial_img_emb = None + partial_control_latents = None + if image_cond is not None: + partial_img_emb = image_cond[:, c, :, :] + partial_img_emb[:, 0, :, :] = image_cond[:, 0, :, :].to(intermediate_device) + if control_latents is not None: + partial_control_latents = control_latents[:, c, :, :] + + partial_zt_tgt = zt_tgt[:, c, :, :] + vt_tgt_context, new_teacache = predict_with_cfg( + partial_zt_tgt, cfg[idx], + positive, text_embeds["negative_prompt_embeds"], + timestep, idx, partial_img_emb, partial_control_latents, + clip_fea, current_teacache) + + if cache_args is not None: + self.window_tracker.cache_states[window_id] = new_teacache + + window_mask = create_window_mask(vt_tgt_context, c, latent_video_length, context_overlap) + vt_tgt[:, c, :, :] += vt_tgt_context * window_mask + counter[:, c, :, :] += window_mask + vt_tgt /= counter + else: + vt_tgt, self.cache_state = predict_with_cfg( + zt_tgt, cfg[idx], + text_embeds["prompt_embeds"], + text_embeds["negative_prompt_embeds"], + timestep, idx, image_cond, clip_fea, control_latents, + cache_state=self.cache_state) + v_delta = vt_tgt - vt_src + x_tgt = x_tgt.to(torch.float32) + v_delta = v_delta.to(torch.float32) + x_tgt = x_tgt + (sigma_prev - sigma) * v_delta + x0 = x_tgt + #region context windowing + elif context_options is not None: + counter = torch.zeros_like(latent_model_input, device=intermediate_device) + noise_pred = torch.zeros_like(latent_model_input, device=intermediate_device) + context_queue = list(context(idx, steps, latent_video_length, context_frames, context_stride, context_overlap)) + fraction_per_context = 1.0 / len(context_queue) + context_pbar = ProgressBar(steps) + step_start_progress = idx + + # Validate all context windows before processing + max_idx = latent_model_input.shape[1] if latent_model_input.ndim > 1 else 0 + for window_indices in context_queue: + if not all(0 <= idx < max_idx for idx in window_indices): + raise ValueError(f"Invalid context window indices {window_indices} for latent_model_input with shape {latent_model_input.shape}") + + for i, c in enumerate(context_queue): + window_id = self.window_tracker.get_window_id(c) + + if cache_args is not None: + current_teacache = self.window_tracker.get_teacache(window_id, self.cache_state) + else: + current_teacache = None + + prompt_index = min(int(max(c) / section_size), num_prompts - 1) + if context_options["verbose"]: + log.info(f"Prompt index: {prompt_index}") + + # Use the appropriate prompt for this section + if len(text_embeds["prompt_embeds"]) > 1: + positive = [text_embeds["prompt_embeds"][prompt_index]] + else: + positive = text_embeds["prompt_embeds"] + + partial_img_emb = None + partial_control_latents = None + if image_cond is not None: + partial_img_emb = image_cond[:, c] + + if c[0] != 0 and context_reference_latent is not None: + new_init_image = context_reference_latent[:, 0].to(intermediate_device) + # Concatenate the first 4 channels of partial_img_emb with new_init_image to match the required shape + if new_init_image.shape[0] + 4 == partial_img_emb.shape[0]: + partial_img_emb[:, 0] = torch.cat([ + image_cond[:4, 0], + new_init_image + ], dim=0) + else: + # fallback to original assignment if shape matches + partial_img_emb[:, 0] = new_init_image + else: + new_init_image = image_cond[:, 0].to(intermediate_device) + partial_img_emb[:, 0] = new_init_image + + if control_latents is not None: + partial_control_latents = control_latents[:, c] + + partial_control_camera_latents = None + if control_camera_latents is not None: + partial_control_camera_latents = control_camera_latents[:, :, c] + + partial_vace_context = None + if vace_data is not None: + window_vace_data = [] + for vace_entry in vace_data: + partial_context = vace_entry["context"][0][:, c] + if has_ref: + partial_context[:, 0] = vace_entry["context"][0][:, 0] + + window_vace_data.append({ + "context": [partial_context], + "scale": vace_entry["scale"], + "start": vace_entry["start"], + "end": vace_entry["end"], + "seq_len": vace_entry["seq_len"] + }) + + partial_vace_context = window_vace_data + + partial_audio_proj = None + if fantasytalking_embeds is not None: + partial_audio_proj = audio_proj[:, c] + + partial_latent_model_input = latent_model_input[:, c] + if latents_to_insert is not None and c[0] != 0: + partial_latent_model_input[:, :1] = latents_to_insert + + partial_unianim_data = None + if unianim_data is not None: + partial_dwpose = dwpose_data[:, :, c] + partial_dwpose_flat=rearrange(partial_dwpose, 'b c f h w -> b (f h w) c') + partial_unianim_data = { + "dwpose": partial_dwpose_flat, + "random_ref": unianim_data["random_ref"], + "strength": unianimate_poses["strength"], + "start_percent": unianimate_poses["start_percent"], + "end_percent": unianimate_poses["end_percent"] + } + + partial_add_cond = None + if add_cond is not None: + partial_add_cond = add_cond[:, :, c].to(device, dtype) + + if len(timestep.shape) != 1: + partial_timestep = timestep[:, c] + partial_timestep[:, :1] = 0 + else: + partial_timestep = timestep + #print("Partial timestep:", partial_timestep) + + noise_pred_context, new_teacache = predict_with_cfg( + partial_latent_model_input, + cfg[idx], positive, + text_embeds["negative_prompt_embeds"], + partial_timestep, idx, partial_img_emb, clip_fea, partial_control_latents, partial_vace_context, partial_unianim_data,partial_audio_proj, + partial_control_camera_latents, partial_add_cond, current_teacache, context_window=c) + + if cache_args is not None: + self.window_tracker.cache_states[window_id] = new_teacache + + window_mask = create_window_mask(noise_pred_context, c, latent_video_length, context_overlap, looped=is_looped, window_type=context_options["fuse_method"]) + noise_pred[:, c] += noise_pred_context * window_mask + counter[:, c] += window_mask + context_pbar.update_absolute(step_start_progress + (i + 1) * fraction_per_context, steps) + noise_pred /= counter + #region multitalk + elif multitalk_sampling: + original_image = cond_image = image_embeds.get("multitalk_start_image", None) + offload = image_embeds.get("force_offload", False) + tiled_vae = image_embeds.get("tiled_vae", False) + frame_num = clip_length = image_embeds.get("num_frames", 81) + vae = image_embeds.get("vae", None) + clip_embeds = image_embeds.get("clip_context", None) + colormatch = image_embeds.get("colormatch", "disabled") + motion_frame = image_embeds.get("motion_frame", 25) + target_w = image_embeds.get("target_w", None) + target_h = image_embeds.get("target_h", None) + + gen_video_list = [] + is_first_clip = True + arrive_last_frame = False + cur_motion_frames_num = 1 + audio_start_idx = iteration_count = 0 + audio_end_idx = audio_start_idx + clip_length + indices = (torch.arange(4 + 1) - 2) * 1 + + if multitalk_embeds is not None: + total_frames = len(multitalk_audio_embedding) + + estimated_iterations = total_frames // (frame_num - motion_frame) + 1 + loop_pbar = tqdm(total=estimated_iterations, desc="Generating video clips") + callback = prepare_callback(patcher, estimated_iterations) + + audio_embedding = multitalk_audio_embedding + human_num = len(audio_embedding) + audio_embs = None + while True: # start video generation iteratively + if multitalk_embeds is not None: + audio_embs = [] + # split audio with window size + for human_idx in range(human_num): + center_indices = torch.arange(audio_start_idx, audio_end_idx, 1).unsqueeze(1) + indices.unsqueeze(0) + center_indices = torch.clamp(center_indices, min=0, max=audio_embedding[human_idx].shape[0]-1) + audio_emb = audio_embedding[human_idx][center_indices].unsqueeze(0).to(device) + audio_embs.append(audio_emb) + audio_embs = torch.concat(audio_embs, dim=0).to(dtype) + + h, w = cond_image.shape[-2], cond_image.shape[-1] + lat_h, lat_w = h // VAE_STRIDE[1], w // VAE_STRIDE[2] + seq_len = ((frame_num - 1) // VAE_STRIDE[0] + 1) * lat_h * lat_w // (PATCH_SIZE[1] * PATCH_SIZE[2]) + + noise = torch.randn( + 16, (frame_num - 1) // 4 + 1, + lat_h, lat_w, dtype=torch.float32, device=torch.device("cpu"), generator=seed_g).to(device) + + # get mask + msk = torch.ones(1, frame_num, lat_h, lat_w, device=device) + msk[:, cur_motion_frames_num:] = 0 + msk = torch.concat([ + torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] + ], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2).to(dtype) # B 4 T H W + + mm.soft_empty_cache() + + # zero padding and vae encode + video_frames = torch.zeros(1, cond_image.shape[1], frame_num-cond_image.shape[2], target_h, target_w, device=device, dtype=vae.dtype) + padding_frames_pixels_values = torch.concat([cond_image.to(device, vae.dtype), video_frames], dim=2) + + vae.to(device) + y = vae.encode(padding_frames_pixels_values, device=device, tiled=tiled_vae).to(dtype) + vae.to(offload_device) + + cur_motion_frames_latent_num = int(1 + (cur_motion_frames_num-1) // 4) + latent_motion_frames = y[:, :, :cur_motion_frames_latent_num][0] # C T H W + y = torch.concat([msk, y], dim=1) # B 4+C T H W + mm.soft_empty_cache() + + if scheduler == "multitalk": + timesteps = list(np.linspace(1000, 1, steps, dtype=np.float32)) + timesteps.append(0.) + timesteps = [torch.tensor([t], device=device) for t in timesteps] + timesteps = [timestep_transform(t, shift=shift, num_timesteps=1000) for t in timesteps] + else: + sample_scheduler, timesteps = get_scheduler(scheduler, steps, shift, device, transformer.dim, flowedit_args, denoise_strength, sigmas=sigmas) + + transformed_timesteps = [] + for t in timesteps: + t_tensor = torch.tensor([t.item()], device=device) + transformed_timesteps.append(t_tensor) + + transformed_timesteps.append(torch.tensor([0.], device=device)) + timesteps = transformed_timesteps + + # sample videos + latent = noise + + # injecting motion frames + if not is_first_clip: + latent_motion_frames = latent_motion_frames.to(latent.dtype).to(device) + motion_add_noise = torch.randn(latent_motion_frames.shape, device=torch.device("cpu"), generator=seed_g).to(device).contiguous() + add_latent = add_noise(latent_motion_frames, motion_add_noise, timesteps[0]) + _, T_m, _, _ = add_latent.shape + latent[:, :T_m] = add_latent + + if offload: + #blockswap init + if transformer_options is not None: + block_swap_args = transformer_options.get("block_swap_args", None) + + if block_swap_args is not None: + transformer.use_non_blocking = block_swap_args.get("use_non_blocking", False) + for name, param in transformer.named_parameters(): + if "block" not in name: + param.data = param.data.to(device) + if "control_adapter" in name: + param.data = param.data.to(device) + elif block_swap_args["offload_txt_emb"] and "txt_emb" in name: + param.data = param.data.to(offload_device) + elif block_swap_args["offload_img_emb"] and "img_emb" in name: + param.data = param.data.to(offload_device) + + transformer.block_swap( + block_swap_args["blocks_to_swap"] - 1 , + block_swap_args["offload_txt_emb"], + block_swap_args["offload_img_emb"], + vace_blocks_to_swap = block_swap_args.get("vace_blocks_to_swap", None), + ) + + elif model["auto_cpu_offload"]: + for module in transformer.modules(): + if hasattr(module, "offload"): + module.offload() + if hasattr(module, "onload"): + module.onload() + elif model["manual_offloading"]: + transformer.to(device) + + comfy_pbar = ProgressBar(len(timesteps)-1) + for i in tqdm(range(len(timesteps)-1)): + timestep = timesteps[i] + latent_model_input = latent.to(device) + + noise_pred, self.cache_state = predict_with_cfg( + latent_model_input, + cfg[idx], + text_embeds["prompt_embeds"], + text_embeds["negative_prompt_embeds"], + timestep, idx, y.squeeze(0), clip_embeds.to(dtype), control_latents, vace_data, unianim_data, audio_proj, control_camera_latents, add_cond, + cache_state=self.cache_state, multitalk_audio_embeds=audio_embs) + + if callback is not None: + callback_latent = (latent_model_input.to(device) - noise_pred.to(device) * t.to(device) / 1000).detach().permute(1,0,2,3) + callback(iteration_count, callback_latent, None, estimated_iterations) + + # update latent + if scheduler == "multitalk": + noise_pred = -noise_pred + dt = timesteps[i] - timesteps[i + 1] + dt = dt / 1000 + latent = latent + noise_pred * dt[:, None, None, None] + else: + latent = latent.to(intermediate_device) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + timestep, + latent.unsqueeze(0), + **scheduler_step_args)[0] + latent = temp_x0.squeeze(0) + + # injecting motion frames + if not is_first_clip: + latent_motion_frames = latent_motion_frames.to(latent.dtype).to(device) + motion_add_noise = torch.randn(latent_motion_frames.shape, device=torch.device("cpu"), generator=seed_g).to(device).contiguous() + add_latent = add_noise(latent_motion_frames, motion_add_noise, timesteps[i+1]) + _, T_m, _, _ = add_latent.shape + latent[:, :T_m] = add_latent + + x0 = latent.to(device) + del latent_model_input, timestep + comfy_pbar.update(1) + + if offload: + transformer.to(offload_device) + vae.to(device) + videos = vae.decode(x0.unsqueeze(0).to(vae.dtype), device=device, tiled=tiled_vae) + vae.to(offload_device) + + # cache generated samples + videos = torch.stack(videos).cpu() # B C T H W + if colormatch != "disabled": + videos = videos[0].permute(1, 2, 3, 0).cpu().float().numpy() + from color_matcher import ColorMatcher + cm = ColorMatcher() + cm_result_list = [] + for img in videos: + cm_result = cm.transfer(src=img, ref=original_image[0].permute(1, 2, 3, 0).squeeze(0).cpu().numpy(), method=colormatch) + cm_result_list.append(torch.from_numpy(cm_result)) + + videos = torch.stack(cm_result_list, dim=0).to(torch.float32).permute(3, 0, 1, 2).unsqueeze(0) + + if is_first_clip: + gen_video_list.append(videos) + else: + gen_video_list.append(videos[:, :, cur_motion_frames_num:]) + + # decide whether is done + if arrive_last_frame: + loop_pbar.update(estimated_iterations - iteration_count) + loop_pbar.close() + break + + # update next condition frames + is_first_clip = False + cur_motion_frames_num = motion_frame + + cond_image = videos[:, :, -cur_motion_frames_num:].to(torch.float32).to(device) + + # Update progress bar + iteration_count += 1 + loop_pbar.update(1) + + # Repeat audio emb + if multitalk_embeds is not None: + audio_start_idx += (frame_num - cur_motion_frames_num) + audio_end_idx = audio_start_idx + clip_length + if audio_end_idx >= len(audio_embedding[0]): + arrive_last_frame = True + miss_lengths = [] + source_frames = [] + for human_inx in range(human_num): + source_frame = len(audio_embedding[human_inx]) + source_frames.append(source_frame) + if audio_end_idx >= len(audio_embedding[human_inx]): + miss_length = audio_end_idx - len(audio_embedding[human_inx]) + 3 + add_audio_emb = torch.flip(audio_embedding[human_inx][-1*miss_length:], dims=[0]) + audio_embedding[human_inx] = torch.cat([audio_embedding[human_inx], add_audio_emb], dim=0) + miss_lengths.append(miss_length) + else: + miss_lengths.append(0) + + gen_video_samples = torch.cat(gen_video_list, dim=2).to(torch.float32) + + del noise, latent + if force_offload: + if model["manual_offloading"]: + transformer.to(offload_device) + mm.soft_empty_cache() + gc.collect() + try: + print_memory(device) + torch.cuda.reset_peak_memory_stats(device) + except: + pass + return {"video": gen_video_samples[0].permute(1, 2, 3, 0).cpu()}, + + #region normal inference + else: + noise_pred, self.cache_state = predict_with_cfg( + latent_model_input, + cfg[idx], + text_embeds["prompt_embeds"], + text_embeds["negative_prompt_embeds"], + timestep, idx, image_cond, clip_fea, control_latents, vace_data, unianim_data, audio_proj, control_camera_latents, add_cond, + cache_state=self.cache_state) + + if latent_shift_loop: + #reverse latent shift + if latent_shift_start_percent <= current_step_percentage <= latent_shift_end_percent: + noise_pred = torch.cat([noise_pred[:, latent_video_length - shift_idx:]] + [noise_pred[:, :latent_video_length - shift_idx]], dim=1) + shift_idx = (shift_idx + latent_skip) % latent_video_length + + + if flowedit_args is None: + latent = latent.to(intermediate_device) + + if len(timestep.shape) != 1 and scheduler != "flowmatch_pusa": #5b + # all_indices is a list of indices to skip + total_indices = list(range(latent.shape[1])) + process_indices = [i for i in total_indices if i not in all_indices] + if process_indices: + latent_to_process = latent[:, process_indices] + noise_pred_to_process = noise_pred[:, process_indices] + latent_slice = sample_scheduler.step( + noise_pred_to_process.unsqueeze(0), + orig_timestep, + latent_to_process.unsqueeze(0), + **scheduler_step_args + )[0].squeeze(0) + # Reconstruct the latent tensor: keep skipped indices as-is, update others + new_latent = [] + for i in total_indices: + if i in all_indices: + new_latent.append(latent[:, i:i+1]) + else: + j = process_indices.index(i) + new_latent.append(latent_slice[:, j:j+1]) + latent = torch.cat(new_latent, dim=1) + else: + latent = sample_scheduler.step( + noise_pred[:, :orig_noise_len].unsqueeze(0) if recammaster is not None else noise_pred.unsqueeze(0), + timestep, + latent[:, :orig_noise_len].unsqueeze(0) if recammaster is not None else latent.unsqueeze(0), + **scheduler_step_args)[0].squeeze(0) + + if freeinit_args is not None: + current_latent = latent.clone() + + if callback is not None: + if recammaster is not None: + callback_latent = (latent_model_input[:, :orig_noise_len].to(device) - noise_pred[:, :orig_noise_len].to(device) * t.to(device) / 1000).detach() + elif phantom_latents is not None: + callback_latent = (latent_model_input[:,:-phantom_latents.shape[1]].to(device) - noise_pred[:,:-phantom_latents.shape[1]].to(device) * t.to(device) / 1000).detach() + else: + callback_latent = (latent_model_input.to(device) - noise_pred.to(device) * t.to(device) / 1000).detach() + callback(idx, callback_latent.permute(1,0,2,3), None, len(timesteps)) + else: + pbar.update(1) + else: + if callback is not None: + callback_latent = (zt_tgt.to(device) - vt_tgt.to(device) * t.to(device) / 1000).detach() + callback(idx, callback_latent.permute(1,0,2,3), None, len(timesteps)) + else: + pbar.update(1) + except Exception as e: + log.error(f"Error during sampling: {e}") + if force_offload: + if model["manual_offloading"]: + offload_transformer(transformer) + raise e + + if phantom_latents is not None: + latent = latent[:,:-phantom_latents.shape[1]] + + if cache_args is not None: + cache_report(transformer, cache_args) + + if force_offload: + if model["manual_offloading"]: + offload_transformer(transformer) + + try: + print_memory(device) + #torch.cuda.memory._dump_snapshot("wanvideowrapper_memory_dump.pt") + #torch.cuda.memory._record_memory_history(enabled=None) + torch.cuda.reset_peak_memory_stats(device) + except: + pass + return ({ + "samples": latent.unsqueeze(0).cpu(), + "looped": is_looped, + "end_image": end_image if not fun_or_fl2v_model else None, + "has_ref": has_ref, + "drop_last": drop_last, + "generator_state": seed_g.get_state(), + },{ + "samples": callback_latent.unsqueeze(0).cpu() if callback is not None else None, + }) + +#region VideoDecode +class WanVideoDecode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "vae": ("WANVAE",), + "samples": ("LATENT",), + "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": ( + "Drastically reduces memory use but will introduce seams at tile stride boundaries. " + "The location and number of seams is dictated by the tile stride size. " + "The visibility of seams can be controlled by increasing the tile size. " + "Seams become less obvious at 1.5x stride and are barely noticeable at 2x stride size. " + "Which is to say if you use a stride width of 160, the seams are barely noticeable with a tile width of 320." + )}), + "tile_x": ("INT", {"default": 272, "min": 40, "max": 2048, "step": 8, "tooltip": "Tile width in pixels. Smaller values use less VRAM but will make seams more obvious."}), + "tile_y": ("INT", {"default": 272, "min": 40, "max": 2048, "step": 8, "tooltip": "Tile height in pixels. Smaller values use less VRAM but will make seams more obvious."}), + "tile_stride_x": ("INT", {"default": 144, "min": 32, "max": 2040, "step": 8, "tooltip": "Tile stride width in pixels. Smaller values use less VRAM but will introduce more seams."}), + "tile_stride_y": ("INT", {"default": 128, "min": 32, "max": 2040, "step": 8, "tooltip": "Tile stride height in pixels. Smaller values use less VRAM but will introduce more seams."}), + }, + "optional": { + "normalization": (["default", "minmax"], {"advanced": True}), + } + } + + @classmethod + def VALIDATE_INPUTS(s, tile_x, tile_y, tile_stride_x, tile_stride_y): + if tile_x <= tile_stride_x: + return "Tile width must be larger than the tile stride width." + if tile_y <= tile_stride_y: + return "Tile height must be larger than the tile stride height." + return True + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) + FUNCTION = "decode" + CATEGORY = "WanVideoWrapper" + + def decode(self, vae, samples, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, normalization="default"): + mm.soft_empty_cache() + video = samples.get("video", None) + if video is not None: + video = torch.clamp(video, -1.0, 1.0) + video = (video + 1.0) / 2.0 + return video.cpu(), + latents = samples["samples"] + end_image = samples.get("end_image", None) + has_ref = samples.get("has_ref", False) + drop_last = samples.get("drop_last", False) + is_looped = samples.get("looped", False) + + vae.to(device) + + latents = latents.to(device = device, dtype = vae.dtype) + + mm.soft_empty_cache() + + if has_ref: + latents = latents[:, :, 1:] + if drop_last: + latents = latents[:, :, :-1] + + if type(vae).__name__ == "TAEHV": + images = vae.decode_video(latents.permute(0, 2, 1, 3, 4))[0].permute(1, 0, 2, 3) + images = torch.clamp(images, 0.0, 1.0) + images = images.permute(1, 2, 3, 0).cpu().float() + return (images,) + else: + if end_image is not None: + enable_vae_tiling = False + images = vae.decode(latents, device=device, end_=(end_image is not None), tiled=enable_vae_tiling, tile_size=(tile_x//8, tile_y//8), tile_stride=(tile_stride_x//8, tile_stride_y//8))[0] + vae.model.clear_cache() + + images = images.cpu().float() + + if normalization == "minmax": + images.sub_(images.min()).div_(images.max() - images.min()) + else: + images.clamp_(-1.0, 1.0) + images.add_(1.0).div_(2.0) + + if is_looped: + temp_latents = torch.cat([latents[:, :, -3:]] + [latents[:, :, :2]], dim=2) + temp_images = vae.decode(temp_latents, device=device, end_=(end_image is not None), tiled=enable_vae_tiling, tile_size=(tile_x//vae.upsampling_factor, tile_y//vae.upsampling_factor), tile_stride=(tile_stride_x//vae.upsampling_factor, tile_stride_y//vae.upsampling_factor))[0] + temp_images = temp_images.cpu().float() + temp_images = (temp_images - temp_images.min()) / (temp_images.max() - temp_images.min()) + images = torch.cat([temp_images[:, 9:].to(images), images[:, 5:]], dim=1) + + if end_image is not None: + images = images[:, 0:-1] + + vae.model.clear_cache() + vae.to(offload_device) + mm.soft_empty_cache() + + images.clamp_(0.0, 1.0) + + return (images.permute(1, 2, 3, 0),) + +#region VideoEncode +class WanVideoEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "vae": ("WANVAE",), + "image": ("IMAGE",), + "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}), + "tile_x": ("INT", {"default": 272, "min": 64, "max": 2048, "step": 1, "tooltip": "Tile size in pixels, smaller values use less VRAM, may introduce more seams"}), + "tile_y": ("INT", {"default": 272, "min": 64, "max": 2048, "step": 1, "tooltip": "Tile size in pixels, smaller values use less VRAM, may introduce more seams"}), + "tile_stride_x": ("INT", {"default": 144, "min": 32, "max": 2048, "step": 32, "tooltip": "Tile stride in pixels, smaller values use less VRAM, may introduce more seams"}), + "tile_stride_y": ("INT", {"default": 128, "min": 32, "max": 2048, "step": 32, "tooltip": "Tile stride in pixels, smaller values use less VRAM, may introduce more seams"}), + }, + "optional": { + "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of noise augmentation, helpful for leapfusion I2V where some noise can add motion and give sharper results"}), + "latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for leapfusion I2V where lower values allow for more motion"}), + "mask": ("MASK", ), + } + } + + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("samples",) + FUNCTION = "encode" + CATEGORY = "WanVideoWrapper" + + def encode(self, vae, image, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, noise_aug_strength=0.0, latent_strength=1.0, mask=None): + vae.to(device) + + image = image.clone() + + B, H, W, C = image.shape + if W % 16 != 0 or H % 16 != 0: + new_height = (H // 16) * 16 + new_width = (W // 16) * 16 + log.warning(f"Image size {W}x{H} is not divisible by 16, resizing to {new_width}x{new_height}") + image = common_upscale(image.movedim(-1, 1), new_width, new_height, "lanczos", "disabled").movedim(1, -1) + + if image.shape[-1] == 4: + image = image[..., :3] + image = image.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W + + if noise_aug_strength > 0.0: + image = add_noise_to_reference_video(image, ratio=noise_aug_strength) + + if isinstance(vae, TAEHV): + latents = vae.encode_video(image.permute(0, 2, 1, 3, 4), parallel=False)# B, T, C, H, W + latents = latents.permute(0, 2, 1, 3, 4) + else: + latents = vae.encode(image * 2.0 - 1.0, device=device, tiled=enable_vae_tiling, tile_size=(tile_x//vae.upsampling_factor, tile_y//vae.upsampling_factor), tile_stride=(tile_stride_x//vae.upsampling_factor, tile_stride_y//vae.upsampling_factor)) + vae.model.clear_cache() + if latent_strength != 1.0: + latents *= latent_strength + + log.info(f"encoded latents shape {latents.shape}") + latent_mask = None + if mask is None: + vae.to(offload_device) + else: + target_h, target_w = latents.shape[3:] + + mask = torch.nn.functional.interpolate( + mask.unsqueeze(0).unsqueeze(0), # Add batch and channel dims [1,1,T,H,W] + size=(latents.shape[2], target_h, target_w), + mode='trilinear', + align_corners=False + ).squeeze(0) # Remove batch dim, keep channel dim + + # Add batch & channel dims for final output + latent_mask = mask.unsqueeze(0).repeat(1, latents.shape[1], 1, 1, 1) + log.info(f"latent mask shape {latent_mask.shape}") + vae.to(offload_device) + mm.soft_empty_cache() + + return ({"samples": latents, "mask": latent_mask},) + +class WanVideoLatentReScale: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "samples": ("LATENT",), + "direction": (["comfy_to_wrapper", "wrapper_to_comfy"], {"tooltip": "Direction to rescale latents, from comfy to wrapper or vice versa"}), + } + } + + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("samples",) + FUNCTION = "encode" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Rescale latents to match the expected range for encoding or decoding. Can be used to " + + def encode(self, samples, direction): + samples = samples.copy() + latents = samples["samples"] + + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + mean = torch.tensor(mean).view(1, latents.shape[1], 1, 1, 1) + std = torch.tensor(std).view(1, latents.shape[1], 1, 1, 1) + inv_std = (1.0 / std).view(1, latents.shape[1], 1, 1, 1) + if direction == "comfy_to_wrapper": + latents = (latents - mean.to(latents)) * inv_std.to(latents) + elif direction == "wrapper_to_comfy": + latents = latents / inv_std.to(latents) + mean.to(latents) + + samples["samples"] = latents + + return (samples,) + +NODE_CLASS_MAPPINGS = { + "WanVideoSampler": WanVideoSampler, + "WanVideoDecode": WanVideoDecode, + "WanVideoTextEncode": WanVideoTextEncode, + "WanVideoTextEncodeSingle": WanVideoTextEncodeSingle, + "WanVideoClipVisionEncode": WanVideoClipVisionEncode, + "WanVideoImageToVideoEncode": WanVideoImageToVideoEncode, + "WanVideoEncode": WanVideoEncode, + "WanVideoEmptyEmbeds": WanVideoEmptyEmbeds, + "WanVideoEnhanceAVideo": WanVideoEnhanceAVideo, + "WanVideoContextOptions": WanVideoContextOptions, + "WanVideoTextEmbedBridge": WanVideoTextEmbedBridge, + "WanVideoFlowEdit": WanVideoFlowEdit, + "WanVideoControlEmbeds": WanVideoControlEmbeds, + "WanVideoSLG": WanVideoSLG, + "WanVideoLoopArgs": WanVideoLoopArgs, + "WanVideoSetBlockSwap": WanVideoSetBlockSwap, + "WanVideoExperimentalArgs": WanVideoExperimentalArgs, + "WanVideoVACEEncode": WanVideoVACEEncode, + "WanVideoPhantomEmbeds": WanVideoPhantomEmbeds, + "WanVideoRealisDanceLatents": WanVideoRealisDanceLatents, + "WanVideoApplyNAG": WanVideoApplyNAG, + "WanVideoMiniMaxRemoverEmbeds": WanVideoMiniMaxRemoverEmbeds, + "WanVideoFreeInitArgs": WanVideoFreeInitArgs, + "WanVideoSetRadialAttention": WanVideoSetRadialAttention, + "WanVideoBlockList": WanVideoBlockList, + "WanVideoTextEncodeCached": WanVideoTextEncodeCached, + "WanVideoAddExtraLatent": WanVideoAddExtraLatent, + "WanVideoLatentReScale": WanVideoLatentReScale, + } +NODE_DISPLAY_NAME_MAPPINGS = { + "WanVideoSampler": "WanVideo Sampler", + "WanVideoDecode": "WanVideo Decode", + "WanVideoTextEncode": "WanVideo TextEncode", + "WanVideoTextEncodeSingle": "WanVideo TextEncodeSingle", + "WanVideoTextImageEncode": "WanVideo TextImageEncode (IP2V)", + "WanVideoClipVisionEncode": "WanVideo ClipVision Encode", + "WanVideoImageToVideoEncode": "WanVideo ImageToVideo Encode", + "WanVideoEncode": "WanVideo Encode", + "WanVideoEmptyEmbeds": "WanVideo Empty Embeds", + "WanVideoEnhanceAVideo": "WanVideo Enhance-A-Video", + "WanVideoContextOptions": "WanVideo Context Options", + "WanVideoTextEmbedBridge": "WanVideo TextEmbed Bridge", + "WanVideoFlowEdit": "WanVideo FlowEdit", + "WanVideoControlEmbeds": "WanVideo Control Embeds", + "WanVideoSLG": "WanVideo SLG", + "WanVideoLoopArgs": "WanVideo Loop Args", + "WanVideoSetBlockSwap": "WanVideo Set BlockSwap", + "WanVideoExperimentalArgs": "WanVideo Experimental Args", + "WanVideoVACEEncode": "WanVideo VACE Encode", + "WanVideoPhantomEmbeds": "WanVideo Phantom Embeds", + "WanVideoRealisDanceLatents": "WanVideo RealisDance Latents", + "WanVideoApplyNAG": "WanVideo Apply NAG", + "WanVideoMiniMaxRemoverEmbeds": "WanVideo MiniMax Remover Embeds", + "WanVideoFreeInitArgs": "WanVideo Free Init Args", + "WanVideoSetRadialAttention": "WanVideo Set Radial Attention", + "WanVideoBlockList": "WanVideo Block List", + "WanVideoTextEncodeCached": "WanVideo TextEncode Cached", + "WanVideoAddExtraLatent": "WanVideo Add Extra Latent", + "WanVideoLatentReScale": "WanVideo Latent ReScale", + } diff --git a/nodes_old.py b/nodes_old.py new file mode 100644 index 00000000..834d335d --- /dev/null +++ b/nodes_old.py @@ -0,0 +1,3466 @@ +import os, gc, math +import torch +import torch.nn.functional as F +import numpy as np +from tqdm import tqdm +import inspect +import hashlib +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler + +from .wanvideo.modules.model import rope_params +from .custom_linear import remove_lora_from_module, set_lora_params +from .wanvideo.schedulers import get_scheduler, get_sampling_sigmas, retrieve_timesteps, scheduler_list +from .gguf.gguf import set_lora_params_gguf +from .multitalk.multitalk import timestep_transform, add_noise +from .utils import(log, print_memory, apply_lora, clip_encode_image_tiled, fourier_filter, + add_noise_to_reference_video, optimized_scale, setup_radial_attention, + compile_model, dict_to_device, tangential_projection, set_module_tensor_to_device, get_raag_guidance) +from .cache_methods.cache_methods import cache_report +from .enhance_a_video.globals import set_enhance_weight, set_num_frames +from .taehv import TAEHV + +from einops import rearrange + +from comfy import model_management as mm +from comfy.utils import ProgressBar, common_upscale +from comfy.clip_vision import clip_preprocess, ClipVisionModel +from comfy.cli_args import args, LatentPreviewMethod +import folder_paths + +script_directory = os.path.dirname(os.path.abspath(__file__)) + +device = mm.get_torch_device() +offload_device = mm.unet_offload_device() + +VAE_STRIDE = (4, 8, 8) +PATCH_SIZE = (1, 2, 2) + +def offload_transformer(transformer): + transformer.teacache_state.clear_all() + transformer.magcache_state.clear_all() + transformer.easycache_state.clear_all() + transformer.to(offload_device) + mm.soft_empty_cache() + gc.collect() + +class WanVideoEnhanceAVideo: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "weight": ("FLOAT", {"default": 2.0, "min": 0, "max": 100, "step": 0.01, "tooltip": "The feta Weight of the Enhance-A-Video"}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply Enhance-A-Video"}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply Enhance-A-Video"}), + }, + } + RETURN_TYPES = ("FETAARGS",) + RETURN_NAMES = ("feta_args",) + FUNCTION = "setargs" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video" + + def setargs(self, **kwargs): + return (kwargs, ) + +class WanVideoSetBlockSwap: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("WANVIDEOMODEL", ), + }, + "optional": { + "block_swap_args": ("BLOCKSWAPARGS", ), + } + } + + RETURN_TYPES = ("WANVIDEOMODEL",) + RETURN_NAMES = ("model", ) + FUNCTION = "loadmodel" + CATEGORY = "WanVideoWrapper" + + def loadmodel(self, model, block_swap_args=None): + if block_swap_args is None: + return (model,) + patcher = model.clone() + if 'transformer_options' not in patcher.model_options: + patcher.model_options['transformer_options'] = {} + patcher.model_options["transformer_options"]["block_swap_args"] = block_swap_args + + return (patcher,) + +class WanVideoSetRadialAttention: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("WANVIDEOMODEL", ), + "dense_attention_mode": ([ + "sdpa", + "flash_attn_2", + "flash_attn_3", + "sageattn", + "sparse_sage_attention", + ], {"default": "sageattn", "tooltip": "The attention mode for dense attention"}), + "dense_blocks": ("INT", {"default": 1, "min": 0, "max": 40, "step": 1, "tooltip": "Number of blocks to apply normal attention to"}), + "dense_vace_blocks": ("INT", {"default": 1, "min": 0, "max": 15, "step": 1, "tooltip": "Number of vace blocks to apply normal attention to"}), + "dense_timesteps": ("INT", {"default": 2, "min": 0, "max": 100, "step": 1, "tooltip": "The step to start applying sparse attention"}), + "decay_factor": ("FLOAT", {"default": 0.2, "min": 0, "max": 1, "step": 0.01, "tooltip": "Controls how quickly the attention window shrinks as the distance between frames increases in the sparse attention mask."}), + "block_size":([128, 64], {"default": 128, "tooltip": "Radial attention block size, larger blocks are faster but restricts usable dimensions more."}), + } + } + + RETURN_TYPES = ("WANVIDEOMODEL",) + RETURN_NAMES = ("model", ) + FUNCTION = "loadmodel" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Sets radial attention parameters, dense attention refers to normal attention" + + def loadmodel(self, model, dense_attention_mode, dense_blocks, dense_vace_blocks, dense_timesteps, decay_factor, block_size): + if "radial" not in model.model.diffusion_model.attention_mode: + raise Exception("Enable radial attention first in the model loader.") + + patcher = model.clone() + if 'transformer_options' not in patcher.model_options: + patcher.model_options['transformer_options'] = {} + + patcher.model_options["transformer_options"]["dense_attention_mode"] = dense_attention_mode + patcher.model_options["transformer_options"]["dense_blocks"] = dense_blocks + patcher.model_options["transformer_options"]["dense_vace_blocks"] = dense_vace_blocks + patcher.model_options["transformer_options"]["dense_timesteps"] = dense_timesteps + patcher.model_options["transformer_options"]["decay_factor"] = decay_factor + patcher.model_options["transformer_options"]["block_size"] = block_size + + return (patcher,) + +class WanVideoBlockList: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "blocks": ("STRING", {"default": "1", "multiline":True}), + } + } + + RETURN_TYPES = ("INT",) + RETURN_NAMES = ("block_list", ) + FUNCTION = "create_list" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Comma separated list of blocks to apply block swap to, can also use ranges like '0-5' or '0,2,3-5' etc., can be connected to the dense_blocks input of 'WanVideoSetRadialAttention' node" + + def create_list(self, blocks): + block_list = [] + for line in blocks.splitlines(): + for part in line.split(","): + part = part.strip() + if not part: + continue + if "-" in part: + try: + start, end = map(int, part.split("-", 1)) + block_list.extend(range(start, end + 1)) + except Exception: + raise ValueError(f"Invalid range: '{part}'") + else: + try: + block_list.append(int(part)) + except Exception: + raise ValueError(f"Invalid integer: '{part}'") + return (block_list,) + + + +# In-memory cache for prompt extender output +_extender_cache = {} + +cache_dir = os.path.join(script_directory, 'text_embed_cache') + +def get_cache_path(prompt): + cache_key = prompt.strip() + cache_hash = hashlib.sha256(cache_key.encode('utf-8')).hexdigest() + return os.path.join(cache_dir, f"{cache_hash}.pt") + +def get_cached_text_embeds(positive_prompt, negative_prompt): + + os.makedirs(cache_dir, exist_ok=True) + + context = None + context_null = None + + pos_cache_path = get_cache_path(positive_prompt) + neg_cache_path = get_cache_path(negative_prompt) + + # Try to load positive prompt embeds + if os.path.exists(pos_cache_path): + try: + log.info(f"Loading prompt embeds from cache: {pos_cache_path}") + context = torch.load(pos_cache_path) + except Exception as e: + log.warning(f"Failed to load cache: {e}, will re-encode.") + + # Try to load negative prompt embeds + if os.path.exists(neg_cache_path): + try: + log.info(f"Loading prompt embeds from cache: {neg_cache_path}") + context_null = torch.load(neg_cache_path) + except Exception as e: + log.warning(f"Failed to load cache: {e}, will re-encode.") + + return context, context_null + +class WanVideoTextEncodeCached: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model_name": (folder_paths.get_filename_list("text_encoders"), {"tooltip": "These models are loaded from 'ComfyUI/models/text_encoders'"}), + "precision": (["fp32", "bf16"], + {"default": "bf16"} + ), + "positive_prompt": ("STRING", {"default": "", "multiline": True} ), + "negative_prompt": ("STRING", {"default": "", "multiline": True} ), + "quantization": (['disabled', 'fp8_e4m3fn'], {"default": 'disabled', "tooltip": "optional quantization method"}), + "use_disk_cache": ("BOOLEAN", {"default": True, "tooltip": "Cache the text embeddings to disk for faster re-use, under the custom_nodes/ComfyUI-WanVideoWrapper/text_embed_cache directory"}), + "device": (["gpu", "cpu"], {"default": "gpu", "tooltip": "Device to run the text encoding on."}), + }, + "optional": { + "extender_args": ("WANVIDEOPROMPTEXTENDER_ARGS", {"tooltip": "Use this node to extend the prompt with additional text."}), + } + } + + RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", "WANVIDEOTEXTEMBEDS", "STRING") + RETURN_NAMES = ("text_embeds", "negative_text_embeds", "positive_prompt") + OUTPUT_TOOLTIPS = ("The text embeddings for both prompts", "The text embeddings for the negative prompt only (for NAG)", "Positive prompt to display prompt extender results") + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = """Encodes text prompts into text embeddings. This node loads and completely unloads the T5 after done, +leaving no VRAM or RAM imprint. If prompts have been cached before T5 is not loaded at all. +negative output is meant to be used with NAG, it contains only negative prompt embeddings. + +Additionally you can provide a Qwen LLM model to extend the positive prompt with either one +of the original Wan templates or a custom system prompt. +""" + + + def process(self, model_name, precision, positive_prompt, negative_prompt, quantization='disabled', use_disk_cache=True, device="gpu", extender_args=None): + from .nodes_model_loading import LoadWanVideoT5TextEncoder + pbar = ProgressBar(3) + + echoshot = True if "[1]" in positive_prompt else False + + # Handle prompt extension with in-memory cache + orig_prompt = positive_prompt + if extender_args is not None: + extender_key = (orig_prompt, str(extender_args)) + if extender_key in _extender_cache: + positive_prompt = _extender_cache[extender_key] + log.info(f"Loaded extended prompt from in-memory cache: {positive_prompt}") + else: + from .qwen.qwen import QwenLoader, WanVideoPromptExtender + log.info("Using WanVideoPromptExtender to process prompts") + qwen, = QwenLoader().load( + extender_args["model"], + load_device="main_device" if device == "gpu" else "cpu", + precision=precision) + positive_prompt, = WanVideoPromptExtender().generate( + qwen=qwen, + max_new_tokens=extender_args["max_new_tokens"], + prompt=orig_prompt, + device=device, + force_offload=False, + custom_system_prompt=extender_args["system_prompt"], + seed=extender_args["seed"] + ) + log.info(f"Extended positive prompt: {positive_prompt}") + _extender_cache[extender_key] = positive_prompt + del qwen + pbar.update(1) + + # Now check disk cache using the (possibly extended) prompt + if use_disk_cache: + context, context_null = get_cached_text_embeds(positive_prompt, negative_prompt) + if context is not None and context_null is not None: + return{ + "prompt_embeds": context, + "negative_prompt_embeds": context_null, + "echoshot": echoshot, + },{"prompt_embeds": context_null}, positive_prompt + + t5, = LoadWanVideoT5TextEncoder().loadmodel(model_name, precision, "main_device", quantization) + pbar.update(1) + + prompt_embeds_dict, = WanVideoTextEncode().process( + positive_prompt=positive_prompt, + negative_prompt=negative_prompt, + t5=t5, + force_offload=False, + model_to_offload=None, + use_disk_cache=use_disk_cache, + device=device + ) + pbar.update(1) + del t5 + mm.soft_empty_cache() + gc.collect() + return (prompt_embeds_dict, {"prompt_embeds": prompt_embeds_dict["negative_prompt_embeds"]}, positive_prompt) + +#region TextEncode +class WanVideoTextEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "positive_prompt": ("STRING", {"default": "", "multiline": True} ), + "negative_prompt": ("STRING", {"default": "", "multiline": True} ), + }, + "optional": { + "t5": ("WANTEXTENCODER",), + "force_offload": ("BOOLEAN", {"default": True}), + "model_to_offload": ("WANVIDEOMODEL", {"tooltip": "Model to move to offload_device before encoding"}), + "use_disk_cache": ("BOOLEAN", {"default": False, "tooltip": "Cache the text embeddings to disk for faster re-use, under the custom_nodes/ComfyUI-WanVideoWrapper/text_embed_cache directory"}), + "device": (["gpu", "cpu"], {"default": "gpu", "tooltip": "Device to run the text encoding on."}), + } + } + + RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) + RETURN_NAMES = ("text_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Encodes text prompts into text embeddings. For rudimentary prompt travel you can input multiple prompts separated by '|', they will be equally spread over the video length" + + + def process(self, positive_prompt, negative_prompt, t5=None, force_offload=True, model_to_offload=None, use_disk_cache=False, device="gpu"): + if t5 is None and not use_disk_cache: + raise ValueError("T5 encoder is required for text encoding. Please provide a valid T5 encoder or enable disk cache.") + + echoshot = True if "[1]" in positive_prompt else False + + if use_disk_cache: + context, context_null = get_cached_text_embeds(positive_prompt, negative_prompt) + if context is not None and context_null is not None: + return{ + "prompt_embeds": context, + "negative_prompt_embeds": context_null, + "echoshot": echoshot, + }, + + if t5 is None: + raise ValueError("No cached text embeds found for prompts, please provide a T5 encoder.") + + if model_to_offload is not None and device == "gpu": + log.info(f"Moving video model to {offload_device}") + model_to_offload.model.to(offload_device) + + encoder = t5["model"] + dtype = t5["dtype"] + + positive_prompts = [] + all_weights = [] + + # Split positive prompts and process each with weights + if "|" in positive_prompt: + log.info("Multiple positive prompts detected, splitting by '|'") + positive_prompts_raw = [p.strip() for p in positive_prompt.split('|')] + elif "[1]" in positive_prompt: + log.info("Multiple positive prompts detected, splitting by [#] and enabling EchoShot") + import re + segments = re.split(r'\[\d+\]', positive_prompt) + positive_prompts_raw = [segment.strip() for segment in segments if segment.strip()] + assert len(positive_prompts_raw) > 1 and len(positive_prompts_raw) < 7, 'Input shot num must between 2~6 !' + else: + positive_prompts_raw = [positive_prompt.strip()] + + for p in positive_prompts_raw: + cleaned_prompt, weights = self.parse_prompt_weights(p) + positive_prompts.append(cleaned_prompt) + all_weights.append(weights) + + mm.soft_empty_cache() + + if device == "gpu": + device_to = mm.get_torch_device() + else: + device_to = torch.device("cpu") + + if encoder.quantization == "fp8_e4m3fn": + cast_dtype = torch.float8_e4m3fn + else: + cast_dtype = encoder.dtype + + params_to_keep = {'norm', 'pos_embedding', 'token_embedding'} + for name, param in encoder.model.named_parameters(): + dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype + value = encoder.state_dict[name] if hasattr(encoder, 'state_dict') else encoder.model.state_dict()[name] + set_module_tensor_to_device(encoder.model, name, device=device_to, dtype=dtype_to_use, value=value) + if hasattr(encoder, 'state_dict'): + del encoder.state_dict + mm.soft_empty_cache() + gc.collect() + + with torch.autocast(device_type=mm.get_autocast_device(device_to), dtype=encoder.dtype, enabled=encoder.quantization != 'disabled'): + # Encode positive if not loaded from cache + if use_disk_cache and context is not None: + pass + else: + context = encoder(positive_prompts, device_to) + # Apply weights to embeddings if any were extracted + for i, weights in enumerate(all_weights): + for text, weight in weights.items(): + log.info(f"Applying weight {weight} to prompt: {text}") + if len(weights) > 0: + context[i] = context[i] * weight + + # Encode negative if not loaded from cache + if use_disk_cache and context_null is not None: + pass + else: + context_null = encoder([negative_prompt], device_to) + + if force_offload: + encoder.model.to(offload_device) + mm.soft_empty_cache() + gc.collect() + + prompt_embeds_dict = { + "prompt_embeds": context, + "negative_prompt_embeds": context_null, + "echoshot": echoshot, + } + + # Save each part to its own cache file if needed + if use_disk_cache: + pos_cache_path = get_cache_path(positive_prompt) + neg_cache_path = get_cache_path(negative_prompt) + try: + if not os.path.exists(pos_cache_path): + torch.save(context, pos_cache_path) + log.info(f"Saved prompt embeds to cache: {pos_cache_path}") + except Exception as e: + log.warning(f"Failed to save cache: {e}") + try: + if not os.path.exists(neg_cache_path): + torch.save(context_null, neg_cache_path) + log.info(f"Saved prompt embeds to cache: {neg_cache_path}") + except Exception as e: + log.warning(f"Failed to save cache: {e}") + + return (prompt_embeds_dict,) + + def parse_prompt_weights(self, prompt): + """Extract text and weights from prompts with (text:weight) format""" + import re + + # Parse all instances of (text:weight) in the prompt + pattern = r'\((.*?):([\d\.]+)\)' + matches = re.findall(pattern, prompt) + + # Replace each match with just the text part + cleaned_prompt = prompt + weights = {} + + for match in matches: + text, weight = match + orig_text = f"({text}:{weight})" + cleaned_prompt = cleaned_prompt.replace(orig_text, text) + weights[text] = float(weight) + + return cleaned_prompt, weights + +class WanVideoTextEncodeSingle: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "prompt": ("STRING", {"default": "", "multiline": True} ), + }, + "optional": { + "t5": ("WANTEXTENCODER",), + "force_offload": ("BOOLEAN", {"default": True}), + "model_to_offload": ("WANVIDEOMODEL", {"tooltip": "Model to move to offload_device before encoding"}), + "use_disk_cache": ("BOOLEAN", {"default": False, "tooltip": "Cache the text embeddings to disk for faster re-use, under the custom_nodes/ComfyUI-WanVideoWrapper/text_embed_cache directory"}), + "device": (["gpu", "cpu"], {"default": "gpu", "tooltip": "Device to run the text encoding on."}), + } + } + + RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) + RETURN_NAMES = ("text_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Encodes text prompt into text embedding." + + def process(self, prompt, t5=None, force_offload=True, model_to_offload=None, use_disk_cache=False, device="gpu"): + # Unified cache logic: use a single cache file per unique prompt + encoded = None + echoshot = True if "[1]" in prompt else False + if use_disk_cache: + cache_dir = os.path.join(script_directory, 'text_embed_cache') + os.makedirs(cache_dir, exist_ok=True) + def get_cache_path(prompt): + cache_key = prompt.strip() + cache_hash = hashlib.sha256(cache_key.encode('utf-8')).hexdigest() + return os.path.join(cache_dir, f"{cache_hash}.pt") + cache_path = get_cache_path(prompt) + if os.path.exists(cache_path): + try: + log.info(f"Loading prompt embeds from cache: {cache_path}") + encoded = torch.load(cache_path) + except Exception as e: + log.warning(f"Failed to load cache: {e}, will re-encode.") + + if t5 is None and encoded is None: + raise ValueError("No cached text embeds found for prompts, please provide a T5 encoder.") + + if encoded is None: + if model_to_offload is not None and device == "gpu": + log.info(f"Moving video model to {offload_device}") + model_to_offload.model.to(offload_device) + mm.soft_empty_cache() + + encoder = t5["model"] + dtype = t5["dtype"] + + if device == "gpu": + device_to = mm.get_torch_device() + else: + device_to = torch.device("cpu") + + if encoder.quantization == "fp8_e4m3fn": + cast_dtype = torch.float8_e4m3fn + else: + cast_dtype = encoder.dtype + params_to_keep = {'norm', 'pos_embedding', 'token_embedding'} + for name, param in encoder.model.named_parameters(): + dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype + value = encoder.state_dict[name] if hasattr(encoder, 'state_dict') else encoder.model.state_dict()[name] + set_module_tensor_to_device(encoder.model, name, device=device_to, dtype=dtype_to_use, value=value) + if hasattr(encoder, 'state_dict'): + del encoder.state_dict + mm.soft_empty_cache() + gc.collect() + with torch.autocast(device_type=mm.get_autocast_device(device_to), dtype=encoder.dtype, enabled=encoder.quantization != 'disabled'): + encoded = encoder([prompt], device_to) + + if force_offload: + encoder.model.to(offload_device) + mm.soft_empty_cache() + + # Save to cache if enabled + if use_disk_cache: + try: + if not os.path.exists(cache_path): + torch.save(encoded, cache_path) + log.info(f"Saved prompt embeds to cache: {cache_path}") + except Exception as e: + log.warning(f"Failed to save cache: {e}") + + prompt_embeds_dict = { + "prompt_embeds": encoded, + "negative_prompt_embeds": None, + "echoshot": echoshot + } + return (prompt_embeds_dict,) + +class WanVideoApplyNAG: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "original_text_embeds": ("WANVIDEOTEXTEMBEDS",), + "nag_text_embeds": ("WANVIDEOTEXTEMBEDS",), + "nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.1}), + "nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.1}), + "nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}), + }, + } + + RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) + RETURN_NAMES = ("text_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Adds NAG prompt embeds to original prompt embeds: 'https://github.com/ChenDarYen/Normalized-Attention-Guidance'" + + def process(self, original_text_embeds, nag_text_embeds, nag_scale, nag_tau, nag_alpha): + prompt_embeds_dict_copy = original_text_embeds.copy() + prompt_embeds_dict_copy.update({ + "nag_prompt_embeds": nag_text_embeds["prompt_embeds"], + "nag_params": { + "nag_scale": nag_scale, + "nag_tau": nag_tau, + "nag_alpha": nag_alpha, + } + }) + return (prompt_embeds_dict_copy,) + +class WanVideoTextEmbedBridge: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "positive": ("CONDITIONING",), + }, + "optional": { + "negative": ("CONDITIONING",), + } + } + + RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) + RETURN_NAMES = ("text_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Bridge between ComfyUI native text embedding and WanVideoWrapper text embedding" + + def process(self, positive, negative=None): + prompt_embeds_dict = { + "prompt_embeds": positive[0][0].to(device), + "negative_prompt_embeds": negative[0][0].to(device) if negative is not None else None, + } + return (prompt_embeds_dict,) + +#region clip vision +class WanVideoClipVisionEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip_vision": ("CLIP_VISION",), + "image_1": ("IMAGE", {"tooltip": "Image to encode"}), + "strength_1": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}), + "strength_2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}), + "crop": (["center", "disabled"], {"default": "center", "tooltip": "Crop image to 224x224 before encoding"}), + "combine_embeds": (["average", "sum", "concat", "batch"], {"default": "average", "tooltip": "Method to combine multiple clip embeds"}), + "force_offload": ("BOOLEAN", {"default": True}), + }, + "optional": { + "image_2": ("IMAGE", ), + "negative_image": ("IMAGE", {"tooltip": "image to use for uncond"}), + "tiles": ("INT", {"default": 0, "min": 0, "max": 16, "step": 2, "tooltip": "Use matteo's tiled image encoding for improved accuracy"}), + "ratio": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Ratio of the tile average"}), + } + } + + RETURN_TYPES = ("WANVIDIMAGE_CLIPEMBEDS",) + RETURN_NAMES = ("image_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, clip_vision, image_1, strength_1, strength_2, force_offload, crop, combine_embeds, image_2=None, negative_image=None, tiles=0, ratio=1.0): + image_mean = [0.48145466, 0.4578275, 0.40821073] + image_std = [0.26862954, 0.26130258, 0.27577711] + + if image_2 is not None: + image = torch.cat([image_1, image_2], dim=0) + else: + image = image_1 + + clip_vision.model.to(device) + + negative_clip_embeds = None + + if tiles > 0: + log.info("Using tiled image encoding") + clip_embeds = clip_encode_image_tiled(clip_vision, image.to(device), tiles=tiles, ratio=ratio) + if negative_image is not None: + negative_clip_embeds = clip_encode_image_tiled(clip_vision, negative_image.to(device), tiles=tiles, ratio=ratio) + else: + if isinstance(clip_vision, ClipVisionModel): + clip_embeds = clip_vision.encode_image(image).penultimate_hidden_states.to(device) + if negative_image is not None: + negative_clip_embeds = clip_vision.encode_image(negative_image).penultimate_hidden_states.to(device) + else: + pixel_values = clip_preprocess(image.to(device), size=224, mean=image_mean, std=image_std, crop=(not crop == "disabled")).float() + clip_embeds = clip_vision.visual(pixel_values) + if negative_image is not None: + pixel_values = clip_preprocess(negative_image.to(device), size=224, mean=image_mean, std=image_std, crop=(not crop == "disabled")).float() + negative_clip_embeds = clip_vision.visual(pixel_values) + + log.info(f"Clip embeds shape: {clip_embeds.shape}, dtype: {clip_embeds.dtype}") + + weighted_embeds = [] + weighted_embeds.append(clip_embeds[0:1] * strength_1) + + # Handle all additional embeddings + if clip_embeds.shape[0] > 1: + weighted_embeds.append(clip_embeds[1:2] * strength_2) + + if clip_embeds.shape[0] > 2: + for i in range(2, clip_embeds.shape[0]): + weighted_embeds.append(clip_embeds[i:i+1]) # Add as-is without strength modifier + + # Combine all weighted embeddings + if combine_embeds == "average": + clip_embeds = torch.mean(torch.stack(weighted_embeds), dim=0) + elif combine_embeds == "sum": + clip_embeds = torch.sum(torch.stack(weighted_embeds), dim=0) + elif combine_embeds == "concat": + clip_embeds = torch.cat(weighted_embeds, dim=1) + elif combine_embeds == "batch": + clip_embeds = torch.cat(weighted_embeds, dim=0) + else: + clip_embeds = weighted_embeds[0] + + + log.info(f"Combined clip embeds shape: {clip_embeds.shape}") + + if force_offload: + clip_vision.model.to(offload_device) + mm.soft_empty_cache() + + clip_embeds_dict = { + "clip_embeds": clip_embeds, + "negative_clip_embeds": negative_clip_embeds + } + + return (clip_embeds_dict,) + +class WanVideoRealisDanceLatents: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "ref_latent": ("LATENT", {"tooltip": "Reference image to encode"}), + "pose_cond_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the SMPL model"}), + "pose_cond_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the SMPL model"}), + }, + "optional": { + "smpl_latent": ("LATENT", {"tooltip": "SMPL pose image to encode"}), + "hamer_latent": ("LATENT", {"tooltip": "Hamer hand pose image to encode"}), + }, + } + + RETURN_TYPES = ("ADD_COND_LATENTS",) + RETURN_NAMES = ("add_cond_latents",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, ref_latent, pose_cond_start_percent, pose_cond_end_percent, hamer_latent=None, smpl_latent=None): + if smpl_latent is None and hamer_latent is None: + raise Exception("At least one of smpl_latent or hamer_latent must be provided") + if smpl_latent is None: + smpl = torch.zeros_like(hamer_latent["samples"]) + else: + smpl = smpl_latent["samples"] + if hamer_latent is None: + hamer = torch.zeros_like(smpl_latent["samples"]) + else: + hamer = hamer_latent["samples"] + + pose_latent = torch.cat((smpl, hamer), dim=1) + + add_cond_latents = { + "ref_latent": ref_latent["samples"], + "pose_latent": pose_latent, + "pose_cond_start_percent": pose_cond_start_percent, + "pose_cond_end_percent": pose_cond_end_percent, + } + + return (add_cond_latents,) + +class WanVideoImageToVideoEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "vae": ("WANVAE",), + "width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}), + "height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}), + "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), + "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of noise augmentation, helpful for I2V where some noise can add motion and give sharper results"}), + "start_latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for I2V where lower values allow for more motion"}), + "end_latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for I2V where lower values allow for more motion"}), + "force_offload": ("BOOLEAN", {"default": True}), + }, + "optional": { + "clip_embeds": ("WANVIDIMAGE_CLIPEMBEDS", {"tooltip": "Clip vision encoded image"}), + "start_image": ("IMAGE", {"tooltip": "Image to encode"}), + "end_image": ("IMAGE", {"tooltip": "end frame"}), + "control_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "Control signal for the Fun -model"}), + "fun_or_fl2v_model": ("BOOLEAN", {"default": True, "tooltip": "Enable when using official FLF2V or Fun model"}), + "temporal_mask": ("MASK", {"tooltip": "mask"}), + "extra_latents": ("LATENT", {"tooltip": "Extra latents to add to the input front, used for Skyreels A2 reference images"}), + "tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}), + "add_cond_latents": ("ADD_COND_LATENTS", {"advanced": True, "tooltip": "Additional cond latents WIP"}), + } + } + + RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) + RETURN_NAMES = ("image_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, vae, width, height, num_frames, force_offload, noise_aug_strength, + start_latent_strength, end_latent_strength, start_image=None, end_image=None, control_embeds=None, fun_or_fl2v_model=False, + temporal_mask=None, extra_latents=None, clip_embeds=None, tiled_vae=False, add_cond_latents=None): + + H = height + W = width + + lat_h = H // 8 + lat_w = W // 8 + + num_frames = ((num_frames - 1) // 4) * 4 + 1 + two_ref_images = start_image is not None and end_image is not None + + base_frames = num_frames + (1 if two_ref_images and not fun_or_fl2v_model else 0) + if temporal_mask is None: + mask = torch.zeros(1, base_frames, lat_h, lat_w, device=device) + if start_image is not None: + mask[:, 0:start_image.shape[0]] = 1 # First frame + if end_image is not None: + mask[:, -end_image.shape[0]:] = 1 # End frame if exists + else: + mask = common_upscale(temporal_mask.unsqueeze(1).to(device), lat_w, lat_h, "nearest", "disabled").squeeze(1) + if mask.shape[0] > base_frames: + mask = mask[:base_frames] + elif mask.shape[0] < base_frames: + mask = torch.cat([mask, torch.zeros(base_frames - mask.shape[0], lat_h, lat_w, device=device)]) + mask = mask.unsqueeze(0).to(device) + + # Repeat first frame and optionally end frame + start_mask_repeated = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1) # T, C, H, W + if end_image is not None and not fun_or_fl2v_model: + end_mask_repeated = torch.repeat_interleave(mask[:, -1:], repeats=4, dim=1) # T, C, H, W + mask = torch.cat([start_mask_repeated, mask[:, 1:-1], end_mask_repeated], dim=1) + else: + mask = torch.cat([start_mask_repeated, mask[:, 1:]], dim=1) + + # Reshape mask into groups of 4 frames + mask = mask.view(1, mask.shape[1] // 4, 4, lat_h, lat_w) # 1, T, C, H, W + mask = mask.movedim(1, 2)[0]# C, T, H, W + + # Resize and rearrange the input image dimensions + if start_image is not None: + resized_start_image = common_upscale(start_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1) + resized_start_image = resized_start_image * 2 - 1 + if noise_aug_strength > 0.0: + resized_start_image = add_noise_to_reference_video(resized_start_image, ratio=noise_aug_strength) + + if end_image is not None: + resized_end_image = common_upscale(end_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1) + resized_end_image = resized_end_image * 2 - 1 + if noise_aug_strength > 0.0: + resized_end_image = add_noise_to_reference_video(resized_end_image, ratio=noise_aug_strength) + + # Concatenate image with zero frames and encode + vae.to(device) + + if temporal_mask is None: + if start_image is not None and end_image is None: + zero_frames = torch.zeros(3, num_frames-start_image.shape[0], H, W, device=device) + concatenated = torch.cat([resized_start_image.to(device), zero_frames], dim=1) + elif start_image is None and end_image is not None: + zero_frames = torch.zeros(3, num_frames-end_image.shape[0], H, W, device=device) + concatenated = torch.cat([zero_frames, resized_end_image.to(device)], dim=1) + elif start_image is None and end_image is None: + concatenated = torch.zeros(3, num_frames, H, W, device=device) + else: + if fun_or_fl2v_model: + zero_frames = torch.zeros(3, num_frames-(start_image.shape[0]+end_image.shape[0]), H, W, device=device) + else: + zero_frames = torch.zeros(3, num_frames-1, H, W, device=device) + concatenated = torch.cat([resized_start_image.to(device), zero_frames, resized_end_image.to(device)], dim=1) + else: + temporal_mask = common_upscale(temporal_mask.unsqueeze(1), W, H, "nearest", "disabled").squeeze(1) + concatenated = resized_start_image[:,:num_frames] * temporal_mask[:num_frames].unsqueeze(0) + + y = vae.encode([concatenated.to(device=device, dtype=vae.dtype)], device, end_=(end_image is not None and not fun_or_fl2v_model),tiled=tiled_vae)[0] + has_ref = False + if extra_latents is not None: + samples = extra_latents["samples"].squeeze(0) + y = torch.cat([samples, y], dim=1) + mask = torch.cat([torch.ones_like(mask[:, 0:samples.shape[1]]), mask], dim=1) + num_frames += samples.shape[1] * 4 + has_ref = True + y[:, :1] *= start_latent_strength + y[:, -1:] *= end_latent_strength + if control_embeds is None: + y = torch.cat([mask, y]) + else: + if end_image is None: + y[:, 1:] = 0 + elif start_image is None: + y[:, -1:] = 0 + + # Calculate maximum sequence length + patches_per_frame = lat_h * lat_w // (PATCH_SIZE[1] * PATCH_SIZE[2]) + frames_per_stride = (num_frames - 1) // 4 + (2 if end_image is not None and not fun_or_fl2v_model else 1) + max_seq_len = frames_per_stride * patches_per_frame + + if add_cond_latents is not None: + add_cond_latents["ref_latent_neg"] = vae.encode(torch.zeros(1, 3, 1, H, W, device=device, dtype=vae.dtype), device) + + vae.model.clear_cache() + if force_offload: + vae.model.to(offload_device) + mm.soft_empty_cache() + gc.collect() + + image_embeds = { + "image_embeds": y, + "clip_context": clip_embeds.get("clip_embeds", None) if clip_embeds is not None else None, + "negative_clip_context": clip_embeds.get("negative_clip_embeds", None) if clip_embeds is not None else None, + "max_seq_len": max_seq_len, + "num_frames": num_frames, + "lat_h": lat_h, + "lat_w": lat_w, + "control_embeds": control_embeds["control_embeds"] if control_embeds is not None else None, + "end_image": resized_end_image if end_image is not None else None, + "fun_or_fl2v_model": fun_or_fl2v_model, + "has_ref": has_ref, + "add_cond_latents": add_cond_latents, + "mask": mask if control_embeds is not None else None, # for 2.2 Fun control as it can handle masks + } + + return (image_embeds,) + +class WanVideoEmptyEmbeds: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}), + "height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}), + "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), + }, + "optional": { + "control_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "control signal for the Fun -model"}), + "extra_latents": ("LATENT", {"tooltip": "First latent to use for the Pusa -model"}), + } + } + + RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) + RETURN_NAMES = ("image_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, num_frames, width, height, control_embeds=None, extra_latents=None): + target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1, + height // VAE_STRIDE[1], + width // VAE_STRIDE[2]) + + embeds = { + "target_shape": target_shape, + "num_frames": num_frames, + "control_embeds": control_embeds["control_embeds"] if control_embeds is not None else None, + } + if extra_latents is not None: + embeds["extra_latents"] = [{ + "samples": extra_latents["samples"], + "index": 0, + }] + + return (embeds,) + +class WanVideoAddExtraLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "embeds": ("WANVIDIMAGE_EMBEDS",), + "extra_latents": ("LATENT",), + "latent_index": ("INT", {"default": 0, "min": -1000, "max": 1000, "step": 1, "tooltip": "Index to insert the extra latents at in latent space"}), + } + } + + RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) + RETURN_NAMES = ("image_embeds",) + FUNCTION = "add" + CATEGORY = "WanVideoWrapper" + + def add(self, embeds, extra_latents, latent_index): + # Prepare the new extra latent entry + new_entry = { + "samples": extra_latents["samples"], + "index": latent_index, + } + # Get previous extra_latents list, or start a new one + prev_extra_latents = embeds.get("extra_latents", None) + if prev_extra_latents is None: + extra_latents_list = [new_entry] + elif isinstance(prev_extra_latents, list): + extra_latents_list = prev_extra_latents + [new_entry] + else: + extra_latents_list = [prev_extra_latents, new_entry] + + # Return a new dict with updated extra_latents + updated = dict(embeds) + updated["extra_latents"] = extra_latents_list + return (updated,) + +class WanVideoMiniMaxRemoverEmbeds: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}), + "height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}), + "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), + "latents": ("LATENT", {"tooltip": "Encoded latents to use as control signals"}), + "mask_latents": ("LATENT", {"tooltip": "Encoded latents to use as mask"}), + }, + } + + RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) + RETURN_NAMES = ("image_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, num_frames, width, height, latents, mask_latents): + target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1, + height // VAE_STRIDE[1], + width // VAE_STRIDE[2]) + + embeds = { + "target_shape": target_shape, + "num_frames": num_frames, + "minimax_latents": latents["samples"].squeeze(0), + "minimax_mask_latents": mask_latents["samples"].squeeze(0), + } + + return (embeds,) + +# region phantom +class WanVideoPhantomEmbeds: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), + "phantom_latent_1": ("LATENT", {"tooltip": "reference latents for the phantom model"}), + + "phantom_cfg_scale": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "CFG scale for the extra phantom cond pass"}), + "phantom_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the phantom model"}), + "phantom_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the phantom model"}), + }, + "optional": { + "phantom_latent_2": ("LATENT", {"tooltip": "reference latents for the phantom model"}), + "phantom_latent_3": ("LATENT", {"tooltip": "reference latents for the phantom model"}), + "phantom_latent_4": ("LATENT", {"tooltip": "reference latents for the phantom model"}), + "vace_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "VACE embeds"}), + } + } + + RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) + RETURN_NAMES = ("image_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, num_frames, phantom_cfg_scale, phantom_start_percent, phantom_end_percent, phantom_latent_1, phantom_latent_2=None, phantom_latent_3=None, phantom_latent_4=None, vace_embeds=None): + samples = phantom_latent_1["samples"].squeeze(0) + if phantom_latent_2 is not None: + samples = torch.cat([samples, phantom_latent_2["samples"].squeeze(0)], dim=1) + if phantom_latent_3 is not None: + samples = torch.cat([samples, phantom_latent_3["samples"].squeeze(0)], dim=1) + if phantom_latent_4 is not None: + samples = torch.cat([samples, phantom_latent_4["samples"].squeeze(0)], dim=1) + C, T, H, W = samples.shape + + log.info(f"Phantom latents shape: {samples.shape}") + + target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1 + T, + H * 8 // VAE_STRIDE[1], + W * 8 // VAE_STRIDE[2]) + + embeds = { + "target_shape": target_shape, + "num_frames": num_frames, + "phantom_latents": samples, + "phantom_cfg_scale": phantom_cfg_scale, + "phantom_start_percent": phantom_start_percent, + "phantom_end_percent": phantom_end_percent, + } + if vace_embeds is not None: + vace_input = { + "vace_context": vace_embeds["vace_context"], + "vace_scale": vace_embeds["vace_scale"], + "has_ref": vace_embeds["has_ref"], + "vace_start_percent": vace_embeds["vace_start_percent"], + "vace_end_percent": vace_embeds["vace_end_percent"], + "vace_seq_len": vace_embeds["vace_seq_len"], + "additional_vace_inputs": vace_embeds["additional_vace_inputs"], + } + embeds.update(vace_input) + + return (embeds,) + +class WanVideoControlEmbeds: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "latents": ("LATENT", {"tooltip": "Encoded latents to use as control signals"}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the control signal"}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the control signal"}), + }, + "optional": { + "fun_ref_image": ("LATENT", {"tooltip": "Reference latent for the Fun 1.1 -model"}), + } + } + + RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) + RETURN_NAMES = ("image_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, latents, start_percent, end_percent, fun_ref_image=None): + + samples = latents["samples"].squeeze(0) + C, T, H, W = samples.shape + + num_frames = (T - 1) * 4 + 1 + seq_len = math.ceil((H * W) / 4 * ((num_frames - 1) // 4 + 1)) + + embeds = { + "max_seq_len": seq_len, + "target_shape": samples.shape, + "num_frames": num_frames, + "control_embeds": { + "control_images": samples, + "start_percent": start_percent, + "end_percent": end_percent, + "fun_ref_image": fun_ref_image["samples"][:,:, 0] if fun_ref_image is not None else None, + } + } + + return (embeds,) + +class WanVideoSLG: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "blocks": ("STRING", {"default": "10", "tooltip": "Blocks to skip uncond on, separated by comma, index starts from 0"}), + "start_percent": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the control signal"}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the control signal"}), + }, + } + + RETURN_TYPES = ("SLGARGS", ) + RETURN_NAMES = ("slg_args",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Skips uncond on the selected blocks" + + def process(self, blocks, start_percent, end_percent): + slg_block_list = [int(x.strip()) for x in blocks.split(",")] + + slg_args = { + "blocks": slg_block_list, + "start_percent": start_percent, + "end_percent": end_percent, + } + return (slg_args,) + +#region VACE +class WanVideoVACEEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "vae": ("WANVAE",), + "width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}), + "height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}), + "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}), + "vace_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the steps to apply VACE"}), + "vace_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the steps to apply VACE"}), + }, + "optional": { + "input_frames": ("IMAGE",), + "ref_images": ("IMAGE",), + "input_masks": ("MASK",), + "prev_vace_embeds": ("WANVIDIMAGE_EMBEDS",), + "tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}), + }, + } + + RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) + RETURN_NAMES = ("vace_embeds",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, vae, width, height, num_frames, strength, vace_start_percent, vace_end_percent, input_frames=None, ref_images=None, input_masks=None, prev_vace_embeds=None, tiled_vae=False): + vae = vae.to(device) + + width = (width // 16) * 16 + height = (height // 16) * 16 + + target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1, + height // VAE_STRIDE[1], + width // VAE_STRIDE[2]) + # vace context encode + if input_frames is None: + input_frames = torch.zeros((1, 3, num_frames, height, width), device=device, dtype=vae.dtype) + else: + input_frames = input_frames[:num_frames] + input_frames = common_upscale(input_frames.clone().movedim(-1, 1), width, height, "lanczos", "disabled").movedim(1, -1) + input_frames = input_frames.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W + input_frames = input_frames * 2 - 1 + if input_masks is None: + input_masks = torch.ones_like(input_frames, device=device) + else: + log.info(f"input_masks shape: {input_masks.shape}") + input_masks = input_masks[:num_frames] + input_masks = common_upscale(input_masks.clone().unsqueeze(1), width, height, "nearest-exact", "disabled").squeeze(1) + input_masks = input_masks.to(vae.dtype).to(device) + input_masks = input_masks.unsqueeze(-1).unsqueeze(0).permute(0, 4, 1, 2, 3).repeat(1, 3, 1, 1, 1) # B, C, T, H, W + + if ref_images is not None: + # Create padded image + if ref_images.shape[0] > 1: + ref_images = torch.cat([ref_images[i] for i in range(ref_images.shape[0])], dim=1).unsqueeze(0) + + B, H, W, C = ref_images.shape + current_aspect = W / H + target_aspect = width / height + if current_aspect > target_aspect: + # Image is wider than target, pad height + new_h = int(W / target_aspect) + pad_h = (new_h - H) // 2 + padded = torch.ones(ref_images.shape[0], new_h, W, ref_images.shape[3], device=ref_images.device, dtype=ref_images.dtype) + padded[:, pad_h:pad_h+H, :, :] = ref_images + ref_images = padded + elif current_aspect < target_aspect: + # Image is taller than target, pad width + new_w = int(H * target_aspect) + pad_w = (new_w - W) // 2 + padded = torch.ones(ref_images.shape[0], H, new_w, ref_images.shape[3], device=ref_images.device, dtype=ref_images.dtype) + padded[:, :, pad_w:pad_w+W, :] = ref_images + ref_images = padded + ref_images = common_upscale(ref_images.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) + + ref_images = ref_images.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3).unsqueeze(0) + ref_images = ref_images * 2 - 1 + + z0 = self.vace_encode_frames(vae, input_frames, ref_images, masks=input_masks, tiled_vae=tiled_vae) + vae.model.clear_cache() + m0 = self.vace_encode_masks(input_masks, ref_images) + z = self.vace_latent(z0, m0) + + vae.to(offload_device) + + vace_input = { + "vace_context": z, + "vace_scale": strength, + "has_ref": ref_images is not None, + "num_frames": num_frames, + "target_shape": target_shape, + "vace_start_percent": vace_start_percent, + "vace_end_percent": vace_end_percent, + "vace_seq_len": math.ceil((z[0].shape[2] * z[0].shape[3]) / 4 * z[0].shape[1]), + "additional_vace_inputs": [], + } + + if prev_vace_embeds is not None: + if "additional_vace_inputs" in prev_vace_embeds and prev_vace_embeds["additional_vace_inputs"]: + vace_input["additional_vace_inputs"] = prev_vace_embeds["additional_vace_inputs"].copy() + vace_input["additional_vace_inputs"].append(prev_vace_embeds) + + return (vace_input,) + def vace_encode_frames(self, vae, frames, ref_images, masks=None, tiled_vae=False): + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = vae.encode(frames, device=device, tiled=tiled_vae) + else: + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = vae.encode(inactive, device=device, tiled=tiled_vae) + reactive = vae.encode(reactive, device=device, tiled=tiled_vae) + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + vae.model.clear_cache() + cat_latents = [] + + pbar = ProgressBar(len(frames)) + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = vae.encode(refs, device=device, tiled=tiled_vae) + else: + ref_latent = vae.encode(refs, device=device, tiled=tiled_vae) + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + pbar.update(1) + return cat_latents + + def vace_encode_masks(self, masks, ref_images=None): + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + pbar = ProgressBar(len(masks)) + for mask, refs in zip(masks, ref_images): + _c, depth, height, width = mask.shape + new_depth = int((depth + 3) // VAE_STRIDE[0]) + height = 2 * (int(height) // (VAE_STRIDE[1] * 2)) + width = 2 * (int(width) // (VAE_STRIDE[2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view( + depth, height, VAE_STRIDE[1], width, VAE_STRIDE[1] + ) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape( + VAE_STRIDE[1] * VAE_STRIDE[2], depth, height, width + ) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + + if refs is not None: + length = len(refs) + mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + pbar.update(1) + return result_masks + + def vace_latent(self, z, m): + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] + + +#region context options +class WanVideoContextOptions: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "context_schedule": (["uniform_standard", "uniform_looped", "static_standard"],), + "context_frames": ("INT", {"default": 81, "min": 2, "max": 1000, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ), + "context_stride": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ), + "context_overlap": ("INT", {"default": 16, "min": 4, "max": 100, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ), + "freenoise": ("BOOLEAN", {"default": True, "tooltip": "Shuffle the noise"}), + "verbose": ("BOOLEAN", {"default": False, "tooltip": "Print debug output"}), + }, + "optional": { + "fuse_method": (["linear", "pyramid"], {"default": "linear", "tooltip": "Window weight function: linear=ramps at edges only, pyramid=triangular weights peaking in middle"}), + "reference_latent": ("LATENT", {"tooltip": "Image to be used as init for I2V models for windows where first frame is not the actual first frame. Mostly useful with MAGREF model"}), + } + } + + RETURN_TYPES = ("WANVIDCONTEXT", ) + RETURN_NAMES = ("context_options",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Context options for WanVideo, allows splitting the video into context windows and attemps blending them for longer generations than the model and memory otherwise would allow." + + def process(self, context_schedule, context_frames, context_stride, context_overlap, freenoise, verbose, image_cond_start_step=6, image_cond_window_count=2, vae=None, fuse_method="linear", reference_latent=None): + context_options = { + "context_schedule":context_schedule, + "context_frames":context_frames, + "context_stride":context_stride, + "context_overlap":context_overlap, + "freenoise":freenoise, + "verbose":verbose, + "fuse_method":fuse_method, + "reference_latent":reference_latent["samples"][0] if reference_latent is not None else None, + } + + return (context_options,) + + +class WanVideoFlowEdit: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "source_embeds": ("WANVIDEOTEXTEMBEDS", ), + "skip_steps": ("INT", {"default": 4, "min": 0}), + "drift_steps": ("INT", {"default": 0, "min": 0}), + "drift_flow_shift": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 30.0, "step": 0.01}), + "source_cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), + "drift_cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), + }, + "optional": { + "source_image_embeds": ("WANVIDIMAGE_EMBEDS", ), + } + } + + RETURN_TYPES = ("FLOWEDITARGS", ) + RETURN_NAMES = ("flowedit_args",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Flowedit options for WanVideo" + + def process(self, **kwargs): + return (kwargs,) + +class WanVideoLoopArgs: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "shift_skip": ("INT", {"default": 6, "min": 0, "tooltip": "Skip step of latent shift"}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the looping effect"}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the looping effect"}), + }, + } + + RETURN_TYPES = ("LOOPARGS", ) + RETURN_NAMES = ("loop_args",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Looping through latent shift as shown in https://github.com/YisuiTT/Mobius/" + + def process(self, **kwargs): + return (kwargs,) + +class WanVideoExperimentalArgs: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "video_attention_split_steps": ("STRING", {"default": "", "tooltip": "Steps to split self attention when using multiple prompts"}), + "cfg_zero_star": ("BOOLEAN", {"default": False, "tooltip": "https://github.com/WeichenFan/CFG-Zero-star"}), + "use_zero_init": ("BOOLEAN", {"default": False}), + "zero_star_steps": ("INT", {"default": 0, "min": 0, "tooltip": "Steps to split self attention when using multiple prompts"}), + "use_fresca": ("BOOLEAN", {"default": False, "tooltip": "https://github.com/WikiChao/FreSca"}), + "fresca_scale_low": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "fresca_scale_high": ("FLOAT", {"default": 1.25, "min": 0.0, "max": 10.0, "step": 0.01}), + "fresca_freq_cutoff": ("INT", {"default": 20, "min": 0, "max": 10000, "step": 1}), + "use_tcfg": ("BOOLEAN", {"default": False, "tooltip": "https://arxiv.org/abs/2503.18137 TCFG: Tangential Damping Classifier-free Guidance. CFG artifacts reduction."}), + "raag_alpha": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Alpha value for RAAG, 1.0 is default, 0.0 is disabled."}), + }, + } + + RETURN_TYPES = ("EXPERIMENTALARGS", ) + RETURN_NAMES = ("exp_args",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Experimental stuff" + EXPERIMENTAL = True + + def process(self, **kwargs): + return (kwargs,) + +class WanVideoFreeInitArgs: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "freeinit_num_iters": ("INT", {"default": 3, "min": 1, "max": 10, "tooltip": "Number of FreeInit iterations"}), + "freeinit_method": (["butterworth", "ideal", "gaussian", "none"], {"default": "ideal", "tooltip": "Frequency filter type"}), + "freeinit_n": ("INT", {"default": 4, "min": 1, "max": 10, "tooltip": "Butterworth filter order (only for butterworth)"}), + "freeinit_d_s": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Spatial filter cutoff"}), + "freeinit_d_t": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Temporal filter cutoff"}), + }, + } + + RETURN_TYPES = ("FREEINITARGS", ) + RETURN_NAMES = ("freeinit_args",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "https://github.com/TianxingWu/FreeInit; FreeInit, a concise yet effective method to improve temporal consistency of videos generated by diffusion models" + EXPERIMENTAL = True + + def process(self, **kwargs): + return (kwargs,) + +#region Sampler +class WanVideoSampler: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("WANVIDEOMODEL",), + "image_embeds": ("WANVIDIMAGE_EMBEDS", ), + "steps": ("INT", {"default": 30, "min": 1}), + "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), + "shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "force_offload": ("BOOLEAN", {"default": True, "tooltip": "Moves the model to the offload device after sampling"}), + "scheduler": (scheduler_list, {"default": "uni_pc",}), + "riflex_freq_index": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1, "tooltip": "Frequency index for RIFLEX, disabled when 0, default 6. Allows for new frames to be generated after without looping"}), + }, + "optional": { + "text_embeds": ("WANVIDEOTEXTEMBEDS", ), + "samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ), + "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "feta_args": ("FETAARGS", ), + "context_options": ("WANVIDCONTEXT", ), + "cache_args": ("CACHEARGS", ), + "flowedit_args": ("FLOWEDITARGS", ), + "batched_cfg": ("BOOLEAN", {"default": False, "tooltip": "Batch cond and uncond for faster sampling, possibly faster on some hardware, uses more memory"}), + "slg_args": ("SLGARGS", ), + "rope_function": (["default", "comfy", "comfy_chunked"], {"default": "comfy", "tooltip": "Comfy's RoPE implementation doesn't use complex numbers and can thus be compiled, that should be a lot faster when using torch.compile. Chunked version has reduced peak VRAM usage when not using torch.compile"}), + "loop_args": ("LOOPARGS", ), + "experimental_args": ("EXPERIMENTALARGS", ), + "sigmas": ("SIGMAS", ), + "unianimate_poses": ("UNIANIMATE_POSE", ), + "fantasytalking_embeds": ("FANTASYTALKING_EMBEDS", ), + "uni3c_embeds": ("UNI3C_EMBEDS", ), + "multitalk_embeds": ("MULTITALK_EMBEDS", ), + "freeinit_args": ("FREEINITARGS", ), + "start_step": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Start step for the sampling, 0 means full sampling, otherwise samples only from this step"}), + "end_step": ("INT", {"default": -1, "min": -1, "max": 10000, "step": 1, "tooltip": "End step for the sampling, -1 means full sampling, otherwise samples only until this step"}), + "add_noise_to_samples": ("BOOLEAN", {"default": False, "tooltip": "Add noise to the samples before sampling, needed for video2video sampling when starting from clean video"}), + } + } + + RETURN_TYPES = ("LATENT", "LATENT",) + RETURN_NAMES = ("samples", "denoised_samples",) + FUNCTION = "process" + CATEGORY = "WanVideoWrapper" + + def process(self, model, image_embeds, shift, steps, cfg, seed, scheduler, riflex_freq_index, text_embeds=None, + force_offload=True, samples=None, feta_args=None, denoise_strength=1.0, context_options=None, + cache_args=None, teacache_args=None, flowedit_args=None, batched_cfg=False, slg_args=None, rope_function="default", loop_args=None, + experimental_args=None, sigmas=None, unianimate_poses=None, fantasytalking_embeds=None, uni3c_embeds=None, multitalk_embeds=None, freeinit_args=None, start_step=0, end_step=-1, add_noise_to_samples=False): + + patcher = model + model = model.model + transformer = model.diffusion_model + + dtype = model["dtype"] + fp8_matmul = model["fp8_matmul"] + gguf = model["gguf"] + control_lora = model["control_lora"] + + transformer_options = patcher.model_options.get("transformer_options", None) + merge_loras = transformer_options["merge_loras"] + + is_5b = transformer.out_dim == 48 + vae_upscale_factor = 16 if is_5b else 8 + + patch_linear = transformer_options.get("patch_linear", False) + + if gguf: + set_lora_params_gguf(transformer, patcher.patches) + elif len(patcher.patches) != 0 and patch_linear: + log.info(f"Using {len(patcher.patches)} LoRA weight patches for WanVideo model") + if not merge_loras and fp8_matmul: + raise NotImplementedError("FP8 matmul with unmerged LoRAs is not supported") + set_lora_params(transformer, patcher.patches) + else: + remove_lora_from_module(transformer) + + transformer.lora_scheduling_enabled = transformer_options.get("lora_scheduling_enabled", False) + + #torch.compile + if model["auto_cpu_offload"] is False: + transformer = compile_model(transformer, model["compile_args"]) + + multitalk_sampling = image_embeds.get("multitalk_sampling", False) + if not multitalk_sampling and scheduler == "multitalk": + raise Exception("multitalk scheduler is only for multitalk sampling when using ImagetoVideoMultiTalk -node") + + if text_embeds == None: + text_embeds = { + "prompt_embeds": [], + "negative_prompt_embeds": [], + } + else: + text_embeds = dict_to_device(text_embeds, device) + + seed_g = torch.Generator(device=torch.device("cpu")) + seed_g.manual_seed(seed) + + #region Scheduler + if scheduler != "multitalk": + sample_scheduler, timesteps = get_scheduler(scheduler, steps, shift, device, transformer.dim, flowedit_args, denoise_strength, sigmas=sigmas) + else: + timesteps = torch.tensor([1000, 750, 500, 250], device=device) + log.info(f"sigmas: {sample_scheduler.sigmas}") + + steps = len(timesteps) + + if end_step != -1 and start_step >= end_step: + raise ValueError("start_step must be less than end_step") + + if denoise_strength < 1.0: + if start_step != 0: + raise ValueError("start_step must be 0 when denoise_strength is used") + start_step = steps - int(steps * denoise_strength) - 1 + add_noise_to_samples = True #for now to not break old workflows + + first_sampler = (end_step != -1 or end_step >= steps) + + if isinstance(cfg, list): + if steps != len(cfg): + log.info(f"Received {len(cfg)} cfg values, but only {steps} steps. Setting step count to match.") + steps = len(cfg) + else: + cfg = [cfg] * (steps + 1) + + if first_sampler: + timesteps = timesteps[:end_step] + sample_scheduler.sigmas = sample_scheduler.sigmas[:end_step+1] + log.info(f"Sampling until step {end_step}, timestep: {timesteps[-1]}") + if start_step > 0: + timesteps = timesteps[start_step:] + sample_scheduler.sigmas = sample_scheduler.sigmas[start_step:] + log.info(f"Skipping first {start_step} steps, starting from timestep {timesteps[0]}") + + log.info(f"timesteps: {timesteps}") + + if hasattr(sample_scheduler, 'timesteps'): + sample_scheduler.timesteps = timesteps + + scheduler_step_args = {"generator": seed_g} + step_sig = inspect.signature(sample_scheduler.step) + for arg in list(scheduler_step_args.keys()): + if arg not in step_sig.parameters: + scheduler_step_args.pop(arg) + + control_latents = control_camera_latents = clip_fea = clip_fea_neg = end_image = recammaster = camera_embed = unianim_data = None + vace_data = vace_context = vace_scale = None + fun_or_fl2v_model = has_ref = drop_last = False + phantom_latents = fun_ref_image = ATI_tracks = None + add_cond = attn_cond = attn_cond_neg = None + + #I2V + image_cond = image_embeds.get("image_embeds", None) + if image_cond is not None: + if transformer.in_dim == 16: + raise ValueError("T2V (text to video) model detected, encoded images only work with I2V (Image to video) models") + log.info(f"image_cond shape: {image_cond.shape}") + #ATI tracks + if transformer_options is not None: + ATI_tracks = transformer_options.get("ati_tracks", None) + if ATI_tracks is not None: + from .ATI.motion_patch import patch_motion + topk = transformer_options.get("ati_topk", 2) + temperature = transformer_options.get("ati_temperature", 220.0) + ati_start_percent = transformer_options.get("ati_start_percent", 0.0) + ati_end_percent = transformer_options.get("ati_end_percent", 1.0) + image_cond_ati = patch_motion(ATI_tracks.to(image_cond.device, image_cond.dtype), image_cond, topk=topk, temperature=temperature) + log.info(f"ATI tracks shape: {ATI_tracks.shape}") + + add_cond_latents = image_embeds.get("add_cond_latents", None) + if add_cond_latents is not None: + add_cond = add_cond_latents["pose_latent"] + attn_cond = add_cond_latents["ref_latent"] + attn_cond_neg = add_cond_latents["ref_latent_neg"] + add_cond_start_percent = add_cond_latents["pose_cond_start_percent"] + add_cond_end_percent = add_cond_latents["pose_cond_end_percent"] + + end_image = image_embeds.get("end_image", None) + fun_or_fl2v_model = image_embeds.get("fun_or_fl2v_model", False) + + noise = torch.randn( #C, T, H, W + 48 if is_5b else 16, + (image_embeds["num_frames"] - 1) // 4 + (2 if end_image is not None and not fun_or_fl2v_model else 1), + image_embeds["lat_h"], + image_embeds["lat_w"], + dtype=torch.float32, + generator=seed_g, + device=torch.device("cpu")) + seq_len = image_embeds["max_seq_len"] + + clip_fea = image_embeds.get("clip_context", None) + if clip_fea is not None: + clip_fea = clip_fea.to(dtype) + clip_fea_neg = image_embeds.get("negative_clip_context", None) + if clip_fea_neg is not None: + clip_fea_neg = clip_fea_neg.to(dtype) + + control_embeds = image_embeds.get("control_embeds", None) + if control_embeds is not None: + if transformer.in_dim not in [52, 48, 32]: + raise ValueError("Control signal only works with Fun-Control model") + if transformer.in_dim == 52: #fun 2.2 control + image_cond_mask = image_embeds.get("mask", None) + if image_cond_mask is not None: + image_cond = torch.cat([image_cond_mask, image_cond]) + control_latents = control_embeds.get("control_images", None) + control_camera_latents = control_embeds.get("control_camera_latents", None) + control_camera_start_percent = control_embeds.get("control_camera_start_percent", 0.0) + control_camera_end_percent = control_embeds.get("control_camera_end_percent", 1.0) + control_start_percent = control_embeds.get("start_percent", 0.0) + control_end_percent = control_embeds.get("end_percent", 1.0) + drop_last = image_embeds.get("drop_last", False) + has_ref = image_embeds.get("has_ref", False) + else: #t2v + target_shape = image_embeds.get("target_shape", None) + if target_shape is None: + raise ValueError("Empty image embeds must be provided for T2V models") + + has_ref = image_embeds.get("has_ref", False) + vace_context = image_embeds.get("vace_context", None) + vace_scale = image_embeds.get("vace_scale", None) + if not isinstance(vace_scale, list): + vace_scale = [vace_scale] * (steps+1) + vace_start_percent = image_embeds.get("vace_start_percent", 0.0) + vace_end_percent = image_embeds.get("vace_end_percent", 1.0) + vace_seqlen = image_embeds.get("vace_seq_len", None) + + vace_additional_embeds = image_embeds.get("additional_vace_inputs", []) + if vace_context is not None: + vace_data = [ + {"context": vace_context, + "scale": vace_scale, + "start": vace_start_percent, + "end": vace_end_percent, + "seq_len": vace_seqlen + } + ] + if len(vace_additional_embeds) > 0: + for i in range(len(vace_additional_embeds)): + if vace_additional_embeds[i].get("has_ref", False): + has_ref = True + vace_scale = vace_additional_embeds[i]["vace_scale"] + if not isinstance(vace_scale, list): + vace_scale = [vace_scale] * (steps+1) + vace_data.append({ + "context": vace_additional_embeds[i]["vace_context"], + "scale": vace_scale, + "start": vace_additional_embeds[i]["vace_start_percent"], + "end": vace_additional_embeds[i]["vace_end_percent"], + "seq_len": vace_additional_embeds[i]["vace_seq_len"] + }) + + noise = torch.randn( + 48 if is_5b else 16, + target_shape[1] + 1 if has_ref else target_shape[1], + target_shape[2] // 2 if is_5b else target_shape[2], #todo make this smarter + target_shape[3] // 2 if is_5b else target_shape[3], #todo make this smarter + dtype=torch.float32, + device=torch.device("cpu"), + generator=seed_g) + + seq_len = math.ceil((noise.shape[2] * noise.shape[3]) / 4 * noise.shape[1]) + + recammaster = image_embeds.get("recammaster", None) + if recammaster is not None: + camera_embed = recammaster.get("camera_embed", None) + recam_latents = recammaster.get("source_latents", None) + orig_noise_len = noise.shape[1] + log.info(f"RecamMaster camera embed shape: {camera_embed.shape}") + log.info(f"RecamMaster source video shape: {recam_latents.shape}") + seq_len *= 2 + + control_embeds = image_embeds.get("control_embeds", None) + if control_embeds is not None: + control_latents = control_embeds.get("control_images", None) + if control_latents is not None: + control_latents = control_latents.to(device) + control_camera_latents = control_embeds.get("control_camera_latents", None) + control_camera_start_percent = control_embeds.get("control_camera_start_percent", 0.0) + control_camera_end_percent = control_embeds.get("control_camera_end_percent", 1.0) + if control_camera_latents is not None: + control_camera_latents = control_camera_latents.to(device) + + if control_lora: + image_cond = control_latents.to(device) + if not patcher.model.is_patched: + log.info("Re-loading control LoRA...") + patcher = apply_lora(patcher, device, device, low_mem_load=False, control_lora=True) + patcher.model.is_patched = True + else: + if transformer.in_dim not in [48, 32, 52]: + raise ValueError("Control signal only works with Fun-Control model") + image_cond = torch.zeros_like(noise).to(device) #fun control + if transformer.in_dim == 52: #fun 2.2 control + mask_latents = torch.tile( + torch.zeros_like(noise[:1]), [4, 1, 1, 1] + ) + masked_video_latents_input = torch.zeros_like(noise) + image_cond = torch.cat([mask_latents, masked_video_latents_input], dim=0).to(device) + clip_fea = None + fun_ref_image = control_embeds.get("fun_ref_image", None) + control_start_percent = control_embeds.get("start_percent", 0.0) + control_end_percent = control_embeds.get("end_percent", 1.0) + else: + if transformer.in_dim == 36: #fun inp + mask_latents = torch.tile( + torch.zeros_like(noise[:1]), [4, 1, 1, 1] + ) + masked_video_latents_input = torch.zeros_like(noise) + image_cond = torch.cat([mask_latents, masked_video_latents_input], dim=0).to(device) + + phantom_latents = image_embeds.get("phantom_latents", None) + phantom_cfg_scale = image_embeds.get("phantom_cfg_scale", None) + if not isinstance(phantom_cfg_scale, list): + phantom_cfg_scale = [phantom_cfg_scale] * (steps +1) + phantom_start_percent = image_embeds.get("phantom_start_percent", 0.0) + phantom_end_percent = image_embeds.get("phantom_end_percent", 1.0) + if phantom_latents is not None: + phantom_latents = phantom_latents.to(device) + + latent_video_length = noise.shape[1] + + # Initialize FreeInit filter if enabled + freq_filter = None + if freeinit_args is not None: + from .freeinit.freeinit_utils import get_freq_filter, freq_mix_3d + filter_shape = list(noise.shape) # [batch, C, T, H, W] + freq_filter = get_freq_filter( + filter_shape, + device=device, + filter_type=freeinit_args.get("freeinit_method", "butterworth"), + n=freeinit_args.get("freeinit_n", 4) if freeinit_args.get("freeinit_method", "butterworth") == "butterworth" else None, + d_s=freeinit_args.get("freeinit_s", 1.0), + d_t=freeinit_args.get("freeinit_t", 1.0) + ) + if samples is not None: + saved_generator_state = samples.get("generator_state", None) + if saved_generator_state is not None: + seed_g.set_state(saved_generator_state) + + # UniAnimate + if unianimate_poses is not None: + transformer.dwpose_embedding.to(device, model["dtype"]) + dwpose_data = unianimate_poses["pose"].to(device, model["dtype"]) + dwpose_data = torch.cat([dwpose_data[:,:,:1].repeat(1,1,3,1,1), dwpose_data], dim=2) + dwpose_data = transformer.dwpose_embedding(dwpose_data) + log.info(f"UniAnimate pose embed shape: {dwpose_data.shape}") + if dwpose_data.shape[2] > latent_video_length: + log.warning(f"UniAnimate pose embed length {dwpose_data.shape[2]} is longer than the video length {latent_video_length}, truncating") + dwpose_data = dwpose_data[:,:, :latent_video_length] + elif dwpose_data.shape[2] < latent_video_length: + log.warning(f"UniAnimate pose embed length {dwpose_data.shape[2]} is shorter than the video length {latent_video_length}, padding with last pose") + pad_len = latent_video_length - dwpose_data.shape[2] + pad = dwpose_data[:,:,:1].repeat(1,1,pad_len,1,1) + dwpose_data = torch.cat([dwpose_data, pad], dim=2) + dwpose_data_flat = rearrange(dwpose_data, 'b c f h w -> b (f h w) c').contiguous() + + random_ref_dwpose_data = None + if image_cond is not None: + transformer.randomref_embedding_pose.to(device) + random_ref_dwpose = unianimate_poses.get("ref", None) + if random_ref_dwpose is not None: + random_ref_dwpose_data = transformer.randomref_embedding_pose( + random_ref_dwpose.to(device) + ).unsqueeze(2).to(model["dtype"]) # [1, 20, 104, 60] + + unianim_data = { + "dwpose": dwpose_data_flat, + "random_ref": random_ref_dwpose_data.squeeze(0) if random_ref_dwpose_data is not None else None, + "strength": unianimate_poses["strength"], + "start_percent": unianimate_poses["start_percent"], + "end_percent": unianimate_poses["end_percent"] + } + + # FantasyTalking + audio_proj = multitalk_audio_embedding = None + audio_scale = 1.0 + if fantasytalking_embeds is not None: + audio_proj = fantasytalking_embeds["audio_proj"].to(device) + audio_scale = fantasytalking_embeds["audio_scale"] + audio_cfg_scale = fantasytalking_embeds["audio_cfg_scale"] + if not isinstance(audio_cfg_scale, list): + audio_cfg_scale = [audio_cfg_scale] * (steps +1) + log.info(f"Audio proj shape: {audio_proj.shape}") + elif multitalk_embeds is not None: + # Handle single or multiple speaker embeddings + audio_features_in = multitalk_embeds.get("audio_features", None) + if audio_features_in is None: + multitalk_audio_embedding = None + else: + if isinstance(audio_features_in, list): + multitalk_audio_embedding = [emb.to(device, dtype) for emb in audio_features_in] + else: + # keep backward-compatibility with single tensor input + multitalk_audio_embedding = [audio_features_in.to(device, dtype)] + + audio_scale = multitalk_embeds.get("audio_scale", 1.0) + audio_cfg_scale = multitalk_embeds.get("audio_cfg_scale", 1.0) + ref_target_masks = multitalk_embeds.get("ref_target_masks", None) + if not isinstance(audio_cfg_scale, list): + audio_cfg_scale = [audio_cfg_scale] * (steps + 1) + + shapes = [tuple(e.shape) for e in multitalk_audio_embedding] + log.info(f"Multitalk audio features shapes (per speaker): {shapes}") + + # MiniMax Remover + minimax_latents = minimax_mask_latents = None + minimax_latents = image_embeds.get("minimax_latents", None) + minimax_mask_latents = image_embeds.get("minimax_mask_latents", None) + if minimax_latents is not None: + log.info(f"minimax_latents: {minimax_latents.shape}") + log.info(f"minimax_mask_latents: {minimax_mask_latents.shape}") + minimax_latents = minimax_latents.to(device, dtype) + minimax_mask_latents = minimax_mask_latents.to(device, dtype) + + # Context windows + is_looped = False + context_reference_latent = None + if context_options is not None: + context_schedule = context_options["context_schedule"] + context_frames = (context_options["context_frames"] - 1) // 4 + 1 + context_stride = context_options["context_stride"] // 4 + context_overlap = context_options["context_overlap"] // 4 + context_reference_latent = context_options.get("reference_latent", None) + + # Get total number of prompts + num_prompts = len(text_embeds["prompt_embeds"]) + log.info(f"Number of prompts: {num_prompts}") + # Calculate which section this context window belongs to + section_size = (latent_video_length / num_prompts) if num_prompts != 0 else 1 + log.info(f"Section size: {section_size}") + is_looped = context_schedule == "uniform_looped" + + seq_len = math.ceil((noise.shape[2] * noise.shape[3]) / 4 * context_frames) + + if context_options["freenoise"]: + log.info("Applying FreeNoise") + # code from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) + delta = context_frames - context_overlap + for start_idx in range(0, latent_video_length-context_frames, delta): + place_idx = start_idx + context_frames + if place_idx >= latent_video_length: + break + end_idx = place_idx - 1 + + if end_idx + delta >= latent_video_length: + final_delta = latent_video_length - place_idx + list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long) + list_idx = list_idx[torch.randperm(final_delta, generator=seed_g)] + noise[:, place_idx:place_idx + final_delta, :, :] = noise[:, list_idx, :, :] + break + list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long) + list_idx = list_idx[torch.randperm(delta, generator=seed_g)] + noise[:, place_idx:place_idx + delta, :, :] = noise[:, list_idx, :, :] + + log.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap") + from .context_windows.context import get_context_scheduler, create_window_mask, WindowTracker + self.window_tracker = WindowTracker(verbose=context_options["verbose"]) + context = get_context_scheduler(context_schedule) + + # vid2vid + if samples is not None: + saved_generator_state = samples.get("generator_state", None) + if saved_generator_state is not None: + seed_g.set_state(saved_generator_state) + input_samples = samples["samples"].squeeze(0).to(noise) + if input_samples.shape[1] != noise.shape[1]: + input_samples = torch.cat([input_samples[:, :1].repeat(1, noise.shape[1] - input_samples.shape[1], 1, 1), input_samples], dim=1) + + if add_noise_to_samples: + latent_timestep = timesteps[:1].to(noise) + noise = noise * latent_timestep / 1000 + (1 - latent_timestep / 1000) * input_samples + else: + noise = input_samples + mask = samples.get("mask", None) + if mask is not None: + original_image = input_samples.to(device) + if mask.shape[2] != noise.shape[1]: + mask = torch.cat([torch.zeros(1, noise.shape[0], noise.shape[1] - mask.shape[2], noise.shape[2], noise.shape[3]), mask], dim=2) + + # extra latents (Pusa) and 5b + latents_to_insert = add_index = None + if (extra_latents := image_embeds.get("extra_latents", None)) is not None: + all_indices = [] + for entry in extra_latents: + add_index = entry["index"] + num_extra_frames = entry["samples"].shape[2] + noise[:, add_index:add_index+num_extra_frames] = entry["samples"].to(noise) + log.info(f"Adding extra samples to latent indices {add_index} to {add_index+num_extra_frames-1}") + all_indices.extend(range(add_index, add_index+num_extra_frames)) + + + latent = noise.to(device) + + #controlnet + controlnet_latents = controlnet = None + if transformer_options is not None: + controlnet = transformer_options.get("controlnet", None) + if controlnet is not None: + self.controlnet = controlnet["controlnet"] + controlnet_start = controlnet["controlnet_start"] + controlnet_end = controlnet["controlnet_end"] + controlnet_latents = controlnet["control_latents"] + controlnet["controlnet_weight"] = controlnet["controlnet_strength"] + controlnet["controlnet_stride"] = controlnet["control_stride"] + + #uni3c + pcd_data = pcd_data_input = None + if uni3c_embeds is not None: + transformer.controlnet = uni3c_embeds["controlnet"] + pcd_data = { + "render_latent": uni3c_embeds["render_latent"].to(dtype), + "render_mask": uni3c_embeds["render_mask"], + "camera_embedding": uni3c_embeds["camera_embedding"], + "controlnet_weight": uni3c_embeds["controlnet_weight"], + "start": uni3c_embeds["start"], + "end": uni3c_embeds["end"], + } + + # Enhance-a-video (feta) + if feta_args is not None and latent_video_length > 1: + set_enhance_weight(feta_args["weight"]) + feta_start_percent = feta_args["start_percent"] + feta_end_percent = feta_args["end_percent"] + if context_options is not None: + set_num_frames(context_frames) + else: + set_num_frames(latent_video_length) + enhance_enabled = True + else: + feta_args = None + enhance_enabled = False + + # EchoShot https://github.com/D2I-ai/EchoShot + echoshot = False + shot_len = None + if text_embeds is not None: + echoshot = text_embeds.get("echoshot", False) + if echoshot: + shot_num = len(text_embeds["prompt_embeds"]) + shot_len = [latent_video_length//shot_num] * (shot_num-1) + shot_len.append(latent_video_length-sum(shot_len)) + rope_function = "default" #echoshot does not support comfy rope function + log.info(f"Number of shots in prompt: {shot_num}, Shot token lengths: {shot_len}") + + #region transformer settings + #rope + freqs = None + transformer.rope_embedder.k = None + transformer.rope_embedder.num_frames = None + if "comfy" in rope_function: + transformer.rope_embedder.k = riflex_freq_index + transformer.rope_embedder.num_frames = latent_video_length + else: + d = transformer.dim // transformer.num_heads + freqs = torch.cat([ + rope_params(1024, d - 4 * (d // 6), L_test=latent_video_length, k=riflex_freq_index), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)) + ], + dim=1) + transformer.rope_func = rope_function + for block in transformer.blocks: + block.rope_func = rope_function + if transformer.vace_layers is not None: + for block in transformer.vace_blocks: + block.rope_func = rope_function + + #blockswap init + + mm.unload_all_models() + mm.soft_empty_cache() + gc.collect() + + if transformer_options is not None: + block_swap_args = transformer_options.get("block_swap_args", None) + + if block_swap_args is not None: + transformer.use_non_blocking = block_swap_args.get("use_non_blocking", False) + for name, param in transformer.named_parameters(): + if "block" not in name: + param.data = param.data.to(device) + if "control_adapter" in name: + param.data = param.data.to(device) + elif block_swap_args["offload_txt_emb"] and "txt_emb" in name: + param.data = param.data.to(offload_device) + elif block_swap_args["offload_img_emb"] and "img_emb" in name: + param.data = param.data.to(offload_device) + + transformer.block_swap( + block_swap_args["blocks_to_swap"] - 1 , + block_swap_args["offload_txt_emb"], + block_swap_args["offload_img_emb"], + vace_blocks_to_swap = block_swap_args.get("vace_blocks_to_swap", None), + prefetch_blocks = block_swap_args.get("prefetch_blocks", 0), + block_swap_debug = block_swap_args.get("block_swap_debug", False), + ) + elif model["auto_cpu_offload"]: + for module in transformer.modules(): + if hasattr(module, "offload"): + module.offload() + if hasattr(module, "onload"): + module.onload() + for block in transformer.blocks: + block.modulation = torch.nn.Parameter(block.modulation.to(device)) + transformer.head.modulation = torch.nn.Parameter(transformer.head.modulation.to(device)) + + elif model["manual_offloading"]: + transformer.to(device) + + # Initialize Cache if enabled + transformer.enable_teacache = transformer.enable_magcache = transformer.enable_easycache = False + cache_args = teacache_args if teacache_args is not None else cache_args #for backward compatibility on old workflows + if cache_args is not None: + from .cache_methods.cache_methods import set_transformer_cache_method + transformer = set_transformer_cache_method(transformer, timesteps, cache_args) + + # Initialize cache state + self.cache_state = [None, None] + if phantom_latents is not None: + log.info(f"Phantom latents shape: {phantom_latents.shape}") + self.cache_state = [None, None, None] + self.cache_state_source = [None, None] + self.cache_states_context = [] + + # Skip layer guidance (SLG) + if slg_args is not None: + assert batched_cfg is not None, "Batched cfg is not supported with SLG" + transformer.slg_blocks = slg_args["blocks"] + transformer.slg_start_percent = slg_args["start_percent"] + transformer.slg_end_percent = slg_args["end_percent"] + else: + transformer.slg_blocks = None + + # Setup radial attention + if transformer.attention_mode == "radial_sage_attention": + setup_radial_attention(transformer, transformer_options, latent, seq_len, latent_video_length, context_options=context_options) + + # FlowEdit setup + if flowedit_args is not None: + source_embeds = flowedit_args["source_embeds"] + source_embeds = dict_to_device(source_embeds, device) + source_image_embeds = flowedit_args.get("source_image_embeds", image_embeds) + source_image_cond = source_image_embeds.get("image_embeds", None) + source_clip_fea = source_image_embeds.get("clip_fea", clip_fea) + if source_image_cond is not None: + source_image_cond = source_image_cond.to(dtype) + skip_steps = flowedit_args["skip_steps"] + drift_steps = flowedit_args["drift_steps"] + source_cfg = flowedit_args["source_cfg"] + if not isinstance(source_cfg, list): + source_cfg = [source_cfg] * (steps +1) + drift_cfg = flowedit_args["drift_cfg"] + if not isinstance(drift_cfg, list): + drift_cfg = [drift_cfg] * (steps +1) + + x_init = samples["samples"].clone().squeeze(0).to(device) + x_tgt = samples["samples"].squeeze(0).to(device) + + sample_scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, + shift=flowedit_args["drift_flow_shift"], + use_dynamic_shifting=False) + + sampling_sigmas = get_sampling_sigmas(steps, flowedit_args["drift_flow_shift"]) + + drift_timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=device, + sigmas=sampling_sigmas) + + if drift_steps > 0: + drift_timesteps = torch.cat([drift_timesteps, torch.tensor([0]).to(drift_timesteps.device)]).to(drift_timesteps.device) + timesteps[-drift_steps:] = drift_timesteps[-drift_steps:] + + # Experimental args + use_cfg_zero_star = use_tangential = use_fresca = False + raag_alpha = 0.0 + if experimental_args is not None: + video_attention_split_steps = experimental_args.get("video_attention_split_steps", []) + if video_attention_split_steps: + transformer.video_attention_split_steps = [int(x.strip()) for x in video_attention_split_steps.split(",")] + else: + transformer.video_attention_split_steps = [] + + use_zero_init = experimental_args.get("use_zero_init", True) + use_cfg_zero_star = experimental_args.get("cfg_zero_star", False) + use_tangential = experimental_args.get("use_tcfg", False) + zero_star_steps = experimental_args.get("zero_star_steps", 0) + raag_alpha = experimental_args.get("raag_alpha", 0.0) + + use_fresca = experimental_args.get("use_fresca", False) + if use_fresca: + fresca_scale_low = experimental_args.get("fresca_scale_low", 1.0) + fresca_scale_high = experimental_args.get("fresca_scale_high", 1.25) + fresca_freq_cutoff = experimental_args.get("fresca_freq_cutoff", 20) + + #region model pred + def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, image_cond=None, clip_fea=None, + control_latents=None, vace_data=None, unianim_data=None, audio_proj=None, control_camera_latents=None, + add_cond=None, cache_state=None, context_window=None, multitalk_audio_embeds=None): + nonlocal transformer + z = z.to(dtype) + with torch.autocast(device_type=mm.get_autocast_device(device), dtype=dtype, enabled=("fp8" in model["quantization"])): + + if use_cfg_zero_star and (idx <= zero_star_steps) and use_zero_init: + return z*0, None + + nonlocal patcher + current_step_percentage = idx / len(timesteps) + control_lora_enabled = False + image_cond_input = None + if control_latents is not None: + if control_lora: + control_lora_enabled = True + else: + if (control_start_percent <= current_step_percentage <= control_end_percent) or \ + (control_end_percent > 0 and idx == 0 and current_step_percentage >= control_start_percent): + image_cond_input = torch.cat([control_latents.to(z), image_cond.to(z)]) + else: + image_cond_input = torch.cat([torch.zeros_like(control_latents, dtype=dtype), image_cond.to(z)]) + if fun_ref_image is not None: + fun_ref_input = fun_ref_image.to(z) + else: + fun_ref_input = torch.zeros_like(z, dtype=z.dtype)[:, 0].unsqueeze(1) + #fun_ref_input = None + + if control_lora: + if not control_start_percent <= current_step_percentage <= control_end_percent: + control_lora_enabled = False + if patcher.model.is_patched: + log.info("Unloading LoRA...") + patcher.unpatch_model(device) + patcher.model.is_patched = False + else: + image_cond_input = control_latents.to(z) + if not patcher.model.is_patched: + log.info("Loading LoRA...") + patcher = apply_lora(patcher, device, device, low_mem_load=False, control_lora=True) + patcher.model.is_patched = True + + elif ATI_tracks is not None and ((ati_start_percent <= current_step_percentage <= ati_end_percent) or + (ati_end_percent > 0 and idx == 0 and current_step_percentage >= ati_start_percent)): + image_cond_input = image_cond_ati.to(z) + else: + image_cond_input = image_cond.to(z) if image_cond is not None else None + + if control_camera_latents is not None: + if (control_camera_start_percent <= current_step_percentage <= control_camera_end_percent) or \ + (control_end_percent > 0 and idx == 0 and current_step_percentage >= control_camera_start_percent): + control_camera_input = control_camera_latents.to(z) + else: + control_camera_input = None + + if recammaster is not None: + z = torch.cat([z, recam_latents.to(z)], dim=1) + + use_phantom = False + if phantom_latents is not None: + if (phantom_start_percent <= current_step_percentage <= phantom_end_percent) or \ + (phantom_end_percent > 0 and idx == 0 and current_step_percentage >= phantom_start_percent): + + z_pos = torch.cat([z[:,:-phantom_latents.shape[1]], phantom_latents.to(z)], dim=1) + z_phantom_img = torch.cat([z[:,:-phantom_latents.shape[1]], phantom_latents.to(z)], dim=1) + z_neg = torch.cat([z[:,:-phantom_latents.shape[1]], torch.zeros_like(phantom_latents).to(z)], dim=1) + use_phantom = True + if cache_state is not None and len(cache_state) != 3: + cache_state.append(None) + if not use_phantom: + z_pos = z_neg = z + + if controlnet_latents is not None: + if (controlnet_start <= current_step_percentage < controlnet_end): + self.controlnet.to(device) + controlnet_states = self.controlnet( + hidden_states=z.unsqueeze(0).to(device, self.controlnet.dtype), + timestep=timestep, + encoder_hidden_states=positive_embeds[0].unsqueeze(0).to(device, self.controlnet.dtype), + attention_kwargs=None, + controlnet_states=controlnet_latents.to(device, self.controlnet.dtype), + return_dict=False, + )[0] + if isinstance(controlnet_states, (tuple, list)): + controlnet["controlnet_states"] = [x.to(z) for x in controlnet_states] + else: + controlnet["controlnet_states"] = controlnet_states.to(z) + + add_cond_input = None + if add_cond is not None: + if (add_cond_start_percent <= current_step_percentage <= add_cond_end_percent) or \ + (add_cond_end_percent > 0 and idx == 0 and current_step_percentage >= add_cond_start_percent): + add_cond_input = add_cond + + if minimax_latents is not None: + if context_window is not None: + z_pos = z_neg = torch.cat([z, minimax_latents[:, context_window], minimax_mask_latents[:, context_window]], dim=0) + else: + z_pos = z_neg = torch.cat([z, minimax_latents, minimax_mask_latents], dim=0) + + if not multitalk_sampling and multitalk_audio_embedding is not None: + audio_embedding = multitalk_audio_embedding + audio_embs = [] + indices = (torch.arange(4 + 1) - 2) * 1 + human_num = len(audio_embedding) + # split audio with window size + if context_window is None: + for human_idx in range(human_num): + center_indices = torch.arange( + 0, + latent_video_length * 4 + 1 if add_cond is not None else (latent_video_length-1) * 4 + 1, + 1).unsqueeze(1) + indices.unsqueeze(0) + center_indices = torch.clamp(center_indices, min=0, max=audio_embedding[human_idx].shape[0] - 1) + audio_emb = audio_embedding[human_idx][center_indices].unsqueeze(0).to(device) + audio_embs.append(audio_emb) + else: + for human_idx in range(human_num): + audio_start = context_window[0] * 4 + audio_end = context_window[-1] * 4 + 1 + #print("audio_start: ", audio_start, "audio_end: ", audio_end) + center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0) + center_indices = torch.clamp(center_indices, min=0, max=audio_embedding[human_idx].shape[0] - 1) + audio_emb = audio_embedding[human_idx][center_indices].unsqueeze(0).to(device) + audio_embs.append(audio_emb) + multitalk_audio_input = torch.concat(audio_embs, dim=0).to(dtype) + + elif multitalk_sampling and multitalk_audio_embeds is not None: + multitalk_audio_input = multitalk_audio_embeds + + if context_window is not None and pcd_data is not None and pcd_data["render_latent"].shape[2] != context_frames: + pcd_data_input = {"render_latent": pcd_data["render_latent"][:, :, context_window]} + for k in pcd_data: + if k != "render_latent": + pcd_data_input[k] = pcd_data[k] + else: + pcd_data_input = pcd_data + + + base_params = { + 'seq_len': seq_len, + 'device': device, + 'freqs': freqs, + 't': timestep, + 'current_step': idx, + 'last_step': len(timesteps) - 1 == idx, + 'control_lora_enabled': control_lora_enabled, + 'enhance_enabled': enhance_enabled, + 'camera_embed': camera_embed, + 'unianim_data': unianim_data, + 'fun_ref': fun_ref_input if fun_ref_image is not None else None, + 'fun_camera': control_camera_input if control_camera_latents is not None else None, + 'audio_proj': audio_proj if fantasytalking_embeds is not None else None, + 'audio_scale': audio_scale, + "pcd_data": pcd_data_input, + "controlnet": controlnet, + "add_cond": add_cond_input, + "nag_params": text_embeds.get("nag_params", {}), + "nag_context": text_embeds.get("nag_prompt_embeds", None), + "multitalk_audio": multitalk_audio_input if multitalk_audio_embedding is not None else None, + "ref_target_masks": ref_target_masks if multitalk_audio_embedding is not None else None, + "inner_t": [shot_len] if shot_len else None, + } + + batch_size = 1 + + if not math.isclose(cfg_scale, 1.0): + if negative_embeds is None: + raise ValueError("Negative embeddings must be provided for CFG scale > 1.0") + if len(positive_embeds) > 1: + negative_embeds = negative_embeds * len(positive_embeds) + + try: + if not batched_cfg: + #cond + noise_pred_cond, cache_state_cond = transformer( + [z_pos], context=positive_embeds, y=[image_cond_input] if image_cond_input is not None else None, + clip_fea=clip_fea, is_uncond=False, current_step_percentage=current_step_percentage, + pred_id=cache_state[0] if cache_state else None, + vace_data=vace_data, attn_cond=attn_cond, + **base_params + ) + noise_pred_cond = noise_pred_cond[0].to(intermediate_device) + if math.isclose(cfg_scale, 1.0): + if use_fresca: + noise_pred_cond = fourier_filter( + noise_pred_cond, + scale_low=fresca_scale_low, + scale_high=fresca_scale_high, + freq_cutoff=fresca_freq_cutoff, + ) + return noise_pred_cond, [cache_state_cond] + #uncond + if fantasytalking_embeds is not None: + if not math.isclose(audio_cfg_scale[idx], 1.0): + base_params['audio_proj'] = None + noise_pred_uncond, cache_state_uncond = transformer( + [z_neg], context=negative_embeds, clip_fea=clip_fea_neg if clip_fea_neg is not None else clip_fea, + y=[image_cond_input] if image_cond_input is not None else None, + is_uncond=True, current_step_percentage=current_step_percentage, + pred_id=cache_state[1] if cache_state else None, + vace_data=vace_data, attn_cond=attn_cond_neg, + **base_params + ) + noise_pred_uncond = noise_pred_uncond[0].to(intermediate_device) + #phantom + if use_phantom and not math.isclose(phantom_cfg_scale[idx], 1.0): + noise_pred_phantom, cache_state_phantom = transformer( + [z_phantom_img], context=negative_embeds, clip_fea=clip_fea_neg if clip_fea_neg is not None else clip_fea, + y=[image_cond_input] if image_cond_input is not None else None, + is_uncond=True, current_step_percentage=current_step_percentage, + pred_id=cache_state[2] if cache_state else None, + vace_data=None, + **base_params + ) + noise_pred_phantom = noise_pred_phantom[0].to(intermediate_device) + + noise_pred = noise_pred_uncond + phantom_cfg_scale[idx] * (noise_pred_phantom - noise_pred_uncond) + cfg_scale * (noise_pred_cond - noise_pred_phantom) + return noise_pred, [cache_state_cond, cache_state_uncond, cache_state_phantom] + #fantasytalking + if fantasytalking_embeds is not None: + if not math.isclose(audio_cfg_scale[idx], 1.0): + if cache_state is not None and len(cache_state) != 3: + cache_state.append(None) + base_params['audio_proj'] = None + noise_pred_no_audio, cache_state_audio = transformer( + [z_pos], context=positive_embeds, y=[image_cond_input] if image_cond_input is not None else None, + clip_fea=clip_fea, is_uncond=False, current_step_percentage=current_step_percentage, + pred_id=cache_state[2] if cache_state else None, + vace_data=vace_data, + **base_params + ) + noise_pred_no_audio = noise_pred_no_audio[0].to(intermediate_device) + noise_pred = ( + noise_pred_uncond + + cfg_scale * (noise_pred_no_audio - noise_pred_uncond) + + audio_cfg_scale[idx] * (noise_pred_cond - noise_pred_no_audio) + ) + return noise_pred, [cache_state_cond, cache_state_uncond, cache_state_audio] + elif multitalk_audio_embedding is not None: + if not math.isclose(audio_cfg_scale[idx], 1.0): + if cache_state is not None and len(cache_state) != 3: + cache_state.append(None) + base_params['multitalk_audio'] = torch.zeros_like(multitalk_audio_input)[-1:] + noise_pred_no_audio, cache_state_audio = transformer( + [z_pos], context=negative_embeds, y=[image_cond_input] if image_cond_input is not None else None, + clip_fea=clip_fea, is_uncond=False, current_step_percentage=current_step_percentage, + pred_id=cache_state[2] if cache_state else None, + vace_data=vace_data, + **base_params + ) + noise_pred_no_audio = noise_pred_no_audio[0].to(intermediate_device) + noise_pred = ( + noise_pred_no_audio + + cfg_scale * (noise_pred_cond - noise_pred_uncond) + + audio_cfg_scale[idx] * (noise_pred_uncond - noise_pred_no_audio) + ) + return noise_pred, [cache_state_cond, cache_state_uncond, cache_state_audio] + + #batched + else: + cache_state_uncond = None + [noise_pred_cond, noise_pred_uncond], cache_state_cond = transformer( + [z] + [z], context=positive_embeds + negative_embeds, + y=[image_cond_input] + [image_cond_input] if image_cond_input is not None else None, + clip_fea=clip_fea.repeat(2,1,1), is_uncond=False, current_step_percentage=current_step_percentage, + pred_id=cache_state[0] if cache_state else None, + **base_params + ) + except Exception as e: + log.error(f"Error during model prediction: {e}") + if force_offload: + if model["manual_offloading"]: + offload_transformer(transformer) + raise e + + #https://github.com/WeichenFan/CFG-Zero-star/ + if use_cfg_zero_star: + alpha = optimized_scale( + noise_pred_cond.view(batch_size, -1), + noise_pred_uncond.view(batch_size, -1) + ).view(batch_size, 1, 1, 1) + else: + alpha = 1.0 + + noise_pred_uncond_scaled = noise_pred_uncond * alpha + + if use_tangential: + noise_pred_uncond_scaled = tangential_projection(noise_pred_cond, noise_pred_uncond_scaled) + + # RAAG (RATIO-aware Adaptive Guidance) + if raag_alpha > 0.0: + cfg_scale = get_raag_guidance(noise_pred_cond, noise_pred_uncond_scaled, cfg_scale, raag_alpha) + log.info(f"RAAG modified cfg: {cfg_scale}") + + #https://github.com/WikiChao/FreSca + if use_fresca: + filtered_cond = fourier_filter( + noise_pred_cond - noise_pred_uncond, + scale_low=fresca_scale_low, + scale_high=fresca_scale_high, + freq_cutoff=fresca_freq_cutoff, + ) + noise_pred = noise_pred_uncond_scaled + cfg_scale * filtered_cond * alpha + else: + noise_pred = noise_pred_uncond_scaled + cfg_scale * (noise_pred_cond - noise_pred_uncond_scaled) + + + return noise_pred, [cache_state_cond, cache_state_uncond] + + log.info(f"Seq len: {seq_len}") + + + + if args.preview_method in [LatentPreviewMethod.Auto, LatentPreviewMethod.Latent2RGB]: #default for latent2rgb + from latent_preview import prepare_callback + else: + from .latent_preview import prepare_callback #custom for tiny VAE previews + callback = prepare_callback(patcher, len(timesteps)) + + log.info(f"Sampling {(latent_video_length-1) * 4 + 1} frames at {latent.shape[3]*vae_upscale_factor}x{latent.shape[2]*vae_upscale_factor} with {steps} steps") + + intermediate_device = device + + # diff diff prep + masks = None + if samples is not None and mask is not None: + mask = 1 - mask + thresholds = torch.arange(len(timesteps), dtype=original_image.dtype) / len(timesteps) + thresholds = thresholds.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1).to(device) + masks = mask.repeat(len(timesteps), 1, 1, 1, 1).to(device) + masks = masks > thresholds + + latent_shift_loop = False + if loop_args is not None: + latent_shift_loop = True + is_looped = True + latent_skip = loop_args["shift_skip"] + latent_shift_start_percent = loop_args["start_percent"] + latent_shift_end_percent = loop_args["end_percent"] + shift_idx = 0 + + #clear memory before sampling + mm.soft_empty_cache() + gc.collect() + try: + torch.cuda.reset_peak_memory_stats(device) + #torch.cuda.memory._record_memory_history(max_entries=100000) + except: + pass + + # Main sampling loop with FreeInit iterations + iterations = freeinit_args.get("freeinit_num_iters", 3) if freeinit_args is not None else 1 + current_latent = latent + + for iter_idx in range(iterations): + + # FreeInit noise reinitialization (after first iteration) + if freeinit_args is not None and iter_idx > 0: + # restart scheduler for each iteration + sample_scheduler, timesteps = get_scheduler(scheduler, steps, shift, device, transformer.dim, flowedit_args, denoise_strength, sigmas=sigmas) + + # Re-apply start_step and end_step logic to timesteps and sigmas + if end_step != -1: + timesteps = timesteps[:end_step] + sample_scheduler.sigmas = sample_scheduler.sigmas[:end_step+1] + if start_step > 0: + timesteps = timesteps[start_step:] + sample_scheduler.sigmas = sample_scheduler.sigmas[start_step:] + if hasattr(sample_scheduler, 'timesteps'): + sample_scheduler.timesteps = timesteps + + # Diffuse current latent to t=999 + diffuse_timesteps = torch.full((noise.shape[0],), 999, device=device, dtype=torch.long) + z_T = add_noise( + current_latent.to(device), + initial_noise_saved.to(device), + diffuse_timesteps + ) + + # Generate new random noise + z_rand = torch.randn(z_T.shape, dtype=torch.float32, generator=seed_g, device=torch.device("cpu")) + + # Apply frequency mixing + current_latent = freq_mix_3d(z_T.to(torch.float32), z_rand.to(device), LPF=freq_filter) + current_latent = current_latent.to(dtype) + + # Store initial noise for first iteration + if freeinit_args is not None and iter_idx == 0: + initial_noise_saved = current_latent.detach().clone() + if samples is not None: + current_latent = input_samples.to(device) + continue + + # Reset per-iteration states + self.cache_state = [None, None] + self.cache_state_source = [None, None] + self.cache_states_context = [] + if context_options is not None: + self.window_tracker = WindowTracker(verbose=context_options["verbose"]) + + # Set latent for denoising + latent = current_latent + + try: + pbar = ProgressBar(len(timesteps)) + #region main loop start + for idx, t in enumerate(tqdm(timesteps)): + if flowedit_args is not None: + if idx < skip_steps: + continue + + # diff diff + if masks is not None: + if idx < len(timesteps) - 1: + noise_timestep = timesteps[idx+1] + image_latent = sample_scheduler.scale_noise( + original_image, torch.tensor([noise_timestep]), noise.to(device) + ) + mask = masks[idx] + mask = mask.to(latent) + latent = image_latent * mask + latent * (1-mask) + # end diff diff + + latent_model_input = latent.to(device) + + current_step_percentage = idx / len(timesteps) + + timestep = torch.tensor([t]).to(device) + if scheduler == "flowmatch_pusa" or (is_5b and 'all_indices' in locals()): + orig_timestep = timestep + timestep = timestep.unsqueeze(1).repeat(1, latent_video_length) + if extra_latents is not None: + if 'all_indices' in locals() and all_indices: + timestep[:, all_indices] = 0 + #print("timestep: ", timestep) + + ### latent shift + if latent_shift_loop: + if latent_shift_start_percent <= current_step_percentage <= latent_shift_end_percent: + latent_model_input = torch.cat([latent_model_input[:, shift_idx:]] + [latent_model_input[:, :shift_idx]], dim=1) + + #enhance-a-video + enhance_enabled = False + if feta_args is not None and feta_start_percent <= current_step_percentage <= feta_end_percent: + enhance_enabled = True + + #flow-edit + if flowedit_args is not None: + sigma = t / 1000.0 + sigma_prev = (timesteps[idx + 1] if idx < len(timesteps) - 1 else timesteps[-1]) / 1000.0 + noise = torch.randn(x_init.shape, generator=seed_g, device=torch.device("cpu")) + if idx < len(timesteps) - drift_steps: + cfg = drift_cfg + + zt_src = (1-sigma) * x_init + sigma * noise.to(t) + zt_tgt = x_tgt + zt_src - x_init + + #source + if idx < len(timesteps) - drift_steps: + if context_options is not None: + counter = torch.zeros_like(zt_src, device=intermediate_device) + vt_src = torch.zeros_like(zt_src, device=intermediate_device) + context_queue = list(context(idx, steps, latent_video_length, context_frames, context_stride, context_overlap)) + for c in context_queue: + window_id = self.window_tracker.get_window_id(c) + + if cache_args is not None: + current_teacache = self.window_tracker.get_teacache(window_id, self.cache_state) + else: + current_teacache = None + + prompt_index = min(int(max(c) / section_size), num_prompts - 1) + if context_options["verbose"]: + log.info(f"Prompt index: {prompt_index}") + + if len(source_embeds["prompt_embeds"]) > 1: + positive = source_embeds["prompt_embeds"][prompt_index] + else: + positive = source_embeds["prompt_embeds"] + + partial_img_emb = None + if source_image_cond is not None: + partial_img_emb = source_image_cond[:, c, :, :] + partial_img_emb[:, 0, :, :] = source_image_cond[:, 0, :, :].to(intermediate_device) + + partial_zt_src = zt_src[:, c, :, :] + vt_src_context, new_teacache = predict_with_cfg( + partial_zt_src, cfg[idx], + positive, source_embeds["negative_prompt_embeds"], + timestep, idx, partial_img_emb, control_latents, + source_clip_fea, current_teacache) + + if cache_args is not None: + self.window_tracker.cache_states[window_id] = new_teacache + + window_mask = create_window_mask(vt_src_context, c, latent_video_length, context_overlap) + vt_src[:, c, :, :] += vt_src_context * window_mask + counter[:, c, :, :] += window_mask + vt_src /= counter + else: + vt_src, self.cache_state_source = predict_with_cfg( + zt_src, cfg[idx], + source_embeds["prompt_embeds"], + source_embeds["negative_prompt_embeds"], + timestep, idx, source_image_cond, + source_clip_fea, control_latents, + cache_state=self.cache_state_source) + else: + if idx == len(timesteps) - drift_steps: + x_tgt = zt_tgt + zt_tgt = x_tgt + vt_src = 0 + #target + if context_options is not None: + counter = torch.zeros_like(zt_tgt, device=intermediate_device) + vt_tgt = torch.zeros_like(zt_tgt, device=intermediate_device) + context_queue = list(context(idx, steps, latent_video_length, context_frames, context_stride, context_overlap)) + for c in context_queue: + window_id = self.window_tracker.get_window_id(c) + + if cache_args is not None: + current_teacache = self.window_tracker.get_teacache(window_id, self.cache_state) + else: + current_teacache = None + + prompt_index = min(int(max(c) / section_size), num_prompts - 1) + if context_options["verbose"]: + log.info(f"Prompt index: {prompt_index}") + + if len(text_embeds["prompt_embeds"]) > 1: + positive = text_embeds["prompt_embeds"][prompt_index] + else: + positive = text_embeds["prompt_embeds"] + + partial_img_emb = None + partial_control_latents = None + if image_cond is not None: + partial_img_emb = image_cond[:, c, :, :] + partial_img_emb[:, 0, :, :] = image_cond[:, 0, :, :].to(intermediate_device) + if control_latents is not None: + partial_control_latents = control_latents[:, c, :, :] + + partial_zt_tgt = zt_tgt[:, c, :, :] + vt_tgt_context, new_teacache = predict_with_cfg( + partial_zt_tgt, cfg[idx], + positive, text_embeds["negative_prompt_embeds"], + timestep, idx, partial_img_emb, partial_control_latents, + clip_fea, current_teacache) + + if cache_args is not None: + self.window_tracker.cache_states[window_id] = new_teacache + + window_mask = create_window_mask(vt_tgt_context, c, latent_video_length, context_overlap) + vt_tgt[:, c, :, :] += vt_tgt_context * window_mask + counter[:, c, :, :] += window_mask + vt_tgt /= counter + else: + vt_tgt, self.cache_state = predict_with_cfg( + zt_tgt, cfg[idx], + text_embeds["prompt_embeds"], + text_embeds["negative_prompt_embeds"], + timestep, idx, image_cond, clip_fea, control_latents, + cache_state=self.cache_state) + v_delta = vt_tgt - vt_src + x_tgt = x_tgt.to(torch.float32) + v_delta = v_delta.to(torch.float32) + x_tgt = x_tgt + (sigma_prev - sigma) * v_delta + x0 = x_tgt + #region context windowing + elif context_options is not None: + counter = torch.zeros_like(latent_model_input, device=intermediate_device) + noise_pred = torch.zeros_like(latent_model_input, device=intermediate_device) + context_queue = list(context(idx, steps, latent_video_length, context_frames, context_stride, context_overlap)) + fraction_per_context = 1.0 / len(context_queue) + context_pbar = ProgressBar(steps) + step_start_progress = idx + + # Validate all context windows before processing + max_idx = latent_model_input.shape[1] if latent_model_input.ndim > 1 else 0 + for window_indices in context_queue: + if not all(0 <= idx < max_idx for idx in window_indices): + raise ValueError(f"Invalid context window indices {window_indices} for latent_model_input with shape {latent_model_input.shape}") + + for i, c in enumerate(context_queue): + window_id = self.window_tracker.get_window_id(c) + + if cache_args is not None: + current_teacache = self.window_tracker.get_teacache(window_id, self.cache_state) + else: + current_teacache = None + + prompt_index = min(int(max(c) / section_size), num_prompts - 1) + if context_options["verbose"]: + log.info(f"Prompt index: {prompt_index}") + + # Use the appropriate prompt for this section + if len(text_embeds["prompt_embeds"]) > 1: + positive = [text_embeds["prompt_embeds"][prompt_index]] + else: + positive = text_embeds["prompt_embeds"] + + partial_img_emb = None + partial_control_latents = None + if image_cond is not None: + partial_img_emb = image_cond[:, c] + + if c[0] != 0 and context_reference_latent is not None: + new_init_image = context_reference_latent[:, 0].to(intermediate_device) + # Concatenate the first 4 channels of partial_img_emb with new_init_image to match the required shape + if new_init_image.shape[0] + 4 == partial_img_emb.shape[0]: + partial_img_emb[:, 0] = torch.cat([ + image_cond[:4, 0], + new_init_image + ], dim=0) + else: + # fallback to original assignment if shape matches + partial_img_emb[:, 0] = new_init_image + else: + new_init_image = image_cond[:, 0].to(intermediate_device) + partial_img_emb[:, 0] = new_init_image + + if control_latents is not None: + partial_control_latents = control_latents[:, c] + + partial_control_camera_latents = None + if control_camera_latents is not None: + partial_control_camera_latents = control_camera_latents[:, :, c] + + partial_vace_context = None + if vace_data is not None: + window_vace_data = [] + for vace_entry in vace_data: + partial_context = vace_entry["context"][0][:, c] + if has_ref: + partial_context[:, 0] = vace_entry["context"][0][:, 0] + + window_vace_data.append({ + "context": [partial_context], + "scale": vace_entry["scale"], + "start": vace_entry["start"], + "end": vace_entry["end"], + "seq_len": vace_entry["seq_len"] + }) + + partial_vace_context = window_vace_data + + partial_audio_proj = None + if fantasytalking_embeds is not None: + partial_audio_proj = audio_proj[:, c] + + partial_latent_model_input = latent_model_input[:, c] + if latents_to_insert is not None and c[0] != 0: + partial_latent_model_input[:, :1] = latents_to_insert + + partial_unianim_data = None + if unianim_data is not None: + partial_dwpose = dwpose_data[:, :, c] + partial_dwpose_flat=rearrange(partial_dwpose, 'b c f h w -> b (f h w) c') + partial_unianim_data = { + "dwpose": partial_dwpose_flat, + "random_ref": unianim_data["random_ref"], + "strength": unianimate_poses["strength"], + "start_percent": unianimate_poses["start_percent"], + "end_percent": unianimate_poses["end_percent"] + } + + partial_add_cond = None + if add_cond is not None: + partial_add_cond = add_cond[:, :, c].to(device, dtype) + + if len(timestep.shape) != 1: + partial_timestep = timestep[:, c] + partial_timestep[:, :1] = 0 + else: + partial_timestep = timestep + #print("Partial timestep:", partial_timestep) + + noise_pred_context, new_teacache = predict_with_cfg( + partial_latent_model_input, + cfg[idx], positive, + text_embeds["negative_prompt_embeds"], + partial_timestep, idx, partial_img_emb, clip_fea, partial_control_latents, partial_vace_context, partial_unianim_data,partial_audio_proj, + partial_control_camera_latents, partial_add_cond, current_teacache, context_window=c) + + if cache_args is not None: + self.window_tracker.cache_states[window_id] = new_teacache + + window_mask = create_window_mask(noise_pred_context, c, latent_video_length, context_overlap, looped=is_looped, window_type=context_options["fuse_method"]) + noise_pred[:, c] += noise_pred_context * window_mask + counter[:, c] += window_mask + context_pbar.update_absolute(step_start_progress + (i + 1) * fraction_per_context, steps) + noise_pred /= counter + #region multitalk + elif multitalk_sampling: + original_image = cond_image = image_embeds.get("multitalk_start_image", None) + offload = image_embeds.get("force_offload", False) + tiled_vae = image_embeds.get("tiled_vae", False) + frame_num = clip_length = image_embeds.get("num_frames", 81) + vae = image_embeds.get("vae", None) + clip_embeds = image_embeds.get("clip_context", None) + colormatch = image_embeds.get("colormatch", "disabled") + motion_frame = image_embeds.get("motion_frame", 25) + target_w = image_embeds.get("target_w", None) + target_h = image_embeds.get("target_h", None) + + gen_video_list = [] + is_first_clip = True + arrive_last_frame = False + cur_motion_frames_num = 1 + audio_start_idx = iteration_count = 0 + audio_end_idx = audio_start_idx + clip_length + indices = (torch.arange(4 + 1) - 2) * 1 + + if multitalk_embeds is not None: + total_frames = len(multitalk_audio_embedding) + + estimated_iterations = total_frames // (frame_num - motion_frame) + 1 + loop_pbar = tqdm(total=estimated_iterations, desc="Generating video clips") + callback = prepare_callback(patcher, estimated_iterations) + + audio_embedding = multitalk_audio_embedding + human_num = len(audio_embedding) + audio_embs = None + while True: # start video generation iteratively + if multitalk_embeds is not None: + audio_embs = [] + # split audio with window size + for human_idx in range(human_num): + center_indices = torch.arange(audio_start_idx, audio_end_idx, 1).unsqueeze(1) + indices.unsqueeze(0) + center_indices = torch.clamp(center_indices, min=0, max=audio_embedding[human_idx].shape[0]-1) + audio_emb = audio_embedding[human_idx][center_indices].unsqueeze(0).to(device) + audio_embs.append(audio_emb) + audio_embs = torch.concat(audio_embs, dim=0).to(dtype) + + h, w = cond_image.shape[-2], cond_image.shape[-1] + lat_h, lat_w = h // VAE_STRIDE[1], w // VAE_STRIDE[2] + seq_len = ((frame_num - 1) // VAE_STRIDE[0] + 1) * lat_h * lat_w // (PATCH_SIZE[1] * PATCH_SIZE[2]) + + noise = torch.randn( + 16, (frame_num - 1) // 4 + 1, + lat_h, lat_w, dtype=torch.float32, device=torch.device("cpu"), generator=seed_g).to(device) + + # get mask + msk = torch.ones(1, frame_num, lat_h, lat_w, device=device) + msk[:, cur_motion_frames_num:] = 0 + msk = torch.concat([ + torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] + ], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2).to(dtype) # B 4 T H W + + mm.soft_empty_cache() + + # zero padding and vae encode + video_frames = torch.zeros(1, cond_image.shape[1], frame_num-cond_image.shape[2], target_h, target_w, device=device, dtype=vae.dtype) + padding_frames_pixels_values = torch.concat([cond_image.to(device, vae.dtype), video_frames], dim=2) + + vae.to(device) + y = vae.encode(padding_frames_pixels_values, device=device, tiled=tiled_vae).to(dtype) + vae.to(offload_device) + + cur_motion_frames_latent_num = int(1 + (cur_motion_frames_num-1) // 4) + latent_motion_frames = y[:, :, :cur_motion_frames_latent_num][0] # C T H W + y = torch.concat([msk, y], dim=1) # B 4+C T H W + mm.soft_empty_cache() + + if scheduler == "multitalk": + timesteps = list(np.linspace(1000, 1, steps, dtype=np.float32)) + timesteps.append(0.) + timesteps = [torch.tensor([t], device=device) for t in timesteps] + timesteps = [timestep_transform(t, shift=shift, num_timesteps=1000) for t in timesteps] + else: + sample_scheduler, timesteps = get_scheduler(scheduler, steps, shift, device, transformer.dim, flowedit_args, denoise_strength, sigmas=sigmas) + + transformed_timesteps = [] + for t in timesteps: + t_tensor = torch.tensor([t.item()], device=device) + transformed_timesteps.append(t_tensor) + + transformed_timesteps.append(torch.tensor([0.], device=device)) + timesteps = transformed_timesteps + + # sample videos + latent = noise + + # injecting motion frames + if not is_first_clip: + latent_motion_frames = latent_motion_frames.to(latent.dtype).to(device) + motion_add_noise = torch.randn(latent_motion_frames.shape, device=torch.device("cpu"), generator=seed_g).to(device).contiguous() + add_latent = add_noise(latent_motion_frames, motion_add_noise, timesteps[0]) + _, T_m, _, _ = add_latent.shape + latent[:, :T_m] = add_latent + + if offload: + #blockswap init + if transformer_options is not None: + block_swap_args = transformer_options.get("block_swap_args", None) + + if block_swap_args is not None: + transformer.use_non_blocking = block_swap_args.get("use_non_blocking", False) + for name, param in transformer.named_parameters(): + if "block" not in name: + param.data = param.data.to(device) + if "control_adapter" in name: + param.data = param.data.to(device) + elif block_swap_args["offload_txt_emb"] and "txt_emb" in name: + param.data = param.data.to(offload_device) + elif block_swap_args["offload_img_emb"] and "img_emb" in name: + param.data = param.data.to(offload_device) + + transformer.block_swap( + block_swap_args["blocks_to_swap"] - 1 , + block_swap_args["offload_txt_emb"], + block_swap_args["offload_img_emb"], + vace_blocks_to_swap = block_swap_args.get("vace_blocks_to_swap", None), + ) + + elif model["auto_cpu_offload"]: + for module in transformer.modules(): + if hasattr(module, "offload"): + module.offload() + if hasattr(module, "onload"): + module.onload() + elif model["manual_offloading"]: + transformer.to(device) + + comfy_pbar = ProgressBar(len(timesteps)-1) + for i in tqdm(range(len(timesteps)-1)): + timestep = timesteps[i] + latent_model_input = latent.to(device) + + noise_pred, self.cache_state = predict_with_cfg( + latent_model_input, + cfg[idx], + text_embeds["prompt_embeds"], + text_embeds["negative_prompt_embeds"], + timestep, idx, y.squeeze(0), clip_embeds.to(dtype), control_latents, vace_data, unianim_data, audio_proj, control_camera_latents, add_cond, + cache_state=self.cache_state, multitalk_audio_embeds=audio_embs) + + if callback is not None: + callback_latent = (latent_model_input.to(device) - noise_pred.to(device) * t.to(device) / 1000).detach().permute(1,0,2,3) + callback(iteration_count, callback_latent, None, estimated_iterations) + + # update latent + if scheduler == "multitalk": + noise_pred = -noise_pred + dt = timesteps[i] - timesteps[i + 1] + dt = dt / 1000 + latent = latent + noise_pred * dt[:, None, None, None] + else: + latent = latent.to(intermediate_device) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + timestep, + latent.unsqueeze(0), + **scheduler_step_args)[0] + latent = temp_x0.squeeze(0) + + # injecting motion frames + if not is_first_clip: + latent_motion_frames = latent_motion_frames.to(latent.dtype).to(device) + motion_add_noise = torch.randn(latent_motion_frames.shape, device=torch.device("cpu"), generator=seed_g).to(device).contiguous() + add_latent = add_noise(latent_motion_frames, motion_add_noise, timesteps[i+1]) + _, T_m, _, _ = add_latent.shape + latent[:, :T_m] = add_latent + + x0 = latent.to(device) + del latent_model_input, timestep + comfy_pbar.update(1) + + if offload: + transformer.to(offload_device) + vae.to(device) + videos = vae.decode(x0.unsqueeze(0).to(vae.dtype), device=device, tiled=tiled_vae) + vae.to(offload_device) + + # cache generated samples + videos = torch.stack(videos).cpu() # B C T H W + if colormatch != "disabled": + videos = videos[0].permute(1, 2, 3, 0).cpu().float().numpy() + from color_matcher import ColorMatcher + cm = ColorMatcher() + cm_result_list = [] + for img in videos: + cm_result = cm.transfer(src=img, ref=original_image[0].permute(1, 2, 3, 0).squeeze(0).cpu().numpy(), method=colormatch) + cm_result_list.append(torch.from_numpy(cm_result)) + + videos = torch.stack(cm_result_list, dim=0).to(torch.float32).permute(3, 0, 1, 2).unsqueeze(0) + + if is_first_clip: + gen_video_list.append(videos) + else: + gen_video_list.append(videos[:, :, cur_motion_frames_num:]) + + # decide whether is done + if arrive_last_frame: + loop_pbar.update(estimated_iterations - iteration_count) + loop_pbar.close() + break + + # update next condition frames + is_first_clip = False + cur_motion_frames_num = motion_frame + + cond_image = videos[:, :, -cur_motion_frames_num:].to(torch.float32).to(device) + + # Update progress bar + iteration_count += 1 + loop_pbar.update(1) + + # Repeat audio emb + if multitalk_embeds is not None: + audio_start_idx += (frame_num - cur_motion_frames_num) + audio_end_idx = audio_start_idx + clip_length + if audio_end_idx >= len(audio_embedding[0]): + arrive_last_frame = True + miss_lengths = [] + source_frames = [] + for human_inx in range(human_num): + source_frame = len(audio_embedding[human_inx]) + source_frames.append(source_frame) + if audio_end_idx >= len(audio_embedding[human_inx]): + miss_length = audio_end_idx - len(audio_embedding[human_inx]) + 3 + add_audio_emb = torch.flip(audio_embedding[human_inx][-1*miss_length:], dims=[0]) + audio_embedding[human_inx] = torch.cat([audio_embedding[human_inx], add_audio_emb], dim=0) + miss_lengths.append(miss_length) + else: + miss_lengths.append(0) + + gen_video_samples = torch.cat(gen_video_list, dim=2).to(torch.float32) + + del noise, latent + if force_offload: + if model["manual_offloading"]: + transformer.to(offload_device) + mm.soft_empty_cache() + gc.collect() + try: + print_memory(device) + torch.cuda.reset_peak_memory_stats(device) + except: + pass + return {"video": gen_video_samples[0].permute(1, 2, 3, 0).cpu()}, + + #region normal inference + else: + noise_pred, self.cache_state = predict_with_cfg( + latent_model_input, + cfg[idx], + text_embeds["prompt_embeds"], + text_embeds["negative_prompt_embeds"], + timestep, idx, image_cond, clip_fea, control_latents, vace_data, unianim_data, audio_proj, control_camera_latents, add_cond, + cache_state=self.cache_state) + + if latent_shift_loop: + #reverse latent shift + if latent_shift_start_percent <= current_step_percentage <= latent_shift_end_percent: + noise_pred = torch.cat([noise_pred[:, latent_video_length - shift_idx:]] + [noise_pred[:, :latent_video_length - shift_idx]], dim=1) + shift_idx = (shift_idx + latent_skip) % latent_video_length + + + if flowedit_args is None: + latent = latent.to(intermediate_device) + + if len(timestep.shape) != 1 and scheduler != "flowmatch_pusa": #5b + # all_indices is a list of indices to skip + total_indices = list(range(latent.shape[1])) + process_indices = [i for i in total_indices if i not in all_indices] + if process_indices: + latent_to_process = latent[:, process_indices] + noise_pred_to_process = noise_pred[:, process_indices] + latent_slice = sample_scheduler.step( + noise_pred_to_process.unsqueeze(0), + orig_timestep, + latent_to_process.unsqueeze(0), + **scheduler_step_args + )[0].squeeze(0) + # Reconstruct the latent tensor: keep skipped indices as-is, update others + new_latent = [] + for i in total_indices: + if i in all_indices: + new_latent.append(latent[:, i:i+1]) + else: + j = process_indices.index(i) + new_latent.append(latent_slice[:, j:j+1]) + latent = torch.cat(new_latent, dim=1) + else: + latent = sample_scheduler.step( + noise_pred[:, :orig_noise_len].unsqueeze(0) if recammaster is not None else noise_pred.unsqueeze(0), + timestep, + latent[:, :orig_noise_len].unsqueeze(0) if recammaster is not None else latent.unsqueeze(0), + **scheduler_step_args)[0].squeeze(0) + + if freeinit_args is not None: + current_latent = latent.clone() + + if callback is not None: + if recammaster is not None: + callback_latent = (latent_model_input[:, :orig_noise_len].to(device) - noise_pred[:, :orig_noise_len].to(device) * t.to(device) / 1000).detach() + elif phantom_latents is not None: + callback_latent = (latent_model_input[:,:-phantom_latents.shape[1]].to(device) - noise_pred[:,:-phantom_latents.shape[1]].to(device) * t.to(device) / 1000).detach() + else: + callback_latent = (latent_model_input.to(device) - noise_pred.to(device) * t.to(device) / 1000).detach() + callback(idx, callback_latent.permute(1,0,2,3), None, len(timesteps)) + else: + pbar.update(1) + else: + if callback is not None: + callback_latent = (zt_tgt.to(device) - vt_tgt.to(device) * t.to(device) / 1000).detach() + callback(idx, callback_latent.permute(1,0,2,3), None, len(timesteps)) + else: + pbar.update(1) + except Exception as e: + log.error(f"Error during sampling: {e}") + if force_offload: + if model["manual_offloading"]: + offload_transformer(transformer) + raise e + + if phantom_latents is not None: + latent = latent[:,:-phantom_latents.shape[1]] + + if cache_args is not None: + cache_report(transformer, cache_args) + + if force_offload: + if model["manual_offloading"]: + offload_transformer(transformer) + + try: + print_memory(device) + #torch.cuda.memory._dump_snapshot("wanvideowrapper_memory_dump.pt") + #torch.cuda.memory._record_memory_history(enabled=None) + torch.cuda.reset_peak_memory_stats(device) + except: + pass + return ({ + "samples": latent.unsqueeze(0).cpu(), + "looped": is_looped, + "end_image": end_image if not fun_or_fl2v_model else None, + "has_ref": has_ref, + "drop_last": drop_last, + "generator_state": seed_g.get_state(), + },{ + "samples": callback_latent.unsqueeze(0).cpu() if callback is not None else None, + }) + +#region VideoDecode +class WanVideoDecode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "vae": ("WANVAE",), + "samples": ("LATENT",), + "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": ( + "Drastically reduces memory use but will introduce seams at tile stride boundaries. " + "The location and number of seams is dictated by the tile stride size. " + "The visibility of seams can be controlled by increasing the tile size. " + "Seams become less obvious at 1.5x stride and are barely noticeable at 2x stride size. " + "Which is to say if you use a stride width of 160, the seams are barely noticeable with a tile width of 320." + )}), + "tile_x": ("INT", {"default": 272, "min": 40, "max": 2048, "step": 8, "tooltip": "Tile width in pixels. Smaller values use less VRAM but will make seams more obvious."}), + "tile_y": ("INT", {"default": 272, "min": 40, "max": 2048, "step": 8, "tooltip": "Tile height in pixels. Smaller values use less VRAM but will make seams more obvious."}), + "tile_stride_x": ("INT", {"default": 144, "min": 32, "max": 2040, "step": 8, "tooltip": "Tile stride width in pixels. Smaller values use less VRAM but will introduce more seams."}), + "tile_stride_y": ("INT", {"default": 128, "min": 32, "max": 2040, "step": 8, "tooltip": "Tile stride height in pixels. Smaller values use less VRAM but will introduce more seams."}), + }, + "optional": { + "normalization": (["default", "minmax"], {"advanced": True}), + } + } + + @classmethod + def VALIDATE_INPUTS(s, tile_x, tile_y, tile_stride_x, tile_stride_y): + if tile_x <= tile_stride_x: + return "Tile width must be larger than the tile stride width." + if tile_y <= tile_stride_y: + return "Tile height must be larger than the tile stride height." + return True + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) + FUNCTION = "decode" + CATEGORY = "WanVideoWrapper" + + def decode(self, vae, samples, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, normalization="default"): + mm.soft_empty_cache() + video = samples.get("video", None) + if video is not None: + video = torch.clamp(video, -1.0, 1.0) + video = (video + 1.0) / 2.0 + return video.cpu(), + latents = samples["samples"] + end_image = samples.get("end_image", None) + has_ref = samples.get("has_ref", False) + drop_last = samples.get("drop_last", False) + is_looped = samples.get("looped", False) + + vae.to(device) + + latents = latents.to(device = device, dtype = vae.dtype) + + mm.soft_empty_cache() + + if has_ref: + latents = latents[:, :, 1:] + if drop_last: + latents = latents[:, :, :-1] + + if type(vae).__name__ == "TAEHV": + images = vae.decode_video(latents.permute(0, 2, 1, 3, 4))[0].permute(1, 0, 2, 3) + images = torch.clamp(images, 0.0, 1.0) + images = images.permute(1, 2, 3, 0).cpu().float() + return (images,) + else: + if end_image is not None: + enable_vae_tiling = False + images = vae.decode(latents, device=device, end_=(end_image is not None), tiled=enable_vae_tiling, tile_size=(tile_x//8, tile_y//8), tile_stride=(tile_stride_x//8, tile_stride_y//8))[0] + vae.model.clear_cache() + + images = images.cpu().float() + + if normalization == "minmax": + images.sub_(images.min()).div_(images.max() - images.min()) + else: + images.clamp_(-1.0, 1.0) + images.add_(1.0).div_(2.0) + + if is_looped: + temp_latents = torch.cat([latents[:, :, -3:]] + [latents[:, :, :2]], dim=2) + temp_images = vae.decode(temp_latents, device=device, end_=(end_image is not None), tiled=enable_vae_tiling, tile_size=(tile_x//vae.upsampling_factor, tile_y//vae.upsampling_factor), tile_stride=(tile_stride_x//vae.upsampling_factor, tile_stride_y//vae.upsampling_factor))[0] + temp_images = temp_images.cpu().float() + temp_images = (temp_images - temp_images.min()) / (temp_images.max() - temp_images.min()) + images = torch.cat([temp_images[:, 9:].to(images), images[:, 5:]], dim=1) + + if end_image is not None: + images = images[:, 0:-1] + + vae.model.clear_cache() + vae.to(offload_device) + mm.soft_empty_cache() + + images.clamp_(0.0, 1.0) + + return (images.permute(1, 2, 3, 0),) + +#region VideoEncode +class WanVideoEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "vae": ("WANVAE",), + "image": ("IMAGE",), + "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}), + "tile_x": ("INT", {"default": 272, "min": 64, "max": 2048, "step": 1, "tooltip": "Tile size in pixels, smaller values use less VRAM, may introduce more seams"}), + "tile_y": ("INT", {"default": 272, "min": 64, "max": 2048, "step": 1, "tooltip": "Tile size in pixels, smaller values use less VRAM, may introduce more seams"}), + "tile_stride_x": ("INT", {"default": 144, "min": 32, "max": 2048, "step": 32, "tooltip": "Tile stride in pixels, smaller values use less VRAM, may introduce more seams"}), + "tile_stride_y": ("INT", {"default": 128, "min": 32, "max": 2048, "step": 32, "tooltip": "Tile stride in pixels, smaller values use less VRAM, may introduce more seams"}), + }, + "optional": { + "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of noise augmentation, helpful for leapfusion I2V where some noise can add motion and give sharper results"}), + "latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for leapfusion I2V where lower values allow for more motion"}), + "mask": ("MASK", ), + } + } + + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("samples",) + FUNCTION = "encode" + CATEGORY = "WanVideoWrapper" + + def encode(self, vae, image, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, noise_aug_strength=0.0, latent_strength=1.0, mask=None): + vae.to(device) + + image = image.clone() + + B, H, W, C = image.shape + if W % 16 != 0 or H % 16 != 0: + new_height = (H // 16) * 16 + new_width = (W // 16) * 16 + log.warning(f"Image size {W}x{H} is not divisible by 16, resizing to {new_width}x{new_height}") + image = common_upscale(image.movedim(-1, 1), new_width, new_height, "lanczos", "disabled").movedim(1, -1) + + if image.shape[-1] == 4: + image = image[..., :3] + image = image.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W + + if noise_aug_strength > 0.0: + image = add_noise_to_reference_video(image, ratio=noise_aug_strength) + + if isinstance(vae, TAEHV): + latents = vae.encode_video(image.permute(0, 2, 1, 3, 4), parallel=False)# B, T, C, H, W + latents = latents.permute(0, 2, 1, 3, 4) + else: + latents = vae.encode(image * 2.0 - 1.0, device=device, tiled=enable_vae_tiling, tile_size=(tile_x//vae.upsampling_factor, tile_y//vae.upsampling_factor), tile_stride=(tile_stride_x//vae.upsampling_factor, tile_stride_y//vae.upsampling_factor)) + vae.model.clear_cache() + if latent_strength != 1.0: + latents *= latent_strength + + log.info(f"encoded latents shape {latents.shape}") + latent_mask = None + if mask is None: + vae.to(offload_device) + else: + target_h, target_w = latents.shape[3:] + + mask = torch.nn.functional.interpolate( + mask.unsqueeze(0).unsqueeze(0), # Add batch and channel dims [1,1,T,H,W] + size=(latents.shape[2], target_h, target_w), + mode='trilinear', + align_corners=False + ).squeeze(0) # Remove batch dim, keep channel dim + + # Add batch & channel dims for final output + latent_mask = mask.unsqueeze(0).repeat(1, latents.shape[1], 1, 1, 1) + log.info(f"latent mask shape {latent_mask.shape}") + vae.to(offload_device) + mm.soft_empty_cache() + + return ({"samples": latents, "mask": latent_mask},) + +class WanVideoLatentReScale: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "samples": ("LATENT",), + "direction": (["comfy_to_wrapper", "wrapper_to_comfy"], {"tooltip": "Direction to rescale latents, from comfy to wrapper or vice versa"}), + } + } + + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("samples",) + FUNCTION = "encode" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Rescale latents to match the expected range for encoding or decoding. Can be used to " + + def encode(self, samples, direction): + samples = samples.copy() + latents = samples["samples"] + + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + mean = torch.tensor(mean).view(1, latents.shape[1], 1, 1, 1) + std = torch.tensor(std).view(1, latents.shape[1], 1, 1, 1) + inv_std = (1.0 / std).view(1, latents.shape[1], 1, 1, 1) + if direction == "comfy_to_wrapper": + latents = (latents - mean.to(latents)) * inv_std.to(latents) + elif direction == "wrapper_to_comfy": + latents = latents / inv_std.to(latents) + mean.to(latents) + + samples["samples"] = latents + + return (samples,) + +NODE_CLASS_MAPPINGS = { + "WanVideoSampler": WanVideoSampler, + "WanVideoDecode": WanVideoDecode, + "WanVideoTextEncode": WanVideoTextEncode, + "WanVideoTextEncodeSingle": WanVideoTextEncodeSingle, + "WanVideoClipVisionEncode": WanVideoClipVisionEncode, + "WanVideoImageToVideoEncode": WanVideoImageToVideoEncode, + "WanVideoEncode": WanVideoEncode, + "WanVideoEmptyEmbeds": WanVideoEmptyEmbeds, + "WanVideoEnhanceAVideo": WanVideoEnhanceAVideo, + "WanVideoContextOptions": WanVideoContextOptions, + "WanVideoTextEmbedBridge": WanVideoTextEmbedBridge, + "WanVideoFlowEdit": WanVideoFlowEdit, + "WanVideoControlEmbeds": WanVideoControlEmbeds, + "WanVideoSLG": WanVideoSLG, + "WanVideoLoopArgs": WanVideoLoopArgs, + "WanVideoSetBlockSwap": WanVideoSetBlockSwap, + "WanVideoExperimentalArgs": WanVideoExperimentalArgs, + "WanVideoVACEEncode": WanVideoVACEEncode, + "WanVideoPhantomEmbeds": WanVideoPhantomEmbeds, + "WanVideoRealisDanceLatents": WanVideoRealisDanceLatents, + "WanVideoApplyNAG": WanVideoApplyNAG, + "WanVideoMiniMaxRemoverEmbeds": WanVideoMiniMaxRemoverEmbeds, + "WanVideoFreeInitArgs": WanVideoFreeInitArgs, + "WanVideoSetRadialAttention": WanVideoSetRadialAttention, + "WanVideoBlockList": WanVideoBlockList, + "WanVideoTextEncodeCached": WanVideoTextEncodeCached, + "WanVideoAddExtraLatent": WanVideoAddExtraLatent, + "WanVideoLatentReScale": WanVideoLatentReScale, + } +NODE_DISPLAY_NAME_MAPPINGS = { + "WanVideoSampler": "WanVideo Sampler", + "WanVideoDecode": "WanVideo Decode", + "WanVideoTextEncode": "WanVideo TextEncode", + "WanVideoTextEncodeSingle": "WanVideo TextEncodeSingle", + "WanVideoTextImageEncode": "WanVideo TextImageEncode (IP2V)", + "WanVideoClipVisionEncode": "WanVideo ClipVision Encode", + "WanVideoImageToVideoEncode": "WanVideo ImageToVideo Encode", + "WanVideoEncode": "WanVideo Encode", + "WanVideoEmptyEmbeds": "WanVideo Empty Embeds", + "WanVideoEnhanceAVideo": "WanVideo Enhance-A-Video", + "WanVideoContextOptions": "WanVideo Context Options", + "WanVideoTextEmbedBridge": "WanVideo TextEmbed Bridge", + "WanVideoFlowEdit": "WanVideo FlowEdit", + "WanVideoControlEmbeds": "WanVideo Control Embeds", + "WanVideoSLG": "WanVideo SLG", + "WanVideoLoopArgs": "WanVideo Loop Args", + "WanVideoSetBlockSwap": "WanVideo Set BlockSwap", + "WanVideoExperimentalArgs": "WanVideo Experimental Args", + "WanVideoVACEEncode": "WanVideo VACE Encode", + "WanVideoPhantomEmbeds": "WanVideo Phantom Embeds", + "WanVideoRealisDanceLatents": "WanVideo RealisDance Latents", + "WanVideoApplyNAG": "WanVideo Apply NAG", + "WanVideoMiniMaxRemoverEmbeds": "WanVideo MiniMax Remover Embeds", + "WanVideoFreeInitArgs": "WanVideo Free Init Args", + "WanVideoSetRadialAttention": "WanVideo Set Radial Attention", + "WanVideoBlockList": "WanVideo Block List", + "WanVideoTextEncodeCached": "WanVideo TextEncode Cached", + "WanVideoAddExtraLatent": "WanVideo Add Extra Latent", + "WanVideoLatentReScale": "WanVideo Latent ReScale", + } diff --git a/wanvideo/distributed/__init__.py b/wanvideo/distributed/__init__.py new file mode 100644 index 00000000..13a6b25a --- /dev/null +++ b/wanvideo/distributed/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +from .xdit_context_parallel import pad_freqs, rope_apply, usp_dit_forward_vace, usp_dit_forward, usp_attn_forward \ No newline at end of file diff --git a/wanvideo/distributed/xdit_context_parallel.py b/wanvideo/distributed/xdit_context_parallel.py new file mode 100644 index 00000000..aacf47fa --- /dev/null +++ b/wanvideo/distributed/xdit_context_parallel.py @@ -0,0 +1,227 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.cuda.amp as amp +from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) +from xfuser.core.long_ctx_attention import xFuserLongContextAttention + +from ..modules.model import sinusoidal_embedding_1d + + +def pad_freqs(original_tensor, target_len): + seq_len, s1, s2 = original_tensor.shape + pad_size = target_len - seq_len + padding_tensor = torch.ones( + pad_size, + s1, + s2, + dtype=original_tensor.dtype, + device=original_tensor.device) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) + return padded_tensor + + +@amp.autocast(enabled=False) +def rope_apply(x, grid_sizes, freqs): + """ + x: [B, L, N, C]. + grid_sizes: [B, 3]. + freqs: [M, C // 2]. + """ + s, n, c = x.size(1), x.size(2), x.size(3) // 2 + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape( + s, n, -1, 2)) + freqs_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], + dim=-1).reshape(seq_len, 1, -1) + + # apply rotary embedding + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + freqs_i = pad_freqs(freqs_i, s * sp_size) + s_per_rank = s + freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * + s_per_rank), :, :] + x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) + x_i = torch.cat([x_i, x[i, s:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).float() + + +def usp_dit_forward_vace( + self, + x, + vace_context, + seq_len, + kwargs +): + # embeddings + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + c = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in c + ]) + + # arguments + new_kwargs = dict(x=x) + new_kwargs.update(kwargs) + + # Context Parallel + c = torch.chunk( + c, get_sequence_parallel_world_size(), + dim=1)[get_sequence_parallel_rank()] + + for block in self.vace_blocks: + c = block(c, **new_kwargs) + hints = torch.unbind(c)[:-1] + return hints + + +def usp_dit_forward( + self, + x, + t, + vace_context, + context, + seq_len, + vace_context_scale=1.0, + clip_fea=None, + y=None, +): + """ + x: A list of videos each with shape [C, T, H, W]. + t: [B]. + context: A list of text embeddings each with shape [L, C]. + """ + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + # if y is not None: + # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) + for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + # if clip_fea is not None: + # context_clip = self.img_emb(clip_fea) # bs x 257 x dim + # context = torch.concat([context_clip, context], dim=1) + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens) + + # Context Parallel + x = torch.chunk( + x, get_sequence_parallel_world_size(), + dim=1)[get_sequence_parallel_rank()] + + hints = self.forward_vace(x, vace_context, seq_len, kwargs) + kwargs['hints'] = hints + kwargs['context_scale'] = vace_context_scale + + for block in self.blocks: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # Context Parallel + x = get_sp_group().all_gather(x, dim=1) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + +def usp_attn_forward(self, + x, + seq_lens, + grid_sizes, + freqs, + dtype=torch.bfloat16): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + half_dtypes = (torch.float16, torch.bfloat16) + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + + # TODO: We should use unpaded q,k,v for attention. + # k_lens = seq_lens // get_sequence_parallel_world_size() + # if k_lens is not None: + # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0) + # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) + # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) + + x = xFuserLongContextAttention()( + None, + query=half(q), + key=half(k), + value=half(v), + window_size=self.window_size) + + # TODO: padding after attention. + # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1) + + # output + x = x.flatten(2) + x = self.o(x) + return x diff --git a/wanvideo/framepack_vace.py b/wanvideo/framepack_vace.py new file mode 100644 index 00000000..44e8f34f --- /dev/null +++ b/wanvideo/framepack_vace.py @@ -0,0 +1,824 @@ +import os +import sys +import gc +import math +import time +import random +import types +import logging +import traceback +from contextlib import contextmanager +from functools import partial +import time +from PIL import Image +import numpy as np +import torchvision.transforms.functional as TF +import torch +import torch.nn.functional as F +import torch.cuda.amp as amp +import torch.distributed as dist +import torch.multiprocessing as mp +from tqdm import tqdm + +from wan.text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler) +from .modules.model import VaceWanModel +from ..utils.preprocessor import VaceVideoProcessor + + +class FramepackVace(WanT2V): + def __init__( + self, + config, + checkpoint_dir, + device_id=0, + rank=0, + t5_fsdp=False, + dit_fsdp=False, + use_usp=False, + t5_cpu=False, + ): + r""" + Initializes the Wan text-to-video generation model components. + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + device_id (`int`, *optional*, defaults to 0): + Id of target GPU device + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for T5 model + dit_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for DiT model + use_usp (`bool`, *optional*, defaults to False): + Enable distribution strategy of USP. + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_fsdp. + """ + self.device = torch.device(f"cuda:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + shard_fn = partial(shard_model, device_id=device_id) + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + + shard_fn=shard_fn if t5_fsdp else None) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + self.vae = WanVAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device) + + logging.info(f"Creating VaceWanModel from {checkpoint_dir}") + self.model = VaceWanModel.from_pretrained(checkpoint_dir) + self.model.eval().requires_grad_(False) + + if use_usp: + from xfuser.core.distributed import \ + get_sequence_parallel_world_size + + from .distributed.xdit_context_parallel import (usp_attn_forward, + usp_dit_forward, + usp_dit_forward_vace) + for block in self.model.blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + for block in self.model.vace_blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + self.model.forward = types.MethodType(usp_dit_forward, self.model) + self.model.forward_vace = types.MethodType(usp_dit_forward_vace, self.model) + self.sp_size = get_sequence_parallel_world_size() + else: + self.sp_size = 1 + + if dist.is_initialized(): + dist.barrier() + if dit_fsdp: + self.model = shard_fn(self.model) + else: + self.model.to(self.device) + + self.sample_neg_prompt = config.sample_neg_prompt + + self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), + min_area=480 * 832, + max_area=480 * 832, + min_fps=self.config.sample_fps, + max_fps=self.config.sample_fps, + zero_start=True, + seq_len=32760, + keep_last=True) + + def vace_encode_frames(self, frames, ref_images, masks=None, vae=None): + vae = self.vae if vae is None else vae + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = vae.encode(frames) + else: + masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks] + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = vae.encode(inactive) + reactive = vae.encode(reactive) + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + + cat_latents = [] + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = vae.encode(refs) + else: + ref_latent = vae.encode(refs) + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + return cat_latents + + def vace_encode_masks(self, masks, ref_images=None, vae_stride=None): + vae_stride = self.vae_stride if vae_stride is None else vae_stride + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + for mask, refs in zip(masks, ref_images): + c, depth, height, width = mask.shape + new_depth = int((depth + 3) // vae_stride[0]) + height = 2 * (int(height) // (vae_stride[1] * 2)) + width = 2 * (int(width) // (vae_stride[2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view( + depth, height, vae_stride[1], width, vae_stride[1] + ) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape( + vae_stride[1] * vae_stride[2], depth, height, width + ) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + + if refs is not None: + length = len(refs) + mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + return result_masks + + def vace_latent(self, z, m): + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] + + def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device): + area = image_size[0] * image_size[1] + self.vid_proc.set_area(area) + if area == 720*1280: + self.vid_proc.set_seq_len(75600) + elif area == 480*832: + self.vid_proc.set_seq_len(32760) + else: + raise NotImplementedError(f'image_size {image_size} is not supported') + + image_size = (image_size[1], image_size[0]) + image_sizes = [] + for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + if sub_src_mask is not None and sub_src_video is not None: + src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask) + src_video[i] = src_video[i].to(device) + src_mask[i] = src_mask[i].to(device) + src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) + image_sizes.append(src_video[i].shape[2:]) + elif sub_src_video is None: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(image_size) + else: + src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video) + src_video[i] = src_video[i].to(device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(src_video[i].shape[2:]) + + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + image_size = image_sizes[i] + for j, ref_img in enumerate(ref_images): + if ref_img is not None: + ref_img = Image.open(ref_img).convert("RGB") + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + ref_img = white_canvas + src_ref_images[i][j] = ref_img.to(device) + + return src_video, src_mask, src_ref_images + + def prepare_video(self, src_video): + import decord + decord.bridge.set_bridge('torch') + reader = decord.VideoReader(src_video) + + total_frames = len(reader) + fps = reader.get_avg_fps() + num_frames=None + # Get frame indices + if num_frames is None: + frame_ids = list(range(total_frames)) + else: + # Sample frames evenly + frame_ids = np.linspace(0, total_frames-1, num_frames, dtype=int).tolist() + + # Load frames + video = reader.get_batch(frame_ids) # [T, H, W, C] + video = video.permute(3, 0, 1, 2) # [C, T, H, W] + + # Convert to float and normalize + video = video.float() + C, T, H, W = video.shape + chunk_length=81 + video = video / 255.0 + C, T, H, W = video.shape + usable_frames = (T // chunk_length) * chunk_length + video = video[:, :usable_frames, :, :] # Trim excess frames + + chunks = [] + for i in range(0, usable_frames, chunk_length): + chunk = video[:, i:i+chunk_length, :, :] + chunks.append(chunk) + + return chunks + def decode_latent(self, zs, ref_images=None, vae=None): + vae = self.vae if vae is None else vae + + # No need to check ref_images length or trim anymore + return vae.decode(zs) + + + def generate_with_framepack(self, + input_prompt, + input_frames, + input_masks, + input_ref_images, + + + size=(1280, 720), + frame_num=41, + context_scale=1.0, + shift=5.0, + sample_solver='dpm++', + sampling_steps=20, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + """ + Generates long videos using hierarchical context with frame packing. + + Major changes: + 1. Fixed context propagation between sections + 2. Improved hierarchical frame selection + 3. Better mask generation for consistent 22-frame structure + 4. Enhanced debugging and visualization + """ + + LATENT_WINDOW = 41 + GENERATION_FRAMES = 30 + CONTEXT_FRAMES = 11 + # frame_num=300 + section_window = 41 + section_num = math.ceil(frame_num / section_window) + + + all_generated_latents = [] + accumulated_latents = [] + context_buffer = None + + print(f'Total frames requested: {frame_num}') + print(f'Total sections to generate: {section_num}') + print(f'Latent structure: {CONTEXT_FRAMES} context + {GENERATION_FRAMES} generation = {LATENT_WINDOW} total') + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + + # Base seed management + if seed == -1: + import time + base_seed = int(time.time() * 1000) % (1 << 32) + else: + base_seed = seed + self.model.to(self.device) + for section_id in range(section_num): + torch.cuda.synchronize() + print(f"\n{'='*60}") + print(f"SECTION {section_id+1} / {section_num}") + print(f"{'='*60}\n") + + def get_tensor_list_memory(tensor_list): + total_bytes = 0 + for tensor in tensor_list: + if isinstance(tensor, torch.Tensor): + total_bytes += tensor.numel() * tensor.element_size() + total_mb = total_bytes / (1024 ** 2) + total_gb = total_bytes / (1024 ** 3) + print(f"Total memory used by tensor list: {total_mb:.2f} MB ({total_gb:.4f} GB)") + + get_tensor_list_memory(accumulated_latents) + get_tensor_list_memory(all_generated_latents) + + # Create unique seed for each section + section_seed = base_seed + section_id * 1000 + section_generator = torch.Generator(device=self.device) + section_generator.manual_seed(section_seed) + + + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + + if section_id == 0: + + print("First section - using input frames") + + current_frames=input_frames + current_masks=input_masks + current_ref_images = input_ref_images + frame_offset = 0 + context_scale_section = context_scale + + else: + + print(f"Section {section_id} - building hierarchical context") + + + context_latent = self.build_hierarchical_context_latent( + accumulated_latents, section_id) + + + # context_decoded = self.decode_latent([context_latent], None) + + # get_tensor_list_memory(context_decoded) + # self.model.to(self.device) + + if section_id > 1: + appearance, motion = self.separate_appearance_and_motion(context_latent) + motion_noise = torch.randn_like(motion) * 0.3 + motion_perturbed = motion + motion_noise + context_decoded = [appearance + motion_perturbed * 0.5] + + + hierarchical_frames = self.pick_context_v2(context_latent, section_id) + current_frames = self.decode_latent([hierarchical_frames], None) + print('current frames shape', current_frames[0].shape ) + current_masks = self.create_temporal_blend_mask_v2( + current_frames[0].shape, section_id) + current_ref_images = None + print('current mask shape', current_masks[0].shape ) + + frame_offset = min(LATENT_WINDOW + (section_id - 1) * GENERATION_FRAMES, 100) + + + context_variation = 0.7 + torch.rand(1).item() * 0.6 + context_scale_section = context_scale * context_variation + + + z0 = self.vace_encode_frames(current_frames, current_ref_images, masks=current_masks) + m0 = self.vace_encode_masks(current_masks, current_ref_images) + z = self.vace_latent(z0, m0) + print(f"Context latent shape: {z0[0].shape}") + print(f"Context scale: {context_scale_section:.3f}") + print(f"Frame offset: {frame_offset}") + + + target_shape = list(z0[0].shape) + target_shape[0] = int(target_shape[0] / 2) + del z0,m0,current_frames,current_masks + noise_base = torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=self.device, + generator=section_generator) + + if section_id > 0 and accumulated_latents: + noise = [noise_base] + + else: + noise = [noise_base] + + print(f"Noise shape: {noise[0].shape}") + + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + + + @contextmanager + def noop_no_sync(): + yield + sample_solver ='dpm++' + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + # sample_solver='dpm++' + # sampling_steps=20 + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + # Setup scheduler + if sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError(f"Unsupported solver: {sample_solver}") + + latents = noise + arg_c = {'context': context, 'seq_len': seq_len, 'frame_offset': frame_offset} + arg_null = {'context': context_null, 'seq_len': seq_len, 'frame_offset': frame_offset} + + # Denoising loop + for step_idx, t in enumerate(tqdm(timesteps, desc=f"Section {section_id+1}")): + latent_model_input = latents + timestep = torch.stack([t]) + + + + noise_pred_cond = self.model( + latent_model_input, t=timestep, vace_context=z, + vace_context_scale=context_scale_section, **arg_c)[0] + noise_pred_uncond = self.model( + latent_model_input, t=timestep, vace_context=z, + vace_context_scale=context_scale_section, **arg_null)[0] + + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latents[0].unsqueeze(0), + return_dict=False, + generator=section_generator)[0] + latents = [temp_x0.squeeze(0)] + + + # Debug first and last steps + if step_idx == 0 or step_idx == len(timesteps) - 1: + print(f" Step {step_idx}: t={t.item():.3f}, " + f"latent stats: mean={latents[0].mean().item():.3f}, " + f"std={latents[0].std().item():.3f}") + del noise_pred, noise_pred_uncond,noise_pred_cond + + del context, context_null,z, noise + if section_id == 0: + print(f"Section 0: Removing {1} reference frames from latent") + + if section_num==1: + latent_without_ref = latents[0] + accumulated_latents.append(latent_without_ref) + + + all_generated_latents.append(latent_without_ref) + + else: + latent_without_ref = latents[0][:, 1:-10, :, :] + accumulated_latents.append(latent_without_ref) + + + all_generated_latents.append(latent_without_ref) + + else: + if section_id > 2: + accumulated_latents.pop(0) + new=latents[0][:, -GENERATION_FRAMES:, :, :] + accumulated_latents.append(new) + + if section_id == 0: + # First section without reference images + all_generated_latents.append(latents[0]) + else: + # Take only newly generated frames + new_content = latents[0][:, -GENERATION_FRAMES:, :, :] + new_content = new_content[:, 11:, :, :] + all_generated_latents.append(new_content) + del new_content + torch.cuda.synchronize() + torch.cuda.empty_cache() + gc.collect() + + # if self.rank == 0: + # section_decoded = self.decode_latent(latents, None) + # self.save_section_debug(section_decoded[0], section_id, accumulated_latents) + + # print(f"Section {section_id} completed. Generated latent shape: {latents[0].shape}") + + # Final video assembly + if self.rank == 0 and all_generated_latents: + + final_latent = torch.cat(all_generated_latents, dim=1) + print(f"\nFinal latent shape: {final_latent.shape}") + + + final_video = self.decode_latent([final_latent], None) + + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + + + if offload_model: + gc.collect() + torch.cuda.synchronize() + + if dist.is_initialized(): + dist.barrier() + + return final_video[0] + + return None + + def build_hierarchical_context_latent(self, accumulated_latents, section_id): + """ + Build hierarchical context from accumulated latents. + + """ + if not accumulated_latents: + raise ValueError("No accumulated latents available") + + all_prev = torch.cat(accumulated_latents, dim=1) + total_frames = all_prev.shape[1] + + print(f"Building context from {total_frames} accumulated frames") + + return all_prev + + + def pick_context_v2(self, frames, section_id, initial=False): + """ + Enhanced hierarchical context selection with constant 22-frame output. + + Changes from original: + 1. Better handling of initial frames + 2. More robust frame selection with proper bounds checking + 3. Improved debugging output + """ + + # Constants + LONG_FRAMES = 5 + MID_FRAMES = 3 + RECENT_FRAMES = 1 + OVERLAP_FRAMES = 2 + GEN_FRAMES = 30 + TOTAL_FRAMES = 41 + + C, T, H, W = frames.shape + + if initial and T == TOTAL_FRAMES: + return frames + + if initial and T < TOTAL_FRAMES: + padding_needed = TOTAL_FRAMES - T + padding = torch.zeros((C, padding_needed, H, W), device=frames.device) + return torch.cat([frames, padding], dim=1) + + selected_indices = [] + + if T >= 40: + step = max(4, T // 20) + long_indices = [] + for i in range(LONG_FRAMES): + idx = min(i * step, T - 15) + long_indices.append(idx) + selected_indices.extend(long_indices) + else: + # Not enough frames - take evenly spaced + if T >= LONG_FRAMES: + step = T // LONG_FRAMES + long_indices = [i * step for i in range(LONG_FRAMES)] + else: + long_indices = list(range(T)) + # Pad by repeating last frame + while len(long_indices) < LONG_FRAMES: + long_indices.append(T - 1) + selected_indices.extend(long_indices[:LONG_FRAMES]) + + mid_start = max(LONG_FRAMES, T - 15) + mid_indices = [ + min(mid_start, T - 1), + min(mid_start + 2, T - 1) + ] + selected_indices.extend(mid_indices) + + recent_idx = max(0, T - 5) + selected_indices.append(recent_idx) + + overlap_start = max(0, T - OVERLAP_FRAMES) + overlap_indices = list(range(overlap_start, T)) + + while len(overlap_indices) < OVERLAP_FRAMES: + overlap_indices.append(T - 1) + selected_indices.extend(overlap_indices[:OVERLAP_FRAMES]) + context_frames = frames[:, selected_indices, :, :] + + gen_placeholder = torch.zeros((C, GEN_FRAMES, H, W), device=frames.device) + + final_frames = torch.cat([ + context_frames[:, :LONG_FRAMES], + context_frames[:, LONG_FRAMES:LONG_FRAMES+MID_FRAMES], + context_frames[:, LONG_FRAMES+MID_FRAMES:LONG_FRAMES+MID_FRAMES+RECENT_FRAMES], + context_frames[:, -OVERLAP_FRAMES:], + gen_placeholder + ], dim=1) + + assert final_frames.shape[1] == TOTAL_FRAMES, \ + f"Expected {TOTAL_FRAMES} frames, got {final_frames.shape[1]}" + + if section_id % 5 == 0: + print(f"\nContext selection debug (section {section_id}):") + print(f" Input frames: {T}") + print(f" Selected indices: {selected_indices}") + print(f" Output shape: {final_frames.shape}") + + return final_frames + + + def create_temporal_blend_mask_v2(self, frame_shape, section_id, initial=False): + """ + Enhanced mask creation that handles decoded frame dimensions + """ + C, T, H, W = frame_shape + LONG_FRAMES = 5 + MID_FRAMES = 3 + RECENT_FRAMES = 1 + OVERLAP_FRAMES = 2 + GEN_FRAMES = 30 + TOTAL_FRAMES = 41 + # Calculate the temporal expansion ratio + LATENT_FRAMES = 41 + decoded_frames = T + expansion_ratio = decoded_frames / LATENT_FRAMES + + mask = torch.zeros(1, decoded_frames, H, W, device=self.device) + + # Scale all frame counts by the expansion ratio + LONG_FRAMES = int(5 * expansion_ratio) + MID_FRAMES = int(3 * expansion_ratio) + RECENT_FRAMES = int(1 * expansion_ratio) + OVERLAP_FRAMES = int(2 * expansion_ratio) + GEN_FRAMES = decoded_frames - (LONG_FRAMES + MID_FRAMES + RECENT_FRAMES + OVERLAP_FRAMES) + + if initial: + mask[:, :-GEN_FRAMES] = 0.0 + mask[:, -GEN_FRAMES:] = 1.0 + return [mask] + + # Apply mask values with expanded frame counts + idx = 0 + mask[:, idx:idx+LONG_FRAMES] = 0.05 + idx += LONG_FRAMES + + mask[:, idx:idx+MID_FRAMES] = 0.2 + idx += MID_FRAMES + + mask[:, idx:idx+RECENT_FRAMES] = 0.3 + idx += RECENT_FRAMES + + for i in range(OVERLAP_FRAMES): + blend_value = 0.4 + (i / (OVERLAP_FRAMES - 1)) * 0.4 + mask[:, idx+i] = blend_value + idx += OVERLAP_FRAMES + + mask[:, idx:] = 1.0 + + return [mask] + def create_spatial_variation(self, H, W): + """Create spatial variation mask for natural blending.""" + y_coords = torch.linspace(-1, 1, H, device=self.device) + x_coords = torch.linspace(-1, 1, W, device=self.device) + y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing='ij') + + + distance = torch.sqrt(x_grid**2 + y_grid**2) / 1.414 + variation = 1.0 - 0.3 * torch.exp(-3 * distance**2) + + return variation + + def separate_appearance_and_motion(self, frames): + """Use frequency domain to separate appearance from motion""" + + C, T, H, W = frames.shape + + + fft_frames = torch.fft.rfft2(frames, dim=(-2, -1)) + + + fft_h = H + fft_w = W // 2 + 1 + + h_freqs = torch.fft.fftfreq(H, device=frames.device) + + w_freqs = torch.fft.rfftfreq(W, device=frames.device) + + + h_grid, w_grid = torch.meshgrid(h_freqs, w_freqs, indexing='ij') + + + freq_magnitude = torch.sqrt(h_grid**2 + w_grid**2) + + + cutoff = 0.1 + low_pass_mask = (freq_magnitude < cutoff).float().to(frames.device) + + + if low_pass_mask.shape != fft_frames.shape[-2:]: + print(f"Mask shape: {low_pass_mask.shape}, FFT shape: {fft_frames.shape}") + + low_pass_mask = low_pass_mask[:fft_h, :fft_w] + + + while low_pass_mask.dim() < fft_frames.dim(): + low_pass_mask = low_pass_mask.unsqueeze(0) + + + appearance_fft = fft_frames * low_pass_mask + motion_fft = fft_frames * (1 - low_pass_mask) + + + appearance = torch.fft.irfft2(appearance_fft, s=(H, W)) + motion = torch.fft.irfft2(motion_fft, s=(H, W)) + + return appearance, motion + + + def save_section_debug(self, video_tensor, section_id, accumulated_latents): + """Enhanced debugging output with more information.""" + import imageio + import numpy as np + + # Save video + output_path = f"debug_section_{section_id:03d}.mp4" + + video_np = video_tensor.cpu().detach() + video_np = (video_np + 1.0) / 2.0 + video_np = torch.clamp(video_np, 0.0, 1.0) + video_np = video_np.permute(1, 2, 3, 0).numpy() + video_np_uint8 = (video_np * 255).astype(np.uint8) + + imageio.mimsave(output_path, video_np_uint8, fps=12) + + # Save debug info + debug_info = { + 'section_id': section_id, + 'video_shape': list(video_tensor.shape), + 'accumulated_latents': len(accumulated_latents), + 'total_latent_frames': sum(l.shape[1] for l in accumulated_latents) + } + + import json + with open(f"debug_section_{section_id:03d}.json", 'w') as f: + json.dump(debug_info, f, indent=2) + + print(f"Saved debug output to {output_path}") diff --git a/wanvideo/modules/__init__.py b/wanvideo/modules/__init__.py index 2d9adae5..e9ef3efa 100644 --- a/wanvideo/modules/__init__.py +++ b/wanvideo/modules/__init__.py @@ -2,6 +2,7 @@ from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model from .tokenizers import HuggingfaceTokenizer +from .model import WanModel, WanAttentionBlock, VaceWanModel, VaceWanAttentionBlock, BaseWanAttentionBlock __all__ = [ 'WanModel', 'T5Model', diff --git a/wanvideo/modules/model.py b/wanvideo/modules/model.py index f4d77df5..d4da9f60 100644 --- a/wanvideo/modules/model.py +++ b/wanvideo/modules/model.py @@ -1607,7 +1607,9 @@ def forward( # MultiTalk if multitalk_audio is not None: - self.audio_proj.to(self.main_device) + if hasattr(self, 'audio_proj') and self.audio_proj is not None: + self.audio_proj.to(self.main_device) + # self.audio_proj.to(self.main_device) audio_cond = multitalk_audio.to(device=x.device, dtype=x.dtype) first_frame_audio_emb_s = audio_cond[:, :1, ...] latter_frame_audio_emb = audio_cond[:, 1:, ...] @@ -1926,3 +1928,254 @@ def unpatchify(self, x, grid_sizes): u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) out.append(u) return out + + +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import register_to_config +from wan.modules.model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d + +class VaceWanAttentionBlock(WanAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=0 + ): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) + self.block_id = block_id + if block_id == 0: + self.before_proj = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.before_proj.weight) + nn.init.zeros_(self.before_proj.bias) + self.after_proj = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.after_proj.weight) + nn.init.zeros_(self.after_proj.bias) + + def forward(self, c, x, frame_offset=0 , **kwargs): + kwargs['frame_offset'] = frame_offset + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + + +class BaseWanAttentionBlock(WanAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=None + ): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) + self.block_id = block_id + + def forward(self, x, hints, context_scale=1.0,frame_offset=0, **kwargs): + kwargs['frame_offset'] = frame_offset + x = super().forward(x, **kwargs) + if self.block_id is not None: + x = x + hints[self.block_id] * context_scale + return x + + +class VaceWanModel(WanModel): + @register_to_config + def __init__(self, + vace_layers=None, + vace_in_dim=None, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6): + model_type = "t2v" # TODO: Hard code for both preview and official versions. + super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim, + num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps) + + self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers + self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim + + assert 0 in self.vace_layers + self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} + + # blocks + self.blocks = nn.ModuleList([ + BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, + self.cross_attn_norm, self.eps, + block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None) + for i in range(self.num_layers) + ]) + + # vace blocks + self.vace_blocks = nn.ModuleList([ + VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, + self.cross_attn_norm, self.eps, block_id=i) + for i in self.vace_layers + ]) + + # vace patch embeddings + self.vace_patch_embedding = nn.Conv3d( + self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size + ) + + def forward_vace( + self, + x, + vace_context, + seq_len, + frame_offset, + kwargs + ): + # embeddings + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + c = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in c + ]) + + # arguments + new_kwargs = dict(x=x) + new_kwargs.update(kwargs) + + for block in self.vace_blocks: + c = block(c, **new_kwargs) + hints = torch.unbind(c)[:-1] + return hints + + def forward( + self, + x, + t, + vace_context, + context, + seq_len, + vace_context_scale=1.0, + clip_fea=None, + frame_offset=0, + + y=None, + ): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + # if self.model_type == 'i2v': + # assert clip_fea is not None and y is not None + # params + self.section_embedding = nn.Embedding(20, self.dim) + + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + # if y is not None: + # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + + # vace_context = [z_i + section_embed for z_i in vace_context] + # if clip_fea is not None: + # context_clip = self.img_emb(clip_fea) # bs x 257 x dim + # context = torch.concat([context_clip, context], dim=1) + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + frame_offset=frame_offset + ) + print('context',vace_context_scale) + hints = self.forward_vace(x, vace_context, seq_len,frame_offset, kwargs) + kwargs['hints'] = hints + kwargs['context_scale'] = vace_context_scale + + for block in self.blocks: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] \ No newline at end of file diff --git a/wanvideo/preprocessor.py b/wanvideo/preprocessor.py new file mode 100644 index 00000000..a0788111 --- /dev/null +++ b/wanvideo/preprocessor.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import numpy as np +from PIL import Image +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF + + +class VaceImageProcessor(object): + def __init__(self, downsample=None, seq_len=None): + self.downsample = downsample + self.seq_len = seq_len + + def _pillow_convert(self, image, cvt_type='RGB'): + if image.mode != cvt_type: + if image.mode == 'P': + image = image.convert(f'{cvt_type}A') + if image.mode == f'{cvt_type}A': + bg = Image.new(cvt_type, + size=(image.width, image.height), + color=(255, 255, 255)) + bg.paste(image, (0, 0), mask=image) + image = bg + else: + image = image.convert(cvt_type) + return image + + def _load_image(self, img_path): + if img_path is None or img_path == '': + return None + img = Image.open(img_path) + img = self._pillow_convert(img) + return img + + def _resize_crop(self, img, oh, ow, normalize=True): + """ + Resize, center crop, convert to tensor, and normalize. + """ + # resize and crop + iw, ih = img.size + if iw != ow or ih != oh: + # resize + scale = max(ow / iw, oh / ih) + img = img.resize( + (round(scale * iw), round(scale * ih)), + resample=Image.Resampling.LANCZOS + ) + assert img.width >= ow and img.height >= oh + + # center crop + x1 = (img.width - ow) // 2 + y1 = (img.height - oh) // 2 + img = img.crop((x1, y1, x1 + ow, y1 + oh)) + + # normalize + if normalize: + img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1) + return img + + def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs): + return self._resize_crop(img, oh, ow, normalize) + + def load_image(self, data_key, **kwargs): + return self.load_image_batch(data_key, **kwargs) + + def load_image_pair(self, data_key, data_key2, **kwargs): + return self.load_image_batch(data_key, data_key2, **kwargs) + + def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs): + seq_len = self.seq_len if seq_len is None else seq_len + imgs = [] + for data_key in data_key_batch: + img = self._load_image(data_key) + imgs.append(img) + w, h = imgs[0].size + dh, dw = self.downsample[1:] + + # compute output size + scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw)))) + oh = int(h * scale) // dh * dh + ow = int(w * scale) // dw * dw + assert (oh // dh) * (ow // dw) <= seq_len + imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs] + return *imgs, (oh, ow) + + +class VaceVideoProcessor(object): + def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs): + self.downsample = downsample + self.min_area = min_area + self.max_area = max_area + self.min_fps = min_fps + self.max_fps = max_fps + self.zero_start = zero_start + self.keep_last = keep_last + self.seq_len = seq_len + assert seq_len >= min_area / (self.downsample[1] * self.downsample[2]) + + def set_area(self, area): + self.min_area = area + self.max_area = area + + def set_seq_len(self, seq_len): + self.seq_len = seq_len + + @staticmethod + def resize_crop(video: torch.Tensor, oh: int, ow: int): + """ + Resize, center crop and normalize for decord loaded video (torch.Tensor type) + + Parameters: + video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C) + oh - target height (int) + ow - target width (int) + + Returns: + The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W) + + Raises: + """ + # permute ([t, h, w, c] -> [t, c, h, w]) + video = video.permute(0, 3, 1, 2) + + # resize and crop + ih, iw = video.shape[2:] + if ih != oh or iw != ow: + # resize + scale = max(ow / iw, oh / ih) + video = F.interpolate( + video, + size=(round(scale * ih), round(scale * iw)), + mode='bicubic', + antialias=True + ) + assert video.size(3) >= ow and video.size(2) >= oh + + # center crop + x1 = (video.size(3) - ow) // 2 + y1 = (video.size(2) - oh) // 2 + video = video[:, :, y1:y1 + oh, x1:x1 + ow] + + # permute ([t, c, h, w] -> [c, t, h, w]) and normalize + video = video.transpose(0, 1).float().div_(127.5).sub_(1.) + return video + + def _video_preprocess(self, video, oh, ow): + return self.resize_crop(video, oh, ow) + + def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng): + target_fps = min(fps, self.max_fps) + duration = frame_timestamps[-1].mean() + x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box + h, w = y2 - y1, x2 - x1 + ratio = h / w + df, dh, dw = self.downsample + + area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) + of = min( + (int(duration * target_fps) - 1) // df + 1, + int(self.seq_len / area_z) + ) + + # deduce target shape of the [latent video] + target_area_z = min(area_z, int(self.seq_len / of)) + oh = round(np.sqrt(target_area_z * ratio)) + ow = int(target_area_z / oh) + of = (of - 1) * df + 1 + oh *= dh + ow *= dw + + # sample frame ids + target_duration = of / target_fps + begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration) + timestamps = np.linspace(begin, begin + target_duration, of) + frame_ids = np.argmax(np.logical_and( + timestamps[:, None] >= frame_timestamps[None, :, 0], + timestamps[:, None] < frame_timestamps[None, :, 1] + ), axis=1).tolist() + return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps + + def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng): + duration = frame_timestamps[-1].mean() + x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box + h, w = y2 - y1, x2 - x1 + ratio = h / w + df, dh, dw = self.downsample + + area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) + of = min( + (len(frame_timestamps) - 1) // df + 1, + int(self.seq_len / area_z) + ) + + # deduce target shape of the [latent video] + target_area_z = min(area_z, int(self.seq_len / of)) + oh = round(np.sqrt(target_area_z * ratio)) + ow = int(target_area_z / oh) + of = (of - 1) * df + 1 + oh *= dh + ow *= dw + + # sample frame ids + target_duration = duration + target_fps = of / target_duration + timestamps = np.linspace(0., target_duration, of) + frame_ids = np.argmax(np.logical_and( + timestamps[:, None] >= frame_timestamps[None, :, 0], + timestamps[:, None] <= frame_timestamps[None, :, 1] + ), axis=1).tolist() + # print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids)) + return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps + + + def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng): + if self.keep_last: + return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng) + else: + return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng) + + def load_video(self, data_key, crop_box=None, seed=2024, **kwargs): + return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs) + + def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs): + return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) + + def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs): + rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) + # read video + import decord + decord.bridge.set_bridge('torch') + readers = [] + for data_k in data_key_batch: + reader = decord.VideoReader(data_k) + readers.append(reader) + + fps = readers[0].get_avg_fps() + length = min([len(r) for r in readers]) + frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)] + frame_timestamps = np.array(frame_timestamps, dtype=np.float32) + h, w = readers[0].next().shape[:2] + frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng) + + # preprocess video + videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers] + videos = [self._video_preprocess(video, oh, ow) for video in videos] + return *videos, frame_ids, (oh, ow), fps + # return videos if len(videos) > 1 else videos[0] + + +def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device): + for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + if sub_src_video is None and sub_src_mask is None: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device) + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + for j, ref_img in enumerate(ref_images): + if ref_img is not None and ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + src_ref_images[i][j] = white_canvas + return src_video, src_mask, src_ref_images diff --git a/wanvideo/wan_vace.py b/wanvideo/wan_vace.py new file mode 100644 index 00000000..d388c507 --- /dev/null +++ b/wanvideo/wan_vace.py @@ -0,0 +1,719 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import sys +import gc +import math +import time +import random +import types +import logging +import traceback +from contextlib import contextmanager +from functools import partial + +from PIL import Image +import torchvision.transforms.functional as TF +import torch +import torch.nn.functional as F +import torch.cuda.amp as amp +import torch.distributed as dist +import torch.multiprocessing as mp +from tqdm import tqdm + +from wan.text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler) +from .modules.model import VaceWanModel +from ..utils.preprocessor import VaceVideoProcessor + + +class WanVace(WanT2V): + def __init__( + self, + config, + checkpoint_dir, + device_id=0, + rank=0, + t5_fsdp=False, + dit_fsdp=False, + use_usp=False, + t5_cpu=False, + ): + r""" + Initializes the Wan text-to-video generation model components. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + device_id (`int`, *optional*, defaults to 0): + Id of target GPU device + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for T5 model + dit_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for DiT model + use_usp (`bool`, *optional*, defaults to False): + Enable distribution strategy of USP. + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_fsdp. + """ + self.device = torch.device(f"cuda:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + shard_fn = partial(shard_model, device_id=device_id) + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn=shard_fn if t5_fsdp else None) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + self.vae = WanVAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device) + + logging.info(f"Creating VaceWanModel from {checkpoint_dir}") + self.model = VaceWanModel.from_pretrained(checkpoint_dir) + self.model.eval().requires_grad_(False) + + if use_usp: + from xfuser.core.distributed import \ + get_sequence_parallel_world_size + + from .distributed.xdit_context_parallel import (usp_attn_forward, + usp_dit_forward, + usp_dit_forward_vace) + for block in self.model.blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + for block in self.model.vace_blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + self.model.forward = types.MethodType(usp_dit_forward, self.model) + self.model.forward_vace = types.MethodType(usp_dit_forward_vace, self.model) + self.sp_size = get_sequence_parallel_world_size() + else: + self.sp_size = 1 + + if dist.is_initialized(): + dist.barrier() + if dit_fsdp: + self.model = shard_fn(self.model) + else: + self.model.to(self.device) + + self.sample_neg_prompt = config.sample_neg_prompt + + self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), + min_area=480 * 832, + max_area=480 * 832, + min_fps=self.config.sample_fps, + max_fps=self.config.sample_fps, + zero_start=True, + seq_len=32760, + keep_last=True) + + def vace_encode_frames(self, frames, ref_images, masks=None, vae=None): + vae = self.vae if vae is None else vae + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = vae.encode(frames) + else: + masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks] + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = vae.encode(inactive) + reactive = vae.encode(reactive) + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + + cat_latents = [] + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = vae.encode(refs) + else: + ref_latent = vae.encode(refs) + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + return cat_latents + + def vace_encode_masks(self, masks, ref_images=None, vae_stride=None): + vae_stride = self.vae_stride if vae_stride is None else vae_stride + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + for mask, refs in zip(masks, ref_images): + c, depth, height, width = mask.shape + new_depth = int((depth + 3) // vae_stride[0]) + height = 2 * (int(height) // (vae_stride[1] * 2)) + width = 2 * (int(width) // (vae_stride[2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view( + depth, height, vae_stride[1], width, vae_stride[1] + ) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape( + vae_stride[1] * vae_stride[2], depth, height, width + ) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + + if refs is not None: + length = len(refs) + mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + return result_masks + + def vace_latent(self, z, m): + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] + + def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device): + area = image_size[0] * image_size[1] + self.vid_proc.set_area(area) + if area == 720*1280: + self.vid_proc.set_seq_len(75600) + elif area == 480*832: + self.vid_proc.set_seq_len(32760) + else: + raise NotImplementedError(f'image_size {image_size} is not supported') + + image_size = (image_size[1], image_size[0]) + image_sizes = [] + for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + if sub_src_mask is not None and sub_src_video is not None: + src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask) + src_video[i] = src_video[i].to(device) + src_mask[i] = src_mask[i].to(device) + src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) + image_sizes.append(src_video[i].shape[2:]) + elif sub_src_video is None: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(image_size) + else: + src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video) + src_video[i] = src_video[i].to(device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(src_video[i].shape[2:]) + + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + image_size = image_sizes[i] + for j, ref_img in enumerate(ref_images): + if ref_img is not None: + ref_img = Image.open(ref_img).convert("RGB") + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + ref_img = white_canvas + src_ref_images[i][j] = ref_img.to(device) + return src_video, src_mask, src_ref_images + + def decode_latent(self, zs, ref_images=None, vae=None): + vae = self.vae if vae is None else vae + if ref_images is None: + ref_images = [None] * len(zs) + else: + assert len(zs) == len(ref_images) + + trimed_zs = [] + for z, refs in zip(zs, ref_images): + if refs is not None: + z = z[:, len(refs):, :, :] + trimed_zs.append(z) + + return vae.decode(trimed_zs) + + + + def generate(self, + input_prompt, + input_frames, + input_masks, + input_ref_images, + size=(1280, 720), + frame_num=81, + context_scale=1.0, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation + size (tupele[`int`], *optional*, defaults to (1280,720)): + Controls video resolution, (width,height). + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + # preprocess + # F = frame_num + # target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, + # size[1] // self.vae_stride[1], + # size[0] // self.vae_stride[2]) + # + # seq_len = math.ceil((target_shape[2] * target_shape[3]) / + # (self.patch_size[1] * self.patch_size[2]) * + # target_shape[1] / self.sp_size) * self.sp_size + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + # vace context encode + z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks) + m0 = self.vace_encode_masks(input_masks, input_ref_images) + z = self.vace_latent(z0, m0) + + target_shape = list(z0[0].shape) + target_shape[0] = int(target_shape[0] / 2) + noise = [ + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=self.device, + generator=seed_g) + ] + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noise + + arg_c = {'context': context, 'seq_len': seq_len} + arg_null = {'context': context_null, 'seq_len': seq_len} + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = latents + timestep = [t] + + timestep = torch.stack(timestep) + + self.model.to(self.device) + noise_pred_cond = self.model( + latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[0] + noise_pred_uncond = self.model( + latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,**arg_null)[0] + + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latents[0].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents = [temp_x0.squeeze(0)] + + x0 = latents + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + if self.rank == 0: + videos = self.decode_latent(x0, input_ref_images) + + del noise, latents + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos[0] if self.rank == 0 else None + + +class WanVaceMP(WanVace): + def __init__( + self, + config, + checkpoint_dir, + use_usp=False, + ulysses_size=None, + ring_size=None + ): + self.config = config + self.checkpoint_dir = checkpoint_dir + self.use_usp = use_usp + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12345' + os.environ['RANK'] = '0' + os.environ['WORLD_SIZE'] = '1' + self.in_q_list = None + self.out_q = None + self.inference_pids = None + self.ulysses_size = ulysses_size + self.ring_size = ring_size + self.dynamic_load() + + self.device = 'cpu' if torch.cuda.is_available() else 'cpu' + self.vid_proc = VaceVideoProcessor( + downsample=tuple([x * y for x, y in zip(config.vae_stride, config.patch_size)]), + min_area=720 * 1280, + max_area=720 * 1280, + min_fps=config.sample_fps, + max_fps=config.sample_fps, + zero_start=True, + seq_len=75600, + keep_last=True) + + + def dynamic_load(self): + if hasattr(self, 'inference_pids') and self.inference_pids is not None: + return + gpu_infer = os.environ.get('LOCAL_WORLD_SIZE') or torch.cuda.device_count() + pmi_rank = int(os.environ['RANK']) + pmi_world_size = int(os.environ['WORLD_SIZE']) + in_q_list = [torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)] + out_q = torch.multiprocessing.Manager().Queue() + initialized_events = [torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)] + context = mp.spawn(self.mp_worker, nprocs=gpu_infer, args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, self), join=False) + all_initialized = False + while not all_initialized: + all_initialized = all(event.is_set() for event in initialized_events) + if not all_initialized: + time.sleep(0.1) + print('Inference model is initialized', flush=True) + self.in_q_list = in_q_list + self.out_q = out_q + self.inference_pids = context.pids() + self.initialized_events = initialized_events + + def transfer_data_to_cuda(self, data, device): + if data is None: + return None + else: + if isinstance(data, torch.Tensor): + data = data.to(device) + elif isinstance(data, list): + data = [self.transfer_data_to_cuda(subdata, device) for subdata in data] + elif isinstance(data, dict): + data = {key: self.transfer_data_to_cuda(val, device) for key, val in data.items()} + return data + + def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, work_env): + try: + world_size = pmi_world_size * gpu_infer + rank = pmi_rank * gpu_infer + gpu + print("world_size", world_size, "rank", rank, flush=True) + + torch.cuda.set_device(gpu) + dist.init_process_group( + backend='nccl', + init_method='env://', + rank=rank, + world_size=world_size + ) + + from xfuser.core.distributed import (initialize_model_parallel, + init_distributed_environment) + init_distributed_environment( + rank=dist.get_rank(), world_size=dist.get_world_size()) + + initialize_model_parallel( + sequence_parallel_degree=dist.get_world_size(), + ring_degree=self.ring_size or 1, + ulysses_degree=self.ulysses_size or 1 + ) + + num_train_timesteps = self.config.num_train_timesteps + param_dtype = self.config.param_dtype + shard_fn = partial(shard_model, device_id=gpu) + text_encoder = T5EncoderModel( + text_len=self.config.text_len, + dtype=self.config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=os.path.join(self.checkpoint_dir, self.config.t5_checkpoint), + tokenizer_path=os.path.join(self.checkpoint_dir, self.config.t5_tokenizer), + shard_fn=shard_fn if True else None) + text_encoder.model.to(gpu) + vae_stride = self.config.vae_stride + patch_size = self.config.patch_size + vae = WanVAE( + vae_pth=os.path.join(self.checkpoint_dir, self.config.vae_checkpoint), + device=gpu) + logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}") + model = VaceWanModel.from_pretrained(self.checkpoint_dir) + model.eval().requires_grad_(False) + + if self.use_usp: + from xfuser.core.distributed import get_sequence_parallel_world_size + from .distributed.xdit_context_parallel import (usp_attn_forward, + usp_dit_forward, + usp_dit_forward_vace) + for block in model.blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + for block in model.vace_blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + model.forward = types.MethodType(usp_dit_forward, model) + model.forward_vace = types.MethodType(usp_dit_forward_vace, model) + sp_size = get_sequence_parallel_world_size() + else: + sp_size = 1 + + dist.barrier() + model = shard_fn(model) + sample_neg_prompt = self.config.sample_neg_prompt + + torch.cuda.empty_cache() + event = initialized_events[gpu] + in_q = in_q_list[gpu] + event.set() + + while True: + item = in_q.get() + input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale, \ + shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item + input_frames = self.transfer_data_to_cuda(input_frames, gpu) + input_masks = self.transfer_data_to_cuda(input_masks, gpu) + input_ref_images = self.transfer_data_to_cuda(input_ref_images, gpu) + + if n_prompt == "": + n_prompt = sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=gpu) + seed_g.manual_seed(seed) + + context = text_encoder([input_prompt], gpu) + context_null = text_encoder([n_prompt], gpu) + + # vace context encode + z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, vae=vae) + m0 = self.vace_encode_masks(input_masks, input_ref_images, vae_stride=vae_stride) + z = self.vace_latent(z0, m0) + + target_shape = list(z0[0].shape) + target_shape[0] = int(target_shape[0] / 2) + noise = [ + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=gpu, + generator=seed_g) + ] + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (patch_size[1] * patch_size[2]) * + target_shape[1] / sp_size) * sp_size + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=gpu, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=gpu, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noise + + arg_c = {'context': context, 'seq_len': seq_len} + arg_null = {'context': context_null, 'seq_len': seq_len} + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = latents + timestep = [t] + + timestep = torch.stack(timestep) + + model.to(gpu) + noise_pred_cond = model( + latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[ + 0] + noise_pred_uncond = model( + latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, + **arg_null)[0] + + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latents[0].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents = [temp_x0.squeeze(0)] + + torch.cuda.empty_cache() + x0 = latents + if rank == 0: + videos = self.decode_latent(x0, input_ref_images, vae=vae) + + del noise, latents + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + if rank == 0: + out_q.put(videos[0].cpu()) + + except Exception as e: + trace_info = traceback.format_exc() + print(trace_info, flush=True) + print(e, flush=True) + + + + def generate(self, + input_prompt, + input_frames, + input_masks, + input_ref_images, + size=(1280, 720), + frame_num=81, + context_scale=1.0, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + + input_data = (input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale, + shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model) + for in_q in self.in_q_list: + in_q.put(input_data) + value_output = self.out_q.get() + + return value_output