Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def load_lora_weights(

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `"lora"` substring.")


self.load_lora_into_unet(
state_dict,
Expand Down Expand Up @@ -641,7 +641,7 @@ def load_lora_weights(

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

self.load_lora_into_unet(
state_dict,
Expand Down Expand Up @@ -1081,7 +1081,7 @@ def load_lora_weights(

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

self.load_lora_into_transformer(
state_dict,
Expand Down Expand Up @@ -1377,7 +1377,7 @@ def load_lora_weights(

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

self.load_lora_into_transformer(
state_dict,
Expand Down Expand Up @@ -1659,7 +1659,7 @@ def load_lora_weights(
)

if not (has_lora_keys or has_norm_keys):
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

transformer_lora_state_dict = {
k: state_dict.get(k)
Expand Down Expand Up @@ -2506,7 +2506,7 @@ def load_lora_weights(

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

self.load_lora_into_transformer(
state_dict,
Expand Down Expand Up @@ -2703,7 +2703,7 @@ def load_lora_weights(

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

self.load_lora_into_transformer(
state_dict,
Expand Down Expand Up @@ -2906,7 +2906,7 @@ def load_lora_weights(

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

self.load_lora_into_transformer(
state_dict,
Expand Down Expand Up @@ -3333,7 +3333,7 @@ def load_lora_weights(

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

self.load_lora_into_transformer(
state_dict,
Expand Down Expand Up @@ -3536,7 +3536,7 @@ def load_lora_weights(

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

self.load_lora_into_transformer(
state_dict,
Expand Down Expand Up @@ -3740,7 +3740,7 @@ def load_lora_weights(

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

self.load_lora_into_transformer(
state_dict,
Expand Down Expand Up @@ -3940,7 +3940,7 @@ def load_lora_weights(

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

self.load_lora_into_transformer(
state_dict,
Expand Down Expand Up @@ -4194,7 +4194,7 @@ def load_lora_weights(
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
Expand Down Expand Up @@ -4471,7 +4471,7 @@ def load_lora_weights(
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
Expand Down Expand Up @@ -4691,7 +4691,7 @@ def load_lora_weights(

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

self.load_lora_into_transformer(
state_dict,
Expand Down Expand Up @@ -4894,7 +4894,7 @@ def load_lora_weights(

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

self.load_lora_into_transformer(
state_dict,
Expand Down Expand Up @@ -5100,7 +5100,7 @@ def load_lora_weights(

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

self.load_lora_into_transformer(
state_dict,
Expand Down Expand Up @@ -5306,7 +5306,7 @@ def load_lora_weights(

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

self.load_lora_into_transformer(
state_dict,
Expand Down Expand Up @@ -5509,7 +5509,7 @@ def load_lora_weights(

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.")

self.load_lora_into_transformer(
state_dict,
Expand Down
22 changes: 2 additions & 20 deletions tests/lora/test_lora_layers_auraflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"]

supports_text_encoder_loras = False

@property
def output_shape(self):
return (1, 8, 8, 3)
Expand Down Expand Up @@ -114,23 +116,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
@unittest.skip("Not supported in AuraFlow.")
def test_modify_padding_mode(self):
pass

@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_partial_text_lora(self):
pass

@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora(self):
pass

@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_and_scale(self):
pass

@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_fused(self):
pass

@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_save_load(self):
pass
22 changes: 2 additions & 20 deletions tests/lora/test_lora_layers_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):

text_encoder_target_modules = ["q", "k", "v", "o"]

supports_text_encoder_loras = False

@property
def output_shape(self):
return (1, 9, 16, 16, 3)
Expand Down Expand Up @@ -147,26 +149,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
def test_modify_padding_mode(self):
pass

@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_partial_text_lora(self):
pass

@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora(self):
pass

@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_and_scale(self):
pass

@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_fused(self):
pass

@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_save_load(self):
pass

@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
22 changes: 2 additions & 20 deletions tests/lora/test_lora_layers_cogview4.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"text_encoder",
)

supports_text_encoder_loras = False

@property
def output_shape(self):
return (1, 32, 32, 3)
Expand Down Expand Up @@ -162,23 +164,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
@unittest.skip("Not supported in CogView4.")
def test_modify_padding_mode(self):
pass

@unittest.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_partial_text_lora(self):
pass

@unittest.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora(self):
pass

@unittest.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_and_scale(self):
pass

@unittest.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_fused(self):
pass

@unittest.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_save_load(self):
pass
22 changes: 2 additions & 20 deletions tests/lora/test_lora_layers_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers"
denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"]

supports_text_encoder_loras = False

@property
def output_shape(self):
return (1, 8, 8, 3)
Expand Down Expand Up @@ -146,23 +148,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
@unittest.skip("Not supported in Flux2.")
def test_modify_padding_mode(self):
pass

@unittest.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_partial_text_lora(self):
pass

@unittest.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora(self):
pass

@unittest.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora_and_scale(self):
pass

@unittest.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora_fused(self):
pass

@unittest.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora_save_load(self):
pass
22 changes: 2 additions & 20 deletions tests/lora/test_lora_layers_hunyuanvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"text_encoder_2",
)

supports_text_encoder_loras = False

@property
def output_shape(self):
return (1, 9, 32, 32, 3)
Expand Down Expand Up @@ -172,26 +174,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
def test_modify_padding_mode(self):
pass

@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_partial_text_lora(self):
pass

@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora(self):
pass

@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_and_scale(self):
pass

@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_fused(self):
pass

@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_save_load(self):
pass


@nightly
@require_torch_accelerator
Expand Down
26 changes: 2 additions & 24 deletions tests/lora/test_lora_layers_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):

denoiser_target_modules = ["to_q", "to_k", "to_out.0"]

supports_text_encoder_loras = False

@property
def output_shape(self):
return (1, 5, 32, 32, 3)
Expand Down Expand Up @@ -267,27 +269,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
@unittest.skip("Not supported in LTX2.")
def test_modify_padding_mode(self):
pass

@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_partial_text_lora(self):
pass

@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_text_lora(self):
pass

@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_text_lora_and_scale(self):
pass

@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_text_lora_fused(self):
pass

@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_text_lora_save_load(self):
pass

@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_save_pretrained_with_text_lora(self):
pass
Loading
Loading