diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 6bba0b94b1b2..317ed2c2b2e1 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -1228,7 +1228,7 @@ def main(args): else {"device": accelerator.device, "dtype": weight_dtype} ) - is_fsdp = accelerator.state.fsdp_plugin is not None + is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None if not is_fsdp: transformer.to(**transformer_to_kwargs) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 5f6a1afb410c..419821e8a8e5 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1178,7 +1178,7 @@ def main(args): else {"device": accelerator.device, "dtype": weight_dtype} ) - is_fsdp = accelerator.state.fsdp_plugin is not None + is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None if not is_fsdp: transformer.to(**transformer_to_kwargs)