19
19
import logging
20
20
import math
21
21
import os
22
+ import random
22
23
import shutil
23
24
import warnings
24
25
from pathlib import Path
40
41
from PIL .ImageOps import exif_transpose
41
42
from torch .utils .data import Dataset
42
43
from torchvision import transforms
44
+ from torchvision .transforms .functional import crop
43
45
from tqdm .auto import tqdm
44
46
from transformers import AutoTokenizer , PretrainedConfig
45
47
@@ -304,18 +306,6 @@ def parse_args(input_args=None):
304
306
" resolution"
305
307
),
306
308
)
307
- parser .add_argument (
308
- "--crops_coords_top_left_h" ,
309
- type = int ,
310
- default = 0 ,
311
- help = ("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet." ),
312
- )
313
- parser .add_argument (
314
- "--crops_coords_top_left_w" ,
315
- type = int ,
316
- default = 0 ,
317
- help = ("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet." ),
318
- )
319
309
parser .add_argument (
320
310
"--center_crop" ,
321
311
default = False ,
@@ -325,6 +315,11 @@ def parse_args(input_args=None):
325
315
" cropped. The images will be resized to the resolution first before cropping."
326
316
),
327
317
)
318
+ parser .add_argument (
319
+ "--random_flip" ,
320
+ action = "store_true" ,
321
+ help = "whether to randomly flip images horizontally" ,
322
+ )
328
323
parser .add_argument (
329
324
"--train_text_encoder" ,
330
325
action = "store_true" ,
@@ -669,6 +664,41 @@ def __init__(
669
664
self .instance_images = []
670
665
for img in instance_images :
671
666
self .instance_images .extend (itertools .repeat (img , repeats ))
667
+
668
+ # image processing to prepare for using SD-XL micro-conditioning
669
+ self .original_sizes = []
670
+ self .crop_top_lefts = []
671
+ self .pixel_values = []
672
+ train_resize = transforms .Resize (size , interpolation = transforms .InterpolationMode .BILINEAR )
673
+ train_crop = transforms .CenterCrop (size ) if center_crop else transforms .RandomCrop (size )
674
+ train_flip = transforms .RandomHorizontalFlip (p = 1.0 )
675
+ train_transforms = transforms .Compose (
676
+ [
677
+ transforms .ToTensor (),
678
+ transforms .Normalize ([0.5 ], [0.5 ]),
679
+ ]
680
+ )
681
+ for image in self .instance_images :
682
+ image = exif_transpose (image )
683
+ if not image .mode == "RGB" :
684
+ image = image .convert ("RGB" )
685
+ self .original_sizes .append ((image .height , image .width ))
686
+ image = train_resize (image )
687
+ if args .random_flip and random .random () < 0.5 :
688
+ # flip
689
+ image = train_flip (image )
690
+ if args .center_crop :
691
+ y1 = max (0 , int (round ((image .height - args .resolution ) / 2.0 )))
692
+ x1 = max (0 , int (round ((image .width - args .resolution ) / 2.0 )))
693
+ image = train_crop (image )
694
+ else :
695
+ y1 , x1 , h , w = train_crop .get_params (image , (args .resolution , args .resolution ))
696
+ image = crop (image , y1 , x1 , h , w )
697
+ crop_top_left = (y1 , x1 )
698
+ self .crop_top_lefts .append (crop_top_left )
699
+ image = train_transforms (image )
700
+ self .pixel_values .append (image )
701
+
672
702
self .num_instance_images = len (self .instance_images )
673
703
self ._length = self .num_instance_images
674
704
@@ -698,12 +728,12 @@ def __len__(self):
698
728
699
729
def __getitem__ (self , index ):
700
730
example = {}
701
- instance_image = self .instance_images [index % self .num_instance_images ]
702
- instance_image = exif_transpose ( instance_image )
703
-
704
- if not instance_image . mode == "RGB" :
705
- instance_image = instance_image . convert ( "RGB" )
706
- example ["instance_images " ] = self . image_transforms ( instance_image )
731
+ instance_image = self .pixel_values [index % self .num_instance_images ]
732
+ original_size = self . original_sizes [ index % self . num_instance_images ]
733
+ crop_top_left = self . crop_top_lefts [ index % self . num_instance_images ]
734
+ example [ "instance_images" ] = instance_image
735
+ example [ "original_size" ] = original_size
736
+ example ["crop_top_left " ] = crop_top_left
707
737
708
738
if self .custom_instance_prompts :
709
739
caption = self .custom_instance_prompts [index % self .num_instance_images ]
@@ -730,6 +760,8 @@ def __getitem__(self, index):
730
760
def collate_fn (examples , with_prior_preservation = False ):
731
761
pixel_values = [example ["instance_images" ] for example in examples ]
732
762
prompts = [example ["instance_prompt" ] for example in examples ]
763
+ original_sizes = [example ["original_size" ] for example in examples ]
764
+ crop_top_lefts = [example ["crop_top_left" ] for example in examples ]
733
765
734
766
# Concat class and instance examples for prior preservation.
735
767
# We do this to avoid doing two forward passes.
@@ -740,7 +772,12 @@ def collate_fn(examples, with_prior_preservation=False):
740
772
pixel_values = torch .stack (pixel_values )
741
773
pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
742
774
743
- batch = {"pixel_values" : pixel_values , "prompts" : prompts }
775
+ batch = {
776
+ "pixel_values" : pixel_values ,
777
+ "prompts" : prompts ,
778
+ "original_sizes" : original_sizes ,
779
+ "crop_top_lefts" : crop_top_lefts ,
780
+ }
744
781
return batch
745
782
746
783
@@ -1233,11 +1270,9 @@ def load_model_hook(models, input_dir):
1233
1270
# pooled text embeddings
1234
1271
# time ids
1235
1272
1236
- def compute_time_ids ():
1273
+ def compute_time_ids (original_size , crops_coords_top_left ):
1237
1274
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1238
- original_size = (args .resolution , args .resolution )
1239
1275
target_size = (args .resolution , args .resolution )
1240
- crops_coords_top_left = (args .crops_coords_top_left_h , args .crops_coords_top_left_w )
1241
1276
add_time_ids = list (original_size + crops_coords_top_left + target_size )
1242
1277
add_time_ids = torch .tensor ([add_time_ids ])
1243
1278
add_time_ids = add_time_ids .to (accelerator .device , dtype = weight_dtype )
@@ -1254,9 +1289,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1254
1289
pooled_prompt_embeds = pooled_prompt_embeds .to (accelerator .device )
1255
1290
return prompt_embeds , pooled_prompt_embeds
1256
1291
1257
- # Handle instance prompt.
1258
- instance_time_ids = compute_time_ids ()
1259
-
1260
1292
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
1261
1293
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
1262
1294
# the redundant encoding.
@@ -1267,7 +1299,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1267
1299
1268
1300
# Handle class prompt for prior-preservation.
1269
1301
if args .with_prior_preservation :
1270
- class_time_ids = compute_time_ids ()
1271
1302
if not args .train_text_encoder :
1272
1303
class_prompt_hidden_states , class_pooled_prompt_embeds = compute_text_embeddings (
1273
1304
args .class_prompt , text_encoders , tokenizers
@@ -1282,9 +1313,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1282
1313
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
1283
1314
# pack the statically computed variables appropriately here. This is so that we don't
1284
1315
# have to pass them to the dataloader.
1285
- add_time_ids = instance_time_ids
1286
- if args .with_prior_preservation :
1287
- add_time_ids = torch .cat ([add_time_ids , class_time_ids ], dim = 0 )
1288
1316
1289
1317
if not train_dataset .custom_instance_prompts :
1290
1318
if not args .train_text_encoder :
@@ -1436,18 +1464,24 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1436
1464
# (this is the forward diffusion process)
1437
1465
noisy_model_input = noise_scheduler .add_noise (model_input , noise , timesteps )
1438
1466
1467
+ # time ids
1468
+ add_time_ids = torch .cat (
1469
+ [
1470
+ compute_time_ids (original_size = s , crops_coords_top_left = c )
1471
+ for s , c in zip (batch ["original_sizes" ], batch ["crop_top_lefts" ])
1472
+ ]
1473
+ )
1474
+
1439
1475
# Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
1440
1476
if not train_dataset .custom_instance_prompts :
1441
1477
elems_to_repeat_text_embeds = bsz // 2 if args .with_prior_preservation else bsz
1442
- elems_to_repeat_time_ids = bsz // 2 if args .with_prior_preservation else bsz
1443
1478
else :
1444
1479
elems_to_repeat_text_embeds = 1
1445
- elems_to_repeat_time_ids = bsz // 2 if args .with_prior_preservation else bsz
1446
1480
1447
1481
# Predict the noise residual
1448
1482
if not args .train_text_encoder :
1449
1483
unet_added_conditions = {
1450
- "time_ids" : add_time_ids . repeat ( elems_to_repeat_time_ids , 1 ) ,
1484
+ "time_ids" : add_time_ids ,
1451
1485
"text_embeds" : unet_add_text_embeds .repeat (elems_to_repeat_text_embeds , 1 ),
1452
1486
}
1453
1487
prompt_embeds_input = prompt_embeds .repeat (elems_to_repeat_text_embeds , 1 , 1 )
@@ -1459,7 +1493,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1459
1493
return_dict = False ,
1460
1494
)[0 ]
1461
1495
else :
1462
- unet_added_conditions = {"time_ids" : add_time_ids . repeat ( elems_to_repeat_time_ids , 1 ) }
1496
+ unet_added_conditions = {"time_ids" : add_time_ids }
1463
1497
prompt_embeds , pooled_prompt_embeds = encode_prompt (
1464
1498
text_encoders = [text_encoder_one , text_encoder_two ],
1465
1499
tokenizers = None ,
0 commit comments