@@ -140,7 +140,7 @@ def save_model_card(
140
140
model_card = load_or_create_model_card (
141
141
repo_id_or_path = repo_id ,
142
142
from_training = True ,
143
- license = "openrail++ " ,
143
+ license = "other " ,
144
144
base_model = base_model ,
145
145
prompt = instance_prompt ,
146
146
model_description = model_description ,
@@ -186,7 +186,7 @@ def log_validation(
186
186
f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
187
187
f" { args .validation_prompt } ."
188
188
)
189
- pipeline = pipeline .to (accelerator .device , dtype = torch_dtype )
189
+ pipeline = pipeline .to (accelerator .device )
190
190
pipeline .set_progress_bar_config (disable = True )
191
191
192
192
# run inference
@@ -608,6 +608,12 @@ def parse_args(input_args=None):
608
608
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
609
609
),
610
610
)
611
+ parser .add_argument (
612
+ "--cache_latents" ,
613
+ action = "store_true" ,
614
+ default = False ,
615
+ help = "Cache the VAE latents" ,
616
+ )
611
617
parser .add_argument (
612
618
"--report_to" ,
613
619
type = str ,
@@ -628,6 +634,15 @@ def parse_args(input_args=None):
628
634
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
629
635
),
630
636
)
637
+ parser .add_argument (
638
+ "--upcast_before_saving" ,
639
+ action = "store_true" ,
640
+ default = False ,
641
+ help = (
642
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
643
+ "Defaults to precision dtype used for training to save memory"
644
+ ),
645
+ )
631
646
parser .add_argument (
632
647
"--prior_generation_precision" ,
633
648
type = str ,
@@ -1394,6 +1409,16 @@ def load_model_hook(models, input_dir):
1394
1409
logger .warning (
1395
1410
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
1396
1411
)
1412
+ if args .train_text_encoder and args .text_encoder_lr :
1413
+ logger .warning (
1414
+ f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
1415
+ f" { args .text_encoder_lr } and learning_rate: { args .learning_rate } . "
1416
+ f"When using prodigy only learning_rate is used as the initial learning rate."
1417
+ )
1418
+ # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
1419
+ # --learning_rate
1420
+ params_to_optimize [1 ]["lr" ] = args .learning_rate
1421
+ params_to_optimize [2 ]["lr" ] = args .learning_rate
1397
1422
1398
1423
optimizer = optimizer_class (
1399
1424
params_to_optimize ,
@@ -1440,6 +1465,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1440
1465
pooled_prompt_embeds = pooled_prompt_embeds .to (accelerator .device )
1441
1466
return prompt_embeds , pooled_prompt_embeds
1442
1467
1468
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
1469
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
1470
+ # the redundant encoding.
1443
1471
if not args .train_text_encoder and not train_dataset .custom_instance_prompts :
1444
1472
instance_prompt_hidden_states , instance_pooled_prompt_embeds = compute_text_embeddings (
1445
1473
args .instance_prompt , text_encoders , tokenizers
@@ -1484,6 +1512,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1484
1512
tokens_two = torch .cat ([tokens_two , class_tokens_two ], dim = 0 )
1485
1513
tokens_three = torch .cat ([tokens_three , class_tokens_three ], dim = 0 )
1486
1514
1515
+ vae_config_shift_factor = vae .config .shift_factor
1516
+ vae_config_scaling_factor = vae .config .scaling_factor
1517
+ if args .cache_latents :
1518
+ latents_cache = []
1519
+ for batch in tqdm (train_dataloader , desc = "Caching latents" ):
1520
+ with torch .no_grad ():
1521
+ batch ["pixel_values" ] = batch ["pixel_values" ].to (
1522
+ accelerator .device , non_blocking = True , dtype = weight_dtype
1523
+ )
1524
+ latents_cache .append (vae .encode (batch ["pixel_values" ]).latent_dist )
1525
+
1526
+ if args .validation_prompt is None :
1527
+ del vae
1528
+ free_memory ()
1529
+
1487
1530
# Scheduler and math around the number of training steps.
1488
1531
overrode_max_train_steps = False
1489
1532
num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
@@ -1500,7 +1543,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1500
1543
power = args .lr_power ,
1501
1544
)
1502
1545
1503
- # Prepare everything with our `accelerator`.
1504
1546
# Prepare everything with our `accelerator`.
1505
1547
if args .train_text_encoder :
1506
1548
(
@@ -1607,8 +1649,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1607
1649
1608
1650
for step , batch in enumerate (train_dataloader ):
1609
1651
models_to_accumulate = [transformer ]
1652
+ if args .train_text_encoder :
1653
+ models_to_accumulate .extend ([text_encoder_one , text_encoder_two ])
1610
1654
with accelerator .accumulate (models_to_accumulate ):
1611
- pixel_values = batch ["pixel_values" ].to (dtype = vae .dtype )
1612
1655
prompts = batch ["prompts" ]
1613
1656
1614
1657
# encode batch prompts when custom prompts are provided for each image -
@@ -1639,8 +1682,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1639
1682
)
1640
1683
1641
1684
# Convert images to latent space
1642
- model_input = vae .encode (pixel_values ).latent_dist .sample ()
1643
- model_input = (model_input - vae .config .shift_factor ) * vae .config .scaling_factor
1685
+ if args .cache_latents :
1686
+ model_input = latents_cache [step ].sample ()
1687
+ else :
1688
+ pixel_values = batch ["pixel_values" ].to (dtype = vae .dtype )
1689
+ model_input = vae .encode (pixel_values ).latent_dist .sample ()
1690
+
1691
+ model_input = (model_input - vae_config_shift_factor ) * vae_config_scaling_factor
1644
1692
model_input = model_input .to (dtype = weight_dtype )
1645
1693
1646
1694
# Sample noise that we'll add to the latents
@@ -1773,6 +1821,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1773
1821
text_encoder_one , text_encoder_two , text_encoder_three = load_text_encoders (
1774
1822
text_encoder_cls_one , text_encoder_cls_two , text_encoder_cls_three
1775
1823
)
1824
+ text_encoder_one .to (weight_dtype )
1825
+ text_encoder_two .to (weight_dtype )
1776
1826
pipeline = StableDiffusion3Pipeline .from_pretrained (
1777
1827
args .pretrained_model_name_or_path ,
1778
1828
vae = vae ,
@@ -1793,15 +1843,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1793
1843
epoch = epoch ,
1794
1844
torch_dtype = weight_dtype ,
1795
1845
)
1796
-
1797
- del text_encoder_one , text_encoder_two , text_encoder_three
1798
- free_memory ()
1846
+ if not args . train_text_encoder :
1847
+ del text_encoder_one , text_encoder_two , text_encoder_three
1848
+ free_memory ()
1799
1849
1800
1850
# Save the lora layers
1801
1851
accelerator .wait_for_everyone ()
1802
1852
if accelerator .is_main_process :
1803
1853
transformer = unwrap_model (transformer )
1804
- transformer = transformer .to (torch .float32 )
1854
+ if args .upcast_before_saving :
1855
+ transformer .to (torch .float32 )
1856
+ else :
1857
+ transformer = transformer .to (weight_dtype )
1805
1858
transformer_lora_layers = get_peft_model_state_dict (transformer )
1806
1859
1807
1860
if args .train_text_encoder :
0 commit comments