diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index cf0a1588f39b..8fb749d328c9 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1929,6 +1929,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip): if args.cache_latents: latents_cache = [] + # Store vae config before potential deletion + vae_scaling_factor = vae.config.scaling_factor for batch in tqdm(train_dataloader, desc="Caching latents"): with torch.no_grad(): batch["pixel_values"] = batch["pixel_values"].to( @@ -1940,6 +1942,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip): del vae if torch.cuda.is_available(): torch.cuda.empty_cache() + else: + vae_scaling_factor = vae.config.scaling_factor # Scheduler and math around the number of training steps. # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. @@ -2109,13 +2113,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = vae.encode(pixel_values).latent_dist.sample() if latents_mean is None and latents_std is None: - model_input = model_input * vae.config.scaling_factor + model_input = model_input * vae_scaling_factor if args.pretrained_vae_model_name_or_path is None: model_input = model_input.to(weight_dtype) else: latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype) latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype) - model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std + model_input = (model_input - latents_mean) * vae_scaling_factor / latents_std model_input = model_input.to(dtype=weight_dtype) # Sample noise that we'll add to the latents