diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index eb0b010075b4..fa6da10bee78 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -16,7 +16,7 @@ LTX2Pipeline, LTX2VideoTransformer3DModel, ) -from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder +from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder from diffusers.utils.import_utils import is_accelerate_available @@ -577,6 +577,33 @@ def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> D return vocoder +def get_ltx2_spatial_latent_upsampler_config(version: str): + if version == "2.0": + config = { + "in_channels": 128, + "mid_channels": 1024, + "num_blocks_per_stage": 4, + "dims": 3, + "spatial_upsample": True, + "temporal_upsample": False, + "rational_spatial_scale": 2.0, + } + else: + raise ValueError(f"Unsupported version: {version}") + return config + + +def convert_ltx2_spatial_latent_upsampler( + original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype +): + with init_empty_weights(): + latent_upsampler = LTX2LatentUpsamplerModel(**config) + + latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True) + latent_upsampler.to(dtype) + return latent_upsampler + + def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: if args.original_state_dict_repo_id is not None: ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename) @@ -682,6 +709,12 @@ def get_args(): type=str, help="HF Hub id for the LTX 2.0 text tokenizer", ) + parser.add_argument( + "--latent_upsampler_filename", + default="rc1/ltx-2-spatial-upscaler-x2-1.0-rc1.safetensors", + type=str, + help="Latent upsampler filename", + ) parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") @@ -689,6 +722,7 @@ def get_args(): parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model") parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model") parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder") + parser.add_argument("--latent_upsampler", action="store_true", help="Whether to convert the latent upsampler") parser.add_argument( "--full_pipeline", action="store_true", @@ -788,6 +822,18 @@ def main(args): if not args.full_pipeline: tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer")) + if args.latent_upsampler: + original_latent_upsampler_ckpt = load_hub_or_local_checkpoint( + repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename + ) + latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version) + latent_upsampler = convert_ltx2_spatial_latent_upsampler( + original_latent_upsampler_ckpt, + latent_upsampler_config, + dtype=vae_dtype, + ) + latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler")) + if args.full_pipeline: scheduler = FlowMatchEulerDiscreteScheduler( use_dynamic_shifting=True, diff --git a/scripts/ltx2_test_latent_upsampler.py b/scripts/ltx2_test_latent_upsampler.py new file mode 100644 index 000000000000..6b2e088f23c9 --- /dev/null +++ b/scripts/ltx2_test_latent_upsampler.py @@ -0,0 +1,174 @@ +import argparse +import gc +import os + +import torch + +from diffusers import AutoencoderKLLTX2Video +from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2LatentUpsamplerModel +from diffusers.pipelines.ltx2.export_utils import encode_video +from diffusers.utils import load_image + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--model_id", type=str, default="diffusers-internal-dev/new-ltx-model") + parser.add_argument("--revision", type=str, default="main") + + parser.add_argument("--image_path", required=True, type=str) + parser.add_argument( + "--prompt", + type=str, + default=( + "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart " + "in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in " + "slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless " + "motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep " + "darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and " + "scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground " + "dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity " + "motion, cinematic lighting, and a breath-taking, movie-like shot." + ), + ) + parser.add_argument( + "--negative_prompt", + type=str, + default=( + "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion " + "artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." + ), + ) + + parser.add_argument("--num_inference_steps", type=int, default=40) + parser.add_argument("--height", type=int, default=512) + parser.add_argument("--width", type=int, default=768) + parser.add_argument("--num_frames", type=int, default=121) + parser.add_argument("--frame_rate", type=float, default=25.0) + parser.add_argument("--guidance_scale", type=float, default=3.0) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--apply_scheduler_fix", action="store_true") + + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--dtype", type=str, default="bf16") + parser.add_argument("--cpu_offload", action="store_true") + parser.add_argument("--vae_tiling", action="store_true") + parser.add_argument("--use_video_latents", action="store_true") + + parser.add_argument( + "--output_dir", + type=str, + default="samples", + help="Output directory for generated video", + ) + parser.add_argument( + "--output_filename", + type=str, + default="ltx2_i2v_video_upsampled.mp4", + help="Filename of the exported generated video", + ) + + args = parser.parse_args() + args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 + return args + + +def main(args): + pipeline = LTX2ImageToVideoPipeline.from_pretrained( + args.model_id, + revision=args.revision, + torch_dtype=args.dtype, + ) + if args.cpu_offload: + pipeline.enable_model_cpu_offload() + else: + pipeline.to(device=args.device) + + image = load_image(args.image_path) + + first_stage_output_type = "pil" + if args.use_video_latents: + first_stage_output_type = "latent" + + video, audio = pipeline( + image=image, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, + frame_rate=args.frame_rate, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + generator=torch.Generator(device=args.device).manual_seed(args.seed), + output_type=first_stage_output_type, + return_dict=False, + ) + + if args.use_video_latents: + # Manually convert the audio latents to a waveform + audio = audio.to(pipeline.audio_vae.dtype) + audio = pipeline.audio_vae.decode(audio, return_dict=False)[0] + audio = pipeline.vocoder(audio) + + # Get some pipeline configs for upsampling + spatial_patch_size = pipeline.transformer_spatial_patch_size + temporal_patch_size = pipeline.transformer_temporal_patch_size + + # upsample_pipeline = LTX2LatentUpsamplePipeline.from_pretrained( + # args.model_id, revision=args.revision, torch_dtype=args.dtype, + # ) + output_sampling_rate = pipeline.vocoder.config.output_sampling_rate + del pipeline # Otherwise there might be an OOM error? + torch.cuda.empty_cache() + gc.collect() + + vae = AutoencoderKLLTX2Video.from_pretrained( + args.model_id, + subfolder="vae", + revision=args.revision, + torch_dtype=args.dtype, + ) + latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( + args.model_id, + subfolder="latent_upsampler", + revision=args.revision, + torch_dtype=args.dtype, + ) + upsample_pipeline = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler) + upsample_pipeline.to(device=args.device) + if args.vae_tiling: + upsample_pipeline.vae.enable_tiling() + + upsample_kwargs = { + "height": args.height, + "width": args.width, + "output_type": "np", + "return_dict": False, + } + if args.use_video_latents: + upsample_kwargs["latents"] = video + upsample_kwargs["num_frames"] = args.num_frames + upsample_kwargs["spatial_patch_size"] = spatial_patch_size + upsample_kwargs["temporal_patch_size"] = temporal_patch_size + else: + upsample_kwargs["video"] = video + + video = upsample_pipeline(**upsample_kwargs)[0] + + # Convert video to uint8 (but keep as NumPy array) + video = (video * 255).round().astype("uint8") + video = torch.from_numpy(video) + + encode_video( + video[0], + fps=args.frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=output_sampling_rate, # should be 24000 + output_path=os.path.join(args.output_dir, args.output_filename), + ) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 770afd24e0b5..c749bad4be47 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -542,6 +542,7 @@ "LongCatImageEditPipeline", "LongCatImagePipeline", "LTX2ImageToVideoPipeline", + "LTX2LatentUpsamplePipeline", "LTX2Pipeline", "LTXConditionPipeline", "LTXI2VLongMultiPromptPipeline", @@ -1263,6 +1264,7 @@ LongCatImageEditPipeline, LongCatImagePipeline, LTX2ImageToVideoPipeline, + LTX2LatentUpsamplePipeline, LTX2Pipeline, LTXConditionPipeline, LTXI2VLongMultiPromptPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 464b24f34069..b94319ffcbdc 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -290,7 +290,7 @@ "LTXLatentUpsamplePipeline", "LTXI2VLongMultiPromptPipeline", ] - _import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline"] + _import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] @@ -738,7 +738,7 @@ LTXLatentUpsamplePipeline, LTXPipeline, ) - from .ltx2 import LTX2ImageToVideoPipeline, LTX2Pipeline + from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline from .lucy import LucyEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py index 2760f8f7feeb..115e83e827a4 100644 --- a/src/diffusers/pipelines/ltx2/__init__.py +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -23,8 +23,10 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["connectors"] = ["LTX2TextConnectors"] + _import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"] _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] + _import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"] _import_structure["vocoder"] = ["LTX2Vocoder"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -36,8 +38,10 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .connectors import LTX2TextConnectors + from .latent_upsampler import LTX2LatentUpsamplerModel from .pipeline_ltx2 import LTX2Pipeline from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline + from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline from .vocoder import LTX2Vocoder else: diff --git a/src/diffusers/pipelines/ltx2/latent_upsampler.py b/src/diffusers/pipelines/ltx2/latent_upsampler.py new file mode 100644 index 000000000000..69a9b1d9193f --- /dev/null +++ b/src/diffusers/pipelines/ltx2/latent_upsampler.py @@ -0,0 +1,285 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +RATIONAL_RESAMPLER_SCALE_MAPPING = { + 0.75: (3, 4), + 1.5: (3, 2), + 2.0: (2, 1), + 4.0: (4, 1), +} + + +# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.ResBlock +class ResBlock(torch.nn.Module): + def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = torch.nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = torch.nn.GroupNorm(32, channels) + self.activation = torch.nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.norm1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.norm2(hidden_states) + hidden_states = self.activation(hidden_states + residual) + return hidden_states + + +# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.PixelShuffleND +class PixelShuffleND(torch.nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + + self.dims = dims + self.upscale_factors = upscale_factors + + if dims not in [1, 2, 3]: + raise ValueError("dims must be 1, 2, or 3") + + def forward(self, x): + if self.dims == 3: + # spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:3])) + .permute(0, 1, 5, 2, 6, 3, 7, 4) + .flatten(6, 7) + .flatten(4, 5) + .flatten(2, 3) + ) + elif self.dims == 2: + # spatial: b (c p1 p2) h w -> b c (h p1) (w p2) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3) + ) + elif self.dims == 1: + # temporal: b (c p1) f h w -> b c (f p1) h w + return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3) + + +class BlurDownsample(torch.nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W. + Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None: + super().__init__() + + if dims not in (2, 3): + raise ValueError(f"`dims` must be either 2 or 3 but is {dims}") + if kernel_size < 3 or kernel_size % 2 != 1: + raise ValueError(f"`kernel_size` must be an odd number >= 3 but is {kernel_size}") + + self.dims = dims + self.stride = stride + self.kernel_size = kernel_size + + # 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from + # the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and + # provides a smooth approximation of a Gaussian filter (often called a "binomial filter"). + # The 2D kernel is constructed as the outer product and normalized. + k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + if self.dims == 2: + c = x.shape[1] + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + else: + # dims == 3: apply per-frame on H,W + b, c, f, _, _ = x.shape + x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W] + + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + + h2, w2 = x.shape[-2:] + x = x.unflatten(0, (b, f)).reshape(b, -1, f, h2, w2) # [B * F, C, H, W] --> [B, C, F, H, W] + return x + + +class SpatialRationalResampler(torch.nn.Module): + """ + Scales by the spatial size of the input by a rational number `scale`. For example, `scale = 0.75` will downsample + by a factor of 3 / 4, while `scale = 1.5` will upsample by a factor of 3 / 2. This works by first upsampling the + input by the (integer) numerator of `scale`, and then performing a blur + stride anti-aliased downsample by the + (integer) denominator. + """ + + def __init__(self, mid_channels: int = 1024, scale: float = 2.0): + super().__init__() + self.scale = float(scale) + num_denom = RATIONAL_RESAMPLER_SCALE_MAPPING.get(scale, None) + if num_denom is None: + raise ValueError( + f"The supplied `scale` {scale} is not supported; supported scales are {list(RATIONAL_RESAMPLER_SCALE_MAPPING.keys())}" + ) + self.num, self.den = num_denom + + self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Expected x shape: [B * F, C, H, W] + # b, _, f, h, w = x.shape + # x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W] + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + # x = x.unflatten(0, (b, f)).reshape(b, -1, f, h, w) # [B * F, C, H, W] --> [B, C, F, H, W] + return x + + +class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`, defaults to `128`): + Number of channels in the input latent + mid_channels (`int`, defaults to `512`): + Number of channels in the middle layers + num_blocks_per_stage (`int`, defaults to `4`): + Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`, defaults to `3`): + Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`, defaults to `True`): + Whether to spatially upsample the latent + temporal_upsample (`bool`, defaults to `False`): + Whether to temporally upsample the latent + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 1024, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + rational_spatial_scale: Optional[float] = 2.0, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + + ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = torch.nn.GroupNorm(32, mid_channels) + self.initial_activation = torch.nn.SiLU() + + self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if rational_spatial_scale is not None: + self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale) + else: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = torch.nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.dims == 2: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.upsampler(hidden_states) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + else: + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + if self.temporal_upsample: + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states[:, :, 1:, :, :] + else: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + + return hidden_states diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index cbfb5b5c4a1b..3e47a695d5ef 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -1082,21 +1082,27 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + if output_type == "latent": video = latents audio = audio_latents else: - latents = self._unpack_latents( - latents, - latent_num_frames, - latent_height, - latent_width, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - latents = self._denormalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) latents = latents.to(prompt_embeds.dtype) if not self.vae.config.timestep_conditioning: @@ -1121,10 +1127,6 @@ def __call__( video = self.video_processor.postprocess_video(video, output_type=output_type) audio_latents = audio_latents.to(self.audio_vae.dtype) - audio_latents = self._denormalize_audio_latents( - audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std - ) - audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] audio = self.vocoder(generated_mel_spectrograms) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 652955fee129..bff977e6a711 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -1179,21 +1179,27 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + if output_type == "latent": video = latents audio = audio_latents else: - latents = self._unpack_latents( - latents, - latent_num_frames, - latent_height, - latent_width, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - latents = self._denormalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) latents = latents.to(prompt_embeds.dtype) if not self.vae.config.timestep_conditioning: @@ -1218,10 +1224,6 @@ def __call__( video = self.video_processor.postprocess_video(video, output_type=output_type) audio_latents = audio_latents.to(self.audio_vae.dtype) - audio_latents = self._denormalize_audio_latents( - audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std - ) - audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] audio = self.vocoder(generated_mel_spectrograms) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py new file mode 100644 index 000000000000..94cddad42fc3 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -0,0 +1,432 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch + +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLLTX2Video +from ...utils import get_logger, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..ltx.pipeline_output import LTXPipelineOutput +from ..pipeline_utils import DiffusionPipeline +from .latent_upsampler import LTX2LatentUpsamplerModel + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2ImageToVideoPipeline, LTX2 + >>> from diffusers.utils import load_image + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + + >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + ... ) + >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> video, audio = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=25.0, + ... num_inference_steps=40, + ... guidance_scale=3.0, + ... output_type="pil", + ... return_dict=False, + ... ) + + >>> upsample_pipe = LTX2LatentUpsamplePipeline.from_pretrained( + ... "Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16 + ... ) + >>> upsample_pipe.to("cuda") + + >>> video = upsample_pipe( + ... video=video, + ... width=768, + ... height=512, + ... output_type="pil", + ... return_dict=False, + ... )[0] + + >>> video = (video * 255).round().astype("uint8") + >>> video = torch.from_numpy(video) + >>> encode_video(video[0], fps=25.0, audio=audio[0].float().cpu(), output_path="output.mp4") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class LTX2LatentUpsamplePipeline(DiffusionPipeline): + model_cpu_offload_seq = "vae->latent_upsampler" + + def __init__( + self, + vae: AutoencoderKLLTX2Video, + latent_upsampler: LTX2LatentUpsamplerModel, + ) -> None: + super().__init__() + + self.register_modules(vae=vae, latent_upsampler=latent_upsampler) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + + def prepare_latents( + self, + video: Optional[torch.Tensor] = None, + batch_size: int = 1, + num_frames: int = 121, + height: int = 512, + width: int = 768, + spatial_patch_size: int = 1, + temporal_patch_size: int = 1, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 3: + # Convert token seq [B, S, D] to latent video [B, C, F, H, W] + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + latents = self._unpack_latents( + latents, latent_num_frames, latent_height, latent_width, spatial_patch_size, temporal_patch_size + ) + return latents.to(device=device, dtype=dtype) + + video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + # NOTE: latent upsampler operates on the unnormalized latents, so don't normalize here + # init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + return init_latents + + def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0): + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent + tensor. + + Args: + latent (`torch.Tensor`): + Input latents to normalize + reference_latents (`torch.Tensor`): + The reference latents providing style statistics. + factor (`float`): + Blending factor between original and transformed latent. Range: -10.0 to 10.0, Default: 1.0 + + Returns: + torch.Tensor: The transformed latent tensor + """ + result = latents.clone() + + for i in range(latents.size(0)): + for c in range(latents.size(1)): + r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order + i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + + result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + + result = torch.lerp(latents, result, factor) + return result + + def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor: + """ + Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually + smooth way using a sigmoid-based compression. + + This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially + when controlling dynamic behavior with a `compression` factor. + + Args: + latents : torch.Tensor + Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range. + compression : float + Compression strength in the range [0, 1]. + - 0.0: No tone-mapping (identity transform) + - 1.0: Full compression effect + + Returns: + torch.Tensor + The tone-mapped latent tensor of the same shape as input. + """ + # Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot + scale_factor = compression * 0.75 + abs_latents = torch.abs(latents) + + # Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0 + # When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect + sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0)) + scales = 1.0 - 0.8 * scale_factor * sigmoid_term + + filtered = latents * scales + return filtered + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + def check_inputs(self, video, height, width, latents, tone_map_compression_ratio): + if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` can be provided.") + if video is None and latents is None: + raise ValueError("One of `video` or `latents` has to be provided.") + + if not (0 <= tone_map_compression_ratio <= 1): + raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + video: Optional[List[PipelineImageInput]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + spatial_patch_size: int = 1, + temporal_patch_size: int = 1, + latents: Optional[torch.Tensor] = None, + latents_normalized: bool = False, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, + adain_factor: float = 0.0, + tone_map_compression_ratio: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + video (`List[PipelineImageInput]`, *optional*) + The video to be upsampled (such as a LTX 2.0 first stage output). If not supplied, `latents` should be + supplied. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the input video (not the generated video, which will have a larger resolution). + width (`int`, *optional*, defaults to `768`): + The width in pixels of the input video (not the generated video, which will have a larger resolution). + num_frames (`int`, *optional*, defaults to `121`): + The number of frames in the input video. + spatial_patch_size (`int`, *optional*, defaults to `1`): + The spatial patch size of the video latents. Used when `latents` is supplied if unpacking is necessary. + temporal_patch_size (`int`, *optional*, defaults to `1`): + The temporal patch size of the video latents. Used when `latents` is supplied if unpacking is + necessary. + latents (`torch.Tensor`, *optional*): + Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a + patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size, + latent_channels, latent_frames, latent_height, latent_width)`. + latents_normalized (`bool`, *optional*, defaults to `False`) + If `latents` are supplied, whether the `latents` are normalized using the VAE latent mean and std. If + `True`, the `latents` will be denormalized before being supplied to the latent upsampler. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + adain_factor (`float`, *optional*, defaults to `0.0`): + Adaptive Instance Normalization (AdaIN) blending factor between the upsampled and original latents. + Should be in [-10.0, 10.0]; supplying 0.0 (the default) means that AdaIN is not performed. + tone_map_compression_ratio (`float`, *optional*, defaults to `0.0`): + The compression strength for tone mapping, which will reduce the dynamic range of the latent values. + This is useful for regularizing high-variance latents or for conditioning outputs during generation. + Should be in [0, 1], where 0.0 (the default) means tone mapping is not applied and 1.0 corresponds to + the full compression effect. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is the upsampled video. + """ + + self.check_inputs( + video=video, + height=height, + width=width, + latents=latents, + tone_map_compression_ratio=tone_map_compression_ratio, + ) + + if video is not None: + # Batched video input is not yet tested/supported. TODO: take a look later + batch_size = 1 + else: + batch_size = latents.shape[0] + device = self._execution_device + + if video is not None: + num_frames = len(video) + if num_frames % self.vae_temporal_compression_ratio != 1: + num_frames = ( + num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1 + ) + video = video[:num_frames] + logger.warning( + f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames." + ) + video = self.video_processor.preprocess_video(video, height=height, width=width) + video = video.to(device=device, dtype=torch.float32) + + latents_supplied = latents is not None + latents = self.prepare_latents( + video=video, + batch_size=batch_size, + num_frames=num_frames, + height=height, + width=width, + spatial_patch_size=spatial_patch_size, + temporal_patch_size=temporal_patch_size, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + + if latents_supplied and latents_normalized: + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(self.latent_upsampler.dtype) + latents_upsampled = self.latent_upsampler(latents) + + if adain_factor > 0.0: + latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor) + else: + latents = latents_upsampled + + if tone_map_compression_ratio > 0.0: + latents = self.tone_map_latents(latents, tone_map_compression_ratio) + + if output_type == "latent": + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + video = latents + else: + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video)