Skip to content

Commit 144704f

Browse files
fix for loading of Kohya's Flux.2 dev lora
1 parent cf3c65c commit 144704f

File tree

2 files changed

+76
-25
lines changed

2 files changed

+76
-25
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,10 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
358358

359359
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
360360
# are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
361-
def _convert_kohya_flux_lora_to_diffusers(state_dict):
361+
def _convert_kohya_flux_lora_to_diffusers(
362+
state_dict,
363+
version_flux2 = False,
364+
):
362365
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
363366
if sds_key + ".lora_down.weight" not in sds_sd:
364367
return
@@ -449,7 +452,15 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
449452

450453
def _convert_sd_scripts_to_ai_toolkit(sds_sd):
451454
ait_sd = {}
452-
for i in range(19):
455+
456+
max_num_double_blocks, max_num_single_blocks = -1, -1
457+
for key in list(sds_sd.keys()):
458+
if key.startswith("lora_unet_double_blocks_"):
459+
max_num_double_blocks = max(max_num_double_blocks, int(key.split("_")[4]))
460+
if key.startswith("lora_unet_single_blocks_"):
461+
max_num_single_blocks = max(max_num_single_blocks, int(key.split("_")[4]))
462+
463+
for i in range(max_num_double_blocks+1):
453464
_convert_to_ai_toolkit(
454465
sds_sd,
455466
ait_sd,
@@ -470,13 +481,21 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
470481
sds_sd,
471482
ait_sd,
472483
f"lora_unet_double_blocks_{i}_img_mlp_0",
473-
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
484+
(
485+
f"transformer.transformer_blocks.{i}.ff.linear_in"
486+
if version_flux2 else
487+
f"transformer.transformer_blocks.{i}.ff.net.0.proj"
488+
),
474489
)
475490
_convert_to_ai_toolkit(
476491
sds_sd,
477492
ait_sd,
478493
f"lora_unet_double_blocks_{i}_img_mlp_2",
479-
f"transformer.transformer_blocks.{i}.ff.net.2",
494+
(
495+
f"transformer.transformer_blocks.{i}.ff.linear_out"
496+
if version_flux2 else
497+
f"transformer.transformer_blocks.{i}.ff.net.2"
498+
),
480499
)
481500
_convert_to_ai_toolkit(
482501
sds_sd,
@@ -504,13 +523,21 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
504523
sds_sd,
505524
ait_sd,
506525
f"lora_unet_double_blocks_{i}_txt_mlp_0",
507-
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
526+
(
527+
f"transformer.transformer_blocks.{i}.ff_context.linear_in"
528+
if version_flux2 else
529+
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj"
530+
),
508531
)
509532
_convert_to_ai_toolkit(
510533
sds_sd,
511534
ait_sd,
512535
f"lora_unet_double_blocks_{i}_txt_mlp_2",
513-
f"transformer.transformer_blocks.{i}.ff_context.net.2",
536+
(
537+
f"transformer.transformer_blocks.{i}.ff_context.linear_out"
538+
if version_flux2 else
539+
f"transformer.transformer_blocks.{i}.ff_context.net.2"
540+
),
514541
)
515542
_convert_to_ai_toolkit(
516543
sds_sd,
@@ -519,25 +546,39 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
519546
f"transformer.transformer_blocks.{i}.norm1_context.linear",
520547
)
521548

522-
for i in range(38):
523-
_convert_to_ai_toolkit_cat(
524-
sds_sd,
525-
ait_sd,
526-
f"lora_unet_single_blocks_{i}_linear1",
527-
[
528-
f"transformer.single_transformer_blocks.{i}.attn.to_q",
529-
f"transformer.single_transformer_blocks.{i}.attn.to_k",
530-
f"transformer.single_transformer_blocks.{i}.attn.to_v",
531-
f"transformer.single_transformer_blocks.{i}.proj_mlp",
532-
],
533-
dims=[3072, 3072, 3072, 12288],
534-
)
535-
_convert_to_ai_toolkit(
536-
sds_sd,
537-
ait_sd,
538-
f"lora_unet_single_blocks_{i}_linear2",
539-
f"transformer.single_transformer_blocks.{i}.proj_out",
540-
)
549+
for i in range(max_num_single_blocks+1):
550+
if version_flux2:
551+
_convert_to_ai_toolkit(
552+
sds_sd,
553+
ait_sd,
554+
f"lora_unet_single_blocks_{i}_linear1",
555+
f"transformer.single_transformer_blocks.{i}.attn.to_qkv_mlp_proj",
556+
)
557+
_convert_to_ai_toolkit(
558+
sds_sd,
559+
ait_sd,
560+
f"lora_unet_single_blocks_{i}_linear2",
561+
f"transformer.single_transformer_blocks.{i}.attn.to_out",
562+
)
563+
else:
564+
_convert_to_ai_toolkit_cat(
565+
sds_sd,
566+
ait_sd,
567+
f"lora_unet_single_blocks_{i}_linear1",
568+
[
569+
f"transformer.single_transformer_blocks.{i}.attn.to_q",
570+
f"transformer.single_transformer_blocks.{i}.attn.to_k",
571+
f"transformer.single_transformer_blocks.{i}.attn.to_v",
572+
f"transformer.single_transformer_blocks.{i}.proj_mlp",
573+
],
574+
dims=[3072, 3072, 3072, 12288],
575+
)
576+
_convert_to_ai_toolkit(
577+
sds_sd,
578+
ait_sd,
579+
f"lora_unet_single_blocks_{i}_linear2",
580+
f"transformer.single_transformer_blocks.{i}.proj_out",
581+
)
541582
_convert_to_ai_toolkit(
542583
sds_sd,
543584
ait_sd,

src/diffusers/loaders/lora_pipeline.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5472,6 +5472,16 @@ def lora_state_dict(
54725472
logger.warning(warn_msg)
54735473
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
54745474

5475+
is_kohya = any(".lora_down.weight" in k for k in state_dict)
5476+
if is_kohya:
5477+
state_dict = _convert_kohya_flux_lora_to_diffusers(
5478+
state_dict,
5479+
version_flux2=True,
5480+
)
5481+
# Kohya already takes care of scaling the LoRA parameters with alpha.
5482+
for k in state_dict:
5483+
assert "alpha" not in k, f"Found key with alpha: {k}"
5484+
54755485
is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
54765486
if is_ai_toolkit:
54775487
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)

0 commit comments

Comments
 (0)