Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/api/loaders/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
- [`ZImageLoraLoaderMixin`] provides similar functions for [Z-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/zimage).
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
- [`LTX2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx2).
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.

> [!TIP]
Expand Down Expand Up @@ -62,6 +63,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi

[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin

## LTX2LoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.LTX2LoraLoaderMixin

## CogVideoXLoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/api/pipelines/ltx2.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

# LTX-2

<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>

LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.

You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def text_encoder_attn_modules(text_encoder):
"SD3LoraLoaderMixin",
"AuraFlowLoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LTX2LoraLoaderMixin",
"LTXVideoLoraLoaderMixin",
"LoraLoaderMixin",
"FluxLoraLoaderMixin",
Expand Down Expand Up @@ -121,6 +122,7 @@ def text_encoder_attn_modules(text_encoder):
HunyuanVideoLoraLoaderMixin,
KandinskyLoraLoaderMixin,
LoraLoaderMixin,
LTX2LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
Mochi1LoraLoaderMixin,
Expand Down
48 changes: 48 additions & 0 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2140,6 +2140,54 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
return converted_state_dict


def _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
# Remove the prefix
state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{non_diffusers_prefix}.")}
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}

if non_diffusers_prefix == "diffusion_model":
rename_dict = {
"patchify_proj": "proj_in",
"audio_patchify_proj": "audio_proj_in",
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
"q_norm": "norm_q",
"k_norm": "norm_k",
}
else:
rename_dict = {"aggregate_embed": "text_proj_in"}

# Apply renaming
renamed_state_dict = {}
for key, value in converted_state_dict.items():
new_key = key[:]
for old_pattern, new_pattern in rename_dict.items():
new_key = new_key.replace(old_pattern, new_pattern)
renamed_state_dict[new_key] = value

# Handle adaln_single -> time_embed and audio_adaln_single -> audio_time_embed
final_state_dict = {}
for key, value in renamed_state_dict.items():
if key.startswith("adaln_single."):
new_key = key.replace("adaln_single.", "time_embed.")
final_state_dict[new_key] = value
elif key.startswith("audio_adaln_single."):
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
final_state_dict[new_key] = value
else:
final_state_dict[key] = value

# Add transformer prefix
prefix = "transformer" if non_diffusers_prefix == "diffusion_model" else "connectors"
final_state_dict = {f"{prefix}.{k}": v for k, v in final_state_dict.items()}

return final_state_dict


def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
if has_diffusion_model:
Expand Down
229 changes: 229 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
_convert_non_diffusers_flux2_lora_to_diffusers,
_convert_non_diffusers_hidream_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_ltx2_lora_to_diffusers,
_convert_non_diffusers_ltxv_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_qwen_lora_to_diffusers,
Expand All @@ -74,6 +75,7 @@
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
TRANSFORMER_NAME = "transformer"
LTX2_CONNECTOR_NAME = "connectors"

_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}

Expand Down Expand Up @@ -3011,6 +3013,233 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
super().unfuse_lora(components=components, **kwargs)


class LTX2LoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`LTX2VideoTransformer3DModel`]. Specific to [`LTX2Pipeline`].
"""

_lora_loadable_modules = ["transformer", "connectors"]
transformer_name = TRANSFORMER_NAME
connectors_name = LTX2_CONNECTOR_NAME
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distilled LoRA checkpoint needs the connectors component of the pipeline to be loaded with LoRA too.


@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)

allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True

user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}

state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)

is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

final_state_dict = state_dict
is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict)
has_connector = any(k.startswith("text_embedding_projection.") for k in state_dict)
if is_non_diffusers_format:
final_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict)
if has_connector:
connectors_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(
state_dict, "text_embedding_projection"
)
final_state_dict.update(connectors_state_dict)
out = (final_state_dict, metadata) if return_lora_metadata else final_state_dict
return out

def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")

low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)

# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()

# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
Comment on lines +3117 to +3118
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

nit, non-blocking: Could we have a more informative error message here? I'm not sure if the suggestion is exactly correct but I think we should give an indication of what a valid LoRA checkpoint should look like.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion looks correct and it applies module-wide. Maybe this could clubbed with #12933 (comment) if you want to take a crack?


transformer_peft_state_dict = {
k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")
}
connectors_peft_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.connectors_name}.")}
self.load_lora_into_transformer(
transformer_peft_state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
if connectors_peft_state_dict:
self.load_lora_into_transformer(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fine.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would make sense to rename load_lora_into_transformer to load_lora_into_modules or have separate functions load_lora_into_transformer and load_lora_into_connectors analogous to how StableDiffusionLoraLoaderMixin does it? IIUC load_lora_weights is the intended entry point, so renaming/refactoring this method shouldn't disrupt how LTX2LoraLoaderMixin is used.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not too bad this way I guess because the connectors component is a mini transformer in itself. load_lora_into_modules would be a bit inexplicit which we want to avoid.

connectors_peft_state_dict,
transformer=getattr(self, self.connectors_name)
if not hasattr(self, "connectors")
else self.connectors,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
prefix=self.connectors_name,
)

@classmethod
def load_lora_into_transformer(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature has the prefix argument added which is why it differs from the rest of the bunch.

cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
prefix: str = "transformer",
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)

# Load the layers corresponding to transformer.
logger.info(f"Loading {prefix}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
prefix=prefix,
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
lora_layers = {}
lora_metadata = {}

if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata

if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")

cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
components: List[str] = ["transformer"],
components: List[str] = ["transformer", "connectors"],

Should the default argument for components include "connectors"?

Also, I think it might make sense to refactor this into something like

    def fuse_lora(
        self,
        components: List[str] = [],
        lora_scale: float = 1.0,
        safe_fusing: bool = False,
        adapter_names: Optional[List[str]] = None,
        **kwargs,
    ):
        if len(components) == 0:
            components = self._lora_loadable_modules
        # Rest of implementation same as before
        ...

Furthermore, since _lora_loadable_modules exists in the parent class LoraBaseMixin, we could push this logic into LoraBaseMixin.fuse_lora and not have to keep overriding fuse_lora in the child classes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can keep the default as is and then follow the rest. But on a second thought, we would rather want to have the users pass the component names they want to fuse explicitly.

Perhaps, we can add validation checks like:

if not components:
   raise ValueError

if any(c not in self._lora_loadable_modules for c in components):
    raise ValueError

Something like this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Furthermore, I assume connectors shouldn't be a default component to be fused. It's probably only going to be applicable for the distilled LoRA checkpoint but not to other checkpoints (like the IC LoRAs LTX2 made available; example).

Does this make sense?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps, we can add validation checks like:

I think those validation checks are already implemented:

if len(components) == 0:
raise ValueError("`components` cannot be an empty list.")
# Need to retrieve the names as `adapter_names` can be None. So we cannot directly use it
# in `self._merged_adapters = self._merged_adapters | merged_adapter_names`.
merged_adapter_names = set()
for fuse_component in components:
if fuse_component not in self._lora_loadable_modules:
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")

In my opinion the ideal behavior is that fuse_lora() can be called without any arguments, and this should attempt to fuse all (active?) LoRAs in all possible modules; that is, all modules in _lora_loadable_modules. If a module in _lora_loadable_modules doesn't have any LoRAs which target it, this will be handled gracefully (presumably as a no-op). A user can then explicitly specify components to fuse if they want finer control (for example, they only want the LoRAs to be fused on some modules). unfuse_lora should then have the analogous behavior.

However, I'm not sure how feasible this is for the current implementation, and it's possible that my view is mistaken (maybe there are some factors I haven't considered?).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion the ideal behavior is that fuse_lora() can be called without any arguments, and this should attempt to fuse all (active?) LoRAs in all possible modules; that is, all modules in _lora_loadable_modules

Actually, there can be several considerations:

  • Load one LoRA in say, text encoder.
  • Load two LoRAs into the DiT, fuse them, keep them unfused, etc.

This is the reason why we allow the users to pass components explicitly. I think this makes the feature more configurable.

If a module in _lora_loadable_modules doesn't have any LoRAs which target it, this will be handled gracefully (presumably as a no-op).

It will be a no-op but that is an implicit behaviour which I would like to avoid

lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer", "connectors"], **kwargs):

See #12933 (comment).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)


class SanaLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
"LTX2VideoTransformer3DModel": lambda model_cls, weights: weights,
"LTX2TextConnectors": lambda model_cls, weights: weights,
}


Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/pipelines/ltx2/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...models.attention import FeedForward
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
Expand Down Expand Up @@ -252,7 +253,7 @@ def forward(
return hidden_states, attention_mask


class LTX2TextConnectors(ModelMixin, ConfigMixin):
class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
"""
Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio
streams.
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin
from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
from ...models.transformers import LTX2VideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
Expand Down Expand Up @@ -184,7 +184,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg


class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
r"""
Pipeline for text-to-video generation.

Expand Down
Loading
Loading