diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..a0777b5
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,9 @@
+MIT License
+
+Copyright [2022] [Bahjat Kawar, Jiaming Song, Stefano Ermon, Michael Elad]
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..df3d8be
--- /dev/null
+++ b/README.md
@@ -0,0 +1,95 @@
+# JPEG Artifact Correction using Denoising Diffusion Restoration Models
+
+[arXiv](https://arxiv.org/abs/2209.11888) | [PDF](https://arxiv.org/pdf/2209.11888.pdf)
+
+[Bahjat Kawar](https://bahjat-kawar.github.io/)\* 1, [Jiaming Song](http://tsong.me)\* 2, [Stefano Ermon](http://cs.stanford.edu/~ermon)3, [Michael Elad](https://elad.cs.technion.ac.il/)1
+1 Technion, 2NVIDIA, 3Stanford University, \*Equal contribution.
+
+
+
+
+We extend DDRM (Denoising Diffusion Restoration Models) for the problems of JPEG artifact correction and image dequantization.
+
+## Running the Experiments
+The code has been tested on PyTorch 1.8 and PyTorch 1.10. Please refer to `environment.yml` for a list of conda/mamba environments that can be used to run the code. The codebase is based heavily on the [original DDRM codebase](https://github.com/bahjat-kawar/ddrm).
+
+### Pretrained models
+We use pretrained models from [https://github.com/openai/guided-diffusion](https://github.com/openai/guided-diffusion), [https://github.com/pesser/pytorch_diffusion](https://github.com/pesser/pytorch_diffusion) and [https://github.com/ermongroup/SDEdit](https://github.com/ermongroup/SDEdit)
+
+We use 1,000 images from the ImageNet validation set for comparison with other methods. The list of images is taken from [https://github.com/XingangPan/deep-generative-prior/](https://github.com/XingangPan/deep-generative-prior/)
+
+The models and datasets are placed in the `exp/` folder as follows:
+```bash
+ # a folder named by the argument `--exp` given to main.py
+├── datasets # all dataset files
+│ ├── celeba # all CelebA files
+│ ├── imagenet # all ImageNet files
+│ ├── ood # out of distribution ImageNet images
+│ ├── ood_bedroom # out of distribution bedroom images
+│ ├── ood_cat # out of distribution cat images
+│ └── ood_celeba # out of distribution CelebA images
+├── logs # contains checkpoints and samples produced during training
+│ ├── celeba
+│ │ └── celeba_hq.ckpt # the checkpoint file for CelebA-HQ
+│ ├── diffusion_models_converted
+│ │ └── ema_diffusion_lsun__model
+│ │ └── model-x.ckpt # the checkpoint file saved at the x-th training iteration
+│ ├── imagenet # ImageNet checkpoint files
+│ │ ├── 256x256_classifier.pt
+│ │ ├── 256x256_diffusion.pt
+│ │ ├── 256x256_diffusion_uncond.pt
+│ │ ├── 512x512_classifier.pt
+│ │ └── 512x512_diffusion.pt
+├── image_samples # contains generated samples
+└── imagenet_val_1k.txt # list of the 1k images used in ImageNet-1K.
+```
+
+### Sampling from the model
+
+The general command to sample from the model is as follows:
+```
+python main.py --ni --config {CONFIG}.yml --doc {DATASET} -i {IMAGE_FOLDER} --timesteps {STEPS} --init_timestep {INIT_T} --eta {ETA} --etaB {ETA_B} --deg {DEGRADATION} --num_avg_samples {NUM_AVG}
+```
+where the following are options
+- `ETA` is the eta hyperparameter in the paper. (default: `1`)
+- `ETA_B` is the eta_b hyperparameter in the paper. (default: `0.4`)
+- `STEPS` controls how many timesteps used in the process. (default: `20`)
+- `INIT_T` controls the timestep to start sampling from. (default: `300`)
+- `NUM_AVG` is the number of samples per input to average for the final result. (default: `1`)
+- `DEGREDATION` is the type of degredation used. (One of: `quant` for dequantization, or `jpegXX` for JPEG with quality factor `XX`, e.g. `jpeg80`)
+- `CONFIG` is the name of the config file (see `configs/` for a list), including hyperparameters such as batch size and network architectures.
+- `DATASET` is the name of the dataset used, to determine where the checkpoint file is found.
+- `IMAGE_FOLDER` is the name of the folder the resulting images will be placed in (default: `images`)
+
+For example, to use the default settings from the paper on the ImageNet 256x256 dataset, the problem of JPEG artifact correction for QF=80, and averaging 8 samples per input:
+```
+python main.py --ni --config imagenet_256.yml --doc imagenet -i imagenet --deg jpeg80 --num_avg_samples 8
+```
+The generated images are place in the `/image_samples/{IMAGE_FOLDER}` folder, where `orig_{id}.png`, `y0_{id}.png`, `{id}_-1.png` refer to the original, degraded, restored images respectively.
+
+The config files contain a setting controlling whether to test on samples from the trained dataset's distribution or not.
+
+## References and Acknowledgements
+```
+@inproceedings{kawar2022jpeg,
+ title={JPEG Artifact Correction using Denoising Diffusion Restoration Models},
+ author={Bahjat Kawar and Jiaming Song and Stefano Ermon and Michael Elad},
+ booktitle={Neural Information Processing Systems (NeurIPS) Workshop on Score-Based Methods},
+ year={2022}
+}
+```
+
+```
+@inproceedings{kawar2022denoising,
+ title={Denoising Diffusion Restoration Models},
+ author={Bahjat Kawar and Michael Elad and Stefano Ermon and Jiaming Song},
+ booktitle={Advances in Neural Information Processing Systems},
+ year={2022}
+}
+```
+
+This implementation is based on / inspired by [https://github.com/bahjat-kawar/ddrm](https://github.com/bahjat-kawar/ddrm)
+
+## License
+
+The code is released under the MIT License.
\ No newline at end of file
diff --git a/assets/ddrm-jpeg-demo.png b/assets/ddrm-jpeg-demo.png
new file mode 100644
index 0000000..7557e4e
Binary files /dev/null and b/assets/ddrm-jpeg-demo.png differ
diff --git a/configs/bedroom.yml b/configs/bedroom.yml
new file mode 100644
index 0000000..aa63ed3
--- /dev/null
+++ b/configs/bedroom.yml
@@ -0,0 +1,51 @@
+data:
+ dataset: "LSUN"
+ category: "bedroom"
+ image_size: 256
+ channels: 3
+ logit_transform: false
+ uniform_dequantization: false
+ gaussian_dequantization: false
+ random_flip: true
+ rescaled: true
+ num_workers: 32
+ out_of_dist: true
+
+model:
+ type: "simple"
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 1, 2, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: [16, ]
+ dropout: 0.0
+ var_type: fixedsmall
+ ema_rate: 0.999
+ ema: True
+ resamp_with_conv: True
+
+diffusion:
+ beta_schedule: linear
+ beta_start: 0.0001
+ beta_end: 0.02
+ num_diffusion_timesteps: 1000
+
+training:
+ batch_size: 64
+ n_epochs: 10000
+ n_iters: 5000000
+ snapshot_freq: 5000
+ validation_freq: 2000
+
+sampling:
+ batch_size: 6
+ last_only: True
+
+optim:
+ weight_decay: 0.000
+ optimizer: "Adam"
+ lr: 0.00002
+ beta1: 0.9
+ amsgrad: false
+ eps: 0.00000001
diff --git a/configs/cat.yml b/configs/cat.yml
new file mode 100644
index 0000000..b46a171
--- /dev/null
+++ b/configs/cat.yml
@@ -0,0 +1,51 @@
+data:
+ dataset: "LSUN"
+ category: "cat"
+ image_size: 256
+ channels: 3
+ logit_transform: false
+ uniform_dequantization: false
+ gaussian_dequantization: false
+ random_flip: true
+ rescaled: true
+ num_workers: 32
+ out_of_dist: false
+
+model:
+ type: "simple"
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 1, 2, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: [16, ]
+ dropout: 0.0
+ var_type: fixedsmall
+ ema_rate: 0.999
+ ema: True
+ resamp_with_conv: True
+
+diffusion:
+ beta_schedule: linear
+ beta_start: 0.0001
+ beta_end: 0.02
+ num_diffusion_timesteps: 1000
+
+training:
+ batch_size: 64
+ n_epochs: 10000
+ n_iters: 5000000
+ snapshot_freq: 5000
+ validation_freq: 2000
+
+sampling:
+ batch_size: 32
+ last_only: True
+
+optim:
+ weight_decay: 0.000
+ optimizer: "Adam"
+ lr: 0.00002
+ beta1: 0.9
+ amsgrad: false
+ eps: 0.00000001
diff --git a/configs/celeba_hq.yml b/configs/celeba_hq.yml
new file mode 100644
index 0000000..1eaba48
--- /dev/null
+++ b/configs/celeba_hq.yml
@@ -0,0 +1,36 @@
+data:
+ dataset: "CelebA_HQ"
+ category: ""
+ image_size: 256
+ channels: 3
+ logit_transform: false
+ uniform_dequantization: false
+ gaussian_dequantization: false
+ random_flip: true
+ rescaled: true
+ num_workers: 32
+ out_of_dist: True
+
+model:
+ type: "simple"
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 1, 2, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: [16, ]
+ dropout: 0.0
+ var_type: fixedsmall
+ ema_rate: 0.999
+ ema: True
+ resamp_with_conv: True
+
+diffusion:
+ beta_schedule: linear
+ beta_start: 0.0001
+ beta_end: 0.02
+ num_diffusion_timesteps: 1000
+
+sampling:
+ batch_size: 4
+ last_only: True
\ No newline at end of file
diff --git a/configs/church.yml b/configs/church.yml
new file mode 100644
index 0000000..56e68bc
--- /dev/null
+++ b/configs/church.yml
@@ -0,0 +1,51 @@
+data:
+ dataset: "LSUN"
+ category: "church_outdoor"
+ image_size: 256
+ channels: 3
+ logit_transform: false
+ uniform_dequantization: false
+ gaussian_dequantization: false
+ random_flip: true
+ rescaled: true
+ num_workers: 32
+ out_of_dist: true
+
+model:
+ type: "simple"
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 1, 2, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: [16, ]
+ dropout: 0.0
+ var_type: fixedsmall
+ ema_rate: 0.999
+ ema: True
+ resamp_with_conv: True
+
+diffusion:
+ beta_schedule: linear
+ beta_start: 0.0001
+ beta_end: 0.02
+ num_diffusion_timesteps: 1000
+
+training:
+ batch_size: 64
+ n_epochs: 10000
+ n_iters: 5000000
+ snapshot_freq: 5000
+ validation_freq: 2000
+
+sampling:
+ batch_size: 6
+ last_only: True
+
+optim:
+ weight_decay: 0.000
+ optimizer: "Adam"
+ lr: 0.00002
+ beta1: 0.9
+ amsgrad: false
+ eps: 0.00000001
diff --git a/configs/imagenet_256.yml b/configs/imagenet_256.yml
new file mode 100644
index 0000000..49de992
--- /dev/null
+++ b/configs/imagenet_256.yml
@@ -0,0 +1,43 @@
+data:
+ dataset: "ImageNet"
+ image_size: 256
+ channels: 3
+ logit_transform: false
+ uniform_dequantization: false
+ gaussian_dequantization: false
+ random_flip: true
+ rescaled: true
+ num_workers: 32
+ subset_1k: True
+ out_of_dist: False
+
+model:
+ type: "openai"
+ in_channels: 3
+ out_channels: 3
+ num_channels: 256
+ num_heads: 4
+ num_res_blocks: 2
+ attention_resolutions: "32,16,8"
+ dropout: 0.0
+ resamp_with_conv: True
+ learn_sigma: True
+ use_scale_shift_norm: true
+ use_fp16: true
+ resblock_updown: true
+ num_heads_upsample: -1
+ var_type: 'fixedsmall'
+ num_head_channels: 64
+ image_size: 256
+ class_cond: false
+ use_new_attention_order: false
+
+diffusion:
+ beta_schedule: linear
+ beta_start: 0.0001
+ beta_end: 0.02
+ num_diffusion_timesteps: 1000
+
+sampling:
+ batch_size: 8
+ last_only: True
diff --git a/configs/imagenet_256_cc.yml b/configs/imagenet_256_cc.yml
new file mode 100644
index 0000000..12bc54f
--- /dev/null
+++ b/configs/imagenet_256_cc.yml
@@ -0,0 +1,55 @@
+data:
+ dataset: "ImageNet"
+ image_size: 256
+ channels: 3
+ logit_transform: false
+ uniform_dequantization: false
+ gaussian_dequantization: false
+ random_flip: true
+ rescaled: true
+ num_workers: 32
+ subset_1k: False
+ out_of_dist: False
+
+model:
+ type: "openai"
+ in_channels: 3
+ out_channels: 3
+ num_channels: 256
+ num_heads: 4
+ num_res_blocks: 2
+ attention_resolutions: "32,16,8"
+ dropout: 0.0
+ resamp_with_conv: True
+ learn_sigma: True
+ use_scale_shift_norm: true
+ use_fp16: true
+ resblock_updown: true
+ num_heads_upsample: -1
+ var_type: 'fixedsmall'
+ num_head_channels: 64
+ image_size: 256
+ class_cond: True
+ use_new_attention_order: false
+
+classifier:
+ image_size: 256
+ classifier_attention_resolutions: "32,16,8"
+ classifier_depth: 2
+ classifier_pool: "attention"
+ classifier_resblock_updown: True
+ classifier_width: 128
+ classifier_use_scale_shift_norm: True
+ classifier_scale: 1.0
+ classifier_use_fp16: True
+
+
+diffusion:
+ beta_schedule: linear
+ beta_start: 0.0001
+ beta_end: 0.02
+ num_diffusion_timesteps: 1000
+
+sampling:
+ batch_size: 8
+ last_only: True
\ No newline at end of file
diff --git a/configs/imagenet_512_cc.yml b/configs/imagenet_512_cc.yml
new file mode 100644
index 0000000..e5b3b22
--- /dev/null
+++ b/configs/imagenet_512_cc.yml
@@ -0,0 +1,55 @@
+data:
+ dataset: "ImageNet"
+ image_size: 512
+ channels: 3
+ logit_transform: false
+ uniform_dequantization: false
+ gaussian_dequantization: false
+ random_flip: true
+ rescaled: true
+ num_workers: 32
+ subset_1k: False
+ out_of_dist: False
+
+model:
+ type: "openai"
+ in_channels: 3
+ out_channels: 3
+ num_channels: 256
+ num_heads: 4
+ num_res_blocks: 2
+ attention_resolutions: "32,16,8"
+ dropout: 0.0
+ resamp_with_conv: True
+ learn_sigma: True
+ use_scale_shift_norm: true
+ use_fp16: false
+ resblock_updown: true
+ num_heads_upsample: -1
+ var_type: 'fixedsmall'
+ num_head_channels: 64
+ image_size: 512
+ class_cond: True
+ use_new_attention_order: false
+
+classifier:
+ image_size: 512
+ classifier_attention_resolutions: "32,16,8"
+ classifier_depth: 2
+ classifier_pool: "attention"
+ classifier_resblock_updown: True
+ classifier_width: 128
+ classifier_use_scale_shift_norm: True
+ classifier_scale: 1.0
+ classifier_use_fp16: false
+
+
+diffusion:
+ beta_schedule: linear
+ beta_start: 0.0001
+ beta_end: 0.02
+ num_diffusion_timesteps: 1000
+
+sampling:
+ batch_size: 1
+ last_only: True
\ No newline at end of file
diff --git a/datasets/__init__.py b/datasets/__init__.py
new file mode 100644
index 0000000..5e8fafd
--- /dev/null
+++ b/datasets/__init__.py
@@ -0,0 +1,223 @@
+import os
+import torch
+import numbers
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as F
+from datasets.celeba import CelebA
+from datasets.lsun import LSUN
+from torch.utils.data import Subset
+import numpy as np
+import torchvision
+from PIL import Image
+from functools import partial
+
+class Crop(object):
+ def __init__(self, x1, x2, y1, y2):
+ self.x1 = x1
+ self.x2 = x2
+ self.y1 = y1
+ self.y2 = y2
+
+ def __call__(self, img):
+ return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1)
+
+ def __repr__(self):
+ return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format(
+ self.x1, self.x2, self.y1, self.y2
+ )
+
+def center_crop_arr(pil_image, image_size = 256):
+ # Imported from openai/guided-diffusion
+ while min(*pil_image.size) >= 2 * image_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = image_size / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image)
+ crop_y = (arr.shape[0] - image_size) // 2
+ crop_x = (arr.shape[1] - image_size) // 2
+ return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
+
+
+def get_dataset(args, config):
+ if config.data.random_flip is False:
+ tran_transform = test_transform = transforms.Compose(
+ [transforms.Resize(config.data.image_size), transforms.ToTensor()]
+ )
+ else:
+ tran_transform = transforms.Compose(
+ [
+ transforms.Resize(config.data.image_size),
+ transforms.RandomHorizontalFlip(p=0.5),
+ transforms.ToTensor(),
+ ]
+ )
+ test_transform = transforms.Compose(
+ [transforms.Resize(config.data.image_size), transforms.ToTensor()]
+ )
+
+ if config.data.dataset == "CELEBA":
+ cx = 89
+ cy = 121
+ x1 = cy - 64
+ x2 = cy + 64
+ y1 = cx - 64
+ y2 = cx + 64
+ if config.data.random_flip:
+ dataset = CelebA(
+ root=os.path.join(args.exp, "datasets", "celeba"),
+ split="train",
+ transform=transforms.Compose(
+ [
+ Crop(x1, x2, y1, y2),
+ transforms.Resize(config.data.image_size),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ ]
+ ),
+ download=True,
+ )
+ else:
+ dataset = CelebA(
+ root=os.path.join(args.exp, "datasets", "celeba"),
+ split="train",
+ transform=transforms.Compose(
+ [
+ Crop(x1, x2, y1, y2),
+ transforms.Resize(config.data.image_size),
+ transforms.ToTensor(),
+ ]
+ ),
+ download=True,
+ )
+
+ test_dataset = CelebA(
+ root=os.path.join(args.exp, "datasets", "celeba"),
+ split="test",
+ transform=transforms.Compose(
+ [
+ Crop(x1, x2, y1, y2),
+ transforms.Resize(config.data.image_size),
+ transforms.ToTensor(),
+ ]
+ ),
+ download=True,
+ )
+
+ elif config.data.dataset == "LSUN":
+ if config.data.out_of_dist:
+ dataset = torchvision.datasets.ImageFolder(
+ os.path.join(args.exp, 'datasets', "ood_{}".format(config.data.category)),
+ transform=transforms.Compose([partial(center_crop_arr, image_size=config.data.image_size),
+ transforms.ToTensor()])
+ )
+ test_dataset = dataset
+ else:
+ train_folder = "{}_train".format(config.data.category)
+ val_folder = "{}_val".format(config.data.category)
+ test_dataset = LSUN(
+ root=os.path.join(args.exp, "datasets", "lsun"),
+ classes=[val_folder],
+ transform=transforms.Compose(
+ [
+ transforms.Resize(config.data.image_size),
+ transforms.CenterCrop(config.data.image_size),
+ transforms.ToTensor(),
+ ]
+ )
+ )
+ dataset = test_dataset
+
+ elif config.data.dataset == "CelebA_HQ" or config.data.dataset == 'FFHQ':
+ if config.data.out_of_dist:
+ dataset = torchvision.datasets.ImageFolder(
+ os.path.join(args.exp, "datasets", "ood_celeba"),
+ transform=transforms.Compose([transforms.Resize([config.data.image_size, config.data.image_size]),
+ transforms.ToTensor()])
+ )
+ test_dataset = dataset
+ else:
+ dataset = torchvision.datasets.ImageFolder(
+ os.path.join(args.exp, "datasets", "celeba_hq"),
+ transform=transforms.Compose([transforms.Resize([config.data.image_size, config.data.image_size]),
+ transforms.ToTensor()])
+ )
+ num_items = len(dataset)
+ indices = list(range(num_items))
+ random_state = np.random.get_state()
+ np.random.seed(2019)
+ np.random.shuffle(indices)
+ np.random.set_state(random_state)
+ train_indices, test_indices = (
+ indices[: int(num_items * 0.9)],
+ indices[int(num_items * 0.9) :],
+ )
+ test_dataset = Subset(dataset, test_indices)
+
+ elif config.data.dataset == 'ImageNet':
+ # only use validation dataset here
+
+ if config.data.subset_1k:
+ from datasets.imagenet_subset import ImageDataset
+ dataset = ImageDataset(os.path.join(args.exp, 'datasets', 'imagenet', 'imagenet'),
+ os.path.join(args.exp, 'imagenet_val_1k.txt'),
+ image_size=config.data.image_size,
+ normalize=False)
+ test_dataset = dataset
+ elif config.data.out_of_dist:
+ dataset = torchvision.datasets.ImageFolder(
+ os.path.join(args.exp, 'datasets', 'ood'),
+ transform=transforms.Compose([partial(center_crop_arr, image_size=config.data.image_size),
+ transforms.ToTensor()])
+ )
+ test_dataset = dataset
+ else:
+ dataset = torchvision.datasets.ImageNet(
+ os.path.join(args.exp, 'datasets', 'imagenet'), split='val',
+ transform=transforms.Compose([partial(center_crop_arr, image_size=config.data.image_size),
+ transforms.ToTensor()])
+ )
+ test_dataset = dataset
+ else:
+ dataset, test_dataset = None, None
+
+ return dataset, test_dataset
+
+
+def logit_transform(image, lam=1e-6):
+ image = lam + (1 - 2 * lam) * image
+ return torch.log(image) - torch.log1p(-image)
+
+
+def data_transform(config, X):
+ if config.data.uniform_dequantization:
+ X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0
+ if config.data.gaussian_dequantization:
+ X = X + torch.randn_like(X) * 0.01
+
+ if config.data.rescaled:
+ X = 2 * X - 1.0
+ elif config.data.logit_transform:
+ X = logit_transform(X)
+
+ if hasattr(config, "image_mean"):
+ return X - config.image_mean.to(X.device)[None, ...]
+
+ return X
+
+
+def inverse_data_transform(config, X):
+ if hasattr(config, "image_mean"):
+ X = X + config.image_mean.to(X.device)[None, ...]
+
+ if config.data.logit_transform:
+ X = torch.sigmoid(X)
+ elif config.data.rescaled:
+ X = (X + 1.0) / 2.0
+
+ return torch.clamp(X, 0.0, 1.0)
diff --git a/datasets/__pycache__/__init__.cpython-37.pyc b/datasets/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..c0cb5a0
Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-37.pyc differ
diff --git a/datasets/__pycache__/__init__.cpython-38.pyc b/datasets/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..67a54a2
Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-38.pyc differ
diff --git a/datasets/__pycache__/__init__.cpython-39.pyc b/datasets/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000..1df55d5
Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-39.pyc differ
diff --git a/datasets/__pycache__/celeba.cpython-37.pyc b/datasets/__pycache__/celeba.cpython-37.pyc
new file mode 100644
index 0000000..cbc878c
Binary files /dev/null and b/datasets/__pycache__/celeba.cpython-37.pyc differ
diff --git a/datasets/__pycache__/celeba.cpython-38.pyc b/datasets/__pycache__/celeba.cpython-38.pyc
new file mode 100644
index 0000000..293b074
Binary files /dev/null and b/datasets/__pycache__/celeba.cpython-38.pyc differ
diff --git a/datasets/__pycache__/celeba.cpython-39.pyc b/datasets/__pycache__/celeba.cpython-39.pyc
new file mode 100644
index 0000000..c5cf2d8
Binary files /dev/null and b/datasets/__pycache__/celeba.cpython-39.pyc differ
diff --git a/datasets/__pycache__/imagenet_subset.cpython-38.pyc b/datasets/__pycache__/imagenet_subset.cpython-38.pyc
new file mode 100644
index 0000000..aad4c4d
Binary files /dev/null and b/datasets/__pycache__/imagenet_subset.cpython-38.pyc differ
diff --git a/datasets/__pycache__/imagenet_subset.cpython-39.pyc b/datasets/__pycache__/imagenet_subset.cpython-39.pyc
new file mode 100644
index 0000000..5217f35
Binary files /dev/null and b/datasets/__pycache__/imagenet_subset.cpython-39.pyc differ
diff --git a/datasets/__pycache__/lsun.cpython-37.pyc b/datasets/__pycache__/lsun.cpython-37.pyc
new file mode 100644
index 0000000..9b530b9
Binary files /dev/null and b/datasets/__pycache__/lsun.cpython-37.pyc differ
diff --git a/datasets/__pycache__/lsun.cpython-38.pyc b/datasets/__pycache__/lsun.cpython-38.pyc
new file mode 100644
index 0000000..37caafb
Binary files /dev/null and b/datasets/__pycache__/lsun.cpython-38.pyc differ
diff --git a/datasets/__pycache__/lsun.cpython-39.pyc b/datasets/__pycache__/lsun.cpython-39.pyc
new file mode 100644
index 0000000..0b03166
Binary files /dev/null and b/datasets/__pycache__/lsun.cpython-39.pyc differ
diff --git a/datasets/__pycache__/utils.cpython-37.pyc b/datasets/__pycache__/utils.cpython-37.pyc
new file mode 100644
index 0000000..03ef98a
Binary files /dev/null and b/datasets/__pycache__/utils.cpython-37.pyc differ
diff --git a/datasets/__pycache__/utils.cpython-38.pyc b/datasets/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000..6b65080
Binary files /dev/null and b/datasets/__pycache__/utils.cpython-38.pyc differ
diff --git a/datasets/__pycache__/utils.cpython-39.pyc b/datasets/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000..cbd7487
Binary files /dev/null and b/datasets/__pycache__/utils.cpython-39.pyc differ
diff --git a/datasets/__pycache__/vision.cpython-37.pyc b/datasets/__pycache__/vision.cpython-37.pyc
new file mode 100644
index 0000000..87ba15f
Binary files /dev/null and b/datasets/__pycache__/vision.cpython-37.pyc differ
diff --git a/datasets/__pycache__/vision.cpython-38.pyc b/datasets/__pycache__/vision.cpython-38.pyc
new file mode 100644
index 0000000..4104b7d
Binary files /dev/null and b/datasets/__pycache__/vision.cpython-38.pyc differ
diff --git a/datasets/__pycache__/vision.cpython-39.pyc b/datasets/__pycache__/vision.cpython-39.pyc
new file mode 100644
index 0000000..d56ed2d
Binary files /dev/null and b/datasets/__pycache__/vision.cpython-39.pyc differ
diff --git a/datasets/celeba.py b/datasets/celeba.py
new file mode 100644
index 0000000..1c466dc
--- /dev/null
+++ b/datasets/celeba.py
@@ -0,0 +1,163 @@
+import torch
+import os
+import PIL
+from .vision import VisionDataset
+from .utils import download_file_from_google_drive, check_integrity
+
+
+class CelebA(VisionDataset):
+ """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset.
+
+ Args:
+ root (string): Root directory where images are downloaded to.
+ split (string): One of {'train', 'valid', 'test'}.
+ Accordingly dataset is selected.
+ target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
+ or ``landmarks``. Can also be a list to output a tuple with all specified target types.
+ The targets represent:
+ ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes
+ ``identity`` (int): label for each person (data points with the same identity are the same person)
+ ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
+ ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
+ righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
+ Defaults to ``attr``.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.ToTensor``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ """
+
+ base_folder = "celeba"
+ # There currently does not appear to be a easy way to extract 7z in python (without introducing additional
+ # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
+ # right now.
+ file_list = [
+ # File ID MD5 Hash Filename
+ ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
+ # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
+ # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
+ ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
+ ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
+ ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
+ ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
+ # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
+ ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
+ ]
+
+ def __init__(self, root,
+ split="train",
+ target_type="attr",
+ transform=None, target_transform=None,
+ download=False):
+ import pandas
+ super(CelebA, self).__init__(root)
+ self.split = split
+ if isinstance(target_type, list):
+ self.target_type = target_type
+ else:
+ self.target_type = [target_type]
+ self.transform = transform
+ self.target_transform = target_transform
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError('Dataset not found or corrupted.' +
+ ' You can use download=True to download it')
+
+ self.transform = transform
+ self.target_transform = target_transform
+
+ if split.lower() == "train":
+ split = 0
+ elif split.lower() == "valid":
+ split = 1
+ elif split.lower() == "test":
+ split = 2
+ else:
+ raise ValueError('Wrong split entered! Please use split="train" '
+ 'or split="valid" or split="test"')
+
+ with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f:
+ splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)
+
+ with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f:
+ self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)
+
+ with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f:
+ self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0)
+
+ with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f:
+ self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1)
+
+ with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f:
+ self.attr = pandas.read_csv(f, delim_whitespace=True, header=1)
+
+ mask = (splits[1] == split)
+ self.filename = splits[mask].index.values
+ self.identity = torch.as_tensor(self.identity[mask].values)
+ self.bbox = torch.as_tensor(self.bbox[mask].values)
+ self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values)
+ self.attr = torch.as_tensor(self.attr[mask].values)
+ self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
+
+ def _check_integrity(self):
+ for (_, md5, filename) in self.file_list:
+ fpath = os.path.join(self.root, self.base_folder, filename)
+ _, ext = os.path.splitext(filename)
+ # Allow original archive to be deleted (zip and 7z)
+ # Only need the extracted images
+ if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
+ return False
+
+ # Should check a hash of the images
+ return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
+
+ def download(self):
+ import zipfile
+
+ if self._check_integrity():
+ print('Files already downloaded and verified')
+ return
+
+ for (file_id, md5, filename) in self.file_list:
+ download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
+
+ with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
+ f.extractall(os.path.join(self.root, self.base_folder))
+
+ def __getitem__(self, index):
+ X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
+
+ target = []
+ for t in self.target_type:
+ if t == "attr":
+ target.append(self.attr[index, :])
+ elif t == "identity":
+ target.append(self.identity[index, 0])
+ elif t == "bbox":
+ target.append(self.bbox[index, :])
+ elif t == "landmarks":
+ target.append(self.landmarks_align[index, :])
+ else:
+ raise ValueError("Target type \"{}\" is not recognized.".format(t))
+ target = tuple(target) if len(target) > 1 else target[0]
+
+ if self.transform is not None:
+ X = self.transform(X)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return X, target
+
+ def __len__(self):
+ return len(self.attr)
+
+ def extra_repr(self):
+ lines = ["Target type: {target_type}", "Split: {split}"]
+ return '\n'.join(lines).format(**self.__dict__)
diff --git a/datasets/imagenet_subset.py b/datasets/imagenet_subset.py
new file mode 100644
index 0000000..bcb7eec
--- /dev/null
+++ b/datasets/imagenet_subset.py
@@ -0,0 +1,102 @@
+import torch.utils.data as data
+import torchvision.transforms as transforms
+from PIL import Image
+
+class CenterCropLongEdge(object):
+ """Crops the given PIL Image on the long edge.
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is
+ made.
+ """
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Image to be cropped.
+ Returns:
+ PIL Image: Cropped image.
+ """
+ return transforms.functional.center_crop(img, min(img.size))
+
+ def __repr__(self):
+ return self.__class__.__name__
+
+def pil_loader(path):
+ # open path as file to avoid ResourceWarning
+ # (https://github.com/python-pillow/Pillow/issues/835)
+ with open(path, 'rb') as f:
+ img = Image.open(f)
+ return img.convert('RGB')
+
+
+def accimage_loader(path):
+ import accimage
+ try:
+ return accimage.Image(path)
+ except IOError:
+ # Potentially a decoding problem, fall back to PIL.Image
+ return pil_loader(path)
+
+def default_loader(path):
+ from torchvision import get_image_backend
+ if get_image_backend() == 'accimage':
+ return accimage_loader(path)
+ else:
+ return pil_loader(path)
+
+class ImageDataset(data.Dataset):
+
+ def __init__(self,
+ root_dir,
+ meta_file,
+ transform=None,
+ image_size=128,
+ normalize=True):
+ self.root_dir = root_dir
+ if transform is not None:
+ self.transform = transform
+ else:
+ norm_mean = [0.5, 0.5, 0.5]
+ norm_std = [0.5, 0.5, 0.5]
+ if normalize:
+ self.transform = transforms.Compose([
+ CenterCropLongEdge(),
+ transforms.Resize(image_size),
+ transforms.ToTensor(),
+ transforms.Normalize(norm_mean, norm_std)
+ ])
+ else:
+ self.transform = transforms.Compose([
+ CenterCropLongEdge(),
+ transforms.Resize(image_size),
+ transforms.ToTensor()
+ ])
+ with open(meta_file) as f:
+ lines = f.readlines()
+ print("building dataset from %s" % meta_file)
+ self.num = len(lines)
+ self.metas = []
+ self.classifier = None
+ suffix = ".jpeg"
+ for line in lines:
+ line_split = line.rstrip().split()
+ if len(line_split) == 2:
+ self.metas.append((line_split[0] + suffix, int(line_split[1])))
+ else:
+ self.metas.append((line_split[0] + suffix, -1))
+ print("read meta done")
+
+ def __len__(self):
+ return self.num
+
+ def __getitem__(self, idx):
+ filename = self.root_dir + '/' + self.metas[idx][0]
+ cls = self.metas[idx][1]
+ img = default_loader(filename)
+
+ # transform
+ if self.transform is not None:
+ img = self.transform(img)
+
+ return img, cls #, self.metas[idx][0]
\ No newline at end of file
diff --git a/datasets/lsun.py b/datasets/lsun.py
new file mode 100644
index 0000000..31dd381
--- /dev/null
+++ b/datasets/lsun.py
@@ -0,0 +1,176 @@
+from .vision import VisionDataset
+from PIL import Image
+import os
+import os.path
+import io
+from collections.abc import Iterable
+import pickle
+from torchvision.datasets.utils import verify_str_arg, iterable_to_str
+
+
+class LSUNClass(VisionDataset):
+ def __init__(self, root, transform=None, target_transform=None):
+ import lmdb
+
+ super(LSUNClass, self).__init__(
+ root, transform=transform, target_transform=target_transform
+ )
+
+ self.env = lmdb.open(
+ root,
+ max_readers=1,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False,
+ )
+ with self.env.begin(write=False) as txn:
+ self.length = txn.stat()["entries"]
+ root_split = root.split("/")
+ cache_file = os.path.join("/".join(root_split[:-1]), f"_cache_{root_split[-1]}")
+ if os.path.isfile(cache_file):
+ self.keys = pickle.load(open(cache_file, "rb"))
+ else:
+ with self.env.begin(write=False) as txn:
+ self.keys = [key for key, _ in txn.cursor()]
+ pickle.dump(self.keys, open(cache_file, "wb"))
+
+ def __getitem__(self, index):
+ img, target = None, None
+ env = self.env
+ with env.begin(write=False) as txn:
+ imgbuf = txn.get(self.keys[index])
+
+ buf = io.BytesIO()
+ buf.write(imgbuf)
+ buf.seek(0)
+ img = Image.open(buf).convert("RGB")
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self):
+ return self.length
+
+
+class LSUN(VisionDataset):
+ """
+ `LSUN `_ dataset.
+
+ Args:
+ root (string): Root directory for the database files.
+ classes (string or list): One of {'train', 'val', 'test'} or a list of
+ categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ """
+
+ def __init__(self, root, classes="train", transform=None, target_transform=None):
+ super(LSUN, self).__init__(
+ root, transform=transform, target_transform=target_transform
+ )
+ self.classes = self._verify_classes(classes)
+
+ # for each class, create an LSUNClassDataset
+ self.dbs = []
+ for c in self.classes:
+ self.dbs.append(
+ LSUNClass(root=root + "/" + c + "_lmdb", transform=transform)
+ )
+
+ self.indices = []
+ count = 0
+ for db in self.dbs:
+ count += len(db)
+ self.indices.append(count)
+
+ self.length = count
+
+ def _verify_classes(self, classes):
+ categories = [
+ "bedroom",
+ "bridge",
+ "church_outdoor",
+ "classroom",
+ "conference_room",
+ "dining_room",
+ "kitchen",
+ "living_room",
+ "restaurant",
+ "tower",
+ "cat",
+ ]
+ dset_opts = ["train", "val", "test"]
+
+ try:
+ verify_str_arg(classes, "classes", dset_opts)
+ if classes == "test":
+ classes = [classes]
+ else:
+ classes = [c + "_" + classes for c in categories]
+ except ValueError:
+ if not isinstance(classes, Iterable):
+ msg = (
+ "Expected type str or Iterable for argument classes, "
+ "but got type {}."
+ )
+ raise ValueError(msg.format(type(classes)))
+
+ classes = list(classes)
+ msg_fmtstr = (
+ "Expected type str for elements in argument classes, "
+ "but got type {}."
+ )
+ for c in classes:
+ verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c)))
+ c_short = c.split("_")
+ category, dset_opt = "_".join(c_short[:-1]), c_short[-1]
+
+ msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
+ msg = msg_fmtstr.format(
+ category, "LSUN class", iterable_to_str(categories)
+ )
+ verify_str_arg(category, valid_values=categories, custom_msg=msg)
+
+ msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
+ verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
+
+ return classes
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: Tuple (image, target) where target is the index of the target category.
+ """
+ target = 0
+ sub = 0
+ for ind in self.indices:
+ if index < ind:
+ break
+ target += 1
+ sub = ind
+
+ db = self.dbs[target]
+ index = index - sub
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ img, _ = db[index]
+ return img, target
+
+ def __len__(self):
+ return self.length
+
+ def extra_repr(self):
+ return "Classes: {classes}".format(**self.__dict__)
diff --git a/datasets/utils.py b/datasets/utils.py
new file mode 100644
index 0000000..1d02194
--- /dev/null
+++ b/datasets/utils.py
@@ -0,0 +1,186 @@
+import os
+import os.path
+import hashlib
+import errno
+from torch.utils.model_zoo import tqdm
+
+
+def gen_bar_updater():
+ pbar = tqdm(total=None)
+
+ def bar_update(count, block_size, total_size):
+ if pbar.total is None and total_size:
+ pbar.total = total_size
+ progress_bytes = count * block_size
+ pbar.update(progress_bytes - pbar.n)
+
+ return bar_update
+
+
+def check_integrity(fpath, md5=None):
+ if md5 is None:
+ return True
+ if not os.path.isfile(fpath):
+ return False
+ md5o = hashlib.md5()
+ with open(fpath, 'rb') as f:
+ # read in 1MB chunks
+ for chunk in iter(lambda: f.read(1024 * 1024), b''):
+ md5o.update(chunk)
+ md5c = md5o.hexdigest()
+ if md5c != md5:
+ return False
+ return True
+
+
+def makedir_exist_ok(dirpath):
+ """
+ Python2 support for os.makedirs(.., exist_ok=True)
+ """
+ try:
+ os.makedirs(dirpath)
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ pass
+ else:
+ raise
+
+
+def download_url(url, root, filename=None, md5=None):
+ """Download a file from a url and place it in root.
+
+ Args:
+ url (str): URL to download file from
+ root (str): Directory to place downloaded file in
+ filename (str, optional): Name to save the file under. If None, use the basename of the URL
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
+ """
+ from six.moves import urllib
+
+ root = os.path.expanduser(root)
+ if not filename:
+ filename = os.path.basename(url)
+ fpath = os.path.join(root, filename)
+
+ makedir_exist_ok(root)
+
+ # downloads file
+ if os.path.isfile(fpath) and check_integrity(fpath, md5):
+ print('Using downloaded and verified file: ' + fpath)
+ else:
+ try:
+ print('Downloading ' + url + ' to ' + fpath)
+ urllib.request.urlretrieve(
+ url, fpath,
+ reporthook=gen_bar_updater()
+ )
+ except OSError:
+ if url[:5] == 'https':
+ url = url.replace('https:', 'http:')
+ print('Failed download. Trying https -> http instead.'
+ ' Downloading ' + url + ' to ' + fpath)
+ urllib.request.urlretrieve(
+ url, fpath,
+ reporthook=gen_bar_updater()
+ )
+
+
+def list_dir(root, prefix=False):
+ """List all directories at a given root
+
+ Args:
+ root (str): Path to directory whose folders need to be listed
+ prefix (bool, optional): If true, prepends the path to each result, otherwise
+ only returns the name of the directories found
+ """
+ root = os.path.expanduser(root)
+ directories = list(
+ filter(
+ lambda p: os.path.isdir(os.path.join(root, p)),
+ os.listdir(root)
+ )
+ )
+
+ if prefix is True:
+ directories = [os.path.join(root, d) for d in directories]
+
+ return directories
+
+
+def list_files(root, suffix, prefix=False):
+ """List all files ending with a suffix at a given root
+
+ Args:
+ root (str): Path to directory whose folders need to be listed
+ suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
+ It uses the Python "str.endswith" method and is passed directly
+ prefix (bool, optional): If true, prepends the path to each result, otherwise
+ only returns the name of the files found
+ """
+ root = os.path.expanduser(root)
+ files = list(
+ filter(
+ lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
+ os.listdir(root)
+ )
+ )
+
+ if prefix is True:
+ files = [os.path.join(root, d) for d in files]
+
+ return files
+
+
+def download_file_from_google_drive(file_id, root, filename=None, md5=None):
+ """Download a Google Drive file from and place it in root.
+
+ Args:
+ file_id (str): id of file to be downloaded
+ root (str): Directory to place downloaded file in
+ filename (str, optional): Name to save the file under. If None, use the id of the file.
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
+ """
+ # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
+ import requests
+ url = "https://docs.google.com/uc?export=download"
+
+ root = os.path.expanduser(root)
+ if not filename:
+ filename = file_id
+ fpath = os.path.join(root, filename)
+
+ makedir_exist_ok(root)
+
+ if os.path.isfile(fpath) and check_integrity(fpath, md5):
+ print('Using downloaded and verified file: ' + fpath)
+ else:
+ session = requests.Session()
+
+ response = session.get(url, params={'id': file_id}, stream=True)
+ token = _get_confirm_token(response)
+
+ if token:
+ params = {'id': file_id, 'confirm': token}
+ response = session.get(url, params=params, stream=True)
+
+ _save_response_content(response, fpath)
+
+
+def _get_confirm_token(response):
+ for key, value in response.cookies.items():
+ if key.startswith('download_warning'):
+ return value
+
+ return None
+
+
+def _save_response_content(response, destination, chunk_size=32768):
+ with open(destination, "wb") as f:
+ pbar = tqdm(total=None)
+ progress = 0
+ for chunk in response.iter_content(chunk_size):
+ if chunk: # filter out keep-alive new chunks
+ f.write(chunk)
+ progress += len(chunk)
+ pbar.update(progress - pbar.n)
+ pbar.close()
diff --git a/datasets/vision.py b/datasets/vision.py
new file mode 100644
index 0000000..0f156c1
--- /dev/null
+++ b/datasets/vision.py
@@ -0,0 +1,84 @@
+import os
+import torch
+import torch.utils.data as data
+
+
+class VisionDataset(data.Dataset):
+ _repr_indent = 4
+
+ def __init__(self, root, transforms=None, transform=None, target_transform=None):
+ if isinstance(root, torch._six.string_classes):
+ root = os.path.expanduser(root)
+ self.root = root
+
+ has_transforms = transforms is not None
+ has_separate_transform = transform is not None or target_transform is not None
+ if has_transforms and has_separate_transform:
+ raise ValueError("Only transforms or transform/target_transform can "
+ "be passed as argument")
+
+ # for backwards-compatibility
+ self.transform = transform
+ self.target_transform = target_transform
+
+ if has_separate_transform:
+ transforms = StandardTransform(transform, target_transform)
+ self.transforms = transforms
+
+ def __getitem__(self, index):
+ raise NotImplementedError
+
+ def __len__(self):
+ raise NotImplementedError
+
+ def __repr__(self):
+ head = "Dataset " + self.__class__.__name__
+ body = ["Number of datapoints: {}".format(self.__len__())]
+ if self.root is not None:
+ body.append("Root location: {}".format(self.root))
+ body += self.extra_repr().splitlines()
+ if hasattr(self, 'transform') and self.transform is not None:
+ body += self._format_transform_repr(self.transform,
+ "Transforms: ")
+ if hasattr(self, 'target_transform') and self.target_transform is not None:
+ body += self._format_transform_repr(self.target_transform,
+ "Target transforms: ")
+ lines = [head] + [" " * self._repr_indent + line for line in body]
+ return '\n'.join(lines)
+
+ def _format_transform_repr(self, transform, head):
+ lines = transform.__repr__().splitlines()
+ return (["{}{}".format(head, lines[0])] +
+ ["{}{}".format(" " * len(head), line) for line in lines[1:]])
+
+ def extra_repr(self):
+ return ""
+
+
+class StandardTransform(object):
+ def __init__(self, transform=None, target_transform=None):
+ self.transform = transform
+ self.target_transform = target_transform
+
+ def __call__(self, input, target):
+ if self.transform is not None:
+ input = self.transform(input)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+ return input, target
+
+ def _format_transform_repr(self, transform, head):
+ lines = transform.__repr__().splitlines()
+ return (["{}{}".format(head, lines[0])] +
+ ["{}{}".format(" " * len(head), line) for line in lines[1:]])
+
+ def __repr__(self):
+ body = [self.__class__.__name__]
+ if self.transform is not None:
+ body += self._format_transform_repr(self.transform,
+ "Transform: ")
+ if self.target_transform is not None:
+ body += self._format_transform_repr(self.target_transform,
+ "Target transform: ")
+
+ return '\n'.join(body)
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000..47aeb41
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,141 @@
+name: ddrm
+channels:
+ - pytorch
+ - conda-forge
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=conda_forge
+ - _openmp_mutex=4.5=1_gnu
+ - absl-py=1.0.0=pyhd8ed1ab_0
+ - aiohttp=3.8.1=py39h3811e60_0
+ - aiosignal=1.2.0=pyhd8ed1ab_0
+ - async-timeout=4.0.2=pyhd8ed1ab_0
+ - attrs=21.4.0=pyhd8ed1ab_0
+ - backcall=0.2.0=pyh9f0ad1d_0
+ - backports=1.0=py_2
+ - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
+ - blas=1.0=mkl
+ - blinker=1.4=py_1
+ - brotlipy=0.7.0=py39h3811e60_1003
+ - bzip2=1.0.8=h7b6447c_0
+ - c-ares=1.18.1=h7f98852_0
+ - ca-certificates=2021.10.26=h06a4308_2
+ - cachetools=4.2.4=pyhd8ed1ab_0
+ - certifi=2021.10.8=py39hf3d152e_1
+ - cffi=1.15.0=py39h4bc2ebd_0
+ - charset-normalizer=2.0.9=pyhd8ed1ab_0
+ - click=8.0.3=py39hf3d152e_1
+ - colorama=0.4.4=pyh9f0ad1d_0
+ - cryptography=36.0.1=py39h95dcef6_0
+ - cudatoolkit=11.3.1=h2bc3f7f_2
+ - dataclasses=0.8=pyhc8e2a94_3
+ - decorator=5.1.0=pyhd8ed1ab_0
+ - ffmpeg=4.3=hf484d3e_0
+ - freetype=2.11.0=h70c0345_0
+ - frozenlist=1.2.0=py39h3811e60_1
+ - giflib=5.2.1=h7b6447c_0
+ - gmp=6.2.1=h2531618_2
+ - gnutls=3.6.15=he1e5248_0
+ - google-auth=2.3.3=pyh6c4a22f_0
+ - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0
+ - grpcio=1.42.0=py39hce63b2e_0
+ - idna=3.3=pyhd3eb1b0_0
+ - importlib-metadata=4.10.0=py39hf3d152e_0
+ - intel-openmp=2021.4.0=h06a4308_3561
+ - ipdb=0.13.9=pyhd8ed1ab_0
+ - ipython=7.30.1=py39hf3d152e_0
+ - jedi=0.18.1=py39hf3d152e_0
+ - jpeg=9d=h7f8727e_0
+ - lame=3.100=h7b6447c_0
+ - lcms2=2.12=h3be6417_0
+ - ld_impl_linux-64=2.36.1=hea4e1c9_2
+ - libblas=3.9.0=12_linux64_mkl
+ - libcblas=3.9.0=12_linux64_mkl
+ - libffi=3.4.2=h7f98852_5
+ - libgcc-ng=11.2.0=h1d223b6_11
+ - libgfortran-ng=11.2.0=h69a702a_11
+ - libgfortran5=11.2.0=h5c6108e_11
+ - libgomp=11.2.0=h1d223b6_11
+ - libiconv=1.15=h63c8f33_5
+ - libidn2=2.3.2=h7f8727e_0
+ - liblapack=3.9.0=12_linux64_mkl
+ - libpng=1.6.37=hbc83047_0
+ - libprotobuf=3.17.2=h780b84a_1
+ - libstdcxx-ng=11.2.0=he4da1e4_11
+ - libtasn1=4.16.0=h27cfd23_0
+ - libtiff=4.2.0=h85742a9_0
+ - libunistring=0.9.10=h27cfd23_0
+ - libuv=1.40.0=h7b6447c_0
+ - libwebp=1.2.0=h89dd481_0
+ - libwebp-base=1.2.0=h27cfd23_0
+ - lmdb=0.9.29=h2531618_0
+ - lz4-c=1.9.3=h295c915_1
+ - markdown=3.3.6=pyhd8ed1ab_0
+ - matplotlib-inline=0.1.3=pyhd8ed1ab_0
+ - mkl=2021.4.0=h06a4308_640
+ - mkl-service=2.4.0=py39h7f8727e_0
+ - mkl_fft=1.3.1=py39hd3c417c_0
+ - mkl_random=1.2.2=py39h51133e4_0
+ - multidict=5.2.0=py39h3811e60_1
+ - ncurses=6.2=h58526e2_4
+ - nettle=3.7.3=hbbd107a_1
+ - numpy=1.21.2=py39h20f2e39_0
+ - numpy-base=1.21.2=py39h79a1101_0
+ - oauthlib=3.1.1=pyhd8ed1ab_0
+ - olefile=0.46=pyhd3eb1b0_0
+ - openh264=2.1.1=h4ff587b_0
+ - openssl=1.1.1l=h7f98852_0
+ - parso=0.8.3=pyhd8ed1ab_0
+ - pexpect=4.8.0=pyh9f0ad1d_2
+ - pickleshare=0.7.5=py39hde42818_1002
+ - pillow=8.4.0=py39h5aabda8_0
+ - pip=21.2.4=py39h06a4308_0
+ - prompt-toolkit=3.0.24=pyha770c72_0
+ - protobuf=3.17.2=py39he80948d_0
+ - ptyprocess=0.7.0=pyhd3deb0d_0
+ - pyasn1=0.4.8=py_0
+ - pyasn1-modules=0.2.8=py_0
+ - pycparser=2.21=pyhd8ed1ab_0
+ - pygments=2.11.1=pyhd8ed1ab_0
+ - pyjwt=2.3.0=pyhd8ed1ab_1
+ - pyopenssl=21.0.0=pyhd8ed1ab_0
+ - pysocks=1.7.1=py39hf3d152e_4
+ - python=3.9.7=hb7a2778_3_cpython
+ - python-lmdb=1.2.1=py39h2531618_1
+ - python_abi=3.9=2_cp39
+ - pytorch=1.10.1=py3.9_cuda11.3_cudnn8.2.0_0
+ - pytorch-mutex=1.0=cuda
+ - pyu2f=0.1.5=pyhd8ed1ab_0
+ - pyyaml=6.0=py39h3811e60_3
+ - readline=8.1=h27cfd23_0
+ - requests=2.26.0=pyhd8ed1ab_1
+ - requests-oauthlib=1.3.0=pyh9f0ad1d_0
+ - rsa=4.8=pyhd8ed1ab_0
+ - scipy=1.7.3=py39hee8e79c_0
+ - setuptools=58.0.4=py39h06a4308_0
+ - six=1.16.0=pyhd3eb1b0_0
+ - sqlite=3.37.0=hc218d9a_0
+ - tensorboard=2.7.0=pyhd8ed1ab_0
+ - tensorboard-data-server=0.6.0=py39h95dcef6_1
+ - tensorboard-plugin-wit=1.8.0=pyh44b312d_0
+ - tk=8.6.11=h1ccaba5_0
+ - torchaudio=0.10.1=py39_cu113
+ - torchvision=0.11.2=py39_cu113
+ - tqdm=4.62.3=pyhd8ed1ab_0
+ - traitlets=5.1.1=pyhd8ed1ab_0
+ - typing-extensions=3.10.0.2=hd3eb1b0_0
+ - typing_extensions=3.10.0.2=pyh06a4308_0
+ - tzdata=2021e=hda174b7_0
+ - urllib3=1.26.7=pyhd8ed1ab_0
+ - wcwidth=0.2.5=pyh9f0ad1d_2
+ - werkzeug=2.0.2=pyhd3eb1b0_0
+ - wheel=0.37.0=pyhd3eb1b0_1
+ - xz=5.2.5=h7b6447c_0
+ - yaml=0.2.5=h516909a_0
+ - yarl=1.7.2=py39h3811e60_1
+ - zipp=3.6.0=pyhd8ed1ab_0
+ - zlib=1.2.11=h7f8727e_4
+ - zstd=1.4.9=haebb681_0
+ - pip:
+ - torch-fidelity==0.3.0
+
diff --git a/functions/__init__.py b/functions/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/functions/__pycache__/__init__.cpython-38.pyc b/functions/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..a855d51
Binary files /dev/null and b/functions/__pycache__/__init__.cpython-38.pyc differ
diff --git a/functions/__pycache__/__init__.cpython-39.pyc b/functions/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000..eec5b2f
Binary files /dev/null and b/functions/__pycache__/__init__.cpython-39.pyc differ
diff --git a/functions/__pycache__/ckpt_util.cpython-38.pyc b/functions/__pycache__/ckpt_util.cpython-38.pyc
new file mode 100644
index 0000000..dec903c
Binary files /dev/null and b/functions/__pycache__/ckpt_util.cpython-38.pyc differ
diff --git a/functions/__pycache__/ckpt_util.cpython-39.pyc b/functions/__pycache__/ckpt_util.cpython-39.pyc
new file mode 100644
index 0000000..d36df38
Binary files /dev/null and b/functions/__pycache__/ckpt_util.cpython-39.pyc differ
diff --git a/functions/__pycache__/dct.cpython-38.pyc b/functions/__pycache__/dct.cpython-38.pyc
new file mode 100644
index 0000000..6c8c20b
Binary files /dev/null and b/functions/__pycache__/dct.cpython-38.pyc differ
diff --git a/functions/__pycache__/dct.cpython-39.pyc b/functions/__pycache__/dct.cpython-39.pyc
new file mode 100644
index 0000000..93705c7
Binary files /dev/null and b/functions/__pycache__/dct.cpython-39.pyc differ
diff --git a/functions/__pycache__/denoising.cpython-38.pyc b/functions/__pycache__/denoising.cpython-38.pyc
new file mode 100644
index 0000000..4872c4a
Binary files /dev/null and b/functions/__pycache__/denoising.cpython-38.pyc differ
diff --git a/functions/__pycache__/denoising.cpython-39.pyc b/functions/__pycache__/denoising.cpython-39.pyc
new file mode 100644
index 0000000..4bcf8f3
Binary files /dev/null and b/functions/__pycache__/denoising.cpython-39.pyc differ
diff --git a/functions/__pycache__/jpeg_torch.cpython-38.pyc b/functions/__pycache__/jpeg_torch.cpython-38.pyc
new file mode 100644
index 0000000..e5065d6
Binary files /dev/null and b/functions/__pycache__/jpeg_torch.cpython-38.pyc differ
diff --git a/functions/__pycache__/jpeg_torch.cpython-39.pyc b/functions/__pycache__/jpeg_torch.cpython-39.pyc
new file mode 100644
index 0000000..b7b0ee2
Binary files /dev/null and b/functions/__pycache__/jpeg_torch.cpython-39.pyc differ
diff --git a/functions/ckpt_util.py b/functions/ckpt_util.py
new file mode 100644
index 0000000..129650b
--- /dev/null
+++ b/functions/ckpt_util.py
@@ -0,0 +1,72 @@
+import os, hashlib
+import requests
+from tqdm import tqdm
+
+URL_MAP = {
+ "cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1",
+ "ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1",
+ "lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1",
+ "ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1",
+ "lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1",
+ "ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1",
+ "lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1",
+ "ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1",
+}
+CKPT_MAP = {
+ "cifar10": "diffusion_cifar10_model/model-790000.ckpt",
+ "ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt",
+ "lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt",
+ "ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt",
+ "lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt",
+ "ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt",
+ "lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt",
+ "ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt",
+}
+MD5_MAP = {
+ "cifar10": "82ed3067fd1002f5cf4c339fb80c4669",
+ "ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3",
+ "lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c",
+ "ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f",
+ "lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b",
+ "ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558",
+ "lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3",
+ "ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f",
+}
+
+
+def download(url, local_path, chunk_size=1024):
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
+ with requests.get(url, stream=True) as r:
+ total_size = int(r.headers.get("content-length", 0))
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
+ with open(local_path, "wb") as f:
+ for data in r.iter_content(chunk_size=chunk_size):
+ if data:
+ f.write(data)
+ pbar.update(chunk_size)
+
+
+def md5_hash(path):
+ with open(path, "rb") as f:
+ content = f.read()
+ return hashlib.md5(content).hexdigest()
+
+
+def get_ckpt_path(name, root=None, check=False, prefix='exp'):
+ if 'church_outdoor' in name:
+ name = name.replace('church_outdoor', 'church')
+ assert name in URL_MAP
+ # Modify the path when necessary
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.join(prefix, "logs/"))
+ root = (
+ root
+ if root is not None
+ else os.path.join(cachedir, "diffusion_models_converted")
+ )
+ path = os.path.join(root, CKPT_MAP[name])
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
+ download(URL_MAP[name], path)
+ md5 = md5_hash(path)
+ assert md5 == MD5_MAP[name], md5
+ return path
diff --git a/functions/dct.py b/functions/dct.py
new file mode 100644
index 0000000..e6d8480
--- /dev/null
+++ b/functions/dct.py
@@ -0,0 +1,207 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+def dct1(x):
+ """
+ Discrete Cosine Transform, Type I
+ :param x: the input signal
+ :return: the DCT-I of the signal over the last dimension
+ """
+ x_shape = x.shape
+ x = x.view(-1, x_shape[-1])
+
+ return torch.fft.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1))[:, :, 0].view(*x_shape)
+
+
+def idct1(X):
+ """
+ The inverse of DCT-I, which is just a scaled DCT-I
+ Our definition if idct1 is such that idct1(dct1(x)) == x
+ :param X: the input signal
+ :return: the inverse DCT-I of the signal over the last dimension
+ """
+ n = X.shape[-1]
+ return dct1(X) / (2 * (n - 1))
+
+
+def dct(x, norm=None):
+ """
+ Discrete Cosine Transform, Type II (a.k.a. the DCT)
+ For the meaning of the parameter `norm`, see:
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
+ :param x: the input signal
+ :param norm: the normalization, None or 'ortho'
+ :return: the DCT-II of the signal over the last dimension
+ """
+ x_shape = x.shape
+ N = x_shape[-1]
+ x = x.contiguous().view(-1, N)
+
+ v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
+
+ Vc = torch.view_as_real(torch.fft.fft(v, dim=1))
+
+ k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
+ W_r = torch.cos(k)
+ W_i = torch.sin(k)
+
+ V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
+
+ if norm == 'ortho':
+ V[:, 0] /= np.sqrt(N) * 2
+ V[:, 1:] /= np.sqrt(N / 2) * 2
+
+ V = 2 * V.view(*x_shape)
+
+ return V
+
+
+def idct(X, norm=None):
+ """
+ The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
+ Our definition of idct is that idct(dct(x)) == x
+ For the meaning of the parameter `norm`, see:
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
+ :param X: the input signal
+ :param norm: the normalization, None or 'ortho'
+ :return: the inverse DCT-II of the signal over the last dimension
+ """
+
+ x_shape = X.shape
+ N = x_shape[-1]
+
+ X_v = X.contiguous().view(-1, x_shape[-1]) / 2
+
+ if norm == 'ortho':
+ X_v[:, 0] *= np.sqrt(N) * 2
+ X_v[:, 1:] *= np.sqrt(N / 2) * 2
+
+ k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
+ W_r = torch.cos(k)
+ W_i = torch.sin(k)
+
+ V_t_r = X_v
+ V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
+
+ V_r = V_t_r * W_r - V_t_i * W_i
+ V_i = V_t_r * W_i + V_t_i * W_r
+
+ V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
+
+ v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
+ x = v.new_zeros(v.shape)
+ x[:, ::2] += v[:, :N - (N // 2)]
+ x[:, 1::2] += v.flip([1])[:, :N // 2]
+
+ return x.view(*x_shape)
+
+
+def dct_2d(x, norm=None):
+ """
+ 2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
+ For the meaning of the parameter `norm`, see:
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
+ :param x: the input signal
+ :param norm: the normalization, None or 'ortho'
+ :return: the DCT-II of the signal over the last 2 dimensions
+ """
+ X1 = dct(x, norm=norm)
+ X2 = dct(X1.transpose(-1, -2), norm=norm)
+ return X2.transpose(-1, -2)
+
+
+def idct_2d(X, norm=None):
+ """
+ The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
+ Our definition of idct is that idct_2d(dct_2d(x)) == x
+ For the meaning of the parameter `norm`, see:
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
+ :param X: the input signal
+ :param norm: the normalization, None or 'ortho'
+ :return: the DCT-II of the signal over the last 2 dimensions
+ """
+ x1 = idct(X, norm=norm)
+ x2 = idct(x1.transpose(-1, -2), norm=norm)
+ return x2.transpose(-1, -2)
+
+
+def dct_3d(x, norm=None):
+ """
+ 3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
+ For the meaning of the parameter `norm`, see:
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
+ :param x: the input signal
+ :param norm: the normalization, None or 'ortho'
+ :return: the DCT-II of the signal over the last 3 dimensions
+ """
+ X1 = dct(x, norm=norm)
+ X2 = dct(X1.transpose(-1, -2), norm=norm)
+ X3 = dct(X2.transpose(-1, -3), norm=norm)
+ return X3.transpose(-1, -3).transpose(-1, -2)
+
+
+def idct_3d(X, norm=None):
+ """
+ The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III
+ Our definition of idct is that idct_3d(dct_3d(x)) == x
+ For the meaning of the parameter `norm`, see:
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
+ :param X: the input signal
+ :param norm: the normalization, None or 'ortho'
+ :return: the DCT-II of the signal over the last 3 dimensions
+ """
+ x1 = idct(X, norm=norm)
+ x2 = idct(x1.transpose(-1, -2), norm=norm)
+ x3 = idct(x2.transpose(-1, -3), norm=norm)
+ return x3.transpose(-1, -3).transpose(-1, -2)
+
+
+class LinearDCT(nn.Linear):
+ """Implement any DCT as a linear layer; in practice this executes around
+ 50x faster on GPU. Unfortunately, the DCT matrix is stored, which will
+ increase memory usage.
+ :param in_features: size of expected input
+ :param type: which dct function in this file to use"""
+ def __init__(self, in_features, type, norm=None, bias=False):
+ self.type = type
+ self.N = in_features
+ self.norm = norm
+ super(LinearDCT, self).__init__(in_features, in_features, bias=bias)
+
+ def reset_parameters(self):
+ # initialise using dct function
+ I = torch.eye(self.N)
+ if self.type == 'dct1':
+ self.weight.data = dct1(I).data.t()
+ elif self.type == 'idct1':
+ self.weight.data = idct1(I).data.t()
+ elif self.type == 'dct':
+ self.weight.data = dct(I, norm=self.norm).data.t()
+ elif self.type == 'idct':
+ self.weight.data = idct(I, norm=self.norm).data.t()
+ self.weight.requires_grad = False # don't learn this!
+
+
+def apply_linear_2d(x, linear_layer):
+ """Can be used with a LinearDCT layer to do a 2D DCT.
+ :param x: the input signal
+ :param linear_layer: any PyTorch Linear layer
+ :return: result of linear layer applied to last 2 dimensions
+ """
+ X1 = linear_layer(x)
+ X2 = linear_layer(X1.transpose(-1, -2))
+ return X2.transpose(-1, -2)
+
+
+def apply_linear_3d(x, linear_layer):
+ """Can be used with a LinearDCT layer to do a 3D DCT.
+ :param x: the input signal
+ :param linear_layer: any PyTorch Linear layer
+ :return: result of linear layer applied to last 3 dimensions
+ """
+ X1 = linear_layer(x)
+ X2 = linear_layer(X1.transpose(-1, -2))
+ X3 = linear_layer(X2.transpose(-1, -3))
+ return X3.transpose(-1, -3).transpose(-1, -2)
diff --git a/functions/denoising.py b/functions/denoising.py
new file mode 100644
index 0000000..1fc517e
--- /dev/null
+++ b/functions/denoising.py
@@ -0,0 +1,54 @@
+import torch
+from tqdm import tqdm
+import torchvision.utils as tvu
+import os
+
+from functions.jpeg_torch import jpeg_decode as jd, jpeg_encode as je
+
+
+def compute_alpha(beta, t):
+ beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
+ a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
+ return a
+
+def jpeg_steps(x, seq, model, b, y_0, etaB, etaA, etaC, cls_fn=None, classes=None, jpeg_qf=None):
+ from functools import partial
+ jpeg_decode = partial(jd, qf = jpeg_qf)
+ jpeg_encode = partial(je, qf = jpeg_qf)
+ with torch.no_grad():
+ n = x.size(0)
+ seq_next = [-1] + list(seq[:-1])
+ x0_preds = []
+
+ a_init = compute_alpha(b, (torch.ones(n) * seq[-1]).to(x.device).long())
+
+ xs = [a_init.sqrt() * y_0 + (1 - a_init).sqrt() * torch.randn_like(x)]
+ for i, j in tqdm(zip(reversed(seq), reversed(seq_next))):
+ t = (torch.ones(n) * i).to(x.device)
+ next_t = (torch.ones(n) * j).to(x.device)
+ at = compute_alpha(b, t.long())
+ at_next = compute_alpha(b, next_t.long())
+ xt = xs[-1].to('cuda')
+ if cls_fn == None:
+ et = model(xt, t)
+ else:
+ et = model(xt, t, classes)
+ et = et[:, :3]
+ et = et - (1 - at).sqrt()[0,0,0,0] * cls_fn(x,t,classes)
+
+ if et.size(1) == 6:
+ et = et[:, :3]
+
+ x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
+
+ sigma = (1 - at).sqrt()[0, 0, 0, 0] / at.sqrt()[0, 0, 0, 0]
+ sigma_next = (1 - at_next).sqrt()[0, 0, 0, 0] / at_next.sqrt()[0, 0, 0, 0]
+
+ xt_next = x0_t
+ xt_next = xt_next - jpeg_decode(jpeg_encode(xt_next)) + jpeg_decode(jpeg_encode(y_0))
+
+ xt_next = etaB * at_next.sqrt() * xt_next + (1 - etaB) * at_next.sqrt() * x0_t + etaA * (1 - at_next).sqrt() * torch.randn_like(xt_next) + (1 - etaA) * et * (1 - at_next).sqrt()
+
+ x0_preds.append(x0_t.to('cpu'))
+ xs.append(xt_next.to('cpu'))
+ return xs, x0_preds
\ No newline at end of file
diff --git a/functions/jpeg_torch.py b/functions/jpeg_torch.py
new file mode 100644
index 0000000..5589cba
--- /dev/null
+++ b/functions/jpeg_torch.py
@@ -0,0 +1,186 @@
+import torch
+import torch.nn as nn
+from .dct import LinearDCT, apply_linear_2d
+import torch.nn.functional as F
+
+
+def torch_rgb2ycbcr(x):
+ # Assume x is a batch of size (N x C x H x W)
+ v = torch.tensor([[.299, .587, .114], [-.1687, -.3313, .5], [.5, -.4187, -.0813]]).to(x.device)
+ ycbcr = torch.tensordot(x, v, dims=([1], [1])).transpose(3, 2).transpose(2, 1)
+ ycbcr[:,1:] += 128
+ return ycbcr
+
+
+def torch_ycbcr2rgb(x):
+ # Assume x is a batch of size (N x C x H x W)
+ v = torch.tensor([[ 1.00000000e+00, -3.68199903e-05, 1.40198758e+00],
+ [ 1.00000000e+00, -3.44113281e-01, -7.14103821e-01],
+ [ 1.00000000e+00, 1.77197812e+00, -1.34583413e-04]]).to(x.device)
+ x[:, 1:] -= 128
+ rgb = torch.tensordot(x, v, dims=([1], [1])).transpose(3, 2).transpose(2, 1)
+ return rgb
+
+def chroma_subsample(x):
+ return x[:, 0:1, :, :], x[:, 1:, ::2, ::2]
+
+
+def general_quant_matrix(qf = 10):
+ q1 = torch.tensor([
+ 16, 11, 10, 16, 24, 40, 51, 61,
+ 12, 12, 14, 19, 26, 58, 60, 55,
+ 14, 13, 16, 24, 40, 57, 69, 56,
+ 14, 17, 22, 29, 51, 87, 80, 62,
+ 18, 22, 37, 56, 68, 109, 103, 77,
+ 24, 35, 55, 64, 81, 104, 113, 92,
+ 49, 64, 78, 87, 103, 121, 120, 101,
+ 72, 92, 95, 98, 112, 100, 103, 99
+ ])
+ q2 = torch.tensor([
+ 17, 18, 24, 47, 99, 99, 99, 99,
+ 18, 21, 26, 66, 99, 99, 99, 99,
+ 24, 26, 56, 99, 99, 99, 99, 99,
+ 47, 66, 99, 99, 99, 99, 99, 99,
+ 99, 99, 99, 99, 99, 99, 99, 99,
+ 99, 99, 99, 99, 99, 99, 99, 99,
+ 99, 99, 99, 99, 99, 99, 99, 99,
+ 99, 99, 99, 99, 99, 99, 99, 99
+ ])
+ s = (5000 / qf) if qf < 50 else (200 - 2 * qf)
+ q1 = torch.floor((s * q1 + 50) / 100)
+ q1[q1 <= 0] = 1
+ q1[q1 > 255] = 255
+ q2 = torch.floor((s * q2 + 50) / 100)
+ q2[q2 <= 0] = 1
+ q2[q2 > 255] = 255
+ return q1, q2
+
+
+def quantization_matrix(qf):
+ return general_quant_matrix(qf)
+ # q1 = torch.tensor([[ 80, 55, 50, 80, 120, 200, 255, 255],
+ # [ 60, 60, 70, 95, 130, 255, 255, 255],
+ # [ 70, 65, 80, 120, 200, 255, 255, 255],
+ # [ 70, 85, 110, 145, 255, 255, 255, 255],
+ # [ 90, 110, 185, 255, 255, 255, 255, 255],
+ # [120, 175, 255, 255, 255, 255, 255, 255],
+ # [245, 255, 255, 255, 255, 255, 255, 255],
+ # [255, 255, 255, 255, 255, 255, 255, 255]])
+ # q2 = torch.tensor([[ 85, 90, 120, 235, 255, 255, 255, 255],
+ # [ 90, 105, 130, 255, 255, 255, 255, 255],
+ # [120, 130, 255, 255, 255, 255, 255, 255],
+ # [235, 255, 255, 255, 255, 255, 255, 255],
+ # [255, 255, 255, 255, 255, 255, 255, 255],
+ # [255, 255, 255, 255, 255, 255, 255, 255],
+ # [255, 255, 255, 255, 255, 255, 255, 255],
+ # [255, 255, 255, 255, 255, 255, 255, 255]])
+ # return q1, q2
+
+def jpeg_encode(x, qf):
+ # Assume x is a batch of size (N x C x H x W)
+ # [-1, 1] to [0, 255]
+ x = (x + 1) / 2 * 255
+ n_batch, _, n_size, _ = x.shape
+
+ x = torch_rgb2ycbcr(x)
+ x_luma, x_chroma = chroma_subsample(x)
+ unfold = nn.Unfold(kernel_size=(8, 8), stride=(8, 8))
+ x_luma = unfold(x_luma).transpose(2, 1)
+ x_chroma = unfold(x_chroma).transpose(2, 1)
+
+ x_luma = x_luma.reshape(-1, 8, 8) - 128
+ x_chroma = x_chroma.reshape(-1, 8, 8) - 128
+
+ dct_layer = LinearDCT(8, 'dct', norm='ortho')
+ dct_layer.to(x_luma.device)
+ x_luma = apply_linear_2d(x_luma, dct_layer)
+ x_chroma = apply_linear_2d(x_chroma, dct_layer)
+
+ x_luma = x_luma.view(-1, 1, 8, 8)
+ x_chroma = x_chroma.view(-1, 2, 8, 8)
+
+ q1, q2 = quantization_matrix(qf)
+ q1 = q1.to(x_luma.device)
+ q2 = q2.to(x_luma.device)
+ x_luma /= q1.view(1, 8, 8)
+ x_chroma /= q2.view(1, 8, 8)
+
+ x_luma = x_luma.round()
+ x_chroma = x_chroma.round()
+
+ x_luma = x_luma.reshape(n_batch, (n_size // 8) ** 2, 64).transpose(2, 1)
+ x_chroma = x_chroma.reshape(n_batch, (n_size // 16) ** 2, 64 * 2).transpose(2, 1)
+
+ fold = nn.Fold(output_size=(n_size, n_size), kernel_size=(8, 8), stride=(8, 8))
+ x_luma = fold(x_luma)
+ fold = nn.Fold(output_size=(n_size // 2, n_size // 2), kernel_size=(8, 8), stride=(8, 8))
+ x_chroma = fold(x_chroma)
+
+ return [x_luma, x_chroma]
+
+
+
+def jpeg_decode(x, qf):
+ # Assume x[0] is a batch of size (N x 1 x H x W) (luma)
+ # Assume x[1:] is a batch of size (N x 2 x H/2 x W/2) (chroma)
+ x_luma, x_chroma = x
+ n_batch, _, n_size, _ = x_luma.shape
+ unfold = nn.Unfold(kernel_size=(8, 8), stride=(8, 8))
+ x_luma = unfold(x_luma).transpose(2, 1)
+ x_luma = x_luma.reshape(-1, 1, 8, 8)
+ x_chroma = unfold(x_chroma).transpose(2, 1)
+ x_chroma = x_chroma.reshape(-1, 2, 8, 8)
+
+ q1, q2 = quantization_matrix(qf)
+ q1 = q1.to(x_luma.device)
+ q2 = q2.to(x_luma.device)
+ x_luma *= q1.view(1, 8, 8)
+ x_chroma *= q2.view(1, 8, 8)
+
+ x_luma = x_luma.reshape(-1, 8, 8)
+ x_chroma = x_chroma.reshape(-1, 8, 8)
+
+ dct_layer = LinearDCT(8, 'idct', norm='ortho')
+ dct_layer.to(x_luma.device)
+ x_luma = apply_linear_2d(x_luma, dct_layer)
+ x_chroma = apply_linear_2d(x_chroma, dct_layer)
+
+ x_luma = (x_luma + 128).reshape(n_batch, (n_size // 8) ** 2, 64).transpose(2, 1)
+ x_chroma = (x_chroma + 128).reshape(n_batch, (n_size // 16) ** 2, 64 * 2).transpose(2, 1)
+
+ fold = nn.Fold(output_size=(n_size, n_size), kernel_size=(8, 8), stride=(8, 8))
+ x_luma = fold(x_luma)
+ fold = nn.Fold(output_size=(n_size // 2, n_size // 2), kernel_size=(8, 8), stride=(8, 8))
+ x_chroma = fold(x_chroma)
+
+ x_chroma_repeated = torch.zeros(n_batch, 2, n_size, n_size, device = x_luma.device)
+ x_chroma_repeated[:, :, 0::2, 0::2] = x_chroma
+ x_chroma_repeated[:, :, 0::2, 1::2] = x_chroma
+ x_chroma_repeated[:, :, 1::2, 0::2] = x_chroma
+ x_chroma_repeated[:, :, 1::2, 1::2] = x_chroma
+
+ x = torch.cat([x_luma, x_chroma_repeated], dim=1)
+
+ x = torch_ycbcr2rgb(x)
+
+ # [0, 255] to [-1, 1]
+ x = x / 255 * 2 - 1
+
+ return x
+
+def quantization_encode(x, qf):
+ qf = 32
+ #to int
+ x = (x + 1) / 2
+ x = x * 255
+ x = x.int()
+ # quantize
+ x = x // qf
+ #to float
+ x = x.float()
+ x = x / (255/qf)
+ x = (x * 2) - 1
+ return x
+
+def quantization_decode(x, qf):
+ return x
\ No newline at end of file
diff --git a/guided_diffusion/__pycache__/fp16_util.cpython-38.pyc b/guided_diffusion/__pycache__/fp16_util.cpython-38.pyc
new file mode 100644
index 0000000..1319da8
Binary files /dev/null and b/guided_diffusion/__pycache__/fp16_util.cpython-38.pyc differ
diff --git a/guided_diffusion/__pycache__/fp16_util.cpython-39.pyc b/guided_diffusion/__pycache__/fp16_util.cpython-39.pyc
new file mode 100644
index 0000000..3e69d38
Binary files /dev/null and b/guided_diffusion/__pycache__/fp16_util.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/logger.cpython-38.pyc b/guided_diffusion/__pycache__/logger.cpython-38.pyc
new file mode 100644
index 0000000..87138b9
Binary files /dev/null and b/guided_diffusion/__pycache__/logger.cpython-38.pyc differ
diff --git a/guided_diffusion/__pycache__/logger.cpython-39.pyc b/guided_diffusion/__pycache__/logger.cpython-39.pyc
new file mode 100644
index 0000000..65664a8
Binary files /dev/null and b/guided_diffusion/__pycache__/logger.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/nn.cpython-38.pyc b/guided_diffusion/__pycache__/nn.cpython-38.pyc
new file mode 100644
index 0000000..c4dbc2d
Binary files /dev/null and b/guided_diffusion/__pycache__/nn.cpython-38.pyc differ
diff --git a/guided_diffusion/__pycache__/nn.cpython-39.pyc b/guided_diffusion/__pycache__/nn.cpython-39.pyc
new file mode 100644
index 0000000..de18e75
Binary files /dev/null and b/guided_diffusion/__pycache__/nn.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/script_util.cpython-38.pyc b/guided_diffusion/__pycache__/script_util.cpython-38.pyc
new file mode 100644
index 0000000..ff34250
Binary files /dev/null and b/guided_diffusion/__pycache__/script_util.cpython-38.pyc differ
diff --git a/guided_diffusion/__pycache__/script_util.cpython-39.pyc b/guided_diffusion/__pycache__/script_util.cpython-39.pyc
new file mode 100644
index 0000000..4b93a64
Binary files /dev/null and b/guided_diffusion/__pycache__/script_util.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/unet.cpython-38.pyc b/guided_diffusion/__pycache__/unet.cpython-38.pyc
new file mode 100644
index 0000000..cfb53e9
Binary files /dev/null and b/guided_diffusion/__pycache__/unet.cpython-38.pyc differ
diff --git a/guided_diffusion/__pycache__/unet.cpython-39.pyc b/guided_diffusion/__pycache__/unet.cpython-39.pyc
new file mode 100644
index 0000000..c7c51ae
Binary files /dev/null and b/guided_diffusion/__pycache__/unet.cpython-39.pyc differ
diff --git a/guided_diffusion/fp16_util.py b/guided_diffusion/fp16_util.py
new file mode 100644
index 0000000..35a3f46
--- /dev/null
+++ b/guided_diffusion/fp16_util.py
@@ -0,0 +1,236 @@
+"""
+Helpers to train with 16-bit precision.
+"""
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
+
+from . import logger
+
+INITIAL_LOG_LOSS_SCALE = 20.0
+
+
+def convert_module_to_f16(l):
+ """
+ Convert primitive modules to float16.
+ """
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+
+def convert_module_to_f32(l):
+ """
+ Convert primitive modules to float32, undoing convert_module_to_f16().
+ """
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ l.weight.data = l.weight.data.float()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.float()
+
+
+def make_master_params(param_groups_and_shapes):
+ """
+ Copy model parameters into a (differently-shaped) list of full-precision
+ parameters.
+ """
+ master_params = []
+ for param_group, shape in param_groups_and_shapes:
+ master_param = nn.Parameter(
+ _flatten_dense_tensors(
+ [param.detach().float() for (_, param) in param_group]
+ ).view(shape)
+ )
+ master_param.requires_grad = True
+ master_params.append(master_param)
+ return master_params
+
+
+def model_grads_to_master_grads(param_groups_and_shapes, master_params):
+ """
+ Copy the gradients from the model parameters into the master parameters
+ from make_master_params().
+ """
+ for master_param, (param_group, shape) in zip(
+ master_params, param_groups_and_shapes
+ ):
+ master_param.grad = _flatten_dense_tensors(
+ [param_grad_or_zeros(param) for (_, param) in param_group]
+ ).view(shape)
+
+
+def master_params_to_model_params(param_groups_and_shapes, master_params):
+ """
+ Copy the master parameter data back into the model parameters.
+ """
+ # Without copying to a list, if a generator is passed, this will
+ # silently not copy any parameters.
+ for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
+ for (_, param), unflat_master_param in zip(
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
+ ):
+ param.detach().copy_(unflat_master_param)
+
+
+def unflatten_master_params(param_group, master_param):
+ return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
+
+
+def get_param_groups_and_shapes(named_model_params):
+ named_model_params = list(named_model_params)
+ scalar_vector_named_params = (
+ [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
+ (-1),
+ )
+ matrix_named_params = (
+ [(n, p) for (n, p) in named_model_params if p.ndim > 1],
+ (1, -1),
+ )
+ return [scalar_vector_named_params, matrix_named_params]
+
+
+def master_params_to_state_dict(
+ model, param_groups_and_shapes, master_params, use_fp16
+):
+ if use_fp16:
+ state_dict = model.state_dict()
+ for master_param, (param_group, _) in zip(
+ master_params, param_groups_and_shapes
+ ):
+ for (name, _), unflat_master_param in zip(
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
+ ):
+ assert name in state_dict
+ state_dict[name] = unflat_master_param
+ else:
+ state_dict = model.state_dict()
+ for i, (name, _value) in enumerate(model.named_parameters()):
+ assert name in state_dict
+ state_dict[name] = master_params[i]
+ return state_dict
+
+
+def state_dict_to_master_params(model, state_dict, use_fp16):
+ if use_fp16:
+ named_model_params = [
+ (name, state_dict[name]) for name, _ in model.named_parameters()
+ ]
+ param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
+ master_params = make_master_params(param_groups_and_shapes)
+ else:
+ master_params = [state_dict[name] for name, _ in model.named_parameters()]
+ return master_params
+
+
+def zero_master_grads(master_params):
+ for param in master_params:
+ param.grad = None
+
+
+def zero_grad(model_params):
+ for param in model_params:
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
+ if param.grad is not None:
+ param.grad.detach_()
+ param.grad.zero_()
+
+
+def param_grad_or_zeros(param):
+ if param.grad is not None:
+ return param.grad.data.detach()
+ else:
+ return th.zeros_like(param)
+
+
+class MixedPrecisionTrainer:
+ def __init__(
+ self,
+ *,
+ model,
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
+ ):
+ self.model = model
+ self.use_fp16 = use_fp16
+ self.fp16_scale_growth = fp16_scale_growth
+
+ self.model_params = list(self.model.parameters())
+ self.master_params = self.model_params
+ self.param_groups_and_shapes = None
+ self.lg_loss_scale = initial_lg_loss_scale
+
+ if self.use_fp16:
+ self.param_groups_and_shapes = get_param_groups_and_shapes(
+ self.model.named_parameters()
+ )
+ self.master_params = make_master_params(self.param_groups_and_shapes)
+ self.model.convert_to_fp16()
+
+ def zero_grad(self):
+ zero_grad(self.model_params)
+
+ def backward(self, loss: th.Tensor):
+ if self.use_fp16:
+ loss_scale = 2 ** self.lg_loss_scale
+ (loss * loss_scale).backward()
+ else:
+ loss.backward()
+
+ def optimize(self, opt: th.optim.Optimizer):
+ if self.use_fp16:
+ return self._optimize_fp16(opt)
+ else:
+ return self._optimize_normal(opt)
+
+ def _optimize_fp16(self, opt: th.optim.Optimizer):
+ logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
+ model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
+ grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
+ if check_overflow(grad_norm):
+ self.lg_loss_scale -= 1
+ logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
+ zero_master_grads(self.master_params)
+ return False
+
+ logger.logkv_mean("grad_norm", grad_norm)
+ logger.logkv_mean("param_norm", param_norm)
+
+ self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
+ opt.step()
+ zero_master_grads(self.master_params)
+ master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
+ self.lg_loss_scale += self.fp16_scale_growth
+ return True
+
+ def _optimize_normal(self, opt: th.optim.Optimizer):
+ grad_norm, param_norm = self._compute_norms()
+ logger.logkv_mean("grad_norm", grad_norm)
+ logger.logkv_mean("param_norm", param_norm)
+ opt.step()
+ return True
+
+ def _compute_norms(self, grad_scale=1.0):
+ grad_norm = 0.0
+ param_norm = 0.0
+ for p in self.master_params:
+ with th.no_grad():
+ param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
+ if p.grad is not None:
+ grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
+ return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
+
+ def master_params_to_state_dict(self, master_params):
+ return master_params_to_state_dict(
+ self.model, self.param_groups_and_shapes, master_params, self.use_fp16
+ )
+
+ def state_dict_to_master_params(self, state_dict):
+ return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
+
+
+def check_overflow(value):
+ return (value == float("inf")) or (value == -float("inf")) or (value != value)
diff --git a/guided_diffusion/logger.py b/guided_diffusion/logger.py
new file mode 100644
index 0000000..b1d856d
--- /dev/null
+++ b/guided_diffusion/logger.py
@@ -0,0 +1,495 @@
+"""
+Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
+https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
+"""
+
+import os
+import sys
+import shutil
+import os.path as osp
+import json
+import time
+import datetime
+import tempfile
+import warnings
+from collections import defaultdict
+from contextlib import contextmanager
+
+DEBUG = 10
+INFO = 20
+WARN = 30
+ERROR = 40
+
+DISABLED = 50
+
+
+class KVWriter(object):
+ def writekvs(self, kvs):
+ raise NotImplementedError
+
+
+class SeqWriter(object):
+ def writeseq(self, seq):
+ raise NotImplementedError
+
+
+class HumanOutputFormat(KVWriter, SeqWriter):
+ def __init__(self, filename_or_file):
+ if isinstance(filename_or_file, str):
+ self.file = open(filename_or_file, "wt")
+ self.own_file = True
+ else:
+ assert hasattr(filename_or_file, "read"), (
+ "expected file or str, got %s" % filename_or_file
+ )
+ self.file = filename_or_file
+ self.own_file = False
+
+ def writekvs(self, kvs):
+ # Create strings for printing
+ key2str = {}
+ for (key, val) in sorted(kvs.items()):
+ if hasattr(val, "__float__"):
+ valstr = "%-8.3g" % val
+ else:
+ valstr = str(val)
+ key2str[self._truncate(key)] = self._truncate(valstr)
+
+ # Find max widths
+ if len(key2str) == 0:
+ print("WARNING: tried to write empty key-value dict")
+ return
+ else:
+ keywidth = max(map(len, key2str.keys()))
+ valwidth = max(map(len, key2str.values()))
+
+ # Write out the data
+ dashes = "-" * (keywidth + valwidth + 7)
+ lines = [dashes]
+ for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
+ lines.append(
+ "| %s%s | %s%s |"
+ % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
+ )
+ lines.append(dashes)
+ self.file.write("\n".join(lines) + "\n")
+
+ # Flush the output to the file
+ self.file.flush()
+
+ def _truncate(self, s):
+ maxlen = 30
+ return s[: maxlen - 3] + "..." if len(s) > maxlen else s
+
+ def writeseq(self, seq):
+ seq = list(seq)
+ for (i, elem) in enumerate(seq):
+ self.file.write(elem)
+ if i < len(seq) - 1: # add space unless this is the last one
+ self.file.write(" ")
+ self.file.write("\n")
+ self.file.flush()
+
+ def close(self):
+ if self.own_file:
+ self.file.close()
+
+
+class JSONOutputFormat(KVWriter):
+ def __init__(self, filename):
+ self.file = open(filename, "wt")
+
+ def writekvs(self, kvs):
+ for k, v in sorted(kvs.items()):
+ if hasattr(v, "dtype"):
+ kvs[k] = float(v)
+ self.file.write(json.dumps(kvs) + "\n")
+ self.file.flush()
+
+ def close(self):
+ self.file.close()
+
+
+class CSVOutputFormat(KVWriter):
+ def __init__(self, filename):
+ self.file = open(filename, "w+t")
+ self.keys = []
+ self.sep = ","
+
+ def writekvs(self, kvs):
+ # Add our current row to the history
+ extra_keys = list(kvs.keys() - self.keys)
+ extra_keys.sort()
+ if extra_keys:
+ self.keys.extend(extra_keys)
+ self.file.seek(0)
+ lines = self.file.readlines()
+ self.file.seek(0)
+ for (i, k) in enumerate(self.keys):
+ if i > 0:
+ self.file.write(",")
+ self.file.write(k)
+ self.file.write("\n")
+ for line in lines[1:]:
+ self.file.write(line[:-1])
+ self.file.write(self.sep * len(extra_keys))
+ self.file.write("\n")
+ for (i, k) in enumerate(self.keys):
+ if i > 0:
+ self.file.write(",")
+ v = kvs.get(k)
+ if v is not None:
+ self.file.write(str(v))
+ self.file.write("\n")
+ self.file.flush()
+
+ def close(self):
+ self.file.close()
+
+
+class TensorBoardOutputFormat(KVWriter):
+ """
+ Dumps key/value pairs into TensorBoard's numeric format.
+ """
+
+ def __init__(self, dir):
+ os.makedirs(dir, exist_ok=True)
+ self.dir = dir
+ self.step = 1
+ prefix = "events"
+ path = osp.join(osp.abspath(dir), prefix)
+ import tensorflow as tf
+ from tensorflow.python import pywrap_tensorflow
+ from tensorflow.core.util import event_pb2
+ from tensorflow.python.util import compat
+
+ self.tf = tf
+ self.event_pb2 = event_pb2
+ self.pywrap_tensorflow = pywrap_tensorflow
+ self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
+
+ def writekvs(self, kvs):
+ def summary_val(k, v):
+ kwargs = {"tag": k, "simple_value": float(v)}
+ return self.tf.Summary.Value(**kwargs)
+
+ summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
+ event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
+ event.step = (
+ self.step
+ ) # is there any reason why you'd want to specify the step?
+ self.writer.WriteEvent(event)
+ self.writer.Flush()
+ self.step += 1
+
+ def close(self):
+ if self.writer:
+ self.writer.Close()
+ self.writer = None
+
+
+def make_output_format(format, ev_dir, log_suffix=""):
+ os.makedirs(ev_dir, exist_ok=True)
+ if format == "stdout":
+ return HumanOutputFormat(sys.stdout)
+ elif format == "log":
+ return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
+ elif format == "json":
+ return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
+ elif format == "csv":
+ return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
+ elif format == "tensorboard":
+ return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
+ else:
+ raise ValueError("Unknown format specified: %s" % (format,))
+
+
+# ================================================================
+# API
+# ================================================================
+
+
+def logkv(key, val):
+ """
+ Log a value of some diagnostic
+ Call this once for each diagnostic quantity, each iteration
+ If called many times, last value will be used.
+ """
+ get_current().logkv(key, val)
+
+
+def logkv_mean(key, val):
+ """
+ The same as logkv(), but if called many times, values averaged.
+ """
+ get_current().logkv_mean(key, val)
+
+
+def logkvs(d):
+ """
+ Log a dictionary of key-value pairs
+ """
+ for (k, v) in d.items():
+ logkv(k, v)
+
+
+def dumpkvs():
+ """
+ Write all of the diagnostics from the current iteration
+ """
+ return get_current().dumpkvs()
+
+
+def getkvs():
+ return get_current().name2val
+
+
+def log(*args, level=INFO):
+ """
+ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
+ """
+ get_current().log(*args, level=level)
+
+
+def debug(*args):
+ log(*args, level=DEBUG)
+
+
+def info(*args):
+ log(*args, level=INFO)
+
+
+def warn(*args):
+ log(*args, level=WARN)
+
+
+def error(*args):
+ log(*args, level=ERROR)
+
+
+def set_level(level):
+ """
+ Set logging threshold on current logger.
+ """
+ get_current().set_level(level)
+
+
+def set_comm(comm):
+ get_current().set_comm(comm)
+
+
+def get_dir():
+ """
+ Get directory that log files are being written to.
+ will be None if there is no output directory (i.e., if you didn't call start)
+ """
+ return get_current().get_dir()
+
+
+record_tabular = logkv
+dump_tabular = dumpkvs
+
+
+@contextmanager
+def profile_kv(scopename):
+ logkey = "wait_" + scopename
+ tstart = time.time()
+ try:
+ yield
+ finally:
+ get_current().name2val[logkey] += time.time() - tstart
+
+
+def profile(n):
+ """
+ Usage:
+ @profile("my_func")
+ def my_func(): code
+ """
+
+ def decorator_with_name(func):
+ def func_wrapper(*args, **kwargs):
+ with profile_kv(n):
+ return func(*args, **kwargs)
+
+ return func_wrapper
+
+ return decorator_with_name
+
+
+# ================================================================
+# Backend
+# ================================================================
+
+
+def get_current():
+ if Logger.CURRENT is None:
+ _configure_default_logger()
+
+ return Logger.CURRENT
+
+
+class Logger(object):
+ DEFAULT = None # A logger with no output files. (See right below class definition)
+ # So that you can still log to the terminal without setting up any output files
+ CURRENT = None # Current logger being used by the free functions above
+
+ def __init__(self, dir, output_formats, comm=None):
+ self.name2val = defaultdict(float) # values this iteration
+ self.name2cnt = defaultdict(int)
+ self.level = INFO
+ self.dir = dir
+ self.output_formats = output_formats
+ self.comm = comm
+
+ # Logging API, forwarded
+ # ----------------------------------------
+ def logkv(self, key, val):
+ self.name2val[key] = val
+
+ def logkv_mean(self, key, val):
+ oldval, cnt = self.name2val[key], self.name2cnt[key]
+ self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
+ self.name2cnt[key] = cnt + 1
+
+ def dumpkvs(self):
+ if self.comm is None:
+ d = self.name2val
+ else:
+ d = mpi_weighted_mean(
+ self.comm,
+ {
+ name: (val, self.name2cnt.get(name, 1))
+ for (name, val) in self.name2val.items()
+ },
+ )
+ if self.comm.rank != 0:
+ d["dummy"] = 1 # so we don't get a warning about empty dict
+ out = d.copy() # Return the dict for unit testing purposes
+ for fmt in self.output_formats:
+ if isinstance(fmt, KVWriter):
+ fmt.writekvs(d)
+ self.name2val.clear()
+ self.name2cnt.clear()
+ return out
+
+ def log(self, *args, level=INFO):
+ if self.level <= level:
+ self._do_log(args)
+
+ # Configuration
+ # ----------------------------------------
+ def set_level(self, level):
+ self.level = level
+
+ def set_comm(self, comm):
+ self.comm = comm
+
+ def get_dir(self):
+ return self.dir
+
+ def close(self):
+ for fmt in self.output_formats:
+ fmt.close()
+
+ # Misc
+ # ----------------------------------------
+ def _do_log(self, args):
+ for fmt in self.output_formats:
+ if isinstance(fmt, SeqWriter):
+ fmt.writeseq(map(str, args))
+
+
+def get_rank_without_mpi_import():
+ # check environment variables here instead of importing mpi4py
+ # to avoid calling MPI_Init() when this module is imported
+ for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
+ if varname in os.environ:
+ return int(os.environ[varname])
+ return 0
+
+
+def mpi_weighted_mean(comm, local_name2valcount):
+ """
+ Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
+ Perform a weighted average over dicts that are each on a different node
+ Input: local_name2valcount: dict mapping key -> (value, count)
+ Returns: key -> mean
+ """
+ all_name2valcount = comm.gather(local_name2valcount)
+ if comm.rank == 0:
+ name2sum = defaultdict(float)
+ name2count = defaultdict(float)
+ for n2vc in all_name2valcount:
+ for (name, (val, count)) in n2vc.items():
+ try:
+ val = float(val)
+ except ValueError:
+ if comm.rank == 0:
+ warnings.warn(
+ "WARNING: tried to compute mean on non-float {}={}".format(
+ name, val
+ )
+ )
+ else:
+ name2sum[name] += val * count
+ name2count[name] += count
+ return {name: name2sum[name] / name2count[name] for name in name2sum}
+ else:
+ return {}
+
+
+def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
+ """
+ If comm is provided, average all numerical stats across that comm
+ """
+ if dir is None:
+ dir = os.getenv("OPENAI_LOGDIR")
+ if dir is None:
+ dir = osp.join(
+ tempfile.gettempdir(),
+ datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
+ )
+ assert isinstance(dir, str)
+ dir = os.path.expanduser(dir)
+ os.makedirs(os.path.expanduser(dir), exist_ok=True)
+
+ rank = get_rank_without_mpi_import()
+ if rank > 0:
+ log_suffix = log_suffix + "-rank%03i" % rank
+
+ if format_strs is None:
+ if rank == 0:
+ format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
+ else:
+ format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
+ format_strs = filter(None, format_strs)
+ output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
+
+ Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
+ if output_formats:
+ log("Logging to %s" % dir)
+
+
+def _configure_default_logger():
+ configure()
+ Logger.DEFAULT = Logger.CURRENT
+
+
+def reset():
+ if Logger.CURRENT is not Logger.DEFAULT:
+ Logger.CURRENT.close()
+ Logger.CURRENT = Logger.DEFAULT
+ log("Reset logger")
+
+
+@contextmanager
+def scoped_configure(dir=None, format_strs=None, comm=None):
+ prevlogger = Logger.CURRENT
+ configure(dir=dir, format_strs=format_strs, comm=comm)
+ try:
+ yield
+ finally:
+ Logger.CURRENT.close()
+ Logger.CURRENT = prevlogger
+
diff --git a/guided_diffusion/nn.py b/guided_diffusion/nn.py
new file mode 100644
index 0000000..a4cd59c
--- /dev/null
+++ b/guided_diffusion/nn.py
@@ -0,0 +1,170 @@
+"""
+Various utilities for neural networks.
+"""
+
+import math
+
+import torch as th
+import torch.nn as nn
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * th.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def update_ema(target_params, source_params, rate=0.99):
+ """
+ Update target parameters to be closer to those of source parameters using
+ an exponential moving average.
+
+ :param target_params: the target parameter sequence.
+ :param source_params: the source parameter sequence.
+ :param rate: the EMA rate (closer to 1 means slower).
+ """
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = th.exp(
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(th.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ with th.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with th.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = th.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
diff --git a/guided_diffusion/script_util.py b/guided_diffusion/script_util.py
new file mode 100644
index 0000000..6df6933
--- /dev/null
+++ b/guided_diffusion/script_util.py
@@ -0,0 +1,453 @@
+import argparse
+import inspect
+
+#from . import gaussian_diffusion as gd
+#from .respace import SpacedDiffusion, space_timesteps
+from .unet import SuperResModel, UNetModel, EncoderUNetModel
+
+NUM_CLASSES = 1000
+
+
+def diffusion_defaults():
+ """
+ Defaults for image and classifier training.
+ """
+ return dict(
+ learn_sigma=False,
+ diffusion_steps=1000,
+ noise_schedule="linear",
+ timestep_respacing="",
+ use_kl=False,
+ predict_xstart=False,
+ rescale_timesteps=False,
+ rescale_learned_sigmas=False,
+ )
+
+
+def classifier_defaults():
+ """
+ Defaults for classifier models.
+ """
+ return dict(
+ image_size=64,
+ classifier_use_fp16=False,
+ classifier_width=128,
+ classifier_depth=2,
+ classifier_attention_resolutions="32,16,8", # 16
+ classifier_use_scale_shift_norm=True, # False
+ classifier_resblock_updown=True, # False
+ classifier_pool="attention",
+ )
+
+
+def model_and_diffusion_defaults():
+ """
+ Defaults for image training.
+ """
+ res = dict(
+ image_size=64,
+ num_channels=128,
+ num_res_blocks=2,
+ num_heads=4,
+ num_heads_upsample=-1,
+ num_head_channels=-1,
+ attention_resolutions="16,8",
+ channel_mult="",
+ dropout=0.0,
+ class_cond=False,
+ use_checkpoint=False,
+ use_scale_shift_norm=True,
+ resblock_updown=False,
+ use_fp16=False,
+ use_new_attention_order=False,
+ )
+ res.update(diffusion_defaults())
+ return res
+
+
+def classifier_and_diffusion_defaults():
+ res = classifier_defaults()
+ res.update(diffusion_defaults())
+ return res
+
+
+def create_model_and_diffusion(
+ image_size,
+ class_cond,
+ learn_sigma,
+ num_channels,
+ num_res_blocks,
+ channel_mult,
+ num_heads,
+ num_head_channels,
+ num_heads_upsample,
+ attention_resolutions,
+ dropout,
+ diffusion_steps,
+ noise_schedule,
+ timestep_respacing,
+ use_kl,
+ predict_xstart,
+ rescale_timesteps,
+ rescale_learned_sigmas,
+ use_checkpoint,
+ use_scale_shift_norm,
+ resblock_updown,
+ use_fp16,
+ use_new_attention_order,
+):
+ model = create_model(
+ image_size,
+ num_channels,
+ num_res_blocks,
+ channel_mult=channel_mult,
+ learn_sigma=learn_sigma,
+ class_cond=class_cond,
+ use_checkpoint=use_checkpoint,
+ attention_resolutions=attention_resolutions,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ num_heads_upsample=num_heads_upsample,
+ use_scale_shift_norm=use_scale_shift_norm,
+ dropout=dropout,
+ resblock_updown=resblock_updown,
+ use_fp16=use_fp16,
+ use_new_attention_order=use_new_attention_order,
+ )
+ diffusion = create_gaussian_diffusion(
+ steps=diffusion_steps,
+ learn_sigma=learn_sigma,
+ noise_schedule=noise_schedule,
+ use_kl=use_kl,
+ predict_xstart=predict_xstart,
+ rescale_timesteps=rescale_timesteps,
+ rescale_learned_sigmas=rescale_learned_sigmas,
+ timestep_respacing=timestep_respacing,
+ )
+ return model, diffusion
+
+
+def create_model(
+ image_size,
+ num_channels,
+ num_res_blocks,
+ channel_mult="",
+ learn_sigma=False,
+ class_cond=False,
+ use_checkpoint=False,
+ attention_resolutions="16",
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ dropout=0,
+ resblock_updown=False,
+ use_fp16=False,
+ use_new_attention_order=False,
+ **kwargs
+):
+ if channel_mult == "":
+ if image_size == 512:
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
+ elif image_size == 256:
+ channel_mult = (1, 1, 2, 2, 4, 4)
+ elif image_size == 128:
+ channel_mult = (1, 1, 2, 3, 4)
+ elif image_size == 64:
+ channel_mult = (1, 2, 3, 4)
+ else:
+ raise ValueError(f"unsupported image size: {image_size}")
+ else:
+ channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
+
+ attention_ds = []
+ for res in attention_resolutions.split(","):
+ attention_ds.append(image_size // int(res))
+
+ return UNetModel(
+ image_size=image_size,
+ in_channels=3,
+ model_channels=num_channels,
+ out_channels=(3 if not learn_sigma else 6),
+ num_res_blocks=num_res_blocks,
+ attention_resolutions=tuple(attention_ds),
+ dropout=dropout,
+ channel_mult=channel_mult,
+ num_classes=(NUM_CLASSES if class_cond else None),
+ use_checkpoint=use_checkpoint,
+ use_fp16=use_fp16,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ num_heads_upsample=num_heads_upsample,
+ use_scale_shift_norm=use_scale_shift_norm,
+ resblock_updown=resblock_updown,
+ use_new_attention_order=use_new_attention_order,
+ )
+
+
+def create_classifier_and_diffusion(
+ image_size,
+ classifier_use_fp16,
+ classifier_width,
+ classifier_depth,
+ classifier_attention_resolutions,
+ classifier_use_scale_shift_norm,
+ classifier_resblock_updown,
+ classifier_pool,
+ learn_sigma,
+ diffusion_steps,
+ noise_schedule,
+ timestep_respacing,
+ use_kl,
+ predict_xstart,
+ rescale_timesteps,
+ rescale_learned_sigmas,
+):
+ classifier = create_classifier(
+ image_size,
+ classifier_use_fp16,
+ classifier_width,
+ classifier_depth,
+ classifier_attention_resolutions,
+ classifier_use_scale_shift_norm,
+ classifier_resblock_updown,
+ classifier_pool,
+ )
+ diffusion = create_gaussian_diffusion(
+ steps=diffusion_steps,
+ learn_sigma=learn_sigma,
+ noise_schedule=noise_schedule,
+ use_kl=use_kl,
+ predict_xstart=predict_xstart,
+ rescale_timesteps=rescale_timesteps,
+ rescale_learned_sigmas=rescale_learned_sigmas,
+ timestep_respacing=timestep_respacing,
+ )
+ return classifier, diffusion
+
+
+def create_classifier(
+ image_size,
+ classifier_use_fp16,
+ classifier_width,
+ classifier_depth,
+ classifier_attention_resolutions,
+ classifier_use_scale_shift_norm,
+ classifier_resblock_updown,
+ classifier_pool,
+):
+ if image_size == 512:
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
+ elif image_size == 256:
+ channel_mult = (1, 1, 2, 2, 4, 4)
+ elif image_size == 128:
+ channel_mult = (1, 1, 2, 3, 4)
+ elif image_size == 64:
+ channel_mult = (1, 2, 3, 4)
+ else:
+ raise ValueError(f"unsupported image size: {image_size}")
+
+ attention_ds = []
+ for res in classifier_attention_resolutions.split(","):
+ attention_ds.append(image_size // int(res))
+
+ return EncoderUNetModel(
+ image_size=image_size,
+ in_channels=3,
+ model_channels=classifier_width,
+ out_channels=1000,
+ num_res_blocks=classifier_depth,
+ attention_resolutions=tuple(attention_ds),
+ channel_mult=channel_mult,
+ use_fp16=classifier_use_fp16,
+ num_head_channels=64,
+ use_scale_shift_norm=classifier_use_scale_shift_norm,
+ resblock_updown=classifier_resblock_updown,
+ pool=classifier_pool,
+ )
+
+
+def sr_model_and_diffusion_defaults():
+ res = model_and_diffusion_defaults()
+ res["large_size"] = 256
+ res["small_size"] = 64
+ arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
+ for k in res.copy().keys():
+ if k not in arg_names:
+ del res[k]
+ return res
+
+
+def sr_create_model_and_diffusion(
+ large_size,
+ small_size,
+ class_cond,
+ learn_sigma,
+ num_channels,
+ num_res_blocks,
+ num_heads,
+ num_head_channels,
+ num_heads_upsample,
+ attention_resolutions,
+ dropout,
+ diffusion_steps,
+ noise_schedule,
+ timestep_respacing,
+ use_kl,
+ predict_xstart,
+ rescale_timesteps,
+ rescale_learned_sigmas,
+ use_checkpoint,
+ use_scale_shift_norm,
+ resblock_updown,
+ use_fp16,
+):
+ model = sr_create_model(
+ large_size,
+ small_size,
+ num_channels,
+ num_res_blocks,
+ learn_sigma=learn_sigma,
+ class_cond=class_cond,
+ use_checkpoint=use_checkpoint,
+ attention_resolutions=attention_resolutions,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ num_heads_upsample=num_heads_upsample,
+ use_scale_shift_norm=use_scale_shift_norm,
+ dropout=dropout,
+ resblock_updown=resblock_updown,
+ use_fp16=use_fp16,
+ )
+ diffusion = create_gaussian_diffusion(
+ steps=diffusion_steps,
+ learn_sigma=learn_sigma,
+ noise_schedule=noise_schedule,
+ use_kl=use_kl,
+ predict_xstart=predict_xstart,
+ rescale_timesteps=rescale_timesteps,
+ rescale_learned_sigmas=rescale_learned_sigmas,
+ timestep_respacing=timestep_respacing,
+ )
+ return model, diffusion
+
+
+def sr_create_model(
+ large_size,
+ small_size,
+ num_channels,
+ num_res_blocks,
+ learn_sigma,
+ class_cond,
+ use_checkpoint,
+ attention_resolutions,
+ num_heads,
+ num_head_channels,
+ num_heads_upsample,
+ use_scale_shift_norm,
+ dropout,
+ resblock_updown,
+ use_fp16,
+):
+ _ = small_size # hack to prevent unused variable
+
+ if large_size == 512:
+ channel_mult = (1, 1, 2, 2, 4, 4)
+ elif large_size == 256:
+ channel_mult = (1, 1, 2, 2, 4, 4)
+ elif large_size == 64:
+ channel_mult = (1, 2, 3, 4)
+ else:
+ raise ValueError(f"unsupported large size: {large_size}")
+
+ attention_ds = []
+ for res in attention_resolutions.split(","):
+ attention_ds.append(large_size // int(res))
+
+ return SuperResModel(
+ image_size=large_size,
+ in_channels=3,
+ model_channels=num_channels,
+ out_channels=(3 if not learn_sigma else 6),
+ num_res_blocks=num_res_blocks,
+ attention_resolutions=tuple(attention_ds),
+ dropout=dropout,
+ channel_mult=channel_mult,
+ num_classes=(NUM_CLASSES if class_cond else None),
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ num_heads_upsample=num_heads_upsample,
+ use_scale_shift_norm=use_scale_shift_norm,
+ resblock_updown=resblock_updown,
+ use_fp16=use_fp16,
+ )
+
+
+def create_gaussian_diffusion(
+ *,
+ steps=1000,
+ learn_sigma=False,
+ sigma_small=False,
+ noise_schedule="linear",
+ use_kl=False,
+ predict_xstart=False,
+ rescale_timesteps=False,
+ rescale_learned_sigmas=False,
+ timestep_respacing="",
+):
+ betas = gd.get_named_beta_schedule(noise_schedule, steps)
+ if use_kl:
+ loss_type = gd.LossType.RESCALED_KL
+ elif rescale_learned_sigmas:
+ loss_type = gd.LossType.RESCALED_MSE
+ else:
+ loss_type = gd.LossType.MSE
+ if not timestep_respacing:
+ timestep_respacing = [steps]
+ return SpacedDiffusion(
+ use_timesteps=space_timesteps(steps, timestep_respacing),
+ betas=betas,
+ model_mean_type=(
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
+ ),
+ model_var_type=(
+ (
+ gd.ModelVarType.FIXED_LARGE
+ if not sigma_small
+ else gd.ModelVarType.FIXED_SMALL
+ )
+ if not learn_sigma
+ else gd.ModelVarType.LEARNED_RANGE
+ ),
+ loss_type=loss_type,
+ rescale_timesteps=rescale_timesteps,
+ )
+
+
+def add_dict_to_argparser(parser, default_dict):
+ for k, v in default_dict.items():
+ v_type = type(v)
+ if v is None:
+ v_type = str
+ elif isinstance(v, bool):
+ v_type = str2bool
+ parser.add_argument(f"--{k}", default=v, type=v_type)
+
+
+def args_to_dict(args, keys):
+ return {k: getattr(args, k) for k in keys}
+
+
+def str2bool(v):
+ """
+ https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
+ """
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif v.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("boolean value expected")
diff --git a/guided_diffusion/unet.py b/guided_diffusion/unet.py
new file mode 100644
index 0000000..e80f050
--- /dev/null
+++ b/guided_diffusion/unet.py
@@ -0,0 +1,895 @@
+from abc import abstractmethod
+
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .fp16_util import convert_module_to_f16, convert_module_to_f32
+from .nn import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+
+
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
+ )
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True)
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ **kwargs
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ ch = input_ch = int(channel_mult[0] * model_channels)
+ self.input_blocks = nn.ModuleList(
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
+ )
+ self._feature_size = ch
+ input_block_chans = [ch]
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=int(mult * model_channels),
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(mult * model_channels)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=int(model_channels * mult),
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(model_channels * mult)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps, y=None):
+ """
+ Apply the model to an input batch.
+
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+
+ hs = []
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0],)
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ hs.append(h)
+ h = self.middle_block(h, emb)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb)
+ h = h.type(x.dtype)
+ return self.out(h)
+
+
+class SuperResModel(UNetModel):
+ """
+ A UNetModel that performs super-resolution.
+
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
+ """
+
+ def __init__(self, image_size, in_channels, *args, **kwargs):
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
+
+ def forward(self, x, timesteps, low_res=None, **kwargs):
+ _, _, new_height, new_width = x.shape
+ upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
+ x = th.cat([x, upsampled], dim=1)
+ return super().forward(x, timesteps, **kwargs)
+
+
+class EncoderUNetModel(nn.Module):
+ """
+ The half UNet model with attention and timestep embedding.
+
+ For usage, see UNet.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ pool="adaptive",
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ ch = int(channel_mult[0] * model_channels)
+ self.input_blocks = nn.ModuleList(
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
+ )
+ self._feature_size = ch
+ input_block_chans = [ch]
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=int(mult * model_channels),
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(mult * model_channels)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+ self.pool = pool
+ if pool == "adaptive":
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d((1, 1)),
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
+ nn.Flatten(),
+ )
+ elif pool == "attention":
+ assert num_head_channels != -1
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ AttentionPool2d(
+ (image_size // ds), ch, num_head_channels, out_channels
+ ),
+ )
+ elif pool == "spatial":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ nn.ReLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ elif pool == "spatial_v2":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ normalization(2048),
+ nn.SiLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ else:
+ raise NotImplementedError(f"Unexpected {pool} pooling")
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps):
+ """
+ Apply the model to an input batch.
+
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :return: an [N x K] Tensor of outputs.
+ """
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+
+ results = []
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = self.middle_block(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = th.cat(results, axis=-1)
+ return self.out(h)
+ else:
+ h = h.type(x.dtype)
+ return self.out(h)
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..de0a036
--- /dev/null
+++ b/main.py
@@ -0,0 +1,175 @@
+import argparse
+import traceback
+import shutil
+import logging
+import yaml
+import sys
+import os
+import torch
+import numpy as np
+import torch.utils.tensorboard as tb
+
+from runners.diffusion import Diffusion
+
+torch.set_printoptions(sci_mode=False)
+
+
+def parse_args_and_config():
+ parser = argparse.ArgumentParser(description=globals()["__doc__"])
+
+ parser.add_argument(
+ "--config", type=str, required=True, help="Path to the config file"
+ )
+ parser.add_argument("--seed", type=int, default=1234, help="Random seed")
+ parser.add_argument(
+ "--exp", type=str, default="exp", help="Path for saving running related data."
+ )
+ parser.add_argument(
+ "--doc",
+ type=str,
+ required=True,
+ help="A string for documentation purpose. "
+ "Will be the name of the log folder.",
+ )
+ parser.add_argument(
+ "--comment", type=str, default="", help="A string for experiment comment"
+ )
+ parser.add_argument(
+ "--verbose",
+ type=str,
+ default="info",
+ help="Verbose level: info | debug | warning | critical",
+ )
+ parser.add_argument(
+ "--sample",
+ action="store_true",
+ help="Whether to produce samples from the model",
+ )
+ parser.add_argument(
+ "-i",
+ "--image_folder",
+ type=str,
+ default="images",
+ help="The folder name of samples",
+ )
+ parser.add_argument(
+ "--ni",
+ action="store_true",
+ help="No interaction. Suitable for Slurm Job launcher",
+ )
+ parser.add_argument(
+ "--timesteps", type=int, default=20, help="number of steps involved"
+ )
+ parser.add_argument(
+ "--deg", type=str, required=True, help="Degradation"
+ )
+ parser.add_argument(
+ "--num_avg_samples", type=int, default=1, help="Number of samples to average per input"
+ )
+ parser.add_argument(
+ "--eta", type=float, default=1.0, help="Eta"
+ )
+ parser.add_argument(
+ "--etaB", type=float, default=0.4, help="Eta_b"
+ )
+ parser.add_argument(
+ '--subset_start', type=int, default=-1
+ )
+ parser.add_argument(
+ '--subset_end', type=int, default=-1
+ )
+ parser.add_argument(
+ '--init_timestep', type=int, default=300
+ )
+
+ args = parser.parse_args()
+ args.log_path = os.path.join(args.exp, "logs", args.doc)
+
+ # parse config file
+ with open(os.path.join("configs", args.config), "r") as f:
+ config = yaml.safe_load(f)
+ new_config = dict2namespace(config)
+
+ tb_path = os.path.join(args.exp, "tensorboard", args.doc)
+
+ level = getattr(logging, args.verbose.upper(), None)
+ if not isinstance(level, int):
+ raise ValueError("level {} not supported".format(args.verbose))
+
+ handler1 = logging.StreamHandler()
+ formatter = logging.Formatter(
+ "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
+ )
+ handler1.setFormatter(formatter)
+ logger = logging.getLogger()
+ logger.addHandler(handler1)
+ logger.setLevel(level)
+
+ os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True)
+ args.image_folder = os.path.join(
+ args.exp, "image_samples", args.image_folder
+ )
+ if not os.path.exists(args.image_folder):
+ os.makedirs(args.image_folder)
+ else:
+ overwrite = False
+ if args.ni:
+ overwrite = True
+ else:
+ response = input(
+ f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)"
+ )
+ if response.upper() == "Y":
+ overwrite = True
+
+ if overwrite:
+ shutil.rmtree(args.image_folder)
+ os.makedirs(args.image_folder)
+ else:
+ print("Output image folder exists. Program halted.")
+ sys.exit(0)
+
+ # add device
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ logging.info("Using device: {}".format(device))
+ new_config.device = device
+
+ # set random seed
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(args.seed)
+
+ torch.backends.cudnn.benchmark = True
+
+ return args, new_config
+
+
+def dict2namespace(config):
+ namespace = argparse.Namespace()
+ for key, value in config.items():
+ if isinstance(value, dict):
+ new_value = dict2namespace(value)
+ else:
+ new_value = value
+ setattr(namespace, key, new_value)
+ return namespace
+
+
+def main():
+ args, config = parse_args_and_config()
+ logging.info("Writing log file to {}".format(args.log_path))
+ logging.info("Exp instance id = {}".format(os.getpid()))
+ logging.info("Exp comment = {}".format(args.comment))
+
+ try:
+ runner = Diffusion(args, config)
+ runner.sample()
+ except Exception:
+ logging.error(traceback.format_exc())
+
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/models/__pycache__/diffusion.cpython-37.pyc b/models/__pycache__/diffusion.cpython-37.pyc
new file mode 100644
index 0000000..16c4847
Binary files /dev/null and b/models/__pycache__/diffusion.cpython-37.pyc differ
diff --git a/models/__pycache__/diffusion.cpython-38.pyc b/models/__pycache__/diffusion.cpython-38.pyc
new file mode 100644
index 0000000..8a8a0d8
Binary files /dev/null and b/models/__pycache__/diffusion.cpython-38.pyc differ
diff --git a/models/__pycache__/diffusion.cpython-39.pyc b/models/__pycache__/diffusion.cpython-39.pyc
new file mode 100644
index 0000000..6f37c3b
Binary files /dev/null and b/models/__pycache__/diffusion.cpython-39.pyc differ
diff --git a/models/diffusion.py b/models/diffusion.py
new file mode 100644
index 0000000..db53ac1
--- /dev/null
+++ b/models/diffusion.py
@@ -0,0 +1,341 @@
+import math
+import torch
+import torch.nn as nn
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(
+ x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h*w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h*w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h*w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = torch.bmm(v, w_)
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class Model(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult)
+ num_res_blocks = config.model.num_res_blocks
+ attn_resolutions = config.model.attn_resolutions
+ dropout = config.model.dropout
+ in_channels = config.model.in_channels
+ resolution = config.data.image_size
+ resamp_with_conv = config.model.resamp_with_conv
+ num_timesteps = config.diffusion.num_diffusion_timesteps
+
+ if config.model.type == 'bayesian':
+ self.logvar = nn.Parameter(torch.zeros(num_timesteps))
+
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+ch_mult
+ self.down = nn.ModuleList()
+ block_in = None
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t):
+ assert x.shape[2] == x.shape[3] == self.resolution
+
+ # timestep embedding
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
diff --git a/runners/__init__.py b/runners/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/runners/__pycache__/__init__.cpython-37.pyc b/runners/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..8d9e1e8
Binary files /dev/null and b/runners/__pycache__/__init__.cpython-37.pyc differ
diff --git a/runners/__pycache__/__init__.cpython-38.pyc b/runners/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..f802c24
Binary files /dev/null and b/runners/__pycache__/__init__.cpython-38.pyc differ
diff --git a/runners/__pycache__/__init__.cpython-39.pyc b/runners/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000..6c36bb9
Binary files /dev/null and b/runners/__pycache__/__init__.cpython-39.pyc differ
diff --git a/runners/__pycache__/diffusion.cpython-37.pyc b/runners/__pycache__/diffusion.cpython-37.pyc
new file mode 100644
index 0000000..30c4930
Binary files /dev/null and b/runners/__pycache__/diffusion.cpython-37.pyc differ
diff --git a/runners/__pycache__/diffusion.cpython-38.pyc b/runners/__pycache__/diffusion.cpython-38.pyc
new file mode 100644
index 0000000..9b5360d
Binary files /dev/null and b/runners/__pycache__/diffusion.cpython-38.pyc differ
diff --git a/runners/__pycache__/diffusion.cpython-39.pyc b/runners/__pycache__/diffusion.cpython-39.pyc
new file mode 100644
index 0000000..2b1d7db
Binary files /dev/null and b/runners/__pycache__/diffusion.cpython-39.pyc differ
diff --git a/runners/diffusion.py b/runners/diffusion.py
new file mode 100644
index 0000000..9bc93b1
--- /dev/null
+++ b/runners/diffusion.py
@@ -0,0 +1,293 @@
+import os
+import logging
+import time
+import glob
+
+import numpy as np
+import tqdm
+import torch
+import torch.utils.data as data
+
+from models.diffusion import Model
+from datasets import get_dataset, data_transform, inverse_data_transform
+from functions.ckpt_util import get_ckpt_path, download
+from functions.denoising import jpeg_steps
+
+import torchvision.utils as tvu
+
+from guided_diffusion.unet import UNetModel
+from guided_diffusion.script_util import create_model, create_classifier, classifier_defaults, args_to_dict
+import random
+
+def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
+ def sigmoid(x):
+ return 1 / (np.exp(-x) + 1)
+
+ if beta_schedule == "quad":
+ betas = (
+ np.linspace(
+ beta_start ** 0.5,
+ beta_end ** 0.5,
+ num_diffusion_timesteps,
+ dtype=np.float64,
+ )
+ ** 2
+ )
+ elif beta_schedule == "linear":
+ betas = np.linspace(
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
+ )
+ elif beta_schedule == "const":
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
+ betas = 1.0 / np.linspace(
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
+ )
+ elif beta_schedule == "sigmoid":
+ betas = np.linspace(-6, 6, num_diffusion_timesteps)
+ betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
+ else:
+ raise NotImplementedError(beta_schedule)
+ assert betas.shape == (num_diffusion_timesteps,)
+ return betas
+
+
+class Diffusion(object):
+ def __init__(self, args, config, device=None):
+ self.args = args
+ self.config = config
+ if device is None:
+ device = (
+ torch.device("cuda")
+ if torch.cuda.is_available()
+ else torch.device("cpu")
+ )
+ self.device = device
+
+ self.model_var_type = config.model.var_type
+ betas = get_beta_schedule(
+ beta_schedule=config.diffusion.beta_schedule,
+ beta_start=config.diffusion.beta_start,
+ beta_end=config.diffusion.beta_end,
+ num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
+ )
+ betas = self.betas = torch.from_numpy(betas).float().to(self.device)
+ self.num_timesteps = betas.shape[0]
+
+ alphas = 1.0 - betas
+ alphas_cumprod = alphas.cumprod(dim=0)
+ alphas_cumprod_prev = torch.cat(
+ [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0
+ )
+ self.alphas_cumprod_prev = alphas_cumprod_prev
+ posterior_variance = (
+ betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
+ )
+ if self.model_var_type == "fixedlarge":
+ self.logvar = betas.log()
+ # torch.cat(
+ # [posterior_variance[1:2], betas[1:]], dim=0).log()
+ elif self.model_var_type == "fixedsmall":
+ self.logvar = posterior_variance.clamp(min=1e-20).log()
+
+ def sample(self):
+ cls_fn = None
+ if self.config.model.type == 'simple':
+ model = Model(self.config)
+ # This used the pretrained DDPM model, see https://github.com/pesser/pytorch_diffusion
+ if self.config.data.dataset == "CIFAR10":
+ name = "cifar10"
+ elif self.config.data.dataset == "LSUN":
+ name = f"lsun_{self.config.data.category}"
+ elif self.config.data.dataset == 'CelebA_HQ':
+ name = 'celeba_hq'
+ else:
+ raise ValueError
+ if name != 'celeba_hq':
+ ckpt = get_ckpt_path(f"ema_{name}", prefix=self.args.exp)
+ print("Loading checkpoint {}".format(ckpt))
+ elif name == 'celeba_hq':
+ #ckpt = '~/.cache/diffusion_models_converted/celeba_hq.ckpt'
+ ckpt = os.path.join(self.args.exp, "logs/celeba/celeba_hq.ckpt")
+ if not os.path.exists(ckpt):
+ download('https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt', ckpt)
+ else:
+ raise ValueError
+ model.load_state_dict(torch.load(ckpt, map_location=self.device))
+ model.to(self.device)
+ model = torch.nn.DataParallel(model)
+
+ elif self.config.model.type == 'openai':
+ config_dict = vars(self.config.model)
+ model = create_model(**config_dict)
+ if self.config.model.use_fp16:
+ model.convert_to_fp16()
+ if self.config.model.class_cond:
+ ckpt = os.path.join(self.args.exp, 'logs/imagenet/%dx%d_diffusion.pt' % (self.config.data.image_size, self.config.data.image_size))
+ if not os.path.exists(ckpt):
+ download('https://openaipublic.blob.core.windows.net/diffusion/jul-2021/%dx%d_diffusion_uncond.pt' % (self.config.data.image_size, self.config.data.image_size), ckpt)
+ else:
+ ckpt = os.path.join(self.args.exp, "logs/imagenet/256x256_diffusion_uncond.pt")
+ if not os.path.exists(ckpt):
+ download('https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt', ckpt)
+
+
+ model.load_state_dict(torch.load(ckpt, map_location=self.device))
+ model.to(self.device)
+ model.eval()
+ model = torch.nn.DataParallel(model)
+
+ if self.config.model.class_cond:
+ ckpt = os.path.join(self.args.exp, 'logs/imagenet/%dx%d_classifier.pt' % (self.config.data.image_size, self.config.data.image_size))
+ if not os.path.exists(ckpt):
+ image_size = self.config.data.image_size
+ download('https://openaipublic.blob.core.windows.net/diffusion/jul-2021/%dx%d_classifier.pt' % image_size, ckpt)
+ classifier = create_classifier(**args_to_dict(self.config.classifier, classifier_defaults().keys()))
+ classifier.load_state_dict(torch.load(ckpt, map_location=self.device))
+ classifier.to(self.device)
+ if self.config.classifier.classifier_use_fp16:
+ classifier.convert_to_fp16()
+ classifier.eval()
+ classifier = torch.nn.DataParallel(classifier)
+
+ import torch.nn.functional as F
+ def cond_fn(x, t, y):
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ logits = classifier(x_in, t)
+ log_probs = F.log_softmax(logits, dim=-1)
+ selected = log_probs[range(len(logits)), y.view(-1)]
+ return torch.autograd.grad(selected.sum(), x_in)[0] * self.config.classifier.classifier_scale
+ cls_fn = cond_fn
+
+ self.sample_sequence(model, cls_fn)
+
+ def sample_sequence(self, model, cls_fn=None):
+ args, config = self.args, self.config
+
+ #get original images and corrupted y_0
+ dataset, test_dataset = get_dataset(args, config)
+
+ device_count = torch.cuda.device_count()
+
+ if args.subset_start >= 0 and args.subset_end > 0:
+ assert args.subset_end > args.subset_start
+ test_dataset = torch.utils.data.Subset(test_dataset, range(args.subset_start, args.subset_end))
+ else:
+ args.subset_start = 0
+ args.subset_end = len(test_dataset)
+
+ print(f'Dataset has size {len(test_dataset)}')
+
+ def seed_worker(worker_id):
+ worker_seed = args.seed % 2**32
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
+
+ g = torch.Generator()
+ g.manual_seed(args.seed)
+ val_loader = data.DataLoader(
+ test_dataset,
+ batch_size=config.sampling.batch_size,
+ shuffle=False,
+ num_workers=config.data.num_workers,
+ worker_init_fn=seed_worker,
+ generator=g,
+ )
+
+
+ ## get degradation operator ##
+ deg = args.deg
+ if deg[:4] == 'jpeg':
+ jpeg_qf = int(deg[4:])
+ from functions.jpeg_torch import jpeg_decode as jd, jpeg_encode as je
+ from functools import partial
+ jpeg_decode = partial(jd, qf = jpeg_qf)
+ jpeg_encode = partial(je, qf = jpeg_qf)
+ elif deg == 'quant':
+ from functions.jpeg_torch import quantization_decode as qd, quantization_encode as qe
+ from functools import partial
+ jpeg_decode = partial(qd, qf = 0)
+ jpeg_encode = partial(qe, qf = 0)
+ else:
+ print("ERROR: degradation type not supported")
+ quit()
+ sigma_0 = 0 #no gaussian noise in jpeg measurements
+
+ print(f'Start from {args.subset_start}')
+ idx_init = args.subset_start
+ idx_so_far = args.subset_start
+ avg_psnr = 0.0
+ avg_psnr_y = 0.0
+ pbar = tqdm.tqdm(val_loader)
+ for x_orig, classes in pbar:
+ x_orig = x_orig.to(self.device)
+ x_orig = data_transform(self.config, x_orig)
+
+ y_0 = jpeg_decode(jpeg_encode(x_orig))
+ pinv_y_0 = y_0
+
+ for i in range(len(pinv_y_0)):
+ tvu.save_image(
+ inverse_data_transform(config, pinv_y_0[i]), os.path.join(self.args.image_folder, f"y0_{idx_so_far + i}.png")
+ )
+ tvu.save_image(
+ inverse_data_transform(config, x_orig[i]), os.path.join(self.args.image_folder, f"orig_{idx_so_far + i}.png")
+ )
+
+ ##Begin DDRM
+ x = 0
+ num_avg_samples = args.num_avg_samples
+ for _ in range(num_avg_samples):
+ local_x = torch.randn(
+ y_0.shape[0],
+ config.data.channels,
+ config.data.image_size,
+ config.data.image_size,
+ device=self.device,
+ )
+
+ # NOTE: This means that we are producing each predicted x0, not x_{t-1} at timestep t.
+ with torch.no_grad():
+ local_x, _ = self.sample_image_jpeg(local_x, model, y_0, last=False, cls_fn=cls_fn, classes=classes, jpeg_qf=jpeg_qf)
+
+ local_x = [inverse_data_transform(config, y) for y in local_x]
+ if x == 0: x = local_x
+ else: x = [x[i] + local_x[i] for i in range(len(local_x))]
+ x = [item / num_avg_samples for item in x]
+
+ for i in [-1]: # range(len(x)):
+ for j in range(x[i].size(0)):
+ tvu.save_image(
+ x[i][j], os.path.join(self.args.image_folder, f"{idx_so_far + j}_{i}.png")
+ )
+ if i == len(x)-1 or i == -1:
+ orig = inverse_data_transform(config, x_orig[j])
+ mse = torch.mean((x[i][j].to(self.device) - orig) ** 2)
+ psnr = 10 * torch.log10(1 / mse)
+ avg_psnr += psnr
+ if deg[:4] == 'jpeg':
+ y_inv = inverse_data_transform(config, pinv_y_0[j])
+ mse = torch.mean((y_inv - orig) ** 2)
+ psnr = 10 * torch.log10(1 / mse)
+ avg_psnr_y += psnr
+
+ idx_so_far += y_0.shape[0]
+
+ pbar.set_description("PSNR: %.2f" % (avg_psnr / (idx_so_far - idx_init)))
+
+ avg_psnr = avg_psnr / (idx_so_far - idx_init)
+ print("Total Average PSNR: %.2f" % avg_psnr)
+ avg_psnr_y = avg_psnr_y / (idx_so_far - idx_init)
+ print("Total Average PSNR of JPEG: %.2f" % avg_psnr_y)
+ print("Number of samples: %d" % (idx_so_far - idx_init))
+
+ def sample_image_jpeg(self, x, model, y_0, last=True, cls_fn=None, classes=None, jpeg_qf=None):
+ skip = self.args.init_timestep // self.args.timesteps
+ seq = range(0, self.args.init_timestep, skip)
+
+ x = jpeg_steps(x, seq, model, self.betas, y_0, \
+ etaB=self.args.etaB, etaA=self.args.eta, etaC=self.args.eta, cls_fn=cls_fn, classes=classes, jpeg_qf=jpeg_qf)
+ if last:
+ x = x[0][-1]
+ return x
\ No newline at end of file