@@ -1346,6 +1346,228 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
1346
1346
return converted_state_dict
1347
1347
1348
1348
1349
+ def _convert_fal_kontext_lora_to_diffusers (original_state_dict ):
1350
+ converted_state_dict = {}
1351
+ original_state_dict_keys = list (original_state_dict .keys ())
1352
+ num_layers = 19
1353
+ num_single_layers = 38
1354
+ inner_dim = 3072
1355
+ mlp_ratio = 4.0
1356
+
1357
+ # double transformer blocks
1358
+ for i in range (num_layers ):
1359
+ block_prefix = f"transformer_blocks.{ i } ."
1360
+ original_block_prefix = "base_model.model."
1361
+
1362
+ for lora_key in ["lora_A" , "lora_B" ]:
1363
+ # norms
1364
+ converted_state_dict [f"{ block_prefix } norm1.linear.{ lora_key } .weight" ] = original_state_dict .pop (
1365
+ f"{ original_block_prefix } double_blocks.{ i } .img_mod.lin.{ lora_key } .weight"
1366
+ )
1367
+ if f"double_blocks.{ i } .img_mod.lin.{ lora_key } .bias" in original_state_dict_keys :
1368
+ converted_state_dict [f"{ block_prefix } norm1.linear.{ lora_key } .bias" ] = original_state_dict .pop (
1369
+ f"{ original_block_prefix } double_blocks.{ i } .img_mod.lin.{ lora_key } .bias"
1370
+ )
1371
+
1372
+ converted_state_dict [f"{ block_prefix } norm1_context.linear.{ lora_key } .weight" ] = original_state_dict .pop (
1373
+ f"{ original_block_prefix } double_blocks.{ i } .txt_mod.lin.{ lora_key } .weight"
1374
+ )
1375
+
1376
+ # Q, K, V
1377
+ if lora_key == "lora_A" :
1378
+ sample_lora_weight = original_state_dict .pop (
1379
+ f"{ original_block_prefix } double_blocks.{ i } .img_attn.qkv.{ lora_key } .weight"
1380
+ )
1381
+ converted_state_dict [f"{ block_prefix } attn.to_v.{ lora_key } .weight" ] = torch .cat ([sample_lora_weight ])
1382
+ converted_state_dict [f"{ block_prefix } attn.to_q.{ lora_key } .weight" ] = torch .cat ([sample_lora_weight ])
1383
+ converted_state_dict [f"{ block_prefix } attn.to_k.{ lora_key } .weight" ] = torch .cat ([sample_lora_weight ])
1384
+
1385
+ context_lora_weight = original_state_dict .pop (
1386
+ f"{ original_block_prefix } double_blocks.{ i } .txt_attn.qkv.{ lora_key } .weight"
1387
+ )
1388
+ converted_state_dict [f"{ block_prefix } attn.add_q_proj.{ lora_key } .weight" ] = torch .cat (
1389
+ [context_lora_weight ]
1390
+ )
1391
+ converted_state_dict [f"{ block_prefix } attn.add_k_proj.{ lora_key } .weight" ] = torch .cat (
1392
+ [context_lora_weight ]
1393
+ )
1394
+ converted_state_dict [f"{ block_prefix } attn.add_v_proj.{ lora_key } .weight" ] = torch .cat (
1395
+ [context_lora_weight ]
1396
+ )
1397
+ else :
1398
+ sample_q , sample_k , sample_v = torch .chunk (
1399
+ original_state_dict .pop (
1400
+ f"{ original_block_prefix } double_blocks.{ i } .img_attn.qkv.{ lora_key } .weight"
1401
+ ),
1402
+ 3 ,
1403
+ dim = 0 ,
1404
+ )
1405
+ converted_state_dict [f"{ block_prefix } attn.to_q.{ lora_key } .weight" ] = torch .cat ([sample_q ])
1406
+ converted_state_dict [f"{ block_prefix } attn.to_k.{ lora_key } .weight" ] = torch .cat ([sample_k ])
1407
+ converted_state_dict [f"{ block_prefix } attn.to_v.{ lora_key } .weight" ] = torch .cat ([sample_v ])
1408
+
1409
+ context_q , context_k , context_v = torch .chunk (
1410
+ original_state_dict .pop (
1411
+ f"{ original_block_prefix } double_blocks.{ i } .txt_attn.qkv.{ lora_key } .weight"
1412
+ ),
1413
+ 3 ,
1414
+ dim = 0 ,
1415
+ )
1416
+ converted_state_dict [f"{ block_prefix } attn.add_q_proj.{ lora_key } .weight" ] = torch .cat ([context_q ])
1417
+ converted_state_dict [f"{ block_prefix } attn.add_k_proj.{ lora_key } .weight" ] = torch .cat ([context_k ])
1418
+ converted_state_dict [f"{ block_prefix } attn.add_v_proj.{ lora_key } .weight" ] = torch .cat ([context_v ])
1419
+
1420
+ if f"double_blocks.{ i } .img_attn.qkv.{ lora_key } .bias" in original_state_dict_keys :
1421
+ sample_q_bias , sample_k_bias , sample_v_bias = torch .chunk (
1422
+ original_state_dict .pop (f"{ original_block_prefix } double_blocks.{ i } .img_attn.qkv.{ lora_key } .bias" ),
1423
+ 3 ,
1424
+ dim = 0 ,
1425
+ )
1426
+ converted_state_dict [f"{ block_prefix } attn.to_q.{ lora_key } .bias" ] = torch .cat ([sample_q_bias ])
1427
+ converted_state_dict [f"{ block_prefix } attn.to_k.{ lora_key } .bias" ] = torch .cat ([sample_k_bias ])
1428
+ converted_state_dict [f"{ block_prefix } attn.to_v.{ lora_key } .bias" ] = torch .cat ([sample_v_bias ])
1429
+
1430
+ if f"double_blocks.{ i } .txt_attn.qkv.{ lora_key } .bias" in original_state_dict_keys :
1431
+ context_q_bias , context_k_bias , context_v_bias = torch .chunk (
1432
+ original_state_dict .pop (f"{ original_block_prefix } double_blocks.{ i } .txt_attn.qkv.{ lora_key } .bias" ),
1433
+ 3 ,
1434
+ dim = 0 ,
1435
+ )
1436
+ converted_state_dict [f"{ block_prefix } attn.add_q_proj.{ lora_key } .bias" ] = torch .cat ([context_q_bias ])
1437
+ converted_state_dict [f"{ block_prefix } attn.add_k_proj.{ lora_key } .bias" ] = torch .cat ([context_k_bias ])
1438
+ converted_state_dict [f"{ block_prefix } attn.add_v_proj.{ lora_key } .bias" ] = torch .cat ([context_v_bias ])
1439
+
1440
+ # ff img_mlp
1441
+ converted_state_dict [f"{ block_prefix } ff.net.0.proj.{ lora_key } .weight" ] = original_state_dict .pop (
1442
+ f"{ original_block_prefix } double_blocks.{ i } .img_mlp.0.{ lora_key } .weight"
1443
+ )
1444
+ if f"{ original_block_prefix } double_blocks.{ i } .img_mlp.0.{ lora_key } .bias" in original_state_dict_keys :
1445
+ converted_state_dict [f"{ block_prefix } ff.net.0.proj.{ lora_key } .bias" ] = original_state_dict .pop (
1446
+ f"{ original_block_prefix } double_blocks.{ i } .img_mlp.0.{ lora_key } .bias"
1447
+ )
1448
+
1449
+ converted_state_dict [f"{ block_prefix } ff.net.2.{ lora_key } .weight" ] = original_state_dict .pop (
1450
+ f"{ original_block_prefix } double_blocks.{ i } .img_mlp.2.{ lora_key } .weight"
1451
+ )
1452
+ if f"{ original_block_prefix } double_blocks.{ i } .img_mlp.2.{ lora_key } .bias" in original_state_dict_keys :
1453
+ converted_state_dict [f"{ block_prefix } ff.net.2.{ lora_key } .bias" ] = original_state_dict .pop (
1454
+ f"{ original_block_prefix } double_blocks.{ i } .img_mlp.2.{ lora_key } .bias"
1455
+ )
1456
+
1457
+ converted_state_dict [f"{ block_prefix } ff_context.net.0.proj.{ lora_key } .weight" ] = original_state_dict .pop (
1458
+ f"{ original_block_prefix } double_blocks.{ i } .txt_mlp.0.{ lora_key } .weight"
1459
+ )
1460
+ if f"{ original_block_prefix } double_blocks.{ i } .txt_mlp.0.{ lora_key } .bias" in original_state_dict_keys :
1461
+ converted_state_dict [f"{ block_prefix } ff_context.net.0.proj.{ lora_key } .bias" ] = original_state_dict .pop (
1462
+ f"{ original_block_prefix } double_blocks.{ i } .txt_mlp.0.{ lora_key } .bias"
1463
+ )
1464
+
1465
+ converted_state_dict [f"{ block_prefix } ff_context.net.2.{ lora_key } .weight" ] = original_state_dict .pop (
1466
+ f"{ original_block_prefix } double_blocks.{ i } .txt_mlp.2.{ lora_key } .weight"
1467
+ )
1468
+ if f"{ original_block_prefix } double_blocks.{ i } .txt_mlp.2.{ lora_key } .bias" in original_state_dict_keys :
1469
+ converted_state_dict [f"{ block_prefix } ff_context.net.2.{ lora_key } .bias" ] = original_state_dict .pop (
1470
+ f"{ original_block_prefix } double_blocks.{ i } .txt_mlp.2.{ lora_key } .bias"
1471
+ )
1472
+
1473
+ # output projections.
1474
+ converted_state_dict [f"{ block_prefix } attn.to_out.0.{ lora_key } .weight" ] = original_state_dict .pop (
1475
+ f"{ original_block_prefix } double_blocks.{ i } .img_attn.proj.{ lora_key } .weight"
1476
+ )
1477
+ if f"{ original_block_prefix } double_blocks.{ i } .img_attn.proj.{ lora_key } .bias" in original_state_dict_keys :
1478
+ converted_state_dict [f"{ block_prefix } attn.to_out.0.{ lora_key } .bias" ] = original_state_dict .pop (
1479
+ f"{ original_block_prefix } double_blocks.{ i } .img_attn.proj.{ lora_key } .bias"
1480
+ )
1481
+ converted_state_dict [f"{ block_prefix } attn.to_add_out.{ lora_key } .weight" ] = original_state_dict .pop (
1482
+ f"{ original_block_prefix } double_blocks.{ i } .txt_attn.proj.{ lora_key } .weight"
1483
+ )
1484
+ if f"{ original_block_prefix } double_blocks.{ i } .txt_attn.proj.{ lora_key } .bias" in original_state_dict_keys :
1485
+ converted_state_dict [f"{ block_prefix } attn.to_add_out.{ lora_key } .bias" ] = original_state_dict .pop (
1486
+ f"{ original_block_prefix } double_blocks.{ i } .txt_attn.proj.{ lora_key } .bias"
1487
+ )
1488
+
1489
+ # single transformer blocks
1490
+ for i in range (num_single_layers ):
1491
+ block_prefix = f"single_transformer_blocks.{ i } ."
1492
+
1493
+ for lora_key in ["lora_A" , "lora_B" ]:
1494
+ # norm.linear <- single_blocks.0.modulation.lin
1495
+ converted_state_dict [f"{ block_prefix } norm.linear.{ lora_key } .weight" ] = original_state_dict .pop (
1496
+ f"{ original_block_prefix } single_blocks.{ i } .modulation.lin.{ lora_key } .weight"
1497
+ )
1498
+ if f"{ original_block_prefix } single_blocks.{ i } .modulation.lin.{ lora_key } .bias" in original_state_dict_keys :
1499
+ converted_state_dict [f"{ block_prefix } norm.linear.{ lora_key } .bias" ] = original_state_dict .pop (
1500
+ f"{ original_block_prefix } single_blocks.{ i } .modulation.lin.{ lora_key } .bias"
1501
+ )
1502
+
1503
+ # Q, K, V, mlp
1504
+ mlp_hidden_dim = int (inner_dim * mlp_ratio )
1505
+ split_size = (inner_dim , inner_dim , inner_dim , mlp_hidden_dim )
1506
+
1507
+ if lora_key == "lora_A" :
1508
+ lora_weight = original_state_dict .pop (
1509
+ f"{ original_block_prefix } single_blocks.{ i } .linear1.{ lora_key } .weight"
1510
+ )
1511
+ converted_state_dict [f"{ block_prefix } attn.to_q.{ lora_key } .weight" ] = torch .cat ([lora_weight ])
1512
+ converted_state_dict [f"{ block_prefix } attn.to_k.{ lora_key } .weight" ] = torch .cat ([lora_weight ])
1513
+ converted_state_dict [f"{ block_prefix } attn.to_v.{ lora_key } .weight" ] = torch .cat ([lora_weight ])
1514
+ converted_state_dict [f"{ block_prefix } proj_mlp.{ lora_key } .weight" ] = torch .cat ([lora_weight ])
1515
+
1516
+ if f"{ original_block_prefix } single_blocks.{ i } .linear1.{ lora_key } .bias" in original_state_dict_keys :
1517
+ lora_bias = original_state_dict .pop (f"single_blocks.{ i } .linear1.{ lora_key } .bias" )
1518
+ converted_state_dict [f"{ block_prefix } attn.to_q.{ lora_key } .bias" ] = torch .cat ([lora_bias ])
1519
+ converted_state_dict [f"{ block_prefix } attn.to_k.{ lora_key } .bias" ] = torch .cat ([lora_bias ])
1520
+ converted_state_dict [f"{ block_prefix } attn.to_v.{ lora_key } .bias" ] = torch .cat ([lora_bias ])
1521
+ converted_state_dict [f"{ block_prefix } proj_mlp.{ lora_key } .bias" ] = torch .cat ([lora_bias ])
1522
+ else :
1523
+ q , k , v , mlp = torch .split (
1524
+ original_state_dict .pop (f"{ original_block_prefix } single_blocks.{ i } .linear1.{ lora_key } .weight" ),
1525
+ split_size ,
1526
+ dim = 0 ,
1527
+ )
1528
+ converted_state_dict [f"{ block_prefix } attn.to_q.{ lora_key } .weight" ] = torch .cat ([q ])
1529
+ converted_state_dict [f"{ block_prefix } attn.to_k.{ lora_key } .weight" ] = torch .cat ([k ])
1530
+ converted_state_dict [f"{ block_prefix } attn.to_v.{ lora_key } .weight" ] = torch .cat ([v ])
1531
+ converted_state_dict [f"{ block_prefix } proj_mlp.{ lora_key } .weight" ] = torch .cat ([mlp ])
1532
+
1533
+ if f"{ original_block_prefix } single_blocks.{ i } .linear1.{ lora_key } .bias" in original_state_dict_keys :
1534
+ q_bias , k_bias , v_bias , mlp_bias = torch .split (
1535
+ original_state_dict .pop (f"{ original_block_prefix } single_blocks.{ i } .linear1.{ lora_key } .bias" ),
1536
+ split_size ,
1537
+ dim = 0 ,
1538
+ )
1539
+ converted_state_dict [f"{ block_prefix } attn.to_q.{ lora_key } .bias" ] = torch .cat ([q_bias ])
1540
+ converted_state_dict [f"{ block_prefix } attn.to_k.{ lora_key } .bias" ] = torch .cat ([k_bias ])
1541
+ converted_state_dict [f"{ block_prefix } attn.to_v.{ lora_key } .bias" ] = torch .cat ([v_bias ])
1542
+ converted_state_dict [f"{ block_prefix } proj_mlp.{ lora_key } .bias" ] = torch .cat ([mlp_bias ])
1543
+
1544
+ # output projections.
1545
+ converted_state_dict [f"{ block_prefix } proj_out.{ lora_key } .weight" ] = original_state_dict .pop (
1546
+ f"{ original_block_prefix } single_blocks.{ i } .linear2.{ lora_key } .weight"
1547
+ )
1548
+ if f"{ original_block_prefix } single_blocks.{ i } .linear2.{ lora_key } .bias" in original_state_dict_keys :
1549
+ converted_state_dict [f"{ block_prefix } proj_out.{ lora_key } .bias" ] = original_state_dict .pop (
1550
+ f"{ original_block_prefix } single_blocks.{ i } .linear2.{ lora_key } .bias"
1551
+ )
1552
+
1553
+ for lora_key in ["lora_A" , "lora_B" ]:
1554
+ converted_state_dict [f"proj_out.{ lora_key } .weight" ] = original_state_dict .pop (
1555
+ f"{ original_block_prefix } final_layer.linear.{ lora_key } .weight"
1556
+ )
1557
+ if f"{ original_block_prefix } final_layer.linear.{ lora_key } .bias" in original_state_dict_keys :
1558
+ converted_state_dict [f"proj_out.{ lora_key } .bias" ] = original_state_dict .pop (
1559
+ f"{ original_block_prefix } final_layer.linear.{ lora_key } .bias"
1560
+ )
1561
+
1562
+ if len (original_state_dict ) > 0 :
1563
+ raise ValueError (f"`original_state_dict` should be empty at this point but has { original_state_dict .keys ()= } ." )
1564
+
1565
+ for key in list (converted_state_dict .keys ()):
1566
+ converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
1567
+
1568
+ return converted_state_dict
1569
+
1570
+
1349
1571
def _convert_hunyuan_video_lora_to_diffusers (original_state_dict ):
1350
1572
converted_state_dict = {k : original_state_dict .pop (k ) for k in list (original_state_dict .keys ())}
1351
1573
0 commit comments