Skip to content

Commit cef4f65

Browse files
sayakpaulsaqlain2204asomozahlkyyiyixuxu
authored
[LoRA] log a warning when there are missing keys in the LoRA loading. (#9622)
* log a warning when there are missing keys in the LoRA loading. * handle missing keys and unexpected keys better. * add tests * fix-copies. * updates * tests * concat warning. * Add Differential Diffusion to Kolors (#9423) * Added diff diff support for kolors img2img * Fized relative imports * Fized relative imports * Added diff diff support for Kolors * Fized import issues * Added map * Fized import issues * Fixed naming issues * Added diffdiff support for Kolors img2img pipeline * Removed example docstrings * Added map input * Updated latents Co-authored-by: Álvaro Somoza <[email protected]> * Updated `original_with_noise` Co-authored-by: Álvaro Somoza <[email protected]> * Improved code quality --------- Co-authored-by: Álvaro Somoza <[email protected]> * FluxMultiControlNetModel (#9647) * tests * Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: YiYi Xu <[email protected]> * fix --------- Co-authored-by: M Saqlain <[email protected]> Co-authored-by: Álvaro Somoza <[email protected]> Co-authored-by: hlky <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent 29a2c5d commit cef4f65

File tree

4 files changed

+211
-35
lines changed

4 files changed

+211
-35
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 84 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,14 +1358,30 @@ def load_lora_into_transformer(
13581358
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
13591359
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
13601360

1361+
warn_msg = ""
13611362
if incompatible_keys is not None:
1362-
# check only for unexpected keys
1363+
# Check only for unexpected keys.
13631364
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
13641365
if unexpected_keys:
1365-
logger.warning(
1366-
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1367-
f" {unexpected_keys}. "
1368-
)
1366+
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
1367+
if lora_unexpected_keys:
1368+
warn_msg = (
1369+
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
1370+
f" {', '.join(lora_unexpected_keys)}. "
1371+
)
1372+
1373+
# Filter missing keys specific to the current adapter.
1374+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
1375+
if missing_keys:
1376+
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
1377+
if lora_missing_keys:
1378+
warn_msg += (
1379+
f"Loading adapter weights from state_dict led to missing keys in the model:"
1380+
f" {', '.join(lora_missing_keys)}."
1381+
)
1382+
1383+
if warn_msg:
1384+
logger.warning(warn_msg)
13691385

13701386
# Offload back.
13711387
if is_model_cpu_offload:
@@ -1932,14 +1948,30 @@ def load_lora_into_transformer(
19321948
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
19331949
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
19341950

1951+
warn_msg = ""
19351952
if incompatible_keys is not None:
1936-
# check only for unexpected keys
1953+
# Check only for unexpected keys.
19371954
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
19381955
if unexpected_keys:
1939-
logger.warning(
1940-
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1941-
f" {unexpected_keys}. "
1942-
)
1956+
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
1957+
if lora_unexpected_keys:
1958+
warn_msg = (
1959+
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
1960+
f" {', '.join(lora_unexpected_keys)}. "
1961+
)
1962+
1963+
# Filter missing keys specific to the current adapter.
1964+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
1965+
if missing_keys:
1966+
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
1967+
if lora_missing_keys:
1968+
warn_msg += (
1969+
f"Loading adapter weights from state_dict led to missing keys in the model:"
1970+
f" {', '.join(lora_missing_keys)}."
1971+
)
1972+
1973+
if warn_msg:
1974+
logger.warning(warn_msg)
19431975

19441976
# Offload back.
19451977
if is_model_cpu_offload:
@@ -2279,14 +2311,30 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada
22792311
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
22802312
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
22812313

2314+
warn_msg = ""
22822315
if incompatible_keys is not None:
2283-
# check only for unexpected keys
2316+
# Check only for unexpected keys.
22842317
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
22852318
if unexpected_keys:
2286-
logger.warning(
2287-
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
2288-
f" {unexpected_keys}. "
2289-
)
2319+
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
2320+
if lora_unexpected_keys:
2321+
warn_msg = (
2322+
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
2323+
f" {', '.join(lora_unexpected_keys)}. "
2324+
)
2325+
2326+
# Filter missing keys specific to the current adapter.
2327+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
2328+
if missing_keys:
2329+
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
2330+
if lora_missing_keys:
2331+
warn_msg += (
2332+
f"Loading adapter weights from state_dict led to missing keys in the model:"
2333+
f" {', '.join(lora_missing_keys)}."
2334+
)
2335+
2336+
if warn_msg:
2337+
logger.warning(warn_msg)
22902338

22912339
# Offload back.
22922340
if is_model_cpu_offload:
@@ -2717,14 +2765,30 @@ def load_lora_into_transformer(
27172765
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
27182766
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
27192767

2768+
warn_msg = ""
27202769
if incompatible_keys is not None:
2721-
# check only for unexpected keys
2770+
# Check only for unexpected keys.
27222771
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
27232772
if unexpected_keys:
2724-
logger.warning(
2725-
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
2726-
f" {unexpected_keys}. "
2727-
)
2773+
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
2774+
if lora_unexpected_keys:
2775+
warn_msg = (
2776+
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
2777+
f" {', '.join(lora_unexpected_keys)}. "
2778+
)
2779+
2780+
# Filter missing keys specific to the current adapter.
2781+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
2782+
if missing_keys:
2783+
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
2784+
if lora_missing_keys:
2785+
warn_msg += (
2786+
f"Loading adapter weights from state_dict led to missing keys in the model:"
2787+
f" {', '.join(lora_missing_keys)}."
2788+
)
2789+
2790+
if warn_msg:
2791+
logger.warning(warn_msg)
27282792

27292793
# Offload back.
27302794
if is_model_cpu_offload:

src/diffusers/loaders/unet.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,14 +354,30 @@ def _process_lora(
354354
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
355355
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
356356

357+
warn_msg = ""
357358
if incompatible_keys is not None:
358-
# check only for unexpected keys
359+
# Check only for unexpected keys.
359360
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
360361
if unexpected_keys:
361-
logger.warning(
362-
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
363-
f" {unexpected_keys}. "
364-
)
362+
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
363+
if lora_unexpected_keys:
364+
warn_msg = (
365+
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
366+
f" {', '.join(lora_unexpected_keys)}. "
367+
)
368+
369+
# Filter missing keys specific to the current adapter.
370+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
371+
if missing_keys:
372+
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
373+
if lora_missing_keys:
374+
warn_msg += (
375+
f"Loading adapter weights from state_dict led to missing keys in the model:"
376+
f" {', '.join(lora_missing_keys)}."
377+
)
378+
379+
if warn_msg:
380+
logger.warning(warn_msg)
365381

366382
return is_model_cpu_offload, is_sequential_cpu_offload
367383

tests/lora/test_lora_layers_flux.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from diffusers.utils.testing_utils import (
2828
floats_tensor,
2929
is_peft_available,
30+
numpy_cosine_similarity_distance,
3031
require_peft_backend,
3132
require_torch_gpu,
3233
slow,
@@ -166,7 +167,7 @@ def test_modify_padding_mode(self):
166167
@slow
167168
@require_torch_gpu
168169
@require_peft_backend
169-
@unittest.skip("We cannot run inference on this model with the current CI hardware")
170+
# @unittest.skip("We cannot run inference on this model with the current CI hardware")
170171
# TODO (DN6, sayakpaul): move these tests to a beefier GPU
171172
class FluxLoRAIntegrationTests(unittest.TestCase):
172173
"""internal note: The integration slices were obtained on audace.
@@ -208,9 +209,11 @@ def test_flux_the_last_ben(self):
208209
generator=torch.manual_seed(self.seed),
209210
).images
210211
out_slice = out[0, -3:, -3:, -1].flatten()
211-
expected_slice = np.array([0.1719, 0.1719, 0.1699, 0.1719, 0.1719, 0.1738, 0.1641, 0.1621, 0.2090])
212+
expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246])
212213

213-
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
214+
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
215+
216+
assert max_diff < 1e-3
214217

215218
def test_flux_kohya(self):
216219
self.pipeline.load_lora_weights("Norod78/brain-slug-flux")
@@ -230,7 +233,9 @@ def test_flux_kohya(self):
230233
out_slice = out[0, -3:, -3:, -1].flatten()
231234
expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484])
232235

233-
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
236+
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
237+
238+
assert max_diff < 1e-3
234239

235240
def test_flux_kohya_with_text_encoder(self):
236241
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
@@ -248,9 +253,11 @@ def test_flux_kohya_with_text_encoder(self):
248253
).images
249254

250255
out_slice = out[0, -3:, -3:, -1].flatten()
251-
expected_slice = np.array([0.4023, 0.4043, 0.4023, 0.3965, 0.3984, 0.3984, 0.3906, 0.3906, 0.4219])
256+
expected_slice = np.array([0.4023, 0.4023, 0.4023, 0.3965, 0.3984, 0.3965, 0.3926, 0.3906, 0.4219])
252257

253-
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
258+
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
259+
260+
assert max_diff < 1e-3
254261

255262
def test_flux_xlabs(self):
256263
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
@@ -268,6 +275,8 @@ def test_flux_xlabs(self):
268275
generator=torch.manual_seed(self.seed),
269276
).images
270277
out_slice = out[0, -3:, -3:, -1].flatten()
271-
expected_slice = np.array([0.3984, 0.4199, 0.4453, 0.4102, 0.4375, 0.4590, 0.4141, 0.4355, 0.4980])
278+
expected_slice = np.array([0.3965, 0.4180, 0.4434, 0.4082, 0.4375, 0.4590, 0.4141, 0.4375, 0.4980])
279+
280+
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
272281

273-
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
282+
assert max_diff < 1e-3

tests/lora/utils.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
LCMScheduler,
2828
UNet2DConditionModel,
2929
)
30+
from diffusers.utils import logging
3031
from diffusers.utils.import_utils import is_peft_available
3132
from diffusers.utils.testing_utils import (
33+
CaptureLogger,
3234
floats_tensor,
3335
require_peft_backend,
3436
require_peft_version_greater,
@@ -219,10 +221,18 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):
219221
modules_to_save = {}
220222
lora_loadable_modules = self.pipeline_class._lora_loadable_modules
221223

222-
if "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder"):
224+
if (
225+
"text_encoder" in lora_loadable_modules
226+
and hasattr(pipe, "text_encoder")
227+
and getattr(pipe.text_encoder, "peft_config", None) is not None
228+
):
223229
modules_to_save["text_encoder"] = pipe.text_encoder
224230

225-
if "text_encoder_2" in lora_loadable_modules and hasattr(pipe, "text_encoder_2"):
231+
if (
232+
"text_encoder_2" in lora_loadable_modules
233+
and hasattr(pipe, "text_encoder_2")
234+
and getattr(pipe.text_encoder_2, "peft_config", None) is not None
235+
):
226236
modules_to_save["text_encoder_2"] = pipe.text_encoder_2
227237

228238
if has_denoiser:
@@ -1747,6 +1757,83 @@ def test_simple_inference_with_dora(self):
17471757
"DoRA lora should change the output",
17481758
)
17491759

1760+
def test_missing_keys_warning(self):
1761+
scheduler_cls = self.scheduler_classes[0]
1762+
# Skip text encoder check for now as that is handled with `transformers`.
1763+
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1764+
pipe = self.pipeline_class(**components)
1765+
pipe = pipe.to(torch_device)
1766+
pipe.set_progress_bar_config(disable=None)
1767+
1768+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
1769+
denoiser.add_adapter(denoiser_lora_config)
1770+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1771+
1772+
with tempfile.TemporaryDirectory() as tmpdirname:
1773+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
1774+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
1775+
self.pipeline_class.save_lora_weights(
1776+
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
1777+
)
1778+
pipe.unload_lora_weights()
1779+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
1780+
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)
1781+
1782+
# To make things dynamic since we cannot settle with a single key for all the models where we
1783+
# offer PEFT support.
1784+
missing_key = [k for k in state_dict if "lora_A" in k][0]
1785+
del state_dict[missing_key]
1786+
1787+
logger = (
1788+
logging.get_logger("diffusers.loaders.unet")
1789+
if self.unet_kwargs is not None
1790+
else logging.get_logger("diffusers.loaders.lora_pipeline")
1791+
)
1792+
logger.setLevel(30)
1793+
with CaptureLogger(logger) as cap_logger:
1794+
pipe.load_lora_weights(state_dict)
1795+
1796+
# Since the missing key won't contain the adapter name ("default_0").
1797+
# Also strip out the component prefix (such as "unet." from `missing_key`).
1798+
component = list({k.split(".")[0] for k in state_dict})[0]
1799+
self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))
1800+
1801+
def test_unexpected_keys_warning(self):
1802+
scheduler_cls = self.scheduler_classes[0]
1803+
# Skip text encoder check for now as that is handled with `transformers`.
1804+
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1805+
pipe = self.pipeline_class(**components)
1806+
pipe = pipe.to(torch_device)
1807+
pipe.set_progress_bar_config(disable=None)
1808+
1809+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
1810+
denoiser.add_adapter(denoiser_lora_config)
1811+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1812+
1813+
with tempfile.TemporaryDirectory() as tmpdirname:
1814+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
1815+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
1816+
self.pipeline_class.save_lora_weights(
1817+
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
1818+
)
1819+
pipe.unload_lora_weights()
1820+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
1821+
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)
1822+
1823+
unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
1824+
state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)
1825+
1826+
logger = (
1827+
logging.get_logger("diffusers.loaders.unet")
1828+
if self.unet_kwargs is not None
1829+
else logging.get_logger("diffusers.loaders.lora_pipeline")
1830+
)
1831+
logger.setLevel(30)
1832+
with CaptureLogger(logger) as cap_logger:
1833+
pipe.load_lora_weights(state_dict)
1834+
1835+
self.assertTrue(".diffusers_cat" in cap_logger.out)
1836+
17501837
@unittest.skip("This is failing for now - need to investigate")
17511838
def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
17521839
"""

0 commit comments

Comments
 (0)