diff --git a/context_windows/context.py b/context_windows/context.py index 4f64f3f9..4bb6fe42 100644 --- a/context_windows/context.py +++ b/context_windows/context.py @@ -186,8 +186,18 @@ def get_total_steps( def create_window_mask(noise_pred_context, c, latent_video_length, context_overlap, looped=False, window_type="linear"): window_mask = torch.ones_like(noise_pred_context) - - if window_type == "pyramid": + + if window_type == "flashvsr": + # Special mode for FlashVSR: use overlap for context but don't blend + # First chunk: keep all frames (weight=1) + # Later chunks: mask out overlap region (weight=0), keep only new frames (weight=1) + if min(c) > 0: # Not the first chunk + # Set overlap region to 0 so these frames don't contribute to final result + window_mask[:, :context_overlap] = 0 + # All other frames get weight 1 (no blending/ramping) + return window_mask + + elif window_type == "pyramid": # Create pyramid weights that peak in the middle length = noise_pred_context.shape[1] if length % 2 == 0: diff --git a/nodes.py b/nodes.py index acb70e2a..45dc3406 100644 --- a/nodes.py +++ b/nodes.py @@ -1654,7 +1654,7 @@ def INPUT_TYPES(s): "verbose": ("BOOLEAN", {"default": False, "tooltip": "Print debug output"}), }, "optional": { - "fuse_method": (["linear", "pyramid"], {"default": "linear", "tooltip": "Window weight function: linear=ramps at edges only, pyramid=triangular weights peaking in middle"}), + "fuse_method": (["linear", "pyramid", "flashvsr"], {"default": "linear", "tooltip": "Window weight function: linear=ramps at edges only, pyramid=triangular weights peaking in middle, flashvsr=no blending (use for FlashVSR upscaling)"}), "reference_latent": ("LATENT", {"tooltip": "Image to be used as init for I2V models for windows where first frame is not the actual first frame. Mostly useful with MAGREF model"}), } } @@ -1999,6 +1999,9 @@ def decode(self, vae, samples, enable_vae_tiling, tile_x, tile_y, tile_stride_x, flashvsr_LQ_images = samples.get("flashvsr_LQ_images", None) + # Get context_options from samples dictionary (passed from WanVideoSampler) + context_options = samples.get("context_options", None) + vae.to(device) latents = latents.to(device = device, dtype = vae.dtype) @@ -2010,11 +2013,104 @@ def decode(self, vae, samples, enable_vae_tiling, tile_x, tile_y, tile_stride_x, if drop_last: latents = latents[:, :, :-1] - if type(vae).__name__ == "TAEHV": - images = vae.decode_video(latents.permute(0, 2, 1, 3, 4), cond=flashvsr_LQ_images.to(vae.dtype))[0].permute(1, 0, 2, 3) - images = torch.clamp(images, 0.0, 1.0) - images = images.permute(1, 2, 3, 0).cpu().float() - return (images,) + if type(vae).__name__ == "TAEHV": + # FlashVSR decoding with chunking for memory efficiency + # Convert context_frames from pixel to latent space for comparison + latent_context_frames_threshold = max(1, context_options["context_frames"] // 4) if context_options is not None else 999999 + + if context_options is not None and latents.shape[2] > latent_context_frames_threshold: + # Chunk the decoding with overlap for temporal continuity + context_frames = context_options["context_frames"] + context_overlap = context_options.get("context_overlap", 16) + num_frames = latents.shape[2] + + # Work in latent space for chunking + latent_context_frames = max(1, context_frames // 4) + latent_overlap = max(1, context_overlap // 4) + stride = latent_context_frames - latent_overlap + + # FlashVSR decoder outputs at PIXEL temporal resolution (4x latent) + # So we need to discard overlap in pixel space + pixel_overlap = context_overlap + + log.info(f"Decoding FlashVSR with overlap: {latent_context_frames} latent frames per chunk, {latent_overlap} latent overlap ({pixel_overlap} pixel overlap), {num_frames} total latent frames") + + decoded_frames = [] + + # Process chunks with overlap, but trim the overlap + for chunk_idx, start_idx in enumerate(range(0, num_frames, stride)): + end_idx = min(start_idx + latent_context_frames, num_frames) + + # Extract latent chunk + chunk_latents = latents[:, :, start_idx:end_idx] + + # Extract corresponding LQ images if they exist + chunk_LQ = None + if flashvsr_LQ_images is not None: + lq_start = start_idx * 4 + lq_end = end_idx * 4 + chunk_LQ = flashvsr_LQ_images[:, :, lq_start:lq_end].to(vae.dtype) + + # Decode this chunk with sequential processing + # Output is at PIXEL temporal resolution (not latent!) + chunk_images = vae.decode_video( + chunk_latents.permute(0, 2, 1, 3, 4), + cond=chunk_LQ, + parallel=False, # Frame-by-frame within chunk + show_progress_bar=True + )[0].permute(1, 0, 2, 3) + + # Keep only non-overlapping frames (discard in PIXEL space) + if chunk_idx == 0: + # First chunk: keep all frames + keep_frames = chunk_images + else: + # Calculate actual overlap based on decoder output + # Decoder doesn't always output exactly 4x latent frames + actual_pixel_frames = chunk_images.shape[1] + overlap_ratio = latent_overlap / latent_context_frames + actual_overlap = int(actual_pixel_frames * overlap_ratio) + + # Discard the overlap frames + if chunk_images.shape[1] > actual_overlap: + keep_frames = chunk_images[:, actual_overlap:] + else: + keep_frames = chunk_images + + decoded_frames.append(keep_frames.cpu()) + + # Log before cleanup + if chunk_idx == 0: + log.info(f"Decoded chunk {start_idx}-{end_idx} latent ({start_idx*4}-{end_idx*4} pixel), decoder output {chunk_images.shape[1]} frames, kept all") + else: + log.info(f"Decoded chunk {start_idx}-{end_idx} latent ({start_idx*4}-{end_idx*4} pixel), decoder output {chunk_images.shape[1]} frames, discarded {actual_overlap}, kept {decoded_frames[-1].shape[1]}") + + # Clean up + del chunk_latents, chunk_images, keep_frames + if chunk_LQ is not None: + del chunk_LQ + mm.soft_empty_cache() + + # Concatenate all chunks + images = torch.cat(decoded_frames, dim=1) + images = torch.clamp(images, 0.0, 1.0) + images = images.permute(1, 2, 3, 0).float() + + del decoded_frames + mm.soft_empty_cache() + + return (images,) + else: + # Single-pass decoding for short videos + images = vae.decode_video( + latents.permute(0, 2, 1, 3, 4), + cond=flashvsr_LQ_images.to(vae.dtype) if flashvsr_LQ_images is not None else None, + parallel=True + )[0].permute(1, 0, 2, 3) + + images = torch.clamp(images, 0.0, 1.0) + images = images.permute(1, 2, 3, 0).cpu().float() + return (images,) else: if end_image is not None: enable_vae_tiling = False diff --git a/nodes_sampler.py b/nodes_sampler.py index 7ed6829d..e512f8f7 100644 --- a/nodes_sampler.py +++ b/nodes_sampler.py @@ -834,14 +834,12 @@ def process(self, model, image_embeds, shift, steps, cfg, seed, scheduler, rifle flashvsr_LQ_images = image_embeds.get("flashvsr_LQ_images", None) flashvsr_strength = image_embeds.get("flashvsr_strength", 1.0) if flashvsr_LQ_images is not None: - if flashvsr_LQ_images.shape[0] < num_frames + 4: - missing_frames = num_frames + 4 - flashvsr_LQ_images.shape[0] - last_frame = flashvsr_LQ_images[-1:].repeat(missing_frames, 1, 1, 1) - flashvsr_LQ_images = torch.cat([flashvsr_LQ_images, last_frame], dim=0) - LQ_images = flashvsr_LQ_images[:num_frames+4].unsqueeze(0).movedim(-1, 1).to(dtype) * 2 - 1 + LQ_images = flashvsr_LQ_images.unsqueeze(0).movedim(-1, 1).to(device, dtype) * 2 - 1 if context_options is None: - flashvsr_LQ_latent = transformer.LQ_proj_in(LQ_images.to(device)) + flashvsr_LQ_latent = transformer.LQ_proj_in(LQ_images) log.info(f"flashvsr_LQ_latent: {flashvsr_LQ_latent[0].shape}") + if noise.shape[1] != 1: + noise = noise[:, :-1] seq_len = math.ceil((noise.shape[2] * noise.shape[3]) / 4 * noise.shape[1]) latent = noise @@ -1955,7 +1953,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i end = c[-1] * 4 + 1 + 4 center_indices = torch.arange(start, end, 1) center_indices = torch.clamp(center_indices, min=0, max=LQ_images.shape[2] - 1) - partial_flashvsr_LQ_images = LQ_images[:, :, center_indices].to(device) + partial_flashvsr_LQ_images = LQ_images[:, :, center_indices].to(device, dtype) partial_flashvsr_LQ_latent = transformer.LQ_proj_in(partial_flashvsr_LQ_images) if len(timestep.shape) != 1: @@ -2999,7 +2997,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i torch.cuda.reset_peak_memory_stats(device) except: pass - return ({ + output_dict = { "samples": latent.unsqueeze(0).cpu(), "looped": is_looped, "end_image": end_image if not fun_or_fl2v_model else None, @@ -3010,7 +3008,13 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i "cache_states": cache_states, "latent_ovi_audio": latent_ovi.unsqueeze(0).transpose(1, 2).cpu() if latent_ovi is not None else None, "flashvsr_LQ_images": LQ_images, - },{ + } + + # Only pass context_options if it's actually being used (not None) + if context_options is not None: + output_dict["context_options"] = context_options + + return (output_dict, { "samples": callback_latent.unsqueeze(0).cpu() if callback is not None else None, })