Open
Description
Describe the bug
I used the latest version of diffusers-0.35.0.dev0
to fine-tune FLUX.1-Kontext with LoRA. However, when I attempted to load the saved LoRA weights, I noticed it abnormally logged that Loading adapter weights from state_dict led to unexpected keys found in the model
.
Subsequently, this same version of diffusers-0.35.0.dev0
produced the identical log messages when loading weights from a previous LoRA fine-tuning of FLUX.1-dev.
However, after I downgraded the diffusers version and ran the exact same code to load the FLUX.1-dev LoRA weights, it worked correctly.
Reproduction
The code to save the LoRA weights and model weights:
class EditModel(L.LightningModule):
def __init__(
self,
config: dict,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
# Initialize the LightningModule
super().__init__()
self.config = config
self.train_config = config["train"]
self.model_config = config["model"]
self.optimizer_config = self.train_config["optimizer"]
self.lr_config = self.train_config["lr_scheduler"]
self.lora_config = self.train_config["lora_config"]
self.use_lora = self.train_config["use_lora"]
self.lora_path = self.train_config["lora_path"]
self.use_step = self.train_config["use_step"]
self.step_path = self.train_config["step_path"]
self.train_connector_only = self.train_config["train_connector_only"]
assert not (
self.use_lora and self.train_connector_only
), "Cannot use both LoRA and train_connector_only"
flux_pipe_id = self.model_config["flux_path"]
qwen2vl_model_path = self.model_config["qwen2vl_model_path"]
max_length = self.model_config["max_length"]
# Load MLLM
self.qwen2vl_encoder = Qwen25VL_7b_Embedder(
qwen2vl_model_path,
device=device,
max_length=max_length,
dtype=dtype,
)
self.qwen2vl_encoder.requires_grad_(False).eval()
# Load Connector
self.connector = Qwen2Connector()
if self.use_step:
self.connector = load_state_dict(self.connector, self.step_path, "cpu")
# self.connector.train()
self.connector.requires_grad_(False).eval()
# Load the Flux pipeline
self.flux_pipe: FluxPipeline = (
FluxPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device)
)
self.transformer = self.flux_pipe.transformer
self.transformer.gradient_checkpointing = self.train_config[
"gradient_checkpointing"
]
self.transformer.train()
noise_scheduler = self.flux_pipe.scheduler
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
# Freeze the Flux pipeline
# self.flux_pipe.text_encoder.requires_grad_(False).eval()
# self.flux_pipe.text_encoder_2.requires_grad_(False).eval()
self.flux_pipe.text_encoder = None
self.flux_pipe.text_encoder_2 = None
torch.cuda.empty_cache()
gc.collect()
self.flux_pipe.vae.requires_grad_(False).eval()
# Initialize LoRA layers
if self.use_lora:
self.lora_layers = self.init_lora(self.lora_path, self.lora_config)
self.to(device).to(dtype)
def init_lora(self, lora_path: str, lora_config: dict):
assert lora_path or lora_config
if lora_path:
# TODO: Implement this
raise NotImplementedError
else:
self.transformer.add_adapter(LoraConfig(**lora_config))
# TODO: Check if this is correct (p.requires_grad)
lora_layers = filter(
lambda p: p.requires_grad, self.transformer.parameters()
)
return list(lora_layers)
def save_weights(self, path: str):
os.makedirs(path, exist_ok=True)
if self.train_connector_only:
torch.save(self.connector.state_dict(), f"{path}/connector.pth")
elif self.use_lora:
FluxPipeline.save_lora_weights(
save_directory=path,
transformer_lora_layers=get_peft_model_state_dict(self.transformer),
safe_serialization=True,
)
torch.save(self.connector.state_dict(), f"{path}/connector.pth")
else:
torch.save(
{
"transformer": self.transformer.state_dict(),
"connector": self.connector.state_dict(),
},
f"{path}/model.pth",
)
The code to load the LoRA weights for my fine-tuned FLUX.1-dev model:
import torch
from diffusers.pipelines import FluxPipeline
flux_pipe: FluxPipeline = (
FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev").to(dtype=torch.bfloat16).to("cuda")
)
flux_pipe.load_lora_weights(
"/data/runs/20250703-133203/ckpt/160000",
weight_name="pytorch_lora_weights.safetensors",
)
Logs
When `diffusers-0.32.2`
Loading pipeline components...: 14%|ββββββββββββββββββββββββ | 1/7 [00:00<00:01, 5.55it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading checkpoint shards: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 2/2 [00:00<00:00, 2.23it/s]
Loading pipeline components...: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 7/7 [00:06<00:00, 1.15it/s]
When `diffusers-0.34.0`
Loading checkpoint shards: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 3/3 [00:02<00:00, 1.03it/s]
Loading pipeline components...: 14%|ββββββββββββββββββββββββ | 1/7 [00:03<00:19, 3.23s/it]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading checkpoint shards: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 2/2 [00:01<00:00, 1.26it/s]
Loading pipeline components...: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 7/7 [00:05<00:00, 1.19it/s]
No LoRA keys associated to CLIPTextModel found with the prefix='text_encoder'. This is safe to ignore if LoRA state dict didn't originally have any CLIPTextModel related params. You can also try specifying `prefix=None` to resolve the warning. Otherwise, open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new
When `diffusers-0.35.0.dev0`
Loading checkpoint shards: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 2/2 [00:00<00:00, 2.26it/s]
Loading checkpoint shards: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 3/3 [00:03<00:00, 1.18s/it]
Loading pipeline components...: 29%|βββββββββββββββββββββββββββββββββββββββββββββββ | 2/7 [00:04<00:13, 2.62s/it]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 7/7 [00:05<00:00, 1.29it/s]
Loading adapter weights from state_dict led to unexpected keys found in the model: single_transformer_blocks.0.proj_out.lora_A.default_0.weight, single_transformer_blocks.0.proj_out.lora_B.default_0.weight, single_transformer_blocks.1.proj_out.lora_A.default_0.weight, single_transformer_blocks.1.proj_out.lora_B.default_0.weight, single_transformer_blocks.2.proj_out.lora_A.default_0.weight, single_transformer_blocks.2.proj_out.lora_B.default_0.weight, single_transformer_blocks.3.proj_out.lora_A.default_0.weight, single_transformer_blocks.3.proj_out.lora_B.default_0.weight, single_transformer_blocks.4.proj_out.lora_A.default_0.weight, single_transformer_blocks.4.proj_out.lora_B.default_0.weight, single_transformer_blocks.5.proj_out.lora_A.default_0.weight, single_transformer_blocks.5.proj_out.lora_B.default_0.weight,
single_transformer_blocks.6.proj_out.lora_A.default_0.weight, single_transformer_blocks.6.proj_out.lora_B.default_0.weight, single_transformer_blocks.7.proj_out.lora_A.default_0.weight, single_transformer_blocks.7.proj_out.lora_B.default_0.weight, single_transformer_blocks.8.proj_out.lora_A.default_0.weight, single_transformer_blocks.8.proj_out.lora_B.default_0.weight, single_transformer_blocks.9.proj_out.lora_A.default_0.weight, single_transformer_blocks.9.proj_out.lora_B.default_0.weight, single_transformer_blocks.10.proj_out.lora_A.default_0.weight, single_transformer_blocks.10.proj_out.lora_B.default_0.weight, single_transformer_blocks.11.proj_out.lora_A.default_0.weight, single_transformer_blocks.11.proj_out.lora_B.default_0.weight, single_transformer_blocks.12.proj_out.lora_A.default_0.weight, single_transformer_blocks.12.proj_out.lora_B.default_0.weight, single_transformer_blocks.13.proj_out.lora_A.default_0.weight, single_transformer_blocks.13.proj_out.lora_B.default_0.weight, single_transformer_blocks.14.proj_out.lora_A.default_0.weight, single_transformer_blocks.14.proj_out.lora_B.default_0.weight, single_transformer_blocks.15.proj_out.lora_A.default_0.weight, single_transformer_blocks.15.proj_out.lora_B.default_0.weight, single_transformer_blocks.16.proj_out.lora_A.default_0.weight, single_transformer_blocks.16.proj_out.lora_B.default_0.weight, single_transformer_blocks.17.proj_out.lora_A.default_0.weight, single_transformer_blocks.17.proj_out.lora_B.default_0.weight, single_transformer_blocks.18.proj_out.lora_A.default_0.weight, single_transformer_blocks.18.proj_out.lora_B.default_0.weight, single_transformer_blocks.19.proj_out.lora_A.default_0.weight, single_transformer_blocks.19.proj_out.lora_B.default_0.weight, single_transformer_blocks.20.proj_out.lora_A.default_0.weight, single_transformer_blocks.20.proj_out.lora_B.default_0.weight, single_transformer_blocks.21.proj_out.lora_A.default_0.weight, single_transformer_blocks.21.proj_out.lora_B.default_0.weight, single_transformer_blocks.22.proj_out.lora_A.default_0.weight, single_transformer_blocks.22.proj_out.lora_B.default_0.weight, single_transformer_blocks.23.proj_out.lora_A.default_0.weight, single_transformer_blocks.23.proj_out.lora_B.default_0.weight, single_transformer_blocks.24.proj_out.lora_A.default_0.weight, single_transformer_blocks.24.proj_out.lora_B.default_0.weight, single_transformer_blocks.25.proj_out.lora_A.default_0.weight,
single_transformer_blocks.25.proj_out.lora_B.default_0.weight, single_transformer_blocks.26.proj_out.lora_A.default_0.weight, single_transformer_blocks.26.proj_out.lora_B.default_0.weight, single_transformer_blocks.27.proj_out.lora_A.default_0.weight, single_transformer_blocks.27.proj_out.lora_B.default_0.weight, single_transformer_blocks.28.proj_out.lora_A.default_0.weight, single_transformer_blocks.28.proj_out.lora_B.default_0.weight, single_transformer_blocks.29.proj_out.lora_A.default_0.weight, single_transformer_blocks.29.proj_out.lora_B.default_0.weight, single_transformer_blocks.30.proj_out.lora_A.default_0.weight, single_transformer_blocks.30.proj_out.lora_B.default_0.weight, single_transformer_blocks.31.proj_out.lora_A.default_0.weight, single_transformer_blocks.31.proj_out.lora_B.default_0.weight, single_transformer_blocks.32.proj_out.lora_A.default_0.weight, single_transformer_blocks.32.proj_out.lora_B.default_0.weight, single_transformer_blocks.33.proj_out.lora_A.default_0.weight, single_transformer_blocks.33.proj_out.lora_B.default_0.weight, single_transformer_blocks.34.proj_out.lora_A.default_0.weight, single_transformer_blocks.34.proj_out.lora_B.default_0.weight, single_transformer_blocks.35.proj_out.lora_A.default_0.weight, single_transformer_blocks.35.proj_out.lora_B.default_0.weight, single_transformer_blocks.36.proj_out.lora_A.default_0.weight, single_transformer_blocks.36.proj_out.lora_B.default_0.weight, single_transformer_blocks.37.proj_out.lora_A.default_0.weight, single_transformer_blocks.37.proj_out.lora_B.default_0.weight.
No LoRA keys associated to CLIPTextModel found with the prefix='text_encoder'. This is safe to ignore if LoRA state dict didn't originally have any CLIPTextModel related params. You can also try specifying `prefix=None` to resolve the warning. Otherwise, open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new
System Info
Latest build from source