Skip to content

Commit 5518d07

Browse files
fix for loading of Kohya's Flux.2 dev lora
1 parent cf3c65c commit 5518d07

File tree

2 files changed

+69
-20
lines changed

2 files changed

+69
-20
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,10 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
358358

359359
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
360360
# are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
361-
def _convert_kohya_flux_lora_to_diffusers(state_dict):
361+
def _convert_kohya_flux_lora_to_diffusers(
362+
state_dict,
363+
version_flux2 = False,
364+
):
362365
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
363366
if sds_key + ".lora_down.weight" not in sds_sd:
364367
return
@@ -449,7 +452,15 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
449452

450453
def _convert_sd_scripts_to_ai_toolkit(sds_sd):
451454
ait_sd = {}
452-
for i in range(19):
455+
456+
max_num_double_blocks, max_num_single_blocks = -1, -1
457+
for key in list(sds_sd.keys()):
458+
if key.startswith("lora_unet_double_blocks_"):
459+
max_num_double_blocks = max(max_num_double_blocks, int(key.split("_")[4]))
460+
if key.startswith("lora_unet_single_blocks_"):
461+
max_num_single_blocks = max(max_num_single_blocks, int(key.split("_")[4]))
462+
463+
for i in range(max_num_double_blocks+1):
453464
_convert_to_ai_toolkit(
454465
sds_sd,
455466
ait_sd,
@@ -470,13 +481,21 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
470481
sds_sd,
471482
ait_sd,
472483
f"lora_unet_double_blocks_{i}_img_mlp_0",
473-
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
484+
(
485+
f"transformer.transformer_blocks.{i}.ff.linear_in"
486+
if version_flux2 else
487+
f"transformer.transformer_blocks.{i}.ff.net.0.proj"
488+
),
474489
)
475490
_convert_to_ai_toolkit(
476491
sds_sd,
477492
ait_sd,
478493
f"lora_unet_double_blocks_{i}_img_mlp_2",
479-
f"transformer.transformer_blocks.{i}.ff.net.2",
494+
(
495+
f"transformer.transformer_blocks.{i}.ff.linear_out"
496+
if version_flux2 else
497+
f"transformer.transformer_blocks.{i}.ff.net.2"
498+
),
480499
)
481500
_convert_to_ai_toolkit(
482501
sds_sd,
@@ -504,13 +523,21 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
504523
sds_sd,
505524
ait_sd,
506525
f"lora_unet_double_blocks_{i}_txt_mlp_0",
507-
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
526+
(
527+
f"transformer.transformer_blocks.{i}.ff_context.linear_in"
528+
if version_flux2 else
529+
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj"
530+
),
508531
)
509532
_convert_to_ai_toolkit(
510533
sds_sd,
511534
ait_sd,
512535
f"lora_unet_double_blocks_{i}_txt_mlp_2",
513-
f"transformer.transformer_blocks.{i}.ff_context.net.2",
536+
(
537+
f"transformer.transformer_blocks.{i}.ff_context.linear_out"
538+
if version_flux2 else
539+
f"transformer.transformer_blocks.{i}.ff_context.net.2"
540+
),
514541
)
515542
_convert_to_ai_toolkit(
516543
sds_sd,
@@ -519,24 +546,36 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
519546
f"transformer.transformer_blocks.{i}.norm1_context.linear",
520547
)
521548

522-
for i in range(38):
523-
_convert_to_ai_toolkit_cat(
524-
sds_sd,
525-
ait_sd,
526-
f"lora_unet_single_blocks_{i}_linear1",
527-
[
528-
f"transformer.single_transformer_blocks.{i}.attn.to_q",
529-
f"transformer.single_transformer_blocks.{i}.attn.to_k",
530-
f"transformer.single_transformer_blocks.{i}.attn.to_v",
531-
f"transformer.single_transformer_blocks.{i}.proj_mlp",
532-
],
533-
dims=[3072, 3072, 3072, 12288],
534-
)
549+
for i in range(max_num_single_blocks+1):
550+
if version_flux2:
551+
_convert_to_ai_toolkit(
552+
sds_sd,
553+
ait_sd,
554+
f"lora_unet_single_blocks_{i}_linear1",
555+
f"transformer.single_transformer_blocks.{i}.attn.to_qkv_mlp_proj",
556+
)
557+
else:
558+
_convert_to_ai_toolkit_cat(
559+
sds_sd,
560+
ait_sd,
561+
f"lora_unet_single_blocks_{i}_linear1",
562+
[
563+
f"transformer.single_transformer_blocks.{i}.attn.to_q",
564+
f"transformer.single_transformer_blocks.{i}.attn.to_k",
565+
f"transformer.single_transformer_blocks.{i}.attn.to_v",
566+
f"transformer.single_transformer_blocks.{i}.proj_mlp",
567+
],
568+
dims=[3072, 3072, 3072, 12288],
569+
)
535570
_convert_to_ai_toolkit(
536571
sds_sd,
537572
ait_sd,
538573
f"lora_unet_single_blocks_{i}_linear2",
539-
f"transformer.single_transformer_blocks.{i}.proj_out",
574+
(
575+
f"transformer.single_transformer_blocks.{i}.attn.to_out"
576+
if version_flux2 else
577+
f"transformer.single_transformer_blocks.{i}.proj_out"
578+
),
540579
)
541580
_convert_to_ai_toolkit(
542581
sds_sd,

src/diffusers/loaders/lora_pipeline.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5472,6 +5472,16 @@ def lora_state_dict(
54725472
logger.warning(warn_msg)
54735473
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
54745474

5475+
is_kohya = any(".lora_down.weight" in k for k in state_dict)
5476+
if is_kohya:
5477+
state_dict = _convert_kohya_flux_lora_to_diffusers(
5478+
state_dict,
5479+
version_flux2=True,
5480+
)
5481+
# Kohya already takes care of scaling the LoRA parameters with alpha.
5482+
for k in state_dict:
5483+
assert "alpha" not in k, f"Found key with alpha: {k}"
5484+
54755485
is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
54765486
if is_ai_toolkit:
54775487
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)

0 commit comments

Comments
 (0)