Skip to content

Commit fbdf26b

Browse files
[dreambooth lora sdxl] add sdxl micro conditioning (#6795)
* add micro conditioning * remove redundant lines * style * fix missing 's' * fix missing shape bug due to missing RGB if statement * remove redundant if, change arg order --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 13001ee commit fbdf26b

File tree

1 file changed

+67
-33
lines changed

1 file changed

+67
-33
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 67 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import logging
2020
import math
2121
import os
22+
import random
2223
import shutil
2324
import warnings
2425
from pathlib import Path
@@ -40,6 +41,7 @@
4041
from PIL.ImageOps import exif_transpose
4142
from torch.utils.data import Dataset
4243
from torchvision import transforms
44+
from torchvision.transforms.functional import crop
4345
from tqdm.auto import tqdm
4446
from transformers import AutoTokenizer, PretrainedConfig
4547

@@ -304,18 +306,6 @@ def parse_args(input_args=None):
304306
" resolution"
305307
),
306308
)
307-
parser.add_argument(
308-
"--crops_coords_top_left_h",
309-
type=int,
310-
default=0,
311-
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
312-
)
313-
parser.add_argument(
314-
"--crops_coords_top_left_w",
315-
type=int,
316-
default=0,
317-
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
318-
)
319309
parser.add_argument(
320310
"--center_crop",
321311
default=False,
@@ -325,6 +315,11 @@ def parse_args(input_args=None):
325315
" cropped. The images will be resized to the resolution first before cropping."
326316
),
327317
)
318+
parser.add_argument(
319+
"--random_flip",
320+
action="store_true",
321+
help="whether to randomly flip images horizontally",
322+
)
328323
parser.add_argument(
329324
"--train_text_encoder",
330325
action="store_true",
@@ -669,6 +664,41 @@ def __init__(
669664
self.instance_images = []
670665
for img in instance_images:
671666
self.instance_images.extend(itertools.repeat(img, repeats))
667+
668+
# image processing to prepare for using SD-XL micro-conditioning
669+
self.original_sizes = []
670+
self.crop_top_lefts = []
671+
self.pixel_values = []
672+
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
673+
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
674+
train_flip = transforms.RandomHorizontalFlip(p=1.0)
675+
train_transforms = transforms.Compose(
676+
[
677+
transforms.ToTensor(),
678+
transforms.Normalize([0.5], [0.5]),
679+
]
680+
)
681+
for image in self.instance_images:
682+
image = exif_transpose(image)
683+
if not image.mode == "RGB":
684+
image = image.convert("RGB")
685+
self.original_sizes.append((image.height, image.width))
686+
image = train_resize(image)
687+
if args.random_flip and random.random() < 0.5:
688+
# flip
689+
image = train_flip(image)
690+
if args.center_crop:
691+
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
692+
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
693+
image = train_crop(image)
694+
else:
695+
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
696+
image = crop(image, y1, x1, h, w)
697+
crop_top_left = (y1, x1)
698+
self.crop_top_lefts.append(crop_top_left)
699+
image = train_transforms(image)
700+
self.pixel_values.append(image)
701+
672702
self.num_instance_images = len(self.instance_images)
673703
self._length = self.num_instance_images
674704

@@ -698,12 +728,12 @@ def __len__(self):
698728

699729
def __getitem__(self, index):
700730
example = {}
701-
instance_image = self.instance_images[index % self.num_instance_images]
702-
instance_image = exif_transpose(instance_image)
703-
704-
if not instance_image.mode == "RGB":
705-
instance_image = instance_image.convert("RGB")
706-
example["instance_images"] = self.image_transforms(instance_image)
731+
instance_image = self.pixel_values[index % self.num_instance_images]
732+
original_size = self.original_sizes[index % self.num_instance_images]
733+
crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
734+
example["instance_images"] = instance_image
735+
example["original_size"] = original_size
736+
example["crop_top_left"] = crop_top_left
707737

708738
if self.custom_instance_prompts:
709739
caption = self.custom_instance_prompts[index % self.num_instance_images]
@@ -730,6 +760,8 @@ def __getitem__(self, index):
730760
def collate_fn(examples, with_prior_preservation=False):
731761
pixel_values = [example["instance_images"] for example in examples]
732762
prompts = [example["instance_prompt"] for example in examples]
763+
original_sizes = [example["original_size"] for example in examples]
764+
crop_top_lefts = [example["crop_top_left"] for example in examples]
733765

734766
# Concat class and instance examples for prior preservation.
735767
# We do this to avoid doing two forward passes.
@@ -740,7 +772,12 @@ def collate_fn(examples, with_prior_preservation=False):
740772
pixel_values = torch.stack(pixel_values)
741773
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
742774

743-
batch = {"pixel_values": pixel_values, "prompts": prompts}
775+
batch = {
776+
"pixel_values": pixel_values,
777+
"prompts": prompts,
778+
"original_sizes": original_sizes,
779+
"crop_top_lefts": crop_top_lefts,
780+
}
744781
return batch
745782

746783

@@ -1233,11 +1270,9 @@ def load_model_hook(models, input_dir):
12331270
# pooled text embeddings
12341271
# time ids
12351272

1236-
def compute_time_ids():
1273+
def compute_time_ids(original_size, crops_coords_top_left):
12371274
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1238-
original_size = (args.resolution, args.resolution)
12391275
target_size = (args.resolution, args.resolution)
1240-
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
12411276
add_time_ids = list(original_size + crops_coords_top_left + target_size)
12421277
add_time_ids = torch.tensor([add_time_ids])
12431278
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
@@ -1254,9 +1289,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
12541289
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
12551290
return prompt_embeds, pooled_prompt_embeds
12561291

1257-
# Handle instance prompt.
1258-
instance_time_ids = compute_time_ids()
1259-
12601292
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
12611293
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
12621294
# the redundant encoding.
@@ -1267,7 +1299,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
12671299

12681300
# Handle class prompt for prior-preservation.
12691301
if args.with_prior_preservation:
1270-
class_time_ids = compute_time_ids()
12711302
if not args.train_text_encoder:
12721303
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
12731304
args.class_prompt, text_encoders, tokenizers
@@ -1282,9 +1313,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
12821313
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
12831314
# pack the statically computed variables appropriately here. This is so that we don't
12841315
# have to pass them to the dataloader.
1285-
add_time_ids = instance_time_ids
1286-
if args.with_prior_preservation:
1287-
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
12881316

12891317
if not train_dataset.custom_instance_prompts:
12901318
if not args.train_text_encoder:
@@ -1436,18 +1464,24 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14361464
# (this is the forward diffusion process)
14371465
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
14381466

1467+
# time ids
1468+
add_time_ids = torch.cat(
1469+
[
1470+
compute_time_ids(original_size=s, crops_coords_top_left=c)
1471+
for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
1472+
]
1473+
)
1474+
14391475
# Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
14401476
if not train_dataset.custom_instance_prompts:
14411477
elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
1442-
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
14431478
else:
14441479
elems_to_repeat_text_embeds = 1
1445-
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
14461480

14471481
# Predict the noise residual
14481482
if not args.train_text_encoder:
14491483
unet_added_conditions = {
1450-
"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
1484+
"time_ids": add_time_ids,
14511485
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
14521486
}
14531487
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
@@ -1459,7 +1493,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14591493
return_dict=False,
14601494
)[0]
14611495
else:
1462-
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
1496+
unet_added_conditions = {"time_ids": add_time_ids}
14631497
prompt_embeds, pooled_prompt_embeds = encode_prompt(
14641498
text_encoders=[text_encoder_one, text_encoder_two],
14651499
tokenizers=None,

0 commit comments

Comments
 (0)