Skip to content

Conversation

@dg845
Copy link
Collaborator

@dg845 dg845 commented Jan 6, 2026

What does this PR do?

This PR adds pipelines for the LTX 2.0 video generation model (code, weights). LTX 2.0 is an audio-video foundation model that generates videos with synced audio; it supports generation tasks such as text-to-video (T2V), text-image-to-video (TI2V), and more.

An example usage script for I2V is as follows:

import torch
from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.utils import load_image

pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
prompt = "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot."
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."

frame_rate = 24.0
video, audio = pipe(
    image=image,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=768,
    height=512,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=40,
    guidance_scale=4.0,
    output_type="np",
    return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_sample.mp4",
)

Note that LTX 2.0 video generation uses a lot of memory; it is necessary to use CPU offloading even for an A100 which has 80 GB VRAM (assuming no other memory optimizations other than bf16 inference are used).

Here is an I2V sample from the above:

ltx2_i2v_sample.mp4

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yiyixuxu
@sayakpaul
@ofirbb

dg845 and others added 30 commits December 12, 2025 07:52
LTX 2.0 Vocoder Implementation
LTX 2.0 Video VAE Implementation
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small comments.

num_rope_elems = num_pos_dims * 2

# 3. Create a 1D grid of frequencies for RoPE
freqs_dtype = torch.float64 if self.double_precision else torch.float32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): we could keep the self.freqs_dtype inside the init to skip doing it multiple times.

Comment on lines +1187 to +1190
video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device)
audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(
audio_coords[:, 0:1, :], device=audio_hidden_states.device
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): would be nice to have a comment about the small indexing going on there.

@yiyixuxu yiyixuxu mentioned this pull request Jan 7, 2026
@sayakpaul sayakpaul requested a review from yiyixuxu January 7, 2026 12:13
sayakpaul and others added 9 commits January 7, 2026 15:46
* Initial implementation of LTX 2.0 latent upsampling pipeline

* Add new LTX 2.0 spatial latent upsampler logic

* Add test script for LTX 2.0 latent upsampling

* Add option to enable VAE tiling in upsampling test script

* Get latent upsampler working with video latents

* Fix typo in BlurDownsample

* Add latent upsample pipeline docstring and example

* Remove deprecated pipeline VAE slicing/tiling methods

* make style and make quality

* When returning latents, return unpacked and denormalized latents for T2V and I2V

* Add model_cpu_offload_seq for latent upsampling pipeline

---------

Co-authored-by: Daniel Gu <[email protected]>
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@dg845
Copy link
Collaborator Author

dg845 commented Jan 8, 2026

Merging as the CI failures are unrelated.

@dg845 dg845 merged commit c10bdd9 into main Jan 8, 2026
10 of 12 checks passed
@hannalaguilar
Copy link

hannalaguilar commented Jan 8, 2026

What does this PR do?

This PR adds pipelines for the LTX 2.0 video generation model (code, weights). LTX 2.0 is an audio-video foundation model that generates videos with synced audio; it supports generation tasks such as text-to-video (T2V), text-image-to-video (TI2V), and more.

An example usage script for I2V is as follows:

import torch
from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.utils import load_image

pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

prompt = "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot."
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."

frame_rate = 24.0
video, audio = pipe(
	prompt=prompt,
	negative_prompt=negative_prompt,
	width=768,
	height=512,
	num_frames=121,
	frame_rate=frame_rate,
	num_inference_steps=40,
	guidance_scale=4.0,
	output_type="np",
	return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)

encode_video(
	video[0],
	fps=frame_rate,
	audio=audio[0].float().cpu(),
	audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
	output_path="ltx2_sample.mp4",
)

Note that LTX 2.0 video generation uses a lot of memory; it is necessary to use CPU offloading even for an A100 which has 80 GB VRAM (assuming no other memory optimizations other than bf16 inference are used).

Here is an I2V sample from the above:

ltx2_i2v_sample.mp4

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@yiyixuxu @sayakpaul @ofirbb

In this example, an `image is not being passed to the pipeline. Should be:

image = load_image("image.png")
video, audio = pipe(
	prompt=prompt,
	negative_prompt=negative_prompt,
        image=image,
	width=768,
	height=512,
	num_frames=121,
	frame_rate=frame_rate,
	num_inference_steps=40,
	guidance_scale=4.0,
	output_type="np",
	return_dict=False,
)

@dg845 dg845 deleted the ltx-2-transformer branch January 8, 2026 22:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants