|
| 1 | +# Advanced diffusion training examples |
| 2 | + |
| 3 | +## Train Dreambooth LoRA with Stable Diffusion XL |
| 4 | +> [!TIP] |
| 5 | +> 💡 This example follows the techniques and recommended practices covered in the blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). Make sure to check it out before starting 🤗 |
| 6 | +
|
| 7 | +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. |
| 8 | + |
| 9 | +LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen* |
| 10 | +In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: |
| 11 | +- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114) |
| 12 | +- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable. |
| 13 | +- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter. |
| 14 | +[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in |
| 15 | +the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. |
| 16 | + |
| 17 | +The `train_dreambooth_lora_sdxl_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_sdxl.py`, with |
| 18 | +advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl), |
| 19 | +[Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️ |
| 20 | + |
| 21 | +> [!NOTE] |
| 22 | +> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳 |
| 23 | +> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora) |
| 24 | +
|
| 25 | +📚 Read more about the advanced features and best practices in this community derived blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script) |
| 26 | + |
| 27 | + |
| 28 | +## Running locally with PyTorch |
| 29 | + |
| 30 | +### Installing the dependencies |
| 31 | + |
| 32 | +Before running the scripts, make sure to install the library's training dependencies: |
| 33 | + |
| 34 | +**Important** |
| 35 | + |
| 36 | +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: |
| 37 | +```bash |
| 38 | +git clone https://github.com/huggingface/diffusers |
| 39 | +cd diffusers |
| 40 | +pip install -e . |
| 41 | +``` |
| 42 | + |
| 43 | +Then cd in the `examples/advanced_diffusion_training` folder and run |
| 44 | +```bash |
| 45 | +pip install -r requirements.txt |
| 46 | +``` |
| 47 | + |
| 48 | +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: |
| 49 | + |
| 50 | +```bash |
| 51 | +accelerate config |
| 52 | +``` |
| 53 | + |
| 54 | +Or for a default accelerate configuration without answering questions about your environment |
| 55 | + |
| 56 | +```bash |
| 57 | +accelerate config default |
| 58 | +``` |
| 59 | + |
| 60 | +Or if your environment doesn't support an interactive shell e.g. a notebook |
| 61 | + |
| 62 | +```python |
| 63 | +from accelerate.utils import write_basic_config |
| 64 | +write_basic_config() |
| 65 | +``` |
| 66 | + |
| 67 | +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. |
| 68 | +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. |
| 69 | + |
| 70 | +### Pivotal Tuning |
| 71 | +**Training with text encoder(s)** |
| 72 | + |
| 73 | +Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization |
| 74 | +available with `train_dreambooth_lora_sdxl_advanced.py`, in the advanced script **pivotal tuning** is also supported. |
| 75 | +[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning - |
| 76 | +we insert new tokens into the text encoders of the model, instead of reusing existing ones. |
| 77 | +We then optimize the newly-inserted token embeddings to represent the new concept. |
| 78 | + |
| 79 | +To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`). |
| 80 | +Please keep the following points in mind: |
| 81 | + |
| 82 | +* SDXL has two text encoders. So, we fine-tune both using LoRA. |
| 83 | +* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memoםהקרry. |
| 84 | + |
| 85 | + |
| 86 | +### 3D icon example |
| 87 | + |
| 88 | +Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon. |
| 89 | + |
| 90 | +Let's first download it locally: |
| 91 | + |
| 92 | +```python |
| 93 | +from huggingface_hub import snapshot_download |
| 94 | + |
| 95 | +local_dir = "./3d_icon" |
| 96 | +snapshot_download( |
| 97 | + "LinoyTsaban/3d_icon", |
| 98 | + local_dir=local_dir, repo_type="dataset", |
| 99 | + ignore_patterns=".gitattributes", |
| 100 | +) |
| 101 | +``` |
| 102 | + |
| 103 | +Let's review some of the advanced features we're going to be using for this example: |
| 104 | +- **custom captions**: |
| 105 | +To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by |
| 106 | +```bash |
| 107 | +pip install datasets |
| 108 | +``` |
| 109 | + |
| 110 | +Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt") |
| 111 | + |
| 112 | +``` |
| 113 | +--dataset_name=./3d_icon |
| 114 | +--caption_column=prompt |
| 115 | +``` |
| 116 | + |
| 117 | +You can also load a dataset straight from by specifying it's name in `dataset_name`. |
| 118 | +Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset. |
| 119 | + |
| 120 | +- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer |
| 121 | +- **pivotal tuning** |
| 122 | +- **min SNR gamma** |
| 123 | + |
| 124 | +**Now, we can launch training:** |
| 125 | + |
| 126 | +```bash |
| 127 | +export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0" |
| 128 | +export DATASET_NAME="./3d_icon" |
| 129 | +export OUTPUT_DIR="3d-icon-SDXL-LoRA" |
| 130 | +export VAE_PATH="madebyollin/sdxl-vae-fp16-fix" |
| 131 | + |
| 132 | +accelerate launch train_dreambooth_lora_sdxl_advanced.py \ |
| 133 | + --pretrained_model_name_or_path=$MODEL_NAME \ |
| 134 | + --pretrained_vae_model_name_or_path=$VAE_PATH \ |
| 135 | + --dataset_name=$DATASET_NAME \ |
| 136 | + --instance_prompt="3d icon in the style of TOK" \ |
| 137 | + --validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \ |
| 138 | + --output_dir=$OUTPUT_DIR \ |
| 139 | + --caption_column="prompt" \ |
| 140 | + --mixed_precision="bf16" \ |
| 141 | + --resolution=1024 \ |
| 142 | + --train_batch_size=3 \ |
| 143 | + --repeats=1 \ |
| 144 | + --report_to="wandb"\ |
| 145 | + --gradient_accumulation_steps=1 \ |
| 146 | + --gradient_checkpointing \ |
| 147 | + --learning_rate=1.0 \ |
| 148 | + --text_encoder_lr=1.0 \ |
| 149 | + --optimizer="prodigy"\ |
| 150 | + --train_text_encoder_ti\ |
| 151 | + --train_text_encoder_ti_frac=0.5\ |
| 152 | + --snr_gamma=5.0 \ |
| 153 | + --lr_scheduler="constant" \ |
| 154 | + --lr_warmup_steps=0 \ |
| 155 | + --rank=8 \ |
| 156 | + --max_train_steps=1000 \ |
| 157 | + --checkpointing_steps=2000 \ |
| 158 | + --seed="0" \ |
| 159 | + --push_to_hub |
| 160 | +``` |
| 161 | + |
| 162 | +To better track our training experiments, we're using the following flags in the command above: |
| 163 | + |
| 164 | +* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. |
| 165 | +* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. |
| 166 | + |
| 167 | +Our experiments were conducted on a single 40GB A100 GPU. |
| 168 | + |
| 169 | + |
| 170 | +### Inference |
| 171 | + |
| 172 | +Once training is done, we can perform inference like so: |
| 173 | +1. starting with loading the unet lora weights |
| 174 | +```python |
| 175 | +import torch |
| 176 | +from huggingface_hub import hf_hub_download, upload_file |
| 177 | +from diffusers import DiffusionPipeline |
| 178 | +from diffusers.models import AutoencoderKL |
| 179 | +from safetensors.torch import load_file |
| 180 | + |
| 181 | +username = "linoyts" |
| 182 | +repo_id = f"{username}/3d-icon-SDXL-LoRA" |
| 183 | + |
| 184 | +pipe = DiffusionPipeline.from_pretrained( |
| 185 | + "stabilityai/stable-diffusion-xl-base-1.0", |
| 186 | + torch_dtype=torch.float16, |
| 187 | + variant="fp16", |
| 188 | +).to("cuda") |
| 189 | + |
| 190 | + |
| 191 | +pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors") |
| 192 | +``` |
| 193 | +2. now we load the pivotal tuning embeddings |
| 194 | + |
| 195 | +```python |
| 196 | +text_encoders = [pipe.text_encoder, pipe.text_encoder_2] |
| 197 | +tokenizers = [pipe.tokenizer, pipe.tokenizer_2] |
| 198 | + |
| 199 | +embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-SDXL-LoRA_emb.safetensors", repo_type="model") |
| 200 | + |
| 201 | +state_dict = load_file(embedding_path) |
| 202 | +# load embeddings of text_encoder 1 (CLIP ViT-L/14) |
| 203 | +pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) |
| 204 | +# load embeddings of text_encoder 2 (CLIP ViT-G/14) |
| 205 | +pipe.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) |
| 206 | +``` |
| 207 | + |
| 208 | +3. let's generate images |
| 209 | + |
| 210 | +```python |
| 211 | +instance_token = "<s0><s1>" |
| 212 | +prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}" |
| 213 | + |
| 214 | +image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0] |
| 215 | +image.save("llama.png") |
| 216 | +``` |
| 217 | + |
| 218 | +### Comfy UI / AUTOMATIC1111 Inference |
| 219 | +The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats! |
| 220 | + |
| 221 | +**AUTOMATIC1111 / SD.Next** \ |
| 222 | +In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time. |
| 223 | +- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. |
| 224 | +- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory. |
| 225 | + |
| 226 | +You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls <lora:y2k:0.9>`. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`. |
| 227 | + |
| 228 | +**ComfyUI** \ |
| 229 | +In ComfyUI we will load a LoRA and a textual embedding at the same time. |
| 230 | +- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/) |
| 231 | +- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/). |
| 232 | +- |
| 233 | +### Specifying a better VAE |
| 234 | + |
| 235 | +SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). |
| 236 | + |
| 237 | + |
| 238 | +### Tips and Tricks |
| 239 | +Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices) |
| 240 | + |
| 241 | +## Running on Colab Notebook |
| 242 | +Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_advanced_example.ipynb). |
| 243 | +to train using the advanced features (including pivotal tuning), and [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb) to train on a free colab, using some of the advanced features (excluding pivotal tuning) |
| 244 | + |
0 commit comments