diff --git a/README.md b/README.md index a757b16..2f6c4c0 100644 --- a/README.md +++ b/README.md @@ -1 +1,209 @@ -# OmniBooth \ No newline at end of file +# OmniBooth + +> OmniBooth: Learning Latent Control for Image Synthesis with Multi-modal Instruction
+> [Leheng Li](https://len-li.github.io), Weichao Qiu, Xu Yan, Jing He, Kaiqiang Zhou, Yingjie CAI, Qing LIAN, Bingbing Liu, Ying-Cong Chen + +OmniBooth is a project focused on synthesizing image data following multi-modal instruction. Users can use text or image to control instance generation. This repository provides tools and scripts to process, train, and generate synthetic image data using COCO dataset, or self-designed data. + +#### [Project Page](https://len-li.github.io/omnibooth-web) | [Paper](https://arxiv.org/) | [Video](https://len-li.github.io/omnibooth-web/videos/teaser-user-draw.mp4) | [Checkpoint](https://huggingface.co/lilelife/Omnibooth) + + + +## Table of Contents + + - [Installation](#installation) + - [Prepare Dataset](#prepare-dataset) + - [Prepare Checkpoint](#prepare-checkpoint) + - [Train](#train) + - [Inference](#inference) + - [Behavior analysis](#behavior-analysis) + - [Data sturture](#instance-data-structure) + + + + + + + +## Installation + +To get started with OmniBooth, follow these steps: + +1. **Clone the repository:** + ```bash + git clone https://github.com/Len-Li/OmniBooth.git + cd OmniBooth + ``` + +2. **Set up a environment :** + ```bash + pip install torch torchvision transformers + pip install diffusers==0.26.0.dev0 + # We use a old version of diffusers, please take care of it. + + pip install albumentations pycocotools + pip install git+https://github.com/cocodataset/panopticapi.git + ``` + + + + +## Prepare Dataset + +You can skip this step if you just want to run a demo generation. I've prepared demo mask in `data/instance_dataset` for generation. Please see [Inference](#inference). + +To train OmniBooth, follow the steps below: + +1. **Download the [COCONut](https://github.com/bytedance/coconut_cvpr2024/blob/main/preparing_datasets.md) dataset:** + + We use COCONut-S split. + Please download the COCONut-S file and relabeled-COCO-val from [here](https://github.com/bytedance/coconut_cvpr2024?tab=readme-ov-file#dataset-splits) and put it in `data/coconut_dataset` folder. I recommend to use [Kaggle](https://www.kaggle.com/datasets/xueqingdeng/coconut) link. + + +2. **Download the COCO dataset:** + ``` + cd data/coconut_dataset + mkdir coco && cd coco + + wget http://images.cocodataset.org/zips/train2017.zip + wget http://images.cocodataset.org/zips/val2017.zip + wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip + + unzip train2017.zip && unzip val2017.zip + unzip annotations_trainval2017.zip + ``` + + + + + After preparation, you will be able to see the following directory structure: + + ``` + OmniBooth/ + ├── data/ + │ ├── instance_dataset/ + │ ├── coconut_dataset/ + │ │ ├── coco/ + │ │ ├── coconut_s/ + | | ├── relabeled_coco_val/ + │ │ ├── annotations/ + │ │ │ ├── coconut_s.json + │ │ │ ├── relabeled_coco_val.json + │ │ │ ├── my-train.json + │ │ │ ├── my-val.json + ``` + + + +## Prepare Checkpoint +Our model is based on [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0). We additionaly use [sdxl-vae-fp16-fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) to avoid numerical issue in VAE decoding. Please download the two models and put them at `./OmniBooth/ckp/`. + +Our checkpoint of OmniBooth is released in [huggingface](https://huggingface.co/lilelife/OmniBooth). If you want to use our model to run inference. Please put them at `./OmniBooth/ckp/`. + +## Train + + ```bash + bash train.sh + ``` +The details of the script are as follows: +```bash +export MODEL_DIR="./ckp/stable-diffusion-xl-base-1.0" +export VAE_DIR="./ckp/sdxl-vae-fp16-fix" + +export EXP_NAME="omnibooth_train" +export OUTPUT_DIR="./ckp/$EXP_NAME" + +accelerate launch --gpu_ids 0, --num_processes 1 --main_process_port 3226 train.py \ + --pretrained_model_name_or_path=$MODEL_DIR \ + --pretrained_vae_model_name_or_path=$VAE_DIR \ + --output_dir=$OUTPUT_DIR \ + --width=1024 \ + --height=1024 \ + --patch_size=364 \ + --learning_rate=4e-5 \ + --num_train_epochs=12 \ + --train_batch_size=1 \ + --mulscale_batch_size=2 \ + --mixed_precision="fp16" \ + --num_validation_images=2 \ + --validation_steps=500 \ + --checkpointing_steps=5000 \ + --checkpoints_total_limit=10 \ + --ctrl_channel=1024 \ + --use_sdxl=True \ + --enable_xformers_memory_efficient_attention \ + --report_to='wandb' \ + --resume_from_checkpoint="latest" \ + --tracker_project_name="omnibooth-demo" +``` + +The training process will take 3 days to complete using 8 NVIDIA A100. We use batchsize=2, image height set as 1024, image width follows the ground-truth image ratio. It will take 65GB memory for each GPU. + +## Inference + +```bash +bash infer.sh +``` +You will find generated images at `./vis_dir/`. The image is shown as follows: +![image](./ckp/plane.jpg) + + +## Behavior analysis + +1. The text instruction is not perfect, it is applicable to descriptions of attributes like color, but it is difficult to provide more granular descriptions. Scaling the data and model can help with this problem. +2. The image instruction may result in generated images with washed-out colors, possibly due to brightness augmentation. This can be adjusted by editing global prompt: ‘a brighter image’. +3. Video Dataset. Ideally, we should use video datasets to train image-instructed generation, similar to Anydoor. However, in our multi-modal setting, the cost of obtaining video datasets + tracking annotations + panoptic annotations is relatively high, so we only trained our model on the single-view COCO dataset. If you plan to expand the training data to video datasets, please let me know. + + +## Instance data structure + +I provide several instance mask datasets for inference in `data/instance_dataset`. This data is converted from coco dataset. The data structure is as follows: + +``` +# use data/instance_dataset/plane as an example +0000_mask.png +0000.png +0001_mask.png +0001.png +0002_mask.png +0002.png +... +prompt_dict.json +``` +The mask file is a binary mask that indicate the instance location. The image file is the optional image reference. Turn the `--text_or_img=img` to use it. + + The `prompt_dict.json` is a dictionary contains instance prompt and global_prompt. The prompt is a string that describes the instance or global image. For example, `"prompt_dict.json"` is as follows: + +```json +{ + "prompt_0": "a plane is silhouetted against a cloudy sky", + "prompt_1": "a road", + "prompt_2": "a pavement of merged", + "global_prompt": "large mustard yellow commercial airplane parked in the airport" +} +``` + + +## Acknowledgment +Additionally, we express our gratitude to the authors of the following opensource projects: + +- [Diffusers controlnet example](https://github.com/huggingface/diffusers/tree/main/examples/controlnet) (ControlNet training script) +- [COCONut](https://github.com/bytedance/coconut_cvpr2024) (Panoptic mask annotation) +- [SyntheOcc](https://len-li.github.io/syntheocc-web/) (Network structure) + + + +## BibTeX + +```bibtex +@inproceedings{li2024OmniBooth, + title={OmniBooth: Synthesize Geometric Controlled Street View Images through 3D Semantic MPIs}, + author={Li, Leheng and Qiu, Weichao and Chen, Ying-Cong et.al.}, + booktitle={arxiv preprint}, + year={2024} + } +``` + +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. + + diff --git a/args_file.py b/args_file.py new file mode 100644 index 0000000..72e3246 --- /dev/null +++ b/args_file.py @@ -0,0 +1,404 @@ +import argparse + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default="", + # required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataroot_path", + type=str, + default='./data/coconut_dataset', + help="The location of nuScenes dataset.", + ) + parser.add_argument( + "--gen_train_or_val", + type=str, + default='train', + help="Use which model to run inference.", + ) + parser.add_argument( + "--text_or_img", + type=str, + default='mix', + help="Use which model to run inference.", + ) + parser.add_argument( + "--model_path_infer", + type=str, + default=None, + help="Use which model to run inference.", + ) + parser.add_argument( + "--save_img_path", + type=str, + default=None, + help="Path to the saved image generated by models.", + ) + parser.add_argument( + "--use_sdxl", + type=bool, + default=False, + help="Path to the saved image generated by models.", + ) + parser.add_argument( + "--use_cbgs", + type=bool, + default=False, + help="Path to the saved image generated by models.", + ) + parser.add_argument( + "--cfg_scale", + type=float, + default=7.0, + help="Path to the saved image generated by models.", + ) + parser.add_argument( + "--curr_gpu", + type=int, + default=0, + help="Path to the saved image generated by models.", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained controlnet model or model identifier from huggingface.co/models." + " If not specified controlnet weights are initialized from unet.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--output_dir", + type=str, + default="controlnet-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--patch_size", + type=int, + default=224, + ) + parser.add_argument( + "--width", + type=int, + default=800, + ) + parser.add_argument( + "--height", + type=int, + default=448, + ) + parser.add_argument( + "--ctrl_channel", + type=int, + default=257, + ) + parser.add_argument( + "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--mulscale_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=3, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant_with_warmup", + # default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + # default=0, + # default=1, + default=4, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=2, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="train_controlnet", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + ) + + return args diff --git a/ckp/plane.jpg b/ckp/plane.jpg new file mode 100644 index 0000000..7713ab3 Binary files /dev/null and b/ckp/plane.jpg differ diff --git a/data/instance_dataset/000000175438/0000.png b/data/instance_dataset/000000175438/0000.png new file mode 100644 index 0000000..5356f16 Binary files /dev/null and b/data/instance_dataset/000000175438/0000.png differ diff --git a/data/instance_dataset/000000175438/0000_mask.png b/data/instance_dataset/000000175438/0000_mask.png new file mode 100644 index 0000000..f6b157b Binary files /dev/null and b/data/instance_dataset/000000175438/0000_mask.png differ diff --git a/data/instance_dataset/000000175438/0001.png b/data/instance_dataset/000000175438/0001.png new file mode 100644 index 0000000..507a08d Binary files /dev/null and b/data/instance_dataset/000000175438/0001.png differ diff --git a/data/instance_dataset/000000175438/0001_mask.png b/data/instance_dataset/000000175438/0001_mask.png new file mode 100644 index 0000000..95267f4 Binary files /dev/null and b/data/instance_dataset/000000175438/0001_mask.png differ diff --git a/data/instance_dataset/000000175438/0002.png b/data/instance_dataset/000000175438/0002.png new file mode 100644 index 0000000..cca61ab Binary files /dev/null and b/data/instance_dataset/000000175438/0002.png differ diff --git a/data/instance_dataset/000000175438/0002_mask.png b/data/instance_dataset/000000175438/0002_mask.png new file mode 100644 index 0000000..23d5c61 Binary files /dev/null and b/data/instance_dataset/000000175438/0002_mask.png differ diff --git a/data/instance_dataset/000000175438/0003.png b/data/instance_dataset/000000175438/0003.png new file mode 100644 index 0000000..219b7bc Binary files /dev/null and b/data/instance_dataset/000000175438/0003.png differ diff --git a/data/instance_dataset/000000175438/0003_mask.png b/data/instance_dataset/000000175438/0003_mask.png new file mode 100644 index 0000000..ceccb34 Binary files /dev/null and b/data/instance_dataset/000000175438/0003_mask.png differ diff --git a/data/instance_dataset/000000175438/0004.png b/data/instance_dataset/000000175438/0004.png new file mode 100644 index 0000000..e9e7428 Binary files /dev/null and b/data/instance_dataset/000000175438/0004.png differ diff --git a/data/instance_dataset/000000175438/0004_mask.png b/data/instance_dataset/000000175438/0004_mask.png new file mode 100644 index 0000000..0a10465 Binary files /dev/null and b/data/instance_dataset/000000175438/0004_mask.png differ diff --git a/data/instance_dataset/000000175438/0005.png b/data/instance_dataset/000000175438/0005.png new file mode 100644 index 0000000..a0d50bf Binary files /dev/null and b/data/instance_dataset/000000175438/0005.png differ diff --git a/data/instance_dataset/000000175438/0005_mask.png b/data/instance_dataset/000000175438/0005_mask.png new file mode 100644 index 0000000..9eeb8b8 Binary files /dev/null and b/data/instance_dataset/000000175438/0005_mask.png differ diff --git a/data/instance_dataset/000000175438/0006.png b/data/instance_dataset/000000175438/0006.png new file mode 100644 index 0000000..f30ebe1 Binary files /dev/null and b/data/instance_dataset/000000175438/0006.png differ diff --git a/data/instance_dataset/000000175438/0006_mask.png b/data/instance_dataset/000000175438/0006_mask.png new file mode 100644 index 0000000..4515ffd Binary files /dev/null and b/data/instance_dataset/000000175438/0006_mask.png differ diff --git a/data/instance_dataset/000000175438/0007.png b/data/instance_dataset/000000175438/0007.png new file mode 100644 index 0000000..14d81ce Binary files /dev/null and b/data/instance_dataset/000000175438/0007.png differ diff --git a/data/instance_dataset/000000175438/0007_mask.png b/data/instance_dataset/000000175438/0007_mask.png new file mode 100644 index 0000000..2a00c62 Binary files /dev/null and b/data/instance_dataset/000000175438/0007_mask.png differ diff --git a/data/instance_dataset/000000175438/0008.png b/data/instance_dataset/000000175438/0008.png new file mode 100644 index 0000000..143c300 Binary files /dev/null and b/data/instance_dataset/000000175438/0008.png differ diff --git a/data/instance_dataset/000000175438/0008_mask.png b/data/instance_dataset/000000175438/0008_mask.png new file mode 100644 index 0000000..c375798 Binary files /dev/null and b/data/instance_dataset/000000175438/0008_mask.png differ diff --git a/data/instance_dataset/000000175438/0009.png b/data/instance_dataset/000000175438/0009.png new file mode 100644 index 0000000..20213ee Binary files /dev/null and b/data/instance_dataset/000000175438/0009.png differ diff --git a/data/instance_dataset/000000175438/0009_mask.png b/data/instance_dataset/000000175438/0009_mask.png new file mode 100644 index 0000000..8da624d Binary files /dev/null and b/data/instance_dataset/000000175438/0009_mask.png differ diff --git a/data/instance_dataset/000000175438/0010.png b/data/instance_dataset/000000175438/0010.png new file mode 100644 index 0000000..dcba3f1 Binary files /dev/null and b/data/instance_dataset/000000175438/0010.png differ diff --git a/data/instance_dataset/000000175438/0010_mask.png b/data/instance_dataset/000000175438/0010_mask.png new file mode 100644 index 0000000..c37e83d Binary files /dev/null and b/data/instance_dataset/000000175438/0010_mask.png differ diff --git a/data/instance_dataset/000000175438/0011.png b/data/instance_dataset/000000175438/0011.png new file mode 100644 index 0000000..ef61b3d Binary files /dev/null and b/data/instance_dataset/000000175438/0011.png differ diff --git a/data/instance_dataset/000000175438/0011_mask.png b/data/instance_dataset/000000175438/0011_mask.png new file mode 100644 index 0000000..94291aa Binary files /dev/null and b/data/instance_dataset/000000175438/0011_mask.png differ diff --git a/data/instance_dataset/000000175438/0012.png b/data/instance_dataset/000000175438/0012.png new file mode 100644 index 0000000..74b5334 Binary files /dev/null and b/data/instance_dataset/000000175438/0012.png differ diff --git a/data/instance_dataset/000000175438/0012_mask.png b/data/instance_dataset/000000175438/0012_mask.png new file mode 100644 index 0000000..60ac181 Binary files /dev/null and b/data/instance_dataset/000000175438/0012_mask.png differ diff --git a/data/instance_dataset/000000175438/0013.png b/data/instance_dataset/000000175438/0013.png new file mode 100644 index 0000000..bc4a5d4 Binary files /dev/null and b/data/instance_dataset/000000175438/0013.png differ diff --git a/data/instance_dataset/000000175438/0013_mask.png b/data/instance_dataset/000000175438/0013_mask.png new file mode 100644 index 0000000..e7cce9c Binary files /dev/null and b/data/instance_dataset/000000175438/0013_mask.png differ diff --git a/data/instance_dataset/000000175438/0014.png b/data/instance_dataset/000000175438/0014.png new file mode 100644 index 0000000..b0be9d7 Binary files /dev/null and b/data/instance_dataset/000000175438/0014.png differ diff --git a/data/instance_dataset/000000175438/0014_mask.png b/data/instance_dataset/000000175438/0014_mask.png new file mode 100644 index 0000000..c11b51c Binary files /dev/null and b/data/instance_dataset/000000175438/0014_mask.png differ diff --git a/data/instance_dataset/000000175438/0015.png b/data/instance_dataset/000000175438/0015.png new file mode 100644 index 0000000..b9598cf Binary files /dev/null and b/data/instance_dataset/000000175438/0015.png differ diff --git a/data/instance_dataset/000000175438/0015_mask.png b/data/instance_dataset/000000175438/0015_mask.png new file mode 100644 index 0000000..edb9184 Binary files /dev/null and b/data/instance_dataset/000000175438/0015_mask.png differ diff --git a/data/instance_dataset/000000175438/0016.png b/data/instance_dataset/000000175438/0016.png new file mode 100644 index 0000000..f3cb56a Binary files /dev/null and b/data/instance_dataset/000000175438/0016.png differ diff --git a/data/instance_dataset/000000175438/0016_mask.png b/data/instance_dataset/000000175438/0016_mask.png new file mode 100644 index 0000000..41f8e84 Binary files /dev/null and b/data/instance_dataset/000000175438/0016_mask.png differ diff --git a/data/instance_dataset/000000175438/0017.png b/data/instance_dataset/000000175438/0017.png new file mode 100644 index 0000000..e6a4a4f Binary files /dev/null and b/data/instance_dataset/000000175438/0017.png differ diff --git a/data/instance_dataset/000000175438/0017_mask.png b/data/instance_dataset/000000175438/0017_mask.png new file mode 100644 index 0000000..290e616 Binary files /dev/null and b/data/instance_dataset/000000175438/0017_mask.png differ diff --git a/data/instance_dataset/000000175438/0018.png b/data/instance_dataset/000000175438/0018.png new file mode 100644 index 0000000..dcfa73f Binary files /dev/null and b/data/instance_dataset/000000175438/0018.png differ diff --git a/data/instance_dataset/000000175438/0018_mask.png b/data/instance_dataset/000000175438/0018_mask.png new file mode 100644 index 0000000..70573f5 Binary files /dev/null and b/data/instance_dataset/000000175438/0018_mask.png differ diff --git a/data/instance_dataset/000000175438/0019.png b/data/instance_dataset/000000175438/0019.png new file mode 100644 index 0000000..c217fc2 Binary files /dev/null and b/data/instance_dataset/000000175438/0019.png differ diff --git a/data/instance_dataset/000000175438/0019_mask.png b/data/instance_dataset/000000175438/0019_mask.png new file mode 100644 index 0000000..626c345 Binary files /dev/null and b/data/instance_dataset/000000175438/0019_mask.png differ diff --git a/data/instance_dataset/000000175438/0020.png b/data/instance_dataset/000000175438/0020.png new file mode 100644 index 0000000..94c9771 Binary files /dev/null and b/data/instance_dataset/000000175438/0020.png differ diff --git a/data/instance_dataset/000000175438/0020_mask.png b/data/instance_dataset/000000175438/0020_mask.png new file mode 100644 index 0000000..714d831 Binary files /dev/null and b/data/instance_dataset/000000175438/0020_mask.png differ diff --git a/data/instance_dataset/000000175438/0021.png b/data/instance_dataset/000000175438/0021.png new file mode 100644 index 0000000..a93bb71 Binary files /dev/null and b/data/instance_dataset/000000175438/0021.png differ diff --git a/data/instance_dataset/000000175438/0021_mask.png b/data/instance_dataset/000000175438/0021_mask.png new file mode 100644 index 0000000..452ea10 Binary files /dev/null and b/data/instance_dataset/000000175438/0021_mask.png differ diff --git a/data/instance_dataset/000000175438/0022.png b/data/instance_dataset/000000175438/0022.png new file mode 100644 index 0000000..acb366f Binary files /dev/null and b/data/instance_dataset/000000175438/0022.png differ diff --git a/data/instance_dataset/000000175438/0022_mask.png b/data/instance_dataset/000000175438/0022_mask.png new file mode 100644 index 0000000..01165dd Binary files /dev/null and b/data/instance_dataset/000000175438/0022_mask.png differ diff --git a/data/instance_dataset/000000175438/0023.png b/data/instance_dataset/000000175438/0023.png new file mode 100644 index 0000000..07052c8 Binary files /dev/null and b/data/instance_dataset/000000175438/0023.png differ diff --git a/data/instance_dataset/000000175438/0023_mask.png b/data/instance_dataset/000000175438/0023_mask.png new file mode 100644 index 0000000..90332a5 Binary files /dev/null and b/data/instance_dataset/000000175438/0023_mask.png differ diff --git a/data/instance_dataset/000000175438/0024.png b/data/instance_dataset/000000175438/0024.png new file mode 100644 index 0000000..d57ecbe Binary files /dev/null and b/data/instance_dataset/000000175438/0024.png differ diff --git a/data/instance_dataset/000000175438/0024_mask.png b/data/instance_dataset/000000175438/0024_mask.png new file mode 100644 index 0000000..0765333 Binary files /dev/null and b/data/instance_dataset/000000175438/0024_mask.png differ diff --git a/data/instance_dataset/000000175438/0025.png b/data/instance_dataset/000000175438/0025.png new file mode 100644 index 0000000..44c3e23 Binary files /dev/null and b/data/instance_dataset/000000175438/0025.png differ diff --git a/data/instance_dataset/000000175438/0025_mask.png b/data/instance_dataset/000000175438/0025_mask.png new file mode 100644 index 0000000..809ec06 Binary files /dev/null and b/data/instance_dataset/000000175438/0025_mask.png differ diff --git a/data/instance_dataset/000000175438/prompt_dict.json b/data/instance_dataset/000000175438/prompt_dict.json new file mode 100644 index 0000000..88d7bb6 --- /dev/null +++ b/data/instance_dataset/000000175438/prompt_dict.json @@ -0,0 +1 @@ +{"prompt_0": "a sky of other of merged", "prompt_1": "a road", "prompt_2": "a house with a chimney and a tree in the background", "prompt_3": "a street with a clock tower in the middle", "prompt_4": "a pavement of merged", "prompt_5": "a metal gate with a black fence and a gate", "prompt_6": "a clock with roman numerals on the face", "prompt_7": "a traffic light", "prompt_8": "a traffic light with three lights on it", "prompt_9": "a traffic light with three lights on it", "prompt_10": "a blue car with a white stripe on the side", "prompt_11": "a car", "prompt_12": "a blue car", "prompt_13": "a car", "prompt_14": "a car is parked in front of a building", "prompt_15": "a blue car is shown in this image", "prompt_16": "a silver car", "prompt_17": "a car", "prompt_18": "a black car with its doors open", "prompt_19": "a car", "prompt_20": "a car", "prompt_21": "a car", "prompt_22": "a blue car", "prompt_23": "a car", "prompt_24": "a white and blue car with a yellow and orange stripe", "prompt_25": "a small car is shown", "global_prompt": "A large cock sitting in the middle of a street."} \ No newline at end of file diff --git a/data/instance_dataset/000000176232/0000.png b/data/instance_dataset/000000176232/0000.png new file mode 100644 index 0000000..f6c2778 Binary files /dev/null and b/data/instance_dataset/000000176232/0000.png differ diff --git a/data/instance_dataset/000000176232/0000_mask.png b/data/instance_dataset/000000176232/0000_mask.png new file mode 100644 index 0000000..abb717a Binary files /dev/null and b/data/instance_dataset/000000176232/0000_mask.png differ diff --git a/data/instance_dataset/000000176232/0001.png b/data/instance_dataset/000000176232/0001.png new file mode 100644 index 0000000..614f77e Binary files /dev/null and b/data/instance_dataset/000000176232/0001.png differ diff --git a/data/instance_dataset/000000176232/0001_mask.png b/data/instance_dataset/000000176232/0001_mask.png new file mode 100644 index 0000000..9ee8a95 Binary files /dev/null and b/data/instance_dataset/000000176232/0001_mask.png differ diff --git a/data/instance_dataset/000000176232/0002.png b/data/instance_dataset/000000176232/0002.png new file mode 100644 index 0000000..9872515 Binary files /dev/null and b/data/instance_dataset/000000176232/0002.png differ diff --git a/data/instance_dataset/000000176232/0002_mask.png b/data/instance_dataset/000000176232/0002_mask.png new file mode 100644 index 0000000..2230282 Binary files /dev/null and b/data/instance_dataset/000000176232/0002_mask.png differ diff --git a/data/instance_dataset/000000176232/0003.png b/data/instance_dataset/000000176232/0003.png new file mode 100644 index 0000000..3881215 Binary files /dev/null and b/data/instance_dataset/000000176232/0003.png differ diff --git a/data/instance_dataset/000000176232/0003_mask.png b/data/instance_dataset/000000176232/0003_mask.png new file mode 100644 index 0000000..dfe2478 Binary files /dev/null and b/data/instance_dataset/000000176232/0003_mask.png differ diff --git a/data/instance_dataset/000000176232/prompt_dict.json b/data/instance_dataset/000000176232/prompt_dict.json new file mode 100644 index 0000000..9d5e135 --- /dev/null +++ b/data/instance_dataset/000000176232/prompt_dict.json @@ -0,0 +1 @@ +{"prompt_0": "a wall of other of merged", "prompt_1": "a bunch of purple flowers are in a vase", "prompt_2": "a paper of merged", "prompt_3": "a copper vase with a striped design", "global_prompt": "A decorative flower vase with lavender in it."} \ No newline at end of file diff --git a/data/instance_dataset/000000177213/0000.png b/data/instance_dataset/000000177213/0000.png new file mode 100644 index 0000000..6c92ae2 Binary files /dev/null and b/data/instance_dataset/000000177213/0000.png differ diff --git a/data/instance_dataset/000000177213/0000_mask.png b/data/instance_dataset/000000177213/0000_mask.png new file mode 100644 index 0000000..e44654a Binary files /dev/null and b/data/instance_dataset/000000177213/0000_mask.png differ diff --git a/data/instance_dataset/000000177213/0001.png b/data/instance_dataset/000000177213/0001.png new file mode 100644 index 0000000..97c98d8 Binary files /dev/null and b/data/instance_dataset/000000177213/0001.png differ diff --git a/data/instance_dataset/000000177213/0001_mask.png b/data/instance_dataset/000000177213/0001_mask.png new file mode 100644 index 0000000..4b7771b Binary files /dev/null and b/data/instance_dataset/000000177213/0001_mask.png differ diff --git a/data/instance_dataset/000000177213/0002.png b/data/instance_dataset/000000177213/0002.png new file mode 100644 index 0000000..5ad865e Binary files /dev/null and b/data/instance_dataset/000000177213/0002.png differ diff --git a/data/instance_dataset/000000177213/0002_mask.png b/data/instance_dataset/000000177213/0002_mask.png new file mode 100644 index 0000000..8ff8c63 Binary files /dev/null and b/data/instance_dataset/000000177213/0002_mask.png differ diff --git a/data/instance_dataset/000000177213/0003.png b/data/instance_dataset/000000177213/0003.png new file mode 100644 index 0000000..f7ff1ef Binary files /dev/null and b/data/instance_dataset/000000177213/0003.png differ diff --git a/data/instance_dataset/000000177213/0003_mask.png b/data/instance_dataset/000000177213/0003_mask.png new file mode 100644 index 0000000..694d5d3 Binary files /dev/null and b/data/instance_dataset/000000177213/0003_mask.png differ diff --git a/data/instance_dataset/000000177213/0004.png b/data/instance_dataset/000000177213/0004.png new file mode 100644 index 0000000..d2f41c9 Binary files /dev/null and b/data/instance_dataset/000000177213/0004.png differ diff --git a/data/instance_dataset/000000177213/0004_mask.png b/data/instance_dataset/000000177213/0004_mask.png new file mode 100644 index 0000000..1fb96ad Binary files /dev/null and b/data/instance_dataset/000000177213/0004_mask.png differ diff --git a/data/instance_dataset/000000177213/0005.png b/data/instance_dataset/000000177213/0005.png new file mode 100644 index 0000000..e2fe079 Binary files /dev/null and b/data/instance_dataset/000000177213/0005.png differ diff --git a/data/instance_dataset/000000177213/0005_mask.png b/data/instance_dataset/000000177213/0005_mask.png new file mode 100644 index 0000000..ab060c3 Binary files /dev/null and b/data/instance_dataset/000000177213/0005_mask.png differ diff --git a/data/instance_dataset/000000177213/0006.png b/data/instance_dataset/000000177213/0006.png new file mode 100644 index 0000000..8a23045 Binary files /dev/null and b/data/instance_dataset/000000177213/0006.png differ diff --git a/data/instance_dataset/000000177213/0006_mask.png b/data/instance_dataset/000000177213/0006_mask.png new file mode 100644 index 0000000..6d5d0f8 Binary files /dev/null and b/data/instance_dataset/000000177213/0006_mask.png differ diff --git a/data/instance_dataset/000000177213/0007.png b/data/instance_dataset/000000177213/0007.png new file mode 100644 index 0000000..9d899a9 Binary files /dev/null and b/data/instance_dataset/000000177213/0007.png differ diff --git a/data/instance_dataset/000000177213/0007_mask.png b/data/instance_dataset/000000177213/0007_mask.png new file mode 100644 index 0000000..c625703 Binary files /dev/null and b/data/instance_dataset/000000177213/0007_mask.png differ diff --git a/data/instance_dataset/000000177213/0008.png b/data/instance_dataset/000000177213/0008.png new file mode 100644 index 0000000..80315b5 Binary files /dev/null and b/data/instance_dataset/000000177213/0008.png differ diff --git a/data/instance_dataset/000000177213/0008_mask.png b/data/instance_dataset/000000177213/0008_mask.png new file mode 100644 index 0000000..2ce47fa Binary files /dev/null and b/data/instance_dataset/000000177213/0008_mask.png differ diff --git a/data/instance_dataset/000000177213/0009.png b/data/instance_dataset/000000177213/0009.png new file mode 100644 index 0000000..a99727c Binary files /dev/null and b/data/instance_dataset/000000177213/0009.png differ diff --git a/data/instance_dataset/000000177213/0009_mask.png b/data/instance_dataset/000000177213/0009_mask.png new file mode 100644 index 0000000..5162cdc Binary files /dev/null and b/data/instance_dataset/000000177213/0009_mask.png differ diff --git a/data/instance_dataset/000000177213/prompt_dict.json b/data/instance_dataset/000000177213/prompt_dict.json new file mode 100644 index 0000000..ec52365 --- /dev/null +++ b/data/instance_dataset/000000177213/prompt_dict.json @@ -0,0 +1 @@ +{"prompt_0": "a floor of other of merged", "prompt_1": "a wall of other of merged", "prompt_2": "a paper of merged", "prompt_3": "a person", "prompt_4": "a cup", "prompt_5": "a cup", "prompt_6": "a fork", "prompt_7": "a pizza with cheese and basil leaves on top", "prompt_8": "a knife and orange flames", "prompt_9": "a dining table", "global_prompt": "Small pizza sits on a plate on a restaurant table. "} \ No newline at end of file diff --git a/data/instance_dataset/000000178618/0000.png b/data/instance_dataset/000000178618/0000.png new file mode 100644 index 0000000..2e9ab72 Binary files /dev/null and b/data/instance_dataset/000000178618/0000.png differ diff --git a/data/instance_dataset/000000178618/0000_mask.png b/data/instance_dataset/000000178618/0000_mask.png new file mode 100644 index 0000000..aeef580 Binary files /dev/null and b/data/instance_dataset/000000178618/0000_mask.png differ diff --git a/data/instance_dataset/000000178618/0001.png b/data/instance_dataset/000000178618/0001.png new file mode 100644 index 0000000..9d12912 Binary files /dev/null and b/data/instance_dataset/000000178618/0001.png differ diff --git a/data/instance_dataset/000000178618/0001_mask.png b/data/instance_dataset/000000178618/0001_mask.png new file mode 100644 index 0000000..3982f79 Binary files /dev/null and b/data/instance_dataset/000000178618/0001_mask.png differ diff --git a/data/instance_dataset/000000178618/0002.png b/data/instance_dataset/000000178618/0002.png new file mode 100644 index 0000000..54c417b Binary files /dev/null and b/data/instance_dataset/000000178618/0002.png differ diff --git a/data/instance_dataset/000000178618/0002_mask.png b/data/instance_dataset/000000178618/0002_mask.png new file mode 100644 index 0000000..3b8b542 Binary files /dev/null and b/data/instance_dataset/000000178618/0002_mask.png differ diff --git a/data/instance_dataset/000000178618/0003.png b/data/instance_dataset/000000178618/0003.png new file mode 100644 index 0000000..203c185 Binary files /dev/null and b/data/instance_dataset/000000178618/0003.png differ diff --git a/data/instance_dataset/000000178618/0003_mask.png b/data/instance_dataset/000000178618/0003_mask.png new file mode 100644 index 0000000..400ebe0 Binary files /dev/null and b/data/instance_dataset/000000178618/0003_mask.png differ diff --git a/data/instance_dataset/000000178618/0004.png b/data/instance_dataset/000000178618/0004.png new file mode 100644 index 0000000..3c1ea00 Binary files /dev/null and b/data/instance_dataset/000000178618/0004.png differ diff --git a/data/instance_dataset/000000178618/0004_mask.png b/data/instance_dataset/000000178618/0004_mask.png new file mode 100644 index 0000000..7e1db7d Binary files /dev/null and b/data/instance_dataset/000000178618/0004_mask.png differ diff --git a/data/instance_dataset/000000178618/0005.png b/data/instance_dataset/000000178618/0005.png new file mode 100644 index 0000000..12d08e3 Binary files /dev/null and b/data/instance_dataset/000000178618/0005.png differ diff --git a/data/instance_dataset/000000178618/0005_mask.png b/data/instance_dataset/000000178618/0005_mask.png new file mode 100644 index 0000000..10ec4c8 Binary files /dev/null and b/data/instance_dataset/000000178618/0005_mask.png differ diff --git a/data/instance_dataset/000000178618/prompt_dict.json b/data/instance_dataset/000000178618/prompt_dict.json new file mode 100644 index 0000000..65fd375 --- /dev/null +++ b/data/instance_dataset/000000178618/prompt_dict.json @@ -0,0 +1 @@ +{"prompt_0": "a sky of other of merged", "prompt_1": "a dirt of merged", "prompt_2": "a tree of merged", "prompt_3": "elephant with tusks walking", "prompt_4": "a elephant", "prompt_5": "a brown elephant with a long trunk", "global_prompt": "A dust cloud has formed in front of an elephant."} \ No newline at end of file diff --git a/data/instance_dataset/000000187362/0000.png b/data/instance_dataset/000000187362/0000.png new file mode 100644 index 0000000..2a52989 Binary files /dev/null and b/data/instance_dataset/000000187362/0000.png differ diff --git a/data/instance_dataset/000000187362/0000_mask.png b/data/instance_dataset/000000187362/0000_mask.png new file mode 100644 index 0000000..12d77a5 Binary files /dev/null and b/data/instance_dataset/000000187362/0000_mask.png differ diff --git a/data/instance_dataset/000000187362/0001.png b/data/instance_dataset/000000187362/0001.png new file mode 100644 index 0000000..7906c9c Binary files /dev/null and b/data/instance_dataset/000000187362/0001.png differ diff --git a/data/instance_dataset/000000187362/0001_mask.png b/data/instance_dataset/000000187362/0001_mask.png new file mode 100644 index 0000000..2259853 Binary files /dev/null and b/data/instance_dataset/000000187362/0001_mask.png differ diff --git a/data/instance_dataset/000000187362/0002.png b/data/instance_dataset/000000187362/0002.png new file mode 100644 index 0000000..b6534ff Binary files /dev/null and b/data/instance_dataset/000000187362/0002.png differ diff --git a/data/instance_dataset/000000187362/0002_mask.png b/data/instance_dataset/000000187362/0002_mask.png new file mode 100644 index 0000000..f52534f Binary files /dev/null and b/data/instance_dataset/000000187362/0002_mask.png differ diff --git a/data/instance_dataset/000000187362/0003.png b/data/instance_dataset/000000187362/0003.png new file mode 100644 index 0000000..caf3810 Binary files /dev/null and b/data/instance_dataset/000000187362/0003.png differ diff --git a/data/instance_dataset/000000187362/0003_mask.png b/data/instance_dataset/000000187362/0003_mask.png new file mode 100644 index 0000000..0c9c68b Binary files /dev/null and b/data/instance_dataset/000000187362/0003_mask.png differ diff --git a/data/instance_dataset/000000187362/0004.png b/data/instance_dataset/000000187362/0004.png new file mode 100644 index 0000000..a0b94cc Binary files /dev/null and b/data/instance_dataset/000000187362/0004.png differ diff --git a/data/instance_dataset/000000187362/0004_mask.png b/data/instance_dataset/000000187362/0004_mask.png new file mode 100644 index 0000000..836aa6a Binary files /dev/null and b/data/instance_dataset/000000187362/0004_mask.png differ diff --git a/data/instance_dataset/000000187362/prompt_dict.json b/data/instance_dataset/000000187362/prompt_dict.json new file mode 100644 index 0000000..2535b37 --- /dev/null +++ b/data/instance_dataset/000000187362/prompt_dict.json @@ -0,0 +1 @@ +{"prompt_0": "a sky of other of merged", "prompt_1": "a sand", "prompt_2": "a sea", "prompt_3": "a person", "prompt_4": "a surfboard", "global_prompt": "A man and his surfboard survey a stormy sea"} \ No newline at end of file diff --git a/data/instance_dataset/dog/0000.png b/data/instance_dataset/dog/0000.png new file mode 100644 index 0000000..d6dc128 Binary files /dev/null and b/data/instance_dataset/dog/0000.png differ diff --git a/data/instance_dataset/dog/0000_mask.png b/data/instance_dataset/dog/0000_mask.png new file mode 100644 index 0000000..512bc83 Binary files /dev/null and b/data/instance_dataset/dog/0000_mask.png differ diff --git a/data/instance_dataset/dog/0001.png b/data/instance_dataset/dog/0001.png new file mode 100644 index 0000000..3d35fcf Binary files /dev/null and b/data/instance_dataset/dog/0001.png differ diff --git a/data/instance_dataset/dog/0001_mask.png b/data/instance_dataset/dog/0001_mask.png new file mode 100644 index 0000000..965f108 Binary files /dev/null and b/data/instance_dataset/dog/0001_mask.png differ diff --git a/data/instance_dataset/dog/0002.png b/data/instance_dataset/dog/0002.png new file mode 100644 index 0000000..8eba4f6 Binary files /dev/null and b/data/instance_dataset/dog/0002.png differ diff --git a/data/instance_dataset/dog/0002_mask.png b/data/instance_dataset/dog/0002_mask.png new file mode 100644 index 0000000..9c45495 Binary files /dev/null and b/data/instance_dataset/dog/0002_mask.png differ diff --git a/data/instance_dataset/dog/prompt_dict.json b/data/instance_dataset/dog/prompt_dict.json new file mode 100644 index 0000000..3340379 --- /dev/null +++ b/data/instance_dataset/dog/prompt_dict.json @@ -0,0 +1 @@ +{"prompt_0": "a dog is running in the grass", "prompt_1": "a black and tan dog running on a green background", "prompt_2": "a frisbee", "global_prompt": "A husky dog has an orange frisbee in it's mouth."} \ No newline at end of file diff --git a/data/instance_dataset/ff_instance_1/0000_mask.png b/data/instance_dataset/ff_instance_1/0000_mask.png new file mode 100644 index 0000000..a258c26 Binary files /dev/null and b/data/instance_dataset/ff_instance_1/0000_mask.png differ diff --git a/data/instance_dataset/ff_instance_1/0001_mask.png b/data/instance_dataset/ff_instance_1/0001_mask.png new file mode 100644 index 0000000..679be0b Binary files /dev/null and b/data/instance_dataset/ff_instance_1/0001_mask.png differ diff --git a/data/instance_dataset/ff_instance_1/0002_mask.png b/data/instance_dataset/ff_instance_1/0002_mask.png new file mode 100644 index 0000000..c9c475e Binary files /dev/null and b/data/instance_dataset/ff_instance_1/0002_mask.png differ diff --git a/data/instance_dataset/ff_instance_1/0003_mask.png b/data/instance_dataset/ff_instance_1/0003_mask.png new file mode 100644 index 0000000..eeb9c1a Binary files /dev/null and b/data/instance_dataset/ff_instance_1/0003_mask.png differ diff --git a/data/instance_dataset/ff_instance_1/0004_mask.png b/data/instance_dataset/ff_instance_1/0004_mask.png new file mode 100644 index 0000000..e5710c4 Binary files /dev/null and b/data/instance_dataset/ff_instance_1/0004_mask.png differ diff --git a/data/instance_dataset/ff_instance_1/grass.png b/data/instance_dataset/ff_instance_1/grass.png new file mode 100644 index 0000000..bc32d21 Binary files /dev/null and b/data/instance_dataset/ff_instance_1/grass.png differ diff --git a/data/instance_dataset/ff_instance_1/prompt_dict.json b/data/instance_dataset/ff_instance_1/prompt_dict.json new file mode 100644 index 0000000..0a08feb --- /dev/null +++ b/data/instance_dataset/ff_instance_1/prompt_dict.json @@ -0,0 +1 @@ +{"prompt_0": "blue sky", "prompt_1": "a green grass land", "prompt_2": "A road", "prompt_3": "a mountain", "prompt_4": "a human waving hand", "global_prompt": "A person standing at the foot of the mountain waving. There are roads and grasslands next to it, with beautiful mountains and a blue sky"} \ No newline at end of file diff --git a/data/instance_dataset/plane/0000.png b/data/instance_dataset/plane/0000.png new file mode 100644 index 0000000..e9f5157 Binary files /dev/null and b/data/instance_dataset/plane/0000.png differ diff --git a/data/instance_dataset/plane/0000_mask.png b/data/instance_dataset/plane/0000_mask.png new file mode 100644 index 0000000..e99a0c1 Binary files /dev/null and b/data/instance_dataset/plane/0000_mask.png differ diff --git a/data/instance_dataset/plane/0001.png b/data/instance_dataset/plane/0001.png new file mode 100644 index 0000000..14f09ff Binary files /dev/null and b/data/instance_dataset/plane/0001.png differ diff --git a/data/instance_dataset/plane/0001_mask.png b/data/instance_dataset/plane/0001_mask.png new file mode 100644 index 0000000..ddedd3c Binary files /dev/null and b/data/instance_dataset/plane/0001_mask.png differ diff --git a/data/instance_dataset/plane/0002.png b/data/instance_dataset/plane/0002.png new file mode 100644 index 0000000..1b6652c Binary files /dev/null and b/data/instance_dataset/plane/0002.png differ diff --git a/data/instance_dataset/plane/0002_mask.png b/data/instance_dataset/plane/0002_mask.png new file mode 100644 index 0000000..fb38ed3 Binary files /dev/null and b/data/instance_dataset/plane/0002_mask.png differ diff --git a/data/instance_dataset/plane/0003.png b/data/instance_dataset/plane/0003.png new file mode 100644 index 0000000..df0c569 Binary files /dev/null and b/data/instance_dataset/plane/0003.png differ diff --git a/data/instance_dataset/plane/0003_mask.png b/data/instance_dataset/plane/0003_mask.png new file mode 100644 index 0000000..d059061 Binary files /dev/null and b/data/instance_dataset/plane/0003_mask.png differ diff --git a/data/instance_dataset/plane/0004.png b/data/instance_dataset/plane/0004.png new file mode 100644 index 0000000..b17953e Binary files /dev/null and b/data/instance_dataset/plane/0004.png differ diff --git a/data/instance_dataset/plane/0004_mask.png b/data/instance_dataset/plane/0004_mask.png new file mode 100644 index 0000000..98c7e33 Binary files /dev/null and b/data/instance_dataset/plane/0004_mask.png differ diff --git a/data/instance_dataset/plane/0005.png b/data/instance_dataset/plane/0005.png new file mode 100644 index 0000000..a6233a9 Binary files /dev/null and b/data/instance_dataset/plane/0005.png differ diff --git a/data/instance_dataset/plane/0005_mask.png b/data/instance_dataset/plane/0005_mask.png new file mode 100644 index 0000000..450bc12 Binary files /dev/null and b/data/instance_dataset/plane/0005_mask.png differ diff --git a/data/instance_dataset/plane/0006.png b/data/instance_dataset/plane/0006.png new file mode 100644 index 0000000..e4412bf Binary files /dev/null and b/data/instance_dataset/plane/0006.png differ diff --git a/data/instance_dataset/plane/0006_mask.png b/data/instance_dataset/plane/0006_mask.png new file mode 100644 index 0000000..26f6859 Binary files /dev/null and b/data/instance_dataset/plane/0006_mask.png differ diff --git a/data/instance_dataset/plane/0007.png b/data/instance_dataset/plane/0007.png new file mode 100644 index 0000000..e2f9a2f Binary files /dev/null and b/data/instance_dataset/plane/0007.png differ diff --git a/data/instance_dataset/plane/0007_mask.png b/data/instance_dataset/plane/0007_mask.png new file mode 100644 index 0000000..18c99a7 Binary files /dev/null and b/data/instance_dataset/plane/0007_mask.png differ diff --git a/data/instance_dataset/plane/prompt_dict.json b/data/instance_dataset/plane/prompt_dict.json new file mode 100644 index 0000000..7f3d749 --- /dev/null +++ b/data/instance_dataset/plane/prompt_dict.json @@ -0,0 +1 @@ +{"prompt_0": "a plane is silhouetted against a cloudy sky", "prompt_1": "a road", "prompt_2": "a pavement of merged", "prompt_3": "a grass of merged", "prompt_4": "a mountain of merged", "prompt_5": "a large yellow airplane with blue letters", "prompt_6": "a large white and orange airplane flying in the air", "prompt_7": "a airplane", "global_prompt": "large mustard yellow commercial airplane parked in the airport"} \ No newline at end of file diff --git a/infer_coco.py b/infer_coco.py new file mode 100644 index 0000000..d7b88d8 --- /dev/null +++ b/infer_coco.py @@ -0,0 +1,385 @@ +import torch +from diffusers import UniPCMultistepScheduler + +import cv2 +import os +import tqdm +import pickle +import numpy as np +from torchvision.io import read_image +import torch.nn.functional as F +from torchvision import utils + +import torchvision.transforms.functional as tf + +import torch.distributed as dist +import torch.multiprocessing as mp +from torchvision.ops import masks_to_boxes + +# torch.set_default_device('cuda') + +from diffusers import ( + UNet2DConditionModel, + AutoencoderKL, +) + +from models.controlnet1x1 import ControlNetModel1x1 as ControlNetModel + +from models.pipeline_controlnet_sd_xl import ( + StableDiffusionXLControlNetPipeline as StableDiffusionXLControlNetPipeline, +) +# from models.pipeline_controlnet_1x1_4dunet import ( +# StableDiffusionControlNetPipeline1x1 as StableDiffusionControlNetPipeline, +# ) +from models.dino_model import FrozenDinoV2Encoder + +# from models.unet_2d_condition_multiview import UNet2DConditionModelMultiview + +from args_file import parse_args +from transformers import AutoTokenizer +from utils.dataset_nusmtv import CocoNutImg as NuScenesDataset + +from transformers import CLIPTextModel, CLIPTextModelWithProjection + +from sfast.compilers.diffusion_pipeline_compiler import compile, CompilationConfig + +config = CompilationConfig.Default() +# xformers and Triton are suggested for achieving best performance. +try: + import xformers + + config.enable_xformers = True +except ImportError: + print("xformers not installed, skip") + +try: + import triton + + config.enable_triton = True +except ImportError: + print("Triton not installed, skip") + +args = parse_args() + + +ckp_path = "./exp/out_sd21_cbgs_loss/" + +if args.model_path_infer is not None: + ckp_path = args.model_path_infer + +if "checkpoint" not in ckp_path: + dirs = os.listdir(ckp_path) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + ckp_path = os.path.join(ckp_path, dirs[-1]) if len(dirs) > 0 else ckp_path + +height = 448 +width = 800 + +generator = torch.manual_seed(666) +# generator = torch.manual_seed(0) + + +tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, +) + +# text_encoder = CLIPTextModel.from_pretrained( +text_encoder = CLIPTextModelWithProjection.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder_2", + revision=args.revision, + variant=args.variant, +) + +val_dataset = NuScenesDataset(args, tokenizer, args.gen_train_or_val) + +save_path = "vis_dir/out_sd21_cbgs_loss2_40/samples" + +if args.save_img_path is not None: + save_path = args.save_img_path + +os.makedirs(save_path, exist_ok=True) + +print(ckp_path, save_path) + + +def tokenize_captions(examples, tokenizer, is_train=True): + captions = [] + for caption in examples: + captions.append(caption) + + inputs = tokenizer( + captions, + max_length=tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + return inputs.input_ids + + +def run_inference(rank, world_size, pred_results, input_datas, pipe, args): + # if rank is not None: + # # dist.init_process_group("gloo", rank=rank, world_size=world_size) + # dist.init_process_group("nccl", rank=rank, world_size=world_size) + # else: + # rank = 0 + print(rank) + # torch.set_default_device(rank) + + pipe.to("cuda") + dino_encoder.to("cuda") + text_encoder.to("cuda") + # pipe.to(rank) + # dino_encoder.to(rank) + weight_dtype = torch.float16 + all_list = input_datas[rank] + + config.enable_cuda_graph = False + # config.enable_cuda_graph = True + # pipe.controlnet = compile(pipe.controlnet, config) + # pipe = compile(pipe, config) + + with torch.no_grad(): + + # for img_idx in [ + # 0, + # 122, + # 555, + # ]: + + for img_idx in tqdm.tqdm(all_list): + batch = val_dataset.__getitem__(img_idx) + mtv_condition = batch["ctrl_img"] # [None] + validation_prompts = batch["prompts"] + # validation_prompts = ['A photorealistic image.' + x for x in validation_prompts] + + curr_h, curr_w = batch["pixel_values"].shape[-2:] + # import ipdb; ipdb.set_trace() + # print(curr_h, curr_w) + + # utils.save_image(batch['patches'][0][1][0], 'output.jpg') # C H W 0~1 + + prompt_fea = torch.zeros((*batch["ctrl_img"].shape, args.ctrl_channel)).to( + "cuda", dtype=weight_dtype + ) + + for curr_b, curr_ins_prompt in enumerate(batch["input_ids_ins"]): + curr_ins_prompt = ["anything"] + curr_ins_prompt + input_ids = tokenize_captions(curr_ins_prompt, tokenizer).cuda() + with torch.cuda.amp.autocast(): + text_features = text_encoder(input_ids, return_dict=True)[ + "text_embeds" + # "pooler_output" + ] + text_features = controlnet.text_adapter(text_features).to( + prompt_fea + ) + # import ipdb; ipdb.set_trace() + + for curr_ins_id in range(len(curr_ins_prompt)): + prompt_fea[curr_b][batch["ctrl_img"][curr_b] == curr_ins_id] = ( + text_features[curr_ins_id] + ) + + if 0: + # for curr_b, curr_ins_img in enumerate(batch["patches"]): + curr_ins_id, curr_ins_patch = curr_ins_img[0], curr_ins_img[1].to( + prompt_fea + ) + if curr_ins_id.shape[0] > 0: + + with torch.cuda.amp.autocast(): + image_features = dino_encoder(curr_ins_patch) + image_features = controlnet.dino_adapter(image_features).to( + prompt_fea + ) + + for id_ins, curr_ins in enumerate(curr_ins_id.tolist()): + all_vector = image_features[id_ins] + global_vector = all_vector[0:1] + + down_s = args.patch_size // 14 + + patch_vector = ( + all_vector[1 : 1 + down_s * down_s] + .view(1, down_s, down_s, -1) + .permute(0, 3, 1, 2) + ) + curr_mask = batch["ctrl_img"][curr_b] == curr_ins + + if curr_mask.max() < 1: + continue + + # import ipdb; ipdb.set_trace() + curr_box = masks_to_boxes(curr_mask[None])[0].int().tolist() + height, width = ( + curr_box[3] - curr_box[1], + curr_box[2] - curr_box[0], + ) + + x = torch.linspace(-1, 1, height) + y = torch.linspace(-1, 1, width) + + xx, yy = torch.meshgrid(x, y) + grid = torch.stack((xx, yy), dim=2).to(patch_vector)[None] + + warp_fea = F.grid_sample( + patch_vector, + grid, + mode="bilinear", + padding_mode="reflection", + align_corners=True, + )[0].permute(1, 2, 0) + + # import ipdb; ipdb.set_trace() + + small_mask = curr_mask[ + curr_box[1] : curr_box[3], curr_box[0] : curr_box[2] + ] + + curr_pix_num = small_mask.sum().item() + all_ins = np.arange(0, curr_pix_num) + sel_ins = np.random.choice( + all_ins, size=(curr_pix_num // 10,), replace=True + ) + # import ipdb; ipdb.set_trace() + warp_fea[small_mask][sel_ins] = global_vector + + prompt_fea[curr_b][ + curr_box[1] : curr_box[3], curr_box[0] : curr_box[2] + ][small_mask] = warp_fea[small_mask] + + mtv_condition = prompt_fea.permute(0, 3, 1, 2) + + # validation_prompts = ['show a photorealistic street view image'] + images_tensor = [] + + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + # import ipdb; ipdb.set_trace() + image = pipe( + prompt=validation_prompts, + image=mtv_condition, + num_inference_steps=30, + # num_inference_steps=20, + generator=generator, + height=curr_h, + width=curr_w, + controlnet_conditioning_scale=1.0, + guidance_scale=args.cfg_scale, + ).images # [0] + image = torch.cat([torch.tensor(np.array(ii)) for ii in image], 1) + + images_tensor.append(image) + + # import ipdb; ipdb.set_trace() + + # raw_img = batch['pixel_values'].permute(1,2,0) * 255 + # gen_img = torch.cat(images_tensor, 1) + # gen_img = torch.cat([raw_img, gen_img], 1) + + raw_img = ( + batch["pixel_values"] + .permute(2, 0, 3, 1) + .reshape(images_tensor[0].shape) + * 255 + ) + gen_img = torch.cat(images_tensor, 1) + + raw_w, raw_h = batch['patches'][0][4] + + gen_img = tf.resize(gen_img.permute(2,0,1), (raw_h, raw_w)).permute(1,2,0) + + # gen_img = torch.cat([raw_img, gen_img], 1) + + out_path = os.path.join( + save_path, + batch['patches'][0][3], + # f"val_{img_idx:06d}.jpg", + ) + + # import ipdb; ipdb.set_trace() + + cv2.imwrite( + out_path, cv2.cvtColor(gen_img.cpu().numpy(), cv2.COLOR_RGB2BGR) + ) + + +if __name__ == "__main__": + os.system("export NCCL_SOCKET_IFNAME=eth1") + + from torch.multiprocessing import Manager + + # world_size = 8 + world_size = 4 + + all_len = len(val_dataset) + # all_len = 500 + + all_list = np.arange(0, all_len, 1) + # all_list = np.arange(0, all_len, 8) + + all_len_sel = all_list.shape[0] + val_len = all_len_sel // world_size * world_size + # import ipdb; ipdb.set_trace() + + all_list_filter = all_list[:val_len] + + all_list_filter = np.split(all_list_filter, world_size) + + input_datas = {} + for i in range(world_size): + input_datas[i] = list(all_list_filter[i]) + print(len(input_datas[i])) + + input_datas[0] += list(all_list[val_len:]) + # input_datas[0] += list(np.all_list(val_len, all_len_sel)) + + global dino_encoder + + dino_encoder = FrozenDinoV2Encoder() + + controlnet = ControlNetModel.from_pretrained( + ckp_path, subfolder="controlnet", torch_dtype=torch.float16 + ) + # unet = UNet2DConditionModel.from_pretrained( + # ckp_path, subfolder="unet", torch_dtype=torch.float16 + # ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + variant=args.variant, + ) + vae_path = "/hpc2hdd/home/lli181/long_video/animate-anything/download/AI-ModelScope/sdxl-vae-fp16-fix" + + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder=None, + revision=args.revision, + variant=args.variant, + ) + + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, vae=vae, unet=unet, controlnet=controlnet, torch_dtype=torch.float16 + ) + + # speed up diffusion process with faster scheduler and memory optimization + pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + + pipe.enable_xformers_memory_efficient_attention() + pipe.set_progress_bar_config(disable=True) + from diffusers.models.attention_processor import AttnProcessor2_0 + + pipe.unet.set_attn_processor(AttnProcessor2_0()) + + run_inference(args.curr_gpu, 1, None, input_datas, pipe, args) + + # with Manager() as manager: + # pred_results = manager.list() + # mp.spawn(run_inference, nprocs=world_size, args=(world_size,pred_results,input_datas,pipe,args,), join=True) diff --git a/infer_coco.sh b/infer_coco.sh new file mode 100644 index 0000000..c88f7b7 --- /dev/null +++ b/infer_coco.sh @@ -0,0 +1,40 @@ +export WANDB_DISABLED=True + +export MODEL_DIR="/hpc2hdd/home/lli181/long_video/animate-anything/download/AI-ModelScope/stable-diffusion-xl-base-1.0" + +source /hpc2ssd/softwares/anaconda3/bin/activate pyt2 + +export HF_HUB_OFFLINE=True + + +# export EXP_NAME="out_coconut_dino_text_gridsam" +# export EXP_NAME="out_coconut_sdxl" +export EXP_NAME="out_coconut_sdxl_relu" +# export EXP_NAME="out_coconut_dino_text" +# export EXP_NAME="out_coconut_vith_img" +# export EXP_NAME="out_coconut_vith_anything" +export OUTPUT_DIR="./exp/$EXP_NAME" +# export SAVE_IMG_DIR="vis_dir/$EXP_NAME/dreambooth" +export SAVE_IMG_DIR="vis_dir/$EXP_NAME/val2017" +# export SAVE_IMG_DIR="vis_dir/$EXP_NAME/samples_coco_one" +export TRAIN_OR_VAL="val" + + + + + +# CUDA_VISIBLE_DEVICES=0 python infer_val_dino_db.py --save_img_path=$SAVE_IMG_DIR --model_path_infer=$OUTPUT_DIR --curr_gpu=0 --ctrl_channel=1024 --width=1024 --height=1024 --gen_train_or_val=$TRAIN_OR_VAL + +# CUDA_VISIBLE_DEVICES=0 python infer_xl_dino_coco.py --save_img_path=$SAVE_IMG_DIR --model_path_infer=$OUTPUT_DIR --curr_gpu=0 --ctrl_channel=1024 --width=1024 --height=1024 --gen_train_or_val=$TRAIN_OR_VAL --mulscale_batch_size=1 --patch_size=364 --num_validation_images=1 --cfg_scale=7 --pretrained_model_name_or_path=$MODEL_DIR + + +CUDA_VISIBLE_DEVICES=0 python infer_xl_dino_coco.py --save_img_path=$SAVE_IMG_DIR --model_path_infer=$OUTPUT_DIR --curr_gpu=0 --ctrl_channel=1024 --width=1024 --height=1024 --gen_train_or_val=$TRAIN_OR_VAL --mulscale_batch_size=1 --patch_size=364 --num_validation_images=1 --cfg_scale=7 --pretrained_model_name_or_path=$MODEL_DIR & CUDA_VISIBLE_DEVICES=1 python infer_xl_dino_coco.py --save_img_path=$SAVE_IMG_DIR --model_path_infer=$OUTPUT_DIR --curr_gpu=1 --ctrl_channel=1024 --width=1024 --height=1024 --gen_train_or_val=$TRAIN_OR_VAL --mulscale_batch_size=1 --patch_size=364 --num_validation_images=1 --cfg_scale=7 --pretrained_model_name_or_path=$MODEL_DIR & CUDA_VISIBLE_DEVICES=2 python infer_xl_dino_coco.py --save_img_path=$SAVE_IMG_DIR --model_path_infer=$OUTPUT_DIR --curr_gpu=2 --ctrl_channel=1024 --width=1024 --height=1024 --gen_train_or_val=$TRAIN_OR_VAL --mulscale_batch_size=1 --patch_size=364 --num_validation_images=1 --cfg_scale=7 --pretrained_model_name_or_path=$MODEL_DIR & CUDA_VISIBLE_DEVICES=3 python infer_xl_dino_coco.py --save_img_path=$SAVE_IMG_DIR --model_path_infer=$OUTPUT_DIR --curr_gpu=3 --ctrl_channel=1024 --width=1024 --height=1024 --gen_train_or_val=$TRAIN_OR_VAL --mulscale_batch_size=1 --patch_size=364 --num_validation_images=1 --cfg_scale=7 --pretrained_model_name_or_path=$MODEL_DIR + +# & CUDA_VISIBLE_DEVICES=4 python infer_xl_dino_coco.py --save_img_path=$SAVE_IMG_DIR --model_path_infer=$OUTPUT_DIR --curr_gpu=4 --ctrl_channel=1024 --width=1024 --height=1024 --gen_train_or_val=$TRAIN_OR_VAL --mulscale_batch_size=1 --patch_size=364 --num_validation_images=1 --cfg_scale=7 --pretrained_model_name_or_path=$MODEL_DIR & CUDA_VISIBLE_DEVICES=5 python infer_xl_dino_coco.py --save_img_path=$SAVE_IMG_DIR --model_path_infer=$OUTPUT_DIR --curr_gpu=5 --ctrl_channel=1024 --width=1024 --height=1024 --gen_train_or_val=$TRAIN_OR_VAL --mulscale_batch_size=1 --patch_size=364 --num_validation_images=1 --cfg_scale=7 --pretrained_model_name_or_path=$MODEL_DIR & CUDA_VISIBLE_DEVICES=6 python infer_xl_dino_coco.py --save_img_path=$SAVE_IMG_DIR --model_path_infer=$OUTPUT_DIR --curr_gpu=6 --ctrl_channel=1024 --width=1024 --height=1024 --gen_train_or_val=$TRAIN_OR_VAL --mulscale_batch_size=1 --patch_size=364 --num_validation_images=1 --cfg_scale=7 --pretrained_model_name_or_path=$MODEL_DIR & CUDA_VISIBLE_DEVICES=7 python infer_xl_dino_coco.py --save_img_path=$SAVE_IMG_DIR --model_path_infer=$OUTPUT_DIR --curr_gpu=7 --ctrl_channel=1024 --width=1024 --height=1024 --gen_train_or_val=$TRAIN_OR_VAL --mulscale_batch_size=1 --patch_size=364 --num_validation_images=1 --cfg_scale=7 --pretrained_model_name_or_path=$MODEL_DIR + + + +# CUDA_VISIBLE_DEVICES=0 python infer_val_dino_db.py --save_img_path=$SAVE_IMG_DIR --model_path_infer=$OUTPUT_DIR --curr_gpu=0 --ctrl_channel=1024 --width=1024 --height=1024 --gen_train_or_val=$TRAIN_OR_VAL & CUDA_VISIBLE_DEVICES=1 python infer_val_dino_db.py --save_img_path=$SAVE_IMG_DIR --model_path_infer=$OUTPUT_DIR --curr_gpu=1 --ctrl_channel=1024 --gen_train_or_val=$TRAIN_OR_VAL & CUDA_VISIBLE_DEVICES=2 python infer_val_dino_db.py --save_img_path=$SAVE_IMG_DIR --model_path_infer=$OUTPUT_DIR --curr_gpu=2 --ctrl_channel=1024 --gen_train_or_val=$TRAIN_OR_VAL & CUDA_VISIBLE_DEVICES=3 python infer_val_dino_db.py --save_img_path=$SAVE_IMG_DIR --model_path_infer=$OUTPUT_DIR --curr_gpu=3 --ctrl_channel=1024 --gen_train_or_val=$TRAIN_OR_VAL + + + diff --git a/infer_instance.py b/infer_instance.py new file mode 100644 index 0000000..eb9bf1e --- /dev/null +++ b/infer_instance.py @@ -0,0 +1,321 @@ +import os +import cv2 +import tqdm +import numpy as np +import torch +import torch.nn.functional as F +from torchvision import utils +from diffusers.utils import load_image + +import torch.distributed as dist +import torch.multiprocessing as mp +from torchvision.ops import masks_to_boxes + +from diffusers import UniPCMultistepScheduler, UNet2DConditionModel, AutoencoderKL +from diffusers.models.attention_processor import AttnProcessor2_0 + +from models.controlnet1x1 import ControlNetModel1x1 as ControlNetModel + +from models.pipeline_controlnet_sd_xl import ( + StableDiffusionXLControlNetPipeline as StableDiffusionXLControlNetPipeline, +) + +from models.dino_model import FrozenDinoV2Encoder + +from args_file import parse_args +from transformers import AutoTokenizer +from utils.datasets import InstanceDataset as OmniboothDataset + +from transformers import CLIPTextModel, CLIPTextModelWithProjection + + +args = parse_args() + + +if args.model_path_infer is not None: + ckp_path = args.model_path_infer + +if "checkpoint" not in ckp_path: + dirs = os.listdir(ckp_path) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + ckp_path = os.path.join(ckp_path, dirs[-1]) if len(dirs) > 0 else ckp_path + + +# generator = torch.manual_seed(666) +generator = torch.manual_seed(0) + + +tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, +) + +text_encoder = CLIPTextModelWithProjection.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder_2", + revision=args.revision, + variant=args.variant, +) + +val_dataset = OmniboothDataset(args, tokenizer, args.gen_train_or_val) + + +if args.save_img_path is not None: + save_path = args.save_img_path + +os.makedirs(save_path, exist_ok=True) + + +def tokenize_captions(examples, tokenizer, is_train=True): + captions = [] + for caption in examples: + captions.append(caption) + + inputs = tokenizer( + captions, + max_length=tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + return inputs.input_ids + + +def run_inference(rank, world_size, pred_results, input_datas, pipe, args): + # uncomment it if use ddp + # if rank is not None: + # # dist.init_process_group("gloo", rank=rank, world_size=world_size) + # dist.init_process_group("nccl", rank=rank, world_size=world_size) + # else: + # rank = 0 + print(rank) + # torch.set_default_device(rank) + + pipe.to("cuda") + dino_encoder.to("cuda") + text_encoder.to("cuda") + weight_dtype = torch.float16 + all_list = input_datas[rank] + + # pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") + # pipe.set_ip_adapter_scale(0.6) + # ipimage = load_image("/hpc2hdd/home/lli181/long_video/occ_exp/ctrl_instance/saia/vis_dir/ins_exp/instance_editimg/zelda.jpg") + + with torch.no_grad(): + + for img_idx in tqdm.tqdm(all_list): + batch = val_dataset.__getitem__(img_idx) + mtv_condition = batch["ctrl_img"] + validation_prompts = batch["prompts"] + + curr_h, curr_w = batch["pixel_values"].shape[-2:] + + prompt_fea = torch.zeros((*batch["ctrl_img"].shape, args.ctrl_channel)).to( + "cuda", dtype=weight_dtype + ) + + if args.text_or_img == "text" or args.text_or_img == "mix": + + for curr_b, curr_ins_prompt in enumerate(batch["input_ids_ins"]): + curr_ins_prompt = ["anything"] + curr_ins_prompt + input_ids = tokenize_captions(curr_ins_prompt, tokenizer).cuda() + with torch.cuda.amp.autocast(): + text_features = text_encoder(input_ids, return_dict=True)[ + "text_embeds" + # "pooler_output" + ] + text_features = controlnet.text_adapter(text_features).to( + prompt_fea + ) + + for curr_ins_id in range(len(curr_ins_prompt)): + prompt_fea[curr_b][batch["ctrl_img"][curr_b] == curr_ins_id] = ( + text_features[curr_ins_id] + ) + if args.text_or_img == "img" or args.text_or_img == "mix": + + for curr_b, curr_ins_img in enumerate(batch["patches"]): + curr_ins_id, curr_ins_patch = curr_ins_img[0], curr_ins_img[1].to( + prompt_fea + ) + if curr_ins_id.shape[0] > 0: + + with torch.cuda.amp.autocast(): + image_features = dino_encoder(curr_ins_patch) + image_features = controlnet.dino_adapter(image_features).to( + prompt_fea + ) + + for id_ins, curr_ins in enumerate(curr_ins_id.tolist()): + all_vector = image_features[id_ins] + global_vector = all_vector[0:1] + + down_s = args.patch_size // 14 + + patch_vector = ( + all_vector[1 : 1 + down_s * down_s] + .view(1, down_s, down_s, -1) + .permute(0, 3, 1, 2) + ) + curr_mask = batch["ctrl_img"][curr_b] == curr_ins + + if curr_mask.max() < 1: + continue + + curr_box = masks_to_boxes(curr_mask[None])[0].int().tolist() + height, width = ( + curr_box[3] - curr_box[1], + curr_box[2] - curr_box[0], + ) + + x = torch.linspace(-1, 1, height) + y = torch.linspace(-1, 1, width) + + xx, yy = torch.meshgrid(x, y) + grid = torch.stack((xx, yy), dim=2).to(patch_vector)[None] + + warp_fea = F.grid_sample( + patch_vector, + grid, + mode="bilinear", + padding_mode="reflection", + align_corners=True, + )[0].permute(1, 2, 0) + + small_mask = curr_mask[ + curr_box[1] : curr_box[3], curr_box[0] : curr_box[2] + ] + + curr_pix_num = small_mask.sum().item() + all_ins = np.arange(0, curr_pix_num) + sel_ins = np.random.choice( + # all_ins, size=(curr_pix_num // 1,), replace=True + all_ins, + size=(curr_pix_num // 10,), + replace=True, + ) + warp_fea[small_mask][sel_ins] = global_vector + + prompt_fea[curr_b][ + curr_box[1] : curr_box[3], curr_box[0] : curr_box[2] + ][small_mask] = warp_fea[small_mask] + + mtv_condition = prompt_fea.permute(0, 3, 1, 2) + + images_tensor = [] + + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipe( + prompt=validation_prompts, + image=mtv_condition, + # ip_adapter_image=ipimage, + num_inference_steps=30, + # num_inference_steps=20, + generator=generator, + height=curr_h, + width=curr_w, + controlnet_conditioning_scale=1.0, + guidance_scale=args.cfg_scale, + ).images + image = torch.cat([torch.tensor(np.array(ii)) for ii in image], 1) + + images_tensor.append(image) + + raw_img = ( + batch["pixel_values"] + .permute(2, 0, 3, 1) + .reshape(images_tensor[0].shape) + * 255 + ) + gen_img = torch.cat(images_tensor, 1) + + out_path = os.path.join( + save_path, + *batch["patches"][0][3].split("/")[-1:], + # f"val_{img_idx:06d}.jpg", + ) + + out_path = out_path[:-3] + "png" + + cv2.imwrite( + out_path, cv2.cvtColor(gen_img.cpu().numpy(), cv2.COLOR_RGB2BGR) + ) + + +if __name__ == "__main__": + os.system("export NCCL_SOCKET_IFNAME=eth1") + + from torch.multiprocessing import Manager + + # world_size = 4 + world_size = 1 + + all_len = len(val_dataset) + + all_list = np.arange(0, all_len, 1) + + all_len_sel = all_list.shape[0] + val_len = all_len_sel // world_size * world_size + + all_list_filter = all_list[:val_len] + + all_list_filter = np.split(all_list_filter, world_size) + + input_datas = {} + for i in range(world_size): + input_datas[i] = list(all_list_filter[i]) + print(len(input_datas[i])) + + input_datas[0] += list(all_list[val_len:]) + + global dino_encoder + + dino_encoder = FrozenDinoV2Encoder() + + controlnet = ControlNetModel.from_pretrained( + ckp_path, + subfolder="controlnet", + torch_dtype=torch.float16, + text_adapter_channel=1280, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + variant=args.variant, + ) + + vae_path = args.pretrained_vae_model_name_or_path + + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder=None, + revision=args.revision, + variant=args.variant, + ) + + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + unet=unet, + controlnet=controlnet, + torch_dtype=torch.float16, + ) + + # speed up diffusion process with faster scheduler and memory optimization + pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + + pipe.enable_xformers_memory_efficient_attention() + pipe.set_progress_bar_config(disable=True) + + pipe.unet.set_attn_processor(AttnProcessor2_0()) + + run_inference(args.curr_gpu, 1, None, input_datas, pipe, args) + + # with Manager() as manager: + # pred_results = manager.list() + # mp.spawn(run_inference, nprocs=world_size, args=(world_size,pred_results,input_datas,pipe,args,), join=True) diff --git a/infer_instance.sh b/infer_instance.sh new file mode 100644 index 0000000..e78061f --- /dev/null +++ b/infer_instance.sh @@ -0,0 +1,32 @@ +export WANDB_DISABLED=True +export HF_HUB_OFFLINE=True + +export MODEL_DIR="./ckp/stable-diffusion-xl-base-1.0" +export VAE_DIR="./ckp/sdxl-vae-fp16-fix" + + + + +export EXP_NAME="omnibooth_train" +export OUTPUT_DIR="./ckp/$EXP_NAME" +export SAVE_IMG_DIR="./vis_dir" +export TRAIN_OR_VAL="val" + + + + + +CUDA_VISIBLE_DEVICES=0 python infer_instance.py --save_img_path=$SAVE_IMG_DIR --model_path_infer=$OUTPUT_DIR --curr_gpu=0 --ctrl_channel=1024 --width=1024 --height=1024 --patch_size=364 --gen_train_or_val=$TRAIN_OR_VAL --pretrained_model_name_or_path=$MODEL_DIR --pretrained_vae_model_name_or_path=$VAE_DIR --text_or_img=text --cfg_scale=7.5 --num_validation_images=3 + + + + + + + + + + + + + diff --git a/models/__pycache__/blocks.cpython-310.pyc b/models/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000..0803252 Binary files /dev/null and b/models/__pycache__/blocks.cpython-310.pyc differ diff --git a/models/__pycache__/controlnet1x1.cpython-310.pyc b/models/__pycache__/controlnet1x1.cpython-310.pyc new file mode 100644 index 0000000..f32c374 Binary files /dev/null and b/models/__pycache__/controlnet1x1.cpython-310.pyc differ diff --git a/models/__pycache__/controlnet1x1.cpython-311.pyc b/models/__pycache__/controlnet1x1.cpython-311.pyc new file mode 100644 index 0000000..c8be029 Binary files /dev/null and b/models/__pycache__/controlnet1x1.cpython-311.pyc differ diff --git a/models/__pycache__/controlnet_down.cpython-310.pyc b/models/__pycache__/controlnet_down.cpython-310.pyc new file mode 100644 index 0000000..e0f3bb1 Binary files /dev/null and b/models/__pycache__/controlnet_down.cpython-310.pyc differ diff --git a/models/__pycache__/controlnet_down.cpython-311.pyc b/models/__pycache__/controlnet_down.cpython-311.pyc new file mode 100644 index 0000000..5c88701 Binary files /dev/null and b/models/__pycache__/controlnet_down.cpython-311.pyc differ diff --git a/models/__pycache__/dino_model.cpython-310.pyc b/models/__pycache__/dino_model.cpython-310.pyc new file mode 100644 index 0000000..38cfeaa Binary files /dev/null and b/models/__pycache__/dino_model.cpython-310.pyc differ diff --git a/models/__pycache__/pipeline_controlnet_1x1_4dunet.cpython-310.pyc b/models/__pycache__/pipeline_controlnet_1x1_4dunet.cpython-310.pyc new file mode 100644 index 0000000..cefde54 Binary files /dev/null and b/models/__pycache__/pipeline_controlnet_1x1_4dunet.cpython-310.pyc differ diff --git a/models/__pycache__/pipeline_controlnet_1x1_4dunet.cpython-311.pyc b/models/__pycache__/pipeline_controlnet_1x1_4dunet.cpython-311.pyc new file mode 100644 index 0000000..ecb85c0 Binary files /dev/null and b/models/__pycache__/pipeline_controlnet_1x1_4dunet.cpython-311.pyc differ diff --git a/models/__pycache__/pipeline_controlnet_sd_xl.cpython-310.pyc b/models/__pycache__/pipeline_controlnet_sd_xl.cpython-310.pyc new file mode 100644 index 0000000..536a578 Binary files /dev/null and b/models/__pycache__/pipeline_controlnet_sd_xl.cpython-310.pyc differ diff --git a/models/__pycache__/unet_2d_condition_multiview.cpython-310.pyc b/models/__pycache__/unet_2d_condition_multiview.cpython-310.pyc new file mode 100644 index 0000000..32121be Binary files /dev/null and b/models/__pycache__/unet_2d_condition_multiview.cpython-310.pyc differ diff --git a/models/controlnet1x1.py b/models/controlnet1x1.py new file mode 100644 index 0000000..fcf5ed2 --- /dev/null +++ b/models/controlnet1x1.py @@ -0,0 +1,956 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalControlNetMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unets.unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + get_down_block, +) +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetOutput(BaseOutput): + """ + The output of [`ControlNetModel`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the midde block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class ControlNetConditioningEmbeddingNodown(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), + ): + super().__init__() + + # self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=1, padding=0) + + # self.blocks = nn.ModuleList([]) + + # for i in range(len(block_out_channels) - 1): + # channel_in = block_out_channels[i] + # channel_out = block_out_channels[i + 1] + # self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=1, padding=0)) + # self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=1, padding=0, stride=1)) + + # self.conv_out = zero_module( + # nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=1, padding=0) + # ) + + self.position_encoder = nn.Sequential( + nn.Conv2d(conditioning_channels, conditioning_embedding_channels, kernel_size=1, stride=1, padding=0), + nn.ReLU(), + # zero_module(nn.Conv2d(conditioning_embedding_channels, conditioning_embedding_channels, kernel_size=1, stride=1, padding=0)), + nn.Conv2d(conditioning_embedding_channels, conditioning_embedding_channels, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, conditioning): + # embedding = self.conv_in(conditioning) + # embedding = F.silu(embedding) + + # for block in self.blocks: + # embedding = block(embedding) + # embedding = F.silu(embedding) + + # embedding = self.conv_out(embedding) + embedding = self.position_encoder(conditioning) + + return embedding + + +class ControlNetModel1x1(ModelMixin, ConfigMixin, FromOriginalControlNetMixin): + """ + A ControlNet model. + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): + The dimension of the attention heads. + use_linear_projection (`bool`, defaults to `False`): + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + num_class_embeds (`int`, *optional*, defaults to 0): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): + The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when + `class_embed_type="projection"`. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + TODO(Patrick) - unused parameter. + addition_embed_type_num_heads (`int`, defaults to 64): + The number of heads to use for the `TextTimeEmbedding` layer. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 3, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int, ...]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + addition_embed_type_num_heads: int = 64, + text_adapter_channel: int = 1024, + ): + super().__init__() + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + # control net conditioning embedding + # self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + self.controlnet_cond_embedding = ControlNetConditioningEmbeddingNodown( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlock2D": + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + self.text_adapter = nn.Sequential( + nn.Linear(text_adapter_channel, conditioning_channels), + # nn.Linear(conditioning_channels, conditioning_channels), + nn.ReLU(), + # nn.Linear(conditioning_channels, conditioning_channels), + ) + self.dino_adapter = nn.Sequential( + nn.Linear(conditioning_channels, conditioning_channels), + nn.ReLU(), + # nn.Linear(conditioning_channels, conditioning_channels), + ) + + def text_forward(self, features): + return self.text_adapter(features) + + def dino_forward(self, features): + return self.dino_adapter(features) + + + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + conditioning_channels: int = 3, + text_adapter_channel: int = 1024, + ): + r""" + Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied + where applicable. + """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + + controlnet = cls( + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=unet.config.in_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + mid_block_type=unet.config.mid_block_type, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + text_adapter_channel=text_adapter_channel, + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if controlnet.class_embedding: + controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) + + return controlnet + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.FloatTensor, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]: + """ + The [`ControlNetModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.FloatTensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + + Returns: + [`~models.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + if self.config.addition_embed_type is not None: + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + # print(666, controlnet_cond.shape, sample.shape) + # 2. pre-process + sample = self.conv_in(sample) + + # controlnet_cond = torch.zeros([1, 256, 56, 100]).to(sample) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + # print(controlnet_cond.shape, sample.shape) + # import ipdb; ipdb.set_trace() + # # ipdb> controlnet_cond.shape + # # torch.Size([6, 320, 56, 100]) + # # ipdb> sample.shape + # # torch.Size([12, 320, 56, 100]) + # controlnet_cond = torch.cat([controlnet_cond] * 2) + + sample = sample + controlnet_cond + # 666 torch.Size([1, 256, 448, 800]) torch.Size([1, 4, 56, 100]) + # torch.Size([1, 320, 56, 100]) torch.Size([1, 320, 56, 100]) + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = self.mid_block(sample, emb) + + # 5. Control net blocks + + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + scales = scales * conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/models/dino_model.py b/models/dino_model.py new file mode 100644 index 0000000..18f5069 --- /dev/null +++ b/models/dino_model.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +import torchvision.transforms as T +import open_clip +from PIL import Image +from open_clip.transform import image_transform +import sys + + + +sys.path.append("./models/dinov2") +from models.dinov2 import hubconf + + +DINOv2_weight_path = '/hpc2hdd/home/lli181/long_video/occ_exp/ctrl_instance/dinov2_vitl14_reg4_pretrain.pth' + + +class FrozenDinoV2Encoder(nn.Module): + """ + Uses the DINOv2 encoder for image + """ + def __init__(self, device="cuda", freeze=True): + super().__init__() + + dinov2 = hubconf.dinov2_vitl14_reg(pretrained=False) + state_dict = torch.load(DINOv2_weight_path, map_location='cpu') + dinov2.load_state_dict(state_dict, strict=True) + self.model = dinov2.to(device, dtype=torch.float16) + + self.device = device + if freeze: + self.freeze() + self.image_mean = torch.tensor([0.485, 0.456, 0.406]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + self.image_std = torch.tensor([0.229, 0.224, 0.225]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + + + def freeze(self): + self.model.eval() + for param in self.model.parameters(): + param.requires_grad = False + + def forward(self, image): + if isinstance(image,list): + image = torch.cat(image,0) + + image = (image.to(self.device) - self.image_mean.to(self.device)) / self.image_std.to(self.device) + features = self.model.forward_features(image) + reg_tokens = features["x_norm_regtokens"] + patch_tokens = features["x_norm_patchtokens"] + image_features = features["x_norm_clstoken"].unsqueeze(1) + + hint = torch.cat([image_features, patch_tokens],1) # 8,257,1024 + # hint = self.projector(hint) + return hint + + def encode(self, image): + return self(image) + +if __name__ == '__main__': + torch.cuda.set_device(0) + model = FrozenDinoV2Encoder(device='cuda',freeze=True) + image = torch.randn(1,3,224,224) + hint = model(image) \ No newline at end of file diff --git a/models/pipeline_controlnet_sd_xl.py b/models/pipeline_controlnet_sd_xl.py new file mode 100644 index 0000000..3e04e97 --- /dev/null +++ b/models/pipeline_controlnet_sd_xl.py @@ -0,0 +1,1476 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from .controlnet1x1 import ControlNetModel1x1 as ControlNetModel + +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 + ... ) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... ).images[0] + ``` +""" + + +class StableDiffusionXLControlNetPipeline( + DiffusionPipeline, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings should always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to + watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no + watermarker is used. + """ + + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + # if 1: + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + # if 1: + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned containing the output images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + + raw_image = image.clone() + image = torch.zeros((batch_size, 1, 448, 800)).to(device) + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + # import ipdb; ipdb.set_trace() + # AttributeError: 'ControlNetModel1x1' object has no attribute 'nets' + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + if self.do_classifier_free_guidance and not guess_mode: + raw_image = torch.cat([raw_image] * 2) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=raw_image.to(image), + # controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/train.py b/train.py new file mode 100644 index 0000000..fc621e0 --- /dev/null +++ b/train.py @@ -0,0 +1,1089 @@ +import argparse +import functools +import gc +import cv2 +import logging +import math +import os +import pickle +import random +import shutil +from pathlib import Path + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as tf +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + UNet2DConditionModel, + UniPCMultistepScheduler, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + +from models.controlnet1x1 import ControlNetModel1x1 as ControlNetModel +from models.pipeline_controlnet_sd_xl import ( + StableDiffusionXLControlNetPipeline as StableDiffusionXLControlNetPipeline, +) + +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + +from utils.datasets import CocoNutImgDataset as OmniboothDataset + +from models.dino_model import FrozenDinoV2Encoder +from torchvision import utils +from torchvision.ops import masks_to_boxes + + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +# check_min_version("0.26.0.dev0") + +logger = get_logger(__name__) + + +def tokenize_captions(examples, tokenizer, is_train=True): + captions = [] + for caption in examples: + captions.append(caption) + + inputs = tokenizer( + captions, + max_length=tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + return inputs.input_ids + + + +def tokenize_captions_sdxl(args, prompt_batch, tokenizer, is_train=True): + tokenizer, text_encoders = tokenizer + + original_size = (args.width, args.height) + target_size = (args.width, args.height) + crops_coords_top_left = (0, 0) + + prompt_embeds, pooled_prompt_embeds = encode_prompt( + prompt_batch, + text_encoders, + tokenizer, + args.proportion_empty_prompts, + is_train, + ) + add_text_embeds = pooled_prompt_embeds + + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]).to(prompt_embeds) + + add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) + + return { + "prompt_ids": prompt_embeds, + "unet_added_conditions": { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + }, + } + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt( + prompt_batch, + text_encoders, + tokenizers, + proportion_empty_prompts, + is_train=True, +): + prompt_embeds_list = [] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + +@torch.no_grad() +def log_validation( + vae, unet, controlnet, args, accelerator, weight_dtype, step, val_dataset +): + logger.info("Running validation... ") + + controlnet = accelerator.unwrap_model(controlnet) + + pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + unet=unet, + controlnet=controlnet, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + for img_idx in [ + 0, + 1, + 2, + ]: + # for img_idx in [0, 122, 1179]: + # for img_idx in range(2): + batch = val_dataset.__getitem__(img_idx) + mtv_condition = batch["ctrl_img"] # [None] + validation_prompts = batch["prompts"] + + curr_h, curr_w = batch["pixel_values"].shape[-2:] + + prompt_fea = torch.zeros((*batch["ctrl_img"].shape, args.ctrl_channel)).to( + accelerator.device, dtype=weight_dtype + ) + + for curr_b, curr_ins_prompt in enumerate(batch["input_ids_ins"]): + if len(curr_ins_prompt) > 0: + curr_ins_prompt = ["anything"] + curr_ins_prompt + input_ids = tokenize_captions(curr_ins_prompt, tokenizer_two).cuda() + with torch.cuda.amp.autocast(): + text_features = text_encoder_infer(input_ids, return_dict=True)[ + "text_embeds" + # "pooler_output" + ] + text_features = controlnet.text_adapter(text_features).to( + prompt_fea + ) + + for curr_ins_id in range(len(curr_ins_prompt)): + prompt_fea[curr_b][batch["ctrl_img"][curr_b] == curr_ins_id] = ( + text_features[curr_ins_id] + ) + + for curr_b, curr_ins_img in enumerate(batch["patches"]): + curr_ins_id, curr_ins_patch = curr_ins_img[0], curr_ins_img[1].to( + accelerator.device, dtype=weight_dtype + ) + if curr_ins_id.shape[0] > 0: + + with torch.cuda.amp.autocast(): + image_features = dino_encoder(curr_ins_patch) + image_features = controlnet.dino_adapter(image_features).to( + prompt_fea + ) + + for id_ins, curr_ins in enumerate(curr_ins_id.tolist()): + all_vector = image_features[id_ins] + global_vector = all_vector[0:1] + + down_s = args.patch_size // 14 + + patch_vector = ( + all_vector[1 : 1 + down_s * down_s] + .view(1, down_s, down_s, -1) + .permute(0, 3, 1, 2) + ) + curr_mask = batch["ctrl_img"][curr_b] == curr_ins + + if curr_mask.max() < 1: + continue + + curr_box = masks_to_boxes(curr_mask[None])[0].int().tolist() + height, width = ( + curr_box[3] - curr_box[1], + curr_box[2] - curr_box[0], + ) + + x = torch.linspace(-1, 1, height) + y = torch.linspace(-1, 1, width) + + xx, yy = torch.meshgrid(x, y) + grid = torch.stack((xx, yy), dim=2).to(patch_vector)[None] + + warp_fea = F.grid_sample( + patch_vector, + grid, + mode="bilinear", + padding_mode="reflection", + align_corners=True, + )[0].permute(1, 2, 0) + + small_mask = curr_mask[ + curr_box[1] : curr_box[3], curr_box[0] : curr_box[2] + ] + + curr_pix_num = small_mask.sum().item() + all_ins = np.arange(0, curr_pix_num) + sel_ins = np.random.choice( + all_ins, size=(curr_pix_num // 10,), replace=True + ) + # import ipdb; ipdb.set_trace() + warp_fea[small_mask][sel_ins] = global_vector + + prompt_fea[curr_b][ + curr_box[1] : curr_box[3], curr_box[0] : curr_box[2] + ][small_mask] = warp_fea[small_mask] + + mtv_condition = prompt_fea.permute(0, 3, 1, 2) + + images_tensor = [] + + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline( + prompt=validation_prompts, + image=mtv_condition, + num_inference_steps=30, + generator=generator, + height=curr_h, + width=curr_w, + controlnet_conditioning_scale=1.0, + guidance_scale=args.cfg_scale, + ).images # [0] + image = torch.cat([torch.tensor(np.array(ii)) for ii in image], 1) + + images_tensor.append(image) + + raw_img = ( + batch["pixel_values"].permute(2, 0, 3, 1).reshape(images_tensor[0].shape) + * 255 + ) + gen_img = torch.cat(images_tensor, 0) + gen_img = torch.cat([raw_img, gen_img], 0) + + out_path = os.path.join( + args.output_dir, + f"step_{step:06d}_{img_idx:04d}.jpg", + ) + + cv2.imwrite(out_path, cv2.cvtColor(gen_img.cpu().numpy(), cv2.COLOR_RGB2BGR)) + + del controlnet + del pipeline + + return None + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): + img_str = "" + if image_logs is not None: + img_str = "You can find some example images below.\n" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + make_image_grid(images, 1, len(images)).save( + os.path.join(repo_folder, f"images_{i}.png") + ) + img_str += f"![images_{i})](./images_{i}.png)\n" + + yaml = f""" +--- +license: openrail++ +base_model: {base_model} +tags: +- stable-diffusion-xl +- stable-diffusion-xl-diffusers +- text-to-image +- diffusers +- controlnet +inference: true +--- + """ + model_card = f""" +# controlnet-{repo_id} + +These are controlnet weights trained on {base_model} with new type of conditioning. +{img_str} +""" + + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration( + project_dir=args.output_dir, logging_dir=logging_dir + ) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + token=args.hub_token, + ).repo_id + + # Load the tokenizers + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + global tokenizer_two, text_encoder_infer + + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + variant=args.variant, + ) + text_encoder_infer = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder_2", + revision=args.revision, + variant=args.variant, + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder_2", + revision=args.revision, + variant=args.variant, + ) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + args.pretrained_vae_model_name_or_path = vae_path + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + variant=args.variant, + ) + # import ipdb; ipdb.set_trace() + + if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnet weights") + controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path) + else: + logger.info("Initializing controlnet weights from unet") + controlnet = ControlNetModel.from_unet( + unet, conditioning_channels=args.ctrl_channel, text_adapter_channel=1280 + ) + # controlnet = ControlNetModel.from_unet(unet) + + # Resuming unet + if args.resume_from_checkpoint == "latest": + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming unet from checkpoint {path}") + # import ipdb; ipdb.set_trace() + # unet = unet.from_pretrained( + # os.path.join(args.output_dir, path), subfolder="unet" + # ) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + i = len(weights) - 1 + + while len(weights) > 0: + weights.pop() + model = models[i] + + sub_dir = "controlnet" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + i -= 1 + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = ControlNetModel.from_pretrained( + input_dir, subfolder="controlnet" + ) + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + vae.requires_grad_(False) + unet.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + text_encoder_infer.requires_grad_(False) + controlnet.train() + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + controlnet.enable_xformers_memory_efficient_attention() + else: + raise ValueError( + "xformers is not available. Make sure it is installed correctly" + ) + + if args.gradient_checkpointing: + controlnet.enable_gradient_checkpointing() + unet.enable_gradient_checkpointing() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if unwrap_model(controlnet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = list(controlnet.parameters()) + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + + global dino_encoder + dino_encoder = FrozenDinoV2Encoder() + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae, unet and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + if args.pretrained_vae_model_name_or_path is not None: + vae.to(accelerator.device, dtype=weight_dtype) + else: + vae.to(accelerator.device, dtype=torch.float32) + unet.to(accelerator.device, dtype=weight_dtype) + + # text_encoder = text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_infer.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + + + gc.collect() + torch.cuda.empty_cache() + + tokenizer = [tokenizers, text_encoders] + train_dataset = OmniboothDataset(args, tokenizer, "train") + val_dataset = OmniboothDataset(args, tokenizer, "val") + + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + + # tensorboard cannot handle list types for config + tracker_config.pop("validation_prompt") + tracker_config.pop("validation_image") + + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Train! + total_batch_size = ( + args.train_batch_size + * accelerator.num_processes + * args.gradient_accumulation_steps + ) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + image_logs = None + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(controlnet): + + batch["pixel_values"] = batch["pixel_values"][0] + batch["ctrl_img"] = batch["ctrl_img"][0] + batch["prompts"] = [x[0] for x in batch["prompts"]] + + prompt_info = tokenize_captions_sdxl(args, batch["prompts"], tokenizer) + batch.update(prompt_info) + + prompt_fea = torch.zeros( + (*batch["ctrl_img"].shape, args.ctrl_channel) + ).to(accelerator.device, dtype=weight_dtype) + + for curr_b, curr_ins_prompt in enumerate(batch["input_ids_ins"]): + if len(curr_ins_prompt) > 0: + curr_ins_prompt = ["anything"] + [x[0] for x in curr_ins_prompt] + input_ids = tokenize_captions(curr_ins_prompt, tokenizer_two).cuda() + with torch.cuda.amp.autocast(): + text_features = text_encoder_infer(input_ids, return_dict=True)[ + "text_embeds" + # "pooler_output" + ] + text_features = controlnet.module.text_adapter( + text_features + ).to(prompt_fea) + + for curr_ins_id in range(len(curr_ins_prompt)): + prompt_fea[curr_b][ + batch["ctrl_img"][curr_b] == curr_ins_id + ] = text_features[curr_ins_id] + + for curr_b, curr_ins_img in enumerate(batch["patches"]): + curr_ins_id, curr_ins_patch = curr_ins_img[0], curr_ins_img[1].to( + weight_dtype + ) + + if curr_ins_id.shape[1] > 0: + with torch.cuda.amp.autocast(): + image_features = dino_encoder( + curr_ins_patch.reshape((-1, *curr_ins_patch.shape[2:])) + ) + image_features = controlnet.module.dino_adapter( + image_features + ).to(prompt_fea) + + for id_ins, curr_ins in enumerate(curr_ins_id[0].tolist()): + all_vector = image_features[id_ins] + global_vector = all_vector[0:1] + + down_s = args.patch_size // 14 + + patch_vector = ( + all_vector[1 : 1 + down_s * down_s] + .view(1, down_s, down_s, -1) + .permute(0, 3, 1, 2) + ) + + curr_mask = batch["ctrl_img"][curr_b] == curr_ins + + if curr_mask.max() < 1: + continue + + curr_box = masks_to_boxes(curr_mask[None])[0].int().tolist() + height, width = ( + curr_box[3] - curr_box[1], + curr_box[2] - curr_box[0], + ) + + x = torch.linspace(-1, 1, height) + y = torch.linspace(-1, 1, width) + + xx, yy = torch.meshgrid(x, y) + grid = torch.stack((xx, yy), dim=2).to(patch_vector)[None] + + warp_fea = F.grid_sample( + patch_vector, + grid, + mode="bilinear", + padding_mode="reflection", + align_corners=True, + )[0].permute(1, 2, 0) + + # import ipdb; ipdb.set_trace() + + small_mask = curr_mask[ + curr_box[1] : curr_box[3], curr_box[0] : curr_box[2] + ] + + curr_pix_num = small_mask.sum().item() + all_ins = np.arange(0, curr_pix_num) + sel_ins = np.random.choice( + all_ins, size=(curr_pix_num // 10,), replace=True + ) + warp_fea[small_mask][sel_ins] = global_vector + + prompt_fea[curr_b][ + curr_box[1] : curr_box[3], curr_box[0] : curr_box[2] + ][small_mask] = warp_fea[small_mask] + + batch["conditioning_pixel_values"] = prompt_fea.permute(0, 3, 1, 2) + + # Convert images to latent space + if args.pretrained_vae_model_name_or_path is not None: + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + else: + pixel_values = batch["pixel_values"] + latents = vae.encode(pixel_values).latent_dist.sample() + latents = latents * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + latents = latents.to(weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (bsz,), + device=latents.device, + ) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps).to( + dtype=weight_dtype + ) + + # ControlNet conditioning. + controlnet_image = batch["conditioning_pixel_values"].to( + dtype=weight_dtype + ) + # import ipdb; ipdb.set_trace() + down_block_res_samples, mid_block_res_sample = controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=batch["prompt_ids"], + added_cond_kwargs=batch["unet_added_conditions"], + controlnet_cond=controlnet_image, + return_dict=False, + ) + + # Predict the noise residual + model_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype), + added_cond_kwargs=batch["unet_added_conditions"], + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) + for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to( + dtype=weight_dtype + ), + return_dict=False, + )[0] + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError( + f"Unknown prediction type {noise_scheduler.config.prediction_type}" + ) + # loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + + def weighted_mse_loss(input, target, weight): + return torch.mean(weight * (input - target) ** 2) + + fore_value = ( + -0.5 * (1 + np.cos(np.pi * global_step / args.max_train_steps)) + + 2.0 + ) + + edge_w = torch.cat([x[2] for x in batch["patches"]], 0) + weight_mask = torch.ones_like(edge_w).to(weight_dtype) + + weight_mask[edge_w != 0] *= fore_value + # import ipdb; ipdb.set_trace() + + loss = weighted_mse_loss( + model_pred.float(), + target.float(), + weight_mask.unsqueeze(1).repeat(1, target.shape[1], 1, 1), + ) + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [ + d for d in checkpoints if d.startswith("checkpoint") + ] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1]) + ) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = ( + len(checkpoints) - args.checkpoints_total_limit + 1 + ) + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}" + ) + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint + ) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}" + ) + accelerator.save_state(save_path) + + # # save unet + # accelerator.unwrap_model(unet).save_pretrained( + # os.path.join(save_path, "unet") + # ) + + logger.info(f"Saved state to {save_path}") + + if global_step % args.validation_steps == 0: + image_logs = log_validation( + vae, + unet, + controlnet, + args, + accelerator, + weight_dtype, + global_step, + val_dataset, + ) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + controlnet = unwrap_model(controlnet) + controlnet.save_pretrained(args.output_dir) + unwrap_model(unet).save_pretrained( + os.path.join(args.output_dir, f"checkpoint-{global_step}", "unet") + ) + + if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + from args_file import parse_args + + args = parse_args() + main(args) diff --git a/train.sh b/train.sh new file mode 100644 index 0000000..4ee1919 --- /dev/null +++ b/train.sh @@ -0,0 +1,39 @@ +export WANDB_DISABLED=True +export HF_HUB_OFFLINE=True + +export MODEL_DIR="./ckp/stable-diffusion-xl-base-1.0" +export VAE_DIR="./ckp/sdxl-vae-fp16-fix" + + + +export EXP_NAME="omnibooth_train" +export OUTPUT_DIR="./ckp/$EXP_NAME" + + + + + +# accelerate launch --gpu_ids 0,1,2,3,4,5,6,7 --num_processes 8 --main_process_port 3226 train.py \ +accelerate launch --gpu_ids 0, --num_processes 1 --main_process_port 3226 train.py \ + --pretrained_model_name_or_path=$MODEL_DIR \ + --pretrained_vae_model_name_or_path=$VAE_DIR \ + --output_dir=$OUTPUT_DIR \ + --width=1024 \ + --height=1024 \ + --patch_size=364 \ + --learning_rate=4e-5 \ + --num_train_epochs=12 \ + --train_batch_size=1 \ + --mulscale_batch_size=2 \ + --mixed_precision="fp16" \ + --num_validation_images=2 \ + --validation_steps=500 \ + --checkpointing_steps=5000 \ + --checkpoints_total_limit=10 \ + --ctrl_channel=1024 \ + --use_sdxl=True \ + --enable_xformers_memory_efficient_attention \ + --report_to='wandb' \ + --resume_from_checkpoint="latest" \ + --tracker_project_name="omnibooth-demo" + diff --git a/utils/coco_category.py b/utils/coco_category.py new file mode 100644 index 0000000..b87de2e --- /dev/null +++ b/utils/coco_category.py @@ -0,0 +1,137 @@ + +# from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES +COCO_CATEGORIES = [ + {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, + {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, + {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, + {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, + {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, + {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, + {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, + {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"}, + {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, + {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"}, + {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"}, + {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"}, + {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"}, + {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"}, + {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, + {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, + {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, + {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, + {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, + {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, + {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"}, + {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"}, + {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"}, + {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"}, + {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"}, + {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"}, + {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"}, + {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"}, + {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"}, + {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"}, + {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"}, + {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"}, + {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"}, + {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"}, + {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"}, + {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"}, + {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"}, + {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"}, + {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"}, + {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, + {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"}, + {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"}, + {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"}, + {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"}, + {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"}, + {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"}, + {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"}, + {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"}, + {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"}, + {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"}, + {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"}, + {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"}, + {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"}, + {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"}, + {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"}, + {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"}, + {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, + {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, + {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, + {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"}, + {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, + {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"}, + {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, + {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"}, + {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"}, + {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"}, + {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"}, + {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"}, + {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"}, + {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"}, + {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"}, + {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"}, + {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"}, + {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"}, + {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"}, + {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"}, + {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"}, + {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"}, + {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"}, + {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"}, + {"color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"}, + {"color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"}, + {"color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"}, + {"color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"}, + {"color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"}, + {"color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"}, + {"color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"}, + {"color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"}, + {"color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"}, + {"color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"}, + {"color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"}, + {"color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"}, + {"color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"}, + {"color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"}, + {"color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"}, + {"color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"}, + {"color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"}, + {"color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"}, + {"color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"}, + {"color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"}, + {"color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"}, + {"color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"}, + {"color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"}, + {"color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"}, + {"color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"}, + {"color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"}, + {"color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"}, + {"color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"}, + {"color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"}, + {"color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"}, + {"color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"}, + {"color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"}, + {"color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"}, + {"color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"}, + {"color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"}, + {"color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"}, + {"color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"}, + {"color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"}, + {"color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"}, + {"color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"}, + {"color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"}, + {"color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"}, + {"color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"}, + {"color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"}, + {"color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"}, + {"color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"}, + {"color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"}, + {"color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"}, + {"color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"}, + {"color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"}, + {"color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"}, + {"color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"}, + {"color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"}, +] \ No newline at end of file diff --git a/utils/coco_meta.py b/utils/coco_meta.py new file mode 100644 index 0000000..99a3d35 --- /dev/null +++ b/utils/coco_meta.py @@ -0,0 +1,806 @@ +COCO_META = [ + { + 'color': [0,0,0], + 'isthing':0, + 'id':0, + 'name': 'background' + }, + { + 'color': [220, 20, 60], + 'isthing': 1, + 'id': 1, + 'name': 'person' + }, + { + 'color': [119, 11, 32], + 'isthing': 1, + 'id': 2, + 'name': 'bicycle' + }, + { + 'color': [0, 0, 142], + 'isthing': 1, + 'id': 3, + 'name': 'car' + }, + { + 'color': [0, 0, 230], + 'isthing': 1, + 'id': 4, + 'name': 'motorcycle' + }, + { + 'color': [106, 0, 228], + 'isthing': 1, + 'id': 5, + 'name': 'airplane' + }, + { + 'color': [0, 60, 100], + 'isthing': 1, + 'id': 6, + 'name': 'bus' + }, + { + 'color': [0, 80, 100], + 'isthing': 1, + 'id': 7, + 'name': 'train' + }, + { + 'color': [0, 0, 70], + 'isthing': 1, + 'id': 8, + 'name': 'truck' + }, + { + 'color': [0, 0, 192], + 'isthing': 1, + 'id': 9, + 'name': 'boat' + }, + { + 'color': [250, 170, 30], + 'isthing': 1, + 'id': 10, + 'name': 'traffic light' + }, + { + 'color': [100, 170, 30], + 'isthing': 1, + 'id': 11, + 'name': 'fire hydrant' + }, + { + 'color': [220, 220, 0], + 'isthing': 1, + 'id': 13, + 'name': 'stop sign' + }, + { + 'color': [175, 116, 175], + 'isthing': 1, + 'id': 14, + 'name': 'parking meter' + }, + { + 'color': [250, 0, 30], + 'isthing': 1, + 'id': 15, + 'name': 'bench' + }, + { + 'color': [165, 42, 42], + 'isthing': 1, + 'id': 16, + 'name': 'bird' + }, + { + 'color': [255, 77, 255], + 'isthing': 1, + 'id': 17, + 'name': 'cat' + }, + { + 'color': [0, 226, 252], + 'isthing': 1, + 'id': 18, + 'name': 'dog' + }, + { + 'color': [182, 182, 255], + 'isthing': 1, + 'id': 19, + 'name': 'horse' + }, + { + 'color': [0, 82, 0], + 'isthing': 1, + 'id': 20, + 'name': 'sheep' + }, + { + 'color': [120, 166, 157], + 'isthing': 1, + 'id': 21, + 'name': 'cow' + }, + { + 'color': [110, 76, 0], + 'isthing': 1, + 'id': 22, + 'name': 'elephant' + }, + { + 'color': [174, 57, 255], + 'isthing': 1, + 'id': 23, + 'name': 'bear' + }, + { + 'color': [199, 100, 0], + 'isthing': 1, + 'id': 24, + 'name': 'zebra' + }, + { + 'color': [72, 0, 118], + 'isthing': 1, + 'id': 25, + 'name': 'giraffe' + }, + { + 'color': [255, 179, 240], + 'isthing': 1, + 'id': 27, + 'name': 'backpack' + }, + { + 'color': [0, 125, 92], + 'isthing': 1, + 'id': 28, + 'name': 'umbrella' + }, + { + 'color': [209, 0, 151], + 'isthing': 1, + 'id': 31, + 'name': 'handbag' + }, + { + 'color': [188, 208, 182], + 'isthing': 1, + 'id': 32, + 'name': 'tie' + }, + { + 'color': [0, 220, 176], + 'isthing': 1, + 'id': 33, + 'name': 'suitcase' + }, + { + 'color': [255, 99, 164], + 'isthing': 1, + 'id': 34, + 'name': 'frisbee' + }, + { + 'color': [92, 0, 73], + 'isthing': 1, + 'id': 35, + 'name': 'skis' + }, + { + 'color': [133, 129, 255], + 'isthing': 1, + 'id': 36, + 'name': 'snowboard' + }, + { + 'color': [78, 180, 255], + 'isthing': 1, + 'id': 37, + 'name': 'sports ball' + }, + { + 'color': [0, 228, 0], + 'isthing': 1, + 'id': 38, + 'name': 'kite' + }, + { + 'color': [174, 255, 243], + 'isthing': 1, + 'id': 39, + 'name': 'baseball bat' + }, + { + 'color': [45, 89, 255], + 'isthing': 1, + 'id': 40, + 'name': 'baseball glove' + }, + { + 'color': [134, 134, 103], + 'isthing': 1, + 'id': 41, + 'name': 'skateboard' + }, + { + 'color': [145, 148, 174], + 'isthing': 1, + 'id': 42, + 'name': 'surfboard' + }, + { + 'color': [255, 208, 186], + 'isthing': 1, + 'id': 43, + 'name': 'tennis racket' + }, + { + 'color': [197, 226, 255], + 'isthing': 1, + 'id': 44, + 'name': 'bottle' + }, + { + 'color': [171, 134, 1], + 'isthing': 1, + 'id': 46, + 'name': 'wine glass' + }, + { + 'color': [109, 63, 54], + 'isthing': 1, + 'id': 47, + 'name': 'cup' + }, + { + 'color': [207, 138, 255], + 'isthing': 1, + 'id': 48, + 'name': 'fork' + }, + { + 'color': [151, 0, 95], + 'isthing': 1, + 'id': 49, + 'name': 'knife' + }, + { + 'color': [9, 80, 61], + 'isthing': 1, + 'id': 50, + 'name': 'spoon' + }, + { + 'color': [84, 105, 51], + 'isthing': 1, + 'id': 51, + 'name': 'bowl' + }, + { + 'color': [74, 65, 105], + 'isthing': 1, + 'id': 52, + 'name': 'banana' + }, + { + 'color': [166, 196, 102], + 'isthing': 1, + 'id': 53, + 'name': 'apple' + }, + { + 'color': [208, 195, 210], + 'isthing': 1, + 'id': 54, + 'name': 'sandwich' + }, + { + 'color': [255, 109, 65], + 'isthing': 1, + 'id': 55, + 'name': 'orange' + }, + { + 'color': [0, 143, 149], + 'isthing': 1, + 'id': 56, + 'name': 'broccoli' + }, + { + 'color': [179, 0, 194], + 'isthing': 1, + 'id': 57, + 'name': 'carrot' + }, + { + 'color': [209, 99, 106], + 'isthing': 1, + 'id': 58, + 'name': 'hot dog' + }, + { + 'color': [5, 121, 0], + 'isthing': 1, + 'id': 59, + 'name': 'pizza' + }, + { + 'color': [227, 255, 205], + 'isthing': 1, + 'id': 60, + 'name': 'donut' + }, + { + 'color': [147, 186, 208], + 'isthing': 1, + 'id': 61, + 'name': 'cake' + }, + { + 'color': [153, 69, 1], + 'isthing': 1, + 'id': 62, + 'name': 'chair' + }, + { + 'color': [3, 95, 161], + 'isthing': 1, + 'id': 63, + 'name': 'couch' + }, + { + 'color': [163, 255, 0], + 'isthing': 1, + 'id': 64, + 'name': 'potted plant' + }, + { + 'color': [119, 0, 170], + 'isthing': 1, + 'id': 65, + 'name': 'bed' + }, + { + 'color': [0, 182, 199], + 'isthing': 1, + 'id': 67, + 'name': 'dining table' + }, + { + 'color': [0, 165, 120], + 'isthing': 1, + 'id': 70, + 'name': 'toilet' + }, + { + 'color': [183, 130, 88], + 'isthing': 1, + 'id': 72, + 'name': 'tv' + }, + { + 'color': [95, 32, 0], + 'isthing': 1, + 'id': 73, + 'name': 'laptop' + }, + { + 'color': [130, 114, 135], + 'isthing': 1, + 'id': 74, + 'name': 'mouse' + }, + { + 'color': [110, 129, 133], + 'isthing': 1, + 'id': 75, + 'name': 'remote' + }, + { + 'color': [166, 74, 118], + 'isthing': 1, + 'id': 76, + 'name': 'keyboard' + }, + { + 'color': [219, 142, 185], + 'isthing': 1, + 'id': 77, + 'name': 'cell phone' + }, + { + 'color': [79, 210, 114], + 'isthing': 1, + 'id': 78, + 'name': 'microwave' + }, + { + 'color': [178, 90, 62], + 'isthing': 1, + 'id': 79, + 'name': 'oven' + }, + { + 'color': [65, 70, 15], + 'isthing': 1, + 'id': 80, + 'name': 'toaster' + }, + { + 'color': [127, 167, 115], + 'isthing': 1, + 'id': 81, + 'name': 'sink' + }, + { + 'color': [59, 105, 106], + 'isthing': 1, + 'id': 82, + 'name': 'refrigerator' + }, + { + 'color': [142, 108, 45], + 'isthing': 1, + 'id': 84, + 'name': 'book' + }, + { + 'color': [196, 172, 0], + 'isthing': 1, + 'id': 85, + 'name': 'clock' + }, + { + 'color': [95, 54, 80], + 'isthing': 1, + 'id': 86, + 'name': 'vase' + }, + { + 'color': [128, 76, 255], + 'isthing': 1, + 'id': 87, + 'name': 'scissors' + }, + { + 'color': [201, 57, 1], + 'isthing': 1, + 'id': 88, + 'name': 'teddy bear' + }, + { + 'color': [246, 0, 122], + 'isthing': 1, + 'id': 89, + 'name': 'hair drier' + }, + { + 'color': [191, 162, 208], + 'isthing': 1, + 'id': 90, + 'name': 'toothbrush' + }, + { + 'color': [255, 255, 128], + 'isthing': 0, + 'id': 92, + 'name': 'banner' + }, + { + 'color': [147, 211, 203], + 'isthing': 0, + 'id': 93, + 'name': 'blanket' + }, + { + 'color': [150, 100, 100], + 'isthing': 0, + 'id': 95, + 'name': 'bridge' + }, + { + 'color': [168, 171, 172], + 'isthing': 0, + 'id': 100, + 'name': 'cardboard' + }, + { + 'color': [146, 112, 198], + 'isthing': 0, + 'id': 107, + 'name': 'counter' + }, + { + 'color': [210, 170, 100], + 'isthing': 0, + 'id': 109, + 'name': 'curtain' + }, + { + 'color': [92, 136, 89], + 'isthing': 0, + 'id': 112, + 'name': 'door-stuff' + }, + { + 'color': [218, 88, 184], + 'isthing': 0, + 'id': 118, + 'name': 'floor-wood' + }, + { + 'color': [241, 129, 0], + 'isthing': 0, + 'id': 119, + 'name': 'flower' + }, + { + 'color': [217, 17, 255], + 'isthing': 0, + 'id': 122, + 'name': 'fruit' + }, + { + 'color': [124, 74, 181], + 'isthing': 0, + 'id': 125, + 'name': 'gravel' + }, + { + 'color': [70, 70, 70], + 'isthing': 0, + 'id': 128, + 'name': 'house' + }, + { + 'color': [255, 228, 255], + 'isthing': 0, + 'id': 130, + 'name': 'light' + }, + { + 'color': [154, 208, 0], + 'isthing': 0, + 'id': 133, + 'name': 'mirror-stuff' + }, + { + 'color': [193, 0, 92], + 'isthing': 0, + 'id': 138, + 'name': 'net' + }, + { + 'color': [76, 91, 113], + 'isthing': 0, + 'id': 141, + 'name': 'pillow' + }, + { + 'color': [255, 180, 195], + 'isthing': 0, + 'id': 144, + 'name': 'platform' + }, + { + 'color': [106, 154, 176], + 'isthing': 0, + 'id': 145, + 'name': 'playingfield' + }, + { + 'color': [230, 150, 140], + 'isthing': 0, + 'id': 147, + 'name': 'railroad' + }, + { + 'color': [60, 143, 255], + 'isthing': 0, + 'id': 148, + 'name': 'river' + }, + { + 'color': [128, 64, 128], + 'isthing': 0, + 'id': 149, + 'name': 'road' + }, + { + 'color': [92, 82, 55], + 'isthing': 0, + 'id': 151, + 'name': 'roof' + }, + { + 'color': [254, 212, 124], + 'isthing': 0, + 'id': 154, + 'name': 'sand' + }, + { + 'color': [73, 77, 174], + 'isthing': 0, + 'id': 155, + 'name': 'sea' + }, + { + 'color': [255, 160, 98], + 'isthing': 0, + 'id': 156, + 'name': 'shelf' + }, + { + 'color': [255, 255, 255], + 'isthing': 0, + 'id': 159, + 'name': 'snow' + }, + { + 'color': [104, 84, 109], + 'isthing': 0, + 'id': 161, + 'name': 'stairs' + }, + { + 'color': [169, 164, 131], + 'isthing': 0, + 'id': 166, + 'name': 'tent' + }, + { + 'color': [225, 199, 255], + 'isthing': 0, + 'id': 168, + 'name': 'towel' + }, + { + 'color': [137, 54, 74], + 'isthing': 0, + 'id': 171, + 'name': 'wall-brick' + }, + { + 'color': [135, 158, 223], + 'isthing': 0, + 'id': 175, + 'name': 'wall-stone' + }, + { + 'color': [7, 246, 231], + 'isthing': 0, + 'id': 176, + 'name': 'wall-tile' + }, + { + 'color': [107, 255, 200], + 'isthing': 0, + 'id': 177, + 'name': 'wall-wood' + }, + { + 'color': [58, 41, 149], + 'isthing': 0, + 'id': 178, + 'name': 'water-other' + }, + { + 'color': [183, 121, 142], + 'isthing': 0, + 'id': 180, + 'name': 'window-blind' + }, + { + 'color': [255, 73, 97], + 'isthing': 0, + 'id': 181, + 'name': 'window-other' + }, + { + 'color': [107, 142, 35], + 'isthing': 0, + 'id': 184, + 'name': 'tree-merged' + }, + { + 'color': [190, 153, 153], + 'isthing': 0, + 'id': 185, + 'name': 'fence-merged' + }, + { + 'color': [146, 139, 141], + 'isthing': 0, + 'id': 186, + 'name': 'ceiling-merged' + }, + { + 'color': [70, 130, 180], + 'isthing': 0, + 'id': 187, + 'name': 'sky-other-merged' + }, + { + 'color': [134, 199, 156], + 'isthing': 0, + 'id': 188, + 'name': 'cabinet-merged' + }, + { + 'color': [209, 226, 140], + 'isthing': 0, + 'id': 189, + 'name': 'table-merged' + }, + { + 'color': [96, 36, 108], + 'isthing': 0, + 'id': 190, + 'name': 'floor-other-merged' + }, + { + 'color': [96, 96, 96], + 'isthing': 0, + 'id': 191, + 'name': 'pavement-merged' + }, + { + 'color': [64, 170, 64], + 'isthing': 0, + 'id': 192, + 'name': 'mountain-merged' + }, + { + 'color': [152, 251, 152], + 'isthing': 0, + 'id': 193, + 'name': 'grass-merged' + }, + { + 'color': [208, 229, 228], + 'isthing': 0, + 'id': 194, + 'name': 'dirt-merged' + }, + { + 'color': [206, 186, 171], + 'isthing': 0, + 'id': 195, + 'name': 'paper-merged' + }, + { + 'color': [152, 161, 64], + 'isthing': 0, + 'id': 196, + 'name': 'food-other-merged' + }, + { + 'color': [116, 112, 0], + 'isthing': 0, + 'id': 197, + 'name': 'building-other-merged' + }, + { + 'color': [0, 114, 143], + 'isthing': 0, + 'id': 198, + 'name': 'rock-merged' + }, + { + 'color': [102, 102, 156], + 'isthing': 0, + 'id': 199, + 'name': 'wall-other-merged' + }, + { + 'color': [250, 141, 255], + 'isthing': 0, + 'id': 200, + 'name': 'rug-merged' + }, +] \ No newline at end of file diff --git a/utils/datasets.py b/utils/datasets.py new file mode 100644 index 0000000..1d54e16 --- /dev/null +++ b/utils/datasets.py @@ -0,0 +1,839 @@ +import torch +import pandas as pd +import os +import cv2 +import copy +import json +import glob +import torch +import random +import pickle +import numpy as np +from torchvision import transforms +from PIL import Image +import torchvision.transforms.functional as tf +from torchvision.io import read_image +from scipy import sparse +from panopticapi.utils import rgb2id +from pycocotools.coco import COCO + +import albumentations as A +from albumentations.pytorch import ToTensorV2 + + +class CocoNutImgDataset(torch.utils.data.Dataset): + def __init__(self, args, tokenizer, trainorval): + self.args = copy.deepcopy(args) + self.trainorval = trainorval + + dataroot = args.dataroot_path + self.dataroot = dataroot + + if trainorval == "val": + data_file = os.path.join( + dataroot, "annotations/relabeled_coco_val.json" + ) + ann_file_captions = f"{dataroot}/coco/annotations/captions_val2017.json" + self.img_dir = f"{dataroot}/coco/val2017/" + self.instance_prompt = f"{dataroot}/annotations/my-val.json" + self.mask_dir = f"{dataroot}/relabeled_coco_val/relabeled_coco_val/" + elif trainorval == "train": + data_file = os.path.join(dataroot, "annotations/coconut_s.json") + ann_file_captions = f"{dataroot}/coco/annotations/captions_train2017.json" + self.img_dir = f"{dataroot}/coco/train2017/" + self.instance_prompt = f"{dataroot}/annotations/my-train.json" + self.mask_dir = f"{dataroot}/coconut_s/coconut_s/" + + with open(self.instance_prompt) as file: # hw + self.instance_prompt = json.load(file) + self.instance_prompt = self.instance_prompt[trainorval] + + self.coco_caption = COCO(ann_file_captions) + + self.dataset = [] + + tmp_dict = {} + + for keys, values in self.instance_prompt.items(): + curr_dict = {} + curr_dict["height"] = values[-1][0] + curr_dict["width"] = values[-1][1] + curr_dict["file_name"] = keys.replace("png", "jpg") + curr_dict["id"] = int(keys[:-4]) + self.dataset.append(curr_dict) + + self.length = len(self.dataset) + print(f"data scale: {self.length}") + + transforms_list = [ + # transforms.Resize((self.args.height, self.args.width), interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + ] + + if trainorval == "train": + transforms_list.append(transforms.Normalize([0.5], [0.5])) + + if self.args.num_validation_images != 1 and trainorval == "val": + self.args.mulscale_batch_size = 3 + + self.image_transforms = transforms.Compose(transforms_list) + + mask_transforms_list = [ + transforms.Resize( + (self.args.height // 8, self.args.width // 8), + interpolation=transforms.InterpolationMode.NEAREST, + ), + transforms.ToTensor(), + ] + + self.mask_transforms = transforms.Compose(mask_transforms_list) + + self.dino_transforms = A.Compose( + [ + A.Resize(self.args.patch_size, self.args.patch_size), + A.HorizontalFlip(), + A.RandomBrightnessContrast( + brightness_limit=0.1, contrast_limit=0.1, p=0.3 + ), + ToTensorV2(), + ] + ) + self.dino_transforms_noflip = A.Compose( + [ + A.Resize(self.args.patch_size, self.args.patch_size), + A.RandomBrightnessContrast( + brightness_limit=0.1, contrast_limit=0.1, p=0.3 + ), + ToTensorV2(), + ] + ) + + if self.args.use_sdxl: + self.tokenizer, self.text_encoders = tokenizer + else: + self.tokenizer = tokenizer + + self.weight_dtype = torch.float16 + self.make_ratio_dict() + + def make_ratio_dict(self): + number = 30 + + ratio_dict = {} + ratio_list = np.linspace(0.5, 2.0, number) + for i in range(number): + ratio_dict[i] = [] + + for img_idx in range(len(self.dataset)): + curr_anno = self.dataset[img_idx] + + width = curr_anno["width"] + height = curr_anno["height"] + + curr_ratio = width / height + + curr_idx = np.argmin(np.abs(ratio_list - curr_ratio)) + + ratio_dict[curr_idx].append(img_idx) + + self.ratio_list = ratio_list + self.ratio_dict = ratio_dict + + def tokenize_captions(self, examples, is_train=True): + captions = [] + for caption in examples: + if random.random() < self.args.proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column should contain either strings or lists of strings." + ) + inputs = self.tokenizer( + captions, + max_length=self.tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + return inputs.input_ids + + def tokenize_captions_sdxl(self, prompt_batch, is_train=True): + + original_size = (self.args.width, self.args.height) + target_size = (self.args.width, self.args.height) + crops_coords_top_left = (0, 0) + + prompt_embeds, pooled_prompt_embeds = self.encode_prompt( + prompt_batch, + self.text_encoders, + self.tokenizer, + self.args.proportion_empty_prompts, + is_train, + ) + add_text_embeds = pooled_prompt_embeds + + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + + add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) + + return { + "prompt_ids": prompt_embeds, + "unet_added_conditions": { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + }, + } + + # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt_batch, + text_encoders, + tokenizers, + proportion_empty_prompts, + is_train=True, + ): + prompt_embeds_list = [] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + def __len__(self): + "Denotes the total number of samples" + if self.args.use_cbgs and self.trainorval == "train": + return len(self.sample_indices) + return self.length + + def __getitem__(self, idx): + if self.args.use_cbgs and self.trainorval == "train": + idx = self.sample_indices[idx] + return self.get_data_dict(idx) + + @torch.no_grad() + def get_data_dict(self, index): + + curr_anno = self.dataset[index] + + sample_list = [curr_anno] + + width = curr_anno["width"] + height = curr_anno["height"] + + curr_ratio = width / height + curr_ratio_idx = np.argmin(np.abs(self.ratio_list - curr_ratio)) + + curr_ratio_list = self.ratio_dict[curr_ratio_idx] + + remanding_size = self.args.mulscale_batch_size - 1 + if remanding_size > 0: + other_anno = random.sample(curr_ratio_list, remanding_size) + sample_list = sample_list + [self.dataset[x] for x in other_anno] + else: + sample_list = sample_list + + mul_pixel_values = [] + mul_ctrl_img = [] + mul_input_ids = [] + mul_input_ids_ins = [] + mul_prompts = [] + mul_patches = [] + + curr_height = self.args.height + curr_width = curr_height * self.ratio_list.tolist()[curr_ratio_idx] + curr_width = round(curr_width / 8) * 8 + + for curr_anno in sample_list: + + file_name = curr_anno["file_name"] + + raw_loca = self.img_dir + file_name.replace("png", "jpg") + mask_loca = self.mask_dir + file_name.replace("jpg", "png") + + try: + instance_captions = self.instance_prompt[ + file_name.replace("jpg", "png") + ] + + # load panoptic mask + img_mask = np.asarray(Image.open(mask_loca), dtype=np.uint32) + img_mask = rgb2id(img_mask) + ins_num = img_mask.max() + img_mask = Image.fromarray(img_mask.astype("uint32")) + + except: + # print('no prompt') + # print(file_name) + continue + + img = Image.open(raw_loca).convert("RGB") + img = img.resize((curr_width, curr_height)) + img_mask = img_mask.resize((curr_width, curr_height), Image.NEAREST) + + use_patch = 1 + if use_patch: + patches = [] + sel_ins = [] + img_np_raw = np.array(img, dtype=np.uint8) + mask_np_raw = np.array(img_mask, dtype=np.uint8) + ins_num = min(30, ins_num) + all_ins = np.arange(1, ins_num + 1).tolist() + + for id_ins, curr_ins in enumerate(all_ins): + # continue + if self.args.text_or_img == "mix" or self.args.text_or_img == "img": + # keep at least 2 image sample + if np.random.randint(0, 100) > 50 and id_ins >= 2: + continue + + mask_np = copy.deepcopy(mask_np_raw) + img_np = copy.deepcopy(img_np_raw) + img_np[mask_np != curr_ins] = 255 + mask_np[mask_np != curr_ins] = 0 + + mask_pil = Image.fromarray(mask_np.astype("uint8")) + box = mask_pil.getbbox() + if ( + box is None or (box[2] - box[0]) * (box[3] - box[1]) < 256 + ) and len(patches) != 0: + continue + + img_pil = Image.fromarray(img_np.astype("uint8")) + + cropped_img = img_pil.crop(box) + if "sign" in instance_captions[0][id_ins]: + cropped_img = self.dino_transforms_noflip( + image=np.array(cropped_img) + ) + else: + cropped_img = self.dino_transforms(image=np.array(cropped_img)) + cropped_img = cropped_img["image"] / 255 + # cropped_img = self.args.img_preprocess(cropped_img) + + patches.append(cropped_img[None]) + sel_ins.append(curr_ins) + + if len(patches) > 0: + patches = torch.cat(patches, dim=0) + else: + patches = torch.zeros( + (0, 3, self.args.patch_size, self.args.patch_size) + ) + + edges = ( + cv2.Canny( + np.array(img.resize((curr_width // 8, curr_height // 8))), + 100, + 200, + ) + / 255 + ) + + patches = [ + torch.tensor(sel_ins), + patches, + edges, + file_name, + (width, height), + ] + + img_mask = img_mask.resize( + (curr_width // 8, curr_height // 8), Image.NEAREST + ) + + img = self.image_transforms(img) # [None] + + fea_mask = torch.tensor(np.array(img_mask)) + + img_id = curr_anno["id"] + # img_id = curr_anno['image_id'] + ann_ids_captions = self.coco_caption.getAnnIds( + imgIds=[img_id], iscrowd=None + ) + anns_caption = self.coco_caption.loadAnns(ann_ids_captions)[0]["caption"] + anns_caption = [anns_caption] + + if self.args.use_sdxl: + pass + else: + input_ids = self.tokenize_captions(anns_caption) + mul_input_ids.append(input_ids) + + mul_pixel_values.append(img[None]) + mul_ctrl_img.append(fea_mask[None]) + mul_input_ids_ins.append(instance_captions[0]) + mul_prompts.append(anns_caption[0]) + mul_patches.append(patches) + + mul_pixel_values = torch.cat(mul_pixel_values, 0) + mul_ctrl_img = torch.cat(mul_ctrl_img, 0) + + data_dict = { + "pixel_values": mul_pixel_values, + "ctrl_img": mul_ctrl_img, + # "input_ids": mul_input_ids, + "input_ids_ins": mul_input_ids_ins, + "prompts": mul_prompts, + "patches": mul_patches, + } + + if self.args.use_sdxl: + pass + else: + mul_input_ids = torch.cat(mul_input_ids, 0) + data_dict["input_ids"] = mul_input_ids + + return data_dict + + def get_cat_ids(self, idx): + """Get category distribution of single scene. + + Args: + idx (int): Index of the data_info. + + Returns: + dict[list]: for each category, if the current scene + contains such boxes, store a list containing idx, + otherwise, store empty list. + """ + info = self.dataset[idx] + if self.use_valid_flag: + mask = info["valid_flag"] + gt_names = set(info["gt_names"][mask]) + else: + gt_names = set(info["gt_names"]) + + cat_ids = [] + fore_flag = 0 + for name in gt_names: + if name in self.CLASSES: + cat_ids.append(self.cat2id[name]) + fore_flag = 1 + if fore_flag == 0: + # model background as two objects + for _ in range(120): + cat_ids.append(self.cat2id["background"]) + return cat_ids + + def _get_sample_indices(self): + """Load annotations from ann_file. + + Args: + ann_file (str): Path of the annotation file. + + Returns: + list[dict]: List of annotations after class sampling. + """ + class_sample_idxs = {cat_id: [] for cat_id in self.cat2id.values()} + for idx in range(len(self.dataset)): + sample_cat_ids = self.get_cat_ids(idx) + for cat_id in sample_cat_ids: + class_sample_idxs[cat_id].append(idx) + duplicated_samples = sum([len(v) for _, v in class_sample_idxs.items()]) + class_distribution = { + k: len(v) / duplicated_samples for k, v in class_sample_idxs.items() + } + for key, value in class_sample_idxs.items(): + print(key, len(value)) + + sample_indices = [] + + frac = 1.0 / len(self.CLASSES) + ratios = [frac / v for v in class_distribution.values()] + for cls_inds, ratio in zip(list(class_sample_idxs.values()), ratios): + sample_indices += np.random.choice( + cls_inds, int(len(cls_inds) * ratio) + ).tolist() + return sample_indices + + +# vis coconut image +class InstanceDataset(torch.utils.data.Dataset): + def __init__(self, args, tokenizer, trainorval): + self.args = args + self.trainorval = trainorval + + dataroot = args.dataroot_path + self.dataroot = dataroot + + # self.dataset = ["./data/instance_dataset/ff_instance_1"] + self.dataset = ["./data/instance_dataset/plane"] + + self.dataset = sorted(self.dataset) + + transforms_list = [ + # transforms.Resize((self.args.height, self.args.width), interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + ] + + if trainorval == "train": + transforms_list.append(transforms.Normalize([0.5], [0.5])) + + self.image_transforms = transforms.Compose(transforms_list) + + mask_transforms_list = [ + transforms.Resize( + (self.args.height // 8, self.args.width // 8), + interpolation=transforms.InterpolationMode.NEAREST, + ), + transforms.ToTensor(), + ] + + self.mask_transforms = transforms.Compose(mask_transforms_list) + + self.dino_transforms = A.Compose( + [ + A.Resize(self.args.patch_size, self.args.patch_size), + A.HorizontalFlip(), + A.RandomBrightnessContrast(p=0.5), + ToTensorV2(), + ] + ) + self.dino_transforms_noflip = A.Compose( + [ + A.Resize(self.args.patch_size, self.args.patch_size), + # A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1), + ToTensorV2(), + ] + ) + + if self.args.use_sdxl: + self.tokenizer, self.text_encoders = tokenizer + else: + self.tokenizer = tokenizer + + self.weight_dtype = torch.float16 + + def tokenize_captions(self, examples, is_train=True): + captions = [] + for caption in examples: + if random.random() < self.args.proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column should contain either strings or lists of strings." + ) + inputs = self.tokenizer( + captions, + max_length=self.tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + return inputs.input_ids + + def tokenize_captions_sdxl(self, prompt_batch, is_train=True): + + original_size = (self.args.width, self.args.height) + target_size = (self.args.width, self.args.height) + crops_coords_top_left = (0, 0) + + prompt_embeds, pooled_prompt_embeds = self.encode_prompt( + prompt_batch, + self.text_encoders, + self.tokenizer, + self.args.proportion_empty_prompts, + is_train, + ) + add_text_embeds = pooled_prompt_embeds + + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + + add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) + + return { + "prompt_ids": prompt_embeds, + "unet_added_conditions": { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + }, + } + + # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt_batch, + text_encoders, + tokenizers, + proportion_empty_prompts, + is_train=True, + ): + prompt_embeds_list = [] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + def __len__(self): + "Denotes the total number of samples" + return len(self.dataset) + + def __getitem__(self, idx): + return self.get_data_dict(idx) + + @torch.no_grad() + def get_data_dict(self, index): + + curr_anno = self.dataset[index] + + mul_pixel_values = [] + mul_ctrl_img = [] + mul_input_ids = [] + mul_input_ids_ins = [] + mul_prompts = [] + mul_patches = [] + + file_name = curr_anno.split("/")[-1] + ".png" + + first_mask = np.array(Image.open(curr_anno + "/0000_mask.png")) + raw_h, raw_w = first_mask.shape[:2] + ratio = raw_w / raw_h + + curr_width = self.args.height * ratio + curr_height = self.args.height + + curr_width = round(curr_width / 8) * 8 + + with open(curr_anno + "/prompt_dict.json") as f: + prompt_dict = json.load(f) + + ins_num = len(prompt_dict.keys()) - 1 + + img = np.zeros((curr_height, curr_width, 3), dtype=np.uint8) + + img_mask = np.zeros_like(np.array(img, dtype=np.uint8))[:, :, 0] + ins_rgb = [] + ins_rgb_id = [] + for p_id in range(ins_num): + mask_loca = curr_anno + f"/{p_id:04d}_mask.png" + ins_mask = np.array( + Image.open(mask_loca).resize((curr_width, curr_height)), dtype=np.uint8 + ) + if len(ins_mask.shape) != 2: + ins_mask = ins_mask[:, :, 0] + + if 'ff' in mask_loca: + img_mask[ins_mask != 255] = p_id + 1 # other mask, fg0~1, bg1 + else: + img_mask[ins_mask != 0] = p_id + 1 # coco mask, fg1 bg0 + + insrgb_loca = curr_anno + f"/{p_id:04d}.png" + if os.path.exists(insrgb_loca): + ins_rgb.append(Image.open(insrgb_loca).convert("RGB")) + ins_rgb_id.append(p_id + 1) + else: + ins_rgb.append(Image.open(mask_loca).convert("RGB")) # fake img patch + + img_mask = Image.fromarray(img_mask) + + instance_captions = [prompt_dict[f"prompt_{p_id}"] for p_id in range(ins_num)] + + patches = [] + + use_patch = 1 + if use_patch: + sel_ins = [] + img_np_raw = np.array(img, dtype=np.uint8) + mask_np_raw = np.array(img_mask, dtype=np.uint8) + # ins_num = min(30, ins_num) + # all_ins = np.arange(1, ins_num + 1).tolist() + all_ins = ins_rgb_id + + for id_ins, curr_ins in enumerate(all_ins): + + cropped_img = ins_rgb[curr_ins - 1] + + cropped_img = self.dino_transforms_noflip(image=np.array(cropped_img)) + + cropped_img = cropped_img["image"] / 255 + patches.append(cropped_img[None]) + sel_ins.append(curr_ins) + + if len(patches) > 0: + patches = torch.cat(patches, dim=0) + else: + patches = torch.zeros( + (0, 3, self.args.patch_size, self.args.patch_size) + ) + # patches = [torch.tensor(sel_ins), patches] + patches = [ + torch.tensor(sel_ins), + patches, + None, + file_name, + (curr_width, curr_height), + ] + + img_mask = img_mask.resize((curr_width // 8, curr_height // 8), Image.NEAREST) + + img = self.image_transforms(img) # [None] + + fea_mask = torch.tensor(np.array(img_mask)) + + anns_caption = [prompt_dict["global_prompt"]] + + if self.args.use_sdxl: + input_ids = self.tokenize_captions_sdxl(anns_caption) + else: + input_ids = self.tokenize_captions(anns_caption) + + mul_pixel_values.append(img[None]) + mul_ctrl_img.append(fea_mask[None]) + mul_input_ids.append(input_ids) + mul_input_ids_ins.append(instance_captions) + mul_prompts.append(anns_caption[0]) + mul_patches.append(patches) + + mul_pixel_values = torch.cat(mul_pixel_values, 0) + mul_ctrl_img = torch.cat(mul_ctrl_img, 0) + mul_input_ids = torch.cat(mul_input_ids, 0) + + data_dict = { + "pixel_values": mul_pixel_values, + "ctrl_img": mul_ctrl_img, + "input_ids": mul_input_ids, + "input_ids_ins": mul_input_ids_ins, + "prompts": mul_prompts, + "patches": mul_patches, + } + + return data_dict + + def get_cat_ids(self, idx): + """Get category distribution of single scene. + + Args: + idx (int): Index of the data_info. + + Returns: + dict[list]: for each category, if the current scene + contains such boxes, store a list containing idx, + otherwise, store empty list. + """ + info = self.dataset[idx] + if self.use_valid_flag: + mask = info["valid_flag"] + gt_names = set(info["gt_names"][mask]) + else: + gt_names = set(info["gt_names"]) + + cat_ids = [] + fore_flag = 0 + for name in gt_names: + if name in self.CLASSES: + cat_ids.append(self.cat2id[name]) + fore_flag = 1 + if fore_flag == 0: + # model background as two objects + for _ in range(120): + cat_ids.append(self.cat2id["background"]) + return cat_ids + + def _get_sample_indices(self): + """Load annotations from ann_file. + + Args: + ann_file (str): Path of the annotation file. + + Returns: + list[dict]: List of annotations after class sampling. + """ + class_sample_idxs = {cat_id: [] for cat_id in self.cat2id.values()} + for idx in range(len(self.dataset)): + sample_cat_ids = self.get_cat_ids(idx) + for cat_id in sample_cat_ids: + class_sample_idxs[cat_id].append(idx) + duplicated_samples = sum([len(v) for _, v in class_sample_idxs.items()]) + class_distribution = { + k: len(v) / duplicated_samples for k, v in class_sample_idxs.items() + } + for key, value in class_sample_idxs.items(): + print(key, len(value)) + + sample_indices = [] + + frac = 1.0 / len(self.CLASSES) + ratios = [frac / v for v in class_distribution.values()] + for cls_inds, ratio in zip(list(class_sample_idxs.values()), ratios): + sample_indices += np.random.choice( + cls_inds, int(len(cls_inds) * ratio) + ).tolist() + return sample_indices