-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[LoRA] add LoRA support to LTX-2 #12933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1acc160
efdfab6
e7164fc
596d978
527b89d
9c95963
8a4bd37
9abdd93
be5644e
397b656
84a6573
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -48,6 +48,7 @@ | |||||||||||||||||||||||
| _convert_non_diffusers_flux2_lora_to_diffusers, | ||||||||||||||||||||||||
| _convert_non_diffusers_hidream_lora_to_diffusers, | ||||||||||||||||||||||||
| _convert_non_diffusers_lora_to_diffusers, | ||||||||||||||||||||||||
| _convert_non_diffusers_ltx2_lora_to_diffusers, | ||||||||||||||||||||||||
| _convert_non_diffusers_ltxv_lora_to_diffusers, | ||||||||||||||||||||||||
| _convert_non_diffusers_lumina2_lora_to_diffusers, | ||||||||||||||||||||||||
| _convert_non_diffusers_qwen_lora_to_diffusers, | ||||||||||||||||||||||||
|
|
@@ -74,6 +75,7 @@ | |||||||||||||||||||||||
| TEXT_ENCODER_NAME = "text_encoder" | ||||||||||||||||||||||||
| UNET_NAME = "unet" | ||||||||||||||||||||||||
| TRANSFORMER_NAME = "transformer" | ||||||||||||||||||||||||
| LTX2_CONNECTOR_NAME = "connectors" | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"} | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
@@ -3011,6 +3013,233 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): | |||||||||||||||||||||||
| super().unfuse_lora(components=components, **kwargs) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class LTX2LoraLoaderMixin(LoraBaseMixin): | ||||||||||||||||||||||||
| r""" | ||||||||||||||||||||||||
| Load LoRA layers into [`LTX2VideoTransformer3DModel`]. Specific to [`LTX2Pipeline`]. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| _lora_loadable_modules = ["transformer", "connectors"] | ||||||||||||||||||||||||
| transformer_name = TRANSFORMER_NAME | ||||||||||||||||||||||||
| connectors_name = LTX2_CONNECTOR_NAME | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||
| @validate_hf_hub_args | ||||||||||||||||||||||||
| def lora_state_dict( | ||||||||||||||||||||||||
| cls, | ||||||||||||||||||||||||
| pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | ||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||
| r""" | ||||||||||||||||||||||||
| See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| # Load the main state dict first which has the LoRA layers for either of | ||||||||||||||||||||||||
| # transformer and text encoder or both. | ||||||||||||||||||||||||
| cache_dir = kwargs.pop("cache_dir", None) | ||||||||||||||||||||||||
| force_download = kwargs.pop("force_download", False) | ||||||||||||||||||||||||
| proxies = kwargs.pop("proxies", None) | ||||||||||||||||||||||||
| local_files_only = kwargs.pop("local_files_only", None) | ||||||||||||||||||||||||
| token = kwargs.pop("token", None) | ||||||||||||||||||||||||
| revision = kwargs.pop("revision", None) | ||||||||||||||||||||||||
| subfolder = kwargs.pop("subfolder", None) | ||||||||||||||||||||||||
| weight_name = kwargs.pop("weight_name", None) | ||||||||||||||||||||||||
| use_safetensors = kwargs.pop("use_safetensors", None) | ||||||||||||||||||||||||
| return_lora_metadata = kwargs.pop("return_lora_metadata", False) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| allow_pickle = False | ||||||||||||||||||||||||
| if use_safetensors is None: | ||||||||||||||||||||||||
| use_safetensors = True | ||||||||||||||||||||||||
| allow_pickle = True | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| state_dict, metadata = _fetch_state_dict( | ||||||||||||||||||||||||
| pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, | ||||||||||||||||||||||||
| weight_name=weight_name, | ||||||||||||||||||||||||
| use_safetensors=use_safetensors, | ||||||||||||||||||||||||
| local_files_only=local_files_only, | ||||||||||||||||||||||||
| cache_dir=cache_dir, | ||||||||||||||||||||||||
| force_download=force_download, | ||||||||||||||||||||||||
| proxies=proxies, | ||||||||||||||||||||||||
| token=token, | ||||||||||||||||||||||||
| revision=revision, | ||||||||||||||||||||||||
| subfolder=subfolder, | ||||||||||||||||||||||||
| user_agent=user_agent, | ||||||||||||||||||||||||
| allow_pickle=allow_pickle, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| is_dora_scale_present = any("dora_scale" in k for k in state_dict) | ||||||||||||||||||||||||
| if is_dora_scale_present: | ||||||||||||||||||||||||
| warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." | ||||||||||||||||||||||||
| logger.warning(warn_msg) | ||||||||||||||||||||||||
| state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| final_state_dict = state_dict | ||||||||||||||||||||||||
| is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict) | ||||||||||||||||||||||||
| has_connector = any(k.startswith("text_embedding_projection.") for k in state_dict) | ||||||||||||||||||||||||
| if is_non_diffusers_format: | ||||||||||||||||||||||||
| final_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict) | ||||||||||||||||||||||||
| if has_connector: | ||||||||||||||||||||||||
| connectors_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers( | ||||||||||||||||||||||||
| state_dict, "text_embedding_projection" | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| final_state_dict.update(connectors_state_dict) | ||||||||||||||||||||||||
| out = (final_state_dict, metadata) if return_lora_metadata else final_state_dict | ||||||||||||||||||||||||
| return out | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def load_lora_weights( | ||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||
| pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | ||||||||||||||||||||||||
| adapter_name: Optional[str] = None, | ||||||||||||||||||||||||
| hotswap: bool = False, | ||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| if not USE_PEFT_BACKEND: | ||||||||||||||||||||||||
| raise ValueError("PEFT backend is required for this method.") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) | ||||||||||||||||||||||||
| if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): | ||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||
| "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # if a dict is passed, copy it instead of modifying it inplace | ||||||||||||||||||||||||
| if isinstance(pretrained_model_name_or_path_or_dict, dict): | ||||||||||||||||||||||||
| pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # First, ensure that the checkpoint is a compatible one and can be successfully loaded. | ||||||||||||||||||||||||
| kwargs["return_lora_metadata"] = True | ||||||||||||||||||||||||
| state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| is_correct_format = all("lora" in key for key in state_dict.keys()) | ||||||||||||||||||||||||
| if not is_correct_format: | ||||||||||||||||||||||||
| raise ValueError("Invalid LoRA checkpoint.") | ||||||||||||||||||||||||
|
Comment on lines
+3117
to
+3118
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
nit, non-blocking: Could we have a more informative error message here? I'm not sure if the suggestion is exactly correct but I think we should give an indication of what a valid LoRA checkpoint should look like.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion looks correct and it applies module-wide. Maybe this could clubbed with #12933 (comment) if you want to take a crack? |
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| transformer_peft_state_dict = { | ||||||||||||||||||||||||
| k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.") | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| connectors_peft_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.connectors_name}.")} | ||||||||||||||||||||||||
| self.load_lora_into_transformer( | ||||||||||||||||||||||||
| transformer_peft_state_dict, | ||||||||||||||||||||||||
| transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, | ||||||||||||||||||||||||
| adapter_name=adapter_name, | ||||||||||||||||||||||||
| metadata=metadata, | ||||||||||||||||||||||||
| _pipeline=self, | ||||||||||||||||||||||||
| low_cpu_mem_usage=low_cpu_mem_usage, | ||||||||||||||||||||||||
| hotswap=hotswap, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| if connectors_peft_state_dict: | ||||||||||||||||||||||||
| self.load_lora_into_transformer( | ||||||||||||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be fine.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it would make sense to rename
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not too bad this way I guess because the |
||||||||||||||||||||||||
| connectors_peft_state_dict, | ||||||||||||||||||||||||
| transformer=getattr(self, self.connectors_name) | ||||||||||||||||||||||||
| if not hasattr(self, "connectors") | ||||||||||||||||||||||||
| else self.connectors, | ||||||||||||||||||||||||
| adapter_name=adapter_name, | ||||||||||||||||||||||||
| metadata=metadata, | ||||||||||||||||||||||||
| _pipeline=self, | ||||||||||||||||||||||||
| low_cpu_mem_usage=low_cpu_mem_usage, | ||||||||||||||||||||||||
| hotswap=hotswap, | ||||||||||||||||||||||||
| prefix=self.connectors_name, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||
| def load_lora_into_transformer( | ||||||||||||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The signature has the |
||||||||||||||||||||||||
| cls, | ||||||||||||||||||||||||
| state_dict, | ||||||||||||||||||||||||
| transformer, | ||||||||||||||||||||||||
| adapter_name=None, | ||||||||||||||||||||||||
| _pipeline=None, | ||||||||||||||||||||||||
| low_cpu_mem_usage=False, | ||||||||||||||||||||||||
| hotswap: bool = False, | ||||||||||||||||||||||||
| metadata=None, | ||||||||||||||||||||||||
| prefix: str = "transformer", | ||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): | ||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||
| "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Load the layers corresponding to transformer. | ||||||||||||||||||||||||
| logger.info(f"Loading {prefix}.") | ||||||||||||||||||||||||
| transformer.load_lora_adapter( | ||||||||||||||||||||||||
| state_dict, | ||||||||||||||||||||||||
| network_alphas=None, | ||||||||||||||||||||||||
| adapter_name=adapter_name, | ||||||||||||||||||||||||
| metadata=metadata, | ||||||||||||||||||||||||
| _pipeline=_pipeline, | ||||||||||||||||||||||||
| low_cpu_mem_usage=low_cpu_mem_usage, | ||||||||||||||||||||||||
| hotswap=hotswap, | ||||||||||||||||||||||||
| prefix=prefix, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||
| # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights | ||||||||||||||||||||||||
| def save_lora_weights( | ||||||||||||||||||||||||
| cls, | ||||||||||||||||||||||||
| save_directory: Union[str, os.PathLike], | ||||||||||||||||||||||||
| transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, | ||||||||||||||||||||||||
| is_main_process: bool = True, | ||||||||||||||||||||||||
| weight_name: str = None, | ||||||||||||||||||||||||
| save_function: Callable = None, | ||||||||||||||||||||||||
| safe_serialization: bool = True, | ||||||||||||||||||||||||
| transformer_lora_adapter_metadata: Optional[dict] = None, | ||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||
| r""" | ||||||||||||||||||||||||
| See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| lora_layers = {} | ||||||||||||||||||||||||
| lora_metadata = {} | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if transformer_lora_layers: | ||||||||||||||||||||||||
| lora_layers[cls.transformer_name] = transformer_lora_layers | ||||||||||||||||||||||||
| lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if not lora_layers: | ||||||||||||||||||||||||
| raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| cls._save_lora_weights( | ||||||||||||||||||||||||
| save_directory=save_directory, | ||||||||||||||||||||||||
| lora_layers=lora_layers, | ||||||||||||||||||||||||
| lora_metadata=lora_metadata, | ||||||||||||||||||||||||
| is_main_process=is_main_process, | ||||||||||||||||||||||||
| weight_name=weight_name, | ||||||||||||||||||||||||
| save_function=save_function, | ||||||||||||||||||||||||
| safe_serialization=safe_serialization, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora | ||||||||||||||||||||||||
| def fuse_lora( | ||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||
| components: List[str] = ["transformer"], | ||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Should the default argument for Also, I think it might make sense to refactor this into something like def fuse_lora(
self,
components: List[str] = [],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
if len(components) == 0:
components = self._lora_loadable_modules
# Rest of implementation same as before
...Furthermore, since
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can keep the default as is and then follow the rest. But on a second thought, we would rather want to have the users pass the component names they want to fuse explicitly. Perhaps, we can add validation checks like: if not components:
raise ValueError
if any(c not in self._lora_loadable_modules for c in components):
raise ValueErrorSomething like this?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Furthermore, I assume Does this make sense?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think those validation checks are already implemented: diffusers/src/diffusers/loaders/lora_base.py Lines 593 to 601 in 02c7adc
In my opinion the ideal behavior is that However, I'm not sure how feasible this is for the current implementation, and it's possible that my view is mistaken (maybe there are some factors I haven't considered?).
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Actually, there can be several considerations:
This is the reason why we allow the users to pass
It will be a no-op but that is an implicit behaviour which I would like to avoid |
||||||||||||||||||||||||
| lora_scale: float = 1.0, | ||||||||||||||||||||||||
| safe_fusing: bool = False, | ||||||||||||||||||||||||
| adapter_names: Optional[List[str]] = None, | ||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||
| r""" | ||||||||||||||||||||||||
| See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| super().fuse_lora( | ||||||||||||||||||||||||
| components=components, | ||||||||||||||||||||||||
| lora_scale=lora_scale, | ||||||||||||||||||||||||
| safe_fusing=safe_fusing, | ||||||||||||||||||||||||
| adapter_names=adapter_names, | ||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora | ||||||||||||||||||||||||
| def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): | ||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
See #12933 (comment).
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See #12933 (comment) |
||||||||||||||||||||||||
| r""" | ||||||||||||||||||||||||
| See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| super().unfuse_lora(components=components, **kwargs) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class SanaLoraLoaderMixin(LoraBaseMixin): | ||||||||||||||||||||||||
| r""" | ||||||||||||||||||||||||
| Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`]. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The distilled LoRA checkpoint needs the
connectorscomponent of the pipeline to be loaded with LoRA too.