@@ -358,7 +358,10 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
358358
359359# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
360360# are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
361- def _convert_kohya_flux_lora_to_diffusers (state_dict ):
361+ def _convert_kohya_flux_lora_to_diffusers (
362+ state_dict ,
363+ version_flux2 = False ,
364+ ):
362365 def _convert_to_ai_toolkit (sds_sd , ait_sd , sds_key , ait_key ):
363366 if sds_key + ".lora_down.weight" not in sds_sd :
364367 return
@@ -449,7 +452,15 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
449452
450453 def _convert_sd_scripts_to_ai_toolkit (sds_sd ):
451454 ait_sd = {}
452- for i in range (19 ):
455+
456+ max_num_double_blocks , max_num_single_blocks = - 1 , - 1
457+ for key in list (sds_sd .keys ()):
458+ if key .startswith ("lora_unet_double_blocks_" ):
459+ max_num_double_blocks = max (max_num_double_blocks , int (key .split ("_" )[4 ]))
460+ if key .startswith ("lora_unet_single_blocks_" ):
461+ max_num_single_blocks = max (max_num_single_blocks , int (key .split ("_" )[4 ]))
462+
463+ for i in range (max_num_double_blocks + 1 ):
453464 _convert_to_ai_toolkit (
454465 sds_sd ,
455466 ait_sd ,
@@ -470,13 +481,21 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
470481 sds_sd ,
471482 ait_sd ,
472483 f"lora_unet_double_blocks_{ i } _img_mlp_0" ,
473- f"transformer.transformer_blocks.{ i } .ff.net.0.proj" ,
484+ (
485+ f"transformer.transformer_blocks.{ i } .ff.linear_in"
486+ if version_flux2 else
487+ f"transformer.transformer_blocks.{ i } .ff.net.0.proj"
488+ ),
474489 )
475490 _convert_to_ai_toolkit (
476491 sds_sd ,
477492 ait_sd ,
478493 f"lora_unet_double_blocks_{ i } _img_mlp_2" ,
479- f"transformer.transformer_blocks.{ i } .ff.net.2" ,
494+ (
495+ f"transformer.transformer_blocks.{ i } .ff.linear_out"
496+ if version_flux2 else
497+ f"transformer.transformer_blocks.{ i } .ff.net.2"
498+ ),
480499 )
481500 _convert_to_ai_toolkit (
482501 sds_sd ,
@@ -504,13 +523,21 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
504523 sds_sd ,
505524 ait_sd ,
506525 f"lora_unet_double_blocks_{ i } _txt_mlp_0" ,
507- f"transformer.transformer_blocks.{ i } .ff_context.net.0.proj" ,
526+ (
527+ f"transformer.transformer_blocks.{ i } .ff_context.linear_in"
528+ if version_flux2 else
529+ f"transformer.transformer_blocks.{ i } .ff_context.net.0.proj"
530+ ),
508531 )
509532 _convert_to_ai_toolkit (
510533 sds_sd ,
511534 ait_sd ,
512535 f"lora_unet_double_blocks_{ i } _txt_mlp_2" ,
513- f"transformer.transformer_blocks.{ i } .ff_context.net.2" ,
536+ (
537+ f"transformer.transformer_blocks.{ i } .ff_context.linear_out"
538+ if version_flux2 else
539+ f"transformer.transformer_blocks.{ i } .ff_context.net.2"
540+ ),
514541 )
515542 _convert_to_ai_toolkit (
516543 sds_sd ,
@@ -519,24 +546,36 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
519546 f"transformer.transformer_blocks.{ i } .norm1_context.linear" ,
520547 )
521548
522- for i in range (38 ):
523- _convert_to_ai_toolkit_cat (
524- sds_sd ,
525- ait_sd ,
526- f"lora_unet_single_blocks_{ i } _linear1" ,
527- [
528- f"transformer.single_transformer_blocks.{ i } .attn.to_q" ,
529- f"transformer.single_transformer_blocks.{ i } .attn.to_k" ,
530- f"transformer.single_transformer_blocks.{ i } .attn.to_v" ,
531- f"transformer.single_transformer_blocks.{ i } .proj_mlp" ,
532- ],
533- dims = [3072 , 3072 , 3072 , 12288 ],
534- )
549+ for i in range (max_num_single_blocks + 1 ):
550+ if version_flux2 :
551+ _convert_to_ai_toolkit (
552+ sds_sd ,
553+ ait_sd ,
554+ f"lora_unet_single_blocks_{ i } _linear1" ,
555+ f"transformer.single_transformer_blocks.{ i } .attn.to_qkv_mlp_proj" ,
556+ )
557+ else :
558+ _convert_to_ai_toolkit_cat (
559+ sds_sd ,
560+ ait_sd ,
561+ f"lora_unet_single_blocks_{ i } _linear1" ,
562+ [
563+ f"transformer.single_transformer_blocks.{ i } .attn.to_q" ,
564+ f"transformer.single_transformer_blocks.{ i } .attn.to_k" ,
565+ f"transformer.single_transformer_blocks.{ i } .attn.to_v" ,
566+ f"transformer.single_transformer_blocks.{ i } .proj_mlp" ,
567+ ],
568+ dims = [3072 , 3072 , 3072 , 12288 ],
569+ )
535570 _convert_to_ai_toolkit (
536571 sds_sd ,
537572 ait_sd ,
538573 f"lora_unet_single_blocks_{ i } _linear2" ,
539- f"transformer.single_transformer_blocks.{ i } .proj_out" ,
574+ (
575+ f"transformer.single_transformer_blocks.{ i } .attn.to_out"
576+ if version_flux2 else
577+ f"transformer.single_transformer_blocks.{ i } .proj_out"
578+ ),
540579 )
541580 _convert_to_ai_toolkit (
542581 sds_sd ,
0 commit comments