diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index 7911bc2b2332..bbae6a9020af 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -33,6 +33,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi - [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen). - [`ZImageLoraLoaderMixin`] provides similar functions for [Z-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/zimage). - [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2). +- [`LTX2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx2). - [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more. > [!TIP] @@ -62,6 +63,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi [[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin +## LTX2LoraLoaderMixin + +[[autodoc]] loaders.lora_pipeline.LTX2LoraLoaderMixin + ## CogVideoXLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md index 231e3112a907..4c6860daf024 100644 --- a/docs/source/en/api/pipelines/ltx2.md +++ b/docs/source/en/api/pipelines/ltx2.md @@ -14,6 +14,10 @@ # LTX-2 +
+ LoRA +
+ LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution. You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization. diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index ace4e8543a1c..bdd4dbbcd4b5 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -67,6 +67,7 @@ def text_encoder_attn_modules(text_encoder): "SD3LoraLoaderMixin", "AuraFlowLoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", + "LTX2LoraLoaderMixin", "LTXVideoLoraLoaderMixin", "LoraLoaderMixin", "FluxLoraLoaderMixin", @@ -121,6 +122,7 @@ def text_encoder_attn_modules(text_encoder): HunyuanVideoLoraLoaderMixin, KandinskyLoraLoaderMixin, LoraLoaderMixin, + LTX2LoraLoaderMixin, LTXVideoLoraLoaderMixin, Lumina2LoraLoaderMixin, Mochi1LoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 2e87f757c352..8f7309d4ed1e 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2140,6 +2140,54 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref return converted_state_dict +def _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"): + # Remove the prefix + state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{non_diffusers_prefix}.")} + converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()} + + if non_diffusers_prefix == "diffusion_model": + rename_dict = { + "patchify_proj": "proj_in", + "audio_patchify_proj": "audio_proj_in", + "av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift", + "av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate", + "av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift", + "av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate", + "scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table", + "scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table", + "q_norm": "norm_q", + "k_norm": "norm_k", + } + else: + rename_dict = {"aggregate_embed": "text_proj_in"} + + # Apply renaming + renamed_state_dict = {} + for key, value in converted_state_dict.items(): + new_key = key[:] + for old_pattern, new_pattern in rename_dict.items(): + new_key = new_key.replace(old_pattern, new_pattern) + renamed_state_dict[new_key] = value + + # Handle adaln_single -> time_embed and audio_adaln_single -> audio_time_embed + final_state_dict = {} + for key, value in renamed_state_dict.items(): + if key.startswith("adaln_single."): + new_key = key.replace("adaln_single.", "time_embed.") + final_state_dict[new_key] = value + elif key.startswith("audio_adaln_single."): + new_key = key.replace("audio_adaln_single.", "audio_time_embed.") + final_state_dict[new_key] = value + else: + final_state_dict[key] = value + + # Add transformer prefix + prefix = "transformer" if non_diffusers_prefix == "diffusion_model" else "connectors" + final_state_dict = {f"{prefix}.{k}": v for k, v in final_state_dict.items()} + + return final_state_dict + + def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict): has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict) if has_diffusion_model: diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 03a2fe9f3f8e..5fc650a80d87 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -48,6 +48,7 @@ _convert_non_diffusers_flux2_lora_to_diffusers, _convert_non_diffusers_hidream_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, + _convert_non_diffusers_ltx2_lora_to_diffusers, _convert_non_diffusers_ltxv_lora_to_diffusers, _convert_non_diffusers_lumina2_lora_to_diffusers, _convert_non_diffusers_qwen_lora_to_diffusers, @@ -74,6 +75,7 @@ TEXT_ENCODER_NAME = "text_encoder" UNET_NAME = "unet" TRANSFORMER_NAME = "transformer" +LTX2_CONNECTOR_NAME = "connectors" _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"} @@ -3011,6 +3013,233 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) +class LTX2LoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`LTX2VideoTransformer3DModel`]. Specific to [`LTX2Pipeline`]. + """ + + _lora_loadable_modules = ["transformer", "connectors"] + transformer_name = TRANSFORMER_NAME + connectors_name = LTX2_CONNECTOR_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + final_state_dict = state_dict + is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict) + has_connector = any(k.startswith("text_embedding_projection.") for k in state_dict) + if is_non_diffusers_format: + final_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict) + if has_connector: + connectors_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers( + state_dict, "text_embedding_projection" + ) + final_state_dict.update(connectors_state_dict) + out = (final_state_dict, metadata) if return_lora_metadata else final_state_dict + return out + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + transformer_peft_state_dict = { + k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.") + } + connectors_peft_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.connectors_name}.")} + self.load_lora_into_transformer( + transformer_peft_state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + if connectors_peft_state_dict: + self.load_lora_into_transformer( + connectors_peft_state_dict, + transformer=getattr(self, self.connectors_name) + if not hasattr(self, "connectors") + else self.connectors, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + prefix=self.connectors_name, + ) + + @classmethod + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + prefix: str = "transformer", + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {prefix}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + prefix=prefix, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information. + """ + lora_layers = {} + lora_metadata = {} + + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata + + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") + + cls._save_lora_weights( + save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details. + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. + """ + super().unfuse_lora(components=components, **kwargs) + + class SanaLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`]. diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 140018bdd34a..16f1a5d1ec7e 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -67,6 +67,8 @@ "QwenImageTransformer2DModel": lambda model_cls, weights: weights, "Flux2Transformer2DModel": lambda model_cls, weights: weights, "ZImageTransformer2DModel": lambda model_cls, weights: weights, + "LTX2VideoTransformer3DModel": lambda model_cls, weights: weights, + "LTX2TextConnectors": lambda model_cls, weights: weights, } diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py index 2608c2783f7e..22ca42d37902 100644 --- a/src/diffusers/pipelines/ltx2/connectors.py +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin from ...models.attention import FeedForward from ...models.modeling_utils import ModelMixin from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor @@ -252,7 +253,7 @@ def forward( return hidden_states, attention_mask -class LTX2TextConnectors(ModelMixin, ConfigMixin): +class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin): """ Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio streams. diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 99d6b71ec3d7..9cf847926347 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -21,7 +21,7 @@ from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video from ...models.transformers import LTX2VideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -184,7 +184,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): +class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): r""" Pipeline for text-to-video generation. diff --git a/tests/lora/test_lora_layers_ltx2.py b/tests/lora/test_lora_layers_ltx2.py new file mode 100644 index 000000000000..886ae70b7d46 --- /dev/null +++ b/tests/lora/test_lora_layers_ltx2.py @@ -0,0 +1,293 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import torch +from transformers import AutoTokenizer, Gemma3ForConditionalGeneration + +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + FlowMatchEulerDiscreteScheduler, + LTX2Pipeline, + LTX2VideoTransformer3DModel, +) +from diffusers.pipelines.ltx2 import LTX2TextConnectors +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder +from diffusers.utils.import_utils import is_peft_available + +from ..testing_utils import floats_tensor, require_peft_backend + + +if is_peft_available(): + from peft import LoraConfig + + +sys.path.append(".") + +from .utils import PeftLoraLoaderMixinTests # noqa: E402 + + +@require_peft_backend +class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = LTX2Pipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler + scheduler_kwargs = {} + + transformer_kwargs = { + "in_channels": 4, + "out_channels": 4, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 2, + "attention_head_dim": 8, + "cross_attention_dim": 16, + "audio_in_channels": 4, + "audio_out_channels": 4, + "audio_num_attention_heads": 2, + "audio_attention_head_dim": 4, + "audio_cross_attention_dim": 8, + "num_layers": 1, + "qk_norm": "rms_norm_across_heads", + "caption_channels": 32, + "rope_double_precision": False, + "rope_type": "split", + } + transformer_cls = LTX2VideoTransformer3DModel + + vae_kwargs = { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 4, + "block_out_channels": (8,), + "decoder_block_out_channels": (8,), + "layers_per_block": (1,), + "decoder_layers_per_block": (1, 1), + "spatio_temporal_scaling": (True,), + "decoder_spatio_temporal_scaling": (True,), + "decoder_inject_noise": (False, False), + "downsample_type": ("spatial",), + "upsample_residual": (False,), + "upsample_factor": (1,), + "timestep_conditioning": False, + "patch_size": 1, + "patch_size_t": 1, + "encoder_causal": True, + "decoder_causal": False, + } + vae_cls = AutoencoderKLLTX2Video + + audio_vae_kwargs = { + "base_channels": 4, + "output_channels": 2, + "ch_mult": (1,), + "num_res_blocks": 1, + "attn_resolutions": None, + "in_channels": 2, + "resolution": 32, + "latent_channels": 2, + "norm_type": "pixel", + "causality_axis": "height", + "dropout": 0.0, + "mid_block_add_attention": False, + "sample_rate": 16000, + "mel_hop_length": 160, + "is_causal": True, + "mel_bins": 8, + } + audio_vae_cls = AutoencoderKLLTX2Audio + + vocoder_kwargs = { + "in_channels": 16, # output_channels * mel_bins = 2 * 8 + "hidden_channels": 32, + "out_channels": 2, + "upsample_kernel_sizes": [4, 4], + "upsample_factors": [2, 2], + "resnet_kernel_sizes": [3], + "resnet_dilations": [[1, 3, 5]], + "leaky_relu_negative_slope": 0.1, + "output_sampling_rate": 16000, + } + vocoder_cls = LTX2Vocoder + + connectors_kwargs = { + "caption_channels": 32, # Will be set dynamically from text_encoder + "text_proj_in_factor": 2, # Will be set dynamically from text_encoder + "video_connector_num_attention_heads": 4, + "video_connector_attention_head_dim": 8, + "video_connector_num_layers": 1, + "video_connector_num_learnable_registers": None, + "audio_connector_num_attention_heads": 4, + "audio_connector_attention_head_dim": 8, + "audio_connector_num_layers": 1, + "audio_connector_num_learnable_registers": None, + "connector_rope_base_seq_len": 32, + "rope_theta": 10000.0, + "rope_double_precision": False, + "causal_temporal_positioning": False, + "rope_type": "split", + } + connectors_cls = LTX2TextConnectors + + tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-gemma3" + text_encoder_cls, text_encoder_id = ( + Gemma3ForConditionalGeneration, + "hf-internal-testing/tiny-gemma3", + ) + + denoiser_target_modules = ["to_q", "to_k", "to_out.0"] + + @property + def output_shape(self): + return (1, 5, 32, 32, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 16 + num_channels = 4 + num_frames = 5 + num_latent_frames = 2 + latent_height = 8 + latent_width = 8 + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_latent_frames, num_channels, latent_height, latent_width)) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "a robot dancing", + "num_frames": num_frames, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "frame_rate": 25.0, + "max_sequence_length": sequence_length, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): + # Override to instantiate LTX2-specific components (connectors, audio_vae, vocoder) + torch.manual_seed(0) + text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id) + tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id) + + # Update caption_channels and text_proj_in_factor based on text_encoder config + transformer_kwargs = self.transformer_kwargs.copy() + transformer_kwargs["caption_channels"] = text_encoder.config.text_config.hidden_size + + connectors_kwargs = self.connectors_kwargs.copy() + connectors_kwargs["caption_channels"] = text_encoder.config.text_config.hidden_size + connectors_kwargs["text_proj_in_factor"] = text_encoder.config.text_config.num_hidden_layers + 1 + + torch.manual_seed(0) + transformer = self.transformer_cls(**transformer_kwargs) + + torch.manual_seed(0) + vae = self.vae_cls(**self.vae_kwargs) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + audio_vae = self.audio_vae_cls(**self.audio_vae_kwargs) + + torch.manual_seed(0) + vocoder = self.vocoder_cls(**self.vocoder_kwargs) + + torch.manual_seed(0) + connectors = self.connectors_cls(**connectors_kwargs) + + if scheduler_cls is None: + scheduler_cls = self.scheduler_cls + scheduler = scheduler_cls(**self.scheduler_kwargs) + + rank = 4 + lora_alpha = rank if lora_alpha is None else lora_alpha + + text_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=self.text_encoder_target_modules, + init_lora_weights=False, + use_dora=use_dora, + ) + + denoiser_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=use_dora, + ) + + pipeline_components = { + "transformer": transformer, + "vae": vae, + "audio_vae": audio_vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "connectors": connectors, + "vocoder": vocoder, + } + + return pipeline_components, text_lora_config, denoiser_lora_config + + def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + + def test_simple_inference_with_text_denoiser_lora_unfused(self): + super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + + @unittest.skip("Not supported in LTX2.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in LTX2.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in LTX2.") + def test_modify_padding_mode(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in LTX2.") + def test_simple_inference_with_partial_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in LTX2.") + def test_simple_inference_with_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in LTX2.") + def test_simple_inference_with_text_lora_and_scale(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in LTX2.") + def test_simple_inference_with_text_lora_fused(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in LTX2.") + def test_simple_inference_with_text_lora_save_load(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in LTX2.") + def test_simple_inference_save_pretrained_with_text_lora(self): + pass