Skip to content

Commit f864a9a

Browse files
[Flux Kontext] Support Fal Kontext LoRA (#11823)
* initial commit * initial commit * initial commit * fix import * fix prefix * remove print * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent d6fa329 commit f864a9a

File tree

2 files changed

+234
-0
lines changed

2 files changed

+234
-0
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,6 +1346,228 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
13461346
return converted_state_dict
13471347

13481348

1349+
def _convert_fal_kontext_lora_to_diffusers(original_state_dict):
1350+
converted_state_dict = {}
1351+
original_state_dict_keys = list(original_state_dict.keys())
1352+
num_layers = 19
1353+
num_single_layers = 38
1354+
inner_dim = 3072
1355+
mlp_ratio = 4.0
1356+
1357+
# double transformer blocks
1358+
for i in range(num_layers):
1359+
block_prefix = f"transformer_blocks.{i}."
1360+
original_block_prefix = "base_model.model."
1361+
1362+
for lora_key in ["lora_A", "lora_B"]:
1363+
# norms
1364+
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
1365+
f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.weight"
1366+
)
1367+
if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
1368+
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
1369+
f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.bias"
1370+
)
1371+
1372+
converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
1373+
f"{original_block_prefix}double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
1374+
)
1375+
1376+
# Q, K, V
1377+
if lora_key == "lora_A":
1378+
sample_lora_weight = original_state_dict.pop(
1379+
f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
1380+
)
1381+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
1382+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
1383+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
1384+
1385+
context_lora_weight = original_state_dict.pop(
1386+
f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
1387+
)
1388+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
1389+
[context_lora_weight]
1390+
)
1391+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
1392+
[context_lora_weight]
1393+
)
1394+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
1395+
[context_lora_weight]
1396+
)
1397+
else:
1398+
sample_q, sample_k, sample_v = torch.chunk(
1399+
original_state_dict.pop(
1400+
f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
1401+
),
1402+
3,
1403+
dim=0,
1404+
)
1405+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
1406+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
1407+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
1408+
1409+
context_q, context_k, context_v = torch.chunk(
1410+
original_state_dict.pop(
1411+
f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
1412+
),
1413+
3,
1414+
dim=0,
1415+
)
1416+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
1417+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
1418+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
1419+
1420+
if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
1421+
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
1422+
original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.bias"),
1423+
3,
1424+
dim=0,
1425+
)
1426+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
1427+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
1428+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
1429+
1430+
if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
1431+
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
1432+
original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"),
1433+
3,
1434+
dim=0,
1435+
)
1436+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
1437+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
1438+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
1439+
1440+
# ff img_mlp
1441+
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
1442+
f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.weight"
1443+
)
1444+
if f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
1445+
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
1446+
f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias"
1447+
)
1448+
1449+
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
1450+
f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.weight"
1451+
)
1452+
if f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
1453+
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
1454+
f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias"
1455+
)
1456+
1457+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
1458+
f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
1459+
)
1460+
if f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
1461+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
1462+
f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
1463+
)
1464+
1465+
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
1466+
f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
1467+
)
1468+
if f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
1469+
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
1470+
f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
1471+
)
1472+
1473+
# output projections.
1474+
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
1475+
f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.weight"
1476+
)
1477+
if f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
1478+
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
1479+
f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias"
1480+
)
1481+
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
1482+
f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
1483+
)
1484+
if f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
1485+
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
1486+
f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
1487+
)
1488+
1489+
# single transformer blocks
1490+
for i in range(num_single_layers):
1491+
block_prefix = f"single_transformer_blocks.{i}."
1492+
1493+
for lora_key in ["lora_A", "lora_B"]:
1494+
# norm.linear <- single_blocks.0.modulation.lin
1495+
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
1496+
f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.weight"
1497+
)
1498+
if f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
1499+
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
1500+
f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias"
1501+
)
1502+
1503+
# Q, K, V, mlp
1504+
mlp_hidden_dim = int(inner_dim * mlp_ratio)
1505+
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
1506+
1507+
if lora_key == "lora_A":
1508+
lora_weight = original_state_dict.pop(
1509+
f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"
1510+
)
1511+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
1512+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
1513+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
1514+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
1515+
1516+
if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
1517+
lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
1518+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
1519+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
1520+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
1521+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
1522+
else:
1523+
q, k, v, mlp = torch.split(
1524+
original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"),
1525+
split_size,
1526+
dim=0,
1527+
)
1528+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
1529+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
1530+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
1531+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
1532+
1533+
if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
1534+
q_bias, k_bias, v_bias, mlp_bias = torch.split(
1535+
original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias"),
1536+
split_size,
1537+
dim=0,
1538+
)
1539+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
1540+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
1541+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
1542+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
1543+
1544+
# output projections.
1545+
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
1546+
f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.weight"
1547+
)
1548+
if f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
1549+
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
1550+
f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias"
1551+
)
1552+
1553+
for lora_key in ["lora_A", "lora_B"]:
1554+
converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
1555+
f"{original_block_prefix}final_layer.linear.{lora_key}.weight"
1556+
)
1557+
if f"{original_block_prefix}final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
1558+
converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
1559+
f"{original_block_prefix}final_layer.linear.{lora_key}.bias"
1560+
)
1561+
1562+
if len(original_state_dict) > 0:
1563+
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
1564+
1565+
for key in list(converted_state_dict.keys()):
1566+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1567+
1568+
return converted_state_dict
1569+
1570+
13491571
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
13501572
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
13511573

src/diffusers/loaders/lora_pipeline.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from .lora_conversion_utils import (
4343
_convert_bfl_flux_control_lora_to_diffusers,
44+
_convert_fal_kontext_lora_to_diffusers,
4445
_convert_hunyuan_video_lora_to_diffusers,
4546
_convert_kohya_flux_lora_to_diffusers,
4647
_convert_musubi_wan_lora_to_diffusers,
@@ -2062,6 +2063,17 @@ def lora_state_dict(
20622063
return_metadata=return_lora_metadata,
20632064
)
20642065

2066+
is_fal_kontext = any("base_model" in k for k in state_dict)
2067+
if is_fal_kontext:
2068+
state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict)
2069+
return cls._prepare_outputs(
2070+
state_dict,
2071+
metadata=metadata,
2072+
alphas=None,
2073+
return_alphas=return_alphas,
2074+
return_metadata=return_lora_metadata,
2075+
)
2076+
20652077
# For state dicts like
20662078
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
20672079
keys = list(state_dict.keys())

0 commit comments

Comments
 (0)