-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[LoRA] add LoRA support to LTX-2 #12933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| _lora_loadable_modules = ["transformer", "connectors"] | ||
| transformer_name = TRANSFORMER_NAME | ||
| connectors_name = LTX2_CONNECTOR_NAME |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The distilled LoRA checkpoint needs the connectors component of the pipeline to be loaded with LoRA too.
| hotswap=hotswap, | ||
| ) | ||
| if connectors_peft_state_dict: | ||
| self.load_lora_into_transformer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it would make sense to rename load_lora_into_transformer to load_lora_into_modules or have separate functions load_lora_into_transformer and load_lora_into_connectors analogous to how StableDiffusionLoraLoaderMixin does it? IIUC load_lora_weights is the intended entry point, so renaming/refactoring this method shouldn't disrupt how LTX2LoraLoaderMixin is used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not too bad this way I guess because the connectors component is a mini transformer in itself. load_lora_into_modules would be a bit inexplicit which we want to avoid.
| ) | ||
|
|
||
| @classmethod | ||
| def load_lora_into_transformer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature has the prefix argument added which is why it differs from the rest of the bunch.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| if not is_correct_format: | ||
| raise ValueError("Invalid LoRA checkpoint.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if not is_correct_format: | |
| raise ValueError("Invalid LoRA checkpoint.") | |
| if not is_correct_format: | |
| raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") |
nit, non-blocking: Could we have a more informative error message here? I'm not sure if the suggestion is exactly correct but I think we should give an indication of what a valid LoRA checkpoint should look like.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion looks correct and it applies module-wide. Maybe this could clubbed with #12933 (comment) if you want to take a crack?
| # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora | ||
| def fuse_lora( | ||
| self, | ||
| components: List[str] = ["transformer"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| components: List[str] = ["transformer"], | |
| components: List[str] = ["transformer", "connectors"], |
Should the default argument for components include "connectors"?
Also, I think it might make sense to refactor this into something like
def fuse_lora(
self,
components: List[str] = [],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
if len(components) == 0:
components = self._lora_loadable_modules
# Rest of implementation same as before
...Furthermore, since _lora_loadable_modules exists in the parent class LoraBaseMixin, we could push this logic into LoraBaseMixin.fuse_lora and not have to keep overriding fuse_lora in the child classes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can keep the default as is and then follow the rest. But on a second thought, we would rather want to have the users pass the component names they want to fuse explicitly.
Perhaps, we can add validation checks like:
if not components:
raise ValueError
if any(c not in self._lora_loadable_modules for c in components):
raise ValueErrorSomething like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Furthermore, I assume connectors shouldn't be a default component to be fused. It's probably only going to be applicable for the distilled LoRA checkpoint but not to other checkpoints (like the IC LoRAs LTX2 made available; example).
Does this make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps, we can add validation checks like:
I think those validation checks are already implemented:
diffusers/src/diffusers/loaders/lora_base.py
Lines 593 to 601 in 02c7adc
| if len(components) == 0: | |
| raise ValueError("`components` cannot be an empty list.") | |
| # Need to retrieve the names as `adapter_names` can be None. So we cannot directly use it | |
| # in `self._merged_adapters = self._merged_adapters | merged_adapter_names`. | |
| merged_adapter_names = set() | |
| for fuse_component in components: | |
| if fuse_component not in self._lora_loadable_modules: | |
| raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.") |
In my opinion the ideal behavior is that fuse_lora() can be called without any arguments, and this should attempt to fuse all (active?) LoRAs in all possible modules; that is, all modules in _lora_loadable_modules. If a module in _lora_loadable_modules doesn't have any LoRAs which target it, this will be handled gracefully (presumably as a no-op). A user can then explicitly specify components to fuse if they want finer control (for example, they only want the LoRAs to be fused on some modules). unfuse_lora should then have the analogous behavior.
However, I'm not sure how feasible this is for the current implementation, and it's possible that my view is mistaken (maybe there are some factors I haven't considered?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my opinion the ideal behavior is that fuse_lora() can be called without any arguments, and this should attempt to fuse all (active?) LoRAs in all possible modules; that is, all modules in _lora_loadable_modules
Actually, there can be several considerations:
- Load one LoRA in say, text encoder.
- Load two LoRAs into the DiT, fuse them, keep them unfused, etc.
This is the reason why we allow the users to pass components explicitly. I think this makes the feature more configurable.
If a module in _lora_loadable_modules doesn't have any LoRAs which target it, this will be handled gracefully (presumably as a no-op).
It will be a no-op but that is an implicit behaviour which I would like to avoid
| ) | ||
|
|
||
| # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora | ||
| def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): | |
| def unfuse_lora(self, components: List[str] = ["transformer", "connectors"], **kwargs): |
See #12933 (comment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #12933 (comment)
| def test_modify_padding_mode(self): | ||
| pass | ||
|
|
||
| @unittest.skip("Text encoder LoRA is not supported in LTX2.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we could have a flag like supports_text_encoder_lora in PeftLoraLoaderMixinTests that can skip all of the text encoder LoRA tests at once?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Do you want to take a crack at that?
dg845
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Left a few questions.
|
This PR also allows loading IC LoRAs: https://huggingface.co/Lightricks/LTX-2-19b-IC-LoRA-Canny-Control @linoytsaban / @asomoza do you want to give this a try? I think you would first need to make the I2V pipeline class ( |
What does this PR do?
Adds support for loading non-diffusers LoRA into
LTX2Pipeline. More specifically, this checkpoint or any other LoRA checkpoints that are obtained with the trainer shipped by the official codebase.The said LoRA is also crucial for reducing the number of inference steps and seems to be also crucial for the two-stage pipeline as implemented in here.
I decided to give this LoRA a try on the single-stage T2V pipeline and I am getting decent results:
video_distilled.mp4
Code
Note that we need the following diff on
pipeline_ltx2.pyas well:The
sigmasabove are pre-computed and come from here.Cc: @asomoza @linoytsaban for awareness.