Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion scripts/convert_ltx2_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
from diffusers.utils.import_utils import is_accelerate_available


Expand Down Expand Up @@ -577,6 +577,33 @@ def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> D
return vocoder


def get_ltx2_spatial_latent_upsampler_config(version: str):
if version == "2.0":
config = {
"in_channels": 128,
"mid_channels": 1024,
"num_blocks_per_stage": 4,
"dims": 3,
"spatial_upsample": True,
"temporal_upsample": False,
"rational_spatial_scale": 2.0,
}
else:
raise ValueError(f"Unsupported version: {version}")
return config


def convert_ltx2_spatial_latent_upsampler(
original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype
):
with init_empty_weights():
latent_upsampler = LTX2LatentUpsamplerModel(**config)

latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
latent_upsampler.to(dtype)
return latent_upsampler


def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
Expand Down Expand Up @@ -682,13 +709,20 @@ def get_args():
type=str,
help="HF Hub id for the LTX 2.0 text tokenizer",
)
parser.add_argument(
"--latent_upsampler_filename",
default="rc1/ltx-2-spatial-upscaler-x2-1.0-rc1.safetensors",
type=str,
help="Latent upsampler filename",
)

parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model")
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
parser.add_argument("--latent_upsampler", action="store_true", help="Whether to convert the latent upsampler")
parser.add_argument(
"--full_pipeline",
action="store_true",
Expand Down Expand Up @@ -788,6 +822,18 @@ def main(args):
if not args.full_pipeline:
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))

if args.latent_upsampler:
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
)
latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version)
latent_upsampler = convert_ltx2_spatial_latent_upsampler(
original_latent_upsampler_ckpt,
latent_upsampler_config,
dtype=vae_dtype,
)
latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler"))

if args.full_pipeline:
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
Expand Down
174 changes: 174 additions & 0 deletions scripts/ltx2_test_latent_upsampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import argparse
import gc
import os

import torch

from diffusers import AutoencoderKLLTX2Video
from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.utils import load_image


def parse_args():
parser = argparse.ArgumentParser()

parser.add_argument("--model_id", type=str, default="diffusers-internal-dev/new-ltx-model")
parser.add_argument("--revision", type=str, default="main")

parser.add_argument("--image_path", required=True, type=str)
parser.add_argument(
"--prompt",
type=str,
default=(
"An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart "
"in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in "
"slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless "
"motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep "
"darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and "
"scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground "
"dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity "
"motion, cinematic lighting, and a breath-taking, movie-like shot."
),
)
parser.add_argument(
"--negative_prompt",
type=str,
default=(
"shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion "
"artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
),
)

parser.add_argument("--num_inference_steps", type=int, default=40)
parser.add_argument("--height", type=int, default=512)
parser.add_argument("--width", type=int, default=768)
parser.add_argument("--num_frames", type=int, default=121)
parser.add_argument("--frame_rate", type=float, default=25.0)
parser.add_argument("--guidance_scale", type=float, default=3.0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--apply_scheduler_fix", action="store_true")

parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--dtype", type=str, default="bf16")
parser.add_argument("--cpu_offload", action="store_true")
parser.add_argument("--vae_tiling", action="store_true")
parser.add_argument("--use_video_latents", action="store_true")

parser.add_argument(
"--output_dir",
type=str,
default="samples",
help="Output directory for generated video",
)
parser.add_argument(
"--output_filename",
type=str,
default="ltx2_i2v_video_upsampled.mp4",
help="Filename of the exported generated video",
)

args = parser.parse_args()
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
return args


def main(args):
pipeline = LTX2ImageToVideoPipeline.from_pretrained(
args.model_id,
revision=args.revision,
torch_dtype=args.dtype,
)
if args.cpu_offload:
pipeline.enable_model_cpu_offload()
else:
pipeline.to(device=args.device)

image = load_image(args.image_path)

first_stage_output_type = "pil"
if args.use_video_latents:
first_stage_output_type = "latent"

video, audio = pipeline(
image=image,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
frame_rate=args.frame_rate,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=torch.Generator(device=args.device).manual_seed(args.seed),
output_type=first_stage_output_type,
return_dict=False,
)

if args.use_video_latents:
# Manually convert the audio latents to a waveform
audio = audio.to(pipeline.audio_vae.dtype)
audio = pipeline.audio_vae.decode(audio, return_dict=False)[0]
audio = pipeline.vocoder(audio)

# Get some pipeline configs for upsampling
spatial_patch_size = pipeline.transformer_spatial_patch_size
temporal_patch_size = pipeline.transformer_temporal_patch_size

# upsample_pipeline = LTX2LatentUpsamplePipeline.from_pretrained(
# args.model_id, revision=args.revision, torch_dtype=args.dtype,
# )
output_sampling_rate = pipeline.vocoder.config.output_sampling_rate
del pipeline # Otherwise there might be an OOM error?
torch.cuda.empty_cache()
gc.collect()

vae = AutoencoderKLLTX2Video.from_pretrained(
args.model_id,
subfolder="vae",
revision=args.revision,
torch_dtype=args.dtype,
)
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
args.model_id,
subfolder="latent_upsampler",
revision=args.revision,
torch_dtype=args.dtype,
)
upsample_pipeline = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
upsample_pipeline.to(device=args.device)
if args.vae_tiling:
upsample_pipeline.vae.enable_tiling()

upsample_kwargs = {
"height": args.height,
"width": args.width,
"output_type": "np",
"return_dict": False,
}
if args.use_video_latents:
upsample_kwargs["latents"] = video
upsample_kwargs["num_frames"] = args.num_frames
upsample_kwargs["spatial_patch_size"] = spatial_patch_size
upsample_kwargs["temporal_patch_size"] = temporal_patch_size
else:
upsample_kwargs["video"] = video

video = upsample_pipeline(**upsample_kwargs)[0]

# Convert video to uint8 (but keep as NumPy array)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)

encode_video(
video[0],
fps=args.frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=output_sampling_rate, # should be 24000
output_path=os.path.join(args.output_dir, args.output_filename),
)


if __name__ == "__main__":
args = parse_args()
main(args)
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@
"LongCatImageEditPipeline",
"LongCatImagePipeline",
"LTX2ImageToVideoPipeline",
"LTX2LatentUpsamplePipeline",
"LTX2Pipeline",
"LTXConditionPipeline",
"LTXI2VLongMultiPromptPipeline",
Expand Down Expand Up @@ -1263,6 +1264,7 @@
LongCatImageEditPipeline,
LongCatImagePipeline,
LTX2ImageToVideoPipeline,
LTX2LatentUpsamplePipeline,
LTX2Pipeline,
LTXConditionPipeline,
LTXI2VLongMultiPromptPipeline,
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@
"LTXLatentUpsamplePipeline",
"LTXI2VLongMultiPromptPipeline",
]
_import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline"]
_import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["lucy"] = ["LucyEditPipeline"]
Expand Down Expand Up @@ -738,7 +738,7 @@
LTXLatentUpsamplePipeline,
LTXPipeline,
)
from .ltx2 import LTX2ImageToVideoPipeline, LTX2Pipeline
from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/pipelines/ltx2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["connectors"] = ["LTX2TextConnectors"]
_import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"]
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
_import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"]
_import_structure["vocoder"] = ["LTX2Vocoder"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Expand All @@ -36,8 +38,10 @@
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .connectors import LTX2TextConnectors
from .latent_upsampler import LTX2LatentUpsamplerModel
from .pipeline_ltx2 import LTX2Pipeline
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
from .vocoder import LTX2Vocoder

else:
Expand Down
Loading