Skip to content

Commit 65329ae

Browse files
[advanced dreambooth lora sdxl script] new features + bug fixes (#6691)
* add noise_offset param * micro conditioning - wip * image processing adjusted and moved to support micro conditioning * change time ids to be computed inside train loop * change time ids to be computed inside train loop * change time ids to be computed inside train loop * time ids shape fix * move token replacement of validation prompt to the same section of instance prompt and class prompt * add offset noise to sd15 advanced script * fix token loading during validation * fix token loading during validation in sdxl script * a little clean * style * a little clean * style * sdxl script - a little clean + minor path fix sd 1.5 script - change default resolution value * ad 1.5 script - minor path fix * fix missing comma in code example in model card * clean up commented lines * style * remove time ids computed outside training loop - no longer used now that we utilize micro-conditioning, as all time ids are now computed inside the training loop * style * [WIP] - added draft readme, building off of examples/dreambooth/README.md * readme * readme * readme * readme * readme * readme * readme * readme * removed --crops_coords_top_left from CLI args * style * fix missing shape bug due to missing RGB if statement * add blog mention at the start of the reamde as well * Update examples/advanced_diffusion_training/README.md Co-authored-by: Sayak Paul <[email protected]> * change note to render nicely as well --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 02338c9 commit 65329ae

File tree

4 files changed

+386
-55
lines changed

4 files changed

+386
-55
lines changed
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Advanced diffusion training examples
2+
3+
## Train Dreambooth LoRA with Stable Diffusion XL
4+
> [!TIP]
5+
> 💡 This example follows the techniques and recommended practices covered in the blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). Make sure to check it out before starting 🤗
6+
7+
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
8+
9+
LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
10+
In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
11+
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
12+
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
13+
- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.
14+
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
15+
the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
16+
17+
The `train_dreambooth_lora_sdxl_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_sdxl.py`, with
18+
advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl),
19+
[Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️
20+
21+
> [!NOTE]
22+
> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳
23+
> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora)
24+
25+
📚 Read more about the advanced features and best practices in this community derived blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script)
26+
27+
28+
## Running locally with PyTorch
29+
30+
### Installing the dependencies
31+
32+
Before running the scripts, make sure to install the library's training dependencies:
33+
34+
**Important**
35+
36+
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
37+
```bash
38+
git clone https://github.com/huggingface/diffusers
39+
cd diffusers
40+
pip install -e .
41+
```
42+
43+
Then cd in the `examples/advanced_diffusion_training` folder and run
44+
```bash
45+
pip install -r requirements.txt
46+
```
47+
48+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
49+
50+
```bash
51+
accelerate config
52+
```
53+
54+
Or for a default accelerate configuration without answering questions about your environment
55+
56+
```bash
57+
accelerate config default
58+
```
59+
60+
Or if your environment doesn't support an interactive shell e.g. a notebook
61+
62+
```python
63+
from accelerate.utils import write_basic_config
64+
write_basic_config()
65+
```
66+
67+
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
68+
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
69+
70+
### Pivotal Tuning
71+
**Training with text encoder(s)**
72+
73+
Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization
74+
available with `train_dreambooth_lora_sdxl_advanced.py`, in the advanced script **pivotal tuning** is also supported.
75+
[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning -
76+
we insert new tokens into the text encoders of the model, instead of reusing existing ones.
77+
We then optimize the newly-inserted token embeddings to represent the new concept.
78+
79+
To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`).
80+
Please keep the following points in mind:
81+
82+
* SDXL has two text encoders. So, we fine-tune both using LoRA.
83+
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memoםהקרry.
84+
85+
86+
### 3D icon example
87+
88+
Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon.
89+
90+
Let's first download it locally:
91+
92+
```python
93+
from huggingface_hub import snapshot_download
94+
95+
local_dir = "./3d_icon"
96+
snapshot_download(
97+
"LinoyTsaban/3d_icon",
98+
local_dir=local_dir, repo_type="dataset",
99+
ignore_patterns=".gitattributes",
100+
)
101+
```
102+
103+
Let's review some of the advanced features we're going to be using for this example:
104+
- **custom captions**:
105+
To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by
106+
```bash
107+
pip install datasets
108+
```
109+
110+
Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt")
111+
112+
```
113+
--dataset_name=./3d_icon
114+
--caption_column=prompt
115+
```
116+
117+
You can also load a dataset straight from by specifying it's name in `dataset_name`.
118+
Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset.
119+
120+
- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer
121+
- **pivotal tuning**
122+
- **min SNR gamma**
123+
124+
**Now, we can launch training:**
125+
126+
```bash
127+
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
128+
export DATASET_NAME="./3d_icon"
129+
export OUTPUT_DIR="3d-icon-SDXL-LoRA"
130+
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"
131+
132+
accelerate launch train_dreambooth_lora_sdxl_advanced.py \
133+
--pretrained_model_name_or_path=$MODEL_NAME \
134+
--pretrained_vae_model_name_or_path=$VAE_PATH \
135+
--dataset_name=$DATASET_NAME \
136+
--instance_prompt="3d icon in the style of TOK" \
137+
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
138+
--output_dir=$OUTPUT_DIR \
139+
--caption_column="prompt" \
140+
--mixed_precision="bf16" \
141+
--resolution=1024 \
142+
--train_batch_size=3 \
143+
--repeats=1 \
144+
--report_to="wandb"\
145+
--gradient_accumulation_steps=1 \
146+
--gradient_checkpointing \
147+
--learning_rate=1.0 \
148+
--text_encoder_lr=1.0 \
149+
--optimizer="prodigy"\
150+
--train_text_encoder_ti\
151+
--train_text_encoder_ti_frac=0.5\
152+
--snr_gamma=5.0 \
153+
--lr_scheduler="constant" \
154+
--lr_warmup_steps=0 \
155+
--rank=8 \
156+
--max_train_steps=1000 \
157+
--checkpointing_steps=2000 \
158+
--seed="0" \
159+
--push_to_hub
160+
```
161+
162+
To better track our training experiments, we're using the following flags in the command above:
163+
164+
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
165+
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
166+
167+
Our experiments were conducted on a single 40GB A100 GPU.
168+
169+
170+
### Inference
171+
172+
Once training is done, we can perform inference like so:
173+
1. starting with loading the unet lora weights
174+
```python
175+
import torch
176+
from huggingface_hub import hf_hub_download, upload_file
177+
from diffusers import DiffusionPipeline
178+
from diffusers.models import AutoencoderKL
179+
from safetensors.torch import load_file
180+
181+
username = "linoyts"
182+
repo_id = f"{username}/3d-icon-SDXL-LoRA"
183+
184+
pipe = DiffusionPipeline.from_pretrained(
185+
"stabilityai/stable-diffusion-xl-base-1.0",
186+
torch_dtype=torch.float16,
187+
variant="fp16",
188+
).to("cuda")
189+
190+
191+
pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors")
192+
```
193+
2. now we load the pivotal tuning embeddings
194+
195+
```python
196+
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
197+
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
198+
199+
embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-SDXL-LoRA_emb.safetensors", repo_type="model")
200+
201+
state_dict = load_file(embedding_path)
202+
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
203+
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
204+
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
205+
pipe.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
206+
```
207+
208+
3. let's generate images
209+
210+
```python
211+
instance_token = "<s0><s1>"
212+
prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
213+
214+
image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
215+
image.save("llama.png")
216+
```
217+
218+
### Comfy UI / AUTOMATIC1111 Inference
219+
The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats!
220+
221+
**AUTOMATIC1111 / SD.Next** \
222+
In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time.
223+
- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory.
224+
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory.
225+
226+
You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls <lora:y2k:0.9>`. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`.
227+
228+
**ComfyUI** \
229+
In ComfyUI we will load a LoRA and a textual embedding at the same time.
230+
- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
231+
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/).
232+
-
233+
### Specifying a better VAE
234+
235+
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
236+
237+
238+
### Tips and Tricks
239+
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)
240+
241+
## Running on Colab Notebook
242+
Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_advanced_example.ipynb).
243+
to train using the advanced features (including pivotal tuning), and [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb) to train on a free colab, using some of the advanced features (excluding pivotal tuning)
244+
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
accelerate>=0.16.0
2+
torchvision
3+
transformers>=4.25.1
4+
ftfy
5+
tensorboard
6+
Jinja2
7+
peft==0.7.0

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,9 @@ def save_model_card(
119119
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
120120
from safetensors.torch import load_file
121121
"""
122-
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model")
122+
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
123123
state_dict = load_file(embedding_path)
124124
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
125-
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
126125
"""
127126
webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.
128127
- Place it on it on your `embeddings` folder
@@ -389,7 +388,7 @@ def parse_args(input_args=None):
389388
parser.add_argument(
390389
"--resolution",
391390
type=int,
392-
default=1024,
391+
default=512,
393392
help=(
394393
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
395394
" resolution"
@@ -645,6 +644,7 @@ def parse_args(input_args=None):
645644
parser.add_argument(
646645
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
647646
)
647+
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
648648
parser.add_argument(
649649
"--rank",
650650
type=int,
@@ -745,10 +745,11 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
745745

746746
idx += 1
747747

748+
# copied from train_dreambooth_lora_sdxl_advanced.py
748749
def save_embeddings(self, file_path: str):
749750
assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
750751
tensors = {}
751-
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
752+
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - TODO - change for sd
752753
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
753754
for idx, text_encoder in enumerate(self.text_encoders):
754755
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
@@ -1634,6 +1635,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
16341635

16351636
# Sample noise that we'll add to the latents
16361637
noise = torch.randn_like(model_input)
1638+
if args.noise_offset:
1639+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
1640+
noise += args.noise_offset * torch.randn(
1641+
(model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
1642+
)
16371643
bsz = model_input.shape[0]
16381644
# Sample a random timestep for each image
16391645
timesteps = torch.randint(
@@ -1788,6 +1794,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
17881794
pipeline = StableDiffusionPipeline.from_pretrained(
17891795
args.pretrained_model_name_or_path,
17901796
vae=vae,
1797+
tokenizer=tokenizer_one,
17911798
text_encoder=accelerator.unwrap_model(text_encoder_one),
17921799
unet=accelerator.unwrap_model(unet),
17931800
revision=args.revision,
@@ -1860,6 +1867,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18601867
unet_lora_layers=unet_lora_layers,
18611868
text_encoder_lora_layers=text_encoder_lora_layers,
18621869
)
1870+
1871+
if args.train_text_encoder_ti:
1872+
embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
1873+
embedding_handler.save_embeddings(embeddings_path)
1874+
18631875
images = []
18641876
if args.validation_prompt and args.num_validation_images > 0:
18651877
# Final inference
@@ -1895,6 +1907,18 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18951907
# load attention processors
18961908
pipeline.load_lora_weights(args.output_dir)
18971909

1910+
# load new tokens
1911+
if args.train_text_encoder_ti:
1912+
state_dict = load_file(embeddings_path)
1913+
all_new_tokens = []
1914+
for key, value in token_abstraction_dict.items():
1915+
all_new_tokens.extend(value)
1916+
pipeline.load_textual_inversion(
1917+
state_dict["clip_l"],
1918+
token=all_new_tokens,
1919+
text_encoder=pipeline.text_encoder,
1920+
tokenizer=pipeline.tokenizer,
1921+
)
18981922
# run inference
18991923
pipeline = pipeline.to(accelerator.device)
19001924
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
@@ -1917,11 +1941,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19171941
}
19181942
)
19191943

1920-
if args.train_text_encoder_ti:
1921-
embedding_handler.save_embeddings(
1922-
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
1923-
)
1924-
19251944
# Conver to WebUI format
19261945
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
19271946
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)

0 commit comments

Comments
 (0)