diff --git a/docs/source/en/api/pipelines/bria_fibo_edit.md b/docs/source/en/api/pipelines/bria_fibo_edit.md new file mode 100644 index 000000000000..b46dd78cdb90 --- /dev/null +++ b/docs/source/en/api/pipelines/bria_fibo_edit.md @@ -0,0 +1,33 @@ + + +# Bria Fibo Edit + +Fibo Edit is an 8B parameter image-to-image model that introduces a new paradigm of structured control, operating on JSON inputs paired with source images to enable deterministic and repeatable editing workflows. +Featuring native masking for granular precision, it moves beyond simple prompt-based diffusion to offer explicit, interpretable control optimized for production environments. +Its lightweight architecture is designed for deep customization, empowering researchers to build specialized "Edit" models for domain-specific tasks while delivering top-tier aesthetic quality + +## Usage +_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/Fibo-Edit), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._ + +Use the command below to log in: + +```bash +hf auth login +``` + + +## BriaFiboEditPipeline + +[[autodoc]] BriaFiboEditPipeline + - all + - __call__ \ No newline at end of file diff --git a/examples/dreambooth/README_fibo_edit.md b/examples/dreambooth/README_fibo_edit.md new file mode 100644 index 000000000000..f2e5e88be7d8 --- /dev/null +++ b/examples/dreambooth/README_fibo_edit.md @@ -0,0 +1,87 @@ +# DreamBooth LoRA training example for Bria Fibo Edit + +[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models given just a few images of a subject. + +The `train_dreambooth_fibo_edit.py` script shows how to implement LoRA fine-tuning for [Bria Fibo Edit](https://huggingface.co/briaai/Fibo-edit), an image editing model. + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the `examples/dreambooth` folder and run: +```bash +pip install -r requirements_fibo_edit.txt +``` + +And initialize an [Accelerate](https://github.com/huggingface/accelerate/) environment: + +```bash +accelerate config default +``` + +### Dataset format + +The training script expects a dataset with the following columns: +- `input_image`: Source image (before editing) +- `image`: Target image (after editing) +- `caption`: Edit instruction in JSON format + +You can use a HuggingFace dataset via `--dataset_name` or a local directory via `--instance_data_dir`. + +### Training + +```bash +export MODEL_NAME="briaai/Fibo-edit" +export DATASET_NAME="your-dataset" +export OUTPUT_DIR="fibo-edit-dreambooth-lora" + +accelerate launch train_dreambooth_fibo_edit.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --output_dir=$OUTPUT_DIR \ + --mixed_precision="bf16" \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --learning_rate=1e-4 \ + --lr_scheduler="cosine_with_warmup" \ + --lr_warmup_steps=100 \ + --max_train_steps=1500 \ + --lora_rank=128 \ + --checkpointing_steps=250 \ + --seed=10 +``` + +### Key arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--lora_rank` | 128 | LoRA rank for fine-tuning | +| `--learning_rate` | 1e-4 | Initial learning rate | +| `--lr_scheduler` | cosine_with_warmup | Learning rate scheduler | +| `--optimizer` | AdamW | Optimizer (AdamW or prodigy) | +| `--gradient_checkpointing` | 1 | Enable gradient checkpointing to save memory | +| `--mixed_precision` | bf16 | Mixed precision training mode | + +### Resume from checkpoint + +To resume training from a checkpoint: + +```bash +accelerate launch train_dreambooth_fibo_edit.py \ + ... \ + --resume_from_checkpoint="latest" +``` + +Or specify a specific checkpoint path: + +```bash +--resume_from_checkpoint="/path/to/checkpoint_500" +``` diff --git a/examples/dreambooth/requirements_fibo_edit.txt b/examples/dreambooth/requirements_fibo_edit.txt new file mode 100644 index 000000000000..fdede0e1d573 --- /dev/null +++ b/examples/dreambooth/requirements_fibo_edit.txt @@ -0,0 +1,7 @@ +accelerate>=0.31.0 +torchvision +transformers>=4.41.2 +peft>=0.11.1 +ujson +Pillow +tqdm diff --git a/examples/dreambooth/test_dreambooth_fibo_edit.py b/examples/dreambooth/test_dreambooth_fibo_edit.py new file mode 100644 index 000000000000..20ef9d0ebf2d --- /dev/null +++ b/examples/dreambooth/test_dreambooth_fibo_edit.py @@ -0,0 +1,445 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile +import unittest + +import safetensors +import torch + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBoothFiboEditUnitTests(unittest.TestCase): + """Unit tests for helper functions in train_dreambooth_fibo_edit.py""" + + def test_find_closest_resolution(self): + """Test the find_closest_resolution function for aspect ratio selection.""" + # Import the function from the training script + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from train_dreambooth_fibo_edit import RESOLUTIONS_1k, find_closest_resolution + + # Test square image (1:1 aspect ratio) + width, height = find_closest_resolution(1024, 1024) + self.assertEqual((width, height), (1024, 1024)) + + # Test landscape image + width, height = find_closest_resolution(1920, 1080) + # 1920/1080 = 1.778, closest to 1.750 which maps to (1344, 768) + self.assertIn((width, height), list(RESOLUTIONS_1k.values())) + self.assertGreater(width, height) # Should be landscape + + # Test portrait image + width, height = find_closest_resolution(1080, 1920) + # 1080/1920 = 0.5625, closest to 0.67 which maps to (832, 1248) + self.assertIn((width, height), list(RESOLUTIONS_1k.values())) + self.assertLess(width, height) # Should be portrait + + def test_clean_json_caption_valid(self): + """Test clean_json_caption with valid JSON.""" + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from train_dreambooth_fibo_edit import clean_json_caption + + # Test valid JSON + valid_json = '{"prompt": "a photo of a cat", "style": "realistic"}' + result = clean_json_caption(valid_json) + self.assertIsInstance(result, str) + # Should be valid JSON after cleaning + import json + + parsed = json.loads(result) + self.assertEqual(parsed["prompt"], "a photo of a cat") + + def test_clean_json_caption_invalid(self): + """Test clean_json_caption with invalid JSON raises ValueError.""" + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from train_dreambooth_fibo_edit import clean_json_caption + + # Test invalid JSON + invalid_json = "not a valid json" + with self.assertRaises(ValueError): + clean_json_caption(invalid_json) + + def test_create_attention_matrix(self): + """Test the create_attention_matrix function.""" + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from train_dreambooth_fibo_edit import create_attention_matrix + + # Create a simple attention mask + attention_mask = torch.tensor([[1, 1, 0], [1, 1, 1]], dtype=torch.float32) + result = create_attention_matrix(attention_mask) + + # Check output shape + self.assertEqual(result.shape, (2, 3, 3)) + + # Check that 1s map to 0 (keep) and 0s map to -inf (ignore) + # First batch: [1,1,0] -> matrix where positions with mask 0 become -inf + self.assertEqual(result[0, 0, 0].item(), 0.0) # 1*1 = 1 -> 0 + self.assertEqual(result[0, 0, 2].item(), float("-inf")) # 1*0 = 0 -> -inf + self.assertEqual(result[0, 2, 0].item(), float("-inf")) # 0*1 = 0 -> -inf + + def test_pad_embedding(self): + """Test the pad_embedding function.""" + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from train_dreambooth_fibo_edit import pad_embedding + + # Create a sample embedding + batch_size, seq_len, dim = 2, 5, 64 + embedding = torch.randn(batch_size, seq_len, dim) + max_tokens = 10 + + padded_embedding, attention_mask = pad_embedding(embedding, max_tokens) + + # Check shapes + self.assertEqual(padded_embedding.shape, (batch_size, max_tokens, dim)) + self.assertEqual(attention_mask.shape, (batch_size, max_tokens)) + + # Check attention mask values + # First seq_len positions should be 1 (unmasked) + self.assertTrue(torch.all(attention_mask[:, :seq_len] == 1)) + # Remaining positions should be 0 (masked/padded) + self.assertTrue(torch.all(attention_mask[:, seq_len:] == 0)) + + # Check that original content is preserved + self.assertTrue(torch.allclose(padded_embedding[:, :seq_len, :], embedding)) + + def test_shifted_logit_normal_timestep_sampler(self): + """Test the ShiftedLogitNormalTimestepSampler.""" + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from train_dreambooth_fibo_edit import ShiftedLogitNormalTimestepSampler + + sampler = ShiftedLogitNormalTimestepSampler(std=1.0) + + batch_size = 10 + seq_length = 1024 + + timesteps = sampler.sample(batch_size, seq_length) + + # Check output shape + self.assertEqual(timesteps.shape, (batch_size,)) + + # Check that all timesteps are in valid range [0, 1] + self.assertTrue(torch.all(timesteps >= 0)) + self.assertTrue(torch.all(timesteps <= 1)) + + def test_uniform_timestep_sampler(self): + """Test the UniformTimestepSampler.""" + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from train_dreambooth_fibo_edit import UniformTimestepSampler + + sampler = UniformTimestepSampler(min_value=0.0, max_value=1.0) + + batch_size = 100 + + timesteps = sampler.sample(batch_size) + + # Check output shape + self.assertEqual(timesteps.shape, (batch_size,)) + + # Check that all timesteps are in valid range [0, 1] + self.assertTrue(torch.all(timesteps >= 0)) + self.assertTrue(torch.all(timesteps <= 1)) + + def test_shifted_stretched_logit_normal_timestep_sampler(self): + """Test the ShiftedStretchedLogitNormalTimestepSampler.""" + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from train_dreambooth_fibo_edit import ShiftedStretchedLogitNormalTimestepSampler + + sampler = ShiftedStretchedLogitNormalTimestepSampler(std=1.0, uniform_prob=0.1) + + batch_size = 100 + seq_length = 1024 + + timesteps = sampler.sample(batch_size, seq_length) + + # Check output shape + self.assertEqual(timesteps.shape, (batch_size,)) + + # Check that all timesteps are in valid range [0, 1] + self.assertTrue(torch.all(timesteps >= 0)) + self.assertTrue(torch.all(timesteps <= 1)) + + def test_resolutions_1k_coverage(self): + """Test that RESOLUTIONS_1k covers common aspect ratios.""" + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from train_dreambooth_fibo_edit import RESOLUTIONS_1k + + # Check that we have expected aspect ratios + aspect_ratios = list(RESOLUTIONS_1k.keys()) + + # Should have portrait (< 1), square (= 1), and landscape (> 1) ratios + portrait_ratios = [r for r in aspect_ratios if r < 1] + square_ratios = [r for r in aspect_ratios if r == 1] + landscape_ratios = [r for r in aspect_ratios if r > 1] + + self.assertGreater(len(portrait_ratios), 0, "Should have portrait aspect ratios") + self.assertEqual(len(square_ratios), 1, "Should have exactly one square ratio") + self.assertGreater(len(landscape_ratios), 0, "Should have landscape aspect ratios") + + # All resolutions should be divisible by 16 (for VAE) + for ratio, (w, h) in RESOLUTIONS_1k.items(): + self.assertEqual(w % 16, 0, f"Width {w} for ratio {ratio} should be divisible by 16") + self.assertEqual(h % 16, 0, f"Height {h} for ratio {ratio} should be divisible by 16") + + +@unittest.skipUnless( + os.environ.get("RUN_SLOW", "0") == "1", + "Slow tests require RUN_SLOW=1 environment variable and a tiny test model", +) +class DreamBoothLoRAFiboEdit(ExamplesTestsAccelerate): + """ + Integration tests for train_dreambooth_fibo_edit.py. + + NOTE: These tests require a tiny test model at 'hf-internal-testing/tiny-bria-fibo-edit-pipe' + or the pretrained_model_name_or_path to be updated to point to an available tiny model. + + To run these tests, set RUN_SLOW=1 and ensure the test model is available. + """ + + # NOTE: Update this path once a tiny test model is available + pretrained_model_name_or_path = "hf-internal-testing/tiny-bria-fibo-edit-pipe" + script_path = "examples/dreambooth/train_dreambooth_fibo_edit.py" + # For fibo-edit, we need a dataset with source images, target images, and JSON captions + # Using a test dataset that should be set up for this purpose + dataset_name = "hf-internal-testing/fibo-edit-test-dataset" + + def test_dreambooth_lora_fibo_edit(self): + """Test basic LoRA training for Fibo Edit.""" + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --dataset_name {self.dataset_name} + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 1e-4 + --lr_scheduler constant + --lr_warmup_steps 0 + --lora_rank 4 + --output_dir {tmpdir} + --mixed_precision no + --checkpointing_steps 1000 + """.split() + + run_command(self._launch_args + test_args) + + # Check that checkpoint was saved + checkpoint_dirs = [d for d in os.listdir(tmpdir) if d.startswith("checkpoint")] + # Either we have checkpoints or final output + self.assertTrue( + len(checkpoint_dirs) > 0 or os.path.exists(os.path.join(tmpdir, "checkpoint_final")), + "Expected checkpoint directories to be created", + ) + + def test_dreambooth_lora_fibo_edit_with_gradient_checkpointing(self): + """Test LoRA training with gradient checkpointing enabled.""" + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --dataset_name {self.dataset_name} + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 1e-4 + --lr_scheduler constant + --lr_warmup_steps 0 + --lora_rank 4 + --gradient_checkpointing 1 + --output_dir {tmpdir} + --mixed_precision no + --checkpointing_steps 1000 + """.split() + + run_command(self._launch_args + test_args) + + # Check that checkpoint was saved + checkpoint_dirs = [d for d in os.listdir(tmpdir) if d.startswith("checkpoint")] + self.assertTrue( + len(checkpoint_dirs) > 0 or os.path.exists(os.path.join(tmpdir, "checkpoint_final")), + "Expected checkpoint directories to be created", + ) + + def test_dreambooth_lora_fibo_edit_checkpointing(self): + """Test checkpointing functionality.""" + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --dataset_name {self.dataset_name} + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 4 + --learning_rate 1e-4 + --lr_scheduler constant + --lr_warmup_steps 0 + --lora_rank 4 + --checkpointing_steps 2 + --output_dir {tmpdir} + --mixed_precision no + """.split() + + run_command(self._launch_args + test_args) + + # Check that intermediate checkpoints were created + checkpoint_dirs = {d for d in os.listdir(tmpdir) if d.startswith("checkpoint")} + # Should have checkpoint at step 2 (and possibly final) + self.assertTrue( + "checkpoint_2" in checkpoint_dirs or "checkpoint_final" in checkpoint_dirs, + f"Expected checkpoints, found: {checkpoint_dirs}", + ) + + def test_dreambooth_lora_fibo_edit_resume_from_checkpoint(self): + """Test resume from checkpoint functionality.""" + with tempfile.TemporaryDirectory() as tmpdir: + # First training run + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --dataset_name {self.dataset_name} + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 1e-4 + --lr_scheduler constant + --lr_warmup_steps 0 + --lora_rank 4 + --checkpointing_steps 2 + --output_dir {tmpdir} + --mixed_precision no + """.split() + + run_command(self._launch_args + test_args) + + # Check that checkpoint was created + checkpoint_dirs = [d for d in os.listdir(tmpdir) if d.startswith("checkpoint")] + self.assertGreater(len(checkpoint_dirs), 0, "Expected checkpoint to be created") + + # Resume training + resume_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --dataset_name {self.dataset_name} + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 4 + --learning_rate 1e-4 + --lr_scheduler constant + --lr_warmup_steps 0 + --lora_rank 4 + --checkpointing_steps 2 + --resume_from_checkpoint latest + --output_dir {tmpdir} + --mixed_precision no + """.split() + + run_command(self._launch_args + resume_args) + + # Check that training continued and created more checkpoints + final_checkpoint_dirs = [d for d in os.listdir(tmpdir) if d.startswith("checkpoint")] + self.assertGreater( + len(final_checkpoint_dirs), + len(checkpoint_dirs), + "Expected more checkpoints after resuming", + ) + + def test_dreambooth_lora_fibo_edit_different_lr_schedulers(self): + """Test different learning rate schedulers.""" + schedulers = ["constant", "cosine_with_warmup"] + + for scheduler in schedulers: + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --dataset_name {self.dataset_name} + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 1e-4 + --lr_scheduler {scheduler} + --lr_warmup_steps 1 + --lora_rank 4 + --output_dir {tmpdir} + --mixed_precision no + --checkpointing_steps 1000 + """.split() + + run_command(self._launch_args + test_args) + + # Check that training completed + checkpoint_dirs = [d for d in os.listdir(tmpdir) if d.startswith("checkpoint")] + self.assertTrue( + len(checkpoint_dirs) > 0 or os.path.exists(os.path.join(tmpdir, "checkpoint_final")), + f"Expected checkpoints with scheduler {scheduler}", + ) + + def test_dreambooth_lora_fibo_edit_lora_weights_structure(self): + """Test that LoRA weights have the correct structure.""" + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --dataset_name {self.dataset_name} + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 1e-4 + --lr_scheduler constant + --lr_warmup_steps 0 + --lora_rank 4 + --output_dir {tmpdir} + --mixed_precision no + --checkpointing_steps 1 + """.split() + + run_command(self._launch_args + test_args) + + # Find the checkpoint directory + checkpoint_dirs = [d for d in os.listdir(tmpdir) if d.startswith("checkpoint")] + self.assertGreater(len(checkpoint_dirs), 0, "Expected checkpoint to be created") + + # Check for LoRA weights file in checkpoint + checkpoint_path = os.path.join(tmpdir, checkpoint_dirs[0]) + lora_weights_path = os.path.join(checkpoint_path, "pytorch_lora_weights.safetensors") + + if os.path.exists(lora_weights_path): + # Load and verify LoRA weights + lora_state_dict = safetensors.torch.load_file(lora_weights_path) + + # Check that all keys contain "lora" + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora, "All LoRA state dict keys should contain 'lora'") + + # Check that all keys start with "transformer" + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer, "All keys should start with 'transformer'") + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/dreambooth/train_dreambooth_fibo_edit.py b/examples/dreambooth/train_dreambooth_fibo_edit.py new file mode 100644 index 000000000000..8f346f79daa4 --- /dev/null +++ b/examples/dreambooth/train_dreambooth_fibo_edit.py @@ -0,0 +1,1452 @@ +import abc +import argparse +import io +import itertools +import json +import logging +import os +import random +from datetime import datetime +from pathlib import Path +from typing import List, Union + +import torch +import transformers +import ujson +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ( + DistributedDataParallelKwargs, + ProjectConfiguration, + set_seed, +) +from huggingface_hub import HfFolder +from peft import LoraConfig, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup + +import diffusers +from diffusers import AutoencoderKLWan, BriaFiboEditPipeline +from diffusers.loaders.lora_pipeline import FluxLoraLoaderMixin +from diffusers.models.transformers.transformer_bria_fibo import ( + BriaFiboTransformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import cast_training_params +from diffusers.utils import convert_unet_state_dict_to_peft + + +# Set Logger +logger = get_logger(__name__, log_level="INFO") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default="briaai/Fibo-edit", + required=False, + ) + parser.add_argument( + "--output_dir", + type=str, + default="fibo-edit-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--seed", + type=int, + default=10, + help="A seed for reproducible training.", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=3000, + help="Maximum sequence length to use with with the T5 text encoder", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=1501, + help="Total number of training steps to perform.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=4, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="cosine_with_warmup", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup", "cosine_with_warmup", "constant_with_warmup_cosine_decay"' + ), + ) + parser.add_argument( + "--constant_steps", + type=int, + default=-1, + help=("Amount of constsnt lr steps"), + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=100, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + default=True, + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam and Prodigy optimizers.", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam and Prodigy optimizers.", + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=1e-3, + help="Weight decay to use.", + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-15, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_decouple", + type=bool, + default=True, + help="Use AdamW style decoupled weight decay", + ) + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="bf16", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=250, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--drop_rate_cfg", + type=float, + default=0.0, + help="Rate for Classifier Free Guidance dropping.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default="no", + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_checkpointing", + type=int, + default=1, + required=False, + help="Path to pretrained ELLA model", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=128, + ) + + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="caption", + help="The column of the dataset containing the instance prompt for each image", + ) + ( + parser.add_argument( + "--repeats", + type=int, + default=1, + help="How many times to repeat the training data.", + ), + ) + parser.add_argument( + "--input_image_column", + type=str, + default="input_image", + help="The column of the dataset containing the source image.", + ) + args = parser.parse_args() + return args + + +# Resolution mapping for dynamic aspect ratio selection +RESOLUTIONS_1k = { + 0.67: (832, 1248), + 0.778: (896, 1152), + 0.883: (960, 1088), + 1.000: (1024, 1024), + 1.133: (1088, 960), + 1.286: (1152, 896), + 1.462: (1216, 832), + 1.600: (1280, 800), + 1.750: (1344, 768), +} + + +def find_closest_resolution(image_width, image_height): + """Find the closest aspect ratio from RESOLUTIONS_1k and return the target dimensions.""" + image_aspect = image_width / image_height + aspect_ratios = list(RESOLUTIONS_1k.keys()) + closest_ratio = min(aspect_ratios, key=lambda x: abs(x - image_aspect)) + return RESOLUTIONS_1k[closest_ratio] + + +def create_attention_matrix(attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix + + +@torch.no_grad() +def get_smollm_prompt_embeds( + tokenizer: AutoTokenizer, + text_encoder: AutoModelForCausalLM, + prompts: Union[str, List[str]] = None, + max_sequence_length: int = 2048, +): + prompts = [prompts] if isinstance(prompts, str) else prompts + bot_token_id = 128000 # same as Llama + + if prompts[0] == "": + bs = len(prompts) + assert all(p == "" for p in prompts) + text_input_ids = torch.zeros([bs, 1], dtype=torch.int64, device=text_encoder.device) + bot_token_id + attention_mask = torch.ones([bs, 1], dtype=torch.int64, device=text_encoder.device) + else: + text_inputs = tokenizer( + prompts, + padding="longest", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(text_encoder.device) + attention_mask = text_inputs.attention_mask.to(text_encoder.device) + + if len(prompts) == 1: + assert (attention_mask == 1).all() + + hidden_states = text_encoder( + text_input_ids, attention_mask=attention_mask, output_hidden_states=True + ).hidden_states + # We need a 4096 dim so since we have 2048 we take last 2 layers + prompt_embeds = torch.concat([hidden_states[-1], hidden_states[-2]], dim=-1) + + return prompt_embeds, hidden_states, attention_mask + + +def open_image_from_binary(binary_data): + return Image.open(io.BytesIO(binary_data)) + + +def pad_embedding(prompt_embeds, max_tokens): + # Pads a tensor which is not masked, i.e. the "initial" tensor mask is 1's + # We extend the tokens to max tokens and provide a mask to differentiate the masked areas + b, seq_len, dim = prompt_embeds.shape + padding = torch.zeros( + (b, max_tokens - seq_len, dim), + dtype=prompt_embeds.dtype, + device=prompt_embeds.device, + ) + attentions_mask = torch.zeros((b, max_tokens), dtype=prompt_embeds.dtype, device=prompt_embeds.device) + attentions_mask[:, :seq_len] = 1 # original tensor is not masked + prompt_embeds = torch.concat([prompt_embeds, padding], dim=1) + + return prompt_embeds, attentions_mask + + +class ShiftedLogitNormalTimestepSampler: + """ + Samples timesteps from a shifted logit-normal distribution, + where the shift is determined by the sequence length. + """ + + def __init__(self, std: float = 1.0): + self.std = std + + def sample(self, batch_size: int, seq_length: int, device: torch.device = None) -> torch.Tensor: + """Sample timesteps for a batch from a shifted logit-normal distribution. + + Args: + batch_size: Number of timesteps to sample + seq_length: Length of the sequence being processed, used to determine the shift + device: Device to place the samples on + + Returns: + Tensor of shape (batch_size,) containing timesteps sampled from a shifted + logit-normal distribution, where the shift is determined by seq_length + """ + shift = self._get_shift_for_sequence_length(seq_length) + normal_samples = torch.randn((batch_size,), device=device) * self.std + shift + sigmas = torch.sigmoid(normal_samples) + return sigmas + + def sample_for(self, batch: torch.Tensor) -> torch.Tensor: + """Sample timesteps for a specific batch tensor. + + Args: + batch: Input tensor of shape (batch_size, seq_length, ...) + + Returns: + Tensor of shape (batch_size,) containing timesteps sampled from a shifted + logit-normal distribution, where the shift is determined by the sequence length + of the input batch + + Raises: + ValueError: If the input batch does not have 3 dimensions + """ + if batch.ndim != 3: + raise ValueError(f"Batch should have 3 dimensions, got {batch.ndim}") + + batch_size, seq_length, _ = batch.shape + return self.sample(batch_size, seq_length, device=batch.device) + + @staticmethod + def _get_shift_for_sequence_length( + seq_length: int, + min_tokens: int = 256, + max_tokens: int = 4096, + min_shift: float = 0.5, + max_shift: float = 1.15, + ) -> float: + # Calculate the shift value for a given sequence length using linear interpolation + # between min_shift and max_shift based on sequence length. + m = (max_shift - min_shift) / (max_tokens - min_tokens) # Calculate slope + b = min_shift - m * min_tokens # Calculate y-intercept + shift = m * seq_length + b # Apply linear equation y = mx + b + return shift + + +class TimestepSampler(abc.ABC): + """Base class for timestep samplers. + + Timestep samplers are used to sample timesteps for diffusion models. + They should implement both sample() and sample_for() methods. + """ + + def sample( + self, + batch_size: int, + seq_length: int | None = None, + device: torch.device = None, + ) -> torch.Tensor: + """Sample timesteps for a batch. + + Args: + batch_size: Number of timesteps to sample + seq_length: (optional) Length of the sequence being processed + device: Device to place the samples on + + Returns: + Tensor of shape (batch_size,) containing timesteps + """ + raise NotImplementedError + + def sample_for(self, batch: torch.Tensor) -> torch.Tensor: + """Sample timesteps for a specific batch tensor. + + Args: + batch: Input tensor of shape (batch_size, seq_length, ...) + + Returns: + Tensor of shape (batch_size,) containing timesteps + """ + raise NotImplementedError + + +class UniformTimestepSampler(TimestepSampler): + """Samples timesteps uniformly between min_value and max_value (default 0 and 1).""" + + def __init__(self, min_value: float = 0.0, max_value: float = 1.0): + self.min_value = min_value + self.max_value = max_value + + def sample( + self, + batch_size: int, + seq_length: int | None = None, + device: torch.device = None, + ) -> torch.Tensor: # noqa: ARG002 + return torch.rand(batch_size, device=device) * (self.max_value - self.min_value) + self.min_value + + def sample_for(self, batch: torch.Tensor) -> torch.Tensor: + if batch.ndim != 3: + raise ValueError(f"Batch should have 3 dimensions, got {batch.ndim}") + + batch_size, seq_length, _ = batch.shape + return self.sample(batch_size, device=batch.device) + + +class ShiftedStretchedLogitNormalTimestepSampler: + """ + Samples timesteps from a stretched logit-normal distribution, + where the shift is determined by the sequence length. + """ + + def __init__(self, std: float = 1.0, uniform_prob: float = 0.1): + self.std = std + self.shifted_logit_normal_sampler = ShiftedLogitNormalTimestepSampler(std=std) + self.uniform_sampler = UniformTimestepSampler() + self.uniform_prob = uniform_prob + + def sample(self, batch_size: int, seq_length: int, device: torch.device = None) -> torch.Tensor: + # Determine which sampler to use for each batch element + should_use_uniform = torch.rand(batch_size, device=device) < self.uniform_prob + + # Initialize an empty tensor for the results + timesteps = torch.empty(batch_size, device=device) + + # Sample from uniform sampler where should_use_uniform is True + num_uniform = should_use_uniform.sum().item() + if num_uniform > 0: + timesteps[should_use_uniform] = self.uniform_sampler.sample( + batch_size=num_uniform, seq_length=seq_length, device=device + ) + + # Sample from shifted logit-normal sampler where should_use_uniform is False + should_use_shifted = ~should_use_uniform + num_shifted = should_use_shifted.sum().item() + if num_shifted > 0: + timesteps[should_use_shifted] = self.shifted_logit_normal_sampler.sample( + batch_size=num_shifted, seq_length=seq_length, device=device + ) + return timesteps + + def sample_for(self, batch: torch.Tensor) -> torch.Tensor: + """Sample timesteps for a specific batch tensor. + + Args: + batch: Input tensor of shape (batch_size, seq_length, ...) + + Returns: + Tensor of shape (batch_size,) containing timesteps + + Raises: + ValueError: If the input batch does not have 3 dimensions + """ + if batch.ndim != 3: + raise ValueError(f"Batch should have 3 dimensions, got {batch.ndim}") + + batch_size, seq_length, _ = batch.shape + return self.sample(batch_size=batch_size, seq_length=seq_length, device=batch.device) + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance images with the prompts for fine-tuning the model. + Images are dynamically resized and center-cropped to the closest aspect ratio from RESOLUTIONS_1k. + """ + + def __init__( + self, + instance_data_root, + repeats=1, + ): + self.custom_instance_prompts = None + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.input_image_column is None: + input_image_column = column_names[0] + logger.info(f"source image column defaulting to {input_image_column}") + else: + input_image_column = args.input_image_column + if input_image_column not in column_names: + raise ValueError( + f"`--input_image_column` value '{args.input_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + source_images = dataset["train"][input_image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + # Validate and normalize the JSON caption (raises error if invalid) + cleaned_caption = clean_json_caption(caption) + self.custom_instance_prompts.extend(itertools.repeat(cleaned_caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir()) if path.is_file()] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + img = open_image_from_binary(img) + self.instance_images.extend(itertools.repeat(img, repeats)) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + self.source_images = [] + for img in source_images: + img = open_image_from_binary(img) + self.source_images.extend(itertools.repeat(img, repeats)) + + self.num_source_images = len(self.source_images) + self._length = self.num_source_images + # Normalization transform (applied after resize/crop) + self.to_tensor_normalize = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + + def __len__(self): + return self._length + + def _process_image(self, image): + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + return image + + def __getitem__(self, index): + example = {} + # Get the original image + instance_image = self.instance_images[index % self.num_instance_images] + source_image = self.source_images[index % self.num_source_images] + instance_image = self._process_image(instance_image) + source_image = self._process_image(source_image) + + # Get image dimensions and find closest resolution + img_width, img_height = instance_image.size + target_width, target_height = find_closest_resolution(img_width, img_height) + + # Resize and center crop to target dimensions + # Calculate scale factor to ensure we can center crop to target dimensions + target_aspect = target_width / target_height + img_aspect = img_width / img_height + + if img_aspect > target_aspect: + # Image is wider than target, resize based on height + scale = target_height / img_height + else: + # Image is taller than target, resize based on width + scale = target_width / img_width + + new_width = int(img_width * scale) + new_height = int(img_height * scale) + + # Resize maintaining aspect ratio + instance_image = transforms.Resize( + (new_height, new_width), interpolation=transforms.InterpolationMode.BILINEAR + )(instance_image) + source_image = transforms.Resize((new_height, new_width), interpolation=transforms.InterpolationMode.BILINEAR)( + source_image + ) + + # Center crop to exact target dimensions + instance_image = transforms.CenterCrop((target_height, target_width))(instance_image) + source_image = transforms.CenterCrop((target_height, target_width))(source_image) + # Convert to tensor and normalize + instance_image = self.to_tensor_normalize(instance_image) + source_image = self.to_tensor_normalize(source_image) + example["instance_images"] = instance_image + example["source_images"] = source_image + example["target_width"] = target_width + example["target_height"] = target_height + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + raise ValueError("Caption cannot be empty when custom_instance_prompts is provided") + else: + raise ValueError( + "Captions must be provided via --caption_column when using --dataset_name, or via dataset metadata when loading from directory" + ) + + return example + + +def clean_json_caption(caption): + """Validate and normalize JSON caption format. Raises ValueError if caption is not valid JSON.""" + try: + caption = json.loads(caption) + return ujson.dumps(caption, escape_forward_slashes=False) + except (json.JSONDecodeError, TypeError) as e: + raise ValueError( + f"Caption must be in valid JSON format. Error: {e}. Caption: {caption[:100] if len(str(caption)) > 100 else caption}" + ) + + +def add_lora(transformer, lora_rank): + target_modules = [ + # HF Lora Layers + "attn.to_k", + "attn.to_q", + "attn.to_v", + "attn.to_out.0", + "attn.add_k_proj", + "attn.add_q_proj", + "attn.add_v_proj", + "attn.to_add_out", + "ff.net.0.proj", + "ff.net.2", + "ff_context.net.0.proj", + "ff_context.net.2", + "proj_mlp", + # + layers that exist on ostris ai-toolkit / replicate trainer + "norm1_context.linear", + "norm1.linear", + "norm.linear", + "proj_out", + ] + transformer_lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_rank, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + transformer.add_adapter(transformer_lora_config) + + +def set_lora_training(accelerator, transformer, lora_rank): + add_lora(transformer, lora_rank) + + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(accelerator.unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + FluxLoraLoaderMixin.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(accelerator.unwrap_model(transformer))): + transformer_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + load_lora(transformer=transformer_, input_dir=input_dir) + # Make sure the trainable params are in float32. This is again needed since the base models + cast_training_params([transformer_], dtype=torch.float32) + + if accelerator: + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + + +def load_lora(transformer, input_dir): + lora_state_dict = FluxLoraLoaderMixin.lora_state_dict(input_dir) + + transformer_state_dict = { + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + raise Exception( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: {unexpected_keys}. " + ) + + +# Not really cosine but with decay +def get_cosine_schedule_with_warmup_and_decay( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + last_epoch: int = -1, + constant_steps=-1, + eps=1e-5, +) -> LambdaLR: + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_periods (`float`, *optional*, defaults to 0.5): + The number of periods of the cosine function in a schedule (the default is to just decrease from the max + value to 0 following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + constant_steps (`int`): + The total number of constant lr steps following a warmup + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + if constant_steps <= 0: + constant_steps = num_training_steps - num_warmup_steps + + def lr_lambda(current_step): + # Accelerate sends current_step*num_processes + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step < num_warmup_steps + constant_steps: + return 1 + + return max( + eps, + float(num_training_steps - current_step) + / float(max(1, num_training_steps - num_warmup_steps - constant_steps)), + ) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_lr_scheduler(name, optimizer, num_warmup_steps, num_training_steps, constant_steps): + if name != "constant_with_warmup_cosine_decay": + return get_scheduler( + name=name, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + ) + + # Using custom warmup+constant+decay scheduler + return get_cosine_schedule_with_warmup_and_decay( + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + constant_steps=constant_steps, + ) + + +def load_checkpoint(accelerator, args): + # Load from local checkpoint that sage maker synced to s3 prefix + global_step = 0 + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("_")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path), map_location="cpu") + global_step = int(path.split("_")[-1]) + + return global_step + + +def collate_fn(examples): + pixel_values = [example["instance_images"] for example in examples] + input_images = [example["source_images"] for example in examples] + captions = [example["instance_prompt"] for example in examples] + # Get target dimensions (assuming batch_size=1, so we can get from first example) + target_width = examples[0]["target_width"] + target_height = examples[0]["target_height"] + + input_images = torch.stack(input_images) + input_images = input_images.to(memory_format=torch.contiguous_format).float() + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + return pixel_values, input_images, captions, target_width, target_height + + +def get_accelerator(args): + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + project_config=accelerator_project_config, + kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)], + ) + + # Set huggingface token key if provided + with accelerator.main_process_first(): + if accelerator.is_local_main_process: + if os.environ.get("HF_API_TOKEN"): + HfFolder.save_token(os.environ.get("HF_API_TOKEN")) + + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + return accelerator + + +def main(args): + try: + cuda_version = torch.version.cuda + print(f"PyTorch CUDA Version: {cuda_version}") + except Exception as e: + print(f"Error checking CUDA version: {e}") + raise e + + args = parse_args() + + RANK = int(os.environ.get("RANK", 0)) + + seed = args.seed + RANK + set_seed(seed) + random.seed(seed) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Set accelerator with fsdp/data-parallel + accelerator = get_accelerator(args) + + WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) + TOTAL_BATCH_NO_ACC = args.train_batch_size + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + logger.info(f"TORCH_VERSION {torch.__version__}") + logger.info(f"DIFFUSERS_VERSION {diffusers.__version__}") + + logger.info("using precompted datasets") + + transformer = BriaFiboTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + low_cpu_mem_usage=False, # critical: avoid meta tensors + weight_dtype=weight_dtype, + ) + transformer = transformer.to(accelerator.device).eval() + total_num_layers = transformer.config["num_layers"] + transformer.config["num_single_layers"] + + logger.info(f"Using precision of {weight_dtype}") + if args.lora_rank > 0: + logger.info(f"Using LORA with rank {args.lora_rank}") + transformer.requires_grad_(False) + transformer.to(dtype=weight_dtype) + set_lora_training(accelerator, transformer, args.lora_rank) + # Upcast trainable parameters (LoRA) into fp32 for mixed precision training + cast_training_params([transformer], dtype=torch.float32) + else: + transformer.requires_grad_(True) + assert transformer.dtype == torch.float32 + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + get_prompt_embeds_lambda = get_smollm_prompt_embeds + print("Loading smolLM text encoder") + + tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + text_encoder = ( + AutoModelForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + dtype=weight_dtype, + ) + .to(accelerator.device) + .eval() + .requires_grad_(False) + ) + + vae_model = AutoencoderKLWan.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") + + vae_model = vae_model.to(accelerator.device).requires_grad_(False) + # Read vae config + vae_config = vae_model.config + vae_config["shift_factor"] = ( + torch.tensor(vae_model.config["latents_mean"]).reshape((1, 48, 1, 1)).to(device=accelerator.device) + ) + vae_config["scaling_factor"] = 1 / torch.tensor(vae_model.config["latents_std"]).reshape((1, 48, 1, 1)).to( + device=accelerator.device + ) + vae_config["compression_rate"] = 16 + vae_config["latent_channels"] = 48 + + def get_prompt_embeds(prompts): + prompt_embeddings, text_encoder_layers, attentions_masks = get_prompt_embeds_lambda( + tokenizer, + text_encoder, + prompts=prompts, + max_sequence_length=args.max_sequence_length, + ) + return prompt_embeddings, text_encoder_layers, attentions_masks + + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + # Initialize the optimizer + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to prodigy" + ) + args.optimizer = "prodigy" + + if args.lora_rank > 0: + parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + else: + parameters = transformer.parameters() + + if args.optimizer.lower() == "adamw": + optimizer_cls = torch.optim.AdamW + optimizer = optimizer_cls( + parameters, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + elif args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_cls = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_cls( + parameters, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + if args.lr_scheduler == "cosine_with_warmup": + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + else: + lr_scheduler = get_lr_scheduler( + name=args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + constant_steps=args.constant_steps * accelerator.num_processes, + ) + + transformer, optimizer, lr_scheduler = accelerator.prepare(transformer, optimizer, lr_scheduler) + fibo_edit_pipeline = BriaFiboEditPipeline( + transformer=transformer, + scheduler=lr_scheduler, + vae=vae_model, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + logger.info("***** Running training *****") + + logger.info(f"diffusers version: {diffusers.__version__}") + + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + + if args.resume_from_checkpoint != "no": + global_step = load_checkpoint(accelerator, args) + logger.info(f"Using {args.optimizer} with lr: {args.learning_rate}, beta2: {args.adam_beta2}") + + # Only show the progress bar once on each machine. + progress_bar = tqdm( + range(global_step, args.max_train_steps), + disable=not accelerator.is_local_main_process, + ) + progress_bar.set_description("Steps") + + now = datetime.now() + times_arr = [] + # Init dynamic scheduler (resolution will be determined per batch) + noise_scheduler = ShiftedStretchedLogitNormalTimestepSampler() + + # encode null prompt "" + null_conditioning, null_conditioning_layers, _ = get_prompt_embeds([""]) + logger.info("Using empty prompt for null embeddings") + assert null_conditioning.shape[0] == 1 + null_conditioning = null_conditioning.repeat(args.train_batch_size, 1, 1).to(dtype=torch.float32) + null_conditioning_layers = [ + layer.repeat(args.train_batch_size, 1, 1).to(dtype=torch.float32) for layer in null_conditioning_layers + ] + + vae_scale_factor = ( + 2 ** (len(vae_config["block_out_channels"]) - 1) + if "compression_rate" not in vae_config + else vae_config["compression_rate"] + ) + transformer.train() + train_loss = 0.0 + generator = torch.Generator(device=accelerator.device).manual_seed(seed) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + repeats=args.repeats, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + ) + + iter_ = iter(train_dataloader) + for step in range( + global_step * args.gradient_accumulation_steps, + args.max_train_steps * args.gradient_accumulation_steps, + ): + have_batch = False + + while not have_batch: + try: + fetch_time = datetime.now() + batch = next(iter_) + fetch_time = datetime.now() - fetch_time + have_batch = True + except StopIteration: + iter_ = iter(train_dataloader) + logger.info(f"Rank {RANK} reinit iterator") + + target_pixel_values, input_pixel_values, captions, target_width, target_height = ( + batch # Get batch with dynamic resolution + ) + height, width = target_height, target_width + target_latents, target_latent_image_ids = fibo_edit_pipeline.prepare_image_latents( + image=target_pixel_values, + batch_size=args.train_batch_size, + num_channels_latents=vae_config["latent_channels"], + height=height, + width=width, + dtype=torch.float32, + device=accelerator.device, + generator=generator, + ) + # target latents of the target image should be 0 + target_latent_image_ids[..., 0] = 0 + + context_latents, context_latent_image_ids = fibo_edit_pipeline.prepare_image_latents( + image=input_pixel_values, + batch_size=args.train_batch_size, + num_channels_latents=vae_config["latent_channels"], + height=height, + width=width, + dtype=torch.float32, + device=accelerator.device, + generator=generator, + ) + + # Get Captions + encoder_hidden_states, text_encoder_layers, prompt_attention_mask = get_prompt_embeds(captions) + text_encoder_layers = list(text_encoder_layers) + # make sure that the number of text encoder layers is equal to the total number of layers in the transformer + assert len(text_encoder_layers) <= total_num_layers + text_encoder_layers = text_encoder_layers + [text_encoder_layers[-1]] * ( + total_num_layers - len(text_encoder_layers) + ) + null_conditioning_layers = null_conditioning_layers + [null_conditioning_layers[-1]] * ( + total_num_layers - len(null_conditioning_layers) + ) + + target_pixel_values = target_pixel_values.to(device=accelerator.device, dtype=torch.float32) + input_pixel_values = input_pixel_values.to(device=accelerator.device, dtype=torch.float32) + encoder_hidden_states = encoder_hidden_states.to(device=accelerator.device, dtype=torch.float32) + prompt_attention_mask = prompt_attention_mask.to(device=accelerator.device, dtype=torch.float32) + + # create attention mask for the target and context latents + target_latents_attention_mask = torch.ones( + [target_latents.shape[0], target_latents.shape[1]], + dtype=target_latents.dtype, + device=target_latents.device, + ) + + context_latents_attention_mask = torch.ones( + [context_latents.shape[0], context_latents.shape[1]], + dtype=context_latents.dtype, + device=context_latents.device, + ) + + attention_mask = torch.cat( + [prompt_attention_mask, target_latents_attention_mask, context_latents_attention_mask], dim=1 + ) + + with accelerator.accumulate(transformer): + # Sample noise that we'll add to the latents + noise = torch.randn_like(target_latents) + + bsz = target_pixel_values.shape[0] + + seq_len = (height // vae_scale_factor) * (width // vae_scale_factor) + + sigmas = noise_scheduler.sample(bsz, seq_len, device=accelerator.device) + timesteps = sigmas * 1000 + while len(sigmas.shape) < len(noise.shape): + sigmas = sigmas.unsqueeze(-1) + noisy_latents = sigmas * noise + (1.0 - sigmas) * target_latents + + # input for rope positional embeddings for text + num_text_tokens = encoder_hidden_states.shape[1] + text_ids = torch.zeros(num_text_tokens, 3).to(device=accelerator.device, dtype=encoder_hidden_states.dtype) + + # Sample masks for the edit prompts. + if args.drop_rate_cfg > 0: + null_embedding, null_attention_mask = pad_embedding(null_conditioning, max_tokens=num_text_tokens) + # null embedding for 10% of the images + random_p = torch.rand(bsz, device=target_latents.device, generator=generator) + + prompt_mask = random_p < args.drop_rate_cfg + + prompt_mask = prompt_mask.reshape(bsz, 1, 1) + encoder_hidden_states = torch.where(prompt_mask, null_embedding, encoder_hidden_states) + + text_encoder_layers = [ + torch.where( + prompt_mask, + pad_embedding(null_conditioning_layers[i], max_tokens=num_text_tokens)[0], + text_encoder_layers[i], + ) + for i in range(len(text_encoder_layers)) + ] + + prompt_mask = prompt_mask.reshape(bsz, 1) + prompt_attention_mask = torch.where(prompt_mask, null_attention_mask, prompt_attention_mask) + + # Get the target for loss depending on the prediction type + target = noise - target_latents # V pred + latent_height = int(height) // vae_scale_factor + latent_width = int(width) // vae_scale_factor + + patched_latent_image_ids = fibo_edit_pipeline._prepare_latent_image_ids( + noisy_latents.shape[0], + latent_height, + latent_width, + accelerator.device, + noisy_latents.dtype, + ) + + latent_attention_mask = torch.ones( + [noisy_latents.shape[0], noisy_latents.shape[1]], + dtype=target_latents.dtype, + device=target_latents.device, + ) + patched_latent_image_ids = torch.cat([patched_latent_image_ids, context_latent_image_ids], dim=0) + noisy_latents = torch.cat([noisy_latents, context_latents], dim=1) + attention_mask = torch.cat( + [prompt_attention_mask, latent_attention_mask, context_latents_attention_mask], dim=1 + ) + + # Prepare attention_matrix + attention_mask = create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq + + attention_mask = attention_mask.unsqueeze(dim=1) # for brodoacast to attention heads + joint_attention_kwargs = {"attention_mask": attention_mask} + + model_pred = transformer( + hidden_states=noisy_latents, + timestep=timesteps, + encoder_hidden_states=encoder_hidden_states, # [batch,128,height/patch*width/patch] + text_encoder_layers=text_encoder_layers, + txt_ids=text_ids, + img_ids=patched_latent_image_ids, + return_dict=False, + joint_attention_kwargs=joint_attention_kwargs, + )[0] + model_pred = model_pred[:, : target_latents.shape[1]] + # Un-Patchify latent (4 -> 1) + loss_coeff = WORLD_SIZE / TOTAL_BATCH_NO_ACC + + denoising_loss = torch.mean( + ((model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ).sum() + denoising_loss = loss_coeff * denoising_loss + + loss = denoising_loss + + train_loss += accelerator.gather(loss.detach()).mean().item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(parameters, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + logger.info(f"train_loss: {train_loss}") + after = datetime.now() - now + now = datetime.now() + + times_arr += [after.total_seconds()] + + train_loss = 0.0 + + if (global_step - 1) % args.checkpointing_steps == 0 and (global_step - 1) > 0: + save_path = os.path.join(args.output_dir, f"checkpoint_{global_step - 1}") + + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + now = datetime.now() + + if global_step == args.max_train_steps: + save_path = os.path.join(args.output_dir, "checkpoint_final") + + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + now = datetime.now() + + logs = {"step_loss": loss.detach().item()} + + progress_bar.set_postfix(**logs) + if global_step >= args.max_train_steps: + break + + # Create the pipeline using the trained modules and save it. + logger.info("Waiting for everyone :)") + accelerator.wait_for_everyone() + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index bc59497d1db7..3977335f5138 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -455,6 +455,7 @@ "AuraFlowPipeline", "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", + "BriaFiboEditPipeline", "BriaFiboPipeline", "BriaPipeline", "ChromaImg2ImgPipeline", @@ -1179,6 +1180,7 @@ AudioLDM2UNet2DConditionModel, AudioLDMPipeline, AuraFlowPipeline, + BriaFiboEditPipeline, BriaFiboPipeline, BriaPipeline, ChromaImg2ImgPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b94319ffcbdc..eb09cd645efb 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -128,7 +128,7 @@ "AnimateDiffVideoToVideoControlNetPipeline", ] _import_structure["bria"] = ["BriaPipeline"] - _import_structure["bria_fibo"] = ["BriaFiboPipeline"] + _import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"] _import_structure["flux2"] = ["Flux2Pipeline"] _import_structure["flux"] = [ "FluxControlPipeline", @@ -594,7 +594,7 @@ from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline from .bria import BriaPipeline - from .bria_fibo import BriaFiboPipeline + from .bria_fibo import BriaFiboEditPipeline, BriaFiboPipeline from .chroma import ChromaImg2ImgPipeline, ChromaPipeline from .chronoedit import ChronoEditPipeline from .cogvideo import ( diff --git a/src/diffusers/pipelines/bria_fibo/__init__.py b/src/diffusers/pipelines/bria_fibo/__init__.py index 206a463b394b..8dd77270902c 100644 --- a/src/diffusers/pipelines/bria_fibo/__init__.py +++ b/src/diffusers/pipelines/bria_fibo/__init__.py @@ -23,6 +23,8 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"] + _import_structure["pipeline_bria_fibo_edit"] = ["BriaFiboEditPipeline"] + if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -33,6 +35,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_bria_fibo import BriaFiboPipeline + from .pipeline_bria_fibo_edit import BriaFiboEditPipeline else: import sys diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py new file mode 100644 index 000000000000..88f01a97399d --- /dev/null +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py @@ -0,0 +1,1132 @@ +# Copyright (c) Bria.ai. All rights reserved. +# +# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0). +# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/ +# +# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit, +# indicate if changes were made, and do not use the material for commercial purposes. +# +# See the license for further details. + +import json +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer +from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan +from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel +from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput +from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +PipelineMaskInput = Union[ + torch.FloatTensor, Image.Image, List[Image.Image], List[torch.FloatTensor], np.ndarray, List[np.ndarray] +] + +# TODO: Update example docstring +EXAMPLE_DOC_STRING = """ + Example: + ```python + import torch + from diffusers import BriaFiboEditPipeline + from diffusers.modular_pipelines import ModularPipeline + + torch.set_grad_enabled(False) + vlm_pipe = ModularPipelineBlocks.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True) + vlm_pipe = vlm_pipe.init_pipeline() + + pipe = BriaFiboEditPipeline.from_pretrained( + "briaai/fibo-edit", + torch_dtype=torch.bfloat16, + ) + pipe.to("cuda") + + output = vlm_pipe( + prompt="A hyper-detailed, ultra-fluffy owl sitting in the trees at night, looking directly at the camera with wide, adorable, expressive eyes. Its feathers are soft and voluminous, catching the cool moonlight with subtle silver highlights. The owl's gaze is curious and full of charm, giving it a whimsical, storybook-like personality." + ) + json_prompt_generate = json.loads(output.values["json_prompt"]) + + image = Image.open("image_generate.png") + + edit_prompt = "Make the owl to be a cat" + + json_prompt_generate["edit_instruction"] = edit_prompt + + results_generate = pipe( + prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=3.5, image=image, output_type="np" + ) + ``` +""" + +PREFERRED_RESOLUTION = { + 256 * 256: [(208, 304), (224, 288), (256, 256), (288, 224), (304, 208), (320, 192), (336, 192)], + 512 * 512: [ + (416, 624), + (432, 592), + (464, 560), + (512, 512), + (544, 480), + (576, 448), + (592, 432), + (608, 416), + (624, 416), + (640, 400), + (672, 384), + (704, 368), + ], + 1024 * 1024: [ + (832, 1248), + (880, 1184), + (912, 1136), + (1024, 1024), + (1136, 912), + (1184, 880), + (1216, 848), + (1248, 832), + (1248, 832), + (1264, 816), + (1296, 800), + (1360, 768), + ], +} + + +def is_valid_edit_json(json_input: str | dict): + """ + Check if the input is a valid JSON string or dict with an "edit_instruction" key. + + Args: + json_input (`str` or `dict`): + The JSON string or dict to check. + + Returns: + `bool`: True if the input is a valid JSON string or dict with an "edit_instruction" key, False otherwise. + """ + try: + if isinstance(json_input, str) and "edit_instruction" in json_input: + json.loads(json_input) + return True + elif isinstance(json_input, dict) and "edit_instruction" in json_input: + return True + else: + return False + except json.JSONDecodeError: + return False + + +def is_valid_mask(mask: PipelineMaskInput): + """ + Check if the mask is a valid mask. + """ + if isinstance(mask, torch.Tensor): + return True + elif isinstance(mask, Image.Image): + return True + elif isinstance(mask, list): + return all(isinstance(m, (torch.Tensor, Image.Image, np.ndarray)) for m in mask) + elif isinstance(mask, np.ndarray): + return mask.ndim in [2, 3] and mask.min() >= 0 and mask.max() <= 1 + else: + return False + + +def get_mask_size(mask: PipelineMaskInput): + """ + Get the size of the mask. + """ + if isinstance(mask, torch.Tensor): + return mask.shape[-2:] + elif isinstance(mask, Image.Image): + return mask.size[::-1] # (height, width) + elif isinstance(mask, list): + return [get_mask_size(m) for m in mask] + elif isinstance(mask, np.ndarray): + return mask.shape[-2:] + else: + return None + + +def get_image_size(image: PipelineImageInput): + """ + Get the size of the image. + """ + if isinstance(image, torch.Tensor): + return image.shape[-2:] + elif isinstance(image, Image.Image): + return image.size[::-1] # (height, width) + elif isinstance(image, list): + return [get_image_size(i) for i in image] + else: + return None + + +def paste_mask_on_image(mask: PipelineMaskInput, image: PipelineImageInput): + """convert mask and image to PIL Images and paste the mask on the image""" + if isinstance(mask, torch.Tensor): + if mask.ndim == 3 and mask.shape[0] == 1: + mask = mask.squeeze(0) + mask = Image.fromarray((mask.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(mask, Image.Image): + pass + elif isinstance(mask, list): + mask = mask[0] + if isinstance(mask, torch.Tensor): + if mask.ndim == 3 and mask.shape[0] == 1: + mask = mask.squeeze(0) + mask = Image.fromarray((mask.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(mask, np.ndarray): + mask = Image.fromarray((mask * 255).astype(np.uint8)) + elif isinstance(mask, np.ndarray): + mask = Image.fromarray((mask * 255).astype(np.uint8)) + + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.permute(1, 2, 0) + image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(image, Image.Image): + pass + elif isinstance(image, list): + image = image[0] + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.permute(1, 2, 0) + image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(image, np.ndarray): + image = Image.fromarray((image * 255).astype(np.uint8)) + elif isinstance(image, np.ndarray): + image = Image.fromarray((image * 255).astype(np.uint8)) + + mask = mask.convert("L") + image = image.convert("RGB") + gray_color = (128, 128, 128) + gray_img = Image.new("RGB", image.size, gray_color) + image = Image.composite(gray_img, image, mask) + return image + + +class BriaFiboEditPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + Args: + transformer (`BriaFiboTransformer2DModel`): + The transformer model for 2D diffusion modeling. + scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`): + Scheduler to be used with `transformer` to denoise the encoded latents. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder for encoding and decoding images to and from latent representations. + text_encoder (`SmolLM3ForCausalLM`): + Text encoder for processing input prompts. + tokenizer (`AutoTokenizer`): + Tokenizer used for processing the input text prompts for the text_encoder. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + transformer: BriaFiboTransformer2DModel, + scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers], + vae: AutoencoderKLWan, + text_encoder: SmolLM3ForCausalLM, + tokenizer: AutoTokenizer, + ): + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor = 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # * 2) + self.default_sample_size = 32 # 64 + + def get_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + max_sequence_length: int = 2048, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + if not prompt: + raise ValueError("`prompt` must be a non-empty string or list of strings.") + + batch_size = len(prompt) + bot_token_id = 128000 + + text_encoder_device = device if device is not None else torch.device("cpu") + if not isinstance(text_encoder_device, torch.device): + text_encoder_device = torch.device(text_encoder_device) + + if all(p == "" for p in prompt): + input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device) + attention_mask = torch.ones_like(input_ids) + else: + tokenized = self.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = tokenized.input_ids.to(text_encoder_device) + attention_mask = tokenized.attention_mask.to(text_encoder_device) + + if any(p == "" for p in prompt): + empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device) + input_ids[empty_rows] = bot_token_id + attention_mask[empty_rows] = 1 + + encoder_outputs = self.text_encoder( + input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_outputs.hidden_states + + prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1) + prompt_embeds = prompt_embeds.to(device=device, dtype=dtype) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + hidden_states = tuple( + layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states + ) + attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) + + return prompt_embeds, hidden_states, attention_mask + + @staticmethod + def pad_embedding(prompt_embeds, max_tokens, attention_mask=None): + # Pad embeddings to `max_tokens` while preserving the mask of real tokens. + batch_size, seq_len, dim = prompt_embeds.shape + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device) + else: + attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if max_tokens < seq_len: + raise ValueError("`max_tokens` must be greater or equal to the current sequence length.") + + if max_tokens > seq_len: + pad_length = max_tokens - seq_len + padding = torch.zeros( + (batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.cat([prompt_embeds, padding], dim=1) + + mask_padding = torch.zeros( + (batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + attention_mask = torch.cat([attention_mask, mask_padding], dim=1) + + return prompt_embeds, attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 3000, + lora_scale: Optional[float] = None, + ): + r""" + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + guidance_scale (`float`): + Guidance scale for classifier free guidance. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_attention_mask = None + negative_prompt_attention_mask = None + if prompt_embeds is None: + prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) + prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers] + + if guidance_scale > 1: + if isinstance(negative_prompt, list) and negative_prompt[0] is None: + negative_prompt = "" + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype) + negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers] + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + # Pad to longest + if prompt_attention_mask is not None: + prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if negative_prompt_embeds is not None: + if negative_prompt_attention_mask is not None: + negative_prompt_attention_mask = negative_prompt_attention_mask.to( + device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype + ) + max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1]) + + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers] + + negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding( + negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask + ) + negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers] + else: + max_tokens = prompt_embeds.shape[1] + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + negative_prompt_layers = None + + dtype = self.text_encoder.dtype + text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype) + + return ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @staticmethod + # Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _unpack_latents_no_patch(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels) + latents = latents.permute(0, 3, 1, 2) + + return latents + + @staticmethod + def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width): + latents = latents.permute(0, 2, 3, 1) + latents = latents.reshape(batch_size, height * width, num_channels_latents) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + do_patching=False, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if do_patching: + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + else: + latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + return latents, latent_image_ids + + @staticmethod + def _prepare_attention_mask(attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Optional[PipelineImageInput] = None, + mask: Optional[PipelineMaskInput] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + seed: Optional[int] = None, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 3000, + do_patching=False, + _auto_resize: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image` or `torch.FloatTensor`, *optional*): + The image to guide the image generation. If not defined, the pipeline will generate an image from + scratch. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + seed (`int`, *optional*): + A seed used to make generation deterministic. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`. + do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching. + Examples: + Returns: + [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if height is None or width is None: + if image is not None: + image_height, image_width = self.image_processor.get_default_height_width(image) + if _auto_resize: + image_width, image_height = min( + PREFERRED_RESOLUTION[1024 * 1024], + key=lambda size: abs(size[0] / size[1] - image_width / image_height), + ) + width, height = image_width, image_height + else: + raise ValueError("You must provide either an image or both height and width.") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + seed=seed, + image=image, + mask=mask, + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + if mask is not None and image is not None: + image = paste_mask_on_image(mask, image) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + + if prompt is not None and is_valid_edit_json(prompt): + prompt = json.dumps(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + num_images_per_prompt=num_images_per_prompt, + lora_scale=lora_scale, + ) + prompt_batch_size = prompt_embeds.shape[0] + + if guidance_scale > 1: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_layers = [ + torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers)) + ] + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + total_num_layers_transformer = len(self.transformer.transformer_blocks) + len( + self.transformer.single_transformer_blocks + ) + if len(prompt_layers) >= total_num_layers_transformer: + # remove first layers + prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :] + else: + # duplicate last layer + prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers)) + + # Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, height, width) + image = self.image_processor.preprocess(image, height, width) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + if do_patching: + num_channels_latents = int(num_channels_latents / 4) + + latents, latent_image_ids = self.prepare_latents( + prompt_batch_size, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + do_patching, + ) + + if image is not None: + image_latents, image_ids = self.prepare_image_latents( + image=image, + batch_size=batch_size * num_images_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + ) + latent_image_ids = torch.cat([latent_image_ids, image_ids], dim=0) # dim 0 is sequence dimension + else: + image_latents = None + + latent_attention_mask = torch.ones( + [latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device + ) + if guidance_scale > 1: + latent_attention_mask = latent_attention_mask.repeat(2, 1) + + if image_latents is None: + attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1) + else: + image_latent_attention_mask = torch.ones( + [image_latents.shape[0], image_latents.shape[1]], + dtype=image_latents.dtype, + device=image_latents.device, + ) + if guidance_scale > 1: + image_latent_attention_mask = image_latent_attention_mask.repeat(2, 1) + attention_mask = torch.cat( + [prompt_attention_mask, latent_attention_mask, image_latent_attention_mask], dim=1 + ) + + attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq + attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting + + if self._joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + self._joint_attention_kwargs["attention_mask"] = attention_mask + + # Adapt scheduler to dynamic shifting (resolution dependent) + + if do_patching: + seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2)) + else: + seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor) + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + mu = calculate_shift( + seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + + # Init sigmas and timesteps according to shift size + # This changes the scheduler in-place according to the dynamic scheduling + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps=num_inference_steps, + device=device, + timesteps=None, + sigmas=sigmas, + mu=mu, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Support old different diffusers versions + if len(latent_image_ids.shape) == 3: + latent_image_ids = latent_image_ids[0] + + if len(text_ids.shape) == 3: + text_ids = text_ids[0] + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = latents + + if image_latents is not None: + latent_model_input = torch.cat([latent_model_input, image_latents], dim=1) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latent_model_input] * 2) if guidance_scale > 1 else latent_model_input + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to( + device=latent_model_input.device, dtype=latent_model_input.dtype + ) + + # This is predicts "v" from flow-matching or eps from diffusion + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_encoder_layers=prompt_layers, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + )[0] + + # perform guidance + if guidance_scale > 1: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred[:, : latents.shape[1], ...], t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + if do_patching: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + else: + latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor) + + latents = latents.unsqueeze(dim=2) + latents_device = latents[0].device + latents_dtype = latents[0].dtype + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents_device, latents_dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents_device, latents_dtype + ) + latents_scaled = [latent / latents_std + latents_mean for latent in latents] + latents_scaled = torch.cat(latents_scaled, dim=0) + image = [] + for scaled_latent in latents_scaled: + curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0] + curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type) + image.append(curr_image) + if len(image) == 1: + image = image[0] + else: + image = np.stack(image, axis=0) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return BriaFiboPipelineOutput(images=image) + + def prepare_image_latents( + self, + image: torch.Tensor, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + ): + image = image.to(device=device, dtype=dtype) + + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + # scaling + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, dtype + ) + + image_latents_cthw = self.vae.encode(image.unsqueeze(2)).latent_dist.mean + latents_scaled = [(latent - latents_mean) * latents_std for latent in image_latents_cthw] + image_latents_cthw = torch.concat(latents_scaled, dim=0) + image_latents_bchw = image_latents_cthw[:, :, 0, :, :] + + image_latent_height, image_latent_width = image_latents_bchw.shape[2:] + image_latents_bsd = self._pack_latents_no_patch( + latents=image_latents_bchw, + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=image_latent_height, + width=image_latent_width, + ) + # breakpoint() + image_ids = self._prepare_latent_image_ids( + batch_size=batch_size, height=image_latent_height, width=image_latent_width, device=device, dtype=dtype + ) + # image ids are the same as latent ids with the first dimension set to 1 instead of 0 + image_ids[..., 0] = 1 + return image_latents_bsd, image_ids + + def check_inputs( + self, + prompt, + seed, + image, + mask, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if seed is not None and not isinstance(seed, int): + raise ValueError("Seed must be an integer") + if image is not None and not isinstance(image, (torch.Tensor, Image.Image, list)): + raise ValueError("Image must be a valid image") + if image is None and mask is not None: + raise ValueError("If mask is provided, image must also be provided") + + if mask is not None and not is_valid_mask(mask): + raise ValueError("Mask must be a valid mask") + + if mask is not None and image is not None and not (get_mask_size(mask) == get_image_size(image)): + raise ValueError("Mask and image must have the same size") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not is_valid_edit_json(prompt): + raise ValueError(f"`prompt` has to be a valid JSON string or dict but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 3000: + raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}") + + def create_attention_matrix(self, attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index da32b7ad8df0..3a48c6a47914 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -602,6 +602,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class BriaFiboEditPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class BriaPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/bria_fibo_edit/__init__.py b/tests/pipelines/bria_fibo_edit/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py b/tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py new file mode 100644 index 000000000000..5376c4b5e03f --- /dev/null +++ b/tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py @@ -0,0 +1,192 @@ +# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer +from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM + +from diffusers import ( + AutoencoderKLWan, + BriaFiboEditPipeline, + FlowMatchEulerDiscreteScheduler, +) +from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel +from tests.pipelines.test_pipelines_common import PipelineTesterMixin + +from ...testing_utils import ( + enable_full_determinism, + torch_device, +) + + +enable_full_determinism() + + +class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = BriaFiboEditPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale"]) + batch_params = frozenset(["prompt"]) + test_xformers_attention = False + test_layerwise_casting = False + test_group_offloading = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = BriaFiboTransformer2DModel( + patch_size=1, + in_channels=16, + num_layers=1, + num_single_layers=1, + attention_head_dim=8, + num_attention_heads=2, + joint_attention_dim=64, + text_encoder_dim=32, + pooled_projection_dim=None, + axes_dims_rope=[0, 4, 4], + ) + + vae = AutoencoderKLWan( + base_dim=80, + decoder_base_dim=128, + dim_mult=[1, 2, 4, 4], + dropout=0.0, + in_channels=12, + latents_mean=[0.0] * 16, + latents_std=[1.0] * 16, + is_residual=True, + num_res_blocks=2, + out_channels=12, + patch_size=2, + scale_factor_spatial=16, + scale_factor_temporal=4, + temperal_downsample=[False, True, True], + z_dim=16, + ) + scheduler = FlowMatchEulerDiscreteScheduler() + text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32)) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "vae": vae, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + inputs = { + "prompt": '{"text": "A painting of a squirrel eating a burger","edit_instruction": "A painting of a squirrel eating a burger"}', + "negative_prompt": "bad, ugly", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 192, + "width": 336, + "output_type": "np", + } + image = Image.new("RGB", (336, 192), (255, 255, 255)) + inputs["image"] = image + return inputs + + @unittest.skip(reason="will not be supported due to dim-fusion") + def test_encode_prompt_works_in_isolation(self): + pass + + @unittest.skip(reason="Batching is not supported yet") + def test_num_images_per_prompt(self): + pass + + @unittest.skip(reason="Batching is not supported yet") + def test_inference_batch_consistent(self): + pass + + @unittest.skip(reason="Batching is not supported yet") + def test_inference_batch_single_identical(self): + pass + + def test_bria_fibo_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = {"edit_instruction": "a different prompt"} + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + assert max_diff > 1e-6 + + def test_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (64, 64), (32, 64)] + for height, width in height_width_pairs: + expected_height = height + expected_width = width + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) + + def test_bria_fibo_edit_mask(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + mask = Image.fromarray((np.ones((192, 336)) * 255).astype(np.uint8), mode="L") + + inputs.update({"mask": mask}) + output = pipe(**inputs).images[0] + + assert output.shape == (192, 336, 3) + + def test_bria_fibo_edit_mask_image_size_mismatch(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + mask = Image.fromarray((np.ones((64, 64)) * 255).astype(np.uint8), mode="L") + + inputs.update({"mask": mask}) + with self.assertRaisesRegex(ValueError, "Mask and image must have the same size"): + pipe(**inputs) + + def test_bria_fibo_edit_mask_no_image(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + mask = Image.fromarray((np.ones((32, 32)) * 255).astype(np.uint8), mode="L") + + # Remove image from inputs if it's there (it shouldn't be by default from get_dummy_inputs) + inputs.pop("image", None) + inputs.update({"mask": mask}) + + with self.assertRaisesRegex(ValueError, "If mask is provided, image must also be provided"): + pipe(**inputs)