Skip to content

Commit ee4ab23

Browse files
[SD3 dreambooth-lora training] small updates + bug fixes (#9682)
* add latent caching + smol updates * update license * replace with free_memory * add --upcast_before_saving to allow saving transformer weights in lower precision * fix models to accumulate * fix mixed precision issue as proposed in #9565 * smol update to readme * style * fix caching latents * style * add tests for latent caching * style * fix latent caching --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent cef4f65 commit ee4ab23

File tree

4 files changed

+107
-24
lines changed

4 files changed

+107
-24
lines changed

examples/dreambooth/README_sd3.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ accelerate launch train_dreambooth_lora_sd3.py \
136136
--resolution=512 \
137137
--train_batch_size=1 \
138138
--gradient_accumulation_steps=4 \
139-
--learning_rate=1e-5 \
139+
--learning_rate=4e-4 \
140140
--report_to="wandb" \
141141
--lr_scheduler="constant" \
142142
--lr_warmup_steps=0 \

examples/dreambooth/test_dreambooth_lora_sd3.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,39 @@ def test_dreambooth_lora_text_encoder_sd3(self):
103103
)
104104
self.assertTrue(starts_with_expected_prefix)
105105

106+
def test_dreambooth_lora_latent_caching(self):
107+
with tempfile.TemporaryDirectory() as tmpdir:
108+
test_args = f"""
109+
{self.script_path}
110+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
111+
--instance_data_dir {self.instance_data_dir}
112+
--instance_prompt {self.instance_prompt}
113+
--resolution 64
114+
--train_batch_size 1
115+
--gradient_accumulation_steps 1
116+
--max_train_steps 2
117+
--cache_latents
118+
--learning_rate 5.0e-04
119+
--scale_lr
120+
--lr_scheduler constant
121+
--lr_warmup_steps 0
122+
--output_dir {tmpdir}
123+
""".split()
124+
125+
run_command(self._launch_args + test_args)
126+
# save_pretrained smoke test
127+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
128+
129+
# make sure the state_dict has the correct naming in the parameters.
130+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
131+
is_lora = all("lora" in k for k in lora_state_dict.keys())
132+
self.assertTrue(is_lora)
133+
134+
# when not training the text encoder, all the parameters in the state dict should start
135+
# with `"transformer"` in their names.
136+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
137+
self.assertTrue(starts_with_transformer)
138+
106139
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
107140
with tempfile.TemporaryDirectory() as tmpdir:
108141
test_args = f"""

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def save_model_card(
140140
model_card = load_or_create_model_card(
141141
repo_id_or_path=repo_id,
142142
from_training=True,
143-
license="openrail++",
143+
license="other",
144144
base_model=base_model,
145145
prompt=instance_prompt,
146146
model_description=model_description,
@@ -186,7 +186,7 @@ def log_validation(
186186
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
187187
f" {args.validation_prompt}."
188188
)
189-
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
189+
pipeline = pipeline.to(accelerator.device)
190190
pipeline.set_progress_bar_config(disable=True)
191191

192192
# run inference
@@ -608,6 +608,12 @@ def parse_args(input_args=None):
608608
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
609609
),
610610
)
611+
parser.add_argument(
612+
"--cache_latents",
613+
action="store_true",
614+
default=False,
615+
help="Cache the VAE latents",
616+
)
611617
parser.add_argument(
612618
"--report_to",
613619
type=str,
@@ -628,6 +634,15 @@ def parse_args(input_args=None):
628634
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
629635
),
630636
)
637+
parser.add_argument(
638+
"--upcast_before_saving",
639+
action="store_true",
640+
default=False,
641+
help=(
642+
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
643+
"Defaults to precision dtype used for training to save memory"
644+
),
645+
)
631646
parser.add_argument(
632647
"--prior_generation_precision",
633648
type=str,
@@ -1394,6 +1409,16 @@ def load_model_hook(models, input_dir):
13941409
logger.warning(
13951410
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
13961411
)
1412+
if args.train_text_encoder and args.text_encoder_lr:
1413+
logger.warning(
1414+
f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
1415+
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
1416+
f"When using prodigy only learning_rate is used as the initial learning rate."
1417+
)
1418+
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
1419+
# --learning_rate
1420+
params_to_optimize[1]["lr"] = args.learning_rate
1421+
params_to_optimize[2]["lr"] = args.learning_rate
13971422

13981423
optimizer = optimizer_class(
13991424
params_to_optimize,
@@ -1440,6 +1465,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14401465
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
14411466
return prompt_embeds, pooled_prompt_embeds
14421467

1468+
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
1469+
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
1470+
# the redundant encoding.
14431471
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
14441472
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
14451473
args.instance_prompt, text_encoders, tokenizers
@@ -1484,6 +1512,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14841512
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
14851513
tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0)
14861514

1515+
vae_config_shift_factor = vae.config.shift_factor
1516+
vae_config_scaling_factor = vae.config.scaling_factor
1517+
if args.cache_latents:
1518+
latents_cache = []
1519+
for batch in tqdm(train_dataloader, desc="Caching latents"):
1520+
with torch.no_grad():
1521+
batch["pixel_values"] = batch["pixel_values"].to(
1522+
accelerator.device, non_blocking=True, dtype=weight_dtype
1523+
)
1524+
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
1525+
1526+
if args.validation_prompt is None:
1527+
del vae
1528+
free_memory()
1529+
14871530
# Scheduler and math around the number of training steps.
14881531
overrode_max_train_steps = False
14891532
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -1500,7 +1543,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15001543
power=args.lr_power,
15011544
)
15021545

1503-
# Prepare everything with our `accelerator`.
15041546
# Prepare everything with our `accelerator`.
15051547
if args.train_text_encoder:
15061548
(
@@ -1607,8 +1649,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16071649

16081650
for step, batch in enumerate(train_dataloader):
16091651
models_to_accumulate = [transformer]
1652+
if args.train_text_encoder:
1653+
models_to_accumulate.extend([text_encoder_one, text_encoder_two])
16101654
with accelerator.accumulate(models_to_accumulate):
1611-
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
16121655
prompts = batch["prompts"]
16131656

16141657
# encode batch prompts when custom prompts are provided for each image -
@@ -1639,8 +1682,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16391682
)
16401683

16411684
# Convert images to latent space
1642-
model_input = vae.encode(pixel_values).latent_dist.sample()
1643-
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
1685+
if args.cache_latents:
1686+
model_input = latents_cache[step].sample()
1687+
else:
1688+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1689+
model_input = vae.encode(pixel_values).latent_dist.sample()
1690+
1691+
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
16441692
model_input = model_input.to(dtype=weight_dtype)
16451693

16461694
# Sample noise that we'll add to the latents
@@ -1773,6 +1821,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17731821
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
17741822
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
17751823
)
1824+
text_encoder_one.to(weight_dtype)
1825+
text_encoder_two.to(weight_dtype)
17761826
pipeline = StableDiffusion3Pipeline.from_pretrained(
17771827
args.pretrained_model_name_or_path,
17781828
vae=vae,
@@ -1793,15 +1843,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17931843
epoch=epoch,
17941844
torch_dtype=weight_dtype,
17951845
)
1796-
1797-
del text_encoder_one, text_encoder_two, text_encoder_three
1798-
free_memory()
1846+
if not args.train_text_encoder:
1847+
del text_encoder_one, text_encoder_two, text_encoder_three
1848+
free_memory()
17991849

18001850
# Save the lora layers
18011851
accelerator.wait_for_everyone()
18021852
if accelerator.is_main_process:
18031853
transformer = unwrap_model(transformer)
1804-
transformer = transformer.to(torch.float32)
1854+
if args.upcast_before_saving:
1855+
transformer.to(torch.float32)
1856+
else:
1857+
transformer = transformer.to(weight_dtype)
18051858
transformer_lora_layers = get_peft_model_state_dict(transformer)
18061859

18071860
if args.train_text_encoder:

examples/dreambooth/train_dreambooth_sd3.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import argparse
1717
import copy
18-
import gc
1918
import itertools
2019
import logging
2120
import math
@@ -51,7 +50,7 @@
5150
StableDiffusion3Pipeline,
5251
)
5352
from diffusers.optimization import get_scheduler
54-
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
53+
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
5554
from diffusers.utils import (
5655
check_min_version,
5756
is_wandb_available,
@@ -119,7 +118,7 @@ def save_model_card(
119118
model_card = load_or_create_model_card(
120119
repo_id_or_path=repo_id,
121120
from_training=True,
122-
license="openrail++",
121+
license="other",
123122
base_model=base_model,
124123
prompt=instance_prompt,
125124
model_description=model_description,
@@ -164,7 +163,7 @@ def log_validation(
164163
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
165164
f" {args.validation_prompt}."
166165
)
167-
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
166+
pipeline = pipeline.to(accelerator.device)
168167
pipeline.set_progress_bar_config(disable=True)
169168

170169
# run inference
@@ -190,8 +189,7 @@ def log_validation(
190189
)
191190

192191
del pipeline
193-
if torch.cuda.is_available():
194-
torch.cuda.empty_cache()
192+
free_memory()
195193

196194
return images
197195

@@ -1065,8 +1063,7 @@ def main(args):
10651063
image.save(image_filename)
10661064

10671065
del pipeline
1068-
if torch.cuda.is_available():
1069-
torch.cuda.empty_cache()
1066+
free_memory()
10701067

10711068
# Handle the repository creation
10721069
if accelerator.is_main_process:
@@ -1386,9 +1383,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
13861383
del tokenizers, text_encoders
13871384
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
13881385
del text_encoder_one, text_encoder_two, text_encoder_three
1389-
gc.collect()
1390-
if torch.cuda.is_available():
1391-
torch.cuda.empty_cache()
1386+
free_memory()
13921387

13931388
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
13941389
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1708,6 +1703,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17081703
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
17091704
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
17101705
)
1706+
text_encoder_one.to(weight_dtype)
1707+
text_encoder_two.to(weight_dtype)
1708+
text_encoder_three.to(weight_dtype)
17111709
pipeline = StableDiffusion3Pipeline.from_pretrained(
17121710
args.pretrained_model_name_or_path,
17131711
vae=vae,
@@ -1730,8 +1728,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17301728
)
17311729
if not args.train_text_encoder:
17321730
del text_encoder_one, text_encoder_two, text_encoder_three
1733-
torch.cuda.empty_cache()
1734-
gc.collect()
1731+
free_memory()
17351732

17361733
# Save the lora layers
17371734
accelerator.wait_for_everyone()

0 commit comments

Comments
 (0)